提问人:Hrithik2212 提问时间:10/25/2023 更新时间:10/25/2023 访问量:37
ValueError:在 dim 1 时长度为 8 的预期序列(得到 9)
ValueError: expected sequence of length 8 at dim 1 (got 9)
问:
我不确定为什么我会收到这个错误,我在前一天晚上运行代码时没有收到这个错误。 我已经确保我的translation_src和translation_target都具有相同的序列长度,而且我的填充整理功能也是正确的
class TranslationDataset(Dataset):
def __init__(self, dataset):
self.dataset = dataset
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
src_encoded=self.dataset[idx]['translation_src']
trg_encoded=self.dataset[idx]['translation_trg']
# Determine the maximum sequence length
max_len = max(len(src_encoded), len(trg_encoded))
# Pad the sequences to have the same length
src_encoded = src_encoded + [0]*(max_len - len(src_encoded))
trg_encoded = trg_encoded + [0]*(max_len - len(trg_encoded))
return (
torch.tensor(src_encoded),
torch.tensor(trg_encoded),
)
train_ds = TranslationDataset(data['train'])
val_ds = TranslationDataset(data['test'])
def pad_collate_fn(batch):
src_sentences,trg_sentences=[],[]
for sample in batch:
src_sentences+=[sample[0]]
trg_sentences+=[sample[1]]
src_sentences = pad_sequence(src_sentences, batch_first=True, padding_value=0)
trg_sentences = pad_sequence(trg_sentences, batch_first=True, padding_value=0)
return src_sentences, trg_sentences
def chunk(indices, chunk_size):
return torch.split(torch.tensor(indices), chunk_size)
class CustomBatchSampler(Sampler):
def __init__(self, dataset, batch_size):
# Dataset is already sorted so just chunk indices
# into batches of indices for sampling
self.batch_size=batch_size
self.indices=range(len(dataset))
self.batch_of_indices=list(chunk(self.indices, self.batch_size))
self.batch_of_indices = [batch.tolist() for batch in self.batch_of_indices]
def __iter__(self):
random.shuffle(self.batch_of_indices)
return iter(self.batch_of_indices)
def __len__(self):
return len(self.batch_of_indices)
custom_batcher_train = CustomBatchSampler(train_ds, config['BATCH_SIZE'])
custom_batcher_val = CustomBatchSampler(val_ds, config['BATCH_SIZE'])
# example-use
dummy_batcher = CustomBatchSampler(train_ds, 3)
dummy_dl=DataLoader(train_ds, collate_fn=pad_collate_fn , batch_sampler=dummy_batcher, pin_memory=True)
for x ,y in dummy_dl:
print('Shapes: ')
print('-'*10)
print(x.size())
print(y.size())
print()
print('e.g. src batch (see there is minimal/no padding):')
print('-'*10)
print(x.numpy())
break
错误
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-103-b82fc7198b80> in <cell line: 30>()
28 dummy_batcher = CustomBatchSampler(train_ds, 3)
29 dummy_dl=DataLoader(train_ds, collate_fn=pad_collate_fn , batch_sampler=dummy_batcher, pin_memory=True)
---> 30 for x ,y in dummy_dl:
31 print('Shapes: ')
32 print('-'*10)
4 frames
/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py in __next__(self)
628 # TODO(https://github.com/pytorch/pytorch/issues/76750)
629 self._reset() # type: ignore[call-arg]
--> 630 data = self._next_data()
631 self._num_yielded += 1
632 if self._dataset_kind == _DatasetKind.Iterable and \
/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py in _next_data(self)
672 def _next_data(self):
673 index = self._next_index() # may raise StopIteration
--> 674 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
675 if self._pin_memory:
676 data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
47 if self.auto_collation:
48 if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
---> 49 data = self.dataset.__getitems__(possibly_batched_index)
50 else:
51 data = [self.dataset[idx] for idx in possibly_batched_index]
/usr/local/lib/python3.10/dist-packages/datasets/arrow_dataset.py in __getitems__(self, keys)
2805 def __getitems__(self, keys: List) -> List:
2806 """Can be used to get a batch using a list of integers indices."""
-> 2807 batch = self.__getitem__(keys)
2808 n_examples = len(batch[next(iter(batch))])
2809 return [{col: array[i] for col, array in batch.items()} for i in range(n_examples)]
<ipython-input-101-977a8cf27fd7> in __getitem__(self, idx)
17 trg_encoded = trg_encoded + [0]*(max_len - len(trg_encoded))
18 return [
---> 19 torch.tensor(src_encoded),
20 # torch.tensor(trg_encoded),
21 ]
ValueError: expected sequence of length 8 at dim 1 (got 9)
在这里,我在没有收到任何错误时提交了笔记本,但是当我第二天运行代码时出现了这个序列长度不匹配错误
我尝试在 collate 函数中创建自己的填充代码
答: 暂无答案
评论