资讯 文档
技术能力
语音技术
文字识别
人脸与人体
图像技术
语言与知识
视频技术

Inference

基类定义

Inference的基类定义在./wenxin/inference/inference.py

class Inference(object):
    """Inferece:模型预测的基类,主要操作有:
     1.解析input_data的结构 
     2.解析参数,构造inference  
     3.启动data_generator,开始预测
     4.回掉预测结果到parser_handler中进行解析
    """
    def __init__(self, param, data_set_reader, parser_handler):
        """
        :param param: 运行的基本参数设置
        :param data_set_reader: 运行的基本参数设置
        :param parser_handler: 解析预测结果的回掉接口
        """
        pass

    def init_env(self, meta_info):
        """
        运行时环境初始化,模型加载,基类已经完成,用户无需自定义。
        :return:
        """
        pass

    def do_inference(self):
        """
        从数据集文件夹读取数据,进行批量预测。
        :return:
        """
        pass

    def inference_query(self, query_list):
        """
        单条样本预测
        :param query_list: 待预测的单条样本,这里的类型必须是list,list的每一项是这条样本的一个域。比如一条单句分类任务,其query_list=[text], 一个匹配任务,其query_list=[text_a, text_b]
        :return:
        """
        pass

核心函数

Inference做为基类,将预测过程中常用的操作(如运行时环境初始化、模型加载)已经统一封装实现,不需要在子类中重新实现。需要用户按自己业务场景在子类中自定义的核心函数为以下两个:

  • do_inference(self):从数据集文件夹读取数据,进行批量预测。
  • inference_query(self, query_list):单条样本预测。

自定义实现示例

文心中目前Inference基类已经能覆盖所有预测任务,所以直接看其对do_inference和inference_query的实现即可,详见以下代码及核心部分的注释。

    def do_inference(self):
        """从数据集文件夹读取数据,进行批量预测。
        :return:
        """
        logging.info("start do inference....")
        total_time = 0
        output_path = self.param.get("output_path", None)
        # 预测结果回写到指定文件
        if not output_path or output_path == "":
            if not os.path.exists("./output"):
                os.makedirs("./output")
            output_path = "./output/predict_result.txt"

        output_file = open(output_path, "w+")
        
        # step1: 从预测数据集读取数据,组成适合model计算的tensor结构。这里使用预置的datasetreader来做数据迭代器。
        reader = self.data_set_reader.predict_reader.data_generator()
        
        # step2:启动数据迭代器,按batch读取明文样本,将其转换为id
        for sample in reader():
            sample_ids = sample[0]
            sample_entity_list = sample[1]
            sample_dict = self.data_set_reader.predict_reader.convert_fields_to_dict(sample_ids, need_emb=False)
            input_list = []

            for item in self.input_keys:
                kv = item.split("#")
                name = kv[0]
                key = kv[1]
                item_instance = sample_dict[name]
                input_item = item_instance[InstanceName.RECORD_ID][key]
                input_list.append(input_item)
						
            # step3: 将array的id数值转换成Paddle需要的tensor类型,这里直接调用文心封装好的方法(array2tensor)即可,该方法位于./wenxin/utils/util_helper.py
            inputs = [array2tensor(ndarray) for ndarray in input_list]
            begin_time = time.time()
            # step4: 调用Paddle的预测对象(self.inference,基类的init方法中已经初始化,用户无需重写,直接调用即可),调用 self.inference.run(inputs) 
            result = self.inference.run(inputs)
            end_time = time.time()
            total_time += end_time - begin_time
            # step5:将预测结果回调给外部的解析函数进行解析,每个task的解析都不一样,需要用户自定义
            write_result_list = self.parser_handler(result, sample_list=sample_entity_list, params_dict=self.param)
            for write_item in write_result_list:
                size = len(write_item)
                for index, item in enumerate(write_item):
                  	# step6: 将解析后的预测结构回写到指定文件,当然也可以不写文件,直接print也行。
                    output_file.write(item)
                    if index != size - 1:
                        output_file.write("\t")

                output_file.write("\n")

        logging.info("total_time:{}".format(total_time))
        output_file.close()

    def inference_query(self, query_list):
        """单条样本预测
        :param query_list:
        :return:
        """
        total_time = 0
        # step1:使用datasetreader的api_generator接口获取单条query的id值。
        sample_id = self.data_set_reader.predict_reader.api_generator(query_list)
        sample_dict = self.data_set_reader.predict_reader.convert_fields_to_dict(sample_id, need_emb=False)
        input_list = []
        for item in self.input_keys:
            kv = item.split("#")
            name = kv[0]
            key = kv[1]
            item_instance = sample_dict[name]
            input_item = item_instance[InstanceName.RECORD_ID][key]
            input_list.append(input_item)
				# step2: 将array的id数值转换成Paddle需要的tensor类型,这里直接调用文心封装好的方法(array2tensor)即可,该方法位于./wenxin/utils/util_helper.py
        inputs = [array2tensor(ndarray) for ndarray in input_list]
        begin_time = time.time()
         # step3: 调用Paddle的预测对象(self.inference,基类的init方法中已经初始化,用户无需重写,直接调用即可),调用 self.inference.run(inputs)
        result = self.inference.run(inputs)
        end_time = time.time()
        total_time += end_time - begin_time
        # step4:将预测结果回调给外部的解析函数进行解析,每个task的解析都不一样,需要用户自定义
        return self.parser_handler(result, params_dict=self.param)
上一篇
FieldReader
下一篇
Model