Inference
更新时间:2021-05-13
基类定义
Inference的基类定义在./wenxin/inference/inference.py
class Inference(object):
"""Inferece:模型预测的基类,主要操作有:
1.解析input_data的结构
2.解析参数,构造inference
3.启动data_generator,开始预测
4.回掉预测结果到parser_handler中进行解析
"""
def __init__(self, param, data_set_reader, parser_handler):
"""
:param param: 运行的基本参数设置
:param data_set_reader: 运行的基本参数设置
:param parser_handler: 解析预测结果的回掉接口
"""
pass
def init_env(self, meta_info):
"""
运行时环境初始化,模型加载,基类已经完成,用户无需自定义。
:return:
"""
pass
def do_inference(self):
"""
从数据集文件夹读取数据,进行批量预测。
:return:
"""
pass
def inference_query(self, query_list):
"""
单条样本预测
:param query_list: 待预测的单条样本,这里的类型必须是list,list的每一项是这条样本的一个域。比如一条单句分类任务,其query_list=[text], 一个匹配任务,其query_list=[text_a, text_b]
:return:
"""
pass核心函数
Inference做为基类,将预测过程中常用的操作(如运行时环境初始化、模型加载)已经统一封装实现,不需要在子类中重新实现。需要用户按自己业务场景在子类中自定义的核心函数为以下两个:
- do_inference(self):从数据集文件夹读取数据,进行批量预测。
- inference_query(self, query_list):单条样本预测。
自定义实现示例
文心中目前Inference基类已经能覆盖所有预测任务,所以直接看其对do_inference和inference_query的实现即可,详见以下代码及核心部分的注释。
def do_inference(self):
"""从数据集文件夹读取数据,进行批量预测。
:return:
"""
logging.info("start do inference....")
total_time = 0
output_path = self.param.get("output_path", None)
# 预测结果回写到指定文件
if not output_path or output_path == "":
if not os.path.exists("./output"):
os.makedirs("./output")
output_path = "./output/predict_result.txt"
output_file = open(output_path, "w+")
# step1: 从预测数据集读取数据,组成适合model计算的tensor结构。这里使用预置的datasetreader来做数据迭代器。
reader = self.data_set_reader.predict_reader.data_generator()
# step2:启动数据迭代器,按batch读取明文样本,将其转换为id
for sample in reader():
sample_ids = sample[0]
sample_entity_list = sample[1]
sample_dict = self.data_set_reader.predict_reader.convert_fields_to_dict(sample_ids, need_emb=False)
input_list = []
for item in self.input_keys:
kv = item.split("#")
name = kv[0]
key = kv[1]
item_instance = sample_dict[name]
input_item = item_instance[InstanceName.RECORD_ID][key]
input_list.append(input_item)
# step3: 将array的id数值转换成Paddle需要的tensor类型,这里直接调用文心封装好的方法(array2tensor)即可,该方法位于./wenxin/utils/util_helper.py
inputs = [array2tensor(ndarray) for ndarray in input_list]
begin_time = time.time()
# step4: 调用Paddle的预测对象(self.inference,基类的init方法中已经初始化,用户无需重写,直接调用即可),调用 self.inference.run(inputs)
result = self.inference.run(inputs)
end_time = time.time()
total_time += end_time - begin_time
# step5:将预测结果回调给外部的解析函数进行解析,每个task的解析都不一样,需要用户自定义
write_result_list = self.parser_handler(result, sample_list=sample_entity_list, params_dict=self.param)
for write_item in write_result_list:
size = len(write_item)
for index, item in enumerate(write_item):
# step6: 将解析后的预测结构回写到指定文件,当然也可以不写文件,直接print也行。
output_file.write(item)
if index != size - 1:
output_file.write("\t")
output_file.write("\n")
logging.info("total_time:{}".format(total_time))
output_file.close()
def inference_query(self, query_list):
"""单条样本预测
:param query_list:
:return:
"""
total_time = 0
# step1:使用datasetreader的api_generator接口获取单条query的id值。
sample_id = self.data_set_reader.predict_reader.api_generator(query_list)
sample_dict = self.data_set_reader.predict_reader.convert_fields_to_dict(sample_id, need_emb=False)
input_list = []
for item in self.input_keys:
kv = item.split("#")
name = kv[0]
key = kv[1]
item_instance = sample_dict[name]
input_item = item_instance[InstanceName.RECORD_ID][key]
input_list.append(input_item)
# step2: 将array的id数值转换成Paddle需要的tensor类型,这里直接调用文心封装好的方法(array2tensor)即可,该方法位于./wenxin/utils/util_helper.py
inputs = [array2tensor(ndarray) for ndarray in input_list]
begin_time = time.time()
# step3: 调用Paddle的预测对象(self.inference,基类的init方法中已经初始化,用户无需重写,直接调用即可),调用 self.inference.run(inputs)
result = self.inference.run(inputs)
end_time = time.time()
total_time += end_time - begin_time
# step4:将预测结果回调给外部的解析函数进行解析,每个task的解析都不一样,需要用户自定义
return self.parser_handler(result, params_dict=self.param)