条件GAN错误,mat1和mat2形状无法相乘(12006400x1和103x256)

ConditionalGAN error, mat1 and mat2 shapes cannot be multiplied (12006400x1 and 103x256)

提问人:ROOT31415 提问时间:8/10/2023 更新时间:8/10/2023 访问量:19

问:

我拥有的代码如下:

class Opt:
    def __init__(self):
        super(Opt, self).__init__()

        self.n_epochs = 10
        self.batch_size = 64
        self.lr = 0.0002
        self.b1 = 0.5
        self.b2 = 0.999
        self.latent_dim = 100
        self.img_size = 64
        self.channels = 3
        self.sample_interval = 400
        self.n_cpu = 14

 
opt= Opt()
img_shape = (opt.channels, opt.img_size, opt.img_size)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class Generator(nn.Module):
    def __init__(self, image_channels, age_embedding_size, latent_dim):
        super(Generator, self).__init__()
        # Embedding layer for age
        self.age_embedding = nn.Embedding(age_embedding_size, latent_dim)
        self.fc = nn.Sequential(
            nn.Linear(latent_dim + image_channels, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, image_channels),
            nn.Tanh()  # To ensure output is between -1 and 1
        )

    def forward(self, current_image, current_age, future_age):
        batch_size = current_image.size(0)
        age_embedded = self.age_embedding(future_age)
        age_embedded = age_embedded.view(batch_size, -1)  # Flatten
        x = torch.cat((current_image.view(batch_size, -1), age_embedded), dim=1)
        generated_image = self.fc(x.view(batch_size, -1, 1, 1))
        return generated_image

class Discriminator(nn.Module):
    def __init__(self, image_channels, age_embedding_size):
        super(Discriminator, self).__init__()
        self.age_embedding = nn.Embedding(age_embedding_size, image_channels)
        self.fc = nn.Sequential(
            nn.Linear(image_channels + image_channels, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, image, current_age, future_age):
        batch_size = image.size(0)
        age_embedded = self.age_embedding(future_age)
        age_embedded = age_embedded.view(batch_size, -1)  # Flatten
        x = torch.cat((image.view(batch_size, -1), age_embedded), dim=1)
        validity = self.fc(x.view(batch_size, -1, 1, 1))
        return validity

# Initialize the generator and discriminator
generator = Generator(image_channels=opt.channels,
                      age_embedding_size=opt.latent_dim,
                      latent_dim=opt.latent_dim)
discriminator = Discriminator(image_channels=opt.channels,
                              age_embedding_size=opt.latent_dim)

generator.to(device)
discriminator.to(device)


for epoch in range(3):
    for i, batch in enumerate(dataloader):

        real_images = batch['original_image'].to(device)
        current_ages = batch['current_age'].to(device)
        future_ages = batch['desired_age'].to(device)
        print(real_images.shape, current_ages.shape, future_ages.shape)
        # Adversarial ground truths
        valid = torch.ones(real_images.size(0), 1).to(device)
        fake = torch.zeros(real_images.size(0), 1).to(device)
        
        # ---------------------
        #  Train Discriminator
        # ---------------------
        
        optimizer_D.zero_grad()

        
        # Generate fake images
        fake_images = generator(real_images, current_ages, future_ages)
        
        # Discriminator loss for real images
        d_real_loss = adversarial_loss(discriminator(real_images, current_ages, future_ages), valid)
        
        # Discriminator loss for fake images
        d_fake_loss = adversarial_loss(discriminator(fake_images.detach(), current_ages, future_ages), fake)
        
        d_loss = d_real_loss + d_fake_loss
        d_loss.backward()
        optimizer_D.step()
        
        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()
        
        # Generate fake images
        gen_images = generator(real_images, current_ages, future_ages)
        
        # Generator loss
        g_loss = adversarial_loss(discriminator(gen_images, current_ages, future_ages), valid)
        
        g_loss.backward()
        optimizer_G.step()

        # Print progress
        if (i + 1) % opt.sample_interval == 0:
            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
            )

            # Save generated images
            save_image(gen_images.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)


我的数据集的格式为

{'original_image': 'CACD2000\\16_Christopher_Mintz-Plasse_0005.jpg',
   'target_image': 'CACD2000\\16_Christopher_Mintz-Plasse_0016.jpg',
   'current_age': 16,
   'desired_age': 27},
  {'original_image': 'CACD2000\\16_Chris_Brown_0003.jpg',
   'target_image': 'CACD2000\\16_Chris_Brown_0004.jpg',
   'current_age': 14,
   'desired_age': 15},

然而,我在网络中相乘的矩阵的维度上遇到了错误。

我正在尝试创建一个面部衰老项目,该项目需要 1 张图像和 2 个年龄,并给我老年面孔。然而,我在架构上遇到了麻烦。我该如何解决这个问题??

print(real_images.shape, current_ages.shape, future_ages.shape) 

给出输出 炬。尺寸([64, 3, 250, 250]) 割炬。尺寸([64]) 火炬。尺寸([64])

我该如何解决这个问题?

我尝试更改一些尺寸以及嵌入形状。然而,我遇到了多个错误,我迷路了。

矩阵 PyTorch 维度 对抗-网络

评论


答: 暂无答案