带有 shuffle=False 的数据加载器,但图像顺序在每个纪元中都会发生变化

Dataloader with shuffle=False, but images order change in each epoch

提问人:mor hale 提问时间:11/5/2023 最后编辑:dan1stmor hale 更新时间:11/9/2023 访问量:33

问:

即使我使用图像随机化每个时代。'shuffle=False'

下面是用于创建加载程序的代码:

data_set = dset.CIFAR10(root='./data/cifar10', train=True, transform=transform, download=True)
train_loader, test_loader = create_loader_from_data_set(data_set, n_samples, batch_size, num_workers)

def create_loader_from_data_set(data_set, n_samples, batch_size, num_workers, test_size=0.2):
    indices = list(range(len(data_set)))
    selected_indices = random.sample(indices, n_samples)

    train_indices, test_indices = train_test_split(selected_indices, test_size=test_size, random_state=42)

    train_sampler = SubsetRandomSampler(train_indices)
    test_sampler = SubsetRandomSampler(test_indices)

    train_loader = DataLoader(data_set, batch_size=batch_size, num_workers=num_workers, sampler=train_sampler, shuffle=False)
    test_loader = DataLoader(data_set, batch_size=batch_size, num_workers=num_workers, sampler=test_sampler, shuffle=False)
    return train_loader, test_loader

这是针对训练循环的:

def train_epoch(epoch, network, loader, optimizer, batch_size):
    network.train()
    for batch_index, sample_tensor in enumerate(loader):
        batch_images, _ = sample_tensor 

我在每个时期得到的图像顺序不同(批次也不相同)。 shuffle=False 不应该保持顺序不变吗?

谢谢!

我也尝试过使用发电机,但没有用:

gen = torch.Generator()

train_loader = DataLoader(data_set, batch_size=batch_size, num_workers=num_workers, sampler=train_sampler, generator=gen)
镜像 深度学习 训练-数据 PyTorch-DataLoader

评论


答:

0赞 arturo-bandini-jr 11/9/2023 #1

您应该尝试,因为此函数的默认值为 True。train_test_split(..., shuffle = False)

reference -> https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html#sklearn.model_selection.train_test_split

评论

0赞 mor hale 11/10/2023
嗨,首先谢谢你,但它没有用......还是不同的顺序。这是我尝试过的代码: train_indices, test_indices = train_test_split(selected_indices, test_size=test_size, random_state=42, shuffle=False) train_sampler = SubsetRandomSampler(train_indices) test_sampler = SubsetRandomSampler(test_indices) train_loader = DataLoader(data_set, batch_size=batch_size, num_workers=num_workers, sampler=train_sampler, shuffle=False) test_loader = DataLoader(data_set, batch_size=batch_size, num_workers=num_workers, 采样器=test_sampler, shuffle=False)
0赞 arturo-bandini-jr 11/10/2023
虽然我不使用 pytorch,但我认为您的问题源于 SubsetRandomSampler 函数(ref -> pytorch.org/docs/stable/...