多模态动态融合模型Predictive Dynamic Fusion阅读与代码分析运行1-信度概念与基础参数指标

张开发
2026/4/16 15:01:09 15 分钟阅读

分享文章

多模态动态融合模型Predictive Dynamic Fusion阅读与代码分析运行1-信度概念与基础参数指标
参考文Cao B, Xia Y, Ding Y, et al. Predictive Dynamic Fusion[J]. arXiv preprint arXiv:2406.04802, 2024.[2406.04802] Predictive Dynamic Fusion一、理论今天就先看看论文中的各个指标含义和多模态训练代码的参数吧文章中一个比较重要的概念就是置信度的概念了在论文前段对置信度的扩展比较多同时没有什么具体说明不知道概念的话读着还是很混乱的置信度在机器学习中置信度表示模型对其预测结果“有多确定”。它刻画的是模型认为自己预测是正确的程度例如在分类任务中“这是正类的概率是 0.92”那么 0.92 就可以视为模型对该预测的置信度在监督学习中给定输入样本 xxx模型预测类别为 y^\hat{y}y^​则置信度通常定义为即模型对预测类别的后验概率估计置信度 和 不确定性补充文中用熵来衡量整体不确定性算是置信度的一种扩展关于熵的概念之前在b站看到的一位up主讲的很生动https://www.bilibili.com/video/BV15V411W7VB/置信度高 熵低分类评价指标对比指标含义对照指标一句话解释Accuracy模型整体准不准Precision模型说“是”的时候靠谱吗Recall真正“是”的有没有被找全F1Precision 和 Recall 的折中ROC-AUC正样本排在负样本前面的能力The Mono-Confidences and Holo-Confidences该文的目的之一是为了解决模态权重融合的权重问题也就是多个模态分别从多个维度评价目标的状态给出不一样的结果怎么融合这几个结果的问题。目前可以确定的是融合权重 ω 应当与损失 l 呈负相关并且与其他模态的损失呈正相关。也就是当前模态越可靠 → 权重越大其他模态越不可靠 → 当前模态权重越大对单个模态的模型权重 ω 是要求的权重损失loss是所以就有人两个信度指标The Mono-ConfidencesHolo-Confidences当前模态本身有多可靠相对其他模态我有多可靠将他们统合Co-Belief协同信度Mono-Confidence只看自己Holo-Confidence只看别人但多模态融合需要既考虑自身可靠性又考虑整体模态状态。故有再由协同信度确定该模态的权重。理论先到这里其他的后面再看二、代码1、运行环境代码训练环境没有明确说明但根据结构可以看得出来用的是autodl里的云服务器Ubuntu20.04python3.11的版本卡随便租一个都一样。论文附带代码只有2mb明显缺失了很多预训练结构与数据集文件2、数据集文件这里选用了代码中可选的第二个训练集MVSA_Single需要自己到网站下好转到autodl服务器上MVSA_Single训练集之类的划分源代码已有了自己按要求放到同一目录下即可。3、词向量文件源代码缺失了预训练好的词向量文件glove.840B.300d需要自己使用指令下载到指定目录wget https://nlp.stanford.edu/data/glove.840B.300d.zip4、源代码逻辑错误训练代码中的forward函数存在运行逻辑错误文本和图像的losstxt_clf_loss和img_clf_loss定义在了if之外会运行不成功估计是作者没有仔细整理代码算法逻辑倒没什么问题原代码150行左右def model_forward(i_epoch, model, args, criterion,optimizer, batch,modeeval): txt, segment, mask, img, tgt,idx batch freeze_img i_epoch args.freeze_img freeze_txt i_epoch args.freeze_txt if args.model bow: txt txt.cuda() out model(txt) elif args.model img: img img.cuda() out model(img) elif args.model concatbow: txt, img txt.cuda(), img.cuda() out model(txt, img) elif args.model bert: txt, mask, segment txt.cuda(), mask.cuda(), segment.cuda() out model(txt, mask, segment) elif args.model concatbert: txt, img txt.cuda(), img.cuda() mask, segment mask.cuda(), segment.cuda() out model(txt, mask, segment, img) elif args.model latefusion_pdf: txt, img txt.cuda(), img.cuda() mask, segment mask.cuda(), segment.cuda() tgt tgt.cuda() maeloss nn.L1Loss(reductionmean) out, txt_logits, img_logits, txt_tcp_pred, img_tcp_pred model(txt, mask,segment,img,pdf_train) label F.one_hot(tgt, num_classesargs.n_classes) # [b,c] if args.task_type multilabel: txt_pred torch.sigmoid(txt_logits) img_pred torch.sigmoid(img_logits) else: txt_pred torch.nn.functional.softmax(txt_logits, dim1) img_pred torch.nn.functional.softmax(img_logits, dim1) txt_tcp, _ torch.max(txt_pred * label, dim1,keepdimTrue) img_tcp, _ torch.max(img_pred * label, dim1,keepdimTrue) tcp_pred_loss maeloss(txt_tcp_pred, txt_tcp.detach()) maeloss(img_tcp_pred, img_tcp.detach()) else: assert args.model mmbt for param in model.enc.img_encoder.parameters(): param.requires_grad not freeze_img for param in model.enc.encoder.parameters(): param.requires_grad not freeze_txt txt, img txt.cuda(), img.cuda() mask, segment mask.cuda(), segment.cuda() out model(txt, mask, segment, img) tgt tgt.cuda() txt_clf_loss nn.CrossEntropyLoss()(txt_logits, tgt) img_clf_loss nn.CrossEntropyLoss()(img_logits, tgt) clf_losstxt_clf_lossimg_clf_lossnn.CrossEntropyLoss()(out,tgt) if modetrain: loss torch.mean(clf_loss)torch.mean(tcp_pred_loss) return loss,out,tgt else: loss torch.mean(clf_loss)torch.mean(tcp_pred_loss) return loss,out,tgt修改后def model_forward(i_epoch, model, args, criterion, optimizer, batch, modeeval): txt, segment, mask, img, tgt, idx batch tgt tgt.cuda() clf_loss 0.0 tcp_pred_loss 0.0 # ⭐ 先初始化避免炸 # ---------- 普通单 / 早期融合模型 ---------- if args.model bow: txt txt.cuda() out model(txt) clf_loss criterion(out, tgt) elif args.model img: img img.cuda() out model(img) clf_loss criterion(out, tgt) elif args.model concatbow: txt, img txt.cuda(), img.cuda() out model(txt, img) clf_loss criterion(out, tgt) elif args.model bert: txt, mask, segment txt.cuda(), mask.cuda(), segment.cuda() out model(txt, mask, segment) clf_loss criterion(out, tgt) elif args.model concatbert: txt, img txt.cuda(), img.cuda() mask, segment mask.cuda(), segment.cuda() out model(txt, mask, segment, img) clf_loss criterion(out, tgt) # ---------- late fusion特例 ---------- elif args.model latefusion_pdf: txt, img txt.cuda(), img.cuda() mask, segment mask.cuda(), segment.cuda() out, txt_logits, img_logits, txt_tcp_pred, img_tcp_pred \ model(txt, mask, segment, img, pdf_train) # 分类 loss txt_loss criterion(txt_logits, tgt) img_loss criterion(img_logits, tgt) clf_loss txt_loss img_loss # TCP loss maeloss nn.L1Loss(reductionmean) label F.one_hot(tgt, num_classesargs.n_classes) if args.task_type multilabel: txt_pred torch.sigmoid(txt_logits) img_pred torch.sigmoid(img_logits) else: txt_pred F.softmax(txt_logits, dim1) img_pred F.softmax(img_logits, dim1) txt_tcp, _ torch.max(txt_pred * label, dim1, keepdimTrue) img_tcp, _ torch.max(img_pred * label, dim1, keepdimTrue) tcp_pred_loss ( maeloss(txt_tcp_pred, txt_tcp.detach()) maeloss(img_tcp_pred, img_tcp.detach()) ) # ---------- mmbt ---------- else: assert args.model mmbt txt, img txt.cuda(), img.cuda() mask, segment mask.cuda(), segment.cuda() out model(txt, mask, segment, img) clf_loss criterion(out, tgt) # ---------- 总 loss ---------- loss clf_loss tcp_pred_loss return loss, out, tgt四、各训练参数主要是get_args里面的参数解释训练与优化相关参数参数名默认值含义说明影响阶段备注 / 建议batch_sz128每个 batch 的样本数量训练大 batch 更稳定但占显存gradient_accumulation_steps24梯度累积步数训练等效 batch batch_sz × stepslr1e-4初始学习率训练BERT 微调常用 1e-55e-5weight_decay0.0权重衰减系数L2 正则训练防止过拟合dropout0.1Dropout 概率模型Transformer 常用 0.1max_epochs100最大训练轮数训练搭配 early stoppingpatience10Early stopping 容忍轮数训练验证集无提升时停止warmup0.1学习率 warmup 比例训练防止初期梯度震荡lr_factor0.5学习率衰减倍率训练ReduceLROnPlateaulr_patience2学习率衰减等待轮数训练验证集不提升则降 lrseed123随机种子全局保证实验可复现n_workers12DataLoader 线程数数据加载与 CPU 核数相关文本模态参数名默认值含义说明影响阶段备注bert_model./bert-base-uncasedBERT 预训练模型路径模型可换成 largefreeze_txt0是否冻结文本编码器训练1 表示不更新 BERTmax_seq_len512文本最大 token 长度数据BERT 上限embed_sz300词向量维度模型对应 GloVeglove_pathglove.840B.300d.txtGloVe 文件路径数据300 维hidden_sz768文本隐藏层维度模型BERT-base 默认图像模态Image相关参数参数名默认值含义说明影响阶段备注img_hidden_sz2048图像特征维度模型ResNet 输出num_image_embeds1图像 token 数模型MMBT 中常见img_embed_pool_typeavg图像特征池化方式模型avg / maxfreeze_img0是否冻结图像编码器训练1 表示冻结drop_img_percent0.0随机丢弃图像比例数据增强模态缺失模拟融合参数参数名默认值含义说明影响阶段备注modellatefusion_pdf使用的模型结构模型PDF Predictive Dynamic Fusionhidden[]额外隐藏层结构模型如 [512,256]include_bnTrue是否使用 BatchNorm模型提高训练稳定性dfTrue是否启用动态融合模型PDF 核心开关baselineNone对比方法名称实验仅用于记录任务与数据相关参数参数名默认值含义说明影响阶段备注taskMVSA_Single使用的数据集数据多模态情绪识别task_typeclassification任务类型训练单标签 / 多标签weight_classes1是否类别加权loss类别不平衡时用noise0.0标签噪声比例数据鲁棒性实验data_path/path/to/data_dir/数据集路径数据必须配置savedir/path/to/save_dir/模型保存路径输出checkpoint其中很多任务数据相关参数都需要调整

更多文章