Model
更新时间:2022-12-17
Model
基类定义
Model的基类定义在./wenxin/models/model.py中。
# -*- coding: utf-8 -*
"""
文心中的深度学习模型对象,使用飞桨最新的动态图方式建模,同时支持静态图和动态图的运行方式,其核心方法是:
1.组织网络结构:动静结合的组网方式,包含structure(构造方法)和forward(前向计算)两个主要方法
2.设定当前网络的模型评估方式
3.同时支持动态图和静态图
"""
import paddle
class BaseModel(paddle.nn.Layer):
def __init__(self, model_params):
paddle.nn.Layer.__init__(self)
self.model_params = model_params
self.is_dygraph = self.model_params.get("is_dygraph", 0)
self.lr = None # 学习率,必须在子类中实现
self.lr_scheduler = None # 学习率的衰减设置,必须在子类中实现
self.optimizer = None # 优化器设定,必须在子类中实现
def structure(self):
"""
网络结构组织
:return:
"""
raise NotImplementedError
def forward(self, fields_dict, phase):
"""
前向计算
:param fields_dict:
:param phase:
:return:
"""
raise NotImplementedError
def get_metrics(self, forward_return_dict, meta_info, phase):
"""
模型效果评估
:param forward_return_dict: 前向计算得出的结果
:param meta_info: 常用的meta信息,如step, used_time, gpu_id等
:param phase: 当前调用的阶段,包含训练和评估
:return:
"""
raise NotImplementedError
def fields_process(self, fields_dict, phase):
"""
对fields_dict中序列化好的id按需做二次处理。
:return: 处理好的fields
"""
raise NotImplementedError
def set_optimizer(self):
"""优化器设置
:return: optimizer
"""
raise NotImplementedError
核心函数
Model作为基类,需要用户按自己业务场景在子类中自定义的核心函数为以下6个:
- structure(self)
- forward(self, fields_dict, phase)
- get_metrics(self, forward_return_dict, meta_info, phase)
- fields_process(self, fields_dict, phase)
- set_optimizer(self)
自定义实现示例
以文心提供的一个基于bow的分类网络BowClassification为例,可参考文件wenxin_appzoo/tasks/text_classification/model/bow_classification.py,解释Model中6个核心函数的实现,详见以下代码及核心部分的注释。
import paddle
from wenxin.common.register import RegisterSet
from wenxin.common.rule import InstanceName
from wenxin.model.model import BaseModel
from wenxin.modules.encoder import BoWEncoder
from model.base_cls import BaseClassification
@RegisterSet.models.register
class BowClassification(BaseClassification):
"""BowClassification
"""
def __init__(self, model_params):
"""
"""
BaseModel.__init__(self, model_params)
def structure(self):
"""网络结构组织,定义需要用到的成员变量即可
:return: None
"""
self.dict_dim = self.model_params.get('vocab_size', 33261)
self.emb_dim = self.model_params.get('emb_dim', 128)
self.hid_dim = self.model_params.get('hid_dim', 128)
self.hid_dim2 = self.model_params.get('hid_dim2', 96)
self.num_labels = self.model_params.get('num_labels', 2)
self.embedding = paddle.nn.Embedding(num_embeddings=self.dict_dim, embedding_dim=self.emb_dim)
self.bow_encoder = BoWEncoder(self.emb_dim)
self.fc_1 = paddle.nn.Linear(in_features=self.hid_dim, out_features=self.hid_dim)
self.fc_2 = paddle.nn.Linear(in_features=self.hid_dim, out_features=self.hid_dim2)
self.fc_prediction = paddle.nn.Linear(in_features=self.hid_dim2, out_features=self.num_labels)
self.loss = paddle.nn.CrossEntropyLoss(use_softmax=False)
def forward(self, fields_dict, phase):
"""
:param fields_dict: 动态图模式下是tensor格式,静态图模式下是python数组
:param phase:
:return:
"""
instance_text_a = fields_dict["text_a"]
record_id_text_a = instance_text_a[InstanceName.RECORD_ID]
text_src = record_id_text_a[InstanceName.SRC_IDS]
emb_output = self.embedding(text_src)
bow_output = self.bow_encoder(emb_output)
# bow_output = paddle.sum(emb_output, axis=1)
fc_1_output = paddle.tanh(self.fc_1(bow_output))
fc_2_output = paddle.tanh(self.fc_2(fc_1_output))
prediction = self.fc_prediction(fc_2_output)
probs = paddle.nn.functional.softmax(prediction)
if phase == InstanceName.TRAINING or phase == InstanceName.EVALUATE or phase == InstanceName.TEST:
instance_label = fields_dict["label"]
record_id_label = instance_label[InstanceName.RECORD_ID]
label = record_id_label[InstanceName.SRC_IDS]
# label = paddle.to_tensor(label)
cost = self.loss(probs, label)
forward_return_dict = {
InstanceName.PREDICT_RESULT: probs,
InstanceName.LABEL: label,
InstanceName.LOSS: cost
}
elif phase == InstanceName.SAVE_INFERENCE:
"save inference model with jit"
target_predict_list = [probs]
target_feed_list = [text_src]
# 以json的形式存入模型的meta文件中,在离线预测的时候用,field_name#field_tensor_name
target_feed_name_list = ["text_a#src_ids"]
forward_return_dict = {
InstanceName.TARGET_FEED: target_feed_list,
InstanceName.TARGET_PREDICTS: target_predict_list,
InstanceName.TARGET_FEED_NAMES: target_feed_name_list
}
else:
forward_return_dict = {
InstanceName.PREDICT_RESULT: probs
}
return forward_return_dict
如您个人编写网络,如希望将网络改为基于LSTM,那么只需要修改该文件中的structure和forward函数即可。上述网络继承自BaseClassification,在该类中实现set_optimizer、get_metrics和fields_process函数,具体可参考wenxin_appzoo/tasks/text_classification/model/base_cls.py,详见以下代码及核心部分的注释。
@RegisterSet.models.register
class BaseClassification(BaseModel):
"""BaseClassification
"""
def __init__(self, model_params):
BaseModel.__init__(self, model_params)
def structure(self):
raise NotImplementedError
def forward(self, fields_dict, phase):
raise NotImplementedError
def set_optimizer(self):
"""优化器设置
:return: optimizer
"""
opt_param = self.model_params.get('optimization', None)
if opt_param:
self.lr = opt_param.get('learning_rate', 2e-5)
else:
self.lr = 2e-5
self.optimizer = paddle.optimizer.Adam(learning_rate=self.lr, parameters=self.parameters())
return self.optimizer
def get_metrics(self, forward_return_dict, meta_info, phase):
"""
:param forward_return_dict: 前向计算得出的结果
:param meta_info: 常用的meta信息,如step, used_time, gpu_id等
:param phase: 当前调用的阶段,包含训练和评估
:return:
"""
predictions = forward_return_dict[InstanceName.PREDICT_RESULT]
label = forward_return_dict[InstanceName.LABEL]
if self.is_dygraph:
if isinstance(predictions, list):
predictions = [item.numpy() for item in predictions]
else:
predictions = predictions.numpy()
if isinstance(label, list):
label = [item.numpy() for item in label]
else:
label = label.numpy()
metrics_acc = metrics.Acc()
acc = metrics_acc.eval([predictions, label])
metrics_pres = metrics.Precision()
precision = metrics_pres.eval([predictions, label])
if phase == InstanceName.TRAINING:
step = meta_info[InstanceName.STEP]
time_cost = meta_info[InstanceName.TIME_COST]
loss = forward_return_dict[InstanceName.LOSS]
if isinstance(loss, paddle.Tensor):
loss_np = loss.numpy()
mean_loss = np.mean(loss_np)
else:
mean_loss = np.mean(loss)
logging.info("phase = {0} loss = {1} acc = {2} precision = {3} step = {4} time_cost = {5}".format(
phase, mean_loss, acc, precision, step, round(time_cost, 4)))
if phase == InstanceName.EVALUATE or phase == InstanceName.TEST:
time_cost = meta_info[InstanceName.TIME_COST]
step = meta_info[InstanceName.STEP]
logging.info("phase = {0} acc = {1} precision = {2} time_cost = {3} step = {4}".format(
phase, acc, precision, round(time_cost, 4), step))
metrics_return_dict = collections.OrderedDict()
metrics_return_dict["acc"] = acc
metrics_return_dict["precision"] = precision
return metrics_return_dict
def fields_process(self, fields_dict, phase):
"""
对fields_dict中序列化好的id按需做二次处理。
:return: 处理好的fields
"""
pass