提问人:ROOT31415 提问时间:8/10/2023 更新时间:8/10/2023 访问量:19
条件GAN错误,mat1和mat2形状无法相乘(12006400x1和103x256)
ConditionalGAN error, mat1 and mat2 shapes cannot be multiplied (12006400x1 and 103x256)
问:
我拥有的代码如下:
class Opt:
def __init__(self):
super(Opt, self).__init__()
self.n_epochs = 10
self.batch_size = 64
self.lr = 0.0002
self.b1 = 0.5
self.b2 = 0.999
self.latent_dim = 100
self.img_size = 64
self.channels = 3
self.sample_interval = 400
self.n_cpu = 14
opt= Opt()
img_shape = (opt.channels, opt.img_size, opt.img_size)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class Generator(nn.Module):
def __init__(self, image_channels, age_embedding_size, latent_dim):
super(Generator, self).__init__()
# Embedding layer for age
self.age_embedding = nn.Embedding(age_embedding_size, latent_dim)
self.fc = nn.Sequential(
nn.Linear(latent_dim + image_channels, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, 1024),
nn.ReLU(),
nn.Linear(1024, image_channels),
nn.Tanh() # To ensure output is between -1 and 1
)
def forward(self, current_image, current_age, future_age):
batch_size = current_image.size(0)
age_embedded = self.age_embedding(future_age)
age_embedded = age_embedded.view(batch_size, -1) # Flatten
x = torch.cat((current_image.view(batch_size, -1), age_embedded), dim=1)
generated_image = self.fc(x.view(batch_size, -1, 1, 1))
return generated_image
class Discriminator(nn.Module):
def __init__(self, image_channels, age_embedding_size):
super(Discriminator, self).__init__()
self.age_embedding = nn.Embedding(age_embedding_size, image_channels)
self.fc = nn.Sequential(
nn.Linear(image_channels + image_channels, 1024),
nn.ReLU(),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Linear(512, 256),
nn.ReLU(),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, image, current_age, future_age):
batch_size = image.size(0)
age_embedded = self.age_embedding(future_age)
age_embedded = age_embedded.view(batch_size, -1) # Flatten
x = torch.cat((image.view(batch_size, -1), age_embedded), dim=1)
validity = self.fc(x.view(batch_size, -1, 1, 1))
return validity
# Initialize the generator and discriminator
generator = Generator(image_channels=opt.channels,
age_embedding_size=opt.latent_dim,
latent_dim=opt.latent_dim)
discriminator = Discriminator(image_channels=opt.channels,
age_embedding_size=opt.latent_dim)
generator.to(device)
discriminator.to(device)
for epoch in range(3):
for i, batch in enumerate(dataloader):
real_images = batch['original_image'].to(device)
current_ages = batch['current_age'].to(device)
future_ages = batch['desired_age'].to(device)
print(real_images.shape, current_ages.shape, future_ages.shape)
# Adversarial ground truths
valid = torch.ones(real_images.size(0), 1).to(device)
fake = torch.zeros(real_images.size(0), 1).to(device)
# ---------------------
# Train Discriminator
# ---------------------
optimizer_D.zero_grad()
# Generate fake images
fake_images = generator(real_images, current_ages, future_ages)
# Discriminator loss for real images
d_real_loss = adversarial_loss(discriminator(real_images, current_ages, future_ages), valid)
# Discriminator loss for fake images
d_fake_loss = adversarial_loss(discriminator(fake_images.detach(), current_ages, future_ages), fake)
d_loss = d_real_loss + d_fake_loss
d_loss.backward()
optimizer_D.step()
# -----------------
# Train Generator
# -----------------
optimizer_G.zero_grad()
# Generate fake images
gen_images = generator(real_images, current_ages, future_ages)
# Generator loss
g_loss = adversarial_loss(discriminator(gen_images, current_ages, future_ages), valid)
g_loss.backward()
optimizer_G.step()
# Print progress
if (i + 1) % opt.sample_interval == 0:
print(
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
)
# Save generated images
save_image(gen_images.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)
我的数据集的格式为
{'original_image': 'CACD2000\\16_Christopher_Mintz-Plasse_0005.jpg',
'target_image': 'CACD2000\\16_Christopher_Mintz-Plasse_0016.jpg',
'current_age': 16,
'desired_age': 27},
{'original_image': 'CACD2000\\16_Chris_Brown_0003.jpg',
'target_image': 'CACD2000\\16_Chris_Brown_0004.jpg',
'current_age': 14,
'desired_age': 15},
然而,我在网络中相乘的矩阵的维度上遇到了错误。
我正在尝试创建一个面部衰老项目,该项目需要 1 张图像和 2 个年龄,并给我老年面孔。然而,我在架构上遇到了麻烦。我该如何解决这个问题??
print(real_images.shape, current_ages.shape, future_ages.shape)
给出输出 炬。尺寸([64, 3, 250, 250]) 割炬。尺寸([64]) 火炬。尺寸([64])
我该如何解决这个问题?
我尝试更改一些尺寸以及嵌入形状。然而,我遇到了多个错误,我迷路了。
答: 暂无答案
评论