资讯 文档
技术能力
语音技术
文字识别
人脸与人体
图像技术
语言与知识
视频技术

10.Inference

简介

在文心中,我们把模型推理预测过程中常用的操作进行了统一封装、统一调度,形成一套标准流程,这套标准流程的调度就是Inference(预测器)。一个Inference实例中的标准操作有初始化运行环境、加载模型文件、数据读取(通过调用Reader实现)、预测推理、预测结果解析。

基本流程

一个Inference实例中包含了一个Reader实例和一个Paddle.Inference实例,Reader实例负责读取明文数据集并将其处理为飞桨可用的Tensor结构;Paddle.Inference实例对Reader传入的Tensor数据按加载的网络结构进行计算,并将计算结果返回给文心的Inference实例,Inference对结果进行自定义解析回传给用户。其具体流程下所示:

image.png

基础操作

  • 批量预测:

        def do_inference(self):
            """
            :return:
            """
            logging.info("start do inference....")
            total_time = 0
            output_path = self.param.get("output_path", None)
            if not output_path or output_path == "":
                if not os.path.exists("./output"):
                    os.makedirs("./output")
                output_path = "./output/predict_result.txt"
    
            output_file = open(output_path, "w+", encoding='utf-8')
            
            ## 通过Reader迭代读取预测数据集中的数据,组成batch。
            reader = self.data_set_reader.predict_reader.data_generator()
            for sample in reader():
                sample_ids = sample[0]
                sample_entity_list = sample[1]
                sample_dict = self.data_set_reader.predict_reader.convert_fields_to_dict(sample_ids, need_emb=False)
                input_list = []
    
                for item in self.input_keys:
                    kv = item.split("#")
                    name = kv[0]
                    key = kv[1]
                    item_instance = sample_dict[name]
                    input_item = item_instance[InstanceName.RECORD_ID][key]
                    input_list.append(input_item)
    
                ## 将明文id转换为飞桨需要的tensor类型
                inputs = [array2tensor(ndarray) for ndarray in input_list]
                begin_time = time.time()
                
                ## 使用Paddle.Inference实例进行预测推理
                result = self.inference.run(inputs)
                
                end_time = time.time()
                total_time += end_time - begin_time
                
                ## 飞桨的预测结果回调给文心inference实例的解析函数,解析函数有各个任务自己定义,文心已经提供了常见任务的解析方法,用户也可以自定义。
                write_result_list = self.parser_handler(result, sample_list=sample_entity_list, params_dict=self.param)
                for write_item in write_result_list:
                    size = len(write_item)
                    for index, item in enumerate(write_item):
                      
                      	## 用户解析后的预测结果回写到文件中。
                        output_file.write(item)
                        if index != size - 1:
                            output_file.write("\t")
    
                    output_file.write("\n")
    
            logging.info("total_time:{}".format(total_time))
            output_file.close()
  • 单样本预测:与批量预测的方式一样,不再赘述。

        def inference_query(self, query_list):
            """
            :param query_list:
            :return:
            """
            total_time = 0
            sample_id = self.data_set_reader.predict_reader.api_generator(query_list)
            sample_dict = self.data_set_reader.predict_reader.convert_fields_to_dict(sample_id, need_emb=False)
            input_list = []
            for item in self.input_keys:
                kv = item.split("#")
                name = kv[0]
                key = kv[1]
                item_instance = sample_dict[name]
                input_item = item_instance[InstanceName.RECORD_ID][key]
                input_list.append(input_item)
    
            inputs = [array2tensor(ndarray) for ndarray in input_list]
            begin_time = time.time()
            result = self.inference.run(inputs)
            end_time = time.time()
            total_time += end_time - begin_time
            # 回调给解析函数
            return self.parser_handler(result, params_dict=self.param)

文心现有的预置Inference

文心目前提供了1个通用的Inference,覆盖了一些比较常见的NLP领域的经典任务,包括文本分类、文本匹配、序列标注、信息抽取等,位于./wenxin/inference/目录下,如下所示:

├── __init__.py
├── inference.py           ## 通用Inference,支持文本分类、匹配、序列标注、信息抽取等常见任务。

进阶使用

文心中提供了通用的Inference流程,如果用户需要针对自己的业务场景进行自定义优化使用的话,请参考详细的接口设计与自定义Inference

上一篇
9.Trainer
下一篇
任务详解