第2.1讲、《The Annotated Transformer》代码精讲:从理论到PyTorch实现

张开发
2026/4/19 18:38:48 15 分钟阅读

分享文章

第2.1讲、《The Annotated Transformer》代码精讲:从理论到PyTorch实现
1. 从理论到代码Transformer的核心实现第一次看到《Attention is All You Need》论文时我被那些复杂的矩阵运算和公式吓到了。直到发现了哈佛大学NLP团队的《The Annotated Transformer》项目才真正理解如何用PyTorch实现这个革命性的模型。这个项目就像一位耐心的导师把论文中的每个公式都转化成了可运行的代码。让我们从一个具体例子开始理解。论文中著名的注意力计算公式是 $$Attention(Q,K,V)softmax(\frac{QK^T}{\sqrt{d_k}})V$$在PyTorch中这个公式的实现出奇地简洁def attention(query, key, value, maskNone, dropoutNone): d_k query.size(-1) scores torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) if mask is not None: scores scores.masked_fill(mask 0, -1e9) p_attn F.softmax(scores, dim-1) if dropout is not None: p_attn dropout(p_attn) return torch.matmul(p_attn, value), p_attn这段代码有几个关键点值得注意torch.matmul实现了Q和K的矩阵乘法除以$\sqrt{d_k}$的操作防止梯度消失论文中的核心发现之一mask机制在处理变长序列时至关重要2. 位置编码的魔法实现Transformer没有RNN的序列处理能力所以需要手动加入位置信息。论文使用正弦和余弦函数来编码位置$$PE_{(pos,2i)}sin(pos/10000^{2i/d_{model}})$$ $$PE_{(pos,2i1)}cos(pos/10000^{2i/d_{model}})$$PyTorch实现展示了如何高效生成这些编码class PositionalEncoding(nn.Module): def __init__(self, d_model, dropout, max_len5000): super(PositionalEncoding, self).__init__() self.dropout nn.Dropout(pdropout) pe torch.zeros(max_len, d_model) position torch.arange(0, max_len).unsqueeze(1) div_term torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)) pe[:, 0::2] torch.sin(position * div_term) pe[:, 1::2] torch.cos(position * div_term) pe pe.unsqueeze(0) self.register_buffer(pe, pe) def forward(self, x): x x self.pe[:, :x.size(1)] return self.dropout(x)这段代码有几个精妙之处使用对数空间计算div_term避免数值不稳定向量化操作同时生成所有位置编码register_buffer确保位置编码能随模型一起保存和加载3. 多头注意力的并行计算多头注意力是Transformer的核心创新它允许模型同时关注不同位置的多个特征子空间。代码实现展示了如何高效拆分和重组注意力头class MultiHeadedAttention(nn.Module): def __init__(self, h, d_model, dropout0.1): super(MultiHeadedAttention, self).__init__() assert d_model % h 0 self.d_k d_model // h self.h h self.linears clones(nn.Linear(d_model, d_model), 4) self.attn None self.dropout nn.Dropout(pdropout) def forward(self, query, key, value, maskNone): if mask is not None: mask mask.unsqueeze(1) nbatches query.size(0) # 1) 线性投影并分头 query, key, value [ l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) for l, x in zip(self.linears, (query, key, value)) ] # 2) 计算注意力 x, self.attn attention(query, key, value, maskmask, dropoutself.dropout) # 3) 合并多头结果 x x.transpose(1, 2).contiguous() \ .view(nbatches, -1, self.h * self.d_k) return self.linears[-1](x)关键实现细节使用clones函数复制线性层后面会讲到这个工具函数view和transpose操作实现高效的分头处理最后的线性层将所有头的结果融合4. 残差连接与层归一化Transformer使用残差连接和层归一化来稳定深层网络的训练。这个看似简单的设计对模型性能至关重要class SublayerConnection(nn.Module): def __init__(self, size, dropout): super(SublayerConnection, self).__init__() self.norm nn.LayerNorm(size) self.dropout nn.Dropout(dropout) def forward(self, x, sublayer): return x self.dropout(sublayer(self.norm(x)))这个实现采用了Post-Norm方式先归一化再残差连接。我在实际项目中发现对小模型Pre-Norm先残差再归一化通常更稳定Dropout的位置影响模型正则化效果归一化维度需要与输入特征维度一致5. 编码器层的完整实现一个完整的编码器层结合了多头注意力和前馈网络class EncoderLayer(nn.Module): def __init__(self, size, self_attn, feed_forward, dropout): super(EncoderLayer, self).__init__() self.self_attn self_attn self.feed_forward feed_forward self.sublayer clones(SublayerConnection(size, dropout), 2) self.size size def forward(self, x, mask): x self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) return self.sublayer[1](x, self.feed_forward)这里有几个值得注意的设计使用lambda表达式延迟计算保持接口统一相同的输入x作为Q、K、V自注意力机制前馈网络采用两层的全连接结构6. 解码器的特殊处理解码器比编码器更复杂需要处理两种不同的注意力机制class DecoderLayer(nn.Module): def __init__(self, size, self_attn, src_attn, feed_forward, dropout): super(DecoderLayer, self).__init__() self.size size self.self_attn self_attn self.src_attn src_attn self.feed_forward feed_forward self.sublayer clones(SublayerConnection(size, dropout), 3) def forward(self, x, memory, src_mask, tgt_mask): m memory x self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask)) x self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask)) return self.sublayer[2](x, self.feed_forward)解码器的关键特点第一层是带掩码的自注意力防止看到未来信息第二层是源注意力连接编码器的输出每层都保留残差连接7. 模型训练技巧Transformer使用了一些特殊的训练技巧包括标签平滑和学习率调度class LabelSmoothing(nn.Module): def __init__(self, size, padding_idx, smoothing0.0): super(LabelSmoothing, self).__init__() self.criterion nn.KLDivLoss(reductionsum) self.padding_idx padding_idx self.confidence 1.0 - smoothing self.smoothing smoothing self.size size self.true_dist None def forward(self, x, target): true_dist x.data.clone() true_dist.fill_(self.smoothing / (self.size - 2)) true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) true_dist[:, self.padding_idx] 0 mask torch.nonzero(target.data self.padding_idx) if mask.dim() 0: true_dist.index_fill_(0, mask.squeeze(), 0.0) self.true_dist true_dist return self.criterion(x, true_dist) class NoamOpt: def __init__(self, model_size, factor, warmup, optimizer): self.optimizer optimizer self._step 0 self.warmup warmup self.factor factor self.model_size model_size self._rate 0 def step(self): self._step 1 rate self.rate() for p in self.optimizer.param_groups: p[lr] rate self._rate rate self.optimizer.step() def rate(self, stepNone): if step is None: step self._step return self.factor * \ (self.model_size ** (-0.5) * min(step ** (-0.5), step * self.warmup ** (-1.5)))这些技巧的实际效果标签平滑防止模型对预测结果过于自信Noam优化器实现动态学习率调整预热期(warmup)对稳定训练很重要8. 实际应用示例让我们看一个完整的翻译示例def greedy_decode(model, src, src_mask, max_len, start_symbol): memory model.encode(src, src_mask) ys torch.ones(1, 1).fill_(start_symbol).type_as(src.data) for i in range(max_len-1): out model.decode(ys, memory, src_mask, subsequent_mask(ys.size(1)).type_as(src.data)) prob model.generator(out[:, -1]) _, next_word torch.max(prob, dim1) next_word next_word.data[0] ys torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim1) return ys这个贪心解码器展示了如何逐步生成输出序列使用subsequent_mask防止解码器看到未来信息编码器-解码器的交互方式9. 注意力可视化理解模型关注什么是调试Transformer的重要技能def draw(data, x, y, ax): seaborn.heatmap(data, xticklabelsx, squareTrue, yticklabelsy, vmin0.0, vmax1.0, cbarFalse, axax) # 可视化编码器注意力 for layer in range(1, 6, 2): fig, axs plt.subplots(1,4, figsize(20, 10)) print(Encoder Layer, layer1) for h in range(4): draw(model.encoder.layers[layer].self_attn.attn[0, h].data, sent, sent if h 0 else [], axaxs[h]) plt.show()这种可视化可以帮助我们发现注意力头学习到的不同模式检查模型是否关注了正确的词语诊断层与层之间的注意力变化10. 工程实践建议在实际项目中实现Transformer时我总结了一些经验从小规模开始先验证单个注意力头的效果使用梯度检查点(gradient checkpointing)节省显存混合精度训练可以显著加快速度注意初始化策略对深层Transformer的影响不同的任务可能需要调整注意力头的数量调试Transformer模型时一个好的做法是从最简单的配置开始逐步增加复杂度。比如先去掉所有残差连接和归一化层只测试基本的注意力机制是否能工作然后再逐个添加其他组件。

更多文章