多个文件上的 IterableDataset

IterableDataset on multiple files

提问人:LearnToGrow 提问时间:11/8/2023 最后编辑:LearnToGrow 更新时间:11/8/2023 访问量:34

问:

我有 1000 个文件,每个文件都用滑动窗口处理以生成许多训练样本。 我编写了这个可迭代样式的数据集。一切正常,唯一的问题是每批数据都由来自 1 个文件的数据组成。有没有办法实现自定义整理功能或类似于合并数据的东西?

class TextFileDataset(IterableDataset):
                def __init__(self, file_pathes: List,) -> None:
                    self.file_pathes = file_pathes
                   
            
                    self.file_id_map= {file_path: idx for idx,
                                        file_path in enumerate(iterable=file_pathes)}
                   def process_file(self, file_path) -> None:
                         doc_indices = self.load_data(file_path=file_path)
                         idx= self.file_id_map[file_path]
                         for i in range(len(doc_indices)):
                             # run a sliding window
                             ..........
                             yield {"idx": torch.tensor(idx), "text1": torch.tensor(text1), "text2": torch.tensor(text2)}
                def __iter__(self) -> None:
                    worker_info = torch.utils.data.get_worker_info()
                    if worker_info is None:
                        for file_path in self.file_pathes:
                            yield from self.process_file(file_path=file_path)
    
                    per_worker = int(np.ceil(len(self.file_pathes) /
                                     float(worker_info.num_workers)))
                    worker_id = worker_info.id
                    self.iter_file_paths = self.file_pathes[worker_id *
                                                            per_worker:(worker_id + 1) * per_worker]
                    for file_path in self.iter_file_paths:
                        yield from self.process_file(file_path=file_path)
python torch 可 迭代 pytorch-dataloader

评论


答: 暂无答案