提问人:Ramon Griffo 提问时间:11/17/2023 更新时间:11/17/2023 访问量:31
使用可变input_shape训练 DCGAN
Train DCGAN with variable input_shape
问:
我有一个包含 2d 灰度图像的数据集,我想为生成器不是我关注的问题构建一个 GAN。我实际上想要一个非常好的鉴别器,能够区分真假图像(由生成器生成)。
问题是,我处理的数据不是传统的图像,图像大小是语义的,每个像素总是相关的,所以我不能简单地重塑图像。我想到了鉴别器的一个解决方案,在这个解决方案中,我将使用自适应最大池化在最后获得单个输出,而不管输入形状如何,这样,如果图像的任何部分看起来是假的,它将返回假的。
我没有像下面那样测试鉴别器,它只是我想到的变化的一个例子。
class Discriminator(nn.Module):
def __init__(self) -> None:
super(Discriminator, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(1, 64, (4, 4), (2, 2), (1, 1), bias=True),
nn.LeakyReLU(0.2, True),
nn.Conv2d(64, 128, (4, 4), (2, 2), (1, 1), bias=False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, True),
nn.Conv2d(128, 256, (4, 4), (2, 2), (1, 1), bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, True),
nn.Conv2d(256, 512, (4, 4), (2, 2), (1, 1), bias=False),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, True),
nn.AdaptiveMaxPool2d((1, 1)),
nn.Flatten(),
nn.Linear(512, 1),
nn.Sigmoid()
)
def forward(self, x: Tensor) -> Tensor:
out = self.main(x)
out = torch.flatten(out, 1)
return out
问题是,我不知道如何做类似于生成器的事情,它为鉴别器生成随机大小的图像是有意义的。我怎样才能做到这一点?
作为参考,我的整个代码如下,我目前运行它的方式,以及重塑。它正在工作,目前正在训练,部分结果还可以,但我真的需要改变。
import torch.nn as nn
from torch import Tensor
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import os
import torch
import numpy as np
from torch.utils.data import Dataset
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from PIL import Image
class CustomDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.image_filename = os.listdir(root_dir)
def __len__(self):
return len(self.image_filename)
def __getitem__(self, idx):
filename = os.path.join(self.root_dir, self.image_filename[idx])
image = np.load(filename, allow_pickle=True)
image = Image.fromarray(image.astype('uint8'))
if self.transform:
image = self.transform(image)
return image
class Discriminator(nn.Module):
def __init__(self) -> None:
super(Discriminator, self).__init__()
self.main = nn.Sequential(
# Input is 1 x 64 x 64
nn.Conv2d(1, 64, (4, 4), (2, 2), (1, 1), bias=True),
nn.LeakyReLU(0.2, True),
# State size. 64 x 32 x 32
nn.Conv2d(64, 128, (4, 4), (2, 2), (1, 1), bias=False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, True),
# State size. 128 x 16 x 16
nn.Conv2d(128, 256, (4, 4), (2, 2), (1, 1), bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, True),
# State size. 256 x 8 x 8
nn.Conv2d(256, 512, (4, 4), (2, 2), (1, 1), bias=False),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, True),
# State size. 512 x 4 x 4
nn.AdaptiveMaxPool2d((1, 1)),
nn.Flatten(),
nn.Linear(512, 1),
nn.Sigmoid()
# nn.Conv2d(512, 1, (4, 4), (1, 1), (0, 0), bias=True),
# nn.Sigmoid()
)
def forward(self, x: Tensor) -> Tensor:
out = self.main(x)
out = torch.flatten(out, 1)
return out
class Generator(nn.Module):
def __init__(self) -> None:
super(Generator, self).__init__()
self.main = nn.Sequential(
nn.ConvTranspose2d(100, 512, (4, 4), (1, 1), (0, 0), bias=False),
nn.BatchNorm2d(512),
nn.ReLU(True),
nn.ConvTranspose2d(512, 256, (4, 4), (2, 2), (1, 1), bias=False),
nn.BatchNorm2d(256),
nn.ReLU(True),
nn.ConvTranspose2d(256, 128, (4, 4), (2, 2), (1, 1), bias=False),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128, 64, (4, 4), (2, 2), (1, 1), bias=False),
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.ConvTranspose2d(64, 1, (4, 4), (2, 2), (1, 1), bias=True),
nn.Tanh()
)
def forward(self, x: Tensor) -> Tensor:
return self._forward_impl(x)
# Support PyTorch.script function.
def _forward_impl(self, x: Tensor) -> Tensor:
out = self.main(x)
return out
def _initialize_weights(self) -> None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight, 0.0, 0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.normal_(m.weight, 1.0, 0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def epoch(
discriminator,
generator,
dataloader,
epoch_n,
n_epochs,
device=torch.device("cuda:0"),
exp_name="exp000",
criterion=nn.BCELoss(),
d_lr=0.0002,
g_lr=0.0002,
):
writer = SummaryWriter(os.path.join("samples", "logs", exp_name))
criterion = criterion.to(device)
d_optimizer = optim.Adam(discriminator.parameters(), d_lr, (0.5, 0.999))
g_optimizer = optim.Adam(generator.parameters(), g_lr, (0.5, 0.999))
# Calculate how many iterations there are under epoch.
batches = len(dataloader)
# Set two models in training mode.
discriminator.train()
generator.train()
for index, real in enumerate(dataloader):
# Copy the data to the specified device.
real = real.to(device)
label_size = real.size(0)
# Create label. Set the real sample label to 1, and the fake sample label to 0.
real_label = torch.full([label_size, 1], 1.0, dtype=real.dtype, device=device)
fake_label = torch.full([label_size, 1], 0.0, dtype=real.dtype, device=device)
# Create an image that conforms to the Gaussian distribution.
noise = torch.randn([label_size, 100, 1, 1], device=device)
# Initialize the discriminator model gradient.
discriminator.zero_grad()
# Calculate the loss of the discriminator model on the real image.
output = discriminator(real)
d_loss_real = criterion(output, real_label)
d_loss_real.backward()
d_real = output.mean().item()
# Generate a fake image.
fake = generator(noise)
# Calculate the loss of the discriminator model on the fake image.
output = discriminator(fake.detach())
d_loss_fake = criterion(output, fake_label)
d_loss_fake.backward()
d_fake1 = output.mean().item()
# Update the weights of the discriminator model.
d_loss = d_loss_real + d_loss_fake
d_optimizer.step()
# Initialize the generator model gradient.
generator.zero_grad()
# Calculate the loss of the discriminator model on the fake image.
output = discriminator(fake)
# Adversarial loss.
g_loss = criterion(output, real_label)
# Update the weights of the generator model.
g_loss.backward()
g_optimizer.step()
d_fake2 = output.mean().item()
# Write the loss during training into Tensorboard.
iters = index + epoch_n * batches + 1
writer.add_scalar("Train_Adversarial/D_Loss", d_loss.item(), iters)
writer.add_scalar("Train_Adversarial/G_Loss", g_loss.item(), iters)
writer.add_scalar("Train_Adversarial/D_Real", d_real, iters)
writer.add_scalar("Train_Adversarial/D_Fake1", d_fake1, iters)
writer.add_scalar("Train_Adversarial/D_Fake2", d_fake2, iters)
# Print the loss function every ten iterations and the last iteration in this epoch.
if (index + 1) % 10 == 0 or (index + 1) == batches:
print(f"Train stage: adversarial "
f"Epoch[{epoch_n + 1:04d}/{n_epochs:04d}]({index + 1:05d}/{batches:05d}) "
f"D Loss: {d_loss.item():.6f} G Loss: {g_loss.item():.6f} "
f"D(Real): {d_real:.6f} D(Fake1)/D(Fake2): {d_fake1:.6f}/{d_fake2:.6f}.")
def train_model(
dataset_path,
image_size,
experiment_samples_path,
experiment_results_path,
batch_size,
device,
n_epochs,
resume_params=None,
):
discriminator = Discriminator().to(device)
generator = Generator().to(device)
# Create a experiment result folder.
if not os.path.exists(experiment_samples_path):
os.makedirs(experiment_samples_path)
if not os.path.exists(experiment_results_path):
os.makedirs(experiment_results_path)
# Create an image that conforms to the Gaussian distribution.
fixed_noise = torch.randn([batch_size, 100, 1, 1], device=device)
# Set up transformations
transform = transforms.Compose([
transforms.Resize([image_size, image_size]),
transforms.ToTensor(),
transforms.Normalize(0, 0.5)
])
# Create custom dataset
dataset = CustomDataset(root_dir=dataset_path, transform=transform)
# Create data loader
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
# Check whether the training progress of the last abnormal end is restored, for example, the power is
# cut off in the middle of the training.
if resume_params is not None:
if resume_params["resume_d_weight"] != "":
print(f"Resuming Discriminator... {resume_params['resume_d_weight']}")
discriminator.load_state_dict(torch.load(resume_params["resume_d_weight"]))
if resume_params["resume_g_weight"] != "":
print(f"Resuming Generator... {resume_params['resume_g_weight']}")
generator.load_state_dict(torch.load(resume_params["resume_g_weight"]))
for epoch_n in range(0, n_epochs):
# Train each epoch to generate a model.
epoch(
discriminator,
generator,
dataloader,
epoch_n,
n_epochs,
device=torch.device("cuda:0"),
exp_name="exp000",
criterion=nn.BCELoss(),
d_lr=0.0002,
g_lr=0.0002
)
# Save the weight of the model under epoch.
torch.save(discriminator.state_dict(), os.path.join(experiment_samples_path, f"d_epoch{epoch_n + 1}.pth"))
torch.save(generator.state_dict(), os.path.join(experiment_samples_path, f"g_epoch{epoch_n + 1}.pth"))
# Each epoch validates the model once.
with torch.no_grad():
# Switch model to eval mode.
generator.eval()
fake = generator(fixed_noise).detach()
torchvision.utils.save_image(fake, os.path.join(experiment_samples_path, f"epoch_{epoch_n + 1}.bmp"),
normalize=True)
# Save the weight of the model under the last Epoch in this stage.
torch.save(discriminator.state_dict(), os.path.join(experiment_results_path, "d-last.pth"))
torch.save(generator.state_dict(), os.path.join(experiment_results_path, "g-last.pth"))
train_model(
"../images",
64,
"../experimento_1/samples",
"../experimento_1/results",
64,
torch.device("cuda:0"),
100,
resume_params=None,
)
答: 暂无答案
评论