(New)微调策略升级
更新时间:2022-12-17
策略简介
-
在下游任务进行fine-tune时,wenxin提供了3中微调策略和一种模型融合策略,分别是:
- 对抗训练策略(adv)
- Rdrop策略
- Layer decay策略
- Fine-tuned model average策略
对抗训练策略(adv)简介
- 对抗训练是一种引入噪声的训练方式,可以对参数进行正则化,提升模型鲁棒性和泛化能力。其可以应用于预训练阶段和finetune阶段,当前已将该任务适配至文本分类任务、文本匹配任务。
-
学术界当前对抗训练中引入噪声的方式分为两种:
- 针对原始文本的噪声引入,例如,token替换
- 针对embedding的噪声引入,例如,在word embedding上直接添加高斯白噪声
- 本任务在第二种方式的基础上进行了进一步的探索,即将噪声引入至attention weight。如下图所示:
- ⚠️注意:该算法在每一个step的训练过程中会前向计算3次,反向计算2次,因此显存占用和训练时长会大大增加,显存占用对比参考如下:
#同样超参下,不加adv策略的显存占用
base:4698MiB / 32510MiB
#同样超参下,加上adv策略的显存占用
+adv:11960MiB / 32510MiB (adv策略每训练一步会进行3次前向和2次反向)
Rdrop策略简介
- Rdrop策略是通过拉近同一个样本经过不同Dropout模型后的输出分布,从而提高模型的鲁棒性的策略,该策略相当于在模型的所有参数上加上正则化约束,具体论文链接:https://arxiv.org/abs/2106.14448
- 在Rdrop策略中使用KL散度来度量同一个样本经过不同Dropout模型的输出分布差异,具体是算法流程如下:
- 该算法并不是同一个样本输出模型分别输入模型两次,而是将样本进行了复制,然后在batch size维度进行拼接,然后在模型输出位置进行拆分实现,这样可以缩短训练时间。
- ⚠️注意:由于rdrop策略对输入样本进行复制扩充,因此显存占用会有一定的上升,显存占用对比参考如下:
#同样超参下,不加adv策略的显存占用
base:4698MiB / 32510MiB
#同样超参下,加上rdrop策略的显存占用
+rdrop:8400MiB / 32510MiB
Layer decay
- Layer decay策略是指分别为模型的每层参数设置一个权重衰减系数,并使用该系数与学习率的乘积作为该层新的学习率,同时该权重衰减系数是根据模型的层数以指数衰减的形式进行计算,具体计算公式如下:
new_lr=lr*decay_rate ** (n_layers + 2 - depth)
- lr为模型的学习率
- decay_rate为衰减指数
- n_layers为模型总层数
- depth为当前参数所在模型的层数
- new_lr为当前参数的学习率
- ⚠️注意:使用Layer decay策略策略时,设置的学习率需要比正常学习率要大,例如不加Layer decay策略训练时学习率为5e-5,那么加上该策略学习率需要设置为1e-4。
Fine-tuned model average策略
Fine-tuned model average策略是指将多个fine-tuned后的模型进行参数平均
策略测评效果
- 文本分类(clue-iflytek数据集)
策略 | ernie3.0 base(不加任何策略) | +fma | +adv | +rdrop | +layer_decay | +adv + layer_decay + fma | +rdrop + layer_decay + fma |
---|---|---|---|---|---|---|---|
acc(%) | 61.14 | 61.52(+0.38) | 61.83(+0.68) | 61.37(+0.23) | 61.56(+0.42) | 63.29(+2.51) | 62.37(+1.23) |
- 文本匹配(单塔pointwise)(clue-afqmc数据集)
策略 | ernie3.0 base | +fma | +adv | +rdrop | +layer_decay | +adv + layer_decay + fma | +rdrop + layer_decay + fma |
---|---|---|---|---|---|---|---|
acc(%) | 76.11 | 76.90(+0.79) | 76.39(+0.28) | 76.30(+0.19) | 76.44(+0.33) | 77.36(+1.25) fma: 76.81(+0.70) | 77.22(+1.11) |
- ⚠️注意:rdop和layer_decay一起使用可能导致性能变差。