资讯 文档
技术能力
语音技术
文字识别
人脸与人体
图像技术
语言与知识
视频技术

文本匹配任务

任务简介

  • 文本匹配是自然语言处理中一个重要的基础问题,自然语言处理中的许多任务都可以抽象为文本匹配任务。
  • 例如网页搜索可抽象为网页内容与用户搜索 Query 的一个相关性匹配问题,自动问答可抽象为候选答案与问题的满足度匹配问题,文本去重可以抽象为文本与文本的相似度匹配问题。
  • 主要应用场景有:搜索,推荐,FAQ,自动拼写修正,重复数据删除,甚至基因组分析等等。

快速开始

1. 代码结构说明

以下是本项目主要代码结构及说明:

代码目录: wenxin-premium/tasks/text_matching

.
├── __init__.py                                                                                 
├── env.sh                                                 ## 环境变量配置脚本
├── run_with_json.py                                       ## 只依靠json进行模型训练的入口脚本
├── run_infer.py                                           ## 只依靠json进行模型预测的入口脚本
├── examples                                               ## 各典型网络的json配置文件
│   ├── mtch_cnn_pointwise_ch.jso
│   ├── mtch_cnn_pointwise_ch_infer.json
│   ├── mtch_ernie_1.0_base_fc_pointwise_ch.json
│   └── ...
├── data                                                   ## 示例数据文件夹,包括各任务所需训练集(train_data)、测试集(test_data)、验证集(dev_data)和预测集(predict_data)
│   ├── train_data_pairwise                                ## 示例未分词pairwise模式训练集,用于ERNIE任务的训练
│   │   └── train.txt
│   ├── train_data_pairwise                                ## 示例未分词pointwise模式训练集,用于ERNIE任务的训练
│   │   └── train.txt
│   ├── train_data_pairwise_tokenized                      ## 示例分词后的pairwise模式训练集,用于非ERNIE任务的训练                    
│   │   └── train.txt
│   ├── test_data
│   │   └── test.txt
│   ├── dev_data
│   │   └── dev.txt
│   ├── predict_data
│   │   └── infer.txt
│   └── ...
└── dict                                                   ## 示例非ERNIE任务的词表文件夹
     └── vocab.txt

2. 数据准备

  • 文本匹配任务的训练集根据其训练方式的不同,分为pointwise和pairwise两种格式。
  • 注:数据集(包含词表)均为utf-8格式。
非ERNIE数据
  • 这里我们提供一份已标注的、经过分词预处理的示例数据集。
  • pointwise训练集、pairwise训练集、测试集、验证集和预测集分别存放在./data目录下的train_data_pointwise_tokenized、train_data_pairwise_tokenized、test_data_tokenized、dev_data_tokenized和predict_data_tokenized文件夹下,对应的示例词表存放在./dict目录下。
  1. pointwise训练集

    • pointwise格式的训练集样例如下所示。数据分为三列,列与列之间用\t进行分隔。前两列为文本,最后一列为标签。
    喜欢 打篮球 的 男生 喜欢 什么样 的 女生      爱 打篮球 的 男生 喜欢 什么样 的 女生      1
    我 手机 丢 了 , 我 想 换 个 手机      我 想 买 个 新手机 , 求 推荐      1
    大家 觉得 她 好看 吗      大家 觉得 跑 男 好看 吗 ?      0
  2. pairwise训练集

    • pairwise格式的训练集样例如下所示。数据分为三列,列与列之间用\t进行分隔。以query和文章标题匹配任务为例,第一列为query,第二列为正例标题pos_titile,第三列为负例标题neg_title。
    喜欢 打篮球 的 男生 喜欢 什么样 的 女生      爱 打篮球 的 男生 喜欢 什么样 的 女生      爱情 里 没有 谁 对 谁错 吗 ?
    我 手机 丢 了 , 我 想 换 个 手机      我 想 买 个 新手机 , 求 推荐      剑灵 高级 衣料 怎么 得
    大家 觉得 她 好看 吗      大家 觉得 跑 男 好看 吗 ?      照片 怎么 变成 漫画
  3. 测试集/验证集

    • 测试集和验证集与pointwise格式训练集的数据格式保持一致,如下所示。数据分为三列,列与列之间用\t进行分隔。前两列为文本,最后一列为标签。
    尺有所短 , 后面 是 什么      尺有所短 , 后面 写 什么      1
    为什么 恐怖片 会 吓死人      为什么 恐怖片 会 吓死人 ?      1
    这 是 什么 舞 ? ( 图片 )      这 是 什么 枪 图片 如下      0
  4. 预测集

    • 预测集无需进行标签预占位,两列文本之间使用\t进行分隔,其格式如下所示:
    图片 上 得 牌子 是 什么      图片 上 是 什么 牌子 的 包
    芹菜 包 什么 肉 好吃      芹菜 炒 啥 好吃
    汽车 坐垫 什么 牌子 好 ?      什么 牌子 的 汽车 坐垫 好
  5. 词表

    • 词表分为两列,第一列为词,第二列为id(从0开始),列与列之间用\t进行分隔。若用户自备词表,需保持[UNK]项与示例词表一致。部分词表示例如下所示:
    [PAD]	0
    [CLS]	1
    [SEP]	2
    [MASK]	3
    [UNK]	4
        5
    郑重	6
    丁约翰	7
    工地	8
    神圣	9
ERNIE数据
  • 若使用基于ERNIE的模型进行训练,那么数据集不需要分词且无需准备词表,其格式与非ERNIE的数据集相同。
  • 这里我们提供一份已标注的ERNIE示例数据集。pointwise训练集、pairwise训练集、测试集、验证集和预测集分别存放在./data目录下的train_data_pointwise、train_data_pairwise、test_data、dev_data和predict_data文件夹下。

3. 训练第一个模型

开始训练
  • 使用预置网络进行训练的方式为使用./run_with_json.py入口脚本,通过--param_path参数来传入./examples/目录下的json配置文件。
  • 以基于预置BOW网络的pointwise文本匹配模型为例,训练分为以下几个步骤:

    1. 请在./env.sh中根据提示配置相应环境变量的路径;
    2. 基于示例的数据集,可以运行以下命令在训练集(train.txt)上进行模型训练,并在测试集(test.txt)上进行验证;
    # BOW pointwise 模型
    # 需要提前参照env.sh进行环境变量配置,在当前shell内去读取
    source env.sh
    # 基于json实现预置网络训练。其调用了配置文件./examples/mtch_cnn_pairwise_ch.json
    python run_with_json.py --param_path ./examples/mtch_cnn_pairwise_ch.json
    1. 训练运行的日志会自动保存在./log/test.log文件中;
    2. 训练中以及结束后产生的模型文件会默认保存在./output/mtch_cnn_pairwise_ch/目录下,其中save_inference_model/文件夹会保存用于预测的模型文件,save_checkpoint/文件夹会保存用于热启动的模型文件。
配置说明
  • 使用预置网络训练时,可以通过修改所加载的json文件来进行参数的自定义配置。
  • json配置文件主要分为三个部分:dataset_reader、model和trainer。
  • 以./examples/mtch_cnn_pairwise_ch.json为例,上述三个部分的配置与说明如下所示。
dataset_reader部分
  • dataset_reader用于配置模型训练时的数据读取。
  • 以下为./examples/mtch_cnn_pairwise_ch.json中抽取出来的dataset_reader部分配置,并通过注释说明。
{
  "dataset_reader": {                                  
    "train_reader": {                                   ## 训练、验证、测试各自基于不同的数据集,数据格式也可能不一样,可以在json中配置不同的reader,此处为训练集的reader。
      "name": "train_reader",
      "type": "BasicDataSetReader",                     ## 采用BasicDataSetReader,其封装了常见的读取tsv文件、组batch等操作。
      "fields": [                                       ## 域(field)是wenxin的高阶封装,对于同一个样本存在不同域的时候,不同域有单独的数据类型(文本、数值、整型、浮点型)、单独的词表(vocabulary)等,可以根据不同域进行语义表示,如文本转id等操作,field_reader是实现这些操作的类。
        {
          "name": "text_a",                             ## 文本匹配有两个文本特征域,分别命名为"text_a"和"text_b"。
          "data_type": "string",                        ## data_type定义域的数据类型,文本域的类型为string,整型数值为int,浮点型数值为float。
          "reader": {"type":"CustomTextFieldReader"},   ## 采用针对文本域的通用reader "CustomTextFieldReader"。数值数组类型域为"ScalarArrayFieldReader",数值标量类型域为"ScalarFieldReader"。
          "tokenizer":{
              "type":"CustomTokenizer",                 ## 指定该文本域的tokenizer为CustomTokenizer。
              "split_char":" ",                         ## 通过空格区分不同的token。
              "unk_token":"[UNK]",                      ## unk标记为"[UNK]"。
              "params":null
            },
          "need_convert": true,                         ## "need_convert"为true说明数据格式是明文字符串,需要通过词表转换为id。
          "vocab_path": "./dict/vocab.txt",             ## 指定该文本域的词表。
          "max_seq_len": 512,                           ## 设定每个域的最大长度。
          "truncation_type": 0,                         ## 选择截断策略,0为从头开始到最大长度截断,1为从头开始到max_len-1的位置截断,末尾补上最后一个id(词或字),2为保留头和尾两个位置,然后按从头开始到最大长度方式截断。
          "padding_id": 0                               ## 设定padding时对应的id值。
        },                                              ## 如果每一个样本有多个特征域(文本类型、数值类型均可),可以仿照前面对每个域进行设置,依次增加每个域的配置即可。此时样本的域之间是以\t分隔的。
        {
          "name": "text_b",                            
          "data_type": "string",                       
          "reader":{"type":"CustomTextFieldReader"},
          "tokenizer":{
              "type":"CustomTokenizer",
              "split_char":" ",
              "unk_token":"[UNK]",
              "params":null
          },
          "need_convert": true,
          "vocab_path": "./dict/vocab.txt",   
          "max_seq_len": 512,
          "truncation_type": 0,
          "padding_id": 0
        },
        {
          "name": "label",                                                  ## 标签也是一个单独的域,命名为"label"。如果多个不同任务体系的标签存在于多个域中,则可实现最基本的多任务学习。
          "data_type": "int",                                               ## 标签是整型数值。
          "reader":{
            "type":"ScalarFieldReader"                                      ## 整型数值域的reader为"ScalarFieldReader"。
          },
          "tokenizer":null,
          "need_convert": false,
          "vocab_path": "",
          "max_seq_len": 1,
          "truncation_type": 0,
          "padding_id": 0,
          "embedding": null
        }
      ],
      "config": {
        "data_path": "./data/train_data_pointwise_tokenized/",              ## 训练数据train_reader的数据路径,写到文件夹目录。
        "shuffle": false,
        "batch_size": 8,
        "epoch": 10,
        "sampling_rate": 1.0
      }
    },
    "test_reader": {                                                        ## 若要评估测试集,需配置test_reader,其配置方式与train_reader类似。
    ……
    },
    "dev_reader": {                                                         ## 若要评估验证集,需配置dev_reader,其配置方式与test_reader类似。
    ……
    }
  },
  ……
}
model部分
  • model用于配置模型训练时的预置网络,包括预置网络的类别及其优化器的参数等。
  • 以下为./examples/mtch_cnn_pairwise_ch.json中抽取出来的model部分配置,并通过注释说明。
{
 ...
  "model": {
    "type": "CnnMatchingPairwise",                      ## wenxin采用模型(models)的方式定义神经网络的基本操作,本例采用预置的模型CnnMatchingPairwise实现文本匹配,具体网络可参考models目录。
    "optimization": {                                    ## 预置模型的优化器所需的参数配置,如学习率等。
      "learning_rate": 2e-05
    }
  },
 ...
}
trainer部分
  • trainer用于配置模型训练的启动器,包括保存模型时的间隔步数、进行测试集或验证集评估的间隔步数等。
  • 以下为./examples/mtch_cnn_pairwise_ch.json中抽取出来的trainer部分配置,并通过注释说明。
{
  ...
  "trainer": {
    "PADDLE_USE_GPU": 0,                                   ## 是否使用GPU进行训练,1为使用GPU。
    "PADDLE_IS_LOCAL": 1,                                  ## 是否单机训练,默认值为0,若要单机训练需要设置为1。
    "train_log_step": 20,                                  ## 训练时打印训练日志的间隔步数。
    "is_eval_dev": 0,                                      ## 是否在训练的时候评估开发集,如果取值为1一定需要配置dev_reader及其数据路径。
    "is_eval_test": 1,                                     ## 是否在训练的时候评估测试集,如果取值为1一定需要配置test_reader及其数据路径。
    "eval_step": 100,                                      ## 进行测试集或验证集评估的间隔步数。
    "save_model_step": 10000,                              ## 保存模型时的间隔步数,建议设置为eval_step的整数倍。
    "load_parameters": 0,                                  ## 加载包含各op参数值的训练好的模型,用于预测。
    "load_checkpoint": " ",                                ## 加载包含学习率等所有参数的训练模型,用于热启动。此处填写checkpoint路径
    "use_fp16": 0,                                         ## 是否使用fp16精度。
    "pre_train_model": [],                                 ## 加载预训练模型,例如ernie。使用时需要填写预训练模型的名称name和预训练模型的目录params_path。
    "output_path": "./output/mtch_cnn_pairwise_ch"        ## 保存模型的输出路径,如置空或者不配置则默认输出路径为"./output"。
  }
}

4. 模型预测

开始预测
  • 使用预置网络进行预测的方式为使用./run_infer.py入口脚本,通过--param_path参数来传入./examples/目录下的json配置文件。
  • 以基于预置BOW网络的pointwise文本匹配所训练出的模型为例,其预测分为以下几个步骤:

    1. 基于./examples/mtch_cnn_pairwise_ch.json训练出的模型默认储存在./output/mtch_cnn_pairwise_ch/save_inference_model/中,在该目录下找到被保存的inference_model文件夹,例如inference_step_251/;
    2. 在./examples/mtch_cnn_pairwise_ch_infer.json中修改"inference_model_path"参数,填入上述模型保存路径,如下所示:
    {
      ...
      "inference":{   
         ...
        "inference_model_path":"./output/mtch_cnn_pairwise_ch/save_inference_model/inference_step_251" 
      }
    }
    1. 基于示例的数据集,可以运行以下命令在预测集(infer.txt)上进行预测:
    # 基于json实现预测。其调用了配置文件./examples/mtch_cnn_pairwise_ch_infer.json
    python run_infer.py --param_path ./examples/mtch_cnn_pairwise_ch_infer.json
    1. 预测运行的日志会自动保存在./output/predict_result.txt文件中。
配置说明
  • 基于json实现预测时,可以通过修改所加载的json文件来进行参数的自定义配置。
  • json配置文件主要分为三个部分:dataset_reader、model和inference。 以./examples/mtch_cnn_pairwise_ch_infer.json为例,上述三个部分的配置与说明如下所示。
{
  "dataset_reader": {
    "predict_reader": {                                                                      ## 预测部分需要单独配置reader。
      "name": "predict_reader",
      "type": "BasicDataSetReader",
      "fields": [                                                                            ## 本样例中有两个文本域需要配置,配置方式与训练过程类似,注意无需配label。
        {
          "name": "text_a",
          "data_type": "string",
          "reader": {"type":"CustomTextFieldReader"},
          "tokenizer":{
              ……
            },
          ……
        },
        {
          "name": "text_b",
          "data_type": "string",
          "reader": {"type":"CustomTextFieldReader"},
          ……
        }
      ],
      "config": {
        "data_path": "./data/predict_data_tokenized/",                                                  ## 需要配置预测数据路径。
        ……
      }
    }
 
 
  },
  "model": {
    "type": "CnnMatchingPairwise"                                                                      ## 如果使用checkpoint需要说明其模型网络。
  },
  "trainer": {
  ……
  },
  "inference":{                                                                                         ## 需配置预测时所需的参数。
    "output_path": "./output/predict_result.txt",                                                       ## 预测结果文件的输出路径。
    "PADDLE_USE_GPU": 0,                                                                                ## 是否采用GPU预测。
    "PADDLE_IS_LOCAL": 1,                                                                               ## 是否单机预测,默认值为0,若要单机预测需要设置为1。
    "inference_model_path":"./output/mtch_cnn_pairwise_ch/save_inference_model/inference_step_251"     ## 训练好的模型路径。
  }
}

5. 使用ERNIE中文模型进行训练

  • ERNIE模型参数文件下载脚本及相关配置文件和词典保存在../model_files/目录下。
  • 不同ernie版本的参数文件(params)、词表(vocab.txt)、网络配置参数(ernie*config.json)大部分不相同,需要注意好对应关系。不同ERNIE模型间的差别及相关配置文件与词典请参考ERNIE预训练模型介绍。
  • 若希望通过开关控制ernie训练时不更新ernie参数,请参考ernie网络配置超参数说明中关于freeze_emb和freeze_num_layers的说明。
开始训练
  • 与非ERNIE的训练方式相同,使用预置网络进行训练的方式为使用./run_with_json.py入口脚本,通过--param_path参数来传入./examples/目录下ernie相关的json配置文件。
  • 以预置基于ernie_1.0_base的FC pointwise文本匹配模型为例,训练分为以下几个步骤:

    1. 请使用以下命令在../model_files/中通过对应脚本下载ernie_1.0_base模型参数文件,其对应配置文件ernie_1.0_base_ch_config.json和词典vocab_ernie_1.0_base_ch.txt分别位于../model_files/目录下的config/和dict/文件夹,用户无需更改;
    # ernie_1.0_base 模型下载
    # 进入model_files目录
    cd ../model_files/
    # 运行下载脚本
    sh download_ernie_1.0_base_ch.sh
    1. 基于示例的数据集,运行以下命令在训练集(train.txt)上进行模型训练,并在测试集(test.txt)上进行验证;
    # 基于json实现预置网络训练。其调用了配置文件./examples/mtch_ernie_1.0_base_fc_pointwise_ch.json
    python run_with_json.py --param_path ./examples/mtch_ernie_1.0_base_fc_pointwise_ch.json
    1. 训练运行的日志会自动保存在./log/test.log文件中;
切换不同版本的ERNIE模型:以ERNIE-Tiny为例
通过加载预置json配置文件切换
  • 通过../model_files目录下对应的download.sh脚本将不同ernie模型文件下载好之后,一般只需要在训练时加载对应的json文件即可切换不同ernie模型。
通过修改预置json配置文件切换
  • 若./examples/目录下没有所需的预置ernie训练配置文件,则可以通过修改配置来实现ernie模型的切换。
  • 以将./examples/mtch_ernie_1.0_base_fc_pointwise_ch.json修改为./examples/mtch_ernie_2.0_base_fc_pointwise_ch.json为例,主要修改的参数有三个,分别为:“vocab_path”、“vocab_path”和“pre_train_model”。修改部分示例如下所示:
{
  "dataset_reader": {
    "train_reader": {
      "fields": [
        {
          "vocab_path": "../model_files/dict/vocab_ernie_2.0_base_ch.txt",            ## ernie_2.0_base的词表文件。
          "embedding": {
            "config_path":"../model_files/config/ernie_2.0_base_ch_config.json"       ## ernie_2.0_base的网络配置参数。
          }
        },
      ],
    }
  },
  "trainer": {
    "pre_train_model": [
      {
        "name":"ernie-tiny",
        "params_path":"../model_files/ernie_2.0_base_ch_dir/params"                  ## ernie_2.0_base的参数文件。
      }
    ]
  }
}

6. 进阶使用

数据处理模块(data)详细说明
  • Data部分为wenxin数据处理模块,通过data_set_reader实现读取数据文件、转换数据格式、组batch、shuffle等操作。
dataset_reader部分
  • dataset_reader用于配置模型训练时的数据读取。
  • 以下为./examples/mtch_cnn_pairwise_ch.json中抽取出来的dataset_reader部分配置,并通过注释说明。
{
  "dataset_reader": {                                  
    "train_reader": {                                   ## 训练、验证、测试各自基于不同的数据集,数据格式也可能不一样,可以在json中配置不同的reader,此处为训练集的reader。
      "name": "train_reader",
      "type": "BasicDataSetReader",                     ## 采用BasicDataSetReader,其封装了常见的读取tsv文件、组batch等操作。
      "fields": [                                       ## 域(field)是wenxin的高阶封装,对于同一个样本存在不同域的时候,不同域有单独的数据类型(文本、数值、整型、浮点型)、单独的词表(vocabulary)等,可以根据不同域进行语义表示,如文本转id等操作,field_reader是实现这些操作的类。
        {
          "name": "text_a",                             ## 文本匹配有两个文本特征域,分别命名为"text_a"和"text_b"。
          "data_type": "string",                        ## data_type定义域的数据类型,文本域的类型为string,整型数值为int,浮点型数值为float。
          "reader": {"type":"CustomTextFieldReader"},   ## 采用针对文本域的通用reader "CustomTextFieldReader"。数值数组类型域为"ScalarArrayFieldReader",数值标量类型域为"ScalarFieldReader"。
          "tokenizer":{
              "type":"CustomTokenizer",                 ## 指定该文本域的tokenizer为CustomTokenizer。
              "split_char":" ",                         ## 通过空格区分不同的token。
              "unk_token":"[UNK]",                      ## unk标记为"[UNK]"。
              "params":null
            },
          "need_convert": true,                         ## "need_convert"为true说明数据格式是明文字符串,需要通过词表转换为id。
          "vocab_path": "./dict/vocab.txt",             ## 指定该文本域的词表。
          "max_seq_len": 512,                           ## 设定每个域的最大长度。
          "truncation_type": 0,                         ## 选择截断策略,0为从头开始到最大长度截断,1为从头开始到max_len-1的位置截断,末尾补上最后一个id(词或字),2为保留头和尾两个位置,然后按从头开始到最大长度方式截断。
          "padding_id": 0                               ## 设定padding时对应的id值。
        },                                              ## 如果每一个样本有多个特征域(文本类型、数值类型均可),可以仿照前面对每个域进行设置,依次增加每个域的配置即可。此时样本的域之间是以\t分隔的。
        {
          "name": "text_b",                            
          "data_type": "string",                       
          "reader":{"type":"CustomTextFieldReader"},
          "tokenizer":{
              "type":"CustomTokenizer",
              "split_char":" ",
              "unk_token":"[UNK]",
              "params":null
          },
          "need_convert": true,
          "vocab_path": "./dict/vocab.txt",   
          "max_seq_len": 512,
          "truncation_type": 0,
          "padding_id": 0
        },
        {
          "name": "label",                                                  ## 标签也是一个单独的域,命名为"label"。如果多个不同任务体系的标签存在于多个域中,则可实现最基本的多任务学习。
          "data_type": "int",                                               ## 标签是整型数值。
          "reader":{
            "type":"ScalarFieldReader"                                      ## 整型数值域的reader为"ScalarFieldReader"。
          },
          "tokenizer":null,
          "need_convert": false,
          "vocab_path": "",
          "max_seq_len": 1,
          "truncation_type": 0,
          "padding_id": 0,
          "embedding": null
        }
      ],
      "config": {
        "data_path": "./data/train_data_pointwise_tokenized/",              ## 训练数据train_reader的数据路径,写到文件夹目录。
        "shuffle": false,
        "batch_size": 8,
        "epoch": 10,
        "sampling_rate": 1.0
      }
    },
    "test_reader": {                                                        ## 若要评估测试集,需配置test_reader,其配置方式与train_reader类似。
    ……
    },
    "dev_reader": {                                                         ## 若要评估验证集,需配置dev_reader,其配置方式与test_reader类似。
    ……
    }
  },
  ……
}
自定义reader配置
  • 自定义reader配置根据具体项目情况通过对base_dataset_reader基类重写来实现。
  • 变量设置规则在common.rule.InstanceName中,该部分囊括了model和data部分的全局变量,实现了数据部分与组网部分的衔接,前向传播loss与优化器反向传播loss、计算metric的loss的衔接。部分与数据相关示例如下所示:
...
    RECORD_ID = "id"
    RECORD_EMB = "emb"
    SRC_IDS = "src_ids"
    MASK_IDS = "mask_ids"
    SEQ_LENS = "seq_lens"
    SENTENCE_IDS = "sent_ids"
    POS_IDS = "pos_ids"
    TASK_IDS = "task_ids"
...
模型核心操作接口(models)详细说明

wenxin采用模型(models)的方式定义神经网络的基本操作,包括前向传播网络(foreward)、优化策略(optimizer)、评估指标(metrics)等部分,均可实现自定义。

预置网络介绍

目录结构

  • wenxin所有预置网络位于../../wenxin/models/目录下,与文本匹配相关的预置网络如下所示:
.
├── __init__.py
├── model.py
├── cnn_matching_pairwise.py
├── cnn_matching_pointwise.py
├── fc_matching_pointwise.py
├── ernie_matching_fc_pointwise.py
├── ernie_matching_siamese_pairwise.py
└── ernie_matching_siamese_pointwise.py
.....

基类model

  • 基类model主要由前向计算组网forward()、优化器optimizer()、预测结果解析器parse_predict_result()和指标评估get_metrics()四个主要部分构成,其功能和参数如下所示。
...
def forward(self, fields_dict, phase):
    """
    必须选项,否则会抛出异常。
    核心内容是模型的前向计算组网部分,包括loss值的计算,必须由子类实现。输出即为对输入数据执行变换计算后的结果。
    :param: fields_dict
            {"field_name":
                {"RECORD_ID":
                    {"InstanceName.SRC_IDS": [ids],
                     "InstanceName.MASK_IDS": [ids],
                     "InstanceName.SEQ_LENS": [ids]
                    }
                }
            }
    序列化好的id,供网络计算使用。
    :param: phase: 当前调用的阶段,包含训练、评估和预测,不同的阶段组网可以不一样。
            训练:InstanceName.TRAINING
            测试集评估:InstanceName.TEST
            验证集评估:InstanceName.EVALUATE
            预测:InstanceName.SAVE_INFERENCE
             
    :return: 训练:forward_return_dict
                 {
                    "InstanceName.PREDICT_RESULT": [predictions],
                    "InstanceName.LABEL": [label],
                    "InstanceName.LOSS": [avg_cost]
                    "自定义变量名": [用户需获取的其余变量]
                 }
             预测:forward_return_dict
                 {
                    # 保存预测模型时需要的入参:模型预测时所需的输入变量
                    "InstanceName.TARGET_FEED_NAMES": [ids, id_lens],
                    # 保存预测模型时需要的入参:模型预测时输出的结果
                    "InstanceName.TARGET_PREDICTS": [predictions]
                 }     
    实例化的dict,存放TARGET_FEED_NAMES, TARGET_PREDICTS, PREDICT_RESULT,LABEL,LOSS等希望从前向网络中获取的数据。
    """
    raise NotImplementedError
 
def optimizer(self, loss, is_fleet=False):
    """
    必须选项,否则会抛出异常。
    设置优化器,如Adam,Adagrad,SGD等。
    :param loss:前向计算得到的损失值。
    :param is_fleet:是否为多机。
    :return:OrderedDict: 该dict中存放的是需要在运行过程中fetch出来的tensor,大多数情况下为空,可以按需求添加内容。
    """
    raise NotImplementedError
 
def parse_predict_result(self, predict_result):
    """按需解析模型预测出来的结果
    :param predict_result: 模型预测出来的结果
    :return:None
    """
    raise NotImplementedError
 
def get_metrics(self, fetch_output_dict, meta_info, phase):
    """指标评估部分的动态计算和打印
    :param fetch_output_dict: executor.run过程中fetch出来的forward中定义的tensor
    :param meta_info:常用的meta信息,如step, used_time, gpu_id等
    :param phase: 当前调用的阶段,包含训练和评估
    :return:metrics_return_dict:该dict中存放的是各个指标的结果,如下所示:
             {
                 "acc": acc,
                 "precision": precision
             }
    """
 
    raise NotImplementedError
 
...

预置网络配置

  • 通过json文件中的model部分对预置网络进行配置,以文本匹配任务mtch_cnn_pairwise_ch.json为例,其model部分如下所示:
{
 ...
  "model": {
    "type": "CnnMatchingPairwise",                      ## wenxin采用模型(models)的方式定义神经网络的基本操作,本例采用预置的模型CnnMatchingPairwise实现文本匹配,具体网络可参考models目录。
    "optimization": {                                    ## 预置模型的优化器所需的参数配置,如学习率等。
      "learning_rate": 2e-05
    }
  },
 ...
}

自定义网络配置

  • 自定义网络配置根据具体项目情况通过对model基类重写来实现。
  • 变量设置规则在common.rule.InstanceName中,该部分囊括了model和data部分的全局变量,实现了数据部分与组网部分的衔接,前向传播loss与优化器反向传播loss、计算metric的loss的衔接。部分与网络相关示例如下所示:
...
    TARGET_FEED_NAMES = "target_feed_name"       # 保存模型时需要的入参:表示模型预测时需要输入的变量名称和顺序
    TARGET_PREDICTS = "target_predicts"          # 保存模型时需要的入参:表示预测时最终输出的结果
    PREDICT_RESULT = "predict_result"            # 训练过程中需要传递的预测结果
    LABEL = "label"                              # label
    LOSS = "loss"                                # loss
 
    TRAINING = "training"                        # 训练过程
    EVALUATE = "evaluate"                        # 评估过程
    TEST = "test"                                # 测试过程
    SAVE_INFERENCE = "save_inference"            # 保存inference model的过程
...