1.任务简介
基于ERNIE预训练模型效果上达到业界领先,但是由于模型比较大,预测性能可能无法满足上线需求。
直接使用ERNIE-Tiny系列轻量模型fine-tune,效果可能不够理想。如果采用数据蒸馏策略,又需要提供海量未标注数据,可能并不具备客观条件。
因此,本专题采用主流的知识蒸馏的方案来压缩模型,在满足用户预测性能、预测效果的需求同时,不依赖海量未标注数据,提升开发效率。
文心提供多种不同大小的基于字粒度的ERNIE-Tiny学生模型,满足不同用户的需求。
1.1 模型蒸馏原理
知识蒸馏是一种模型压缩常见方法,指的是在teacher-student框架中,将复杂、学习能力强的网络(teacher)学到的特征表示"知识"蒸馏出来,传递给参数量小、学习能力弱的网络(student)。
在训练过程中,往往以最优化训练集的准确率作为训练目标,但真实目标其实应该是最优化模型的泛化能力。显然如果能直接以提升模型的泛化能力为目标进行训练是最好的,但这需要正确的关于泛化能力的信息,而这些信息通常不可用。如果我们使用由大型模型产生的所有类概率作为训练小模型的目标,就可以让小模型得到不输大模型的性能。这种把大模型的知识迁移到小模型的方式就是蒸馏。
基本原理可参考Hinton经典论文:https://arxiv.org/abs/1503.02531
1.2 ERNIE-Tiny 模型蒸馏
- 模型蒸馏原理可参考论文 ERNIE-Tiny : A Progressive Distillation Framework for Pretrained Transformer Compression 2021。不同于原论文的实现,为了和开发套件中的通用蒸馏学生模型保持一致,我们将蒸馏loss替换为Attention矩阵的KQ loss和 VV loss,原理可参考论文 MiniLM: Deep Self-Attention Distillation for Task-Agnostic Compression of Pre-Trained Transformers 2020 和 MiniLMv2: Multi-Head Self-Attention Relation Distillation for Compressing Pretrained Transformers 2021。实验表明通用蒸馏阶段和任务蒸馏阶段的蒸馏loss不匹配时,学生模型的效果会受到影响。
-
二阶段蒸馏:
- 通用蒸馏(General Distillation,GD):在预训练阶段训练,使用大规模无监督的数据, 帮助学生网络学习到尚未微调的教师网络中的知识,有利于提高泛化能力。为方便用户使用,我们提供了多种尺寸的通用蒸馏学生模型,可用的通用蒸馏学生模型可参考文档:通用模型 - ERNIE3.0 Tiny。
- 任务蒸馏(Task-specific Distillation,TD):使用具体任务的数据,学习到更多任务相关的具体知识。
-
如果已提供的通用蒸馏学生模型尺寸符合需求,用户可以主要关注接下来的任务蒸馏过程。
1.3任务蒸馏步骤
- FT阶段:基于ERNIE 3.0 Large蒸馏模型fine-tune得到教师模型,注意这里用到的教师模型和ERNIE 3.0 Large是两个不同的模型;
- GED阶段(可选):使用fine-tuned教师模型和通用数据继续用通用蒸馏的方式蒸馏学生模型,进一步提升学生模型的效果;
- TD1阶段:蒸馏中间层
- TD2阶段:蒸馏预测层,产出最终的学生模型。
注:关于GED阶段使用的通用数据:开发套件中的通用数据是由开源项目 https://github.com/brightmart/nlp_chinese_corpus 中的中文维基百科语料(wiki2019zh)经过预处理得到。该数据只用于demo展示,实际使用时替换为业务无标注数据效果提升更明显。
2. 常见问题
问题1:怎么修改学生模型的层数?上面提供了多种不同的学生模型,但每个学生模型的层数、hidden size等都是固定的,如果想更改,需要在哪些地方更改?
文心提供了三种不同结构的预训练学生模型,如果修改层数、hidden size等,会导致预训练学生模型参数无法加载,在此情况下,蒸馏出来的学生模型效果无法保证,建议用户还是使用文心提供的预训练模型,不更改模型结构;如果用户认为更改学生模型的结构非常有必要,需要对文心做以下改动:
- 修改TD1阶段json配置文件的pre_train_model配置项,删除预训练学生模型的加载,只保留微调后的教师模型
"pre_train_model": [
# 热启动fine-tune的teacher模型
{
"name": "finetuned_teacher_model",
"params_path": "./output/cls_ernie_3.0_large_ft/save_checkpoints/checkpoints_step_6000"
}
]
- 将json文件中的"student_embedding"替换为自定义的学生模型
"student_embedding": {
"config_path": "../../models_hub/ernie_3.0_tiny_ch_dir/ernie_config.json"
},
- 再次强调,上述修改后,由于无法加载预训练学生模型,蒸馏出来的学生模型效果无法保证。(用户训练数据量到百万样本以上可以考虑尝试一下)
3.数据蒸馏任务
3.1 简介
在ERNIE强大的语义理解能力背后,是需要同样强大的算力才能支撑起如此大规模模型的训练和预测。很多工业应用场景对性能要求较高,若不能有效压缩则无法实际应用。
因此,我们基于数据蒸馏技术构建了数据蒸馏系统。其原理是通过数据作为桥梁,将ERNIE模型的知识迁移至小模型,以达到损失很小的效果却能达到上千倍的预测速度提升的效果。
目录结构
数据蒸馏任务位于 wenxin_appzoo/tasks/data_distillation
├── data
│ ├── dev_data
│ ├── dict
│ ├── download_data.sh
│ ├── predict_data
│ ├── test_data
│ └── train_data
├── distill
│ └── chnsenticorp
│ ├── student
│ └── teacher
├── examples
│ ├── cls_bow_ch.json
│ ├── cls_cnn_ch.json
│ ├── cls_ernie_fc_ch_infer.json
│ └── cls_ernie_fc_ch.json
├── inference
│ ├── custom_inference.py
│ ├── __init__.py
├── model
│ ├── base_cls.py
│ ├── bow_classification.py
│ ├── cnn_classification.py
│ ├── ernie_classification.py
│ ├── __init__.py
├── run_distill.sh
├── run_infer.py
├── run_trainer.py
└── trainer
├── custom_dynamic_trainer.py
├── __init__.py
3.2 数据准备
目前采用三种数据增强策略策略,对于不用的任务可以特定的比例混合。三种数据增强策略包括:
(1)添加噪声:对原始样本中的词,以一定的概率(如0.1)替换为”UNK”标签
(2)同词性词替换:对原始样本中的所有词,以一定的概率(如0.1)替换为本数据集中随机一个同词性的词
(3)N-sampling:从原始样本中,随机选取位置截取长度为m的片段作为新的样本,其中片段的长度m为0到原始样本长度之间的随机值
数据增强策略可参考数据增强,我们已准备好了采用上述3种增强策略制作的chnsenticorp的增强数据。
3.3 离线蒸馏
- 使用预置的ERNIE 2.0 base模型
cd wenxin_appzoo/models_hub
bash download_ernie_2.0_base_ch.sh
- 下载预置的原始数据以及增强数据。
cd wenxin_appzoo/tasks/data_distillation/distill
bash download_data.sh
- 运行以下命令,开始数据蒸馏
cd wenxin_appzoo/tasks/data_distillation
bash run_distill.sh
3.3.1蒸馏过程说明
- run_distill.sh脚本会进行前述的三步:
- run_distill.sh脚本涉及教师和学生模型的json文件,其中:
- 事先构造好的增强数据放在./distill/chnsenticorp/student/unsup_train_aug
- 在脚本的第二步中,使用 ./examples/cls_ernie_fc_ch_infer.json 进行预测:脚本从标准输入获取明文输入,并将打分输出到标准输出。用这种方式对数据增强后的无监督训练预料进行标注。最终的标注结果放在 ./distill/chnsenticorp/student/train/part.1文件中。标注结果包含两列, 第一列为明文,第二列为标注label。
- 在第三步开始student模型的训练,其训练数据放在 distill/chnsenticorp/student/train/ 中,part.0 为原监督数据 part.1 为 ERNIE 标注数据。
- 注:如果用户已经拥有了无监督数据,则可以将无监督数据放入distill/chnsenticorp/student/unsup_train_aug 即可。