注册机制
文心中的注册机制
简介
通过之前的介绍我们可以看出,如果要完成一个模型训练任务,我们需要一个Model对象用来搭建神经网络,一个Reader对象用来读取数据,一个Trainer对象用来对Model和Reader对象进行调度。在文心中,为了满足多种场景和需求,我们有很多个不同的Trainer类、Model类和Reader类,这些类的作用和调用方式都是一样的。所以如果能把这些类的选择、初始化、调度都统一到一个脚本里面,我们通过指定类名,就能构造出不同的实例对象,这样就能大大的节省开发成本。文心中目前已经通过Python的注册机制实现了这种能力,通过在json文件中配置不同的类名,调用统一的运行脚本即可完成不同任务的训练和预测。
技术原理
注册机制要实现的功能就是,通过一个类名字符串去实例化出一个对象,然后给该对象传入需要的参数进行初始化。其核心就是怎么把类名字符串和Python中module(type)引用关联起来。文心中注册机制的实现方式在./wenxin/common/register.py中,结合在json文件中配置各类参数,其具体原理为:
1.把需要实现映射的类文件导入(import)到Python的解释器环境中,这一步操作对应Python中常用的import 语法,部分实现如下所示:
def import_new_module(package_name, file_name):
"""import一个新的类
:param package_name: 包名
:param file_name: 文件名,不需要文件后缀
:return:
"""
try:
if package_name != "":
full_name = package_name + "." + file_name
else:
full_name = file_name
importlib.import_module(full_name)
except Exception:
logging.error("error in import %s" % file_name)
logging.error("traceback.format_exc():\n%s" % traceback.format_exc())
2.将步骤1中导入的类文件与唯一的可识别的字符串(类名)进行关联,文心在这一部分的操作通过在装饰器的方式在每个类被import的时候,以类名为key,自身import之后的module类型为value,存入了全局变量RegisterSet中,部分实现如下:
## 以一个Model类举例
@RegisterSet.models.register
class CnnClassification(Model):
"""CnnClassification
"""
def __init__(self, model_params):
Model.__init__(self, model_params)
self.params = model_params
.......
class RegisterSet(object):
"""RegisterSet"""
field_reader = Register("field_reader") ## 存储各种FieldReader
data_set_reader = Register("data_set_reader") ## 存储各种DataSetReader
models = Register("models") ## 存储各种Model
tokenizer = Register("tokenizer") ## 存储各种Tokenizer
trainer = Register("trainer") ## 存储各种Trainer
.....
class Register(object):
"""Register"""
def __init__(self, registry_name):
self._dict = {}
self._name = registry_name
def __setitem__(self, key, value):
if not callable(value):
raise Exception("Value of a Registry must be a callable.")
if key is None:
key = value.__name__
if key in self._dict:
logging.warning("Key %s already in registry %s." % (key, self._name))
self._dict[key] = value
def register(self, param):
"""Decorator to register a function or class."""
def decorator(key, value):
"""decorator"""
self[key] = value
return value
if callable(param):
# @reg.register
return decorator(None, param)
# @reg.register('alias')
return lambda x: decorator(param, x)
3.解析json文件,拿到需要实例化的类名。
4.用步骤3中解析出的类名,在步骤2中所注册的全局变量RegisterSet中查找已经import成功的module,然后调用callable方法(Python系统方法,会调用_ new ()、 init _() )完成实例对象的构造和初始化。
....
## 从json中解析类名
model_name = params_dict.get("type")
## 从全局变量RegisterSet中查找已经import成功的module, 这里以查找Model类型的类举例。
model_class = RegisterSet.models.__getitem__(model_name)
## 调用callable方法
model = model_class(params_dict)
....
整体结构如下图所示:
使用方式
-
使用范围
目前文心中内置的可以通过json配置完成的类型包括:
- field_reader:域数据相关类,位于wenxin/data/field_reader/
- dataset_reader:数据集相关类,位于wenxin/data/data_set_reader/
- model:神经网络相关类,位于wenxin/models/
- trainer:训练调度相关类,位于wenxin/training/
- tokenizer:分词相关类,位于wenxin/data/tokenizer/
-
如何在文心中为新类别添加注册机制
- 在./wenxin/common/register.py文件中查看当前的全局变量RegisterSet中是否已有对应类别的Register,如果没有则新初始化一个Register类型的成员变量,参考已有的Register即可。
-
在你的Python文件中,找到新类的定义,在类名上方添加register装饰器。
## 添加register装饰器。 @RegisterSet.models.register class CnnClassification(Model): """CnnClassification """ def __init__(self, model_params): Model.__init__(self, model_params) self.params = model_params
-
在json中配置你的类名,在运行脚本中添加实例化的代码,如下所示:
.... ## 从json中解析类名 model_name = params_dict.get("type") ## 从全局变量RegisterSet中查找已经import成功的module, 这里以查找Model类型的类举例。 model_class = RegisterSet.models.__getitem__(model_name) ## 调用callable方法 model = model_class(params_dict) ....