开始训练与预测
更新时间:2022-07-05
环境安装
目录结构
阅读理解任务位于/wenxin/tasks/mrc
.
├── __init__.py
├── examples ## 各典型网络的json配置文件
├── cls_ernie_2.0_base_dureader_ch.json
├── run_with_json.py ## 只依靠json进行模型训练的入口脚本
├── run_mrc.py ## 阅读理解任务的脚本
└── package/task_data/dureader ## 示例数据文件夹,包括各任务所需训练集(train.json)、测试集(test.json)和验证集(dev.json)
├── dev.json
├── test.json
└── train.json
预置Reader配置
通过json文件中的dataset_reader配置模型的训练数据读取,以下为./examples/cls_ernie_2.0_base_dureader_ch.json中抽取出来的dataset_reader部分配置,并通过注释说明。
{
"dataset_reader": {
"train_reader": { ## 训练、验证、测试各自基于不同的数据集,数据格式也可能不一样,可以在json中配置不同的reader,此处为训练集的reader。
"name": "train_reader",
"type": "MRCReader", ## 采用MRCReader,其封装了常见的读取tsv文件、组batch等操作。
"fields": [], ## mrc任务中未使用fields_reader。 域(field)是文心的高阶封装,对于同一个样本存在不同域的时候,不同域有单独的数据类型(文本、数值、整型、浮点型)、单独的词表(vocabulary)等,可以根据不同域进行语义表示,如文本转id等操作,field_reader是实现这些操作的类。
"config": {
"data_path": "./package/task_data/dureader/train.json", ## 训练数据train_reader的数据路径,写到具体文件名。
"shuffle": true,
"batch_size": 16,
"epoch": 5,
"sampling_rate": 1.0,
"extra_params":{ ## MRCReader所需的额外配置参数
"vocab_path":"../model_files/dict/vocab_ernie_2.0_base_ch.txt", ## 指定模型运行所需词表。
"label_map_config":"", ## 指定文本label的映射文件,无需映射则为空。
"max_seq_len":512, ## 指定文本的最大长度。
"do_lower_case":true, ## 指定tokenizer时是否对英文字母做小写处理。
"in_tokens":false, ## 指定batch的组成是已token为单位还是sequence为单位。
"tokenizer": "FullTokenizer", ## 指定tokenizer为FullTokenizer。
"for_cn": true, ## 指定训练文本的语言是中文。
"task_id": 0, ## 指定task_id,默认为0即可。
"doc_stride": 128, ##当context的文本序列长度超过max_seq_len,需要将context分块,改参数指定分块间隔。
"max_query_length": 64, ## 问题文本序列的最大长度。
"use_multi_gpu_test":true ## 使用多卡预测
}
}
},
"test_reader": { ## 若要评估测试集,需配置test_reader,其配置方式与train_reader类似。
……
},
"dev_reader": { ## 若要评估验证集,需配置dev_reader,其配置方式与test_reader类似。
……
}
},
……
}
开始训练
- 如您使用镜像开发套件,您可直接进入下一步骤。如您将文心开发套件与本地已有的开发环境相结合,您需要在./env.sh中配置对应的环境变量,并执行source env.sh ,如需了解更多详情,请参考环境配置。
-
需提前下载对应的ERNIE预训练模型,例如:
# ernie_2.0_base 模型下载 # 进入model_files目录 cd ../model_files/ # 运行下载脚本 sh download_ernie_2.0_base_ch.sh
-
模型训练的入口脚本为./run_with_json.py , 通过--param_path参数来传入./examples/目录下的json配置文件。例如:
python run_with_json.py --param_path ./examples/cls_ernie_2.0_base_dureader_ch.json
- 训练运行的日志会自动保存在./log/test.log文件中。
- 训练中以及结束后产生的模型文件会保存在json配置文件中的output_path字段值的目录下,其中save_checkpoint文件夹会保存用于热启动的模型文件, 目前阅读理解任务暂不支持save_inference_model的配置。