训练与预测:开放域抽取
开始训练
使用预置网络进行训练的方式为使用./run_trainer.py入口脚本,通过--param_path参数来传入./examples/目录下的json配置文件。 模型结构是基于ernie2.0_large的阅读理解网络,训练分为以下几个步骤:
- 基于示例的数据集,可以运行以下命令在训练集(train.json)上进行模型训练,并在测试集(dev.json)上进行验证;
# 下载模型
cd ../../models_hub
sh download_ernie_ie_2.0_large_extraction_ch.sh
# 基于ernie2.0_large的阅读理解网络, 配置文件./examples/ernie_ie_2.0_large_extraction_ch.json
cd ../tasks/openie
python run_trainer.py --param_path ./examples/ernie_ie_2.0_large_extraction_ch.json
- 训练运行的日志会自动保存在./log/test.log文件中;
- 训练结束后产生的模型文件会默认保存在./output/ernie_ie_2.0_large_extraction_ch/目录下,其中save_inference_model/文件夹会保存用于预测的模型文件,save_checkpoint/文件夹会保存用于热启动的模型文件;训练中产生评估结果和对评估集的预测结果分别保存在output/evaluations/文件夹和output/predictions文件夹
开始预测
使用预置网络进行预测的方式为使用./run_infer.py入口脚本,通过--param_path参数来传入./examples/目录下的json配置文件。
其预测分为以下几个步骤:
- 基于./examples/ernie_ie_2.0_large_extraction_ch_infer.json训练出的模型默认储存在./output/ernie_ie_2.0_large_extraction_ch/save_inference_model/中,在该目录下找到被保存的inference_model文件夹,修改./examples/ernie_ie_2.0_large_extraction_ch_infer.json中"inference_model_path"参数,如下所示:
{
...
"inference":{
...
"inference_model_path":"output/ernie_ie_2.0_large_extraction_ch/save_inference_model/inference_step_31"
}
}
- 预测运行的日志会自动保存在./output/predict_result.txt、predict_result.txt.gold_read(如果有标注信息,会显示正确的标签)和predict_result.txt.pred_read(预测的标签)文件中。
进阶使用:使用Prompt tuning训练方式
ERNIE-IE 2.0开创性地将各种类型的信息抽取任务统一转化为自然语言的形式,模型的输入是待抽取文本(content)和自然语言描述的抽取目标(prompt),其中抽取目标是用户构造的离散Prompt。Prompt tuning,采用的是连续的Prompt向量,认为参数化的Prompt相对离散的Prompt形式具有更强的表达能力,不关注Prompt的自然语言性,其模型效果更优。
文心ERNIE信息抽取任务在ERNIE-IE 2.0的离散Prompt基础上接入Prompt tuning的训练方式,用户可根据自己的任务需求只需要简单更改json配置即可,我们在相关数据集上进行了效果评测,其中在互联网和医疗场景的小样本数据集上平均提升2.82%, 政务场景的全量数据效果提升0.42%。
-
训练
与ERNIE-IE 2.0的数据准备和训练方式相同,其中仅需更改训练时的--param_path参数传入的./examples/目录下的json配置文件。
python run_trainer.py --param_path ./examples/ernie_ie_2.0_large_extraction_prompt_tuning_ch.json
其中与Prompt tuning相关的json配置说明如下:
{ "dataset_reader": { "train_reader": { ... "config": { ... "extra_params":{ "is_prompt_tuning": true, # 是否采用Prompt tuning的训练方式 "prompt_len": 20, # 连续Prompt embedding的长度 ... } ... } }, ... }, "model": { "type": "DoieMRCExtractionModelPromptTuning", # 相关组网 "is_prompt_tuning": true, # 是否采用Prompt tuning的训练方式 "prompt_len": 20, # 连续Prompt embedding的长度 "use_mlp_prompt": true, # 是否采用mlp网络构造Prompt embedding "use_lstm_prompt": false, # 是否采用lstm网络构造Prompt embedding ... }, ... }
-
预测
与ERNIE-IE 2.0的数据准备和预测方式相同,其中仅需更改预测时的--param_path参数传入的./examples/目录下的json配置文件。
python run_infer.py --param_path ./examples/ernie_ie_2.0_large_extraction_prompt_tuning_ch_infer.json
其中与Prompt tuning相关的json配置说明如下:
{ "dataset_reader": { "predict_reader": { ... "extra_params":{ "is_prompt_tuning": true, # 是否采用Prompt tuning的训练方式 "prompt_len": 20, # 连续Prompt embedding的长度 ... } } } }, "model": { "type": "DoieMRCExtractionModelPromptTuning", # 相关组网 ... }, ... }