Trainer
更新时间:2022-12-17
Trainer
基类定义
Trainer的两个基类定义在./wenxin/controller/中,动态图trainer的基类为BaseDynamicTrainer,静态图trainer的基类为BaseStaticTrainer。
BaseDynamicTrainer
- 基本定义:BaseDynamicTrainer做为基类,将训练过程中常用的操作(如运行时环境初始化、网络结构初始化、模型保存等)已经统一封装实现,不需要在子类中重新实现,其基本定义如下:
class BaseDynamicTrainer(object):
"""BaseDynamicTrainer
"""
def __init__(self, params, data_set_reader, model_class):
"""
:param params:
:param data_set_reader:
:param model_class:
"""
...
self.model_class.structure()
self.load_pretrain_model()
self.optimizer = self.model_class.set_optimizer()
...
def do_train(self):
"""
启动数据集循环,开始训练,需要各个task自定义实现,文心目前已经在各个task中提供了通用的CustomDynamicTrainer,覆盖绝大部分的训练需求。
"""
raise NotImplementedError
def do_evaluate(self, reader, phase, step):
"""模型效果评估,需要各个task自定义实现,文心目前已经在各个task中提供了通用的CustomDynamicTrainer,覆盖绝大部分的训练需求。
:param reader:
:param phase:
:param step:
:return: loss
"""
raise NotImplementedError
def save_models(self, step, fields_dict, save_checkpoint=True, save_inference=True):
"""模型保存:checkpoints文件用来热启动,inference文件用来预测推理
:param step:
:param fields_dict:
:param save_checkpoint
:param save_inference
:return:
保存模型,基类已经实现,用户不需要自定义
"""
def load_pretrain_model(self):
"""加载预训练模型或者热启动模型参数
"""
...
-
核心函数
- do_train(self):模型训练函数
- do_evaluate(self, reader, phase, step): 模型评估函数
-
自定义实现
以wenxin_appzoo/tasks/text_classification分类任务为例。
def do_train(self):
"""
:return:
"""
dg = self.data_set_reader.train_reader
steps = 1
opt_params = self.original_model.model_params.get('optimization', None)
# 设置混合精度相关的参数,可以没有,
init_loss_scaling = opt_params.get("init_loss_scaling", 1.0)
incr_every_n_steps = opt_params.get("incr_every_n_steps", 1000)
decr_every_n_nan_or_inf = opt_params.get("decr_every_n_nan_or_inf", 2)
incr_ratio = opt_params.get("incr_ratio", 2.0)
decr_ratio = opt_params.get("decr_ratio", 0.8)
if self.use_amp:
self.scaler = paddle.amp.GradScaler(enable=self.use_amp,
init_loss_scaling=init_loss_scaling,
incr_ratio=incr_ratio,
decr_ratio=decr_ratio,
incr_every_n_steps=incr_every_n_steps, decr_every_n_nan_or_inf=decr_every_n_nan_or_inf)
if self.multi_devices:
self.scaler = fleet.distributed_scaler(self.scaler)
time_begin = time.time()
# 启动训练集数据读取器
for batch_id, data in enumerate(dg()):
self.model_class.train()
with paddle.amp.auto_cast(enable=self.use_amp):
example = self.data_set_reader.train_reader.dataset.convert_fields_to_dict(data, need_emb=False)
forward_out = self.model_class(example, phase=InstanceName.TRAINING)
loss = forward_out[InstanceName.LOSS]
if self.use_amp:
loss = self.scaler.scale(loss)
loss.backward()
self.scaler.minimize(self.optimizer, loss)
else:
loss.backward()
self.optimizer.minimize(loss)
self.optimizer.step()
self.model_class.clear_gradients()
if self.original_model.lr_scheduler:
cur_lr = self.original_model.lr_scheduler.get_lr()
self.original_model.lr_scheduler.step()
else:
cur_lr = self.original_model.lr
self.optimizer.clear_grad()
# 在合适的时机进行当前step的数据评估
if steps % self.params["train_log_step"] == 0:
metrics_output = self.original_model.get_metrics(forward_out, meta_info, InstanceName.TRAINING)
time_begin = time.time()
# 在合适的时机对测试集、验证集上的数据进行评估
if steps % self.params["eval_step"] == 0:
if self.params["is_eval_dev"]:
self.do_evaluate(self.data_set_reader.dev_reader, InstanceName.EVALUATE, steps)
if self.params["is_eval_test"]:
self.do_evaluate(self.data_set_reader.test_reader, InstanceName.TEST, steps)
# 适当的时候进行模型保存
if steps % self.params["save_model_step"] == 0 and self.worker_index == 0:
self.save_models(steps, example)
steps += 1
....
def do_evaluate(self, reader, phase, step):
"""
:param reader:
:param phase:
:param step:
:return: loss
"""
step = 0
with paddle.no_grad():
time_begin = time.time()
# 先切换到eval模式
self.model_class.eval()
fetch_output_dict = collections.OrderedDict()
for batch_id, data in enumerate(reader()):
step += 1
example = reader.dataset.convert_fields_to_dict(data, need_emb=False)
forward_out = self.model_class(example, phase=phase)
for key, value in forward_out.items():
fetch_output_dict.setdefault(key, []).append(value)
time_end = time.time()
used_time = time_end - time_begin
meta_info = collections.OrderedDict()
meta_info[InstanceName.STEP] = step
meta_info[InstanceName.TIME_COST] = used_time
metrics_output = self.original_model.get_metrics(fetch_output_dict, meta_info, phase)
self.model_class.train()
logging.info("eval step = {0}".format(step))
BaseStaticTrainer
- 基本定义:BaseStaticTrainer做为基类,将训练过程中常用的操作(如运行时环境初始化、网络结构初始化、模型保存等)已经统一封装实现,不需要在子类中重新实现,其基本定义如下:
....
class BaseStaticTrainer(object):
"""BaseTrainer"""
def __init__(self, params, data_set_reader, model_class):
"""
1.运行环境初始化 2.program初始化 3.计算图网络导入 4.模型参数导入 5.运行(reader) 6.模型导出
:param params: 运行的基本参数设置
:param data_set_reader: 运行的基本参数设置
:param model_class: 使用的是哪个model
"""
self.params = params
self.data_set_reader = data_set_reader
self.model_class = model
# 运行模式设置为静态图
self.enable_static = True
self.is_recompute = self.params.get("is_recompute", 0)
if 'output_path' in self.params.keys() and self.params["output_path"]:
self.save_checkpoints_path = os.path.join(self.params["output_path"], "save_checkpoints")
self.save_inference_model_path = os.path.join(self.params["output_path"], "save_inference_model")
else:
self.save_checkpoints_path = "./output/save_checkpoints/"
self.save_inference_model_path = "./output/save_inference_model/"
self.forward_train_output = {}
self.fetch_list_train = []
self.fetch_list_evaluate = []
self.fetch_list_train_key = []
self.fetch_list_evaluate_key = []
self.parser_meta()
self.use_fleet = False
self.init_env_static()
....
def init_env_static(self):
"""
初始化静态图的运行时环境:包括:program、executor、fleet、cuda、place,基类已经实现,用户不需要自定义
"""
....
def init_static_model_net(self):
"""
初始化网络结构,基类已经实现,用户不需要自定义
"""
....
def do_train(self):
"""
启动模型训练,训练过程中可以进行模型评估,需要各个task自定义实现,文心目前已经在各个task中提供了通用的CustomTrainer,覆盖绝大部分的训练需求。
:return:
"""
raise NotImplementedError
def do_evaluate(self, reader, phase, step):
"""
模型评估方法,需要各个task自定义实现,文心目前已经在各个task中提供了通用的CustomTrainer,覆盖绝大部分的训练需求。
:param reader:
:param phase:
:param program:
:param step:
:return: loss
"""
raise NotImplementedError
def save_models(self, steps, save_checkpoint=True, save_inference=True):
"""
保存模型,基类已经实现,用户不需要自定义
:param steps:
:param save_checkpoint:
:param save_inference:
:return:
"""
....
-
核心函数:需要用户按自己业务场景在子类中自定义的核心函数为以下两个:
- do_train:模型训练函数
- do_evaluate:模型评估函数
- 自定义实现示例:以wenxin_appzoo/tasks/text_classification分类任务为例。
def do_train(self):
""" 启动数据集循环,开始训练
:return:
"""
....
dg = self.data_set_reader.train_reader
steps = 1
time_begin = time.time()
# 开始迭代训练集数据
for batch_id, data in enumerate(dg()):
feed_dict = self.data_set_reader.train_reader.dataset.convert_input_list_to_dict(data)
if steps % self.params["train_log_step"] != 0:
# train_exe的run方法是指进行一次网络计算,一次是一个batch,如果不需要输出计算结果,则fetch_list参数传空列别,需要输出计算结果,泽fetch_list参数传入需要打印的tensor列表,这里直接使用fetch_list_train变量即可,fetch_list_train由model中的forward的返回值决定。
if self.use_fleet:
self.train_exe.run(program=self.train_program, feed=feed_dict, fetch_list=[], return_numpy=True)
else:
self.train_exe.run(feed=feed_dict, fetch_list=[], return_numpy=True)
else:
if self.use_fleet:
fetch_output = self.train_exe.run(program=self.train_program,
feed=feed_dict,
fetch_list=self.fetch_list_train,
return_numpy=True)
else:
fetch_output = self.train_exe.run(feed=feed_dict,
fetch_list=self.fetch_list_train,
return_numpy=True)
fetch_output_dict = collections.OrderedDict()
for key, value in zip(self.fetch_list_train_key, fetch_output):
if key == InstanceName.LOSS and not self.return_numpy:
value = np.array(value)
fetch_output_dict[key] = value
time_end = time.time()
used_time = time_end - time_begin
meta_info = collections.OrderedDict()
meta_info[InstanceName.STEP] = steps
meta_info[InstanceName.GPU_ID] = self.gpu_id
meta_info[InstanceName.TIME_COST] = used_time
# 调用model中的get_metrics方法进行效果评估计算。
metrics_output = self.model_class.get_metrics(fetch_output_dict, meta_info, InstanceName.TRAINING)
if self.model_class.lr_scheduler:
# 这一步一定要有,没有的话lr_scheduler不会生效,学习率一直为0
self.model_class.lr_scheduler.step()
# 在合适的step对验证集和测试集进行评估
if steps % self.params["eval_step"] == 0:
if self.params["is_eval_dev"]:
self.do_evaluate(self.data_set_reader.dev_reader, InstanceName.EVALUATE, steps)
if self.params["is_eval_test"]:
self.do_evaluate(self.data_set_reader.test_reader, InstanceName.TEST, steps)
if self.trainer_id == 0:
# 在合适的时候进行模型保存
if steps % self.params["save_model_step"] == 0:
self.save_model(steps)
steps += 1
if self.params["is_eval_dev"]:
logging.info("Final evaluate result: ")
metrics_output = self.do_evaluate(self.data_set_reader.dev_reader, InstanceName.EVALUATE, steps)
self.eval_metrics = metrics_output
if self.params["is_eval_test"]:
logging.info("Final test result: ")
self.do_evaluate(self.data_set_reader.test_reader, InstanceName.TEST, steps)
if self.trainer_id == 0:
self.save_model(steps)
def do_evaluate(self, reader, phase, step):
"""在当前的训练状态下,对某个测试集进行评估
:param reader:待评估数据集
:param phase:当前的运行阶段
:param step:当前的运行步数
"""
# 定义变量保存整个数据集上的模型预测结果
all_metrics_tensor_value = []
i = 0
time_begin = time.time()
# 启动数据集迭代器
for batch_id, data in enumerate(reader()):
feed_dict = reader.dataset.convert_input_list_to_dict(data)
# 调用run方法启动模型计算,并返回计算结果
metrics_tensor_value = self.executor.run(program=self.test_program,
feed=feed_dict,
fetch_list=self.fetch_list_evaluate,
return_numpy=True)
# 计算结果存入all_metrics_tensor_value
if i == 0:
all_metrics_tensor_value = [[tensor] for tensor in metrics_tensor_value]
else:
for j in range(len(metrics_tensor_value)):
one_tensor_value = all_metrics_tensor_value[j]
all_metrics_tensor_value[j] = one_tensor_value + [metrics_tensor_value[j]]
i += 1
fetch_output_dict = collections.OrderedDict()
# 遍历所有的计算结果,将其转为model类中的get_metrics方法可用的dict结构,并调用get_metrics方法计算指标结果。这里是可以自定义的,可以不使用model.get_metrics进行计算,直接在这里写你自己的指标计算代码即可。
for key, value in zip(self.fetch_list_evaluate_key, all_metrics_tensor_value):
if key == InstanceName.LOSS and not self.return_numpy:
value = [np.array(item) for item in value]
fetch_output_dict[key] = value
time_end = time.time()
used_time = time_end - time_begin
meta_info = collections.OrderedDict()
meta_info[InstanceName.STEP] = step
meta_info[InstanceName.GPU_ID] = self.gpu_id
meta_info[InstanceName.TIME_COST] = used_time
metrics_output = self.model_class.get_metrics(fetch_output_dict, meta_info, phase)
return metrics_output