fastNLP.modules.decoder.seq2seq_decoder module

undocumented

class fastNLP.modules.decoder.seq2seq_decoder.Seq2SeqDecoder[源代码]

基类:torch.nn.modules.module.Module

别名 fastNLP.modules.Seq2SeqDecoder fastNLP.modules.decoder.Seq2SeqDecoder

Sequence-to-Sequence Decoder的基类。一定需要实现forward、decode函数,剩下的函数根据需要实现。每个Seq2SeqDecoder都应该有相应的State对象

用来承载该Decoder所需要的Encoder输出、Decoder需要记录的历史信息(例如LSTM的hidden信息)。

forward(tokens, state, **kwargs)[源代码]
参数
  • tokens (torch.LongTensor) – bsz x max_len

  • state (State) – state包含了encoder的输出以及decode之前的内容

返回

返回值可以为bsz x max_len x vocab_size的Tensor,也可以是一个list,但是第一个元素必须是词的预测分布

reorder_states(indices, states)[源代码]

根据indices重新排列states中的状态,在beam search进行生成时,会用到该函数。

参数
  • indices (torch.LongTensor) –

  • states (State) –

返回

init_state(encoder_output, encoder_mask)[源代码]

初始化一个state对象,用来记录了encoder的输出以及decode已经完成的部分。

参数
  • list, tuple] encoder_output (Union[torch.Tensor,) – 如果不为None,内部元素需要为torch.Tensor, 默认其中第一维是batch 维度

  • list, tuple] encoder_mask (Union[torch.Tensor,) – 如果部位None,内部元素需要torch.Tensor, 默认其中第一维是batch 维度

  • kwargs

返回

State, 返回一个State对象,记录了encoder的输出

decode(tokens, state)[源代码]

根据states中的内容,以及tokens中的内容进行之后的生成。

参数
  • tokens (torch.LongTensor) – bsz x max_len, 截止到上一个时刻所有的token输出。

  • state (State) – 记录了encoder输出与decoder过去状态

返回

torch.FloatTensor: bsz x vocab_size, 输出的是下一个时刻的分布

training: bool