10.Inference
更新时间:2021-03-23
简介
在文心中,我们把模型推理预测过程中常用的操作进行了统一封装、统一调度,形成一套标准流程,这套标准流程的调度就是Inference(预测器)。一个Inference实例中的标准操作有初始化运行环境、加载模型文件、数据读取(通过调用Reader实现)、预测推理、预测结果解析。
基本流程
一个Inference实例中包含了一个Reader实例和一个Paddle.Inference实例,Reader实例负责读取明文数据集并将其处理为飞桨可用的Tensor结构;Paddle.Inference实例对Reader传入的Tensor数据按加载的网络结构进行计算,并将计算结果返回给文心的Inference实例,Inference对结果进行自定义解析回传给用户。其具体流程下所示:

基础操作
-
批量预测:
def do_inference(self): """ :return: """ logging.info("start do inference....") total_time = 0 output_path = self.param.get("output_path", None) if not output_path or output_path == "": if not os.path.exists("./output"): os.makedirs("./output") output_path = "./output/predict_result.txt" output_file = open(output_path, "w+", encoding='utf-8') ## 通过Reader迭代读取预测数据集中的数据,组成batch。 reader = self.data_set_reader.predict_reader.data_generator() for sample in reader(): sample_ids = sample[0] sample_entity_list = sample[1] sample_dict = self.data_set_reader.predict_reader.convert_fields_to_dict(sample_ids, need_emb=False) input_list = [] for item in self.input_keys: kv = item.split("#") name = kv[0] key = kv[1] item_instance = sample_dict[name] input_item = item_instance[InstanceName.RECORD_ID][key] input_list.append(input_item) ## 将明文id转换为飞桨需要的tensor类型 inputs = [array2tensor(ndarray) for ndarray in input_list] begin_time = time.time() ## 使用Paddle.Inference实例进行预测推理 result = self.inference.run(inputs) end_time = time.time() total_time += end_time - begin_time ## 飞桨的预测结果回调给文心inference实例的解析函数,解析函数有各个任务自己定义,文心已经提供了常见任务的解析方法,用户也可以自定义。 write_result_list = self.parser_handler(result, sample_list=sample_entity_list, params_dict=self.param) for write_item in write_result_list: size = len(write_item) for index, item in enumerate(write_item): ## 用户解析后的预测结果回写到文件中。 output_file.write(item) if index != size - 1: output_file.write("\t") output_file.write("\n") logging.info("total_time:{}".format(total_time)) output_file.close() -
单样本预测:与批量预测的方式一样,不再赘述。
def inference_query(self, query_list): """ :param query_list: :return: """ total_time = 0 sample_id = self.data_set_reader.predict_reader.api_generator(query_list) sample_dict = self.data_set_reader.predict_reader.convert_fields_to_dict(sample_id, need_emb=False) input_list = [] for item in self.input_keys: kv = item.split("#") name = kv[0] key = kv[1] item_instance = sample_dict[name] input_item = item_instance[InstanceName.RECORD_ID][key] input_list.append(input_item) inputs = [array2tensor(ndarray) for ndarray in input_list] begin_time = time.time() result = self.inference.run(inputs) end_time = time.time() total_time += end_time - begin_time # 回调给解析函数 return self.parser_handler(result, params_dict=self.param)
文心现有的预置Inference
文心目前提供了1个通用的Inference,覆盖了一些比较常见的NLP领域的经典任务,包括文本分类、文本匹配、序列标注、信息抽取等,位于./wenxin/inference/目录下,如下所示:
├── __init__.py
├── inference.py ## 通用Inference,支持文本分类、匹配、序列标注、信息抽取等常见任务。进阶使用
文心中提供了通用的Inference流程,如果用户需要针对自己的业务场景进行自定义优化使用的话,请参考详细的接口设计与自定义Inference。
