提问人:mor hale 提问时间:11/5/2023 最后编辑:dan1stmor hale 更新时间:11/9/2023 访问量:33
带有 shuffle=False 的数据加载器,但图像顺序在每个纪元中都会发生变化
Dataloader with shuffle=False, but images order change in each epoch
问:
即使我使用图像随机化每个时代。'shuffle=False'
下面是用于创建加载程序的代码:
data_set = dset.CIFAR10(root='./data/cifar10', train=True, transform=transform, download=True)
train_loader, test_loader = create_loader_from_data_set(data_set, n_samples, batch_size, num_workers)
def create_loader_from_data_set(data_set, n_samples, batch_size, num_workers, test_size=0.2):
indices = list(range(len(data_set)))
selected_indices = random.sample(indices, n_samples)
train_indices, test_indices = train_test_split(selected_indices, test_size=test_size, random_state=42)
train_sampler = SubsetRandomSampler(train_indices)
test_sampler = SubsetRandomSampler(test_indices)
train_loader = DataLoader(data_set, batch_size=batch_size, num_workers=num_workers, sampler=train_sampler, shuffle=False)
test_loader = DataLoader(data_set, batch_size=batch_size, num_workers=num_workers, sampler=test_sampler, shuffle=False)
return train_loader, test_loader
这是针对训练循环的:
def train_epoch(epoch, network, loader, optimizer, batch_size):
network.train()
for batch_index, sample_tensor in enumerate(loader):
batch_images, _ = sample_tensor
我在每个时期得到的图像顺序不同(批次也不相同)。 shuffle=False 不应该保持顺序不变吗?
谢谢!
我也尝试过使用发电机,但没有用:
gen = torch.Generator()
train_loader = DataLoader(data_set, batch_size=batch_size, num_workers=num_workers, sampler=train_sampler, generator=gen)
答:
0赞
arturo-bandini-jr
11/9/2023
#1
您应该尝试,因为此函数的默认值为 True。train_test_split(..., shuffle = False)
评论
0赞
mor hale
11/10/2023
嗨,首先谢谢你,但它没有用......还是不同的顺序。这是我尝试过的代码: train_indices, test_indices = train_test_split(selected_indices, test_size=test_size, random_state=42, shuffle=False) train_sampler = SubsetRandomSampler(train_indices) test_sampler = SubsetRandomSampler(test_indices) train_loader = DataLoader(data_set, batch_size=batch_size, num_workers=num_workers, sampler=train_sampler, shuffle=False) test_loader = DataLoader(data_set, batch_size=batch_size, num_workers=num_workers, 采样器=test_sampler, shuffle=False)
0赞
arturo-bandini-jr
11/10/2023
虽然我不使用 pytorch,但我认为您的问题源于 SubsetRandomSampler 函数(ref -> pytorch.org/docs/stable/...
评论