(New)进阶任务:模型可解释性-特征分析
任务简介
模型的预测一般情况下是黑盒的,我们不知道模型为什么预测出这个结果,所以我们需要可信分析的输出结果帮助用户理解模型做出这种判断的机制(模型可解释性),诊断模型的问题所在并提出优化方案(可信增强工具)。特征分析就是模型可解释性方向比较常用的方法,根据模型预测结果从输入文本中提取模型预测所依赖的证据,即输入文本中那些词(token)是比较重要的。
快速开始
代码结构
任务位于/wenxin_appzoo/tasks/text_classification目录下,是分类任务的一个进阶使用,目录结构如下:
.
├── data
│ ├── analysis_data ## 可信分析(实例分析、特征分析)的demo数据
│ │ ├── example_analysis ## 实例分析
│ │ │ └── demo.txt
│ │ └── feature_analysis ## 特征分析
│ │ └── demo.txt
......
│ ├── dict
│ │ ├── sentencepiece.bpe.model
│ │ └── vocab.txt
......
│ ├── predict_data
│ │ └── infer.txt
│ ├── test_data
│ │ └── test.txt
│ ├── train_data
│ │ └── train.txt
......
├── data_set_reader
│ ......
├── ernie_doc_infer_server.py
├── ernie_run_infer_server.py
├── examples
│ ├── cls_ernie_fc_ch_data_distribution_correct.json ## 获取数据集样本分布情况任务对应的配置文件
│ ├── cls_ernie_fc_ch_example_analysis.json ## 实例分析任务的配置文件
│ ├── cls_ernie_fc_ch_feature_analysis.json ## 特征分析任务的配置文件
│ ├── cls_ernie_fc_ch_find_dirty_data.json ## 筛选脏数据任务的配置文件
│ ├── cls_ernie_fc_ch_infer.json
│ ├── cls_ernie_fc_ch_infer_with_active_learning.json
│ ├── cls_ernie_fc_ch_infer_with_iflytek.json
│ ├── cls_ernie_fc_ch.json
│ ......
├── find_dirty_data.py ## 基于实例分析的结果查找脏数据的脚本
├── inference
│ ......
├── model
│ ├── base_cls.py
│ ├── ernie_classification.py
│ ......
├── reader
│ ......
├── run_balance_data.py ## 均衡训练集数据分布的脚本
├── run_data_distribution_correct.py ## 获取数据集样本分布情况的脚本
├── run_example_analysis.py ## 运行实例分析的启动脚本
├── run_features_analysis.py ## 运行特征分析的启动脚本
├── ......
├── trainer
│ ......
└── trust_analysis
├── gradient_similarity_wenxin.py ## 基于梯度相似度的实例分析脚本
├── integrated_gradients_wenxin.py ## 基于积分梯度的特征分析脚本
└── respresenter_point_wenxin.py ## 基于表示点方法的实例分析脚本
准备工作
在进行模型特征分析之前,首先需要训练出一个可预测的模型,本次文心套件中仅提供在文本分类任务上的特征分析功能,所以用户首先需要在分类任务中训练出一个模型,这里以ERNIE 3.0 Base模型举例。
模型准备
预训练模型均存放于wenxin_appzoo/wenxin_appzoo/models_hub文件夹下,进入文件夹下,执行sh download_ernie_3.0_base_ch.sh 即可下载ERNIE 3.0 Base模型的模型参数、字典、网络配置文件。
训练准备
数据准备可以参考:V2.1.0准备工作 训练配置文件如下(examples/cls_ernie_fc_ch.json):
{
"dataset_reader": {
"train_reader": {
"name": "train_reader", ## 训练、验证、测试各自基于不同的数据集,数据格式也可能不一样,可以在json中配置不同的reader,此处为训练集的reader。
"type": "BasicDataSetReader", ## 采用BasicDataSetReader,其封装了常见的读取tsv、txt文件、组batch等操作。
"fields": [## 域(field)是文心的高阶封装,对于同一个样本存在不同域的时候,不同域有单独的数据类型(文本、数值、整型、浮点型)、单独的词表(vocabulary)等,可以根据不同域进行语义表示,如文本转id等操作,field_reader是实现这些操作的类。
{
"name": "text_a", ## 文本分类任务的第一个特征域,命名为"text_a"。
"data_type": "string",
"reader": {
"type": "ErnieTextFieldReader"
},
"tokenizer": {
"type": "FullTokenizer", ## 指定text_a分词器,除ernie-tiny模型之外,其余基本上都固定为FullTokenizer分词器。
"split_char": " ",
"unk_token": "[UNK]"
},
"need_convert": true,
"vocab_path": "../../models_hub/ernie_3.0_base_ch_dir/vocab.txt", ## 词表地址
"max_seq_len": 512,
"truncation_type": 0,
"padding_id": 0
},
{
"name": "label",
"data_type": "int",
"reader": {
"type": "ScalarFieldReader"
},
"tokenizer": null,
"need_convert": false,
"vocab_path": "",
"max_seq_len": 1,
"truncation_type": 0,
"padding_id": 0,
"embedding": null
}
],
"config": {
"data_path": "./data/train_data", ## 数据路径。
"shuffle": true,
"batch_size": 8,
"epoch": 10,
"sampling_rate": 1.0,
"need_data_distribute": true,
"need_generate_examples": false
}
},
"test_reader": { ## 此处为测试集的reader。
"name": "test_reader",
"type": "BasicDataSetReader",
"fields": [
{
"name": "text_a",
"data_type": "string",
"reader": {
"type": "ErnieTextFieldReader"
},
"tokenizer": {
"type": "FullTokenizer",
"split_char": " ",
"unk_token": "[UNK]"
},
"need_convert": true,
"vocab_path": "../../models_hub/ernie_3.0_base_ch_dir/vocab.txt",
"max_seq_len": 512,
"truncation_type": 0,
"padding_id": 0
},
{
"name": "label",
"data_type": "int",
"need_convert": false,
"reader": {
"type": "ScalarFieldReader"
},
"tokenizer": null,
"vocab_path": "",
"max_seq_len": 1,
"truncation_type": 0,
"padding_id": 0,
"embedding": null
}
],
"config": {
"data_path": "./data/test_data",
"shuffle": false,
"batch_size": 8,
"epoch": 1,
"sampling_rate": 1.0,
"need_data_distribute": false,
"need_generate_examples": false
}
}
},
"model": {
"type": "ErnieClassification",
"is_dygraph": 1,
"optimization": { ## 优化器设置,文心ERNIE推荐的默认设置。
"learning_rate": 2e-05,
"use_lr_decay": true,
"warmup_steps": 0,
"warmup_proportion": 0.1,
"weight_decay": 0.01,
"use_dynamic_loss_scaling": false,
"init_loss_scaling": 128,
"incr_every_n_steps": 100,
"decr_every_n_nan_or_inf": 2,
"incr_ratio": 2.0,
"decr_ratio": 0.8
},
"embedding": {
"config_path": "../../models_hub/ernie_3.0_base_ch_dir/ernie_config.json"
},
"num_labels": 2
},
"trainer": {
"type": "CustomDynamicTrainer",
"PADDLE_PLACE_TYPE": "gpu",
"PADDLE_IS_FLEET": 0, ## 是否启用fleetrun运行,多卡运行时必须使用fleetrun,单卡时即可以使用fleetrun启动也可以直接python启动
"train_log_step": 10,
"use_amp": true,
"is_eval_dev": 0,
"is_eval_test": 1,
"eval_step": 100,
"save_model_step": 200,
"load_parameters": "",
"load_checkpoint": "",
"pre_train_model": [
{
"name": "ernie_3.0_base_ch",
"params_path": "../../models_hub/ernie_3.0_base_ch_dir/params"
}
],
"output_path": "./output/cls_ernie_3.0_base_fc_ch_dy",
"extra_param": {
"meta":{
"job_type": "text_classification"
}
}
}
}
开始训练
# 进入指定任务的目录
cd wenxin_appzoo/wenxin_appzoo/tasks/text_classification
# 单卡训练,如果fleetrun设置为0,则使用下面命令
python ./run_trainer.py --param_path "./examples/cls_ernie_fc_ch.json"
# 单卡训练,如果fleetrun设置为1,则使用下面的命令
fleetrun --log_dir log ./run_trainer.py --param_path "./examples/cls_ernie_fc_ch.json" 1>log/lanch.log 2>&1
# 多卡训练,fleetrun必须设置为1,使用下面的命令
fleetrun --log_dir log ./run_trainer.py --param_path "./examples/cls_ernie_fc_ch.json" 1>log/lanch.log 2>&1
- 通过上述脚本调用json文件开启训练
- 训练阶段日志文件于log文件夹下,workerlog.N 保存了第N张卡的log日志内容,如遇到程序报错可以通过查看不同卡的workerlog.N定位到有效的报错信息。
- 训练模型保存于./output/cls_ernie_3.0_base_fc_ch_dy文件夹下,保存好的模型,我们选择checkpoints文件进行下一步的特征分析。
开始特征分析
特征分析是将对预测数据和其结果进行分析,所以需要按预测数据的格式进行数据准备,仅需要文本即可,不需要标签,如下所示:
选择珠江花园的原因就是方便,有电动扶梯直接到达海边,周围餐馆、食廊、商场、超市、摊位一应俱全。酒店装修一般,但还算整洁。
15.4寸笔记本的键盘确实爽,基本跟台式机差不多了,蛮喜欢数字小键盘,输数字特方便,样子也很美观,做工也相当不错
房间太小。其他的都一般。。。。。。。。。
运行脚本配置如下(examples/cls_ernie_fc_ch_feature_analysis.json)
{
"dataset_reader": {
"predict_reader": {
"name": "predict_reader", ## 模型预测必须配置predict_reader,其配置方式与上文的train_reader、test_reader类似,需要注意的是predict_reader不需要label域,shuffle参数必须是false,epoch参数必须是1。
"type": "BasicDataSetReader",
"fields": [
{
"name": "text_a",
"data_type": "string",
"reader": {
"type": "ErnieTextFieldReader"
},
"tokenizer": {
"type": "FullTokenizer",
"split_char": " ",
"unk_token": "[UNK]",
"params": null
},
"need_convert": true,
"vocab_path": "../../models_hub/ernie_3.0_base_ch_dir/vocab.txt",
"max_seq_len": 512,
"truncation_type": 0,
"padding_id": 0,
"embedding": null
}
],
"config": {
"data_path": "./data/analysis_data/feature_analysis", ## 数据路径
"shuffle": false, ## 这里必须设置为false,不进行数据打乱
"batch_size": 1, ## 这里的batch_size必须设置为1,否则会有padding数据的生成的噪音
"epoch": 1, ## 迭代次数必须为1
"sampling_rate": 1.0,
"need_data_distribute": false, ## 因为预测仅支持单卡,所以这里必须设置为false,即数据不能进行分布式分发
"need_generate_examples": true ## 表示在数据读取过程中除了id化好的tensor数据外,是否需要返回原始明文样本,为了方便查看,预测集必须设置为true。
}
}
},
"model": { ## 和训练时保持一致即可
"type": "ErnieClassification",
"is_dygraph": 1,
"optimization": {
"learning_rate": 2e-05,
"use_lr_decay": false, ## 这里必须设置为false
"warmup_steps": 0,
"warmup_proportion": 0.1,
"weight_decay": 0.01,
"use_dynamic_loss_scaling": false,
"init_loss_scaling": 128,
"incr_every_n_steps": 100,
"decr_every_n_nan_or_inf": 2,
"incr_ratio": 2.0,
"decr_ratio": 0.8
},
"embedding": {
"config_path": "../../models_hub/ernie_3.0_base_ch_dir/ernie_config.json"
},
"num_labels": 2
},
"trainer": {
"type": "CustomDynamicTrainer",
"PADDLE_PLACE_TYPE": "gpu",
"PADDLE_IS_FLEET": 0, ## 和训练保持一致
"train_log_step": 10,
"use_amp": true,
"is_eval_dev": 0, ## 必须设置为0
"is_eval_test": 0, ## 必须设置为0
"eval_step": 1, ## 必须设置为1
"save_model_step": 200,
"load_parameters": "./output/cls_ernie_3.0_base_fc_ch_dy/save_checkpoints/checkpoints_step_501/", ## 上一步训练出来的模型checkpoints文件
"load_checkpoint": "",
"pre_train_model": [],
"output_path": "./output/analysis_result.txt", ## 输出结果的保存路径
"extra_param": {
"meta":{
"job_type": "text_classification"
}
}
}
}
开始运行:
# 使用当前目录下的
run_features_analysis.py 脚本进行特征分析
python run_features_analysis.py --param_path=./examples/cls_ernie_fc_ch_feature_analysis.json
# 运行完成后,结果将保存在
./output/analysis_result.txt中。
分析结果如下所示: word_attributions:样本中每个token的特征概率。 pred_label:模型预测结果。 pred_proba:模型在分类类别上的概率分布。 rationale:对当前预测结果起了重要依据的token位置。 non_rationale:对当前预测结果没起什么依据的token位置。 rationale_tokens:关键依据的token。 non_rationale_tokens:非关键依据的token。
InterpretResult(words=['[CLS]', '这个', '宾馆', '比较', '陈旧', '了', ',', '特价', '的', '房间', '也', '很', '一般', '。', '总体', '来说', '一般', '[SEP]'], word_attributions=[0.08131204545497894, -0.015465758740901947, 0.15092851594090462, -0.11724942922592163, 0.19505570828914642, -0.0005141766741871834, -0.05417150259017944, 0.24456731975078583, 0.05290137231349945, 0.19248447567224503, 0.0065753161907196045, 0.0394124835729599, -0.0720921941101551, -0.07563091069459915, 0.002225611824542284, 0.03168269246816635, -0.0702002951875329, 0.28575998544692993], pred_label=1, pred_proba=array([0.0016131 , 0.99838686], dtype=float32), rationale=(2, 4, 7, 8, 9), non_rationale=(1, 3, 5, 6, 10, 11, 12, 13, 14, 15, 16), rationale_tokens=('宾馆', '陈旧', '特价', '的', '房间'), non_rationale_tokens=('这个', '比较', '了', ',', '也', '很', '一般', '。', '总体', '来说', '一般'), rationale_pred_proba=None, non_rationale_pred_proba=None)
InterpretResult(words=['[CLS]', '怀着', '十分', '激动', '的', '心情', '放映', ',', '可是', '看着', '看着', '发现', ',', '在', '放映', '完毕', '后', ',', '出现', '一集', '米老鼠', '的', '动画片', '!', '开始', '还', '怀疑', '是不是', '赠送', '的', '个别现象', ',', '可是', '后来', '发现', '每张', 'dvd', '后面', '都', '有', '!', '真不知道', '生产商', '怎么', '想', '的', ',', '我', '想', '看', '的', '是', '猫和老鼠', ',', '不是', '米老鼠', '!', '如果', '厂家', '是', '想', '赠送', '的话', ',', '那', '就', '全套', '米老鼠', '和', '唐老鸭', '都', '赠送', ',', '只', '在', '每张', 'dvd', '后面', '添加', '一集', '算', '什么', '?', '?', '简直', '是', '画蛇添足', '!', '!', '[SEP]'], word_attributions=[0.002076566219329834, -0.05280791595578194, 0.0025696037337183952, 0.015039321035146713, -0.008998476900160313, 0.030619521159678698, 0.025006775744259357, -0.0036676046438515186, 0.041291133500635624, 0.03564849589020014, 0.024641045834869146, 0.032352319452911615, -0.004960148595273495, 0.017146483063697815, 0.04032287187874317, 0.03004542924463749, 0.01996348798274994, -0.013987114652991295, 0.024433945305645466, -0.0013266168534755707, 0.048110346775501966, -0.015303543768823147, 0.03397809900343418, -0.03299016132950783, -0.005744744557887316, 0.004264251794666052, 0.017910479567945004, -0.01459295628592372, 0.00665617361664772, -0.0013534692116081715, 0.08922316692769527, -0.016856784000992775, 0.021419867873191833, 0.02062097378075123, 0.010543567826971412, 0.0038286359049379826, 0.02364981919527054, -0.004806260112673044, 0.0034431591629981995, -0.0012499039294198155, -0.03161320090293884, 0.01552544953301549, 0.033262948505580425, 0.011698258924297988, 0.006797433365136385, 0.007958417758345604, -0.04121304303407669, -0.008125465363264084, -0.01731642335653305, -0.00459657609462738, 0.0049750651232898235, 0.00012094812700524926, 0.026585438288748264, -0.0071816458366811275, -0.023575767874717712, -0.01287288754247129, -0.03258969262242317, -0.015483793802559376, -0.006610091193579137, 0.0020591255743056536, 0.0035222265869379044, 0.0018081237794831395, -0.011568047106266022, -0.01671186089515686, -0.014911983162164688, -0.012085910886526108, 0.004526615142822266, 0.038233909755945206, 0.01656595803797245, 0.03270253702066839, 0.015023697167634964, 0.0056251659989356995, -0.024716133251786232, -0.0006697603967040777, -0.008279712870717049, 0.027053192257881165, 0.03592480719089508, 0.01851442491170019, 0.013252677395939827, 0.00812239944934845, 0.04180645942687988, 0.011230507167056203, -0.002986354287713766, 0.020911958068609238, 0.056243724189698696, 0.025230616331100464, 0.17893190309405327, 0.013372702524065971, 0.040163636207580566, 0.08028377592563629], pred_label=0, pred_proba=array([9.9933136e-01, 6.6862709e-04], dtype=float32), rationale=(20, 30, 80, 84, 86), non_rationale=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 21, 22, 23, 24, 25, 26, 27, 28, 29, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 81, 82, 83, 85, 87, 88), rationale_tokens=('米老鼠', '个别现象', '算', '简直', '画蛇添足'), non_rationale_tokens=('怀着', '十分', '激动', '的', '心情', '放映', ',', '可是', '看着', '看着', '发现', ',', '在', '放映', '完毕', '后', ',', '出现', '一集', '的', '动画片', '!', '开始', '还', '怀疑', '是不是', '赠送', '的', ',', '可是', '后来', '发现', '每张', 'dvd', '后面', '都', '有', '!', '真不知道', '生产商', '怎么', '想', '的', ',', '我', '想', '看', '的', '是', '猫和老鼠', ',', '不是', '米老鼠', '!', '如果', '厂家', '是', '想', '赠送', '的话', ',', '那', '就', '全套', '米老鼠', '和', '唐老鸭', '都', '赠送', ',', '只', '在', '每张', 'dvd', '后面', '添加', '一集', '什么', '?', '?', '是', '!', '!'), rationale_pred_proba=None, non_rationale_pred_proba=None)
InterpretResult(words=['[CLS]', '还', '稍微', '重', '了', '点', ',', '可能', '是', '硬盘', '大', '的', '原故', ',', '还要', '再轻', '半斤', '就', '好', '了', '。', '其他', '要', '进一步', '验证', '。', '贴', '的', '几种', '膜', '气泡', '较', '多', ',', '用', '不了', '多久', '就要', '更换', '了', ',', '屏幕', '膜', '稍', '好点', ',', '但', '比', '没有', '要强', '多', '了', '。', '建议', '配赠', '几张', '膜', '让', '用', '用户', '自己', '贴', '。', '[SEP]'], word_attributions=[0.0006664171814918518, -0.00232806708663702, 0.01713430415838957, 0.006945062894374132, -0.009011978283524513, 0.011891393922269344, 0.007345870137214661, 0.0029482142999768257, -0.00351709988899529, 0.013950992841273546, 0.014718687161803246, -0.004140591714531183, 0.017135773319751024, -0.016810597851872444, 0.00512023433111608, 0.03256610780954361, -0.007932081818580627, 0.006141374818980694, 0.022120775654911995, 0.021381715312600136, 0.0022811200469732285, 0.018820897676050663, 0.007859316654503345, 0.02857692865654826, -0.015171912964433432, -0.020635122433304787, 0.012667935341596603, -0.01707950420677662, 0.02939440740738064, -0.0042770239524543285, 0.010003181174397469, 0.0014542816206812859, 0.005212768912315369, 0.012886474840342999, -0.0065664309076964855, 0.02783005405217409, 0.020129258511587977, 0.00193664466496557, -0.0004319334402680397, -0.0035491527523845434, 0.01658567786216736, 0.01721367542631924, 0.010507157072424889, -0.0003093094564974308, 0.010007902281358838, -0.0011728676036000252, 0.006455680355429649, 0.013318296521902084, 0.024643402080982924, 0.02000664547085762, 0.007289891596883535, 0.01610838621854782, 0.017332401126623154, 0.01150076906196773, 0.05950666218996048, 0.015790967270731926, 0.007413539104163647, 0.013358291238546371, 0.017596140503883362, 0.031549626495689154, 0.003153653349727392, 0.002609523944556713, 0.09512047469615936, 0.289874404668808], pred_label=0, pred_proba=array([9.992785e-01, 7.215710e-04], dtype=float32), rationale=(15, 28, 54, 59, 62), non_rationale=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 55, 56, 57, 58, 60, 61), rationale_tokens=('再轻', '几种', '配赠', '用户', '。'), non_rationale_tokens=('还', '稍微', '重', '了', '点', ',', '可能', '是', '硬盘', '大', '的', '原故', ',', '还要', '半斤', '就', '好', '了', '。', '其他', '要', '进一步', '验证', '。', '贴', '的', '膜', '气泡', '较', '多', ',', '用', '不了', '多久', '就要', '更换', '了', ',', '屏幕', '膜', '稍', '好点', ',', '但', '比', '没有', '要强', '多', '了', '。', '建议', '几张', '膜', '让', '用', '自己', '贴'), rationale_pred_proba=None, non_rationale_pred_proba=None)
InterpretResult(words=['[CLS]', '交通', '方便', ';', '环境', '很', '好', ';', '服务态度', '很', '好', '房间', '较', '小', '[SEP]'], word_attributions=[0.13132205605506897, 0.18499934300780296, 0.035210857167840004, 0.07993949204683304, 0.031213230453431606, -0.0022432543337345123, 0.06973284482955933, 0.06583131849765778, 0.11495713517069817, -0.01573021709918976, 0.01941096968948841, 0.03553070407360792, -0.006632696837186813, -0.01640896499156952, 0.210178405046463], pred_label=1, pred_proba=array([5.9603900e-04, 9.9940395e-01], dtype=float32), rationale=(1, 3, 6, 7, 8), non_rationale=(2, 4, 5, 9, 10, 11, 12, 13), rationale_tokens=('交通', ';', '好', ';', '服务态度'), non_rationale_tokens=('方便', '环境', '很', '很', '好', '房间', '较', '小'), rationale_pred_proba=None, non_rationale_pred_proba=None)