fastNLP.models.seq2seq_generator module

undocumented

class fastNLP.models.seq2seq_generator.SequenceGeneratorModel(seq2seq_model: fastNLP.models.seq2seq_model.Seq2SeqModel, bos_token_id, eos_token_id=None, max_length=30, max_len_a=0.0, num_beams=1, do_sample=True, temperature=1.0, top_k=50, top_p=1.0, repetition_penalty=1, length_penalty=1.0, pad_token_id=0)[源代码]

基类:torch.nn.modules.module.Module

通过使用本模型封装seq2seq_model使得其既可以用于训练也可以用于生成。训练的时候,本模型的forward函数会被调用,生成的时候本模型的predict

函数会被调用。

__init__(seq2seq_model: fastNLP.models.seq2seq_model.Seq2SeqModel, bos_token_id, eos_token_id=None, max_length=30, max_len_a=0.0, num_beams=1, do_sample=True, temperature=1.0, top_k=50, top_p=1.0, repetition_penalty=1, length_penalty=1.0, pad_token_id=0)[源代码]
参数
  • seq2seq_model (Seq2SeqModel) – 序列到序列模型

  • bos_token_id (int,None) – 句子开头的token id

  • eos_token_id (int,None) – 句子结束的token id

  • max_length (int) – 生成句子的最大长度, 每句话的decode长度为max_length + max_len_a*src_len

  • max_len_a (float) – 每句话的decode长度为max_length + max_len_a*src_len。 如果不为0,需要保证State中包含encoder_mask

  • num_beams (int) – beam search的大小

  • do_sample (bool) – 是否通过采样的方式生成

  • temperature (float) – 只有在do_sample为True才有意义

  • top_k (int) – 只从top_k中采样

  • top_p (float) – 只从top_p的token中采样,nucles sample

  • repetition_penalty (float) – 多大程度上惩罚重复的token

  • length_penalty (float) – 对长度的惩罚,小于1鼓励长句,大于1鼓励短剧

  • pad_token_id (int) – 当某句话生成结束之后,之后生成的内容用pad_token_id补充

forward(src_tokens, tgt_tokens, src_seq_len=None, tgt_seq_len=None)[源代码]

透传调用seq2seq_model的forward。

参数
  • src_tokens (torch.LongTensor) – bsz x max_len

  • tgt_tokens (torch.LongTensor) – bsz x max_len’

  • src_seq_len (torch.LongTensor) – bsz

  • tgt_seq_len (torch.LongTensor) – bsz

返回

predict(src_tokens, src_seq_len=None)[源代码]

给定source的内容,输出generate的内容。

参数
  • src_tokens (torch.LongTensor) – bsz x max_len

  • src_seq_len (torch.LongTensor) – bsz

返回

training: bool