TokenLearner in Action: How Adaptive Tokenization Boosts ViT Efficiency Without Sacrificing Accuracy

张开发
2026/4/16 19:29:07 15 分钟阅读

分享文章

TokenLearner in Action: How Adaptive Tokenization Boosts ViT Efficiency Without Sacrificing Accuracy
1. 什么是TokenLearner如果你用过Vision TransformerViT肯定知道它有个吃显存怪兽的称号。一张512x512的图片切成16x16的小块瞬间变成1024个token——这还只是静态图片要是换成视频每帧都这么处理计算量直接爆炸。TokenLearner就是来解决这个问题的它能让ViT动态学习该关注哪些区域把上千个token压缩到8-16个计算量直接砍半性能却不降反升。举个生活中的例子就像你看电影时大脑不会逐帧分析每个像素而是自动聚焦在主角的脸部动作或关键道具上。TokenLearner做的正是这件事——通过空间注意力机制让模型学会盯住图像中真正重要的区域。在NeurIPS 2021的实验中这个模块让ViT在ImageNet分类任务上节省50%计算量的同时准确率还提高了0.5%。2. TokenLearner的核心原理2.1 动态token生成机制传统ViT的tokenization简单粗暴把图像均匀分割成16x16的网格每个网格视为一个token。TokenLearner则像智能剪刀——用可学习的空间注意力图决定裁剪哪里。具体实现分三步特征提取输入图像经过卷积层生成H×W大小的空间特征图注意力加权通过并行的S个如8个注意力头生成S张空间权重图token压缩每张权重图与特征图逐元素相乘后全局池化最终输出S×C的token矩阵# 伪代码示例 def TokenLearner(x): # x: [H, W, C]的特征图 spatial_weights [conv_block(x) for _ in range(S)] # 生成S张注意力图 tokens [] for weight in spatial_weights: weighted_feature x * weight # Hadamard积 token global_avg_pool(weighted_feature) # 压缩为1xC向量 tokens.append(token) return stack(tokens) # 输出[S, C]2.2 计算效率的飞跃Transformer的计算复杂度与token数量呈平方关系。当token从196个ViT-B/16降到8个单层MHSA计算量从O(196²d)降到O(8²d)实际测试中插入TokenLearner后ViT-L模型的FLOPs从190G降至84G内存占用减少60%这在处理视频时尤为关键Kinetics数据集单样本可达4096个token3. 实战中的性能表现3.1 图像分类任务在ImageNet上对比三种配置模型类型位置策略Top-1准确率FLOPsViT-B/16基线无TokenLearner79.3%17.6G早置(1/4处)固定位置79.1%5.8G晚置(3/4处)自适应位置79.8%9.2G关键发现放在网络后1/4处效果最佳。因为浅层需要保留更多空间细节深层更适合做语义筛选。3.2 视频理解任务在Kinetics-400视频数据集上每帧提取8个token时空token总数8×帧数如64帧共512个token相比ViViT模型计算量减少8倍准确率从78.6%提升到79.3%在Charades长视频数据集上提升更明显2.1%可视化显示TokenLearner会随着人物移动动态调整关注区域。比如乒乓球比赛中它会自动追踪球拍和球的轨迹忽略静止的背景。4. 如何集成到现有模型4.1 插入位置选择通过消融实验发现中间层插入如12层Transformer的第6层后保持前6层完整处理所有token后6层仅处理8个学习到的token平衡精度与效率的最佳选择多阶段插入graph LR 输入图像 -- ViT_1-4层 ViT_1-4层 -- TokenLearner1(生成64token) TokenLearner1 -- ViT_5-8层 ViT_5-8层 -- TokenLearner2(生成8token) TokenLearner2 -- ViT_9-12层4.2 配合TokenFuser使用TokenFuser是配套模块主要做两件事信息融合通过token间的线性变换类似MLP-Mixer增强交互空间还原将压缩后的token重新映射回原始分辨率实测表明单独使用TokenLearner可使ViT-B/16的FLOPs降至9.8G配合TokenFuser后进一步降至8.3G且准确率提升0.4%5. 开发者实践指南5.1 在PyTorch中实现class TokenLearner(nn.Module): def __init__(self, in_dim768, num_tokens8): super().__init__() self.conv nn.Sequential( nn.Conv2d(in_dim, in_dim, 3, padding1, groups8), nn.GELU(), nn.Conv2d(in_dim, num_tokens, 1) ) def forward(self, x): # x: [B, H*W, C] B, N, C x.shape H W int(N**0.5) x x.permute(0,2,1).view(B, C, H, W) attn self.conv(x) # [B, S, H, W] attn attn.view(B, -1, H*W).softmax(dim-1) x x.view(B, C, H*W) tokens torch.einsum(bsn,bcn-bsc, attn, x) return tokens # [B, S, C]5.2 训练技巧渐进式训练第一阶段冻结TokenLearner正常训练ViT第二阶段解冻TokenLearner用1e-4小学习率微调token数量调度初始16个token每10个epoch减半最终稳定在8个token时效果最佳注意力的温度系数attn attn / temperature # 初始temperature2.0逐渐降到1.0我在实际项目中发现配合LayerScale技术给每个注意力头加可学习标量能提升稳定性。另外要注意当输入分辨率变化时需要调整注意力头的感受野大小。

更多文章