DataSetReader
更新时间:2021-05-13
基类定义
DataSetReader的基类是BaseDataLoader,定义在./wenxin/data/data_set_reader/base_dataset_dataloader.py中。
...
@RegisterSet.data_set_reader.register
class BaseDataLoader(object):
"""BaseDataLoader:将样本中数据组装成一个data_loader, 向外提供一个统一的接口。
核心内容是读取明文文件,转换成id,按data_loader需要的tensor格式灌进去,然后通过调用run方法让整个循环跑起来。data_loader拿出的来的是lod-tensor形式的id,这些id可以用来做后面的embedding等计算。
"""
def __init__(self, name, fields, config):
self.name = name
self.fields = fields
self.config = config # 常用参数,batch_size等,ReaderConfig类型变量
self.paddle_data_loader = None
self.input_data_list = []
self.current_example = 0
self.current_epoch = 0
self.num_examples = 0
# 迭代器生成数据的时候是否需要生成明文样本,目前来看,训练的时候不需要,预测的时候需要
self.need_generate_examples = False
if config.extra_params:
self.iterable = config.extra_params.get("iterable", False)
self.use_cuda = config.extra_params.get("use_cuda", False)
self.use_data_parallel = config.extra_params.get("use_data_parallel", False)
self.return_list = config.extra_params.get("return_list", False)
self.use_multiprocess = config.extra_params.get("use_multiprocess", False)
else:
self.iterable = False
self.use_cuda = False
self.use_data_parallel = False
self.return_list = False
self.use_multiprocess = False
def create_reader(self):
"""
子类必须实现的接口,否则会抛出异常。
用于初始化self.paddle_data_loader。
:return:None
"""
raise NotImplementedError
def instance_fields_dict(self):
"""
子类必须实现的接口,否则会抛出异常。
实例化fields_dict, 得到fields_id, 视情况构造embedding,然后结构化成dict类型返回给组网部分。
:return:dict
{"field_name":
{"RECORD_ID":
{"SRC_IDS": [ids],
"MASK_IDS": [ids],
"SEQ_LENS": [ids]
}
}
}
实例化的dict,保存了各个field的id和embedding(可以没有,是情况而定), 给trainer用.
"""
raise NotImplementedError
def data_generator(self):
"""
子类必须实现的接口,否则会抛出异常。
数据生成器:读取明文文件,生成batch化的id数据,绑定到py_reader中
:return:list
[[src_ids],
[mask_ids],
[seq_lens]
]
"""
raise NotImplementedError
def convert_fields_to_dict(self, field_list, need_emb=False):
"""instance_fields_dict一般调用本方法实例化fields_dict,保存各个field的id和embedding(可以没有,是情况而定),
当need_emb=False的时候,可以直接给predictor调用
:param field_list:
:param need_emb:
:return: dict
"""
raise NotImplementedError
def run(self):
"""
配置data_loader对应的数据生成器,并启动。该方法由基类定义,子类不需要自定义,直接调用即可。
:return:
"""
if self.paddle_data_loader:
if self.iterable:
# 若DataLoader可迭代,则必须设置places参数
if self.use_data_parallel:
# 若进行多GPU卡训练,则取所有的CUDAPlace
# 若进行多CPU核训练,则取多个CPUPlace,这里以2个为例 # TODO: 这里需要从环境变量获取
places = paddle.CUDAPlace() if self.use_cuda else paddle.CPUPlace(2)
else:
# 若进行单GPU卡训练,则取单个CUDAPlace,本例中0代表0号GPU卡
# 若进行单CPU核训练,则取单个CPUPlace,本例中1代表1个CPUPlace
gpus = os.getenv('FLAGS_selected_gpus', '0').split(",")
id = int(gpus[0])
places = paddle.CUDAPlace(int(gpus[0])) if self.use_cuda else paddle.CPUPlace(1)
self.paddle_data_loader.set_batch_generator(self.data_generator(), places=places)
logging.info("set data_loader's generator with iterable.......")
else:
# 若DataLoader不可迭代,则不需要设置places参数
places = None
self.paddle_data_loader.set_batch_generator(self.data_generator(), places=places)
self.paddle_data_loader.start()
logging.info("set data_loader's generator and start.......")
else:
raise ValueError("paddle_data_loader is None")
def stop(self):
"""
停止一个data_loader,该方法由基类定义,子类不需要自定义,直接调用即可。
:return:
"""
if self.paddle_data_loader:
self.paddle_data_loader.reset()
else:
raise ValueError("paddle_data_loader is None")
def get_train_progress(self):
"""
获取当前reader的读取进度
"""
return self.current_example, self.current_epoch
def get_num_examples(self):
"""
获取当前reader对应的数据集的样本总数,视情况由子类实现
"""
return self.num_examples
def api_generator(self, query_list):
"""
对单条query进行id化,一般在预测api接口调用的情况下使用,基类实现了最基础的版本,其他视情况由子类实现。
"""
if len(query_list) <= 0:
raise ValueError("query can't be None")
field_names = []
for filed in self.fields:
field_names.append(filed.name)
Example = namedtuple('Example', field_names)
examples = []
for query in query_list:
examples.append(Example(**query))
for batch_data in self.prepare_batch_data(examples, self.config.batch_size):
yield batch_data核心函数
BaseDataSetReader做为基类,需要用户按自己业务场景在子类中自定义的核心函数为以下3个:
- create_reader(self):
- instance_fields_dict(self):
- data_generator(self):
自定义实现示例
以文心目前提供的能够覆盖绝大部分任务的BasicDataSetReader为例,解释BaseDataSetReader中3个核心函数的实现,详见以下代码及核心部分的注释。
@RegisterSet.data_set_reader.register
class BasicDataSetReader(BaseDataLoader):
"""BasicDataSetReader:一个基础的data_set_reader,实现了文件读取,id序列化,token embedding化等基本操作
"""
def __init__(self, name, fields, config):
"""__init__
"""
BaseDataLoader.__init__(self, name, fields, config)
def create_reader(self):
""" 对基类接口的实现,初始化paddle_data_loader,必须要初始化,否则会抛出异常
:return:
"""
if not self.fields:
raise ValueError("fields can't be None")
for item in self.fields:
if not item.field_reader:
raise ValueError("{0}'s field_reader is None".format(item.name))
# step1:分别获取当前数据集样本中每个域的定义,拿到各个域的数据类型、形状、lod_level,构造fluid.layers.data类型变量
self.input_data_list.extend(item.field_reader.init_reader(dataset_type=InstanceName.TYPE_DATA_LOADER))
# step2: 由step1中初始化好的变量对data_loader进行初始化。
self.paddle_data_loader = paddle.io.DataLoader.from_generator(
feed_list=self.input_data_list,
capacity=64,
iterable=self.iterable,
return_list=self.return_list)
def instance_fields_dict(self):
"""对基类接口的实现
将输入进来的tensor数组经过转换,得到fields_id, 视情况构造embedding,然后结构化成dict类型返回给组网部分
:return: 实例化的dict,保存了各个field的id和embedding(可以没有,是情况而定), 给trainer用
"""
fields_instance = self.convert_fields_to_dict(self.input_data_list)
return fields_instance
def convert_fields_to_dict(self, field_list, need_emb=True):
"""
实例化fields_dict,保存了各个field的id和embedding(可以没有,是情况而定),
当need_emb=False的时候,可以直接给predictor调用
:param field_list:
:param need_emb:
:return: dict
"""
start_index = 0
fields_instance = {}
for index, filed in enumerate(self.fields):
item_dict = filed.field_reader.structure_fields_dict(field_list, start_index, need_emb=need_emb)
fields_instance[filed.name] = item_dict
start_index += filed.field_reader.get_field_length()
return fields_instance
def data_generator(self):
"""
对基类接口的实现,构造一个迭代器,实现了对目录文件的读取,划分batch,明文转id。
:return:
"""
assert os.path.isdir(self.config.data_path), "%s must be a directory that stores data files" % self.config.data_path
data_files = os.listdir(self.config.data_path)
assert len(data_files) > 0, "%s is an empty directory" % self.config.data_path
def wrapper():
"""
:return:
"""
# 以超参epoch循环读取数据集明文数据
for epoch_index in range(self.config.epoch):
self.current_example = 0
self.current_epoch = epoch_index
for input_file in data_files:
examples = self.read_files(os.path.join(self.config.data_path, input_file))
if self.config.shuffle:
np.random.shuffle(examples)
# 按batch迭代输出id化完成的明文样本
for batch_data in self.prepare_batch_data(examples, self.config.batch_size):
yield batch_data
return wrapper
def read_files(self, file_path, quotechar=None):
"""读取明文文件
:param file_path
:return: 以namedtuple数组形式输出明文样本对应的实例
"""
with open(file_path, "r") as f:
try:
examples = []
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
len_fields = len(self.fields)
field_names = []
for filed in self.fields:
field_names.append(filed.name)
self.Example = namedtuple('Example', field_names)
for line in reader:
if len(line) == len_fields:
example = self.Example(*line)
examples.append(example)
else:
logging.warn('fileds in file %s not match: got %d, expect %d'\
% (file_path, len(line), len_fields))
return examples
except Exception:
logging.error("error in read tsv")
logging.error("traceback.format_exc():\n%s" % traceback.format_exc())
def prepare_batch_data(self, examples, batch_size):
"""将明文样本按照data_loader需要的格式序列化成一个个batch输出
:param examples:
:param batch_size:
:return:
"""
batch_records = []
for index, example in enumerate(examples):
self.current_example += 1
if len(batch_records) < batch_size:
batch_records.append(example)
else:
yield self.pad_batch_records(batch_records)
batch_records = [example]
if batch_records:
yield self.pad_batch_records(batch_records)
def pad_batch_records(self, batch_records):
"""
按batch处理一组样本,一个batch内的样本对应的文本长度会被padding到当前batch中最大的样本长度
:param batch_records:
:return:
"""
return_list = []
example = batch_records[0]
for index, key in enumerate(example._fields):
text_batch = []
for record in batch_records:
text_batch.append(record[index])
id_list = self.fields[index].field_reader.convert_texts_to_ids(text_batch)
return_list.extend(id_list)
if self.need_generate_examples:
return return_list, batch_records
else:
return return_list
def get_train_progress(self):
"""获取当前reader的读取进度."""
return self.current_example, self.current_epoch
def get_num_examples(self):
"""获取当前数据集上的样本总数"""
data_files = os.listdir(self.config.data_path)
assert len(data_files) > 0, "%s is an empty directory" % self.config.data_path
sum_examples = 0
for input_file in data_files:
examples = self.read_files(os.path.join(self.config.data_path, input_file))
sum_examples += len(examples)
self.num_examples = sum_examples
return self.num_examples