资讯 文档
技术能力
语音技术
文字识别
人脸与人体
图像技术
语言与知识
视频技术

注册机制

文心中的注册机制

简介

通过之前的介绍我们可以看出,如果要完成一个模型训练任务,我们需要一个Model对象用来搭建神经网络,一个Reader对象用来读取数据,一个Trainer对象用来对Model和Reader对象进行调度。在文心中,为了满足多种场景和需求,我们有很多个不同的Trainer类、Model类和Reader类,这些类的作用和调用方式都是一样的。所以如果能把这些类的选择、初始化、调度都统一到一个脚本里面,我们通过指定类名,就能构造出不同的实例对象,这样就能大大的节省开发成本。文心中目前已经通过Python的注册机制实现了这种能力,通过在json文件中配置不同的类名,调用统一的运行脚本即可完成不同任务的训练和预测。

技术原理

注册机制要实现的功能就是,通过一个类名字符串去实例化出一个对象,然后给该对象传入需要的参数进行初始化。其核心就是怎么把类名字符串和Python中module(type)引用关联起来。文心中注册机制的实现方式在./wenxin/common/register.py中,结合在json文件中配置各类参数,其具体原理为:

1.把需要实现映射的类文件导入(import)到Python的解释器环境中,这一步操作对应Python中常用的import 语法,部分实现如下所示:

def import_new_module(package_name, file_name):
    """import一个新的类
    :param package_name: 包名
    :param file_name: 文件名,不需要文件后缀
    :return:
    """
    try:
        if package_name != "":
            full_name = package_name + "." + file_name
        else:
            full_name = file_name
        importlib.import_module(full_name)

    except Exception:
        logging.error("error in import %s" % file_name)
        logging.error("traceback.format_exc():\n%s" % traceback.format_exc())

2.将步骤1中导入的类文件与唯一的可识别的字符串(类名)进行关联,文心在这一部分的操作通过在装饰器的方式在每个类被import的时候,以类名为key,自身import之后的module类型为value,存入了全局变量RegisterSet中,部分实现如下:

## 以一个Model类举例
@RegisterSet.models.register
class CnnClassification(Model):
    """CnnClassification
    """
    def __init__(self, model_params):
        Model.__init__(self, model_params)
        self.params = model_params
        
  .......
class RegisterSet(object):
    """RegisterSet"""
    field_reader = Register("field_reader")  ## 存储各种FieldReader
    data_set_reader = Register("data_set_reader")  ## 存储各种DataSetReader
    models = Register("models")  ## 存储各种Model
    tokenizer = Register("tokenizer")   ## 存储各种Tokenizer
    trainer = Register("trainer")  ## 存储各种Trainer
    .....
    
class Register(object):
    """Register"""

    def __init__(self, registry_name):
        self._dict = {}
        self._name = registry_name

    def __setitem__(self, key, value):
        if not callable(value):
            raise Exception("Value of a Registry must be a callable.")
        if key is None:
            key = value.__name__
        if key in self._dict:
            logging.warning("Key %s already in registry %s." % (key, self._name))
        self._dict[key] = value

    def register(self, param):
        """Decorator to register a function or class."""

        def decorator(key, value):
            """decorator"""
            self[key] = value
            return value

        if callable(param):
            # @reg.register
            return decorator(None, param)
        # @reg.register('alias')
        return lambda x: decorator(param, x)

3.解析json文件,拿到需要实例化的类名。

4.用步骤3中解析出的类名,在步骤2中所注册的全局变量RegisterSet中查找已经import成功的module,然后调用callable方法(Python系统方法,会调用_ new ()、 init _() )完成实例对象的构造和初始化。

....
## 从json中解析类名
model_name = params_dict.get("type")
## 从全局变量RegisterSet中查找已经import成功的module, 这里以查找Model类型的类举例。
model_class = RegisterSet.models.__getitem__(model_name)
## 调用callable方法
model = model_class(params_dict)

....
 

整体结构如下图所示:

bj-259058f6a109ce84fe469e526b20668fe5d0152d.png

使用方式

  • 使用范围

    目前文心中内置的可以通过json配置完成的类型包括:

    • field_reader:域数据相关类,位于wenxin/data/field_reader/
    • dataset_reader:数据集相关类,位于wenxin/data/data_set_reader/
    • model:神经网络相关类,位于wenxin/models/
    • trainer:训练调度相关类,位于wenxin/training/
    • tokenizer:分词相关类,位于wenxin/data/tokenizer/
  • 如何在文心中为新类别添加注册机制

    • 在./wenxin/common/register.py文件中查看当前的全局变量RegisterSet中是否已有对应类别的Register,如果没有则新初始化一个Register类型的成员变量,参考已有的Register即可。
    • 在你的Python文件中,找到新类的定义,在类名上方添加register装饰器。

      ## 添加register装饰器。
      @RegisterSet.models.register    
      class CnnClassification(Model):
          """CnnClassification
          """
          def __init__(self, model_params):
              Model.__init__(self, model_params)
              self.params = model_params
    • 在json中配置你的类名,在运行脚本中添加实例化的代码,如下所示:

      ....
      ## 从json中解析类名
      model_name = params_dict.get("type")
      ## 从全局变量RegisterSet中查找已经import成功的module, 这里以查找Model类型的类举例。
      model_class = RegisterSet.models.__getitem__(model_name)
      ## 调用callable方法
      model = model_class(params_dict)
      
      ....
       
上一篇
模型超参数自动搜索
下一篇
V2.1.1文生图开发工具