提问人:LearnToGrow 提问时间:11/8/2023 最后编辑:LearnToGrow 更新时间:11/8/2023 访问量:34
多个文件上的 IterableDataset
IterableDataset on multiple files
问:
我有 1000 个文件,每个文件都用滑动窗口处理以生成许多训练样本。 我编写了这个可迭代样式的数据集。一切正常,唯一的问题是每批数据都由来自 1 个文件的数据组成。有没有办法实现自定义整理功能或类似于合并数据的东西?
class TextFileDataset(IterableDataset):
def __init__(self, file_pathes: List,) -> None:
self.file_pathes = file_pathes
self.file_id_map= {file_path: idx for idx,
file_path in enumerate(iterable=file_pathes)}
def process_file(self, file_path) -> None:
doc_indices = self.load_data(file_path=file_path)
idx= self.file_id_map[file_path]
for i in range(len(doc_indices)):
# run a sliding window
..........
yield {"idx": torch.tensor(idx), "text1": torch.tensor(text1), "text2": torch.tensor(text2)}
def __iter__(self) -> None:
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
for file_path in self.file_pathes:
yield from self.process_file(file_path=file_path)
per_worker = int(np.ceil(len(self.file_pathes) /
float(worker_info.num_workers)))
worker_id = worker_info.id
self.iter_file_paths = self.file_pathes[worker_id *
per_worker:(worker_id + 1) * per_worker]
for file_path in self.iter_file_paths:
yield from self.process_file(file_path=file_path)
答: 暂无答案
评论