资讯 文档
技术能力
语音技术
文字识别
人脸与人体
图像技术
语言与知识
视频技术

FieldReader

FieldReader

基类定义

FieldReader的基类为BaseFieldReader,其定义在./wenxin/data/field_reader/base_field_reader.py中。

@RegisterSet.field_reader.register
class BaseFieldReader(object):
    """BaseFieldReader: 作用于field的reader,主要是定义data_loader的格式,完成id序列化和embedding的操作
    """
    def __init__(self, field_config):
        self.field_config = field_config
        self.tokenizer = None  # 用来分词,需要各个子类实现
        self.token_embedding = None  # 用来生成embedding向量,需要各个子类实现

    def init_reader(self, dataset_type=InstanceName.TYPE_PY_READER):
        """ 初始化reader格式,两种模式,如果是py_reader模式的话,返回reader的shape、type、level;
        如果是data_loader模式,返回fluid.data数组
        :param dataset_type : dataset的类型,目前有两种:py_reader、data_loader, 默认是py_reader
        :return:
        """
        raise NotImplementedError

    def convert_texts_to_ids(self, batch_text):
        """ 明文序列化,转为数字id
        :param:batch_text
        :return: id_list
        """
        raise NotImplementedError

    def get_field_length(self):
        """获取当前这个field在进行了序列化之后,在field_id_list中占多少长度
        :return:
        """
        raise NotImplementedError

    def structure_fields_dict(self, fields_id, start_index, need_emb=True):
        """
        静态图调用的方法,生成一个dict, dict有两个key:id , emb. id对应的是pyreader读出来的各个field产出的id,emb对应的是各个field对应的embedding
        :param fields_id: pyreader输出的完整的id序列
        :param start_index:当前需要处理的field在field_id_list中的起始位置
        :param need_emb:是否需要embedding(预测过程中是不需要embedding的)
        :return:
        """
        raise NotImplementedError

核心函数

BaseFieldReader做为基类,需要用户按自己业务场景在子类中自定义的核心函数为以下5个:

  • _ init _(self, field_config)
  • init_reader(self, dataset_type=InstanceName.TYPE_PY_READER)
  • convert_texts_to_ids(self, batch_text)
  • get_field_length(self)
  • structure_fields_dict(self, fields_id, start_index, need_emb=True)

自定义实现示例

以文心目前提供的ERNIE任务中常用的对文本域进行处理的ErnieTextFieldReader为例,解释BaseFieldReader中5个核心函数的实现,详见以下代码及核心部分的注释。

# -*- coding: utf-8 -*
"""
:py:class:`ErnieTextFieldReader`

"""
import paddle
# import logging
from paddle import fluid
from ...common.register import RegisterSet
from ...common.rule import DataShape, FieldLength, InstanceName
from .base_field_reader import BaseFieldReader
from ..util_helper import pad_batch_data, get_random_pos_id
# from wenxin.modules.token_embedding.ernie_embedding import ErnieTokenEmbedding
from ...utils.util_helper import truncation_words


@RegisterSet.field_reader.register
class ErnieTextFieldReader(BaseFieldReader):
    """使用ernie的文本类型的field_reader,用户不需要自己分词
        处理规则是:自动添加padding,mask,position,task,sentence,并返回length
        """
    def __init__(self, field_config):
        """
        :param field_config:
        """
        BaseFieldReader.__init__(self, field_config=field_config)

        if self.field_config.tokenizer_info:
            tokenizer_class = RegisterSet.tokenizer.__getitem__(self.field_config.tokenizer_info["type"])
            params = None
            if self.field_config.tokenizer_info.__contains__("params"):
                params = self.field_config.tokenizer_info["params"]
            self.tokenizer = tokenizer_class(vocab_file=self.field_config.vocab_path,
                                             split_char=self.field_config.tokenizer_info["split_char"],
                                             unk_token=self.field_config.tokenizer_info["unk_token"],
                                             params=params)

     
    def init_reader(self, dataset_type=InstanceName.TYPE_PY_READER):
        """ 初始化reader格式,两种模式,如果是py_reader模式的话,返回reader的shape、type、level;
        如果是data_loader模式,返回fluid.data数组
        :param dataset_type : dataset的类型,目前有两种:py_reader、data_loader, 默认是py_reader
        :return:
        """

        shape = []
        types = []
        levels = []
        feed_names = []
        data_list = []

        if self.field_config.data_type == DataShape.STRING:
            """src_ids"""
            shape.append([-1, -1])
            levels.append(0)
            types.append('int64')
            feed_names.append(self.field_config.name + "_" + InstanceName.SRC_IDS)
        else:
            raise TypeError("ErnieTextFieldReader's data_type must string")

        """sentence_ids"""
        shape.append([-1, -1])
        levels.append(0)
        types.append('int64')
        feed_names.append(self.field_config.name + "_" + InstanceName.SENTENCE_IDS)

        """position_ids"""
        shape.append([-1, -1])
        levels.append(0)
        types.append('int64')
        feed_names.append(self.field_config.name + "_" + InstanceName.POS_IDS)

        """mask_ids"""
        shape.append([-1, -1])
        levels.append(0)
        types.append('float32')
        feed_names.append(self.field_config.name + "_" + InstanceName.MASK_IDS)

        """task_ids"""
        shape.append([-1, -1])
        levels.append(0)
        types.append('int64')
        feed_names.append(self.field_config.name + "_" + InstanceName.TASK_IDS)

        """seq_lens"""
        shape.append([-1])
        levels.append(0)
        types.append('int64')
        feed_names.append(self.field_config.name + "_" + InstanceName.SEQ_LENS)

        if dataset_type == InstanceName.TYPE_DATA_LOADER:
            for i in range(len(feed_names)):
                data_list.append(paddle.static.data(name=feed_names[i], shape=shape[i],
                                                   dtype=types[i], lod_level=levels[i]))
            return data_list
        else:
            return shape, types, levels

    def convert_texts_to_ids(self, batch_text, use_random_pos=False, max_pos_id=2048):
        """将一个batch的明文text转成id
        :param batch_text:
        :return:
        """
        src_ids = []
        position_ids = []
        task_ids = []
        sentence_ids = []
        for text in batch_text:
            if self.field_config.need_convert:
                tokens_text = self.tokenizer.tokenize(text)
                # 加上截断策略
                if len(tokens_text) > self.field_config.max_seq_len - 2:
                    tokens_text = truncation_words(tokens_text, self.field_config.max_seq_len - 2,
                                                   self.field_config.truncation_type)
                tokens = []
                tokens.append("[CLS]")
                for token in tokens_text:
                    tokens.append(token)
                tokens.append("[SEP]")
                src_id = self.tokenizer.convert_tokens_to_ids(tokens)
            else:
                if isinstance(text, str):
                    text = text.split(" ")
                src_id = [int(i) for i in text]
                if len(src_id) > self.field_config.max_seq_len - 2:
                    src_id = truncation_words(src_id, self.field_config.max_seq_len - 2,
                                                   self.field_config.truncation_type)
                src_id.insert(0, self.tokenizer.covert_token_to_id("[CLS]"))
                src_id.append(self.tokenizer.covert_token_to_id("[SEP]"))

            src_ids.append(src_id)
            pos_id = list(range(len(src_id)))
            task_id = [0] * len(src_id)
            sentence_id = [0] * len(src_id)
            position_ids.append(pos_id)
            task_ids.append(task_id)
            sentence_ids.append(sentence_id)

        return_list = []
        if use_random_pos:
            position_ids = get_random_pos_id(position_ids, max_pos_id)
        padded_ids, input_mask, batch_seq_lens = pad_batch_data(src_ids,
                                                                pad_idx=self.field_config.padding_id,
                                                                return_input_mask=True,
                                                                return_seq_lens=True)
        sent_ids_batch = pad_batch_data(sentence_ids, pad_idx=self.field_config.padding_id)
        pos_ids_batch = pad_batch_data(position_ids, pad_idx=self.field_config.padding_id)
        task_ids_batch = pad_batch_data(task_ids, pad_idx=self.field_config.padding_id)

        return_list.append(padded_ids)  # append src_ids
        return_list.append(sent_ids_batch)  # append sent_ids
        return_list.append(pos_ids_batch)  # append pos_ids
        return_list.append(input_mask)  # append mask
        return_list.append(task_ids_batch)  # append task_ids
        return_list.append(batch_seq_lens)  # append seq_lens

        return return_list

    def structure_fields_dict(self, fields_id, start_index, need_emb=True):
        """静态图调用的方法,生成一个dict, dict有两个key:id , emb. id对应的是pyreader读出来的各个field产出的id,emb对应的是各个
        field对应的embedding
        :param fields_id: pyreader输出的完整的id序列
        :param start_index:当前需要处理的field在field_id_list中的起始位置
        :param need_emb:是否需要embedding(预测过程中是不需要embedding的)
        :return:
        """
        record_id_dict = {}
        record_id_dict[InstanceName.SRC_IDS] = fields_id[start_index]
        record_id_dict[InstanceName.SENTENCE_IDS] = fields_id[start_index + 1]
        record_id_dict[InstanceName.POS_IDS] = fields_id[start_index + 2]
        record_id_dict[InstanceName.MASK_IDS] = fields_id[start_index + 3]
        record_id_dict[InstanceName.TASK_IDS] = fields_id[start_index + 4]
        record_id_dict[InstanceName.SEQ_LENS] = fields_id[start_index + 5]

        record_emb_dict = None
        if need_emb and self.token_embedding:
            record_emb_dict = self.token_embedding.get_token_embedding(record_id_dict)

        record_dict = {}
        record_dict[InstanceName.RECORD_ID] = record_id_dict
        record_dict[InstanceName.RECORD_EMB] = record_emb_dict

        return record_dict

    def get_field_length(self):
        """获取当前这个field在进行了序列化之后,在field_id_list中占多少长度
        :return:
        """
        return FieldLength.ERNIE_TEXT_FIELD
上一篇
Tokenizer
下一篇
DataSetReader