ValueError:在 dim 1 时长度为 8 的预期序列(得到 9)

ValueError: expected sequence of length 8 at dim 1 (got 9)

提问人:Hrithik2212 提问时间:10/25/2023 更新时间:10/25/2023 访问量:37

问:

我不确定为什么我会收到这个错误,我在前一天晚上运行代码时没有收到这个错误。 我已经确保我的translation_src和translation_target都具有相同的序列长度,而且我的填充整理功能也是正确的

class TranslationDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        src_encoded=self.dataset[idx]['translation_src']
        trg_encoded=self.dataset[idx]['translation_trg']

        # Determine the maximum sequence length
        max_len = max(len(src_encoded), len(trg_encoded))

        # Pad the sequences to have the same length
        src_encoded = src_encoded + [0]*(max_len - len(src_encoded))
        trg_encoded = trg_encoded + [0]*(max_len - len(trg_encoded))
        return (
            torch.tensor(src_encoded),
            torch.tensor(trg_encoded),
        )

train_ds = TranslationDataset(data['train'])
val_ds = TranslationDataset(data['test'])

def pad_collate_fn(batch):
    src_sentences,trg_sentences=[],[]
    for sample in batch:
        src_sentences+=[sample[0]]
        trg_sentences+=[sample[1]]

    src_sentences = pad_sequence(src_sentences, batch_first=True, padding_value=0)
    trg_sentences = pad_sequence(trg_sentences, batch_first=True, padding_value=0)

    return src_sentences, trg_sentences

def chunk(indices, chunk_size):
    return torch.split(torch.tensor(indices), chunk_size)

class CustomBatchSampler(Sampler):
    def __init__(self, dataset, batch_size):

        # Dataset is already sorted so just chunk indices
        # into batches of indices for sampling
        self.batch_size=batch_size
        self.indices=range(len(dataset))
        self.batch_of_indices=list(chunk(self.indices, self.batch_size))
        self.batch_of_indices = [batch.tolist() for batch in self.batch_of_indices]

    def __iter__(self):
        random.shuffle(self.batch_of_indices)
        return iter(self.batch_of_indices)

    def __len__(self):
        return len(self.batch_of_indices)
     

custom_batcher_train = CustomBatchSampler(train_ds, config['BATCH_SIZE'])
custom_batcher_val = CustomBatchSampler(val_ds, config['BATCH_SIZE'])

     

# example-use
dummy_batcher = CustomBatchSampler(train_ds, 3)
dummy_dl=DataLoader(train_ds, collate_fn=pad_collate_fn , batch_sampler=dummy_batcher, pin_memory=True)
for x ,y  in dummy_dl:
    print('Shapes: ')
    print('-'*10)
    print(x.size())
    print(y.size())
    print()


    print('e.g. src batch (see there is minimal/no padding):')
    print('-'*10)
    print(x.numpy())

    break

错误

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-103-b82fc7198b80> in <cell line: 30>()
     28 dummy_batcher = CustomBatchSampler(train_ds, 3)
     29 dummy_dl=DataLoader(train_ds, collate_fn=pad_collate_fn , batch_sampler=dummy_batcher, pin_memory=True)
---> 30 for x ,y  in dummy_dl:
     31     print('Shapes: ')
     32     print('-'*10)

4 frames
/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py in __next__(self)
    628                 # TODO(https://github.com/pytorch/pytorch/issues/76750)
    629                 self._reset()  # type: ignore[call-arg]
--> 630             data = self._next_data()
    631             self._num_yielded += 1
    632             if self._dataset_kind == _DatasetKind.Iterable and \

/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py in _next_data(self)
    672     def _next_data(self):
    673         index = self._next_index()  # may raise StopIteration
--> 674         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    675         if self._pin_memory:
    676             data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)

/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
     47         if self.auto_collation:
     48             if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
---> 49                 data = self.dataset.__getitems__(possibly_batched_index)
     50             else:
     51                 data = [self.dataset[idx] for idx in possibly_batched_index]

/usr/local/lib/python3.10/dist-packages/datasets/arrow_dataset.py in __getitems__(self, keys)
   2805     def __getitems__(self, keys: List) -> List:
   2806         """Can be used to get a batch using a list of integers indices."""
-> 2807         batch = self.__getitem__(keys)
   2808         n_examples = len(batch[next(iter(batch))])
   2809         return [{col: array[i] for col, array in batch.items()} for i in range(n_examples)]

<ipython-input-101-977a8cf27fd7> in __getitem__(self, idx)
     17         trg_encoded = trg_encoded + [0]*(max_len - len(trg_encoded))
     18         return [
---> 19             torch.tensor(src_encoded),
     20             # torch.tensor(trg_encoded),
     21         ]

ValueError: expected sequence of length 8 at dim 1 (got 9)

github 上的完整代码

在这里,我在没有收到任何错误时提交了笔记本,但是当我第二天运行代码时出现了这个序列长度不匹配错误

我尝试在 collate 函数中创建自己的填充代码

python pytorch NLP 序列 翻译

评论

0赞 Djinn 10/25/2023
在 CPU 而不是 GPU 上运行,看看错误消息是什么。除了某些错误之外,CPU 错误报告比 GPU 报告的错误更有助于故障排除。
0赞 Hrithik2212 10/26/2023
嗨巨灵,我在 cpu 上运行它

答: 暂无答案