资讯 文档
技术能力
语音技术
文字识别
人脸与人体
图像技术
语言与知识
视频技术

进阶任务:Prompt tuning文本分类

任务简介

近年来,Pretain+Finetuning的训练方式已广泛流行在NLP领域,近期来出现了新的训练方式——Prompt Learning, Prompt Learning是指对输入文本信息按照特定模板进行处理,把任务重构成一个更能充分利用预训练语言模型处理的形式。

Prompt Learning的做法可由Prompt的构造方式分为以下几种方式:

  • Hand-crafted,人工模板构造Prompt的方式,需要人工的参与,而且不同的Prompt的构造方式模型的效果差距较大,模型效果不稳定。
  • Automate prompt design (离散),该部分工作提出自动Prompt的构造方式,根据任务搜索合适的prompt的token。但因为这些token仍是具体的文字,从深度学习的角度来说,这种方式的效果是次优的。
  • Prompt tuning,采用的是连续的Prompt向量,认为参数化的Prompt相对离散的Prompt形式具有更强的表达能力,不关注Prompt的自然语言性,其模型效果更优。

文心ERNIE分类任务接入Prompt tuning的训练方式,用户可根据自己的任务需求只需要简单更改json配置即可。我们在FewCLUE的文本分类任务中进行了效果评测:

eprstmt tnews csldcp iflytek
Finetune 82.438% 53.75% 55.687% 47.058%
Prompt tuning 85.2464% 54.328% 56.0476% 48.0868%

数据集样本量:

Train Test
eprstmt 32 610
tnews 240 2010
csldcp 536 1784
iflytek 928 2279

(⚠️注意:eprstmt:电商评论情感分析;tnews:新闻分类;csldcp:科学文献学科分类;iflytek:APP应用描述主题分类,其中tnews、csldcp、iflytek三个数据集难度较大分类类别较多,所以准确率相对较低)

准备数据

示例数据位于text_classification/data/fewclue/eprstmt

.
├── dict	# label词表文件
│   └── label_map.txt
├── predict_data	# 预测数据
│   └── test.txt
├── test_public	# 测试数据
│   └── test_public.txt
└── train_1	# 训练数据
    └── train_1.txt

其中训练、测试、预测数据集可参考正常的文本分类数据集格式:数据准备

需要用户额外准备一份label的词表文件,为语言模型待预测的mask处的值,用\t分割,要求字数一致,如:

01
故事	0
文化	1
娱乐	2
体育	3
财经	4

模型准备

模型均存放于wenxin_appzoo/wenxin_appzoo/models_hub文件夹下,进入文件夹执行sh download_ernie_3.0_base_ch.sh下载预训练模型参数,字典,模型配置文件。

训练参数配置

  • 文心中的各种参数都是在json文件中进行配置的,你可以通过修改所加载的json文件来进行参数的自定义配置。json配置文件主要分为三个部分:dataset_reader(数据部分)、model(网络部分)、trainer(训练任务)或inference(预测部分),在模型训练的时候,json文件中需要配置dataset_reader、model和trainer这三个部分;在预测推理的时候,json文件中需要配置dataset_reader、inference这两个部分。
  • Prompt tuning的json配置大体与正常的文本分类json配置类似,以下介绍与prompt tuning相关的json配置。
{
  "dataset_reader": {
    "train_reader": {
      "name": "train_reader",
      "type": "BasicDataSetReader",
      "fields": [
        {
          "name": "text_a",
          "data_type": "string",
          "reader": {
            "type": "ErniePromptTextFieldReader"	# 采用prompt tuning的reader
          },
          "tokenizer": {
            "type": "FullTokenizer",
            "split_char": " ",
            "unk_token": "[UNK]"
          },
          "need_convert": true,
          "vocab_path": "../../models_hub/ernie_3.0_base_ch_dir/vocab.txt",
          "label_map_path": "./data/fewclue/eprstmt/dict/label_map.txt",	# 需填入label的词表文件
          "max_seq_len": 512,
          "truncation_type": 0,
          "padding_id": 0,
          "is_prompt_tuning": true,	# 是否采用prompt tuning的组网方式
          "prompt_len": 100,	# prompt embedding的长度
          "is_mask_res": 1,	# 是否采用mask预测的方式,默认开启。
          "mask_res_len": 1,	# mask的label长度,如第一个label_map示例为1,第二个示例该值需改为2
          "prompt": "text_a, "	# 用户可构造输入prompt,其中待分类的文本用text_a指代。
        },
        {
          "name": "label",
          "data_type": "int",
          "reader": {
            "type": "ScalarVerbalizerFieldReader"	# 采用verbalizer的reader
          },
					......
        }
      ],
      "config": {
        "data_path": "./data/fewclue/eprstmt/train_1",
				......
      }
    },
    "test_reader": {
			......
  },
  "model": {
    "type": "ErnieClassificationPromptMultiMask",
    "need_mask_id": true,	# 若采用mask预测的方式,该处需设为true
    "is_prompt_tuning": true,	# 是否为prompt tuning训练方式
    "prompt_len": 100,	# prompt embedding的长度
    "use_mlp_prompt": true,	# 是否采用mlp网络构造Prompt embedding
    "use_lstm_prompt": false,	# 是否采用lstm网络构造Prompt embedding
    "is_dygraph": 1, # 仅支持动态图
    "num_labels": 2, # 类别数
    ......
  },
  "trainer": {
    "type": "CustomDynamicTrainer",
    "is_freeze": false,	# 是否冻结预训练网络参数
		......
  }
}

开始训练

在开始训练前需要确定您已经配置好env.sh文件(否则会报错 No such file or directory: 'auth.txt'),并且在env.sh目录下执行以下命令source env.sh(注意:请确保env文件中的各个路径配置正确,每次打开新的终端窗口都需要执行source env.sh命令,env.sh文件的配置请参考:环境配置

目录结构

Prompt tuning文本分类任务位于/wenxin/tasks/text_classification,仅支持动态图。

  • 指定AUTH的配置路径,找到你自己的auth.txt,获取它的绝对路径,如/home/zhangsang/wenxin/wenxin_appzoo,那么将这个绝对路径设给环境变量WENXIN_AUTH_PATH,如下所示:

    export WENXIN_AUTH_PATH=/home/zhangsang/wenxin/wenxin_appzoo
  • 进入指定目录

    cd wenxin_appzoo/wenxin_appzoo/tasks/text_classification
  • 如果是gpu训练,先指定卡号,如下所示指定在0号卡上进行训练

    export CUDA_VISIBLE_DEVICES=0
  • 模型训练的入口脚本为./run_trainer.py , 通过--param_path参数来传入./examples/目录下的json配置文件。 python run_trainer.py --param_path=./examples/cls_prompt_tuning_ernie_fc_ch.json

训练运行的日志会自动保存在./log/test.log文件中。

训练中以及结束后产生的模型文件会默认保存在./output/cls_ernie_3.0_base_prompt_fc_ch_dy目录下,其中save_inference_model/文件夹会保存用于预测的模型文件,save_checkpoint/文件夹会保存用于热启动的模型文件。

开始预测

  • 指定AUTH的配置路径,找到你自己的auth.txt,获取它的绝对路径,如/home/zhangsang/wenxin/wenxin_appzoo,那么将这个绝对路径设给环境变量WENXIN_AUTH_PATH,如下所示:

    export WENXIN_AUTH_PATH=/home/zhangsang/wenxin/wenxin_appzoo
  • 进入指定目录

    cd wenxin_appzoo/wenxin_appzoo/tasks/text_classification
  • 如果是gpu训练,先指定卡号,如下所示指定在0号卡上进行预测

    export CUDA_VISIBLE_DEVICES=0
  • 选定配置好的json文件,把你将要预测的模型对应的inference_model文件路径填入json文件的inference_model_path参数中。(其中网络中gather_nd op,在预测时有限制,目前预测的batch_size需小于等于训练时的batch_size,若用户想开大预测的batch_size,可load训练再save,如有疑问可联系客服同学解决。)
  • 模型训练的入口脚本为./run_infer.py , 通过--param_path参数来传入./examples/目录下的json配置文件。

    python run_infer.py --param_path=./examples/cls_prompt_tuning_ernie_fc_ch_infer.json

预测运行的日志会自动保存在./output/predict_result.txt文件中。

上一篇
进阶任务-使用ERNIE-Word进行文本分类
下一篇
进阶任务:使用ERNIE-Doc进行长文本分类