资讯 文档
技术能力
语音技术
文字识别
人脸与人体
图像技术
语言与知识
视频技术

进阶任务:小样本文本分类

本页提供小样本文本分类的示例代码。建议提前阅读:

30s上手文心

文本分类任务:适用场景

文本分类任务:准备工作

文本分类任务:开始训练与预测

任务简介

  • 小样本文本分类是指训练集中样本数较少情况下的文本分类任务。特别是当类别数较多(大于50),样本数较小(每个类别有5个样本左右)的时候,能够显著提升其相对直接使用ERNIE分类网络finetune的效果。
  • 本案例实现的技术是将分类任务变为pairwise匹配任务,样本数可相应增多,有助于小样本学习。因此需要将分类训练集处理得到匹配训练集,训练匹配双塔网络,loss可由cosine或l2距离等计算方式得到,提供的示例model文件采取余弦相似度的距离计算方式
  • 测试时有两种计算方式:

    1. 将待测试数据与训练集各样本计算相似度,取相似度最高的训练数据类别标签作为测试数据的类别标签。

      • 就相当于KNN算法中K为1的情况。
      • 当样本数较小时建议采用这个方案。
    2. 将训练集同类别数据计算类平均表示作为该类的表示,将待测试数据与训练集各类别的类平均表示计算相似度,取相似度最高的训练数据类别作为测试数据的类别标签。

30s上手小样本文本分类

其基本流程与30s上手文心介绍的一致。

进入目录

首先进入文本分类的示例代码目录./wenxin/tasks/few_shot_text_classification/ 。

cd ./wenxin/tasks/few_shot_text_classification/

代码结构说明

.                                                      
├── data                     ## 示例数据文件夹,包括任务所需训练集(train_data)、测试集(test_data)、验证集(dev_data)和预测集(predict_data),以及处理数据的代码文件
│   ├── data_for_pairwise_match.py  ## 处理分类数据生成pairwise匹配数据
│   ├── process_data_and_train.sh   ## 处理数据并训练的脚本文件
│   ├── dev_data
│   │   └── dev.txt
│   ├── predict_data
│   │   └── predict.txt
│   ├── test_data
│   │   └── test.txt
│   └── train_data
│       └── train.txt
├── examples
│   ├── fstc_ernie_sim_1.0_pairwise_simnet_ch.json
│   └── fstc_ernie_sim_1.0_pairwise_simnet_ch_infer.json
├── __init__.py
├── run_infer.py       ## 只依靠json进行模型预测的入口脚本
└── run_with_json.py   ## 只依靠json进行模型训练的入口脚本

准备训练数据

小样本文本分类的预置示例数据目录为./data,其中./data/train_data为训练集、./data/dev_data为验证集、./data/test_data为测试集、./data/predict_data为预测集。

如训练集、验证集、测试集数据所示,其格式与文本分类任务:准备工作中介绍的ERNIE数据集的格式一致:数据分为两列,列与列之间用\t进行分隔。第一列为文本,第二列为标签。文本列不需要分词,且长度不受限制;预测集格式也如文本分类任务:准备工作中介绍的预测集的格式一致:仅一列为文本。

训练数据增强(附加)

背景: 数据增强是扩充数据样本规模的一种有效地方法,数据的规模越大,模型才能够有着更好的泛化能力。因此增加了基于回译的数据增强方法,在训练前可以自行选择对训练数据进行增强,提升模型效果。

  1. 申请百度翻译api(全量免费QPS=1的接口)的权限,网址:https://api.fanyi.baidu.com/product/11 将申请的appid和secret字段将data_aug/back_translate_thread.py中相应字段替换
  2. 运行数据增强脚本
cd wenxin/data/data_aug
python back_translate_thread.py

会在../data/train_data/目录下面生成 train.txt.ag 数据增强后的数据,然后和原来的训练数据合并后打乱顺序

附: 提供了其他公开数据集在public_data目录下,其中包括amazon、huffpost和reuters。

方法 rcv
(51way-5shot)
amazon
(24way5shot)
huffpost
(41way5shot)
reuters
(31way-5shot)
平均提升
baseline 67.6% 70% 36% 87.3% ----
基于回译 71% 73% 37.3% 89.2% +2.4%

启动训练

  1. ERNIE预训练模型下载

    # 以ernie_1.0_sim预训练模型为例
    # 进入model_files目录
    cd ../model_files/
    # 运行下载脚本
    sh download_ernie_sim_1.0_ch.sh
  2. 需生成pairwise数据并训练模型,则需先运行pairwise数据处理代码,数据生成成功后训练模型,其中json配置文件为./examples/fstc_ernie_sim_1.0_pairwise_simnet_ch.json。

    # 处理分类数据为pairwise数据格式
    python ./data/data_for_pairwise_match.py
    # 模型训练
    python run_with_json.py --param_path ./examples/fstc_ernie_sim_1.0_pairwise_simnet_ch.json

    若已生成了pairwise数据则可直接训练模型

    # 模型训练
    python run_with_json.py --param_path ./examples/fstc_ernie_sim_1.0_pairwise_simnet_ch.json

启动预测

  1. 基于示例的数据集,可以运行以下命令在预测集(./data/predict_data)上进行预测:

    # 基于json实现预测。其调用了配置文件./examples/fstc_ernie_sim_1.0_pairwise_simnet_ch_infer.json
    python run_infer.py --param_path ./examples/fstc_ernie_sim_1.0_pairwise_simnet_ch_infer.json

    预测运行的日志会自动保存在./output/predict_result.txt文件中。

  2. 预测结果为每个text的embedding向量,格式为text\t[embedding],可利用预测得到的预测集的embedding和训练集的embedding进行相似度计算得出分类结果,具体计算方式可参考ernie_recall_siamese_pairwise.py中的get_metrics函数。

进阶说明

  • pairwise数据处理说明

    在训练前需先将分类数据处理为pairwise匹配数据,./data/data_for_pairwise_match.py可处理分类数据,默认处理的数据文件路径为训练集./data/train_data、测试集./data/test_data、验证集./data/dev_data。其中根据需求可调节生成pairwise数据的方式,neg_sampling_rate为生成pairwise格式数据时负样本的采样率max_data_num为pairwise格式训练数据的最大数据量,-1表示不设置最大数据量。生成后的数据各数据集存放的路径为:

    训练集:./data/train_pairwise/ 测试集:./data/test_combine/ 验证集:./data/dev_combine/

    • 处理后将对应的文件夹路径填写在训练所使用的json配置中。
    • 具体生成的pairwise数据格式如下:

      pairwise训练集

      • 训练集一行三列(三个key),列之间用\t分隔,都为text。第一列和第二列为同类别的text,第三列为其他类别的负样本text。

        WernerDieter前首席执行官曼内斯曼(MannesmannAG)表示,将寻求法律咨询,以决定是否接受检察官在周二出版的《BoersenZeitung》报纸上提供的交易	法国住房部给出了以下未经调整的数据:7月住房开工率5月底5月底7月住房批准总数量同比变动百分比非住宅建筑批准面积百万年同比变动开工率百万巴黎新闻编辑室	墨西哥足球锦标赛星期二在蒙特雷举行

      pairwise所需的测试集/验证集

      • 测试集和验证集格式一样,一行三列(三个key),列之间用\t分隔。第一列文本text,第二列为文本对应的分类label,第三列为标识该数据是否为训练集,1为训练集的数据,0为测试集/验证集的数据。
      • 具体操作为将训练集(./data/train_data/)中的数据与测试集/验证集(./data/test_data 或 ./data/dev_data)中的数据进行拼接,通过第三列标识符进行区分。在具体测试过程中,会分别读出训练集和测试集/验证集数据,根据训练集的embedding与测试集/验证集的embedding进行相似度的计算得到测试集/验证集样本的类别。

        捷克国家银行周二公布的捷克股市综合指数上涨8点,捷克国家银行表示,7-10个行业指数上涨,食品饮料行业指数上涨最多,布拉格新闻编辑室    3   0
上一篇
开始训练与预测
下一篇
进阶任务:数据增强文本分类