资讯 文档
技术能力
语音技术
文字识别
人脸与人体
图像技术
语言与知识
视频技术

Inference

Infernece

基类定义

Inference的基类定义在wenxin/controller/inference.py

# -*- coding: utf-8 -*
"""
模型的预测控制器,核心成员有model,reader;核心方法有:
1.模型加载
2.预测:批量预测与单query预测
3.结果解析与回传
"""

class BaseInference(object):
    def __init__(self, params, data_set_reader, parser_handler):
        """
        :param params:前端json中设置的参数
        :param data_set_reader: 预测集reader
        :param parser_handler: 飞桨预测结果通过parser_handler参数回调到具体的任务中,由用户控制具体结果解析
        """

    def inference_batch(self):
        """批量预测
        """
        raise NotImplementedError

    def inference_query(self, query):
        """单query预测
        """
        raise NotImplementedError

    def init_env(self):
        """预测环境初始化
        """
        raise NotImplementedError

    def load_inference_model(self, model_path, thread_num=1):
        """加载预训练模型
        :param model_path:
        :param thread_num
        :return: inference
        """

    def parser_input_keys(self):
        """从meta文件中解析出模型预测过程中需要feed的变量名称,与model.forward的fields_dict对应起来
        """

核心函数

BaseInference做为基类,将预测过程中常用的操作(如运行时环境初始化、模型加载)已经统一封装实现,不需要在子类中重新实现。需要用户按自己业务场景在子类中自定义的核心函数为以下两个:

  • inference_batch(self):从数据集文件夹读取数据,进行批量预测。
  • inference_query(self, query):单条样本预测。

自定义实现示例

文心中目前BaseInference基类已经能覆盖所有预测任务,以文本匹配的wenxin_appzoo/tasks/text_matching/inference/custom_inference.py为例,详见以下代码及核心部分的注释。

# -*- coding: utf-8 -*
"""
对内工具包(major)中最常用的inference,必须继承自文心core中的BaseInference基类,必须实现inference_batch, inference_query方法。
"""

class CustomInference(BaseInference):
    """CustomInference
    """
    def __init__(self, params, data_set_reader, parser_handler):
        """
        :param params:前端json中设置的参数
        :param data_set_reader: 预测集reader
        :param parser_handler: 飞桨预测结果通过parser_handler参数回调到具体的任务中,由用户控制具体结果解析
        """
        BaseInference.__init__(self, params, data_set_reader, parser_handler)

    def inference_batch(self):
        """
        批量预测
        """
        logging.info("start do inference....")
        total_time = 0
        # 预测结果回写到指定文件
        output_path = self.params.get("output_path", None)
        if not output_path or output_path == "":
            if not os.path.exists("./output"):
                os.makedirs("./output")
            output_path = "./output/predict_result.txt"
        output_file = open(output_path, "w+")
        # step1: 从预测数据集读取数据,组成适合model计算的tensor结构。这里使用预置的datasetreader来做数据迭代器。
        dg = self.data_set_reader.predict_reader
        # step2:启动数据迭代器,按batch读取明文样本,将其转换为id
        for batch_id, data_t in enumerate(dg()):
            data = data_t[0]
            samples = data_t[1]
            feed_dict = dg.dataset.convert_fields_to_dict(data)
            predict_results = []
             # step3: 根据配置读取需要的输入
            for index, item in enumerate(self.input_keys):
                kv = item.split("#")
                name = kv[0]
                key = kv[1]
                item_instance = feed_dict[name]
                input_item = item_instance[InstanceName.RECORD_ID][key]
                # input_item是tensor类型
                self.input_handles[index].copy_from_cpu(input_item.numpy())
            
            wrap_load(self.input_handles, self.input_keys)
            begin_time = time.time()
            # step4: 调用PaddlePaddle的预测对象run()方法进行预测
            self.predictor.run()
            end_time = time.time()
            total_time += end_time - begin_time
            # step5: 调用PaddlePaddle的预测对象get_output_names()方法获取预测输出
            output_names = self.predictor.get_output_names()
            for i in range(len(output_names)):
                output_tensor = self.predictor.get_output_handle(output_names[i])
                predict_results.append(output_tensor)
            # step6: 回调给解析函数
            write_result_list = self.parser_handler(predict_results, sample_list=samples, params_dict=self.params)
            # step7: 将预测结果写入文件
            for write_item in write_result_list:
                size = len(write_item)
                for index, item in enumerate(write_item):
                    output_file.write(str(item))
                    if index != size - 1:
                        output_file.write("\t")
                output_file.write("\n")
        logging.info("total_time:{}".format(total_time))
        output_file.close()
    def inference_query(self, query):
        """单条query预测
        :param query
        """
        total_time = 0
        reader = self.data_set_reader.predict_reader.dataset

        data, sample = reader.api_generator(query)
        feed_dict = reader.convert_fields_to_dict(data)

        predict_results = []
        for index, item in enumerate(self.input_keys):
            kv = item.split("#")
            name = kv[0]
            key = kv[1]
            item_instance = feed_dict[name]
            input_item = item_instance[InstanceName.RECORD_ID][key]
            # input_item 是ndarray
            self.input_handles[index].copy_from_cpu(np.array(input_item))

        wrap_load(self.input_handles, self.input_keys)
        begin_time = time.time()
        self.predictor.run()
        end_time = time.time()
        total_time += end_time - begin_time

        output_names = self.predictor.get_output_names()
        for i in range(len(output_names)):
            output_tensor = self.predictor.get_output_handle(output_names[i])
            predict_results.append(output_tensor)

        # 回调给解析函数
        result_list = self.parser_handler(predict_results, sample_list=sample, params_dict=self.params)
        return result_list
上一篇
FieldReader
下一篇
Model