fastNLP.modules.generator.seq2seq_generator module

class fastNLP.modules.generator.seq2seq_generator.SequenceGenerator(decoder: fastNLP.modules.decoder.seq2seq_decoder.Seq2SeqDecoder, max_length=20, max_len_a=0.0, num_beams=1, do_sample=True, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=None, eos_token_id=None, repetition_penalty=1, length_penalty=1.0, pad_token_id=0)[源代码]

基类:object

别名 fastNLP.modules.SequenceGenerator fastNLP.modules.generator.seq2seq_generator.SequenceGenerator

给定一个Seq2SeqDecoder,decode出句子。输入的decoder对象需要有decode()函数, 接受的第一个参数为decode的到目前位置的所有输出,

第二个参数为state。SequenceGenerator不会对state进行任何操作。

__init__(decoder: fastNLP.modules.decoder.seq2seq_decoder.Seq2SeqDecoder, max_length=20, max_len_a=0.0, num_beams=1, do_sample=True, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=None, eos_token_id=None, repetition_penalty=1, length_penalty=1.0, pad_token_id=0)[源代码]
参数
  • decoder (Seq2SeqDecoder) – Decoder对象

  • max_length (int) – 生成句子的最大长度, 每句话的decode长度为max_length + max_len_a*src_len

  • max_len_a (float) – 每句话的decode长度为max_length + max_len_a*src_len。 如果不为0,需要保证State中包含encoder_mask

  • num_beams (int) – beam search的大小

  • do_sample (bool) – 是否通过采样的方式生成

  • temperature (float) – 只有在do_sample为True才有意义

  • top_k (int) – 只从top_k中采样

  • top_p (float) – 只从top_p的token中采样,nucles sample

  • bos_token_id (int,None) – 句子开头的token id

  • eos_token_id (int,None) – 句子结束的token id

  • repetition_penalty (float) – 多大程度上惩罚重复的token

  • length_penalty (float) – 对长度的惩罚,小于1鼓励长句,大于1鼓励短剧

  • pad_token_id (int) – 当某句话生成结束之后,之后生成的内容用pad_token_id补充

generate(state, tokens=None)[源代码]
参数
  • state (State) – encoder结果的State, 是与Decoder配套是用的

  • tokens (torch.LongTensor,None) – batch_size x length, 开始的token。如果为None,则默认添加bos_token作为开头的token 进行生成。

返回

bsz x max_length’ 生成的token序列。如果eos_token_id不为None, 每个sequence的结尾一定是eos_token_id