资讯 文档
技术能力
语音技术
文字识别
人脸与人体
图像技术
语言与知识
视频技术

通过热启动进行多阶段训练

任务简介

  • 在训练模型时难免会出现中断的情况,我们自然希望能够将训练得到的参数通过保存检查点(checkpoint)的方式保存下来,免得在中断之后重新训练。
  • 模型的热启动指的是通过加载检查点来恢复模型的训练,此时还可以修改网络的参数,比如学习率等来达到多阶段训练的目的。

快速开始

  • 本文档以基于CNN预置网络的文本分类任务为例。
  • 代码目录: textone-public/tasks/text_classification

数据准备

非ERNIE数据
  • 这里我们提供一份已标注的、经过分词预处理的示例数据集。
  • 训练集、测试集、验证集和预测集分别存放在./data目录下的train_data、test_data、dev_data和predict_data文件夹下,对应的示例词表存放在./dict目录下。
训练集/测试集/验证集
  • 训练集、测试集和验证集的数据格式相同,如下所示。数据分为两列,列与列之间用\t进行分隔。第一列为文本,第二列为标签。
房间 太 小 。 其他 的 都 一般 。 。 。 。 。 。 。 。 。         0
LED屏 就是 爽 , 基本 硬件 配置 都 很 均衡 , 镜面 考 漆 不错 , 小黑 , 我喜欢 。         1
差 得 要命 , 很大 股霉味 , 勉强 住 了 一晚 , 第二天 大早 赶紧 溜。         0
预测集
  • 预测集无需进行标签预占位,其格式如下所示:
USB接口 只有 2个 , 太 少 了 点 , 不能 接 太多 外 接 设备 ! 表面 容易 留下 污垢 !
平时 只 用来 工作 , 上 上网 , 挺不错 的 , 没有 冗余 的 功能 , 样子 也 比较 正式 !
还 可以 吧 , 价格 实惠   宾馆 反馈   2008417日   :   谢谢 ! 欢迎 再次 入住 其士 大酒店 。
词表
  • 词表分为两列,第一列为词,第二列为id(从0开始),列与列之间用\t进行分隔。部分词表示例如下所示:
[PAD]	0
[CLS]	1
[SEP]	2
[MASK]	3
[UNK]	4
	5
郑重	6
丁约翰	7
工地	8
神圣	9

训练模型获得检查点

开始训练
  • 使用预置网络进行训练的方式为使用./run_with_json.py入口脚本,通过--param_path参数来传入./examples/目录下的json配置文件。
  • 以基于预置CNN网络的文本分类模型为例,训练分为以下几个步骤:

    1. 请在./env.sh中根据提示配置相应环境变量的路径
    2. 基于示例的数据集,可以运行以下命令在训练集(train.txt)上进行模型训练,并在测试集(test.txt)上进行验证;
    # CNN 模型
    # 需要提前参照env.sh进行环境变量配置,在当前shell内去读取
    source env.sh
    # 基于json实现预置网络训练。其调用了配置文件./examples/cls_cnn_ch.json
    python run_with_json.py --param_path ./examples/cls_cnn_ch.json
    1. 训练运行的日志会自动保存在./log/test.log文件中;
    2. 训练中以及结束后产生的模型文件会默认保存在./output/cls_cnn_ch/目录下,其中save_checkpoints/文件夹会保存用于热启动的检查点。

通过加载检查点来热启动

继续训练
  • 以基于预置CNN网络所训练出的检查点为例,其热启动分为以下几个步骤:

    1. 基于./examples/cls_cnn_ch.json训练出的检查点默认储存在./output/cls_cnn_ch/save_checkpoints/中,在该目录下找到被保存的checkpoints文件夹,例如checkpoints_step_251/;
    2. 在./examples/cls_cnn_ch.json中修改"load_checkpoint"参数,填入上述模型保存路径,如下所示:
    {
      ...
      "trainer":{   
         ...
        "load_checkpoint":"./output/cls_cnn_ch/save_checkpoints/checkpoints_step_251" ,
         ...
      }
    }
    1. 基于示例的数据集,可以运行以下命令通过热启动继续之前的训练:
    # 基于json实现预置网络训练。其调用了配置文件./examples/cls_cnn_ch.json,并且在checkpoints_step_251检查点基础上进行了热启动
    python run_with_json.py --param_path ./examples/cls_cnn_ch.json
    1. 训练运行的日志会自动保存在./log/test.log文件中。