使用热启动进行多阶段训练
简介
在模型训练过程中难免会出现中断的情况,我们自然希望能够将训练得到的参数通过保存检查点(checkpoints)的方式保存下来,免得在中断之后重新训练。模型的热启动指的就是通过加载检查点文件(checkpoints)来恢复模型的训练,此时还可以修改网络的超参数,比如学习率等来达到多阶段训练(持续训练)的目的。
这里我们以基于CNN预置网络的文本分类任务为例,介绍如何使用热启动进行多阶段训练。
环境安装
准备数据
非ERNIE数据:这里我们提供一份已经标注的、分词之后的示例数据集。
- 文心文本分类任务非ERNIE的训练数据、测试数据、验证数据和预测数据分别存放在./data文件夹中的train_data、test_data、dev_data和predict_data文件夹下,对应的示例词典存放在dict文件夹下。
- 示例数据为二分类,标签标注分别为0和1,文本经过分词预处理。
-
训练数据、测试数据和验证数据格式相同。数据分为两列,第一列为分词处理后的文本,第二列为标签。列与列之间用\t进行分隔,如下所示:
房间 太 小 。 其他 的 都 一般 。 。 。 。 。 。 。 。 。 0 LED屏 就是 爽 , 基本 硬件 配置 都 很 均衡 , 镜面 考 漆 不错 , 小黑 , 我喜欢 。 1 差 得 要命 , 很大 股霉味 , 勉强 住 了 一晚 , 第二天 大早 赶紧 溜 0
-
预测数据没有标签预占位,其格式如下所示:
USB接口 只有 2个 , 太 少 了 点 , 不能 接 太多 外 接 设备 ! 表面 容易 留下 污垢 ! 平时 只 用来 工作 , 上 上网 , 挺不错 的 , 没有 冗余 的 功能 , 样子 也 比较 正式 ! 还 可以 吧 , 价格 实惠 宾馆 反馈 2008年4月17日 : 谢谢 ! 欢迎 再次 入住 其士 大酒店 。
-
词表分为两列,第一列为词,第二列为id(从0开始),列与列之间用\t进行分隔。文心的词表中,[PAD]、[CLS]、[SEP]、[MASK]、[UNK]这5个词是必须要有的,若用户自备词表,需保证这5个词是存在的。部分词表示例如下所示:
[PAD] 0 [CLS] 1 [SEP] 2 [MASK] 3 [UNK] 4 郑重 5 天空 6 工地 7 神圣 8
通过训练模型获得checkpoints文件
这里我们以基于预置CNN网络的文本分类模型为例,通过训练模型,获得对应的checkpoints文件,主要分为以下几个步骤:
- 配置CNN文本分类的网络的训练任务的json文件:json文件位于./tasks/text_classification/examples/cls_cnn_ch.json,详细配置内容可参考实战演练:使用文心进行模型训练中的参数配置章节。
-
启动模型训练:
# CNN 模型 # 需要提前参照env.sh进行环境变量配置,在当前shell内去读取 source env.sh # 基于json实现预置网络训练。其调用了配置文件./examples/cls_cnn_ch.json python run_with_json.py --param_path ./examples/cls_cnn_ch.json
-
产出checkpoints文件:训练中以及结束后产生的模型文件会默认保存在./output/cls_cnn_ch/目录下,其中save_inference_model/文件夹会保存用于预测的模型文件,save_checkpoint/文件夹会保存用于热启动的checkpoints文件,结果如下所示:
├── save_checkpoints │ ├── checkpoints_step_251 ## 神经网络中的所有参数文件,可以用来做热启动。 │ │ ├── embedding_0.w_0 │ │ ├── embedding_0.w_0_beta1_pow_acc_0 │ │ ├── .... │ │ ├── fc_0.b_0 │ │ ├── .... │ │ ├── sequence_conv_0.b_0 │ │ ├── model.meta ## meta文件,存放了文心自定义的网络基本介绍信息。 │ │ ├── .... ├── save_inference_model │ ├── inference_step_251 │ │ ├── infer_data_params.json ## 模型预测过程中需要解析的字段,模型训练过程中自动生成,模型预测过程中自动解析。其写入内容由model文件(组网文件)中的forward方法的返回值决定。 │ │ ├── model ## paddle框架保存出来的模型结构文件 │ │ ├── model.meta ## meta文件,存放了文心自定义的网络基本介绍信息。同checkpoints中的model.meta │ │ ├── params ## 经过优化裁剪之后的参数文件(所有参数压缩保存在一个文件中)
使用checkpoints文件进行热启动
- 基于./examples/cls_cnn_ch.json训练出的检查点默认储存在./output/cls_cnn_ch/save_checkpoints/中,在该目录下找到被保存的checkpoints文件夹,例如checkpoints_step_251/
-
在./examples/cls_cnn_ch.json中"trainer"部分修改"load_checkpoint"参数,填入上述模型保存路径,如下所示:
{ ... "trainer": { ... "load_checkpoint": "./output/cls_cnn_ch/save_checkpoints/checkpoints_step_251", ## 必填参数,待进行热启动的checkpoints目录地址。 ... }
-
如果想使用新设置的学习率进行训练,则在./examples/cls_cnn_ch.json中"trainer"部分修改"load_parameters"参数,填入上述模型保存路径,并在"model"部分修改"learning_rate"参数,如下所示:
{ ... "model": { "type": "CnnClassification", ## 文心采用模型(models)的方式定义神经网络的基本操作,本例采用预置的模型CnnClassification实现文本分类,具体网络可参考models目录。 "optimization": { "learning_rate": 2e-05 ## 预置模型的优化器所需的参数配置,如学习率等。 }, ... }, "trainer": { ... "load_parameters": "./output/cls_cnn_ch/save_checkpoints/checkpoints_step_251", ... }
-
基于示例的数据集,可以运行以下命令通过热启动继续之前的训练:
# 基于json实现预置网络训练。其调用了配置文件./examples/cls_cnn_ch.json,并且在checkpoints_step_251检查点基础上进行了热启动 python run_with_json.py --param_path ./examples/cls_cnn_ch.json
- 训练运行的日志会自动保存在./log/test.log文件中,生成的模型文件保存在./output/cls_cnn_ch目录下。