pytorch重要函数介绍

时间:2021-02-01 12:58:42   收藏:0   阅读:0

一、torch.nn.Embedding

模块可以看做一个字典,字典中每个索引对应一个词和词的embedding形式。利用这个模块,可以给词做embedding的初始化操作

torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2, scale_grad_by_freq=False, sparse=False)

num_embeddings :字典中词的个数

embedding_dim:embedding的维度

padding_idx(索引指定填充):如果给定,则遇到padding_idx中的索引,则将其位置填0(0是默认值)。

输入输出:

input:(?) , LongTensor 结构

output:(*,e):*是input的大小,e是embedding_dim,即每个词的embedding的维度

注:embeddings中的值是正态分布N(0,1)中随机取值。

import torch
import torch.nn as nn
x = torch.LongTensor([[1,2,4],[4,3,2]])
embeddings = nn.Embedding(5,5,padding_idx=4) #5个词,每个词也是5维
print(embeddings(x))
print(embeddings(x).size())
 
 
output:
tensor([[[ 0.8839, -1.2889,  0.0697, -0.9998, -0.7471],
         [-0.5681,  0.8486,  0.8176,  0.8349,  0.1719],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000]],  ->index=4 赋值 0
 
        [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],   ->index=4 赋值 0
         [ 1.4224,  0.2333,  1.9383, -0.7320,  0.9987],
         [-0.5681,  0.8486,  0.8176,  0.8349,  0.1719]]],
       grad_fn=<EmbeddingBackward>)
torch.Size([2, 3, 5])

 

评论(0
© 2014 mamicode.com 版权所有 京ICP备13008772号-2  联系我们:gaon5@hotmail.com
迷上了代码!