Inference
更新时间:2022-12-17
Infernece
基类定义
Inference的基类定义在wenxin/controller/inference.py
# -*- coding: utf-8 -*
"""
模型的预测控制器,核心成员有model,reader;核心方法有:
1.模型加载
2.预测:批量预测与单query预测
3.结果解析与回传
"""
class BaseInference(object):
def __init__(self, params, data_set_reader, parser_handler):
"""
:param params:前端json中设置的参数
:param data_set_reader: 预测集reader
:param parser_handler: 飞桨预测结果通过parser_handler参数回调到具体的任务中,由用户控制具体结果解析
"""
def inference_batch(self):
"""批量预测
"""
raise NotImplementedError
def inference_query(self, query):
"""单query预测
"""
raise NotImplementedError
def init_env(self):
"""预测环境初始化
"""
raise NotImplementedError
def load_inference_model(self, model_path, thread_num=1):
"""加载预训练模型
:param model_path:
:param thread_num
:return: inference
"""
def parser_input_keys(self):
"""从meta文件中解析出模型预测过程中需要feed的变量名称,与model.forward的fields_dict对应起来
"""
核心函数
BaseInference做为基类,将预测过程中常用的操作(如运行时环境初始化、模型加载)已经统一封装实现,不需要在子类中重新实现。需要用户按自己业务场景在子类中自定义的核心函数为以下两个:
- inference_batch(self):从数据集文件夹读取数据,进行批量预测。
- inference_query(self, query):单条样本预测。
自定义实现示例
文心中目前BaseInference基类已经能覆盖所有预测任务,以文本匹配的wenxin_appzoo/tasks/text_matching/inference/custom_inference.py为例,详见以下代码及核心部分的注释。
# -*- coding: utf-8 -*
"""
对内工具包(major)中最常用的inference,必须继承自文心core中的BaseInference基类,必须实现inference_batch, inference_query方法。
"""
class CustomInference(BaseInference):
"""CustomInference
"""
def __init__(self, params, data_set_reader, parser_handler):
"""
:param params:前端json中设置的参数
:param data_set_reader: 预测集reader
:param parser_handler: 飞桨预测结果通过parser_handler参数回调到具体的任务中,由用户控制具体结果解析
"""
BaseInference.__init__(self, params, data_set_reader, parser_handler)
def inference_batch(self):
"""
批量预测
"""
logging.info("start do inference....")
total_time = 0
# 预测结果回写到指定文件
output_path = self.params.get("output_path", None)
if not output_path or output_path == "":
if not os.path.exists("./output"):
os.makedirs("./output")
output_path = "./output/predict_result.txt"
output_file = open(output_path, "w+")
# step1: 从预测数据集读取数据,组成适合model计算的tensor结构。这里使用预置的datasetreader来做数据迭代器。
dg = self.data_set_reader.predict_reader
# step2:启动数据迭代器,按batch读取明文样本,将其转换为id
for batch_id, data_t in enumerate(dg()):
data = data_t[0]
samples = data_t[1]
feed_dict = dg.dataset.convert_fields_to_dict(data)
predict_results = []
# step3: 根据配置读取需要的输入
for index, item in enumerate(self.input_keys):
kv = item.split("#")
name = kv[0]
key = kv[1]
item_instance = feed_dict[name]
input_item = item_instance[InstanceName.RECORD_ID][key]
# input_item是tensor类型
self.input_handles[index].copy_from_cpu(input_item.numpy())
wrap_load(self.input_handles, self.input_keys)
begin_time = time.time()
# step4: 调用PaddlePaddle的预测对象run()方法进行预测
self.predictor.run()
end_time = time.time()
total_time += end_time - begin_time
# step5: 调用PaddlePaddle的预测对象get_output_names()方法获取预测输出
output_names = self.predictor.get_output_names()
for i in range(len(output_names)):
output_tensor = self.predictor.get_output_handle(output_names[i])
predict_results.append(output_tensor)
# step6: 回调给解析函数
write_result_list = self.parser_handler(predict_results, sample_list=samples, params_dict=self.params)
# step7: 将预测结果写入文件
for write_item in write_result_list:
size = len(write_item)
for index, item in enumerate(write_item):
output_file.write(str(item))
if index != size - 1:
output_file.write("\t")
output_file.write("\n")
logging.info("total_time:{}".format(total_time))
output_file.close()
def inference_query(self, query):
"""单条query预测
:param query
"""
total_time = 0
reader = self.data_set_reader.predict_reader.dataset
data, sample = reader.api_generator(query)
feed_dict = reader.convert_fields_to_dict(data)
predict_results = []
for index, item in enumerate(self.input_keys):
kv = item.split("#")
name = kv[0]
key = kv[1]
item_instance = feed_dict[name]
input_item = item_instance[InstanceName.RECORD_ID][key]
# input_item 是ndarray
self.input_handles[index].copy_from_cpu(np.array(input_item))
wrap_load(self.input_handles, self.input_keys)
begin_time = time.time()
self.predictor.run()
end_time = time.time()
total_time += end_time - begin_time
output_names = self.predictor.get_output_names()
for i in range(len(output_names)):
output_tensor = self.predictor.get_output_handle(output_names[i])
predict_results.append(output_tensor)
# 回调给解析函数
result_list = self.parser_handler(predict_results, sample_list=sample, params_dict=self.params)
return result_list