fastNLP.embeddings.static_embedding module

class fastNLP.embeddings.static_embedding.StaticEmbedding(vocab: fastNLP.core.vocabulary.Vocabulary, model_dir_or_name: Optional[str] = 'en', embedding_dim=- 1, requires_grad: bool = True, init_method=None, lower=False, dropout=0, word_dropout=0, normalize=False, min_freq=1, **kwargs)[源代码]

基类:fastNLP.embeddings.embedding.TokenEmbedding

基类 fastNLP.embeddings.TokenEmbedding

别名 fastNLP.embeddings.StaticEmbedding fastNLP.embeddings.static_embedding.StaticEmbedding

StaticEmbedding组件. 给定预训练embedding的名称或路径,根据vocab从embedding中抽取相应的数据(只会将出现在vocab中的词抽取出来, 如果没有找到,则会随机初始化一个值(但如果该word是被标记为no_create_entry的话,则不会单独创建一个值,而是会被指向unk的index))。 当前支持自动下载的预训练vector有:

en: 实际为en-glove-840b-300d(常用)
en-glove-6b-50d: glove官方的50d向量
en-glove-6b-100d: glove官方的100d向量
en-glove-6b-200d: glove官方的200d向量
en-glove-6b-300d: glove官方的300d向量
en-glove-42b-300d: glove官方使用42B数据训练版本
en-glove-840b-300d:
en-glove-twitter-27b-25d:
en-glove-twitter-27b-50d:
en-glove-twitter-27b-100d:
en-glove-twitter-27b-200d:
en-word2vec-300d: word2vec官方发布的300d向量
en-fasttext-crawl: fasttext官方发布的300d英文预训练
cn-char-fastnlp-100d: fastNLP训练的100d的character embedding
cn-bi-fastnlp-100d: fastNLP训练的100d的bigram embedding
cn-tri-fastnlp-100d: fastNLP训练的100d的trigram embedding
cn-fasttext: fasttext官方发布的300d中文预训练embedding

Example:

>>> from fastNLP import Vocabulary
>>> from fastNLP.embeddings import StaticEmbedding
>>> vocab = Vocabulary().add_word_lst("The whether is good .".split())
>>> embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-50d')

>>> vocab = Vocabulary().add_word_lst(["The", 'the', "THE"])
>>> embed = StaticEmbedding(vocab, model_dir_or_name="en-glove-50d", lower=True)
>>> # "the", "The", "THE"它们共用一个vector,且将使用"the"在预训练词表中寻找它们的初始化表示。

>>> vocab = Vocabulary().add_word_lst(["The", "the", "THE"])
>>> embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5, lower=True)
>>> words = torch.LongTensor([[vocab.to_index(word) for word in ["The", "the", "THE"]]])
>>> embed(words)
>>> tensor([[[ 0.5773,  0.7251, -0.3104,  0.0777,  0.4849],
             [ 0.5773,  0.7251, -0.3104,  0.0777,  0.4849],
             [ 0.5773,  0.7251, -0.3104,  0.0777,  0.4849]]],
           grad_fn=<EmbeddingBackward>)  # 每种word的输出是一致的。
__init__(vocab: fastNLP.core.vocabulary.Vocabulary, model_dir_or_name: Optional[str] = 'en', embedding_dim=- 1, requires_grad: bool = True, init_method=None, lower=False, dropout=0, word_dropout=0, normalize=False, min_freq=1, **kwargs)[源代码]
参数
  • vocab (Vocabulary) – 词表. StaticEmbedding只会加载包含在词表中的词的词向量,在预训练向量中没找到的使用随机初始化

  • model_dir_or_name – 可以有两种方式调用预训练好的static embedding:第一种是传入embedding文件夹(文件夹下应该只有一个 以.txt作为后缀的文件)或文件路径;第二种是传入embedding的名称,第二种情况将自动查看缓存中是否存在该模型,没有的话将自动下载。 如果输入为None则使用embedding_dim的维度随机初始化一个embedding。

  • embedding_dim (int) – 随机初始化的embedding的维度,当该值为大于0的值时,将忽略model_dir_or_name。

  • requires_grad (bool) – 是否需要gradient. 默认为True

  • init_method (callable) – 如何初始化没有找到的值。可以使用torch.nn.init.*中各种方法, 传入的方法应该接受一个tensor,并 inplace地修改其值。

  • lower (bool) – 是否将vocab中的词语小写后再和预训练的词表进行匹配。如果你的词表中包含大写的词语,或者就是需要单独 为大写的词语开辟一个vector表示,则将lower设置为False。

  • dropout (float) – 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。

  • word_dropout (float) – 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。

  • normalize (bool) – 是否对vector进行normalize,使得每个vector的norm为1。

  • min_freq (int) – Vocabulary词频数小于这个数量的word将被指向unk。

  • kwargs (dict) – bool only_train_min_freq: 仅对train中的词语使用min_freq筛选; bool only_norm_found_vector: 默认为False, 是否仅对在预训练中找到的词语使用normalize; bool only_use_pretrain_word: 默认为False, 仅使用出现在pretrain词表中的词,如果该词没有在预训练的词表中出现则为unk。如果embedding不需要更新建议设置为True。

forward(words)[源代码]

传入words的index

参数

words – torch.LongTensor, [batch_size, max_len]

返回

torch.FloatTensor, [batch_size, max_len, embed_size]

save(folder)[源代码]

将embedding存储到folder下,之后可以通过使用load方法读取

参数

folder (str) – 会在该folder下生成三个文件, vocab.txt, static_embed_hyper.txt, static_embed_hyper.json. 其中vocab.txt可以用Vocabulary通过load读取; embedding.txt按照word2vec的方式存储,以空格的方式隔开元素, 第一行只有两个元素,剩下的行首先是word然后是各个维度的值; static_embed_hyper.json是StaticEmbedding的超参数

返回

classmethod load(folder)[源代码]
参数

folder (str) – 该folder下应该有以下三个文件vocab.txt, static_embed.txt, static_hyper.json

返回

training: bool