Sklearn服务代码文件示例
更新时间:2021-04-12
Sklearn服务代码文件示例
在模型仓库中导入基于Sklearn库的机器学习模型时,除需导入模型文件外,也需要导入服务代码文件,其中服务代码文件用于在线部署模型时进行模型文件的加载以及进行必要的预处理和后处理逻辑。
Sklearn模型服务代码示例如下所示:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# *******************************************************************************
#
# Copyright (c) 2020 Baidu.com, Inc. All Rights Reserved
#
# *******************************************************************************
import numpy as np
ERROR_CODE_FAILED_DECODING_INPUT = 336005
MESSAGE_FAILED_DECODING_INPUT = "Failed decoding input"
ERROR_CODE_MISSING_REQUIRED_PARAMETER = 336006
MESSAGE_MISSING_REQUIRED_PARAMETER = "Missing required parameter"
class CustomException(RuntimeError):
"""
进行模型验证和部署服务必需的异常类,缺少该类在代码验证时将会失败
在处理异常数据或者请求时,推荐在`PredictWrapper`中的自定义预处理preprocess和后处理postprocess函数中抛出`CustomException`类,
并为`message`指定准确可读的错误信息,以便在服务响应包中的`error_msg`参数中返回。
"""
def __init__(self, error_code, message, orig_error=None):
""" init with error_code, message and origin exception """
super(CustomException, self).__init__(message)
self.error_code = error_code
self.orig_error = orig_error
class PredictWrapper(object):
""" 模型服务预测封装类,支持用户自定义对服务请求数据的预处理和模型预测结果的后处理函数 """
def __init__(self, model_path, use_gpu, logger):
"""
根据`model_path`初始化`PredictWrapper`类,如解析label_list.txt,加载模型输出标签id和标签名称的映射关系
:param model_path: 该目录下存放了用户选择的模型版本中包含的所有文件
"""
# 加载模型
model_filename = model_path + '/model.pkl'
with open(model_filename, 'rb') as pk_fin:
import pickle
self._model = pickle.load(pk_fin)
def preprocess(self, request_body):
"""
自定义对请求体的预处理,针对图像类模型服务,包括对图片对图像的解析、转化等
:param request_body: 请求体的json字典
:return:
data: 用于模型预测的输入。
infer_args: 用于模型预测的其他参数
request_context: 透传给自定义后处理函数`postprocess`的参数,例如指定返回预测结果的top N,过滤低score的阈值threshold.
"""
try:
features = request_body['features']
features_np = np.array(features)
except KeyError:
raise CustomException(error_code=ERROR_CODE_MISSING_REQUIRED_PARAMETER,
message=MESSAGE_FAILED_DECODING_INPUT)
except Exception as e:
raise CustomException(error_code=ERROR_CODE_FAILED_DECODING_INPUT,
message=MESSAGE_FAILED_DECODING_INPUT, orig_error=e)
return features_np, {}, {}
def predict(self, data, infer_args):
"""
模型预测
:param data: 预处理后的数据
:param infer_args: 预处理返回的`infer_args`
:return: infer_result 预测结果
"""
return self._model.predict(data, **infer_args)
def postprocess(self, infer_result, request_context):
"""
对ml模型预测结果进行后处理
:param infer_result: ML模型的预测结果
:param request_context: 自定义预处理函数中返回的`request context`
:return: request results 请求的处理结果
"""
if not isinstance(infer_result, np.ndarray):
infer_result = np.array(infer_result)
return {'categories': infer_result.tolist()}