资讯 文档
技术能力
语音技术
文字识别
人脸与人体
图像技术
语言与知识
视频技术

DataSetReader

基类定义

DataSetReader的基类是BaseDataLoader,定义在./wenxin/data/data_set_reader/base_dataset_dataloader.py中。

...
@RegisterSet.data_set_reader.register
class BaseDataLoader(object):
    """BaseDataLoader:将样本中数据组装成一个data_loader, 向外提供一个统一的接口。
    核心内容是读取明文文件,转换成id,按data_loader需要的tensor格式灌进去,然后通过调用run方法让整个循环跑起来。data_loader拿出的来的是lod-tensor形式的id,这些id可以用来做后面的embedding等计算。
    """

    def __init__(self, name, fields, config):
        self.name = name
        self.fields = fields
        self.config = config  # 常用参数,batch_size等,ReaderConfig类型变量
        self.paddle_data_loader = None
        self.input_data_list = []
        self.current_example = 0
        self.current_epoch = 0
        self.num_examples = 0
        # 迭代器生成数据的时候是否需要生成明文样本,目前来看,训练的时候不需要,预测的时候需要
        self.need_generate_examples = False

        if config.extra_params:
            self.iterable = config.extra_params.get("iterable", False)
            self.use_cuda = config.extra_params.get("use_cuda", False)
            self.use_data_parallel = config.extra_params.get("use_data_parallel", False)
            self.return_list = config.extra_params.get("return_list", False)
            self.use_multiprocess = config.extra_params.get("use_multiprocess", False)
        else:
            self.iterable = False
            self.use_cuda = False
            self.use_data_parallel = False
            self.return_list = False
            self.use_multiprocess = False

    def create_reader(self):
        """
        子类必须实现的接口,否则会抛出异常。
        用于初始化self.paddle_data_loader。
        :return:None
        """
        raise NotImplementedError

    def instance_fields_dict(self):
        """
        子类必须实现的接口,否则会抛出异常。
        实例化fields_dict, 得到fields_id, 视情况构造embedding,然后结构化成dict类型返回给组网部分。
        :return:dict
                {"field_name":
                    {"RECORD_ID":
                        {"SRC_IDS": [ids],
                         "MASK_IDS": [ids],
                         "SEQ_LENS": [ids]
                        }
                    }
                }
        实例化的dict,保存了各个field的id和embedding(可以没有,是情况而定), 给trainer用.

        """
        raise NotImplementedError

    def data_generator(self):
        """
        子类必须实现的接口,否则会抛出异常。
        数据生成器:读取明文文件,生成batch化的id数据,绑定到py_reader中
        :return:list
                [[src_ids],
                 [mask_ids],
                 [seq_lens]
                ]
        """
        raise NotImplementedError

    def convert_fields_to_dict(self, field_list, need_emb=False):
        """instance_fields_dict一般调用本方法实例化fields_dict,保存各个field的id和embedding(可以没有,是情况而定),
        当need_emb=False的时候,可以直接给predictor调用
        :param field_list:
        :param need_emb:
        :return: dict
        """
        raise NotImplementedError

    def run(self):
        """
        配置data_loader对应的数据生成器,并启动。该方法由基类定义,子类不需要自定义,直接调用即可。
        :return:
        """
        if self.paddle_data_loader:
            if self.iterable:
                # 若DataLoader可迭代,则必须设置places参数
                if self.use_data_parallel:
                    # 若进行多GPU卡训练,则取所有的CUDAPlace
                    # 若进行多CPU核训练,则取多个CPUPlace,这里以2个为例 # TODO: 这里需要从环境变量获取
                    places = paddle.CUDAPlace() if self.use_cuda else paddle.CPUPlace(2)
                else:
                    # 若进行单GPU卡训练,则取单个CUDAPlace,本例中0代表0号GPU卡
                    # 若进行单CPU核训练,则取单个CPUPlace,本例中1代表1个CPUPlace
                    gpus = os.getenv('FLAGS_selected_gpus', '0').split(",")
                    id = int(gpus[0])
                    places = paddle.CUDAPlace(int(gpus[0])) if self.use_cuda else paddle.CPUPlace(1)
                self.paddle_data_loader.set_batch_generator(self.data_generator(), places=places)
                logging.info("set data_loader's generator with iterable.......")
            else:
                # 若DataLoader不可迭代,则不需要设置places参数
                places = None
                self.paddle_data_loader.set_batch_generator(self.data_generator(), places=places)
                self.paddle_data_loader.start()
                logging.info("set data_loader's generator and start.......")

        else:
            raise ValueError("paddle_data_loader is None")

    def stop(self):
        """
        停止一个data_loader,该方法由基类定义,子类不需要自定义,直接调用即可。
        :return:
        """
        if self.paddle_data_loader:
            self.paddle_data_loader.reset()
        else:
            raise ValueError("paddle_data_loader is None")

    def get_train_progress(self):
        """
        获取当前reader的读取进度
        """
        return self.current_example, self.current_epoch

    def get_num_examples(self):
        """
        获取当前reader对应的数据集的样本总数,视情况由子类实现
        """
        return self.num_examples
    
    def api_generator(self, query_list):
        """
        对单条query进行id化,一般在预测api接口调用的情况下使用,基类实现了最基础的版本,其他视情况由子类实现。
        """
        if len(query_list) <= 0:
            raise ValueError("query can't be None")

        field_names = []
        for filed in self.fields:
            field_names.append(filed.name)

        Example = namedtuple('Example', field_names)
        examples = []
        for query in query_list:
            examples.append(Example(**query))

        for batch_data in self.prepare_batch_data(examples, self.config.batch_size):
            yield batch_data

核心函数

BaseDataSetReader做为基类,需要用户按自己业务场景在子类中自定义的核心函数为以下3个:

  • create_reader(self):
  • instance_fields_dict(self):
  • data_generator(self):

自定义实现示例

以文心目前提供的能够覆盖绝大部分任务的BasicDataSetReader为例,解释BaseDataSetReader中3个核心函数的实现,详见以下代码及核心部分的注释。

@RegisterSet.data_set_reader.register
class BasicDataSetReader(BaseDataLoader):
    """BasicDataSetReader:一个基础的data_set_reader,实现了文件读取,id序列化,token embedding化等基本操作
    """

    def __init__(self, name, fields, config):
        """__init__
        """
        BaseDataLoader.__init__(self, name, fields, config)

    def create_reader(self):
        """ 对基类接口的实现,初始化paddle_data_loader,必须要初始化,否则会抛出异常
        :return:
        """
        if not self.fields:
            raise ValueError("fields can't be None")

        for item in self.fields:
            if not item.field_reader:
                raise ValueError("{0}'s field_reader is None".format(item.name))
            # step1:分别获取当前数据集样本中每个域的定义,拿到各个域的数据类型、形状、lod_level,构造fluid.layers.data类型变量
                self.input_data_list.extend(item.field_reader.init_reader(dataset_type=InstanceName.TYPE_DATA_LOADER))
				
        # step2: 由step1中初始化好的变量对data_loader进行初始化。
        self.paddle_data_loader = paddle.io.DataLoader.from_generator(
            feed_list=self.input_data_list,
            capacity=64,
            iterable=self.iterable,
            return_list=self.return_list)

    def instance_fields_dict(self):
        """对基类接口的实现
        将输入进来的tensor数组经过转换,得到fields_id, 视情况构造embedding,然后结构化成dict类型返回给组网部分
        :return: 实例化的dict,保存了各个field的id和embedding(可以没有,是情况而定), 给trainer用
        """
        fields_instance = self.convert_fields_to_dict(self.input_data_list)
        return fields_instance

    def convert_fields_to_dict(self, field_list, need_emb=True):
        """
        实例化fields_dict,保存了各个field的id和embedding(可以没有,是情况而定),
        当need_emb=False的时候,可以直接给predictor调用
        :param field_list:
        :param need_emb:
        :return: dict
        """
        start_index = 0
        fields_instance = {}
        for index, filed in enumerate(self.fields):
            item_dict = filed.field_reader.structure_fields_dict(field_list, start_index, need_emb=need_emb)
            fields_instance[filed.name] = item_dict
            start_index += filed.field_reader.get_field_length()

        return fields_instance

    def data_generator(self):
        """
        对基类接口的实现,构造一个迭代器,实现了对目录文件的读取,划分batch,明文转id。
        :return:
        """
        assert os.path.isdir(self.config.data_path), "%s must be a directory that stores data files" % self.config.data_path
        data_files = os.listdir(self.config.data_path)

        assert len(data_files) > 0, "%s is an empty directory" % self.config.data_path

        def wrapper():
            """
            :return:
            """
            # 以超参epoch循环读取数据集明文数据
            for epoch_index in range(self.config.epoch):
                self.current_example = 0
                self.current_epoch = epoch_index

                for input_file in data_files:
                    examples = self.read_files(os.path.join(self.config.data_path, input_file))
                    if self.config.shuffle:
                        np.random.shuffle(examples)
										# 按batch迭代输出id化完成的明文样本
                    for batch_data in self.prepare_batch_data(examples, self.config.batch_size):
                        yield batch_data

        return wrapper

    def read_files(self, file_path, quotechar=None):
        """读取明文文件
        :param file_path
        :return: 以namedtuple数组形式输出明文样本对应的实例
        """
        with open(file_path, "r") as f:
            try:
                examples = []
                reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
                len_fields = len(self.fields)
                field_names = []
                for filed in self.fields:
                    field_names.append(filed.name)

                self.Example = namedtuple('Example', field_names)

                for line in reader:
                    if len(line) == len_fields:
                        example = self.Example(*line)
                        examples.append(example)
                    else:
                        logging.warn('fileds in file %s not match: got %d, expect %d'\
                                % (file_path, len(line), len_fields))
                return examples

            except Exception:
                logging.error("error in read tsv")
                logging.error("traceback.format_exc():\n%s" % traceback.format_exc())

    def prepare_batch_data(self, examples, batch_size):
        """将明文样本按照data_loader需要的格式序列化成一个个batch输出
        :param examples:
        :param batch_size:
        :return:
        """
        batch_records = []
        for index, example in enumerate(examples):
            self.current_example += 1
            if len(batch_records) < batch_size:
                batch_records.append(example)
            else:
                yield self.pad_batch_records(batch_records)
                batch_records = [example]

        if batch_records:
            yield self.pad_batch_records(batch_records)

    def pad_batch_records(self, batch_records):
        """
        按batch处理一组样本,一个batch内的样本对应的文本长度会被padding到当前batch中最大的样本长度
        :param batch_records:
        :return:
        """
        return_list = []
        example = batch_records[0]
        for index, key in enumerate(example._fields):
            text_batch = []
            for record in batch_records:
                text_batch.append(record[index])

            id_list = self.fields[index].field_reader.convert_texts_to_ids(text_batch)
            return_list.extend(id_list)

        if self.need_generate_examples:
            return return_list, batch_records
        else:
            return return_list

    def get_train_progress(self):
        """获取当前reader的读取进度."""
        return self.current_example, self.current_epoch

    def get_num_examples(self):
        """获取当前数据集上的样本总数"""
        data_files = os.listdir(self.config.data_path)
        assert len(data_files) > 0, "%s is an empty directory" % self.config.data_path
        sum_examples = 0
        for input_file in data_files:
            examples = self.read_files(os.path.join(self.config.data_path, input_file))
            sum_examples += len(examples)

        self.num_examples = sum_examples

        return self.num_examples
上一篇
文心core框架设计
下一篇
FieldReader