nn.Embedding
在这个代码片段中,TokenEmbedding
类继承了 torch.nn.Embedding
类,并在 __init__
方法中通过调用 super(TokenEmbedding, self).__init__(vocab_size, d_model, padding_idx=1)
来初始化父类 nn.Embedding
。由于 TokenEmbedding
没有定义新的方法,默认情况下它会使用 nn.Embedding
的行为来提供返回值。
nn.Embedding
的行为
nn.Embedding
是一个嵌入层,用于将词汇表中的单词映射为稠密的向量表示。它的作用是查找输入索引对应的嵌入向量,具体步骤如下:
- 当你传入词汇的索引(整数)时,它会从权重矩阵中查找对应的嵌入向量。
- 它不需要定义一个显式的
forward
方法,因为调用nn.Embedding
实例时,自动会执行这个查找操作。
使用方式
-
实例化
TokenEmbedding
:实例化时会初始化一个嵌入矩阵,矩阵的维度是vocab_size x d_model
,其中vocab_size
是词汇表的大小,d_model
是每个单词的向量维度。 -
调用实例:传入单词索引(整数序列),实例会返回对应的嵌入向量。
示例:
import torch
import torch.nn as nn
class TokenEmbedding(nn.Embedding):
def __init__(self, vocab_size, d_model):
super(TokenEmbedding, self).__init__(vocab_size, d_model, padding_idx=1)
# 假设词汇表大小为 100,嵌入维度为 64
vocab_size = 100
d_model = 64
# 实例化 TokenEmbedding
embedding_layer = TokenEmbedding(vocab_size, d_model)
# 创建输入张量,表示单词的索引
input_indices = torch.LongTensor([2, 5, 10])
# 调用实例,将词汇索引转换为嵌入向量
output = embedding_layer(input_indices)
print(output.shape) # 输出形状为 (3, 64),因为输入中有 3 个单词,每个单词的嵌入向量是 64 维
解释:
vocab_size
: 词汇表的大小,即可以表示多少个不同的单词。d_model
: 每个单词的嵌入向量的维度。padding_idx=1
: 用于指定填充标记的索引,通常是为了忽略填充标记在训练中的影响。
在这个类中,TokenEmbedding
类实际上没有显式返回值的方法,但是通过调用 __call__
方法(继承自 nn.Embedding
),它会查找并返回对应的嵌入向量。
原文地址:https://blog.csdn.net/qq_45809323/article/details/142442787
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!