欢迎访问宙启技术站
智能推送

PyTorch实现生成对抗网络(GAN)

发布时间:2023-12-23 09:07:50

PyTorch是一个基于Python的科学计算库,用于构建深度学习模型。生成对抗网络(GAN)是一种非监督式学习方法,通过训练生成器和判别器两个网络相互对抗来生成逼真的样本。下面将介绍如何使用PyTorch实现GAN,并提供一个简单的例子。

首先,我们需要导入PyTorch和其他必要的库:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

接下来,我们需要定义生成器和判别器网络的架构。生成器网络接收一个随机噪声向量作为输入,并输出一个逼真的样本。判别器网络接收一个样本作为输入,并输出该样本是真实样本的概率。以下是生成器和判别器的简单实现:

class Generator(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, output_dim),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)


class Discriminator(nn.Module):
    def __init__(self, input_dim):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

在定义好生成器和判别器之后,我们需要设置超参数和数据加载器。以下是一些常用的超参数,根据需要进行调整:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
lr = 0.0002
batch_size = 128
input_dim = 100
output_dim = 784
epochs = 30

下面是数据加载器的代码,我们将使用MNIST数据集作为训练数据:

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_data = datasets.MNIST(root="data/", train=True, transform=transform, download=True)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

接下来,我们可以初始化生成器和判别器,并设置损失函数和优化器:

generator = Generator(input_dim, output_dim).to(device)
discriminator = Discriminator(output_dim).to(device)
criterion = nn.BCELoss()
gen_optimizer = optim.Adam(generator.parameters(), lr=lr)
disc_optimizer = optim.Adam(discriminator.parameters(), lr=lr)

然后,我们可以开始进行训练:

for epoch in range(epochs):
    for batch_idx, (real_images, _) in enumerate(train_loader):
        real_images = real_images.view(-1, output_dim).to(device)
        batch_size = real_images.shape[0]

        # 训练判别器
        disc_optimizer.zero_grad()
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        # 真实样本的损失
        real_outputs = discriminator(real_images)
        real_loss = criterion(real_outputs, real_labels)

        # 生成样本的损失
        z = torch.randn(batch_size, input_dim).to(device)
        fake_images = generator(z)
        fake_outputs = discriminator(fake_images.detach())
        fake_loss = criterion(fake_outputs, fake_labels)

        # 判别器的总损失
        disc_loss = real_loss + fake_loss
        disc_loss.backward()
        disc_optimizer.step()

        # 训练生成器
        gen_optimizer.zero_grad()
        z = torch.randn(batch_size, input_dim).to(device)
        fake_images = generator(z)
        outputs = discriminator(fake_images)
        gen_loss = criterion(outputs, real_labels)
        gen_loss.backward()
        gen_optimizer.step()

        if batch_idx % 100 == 0:
            print(f"Epoch [{epoch}/{epochs}], Batch [{batch_idx}/{len(train_loader)}], "
                  f"Gen Loss: {gen_loss.item():.4f}, Disc Loss: {disc_loss.item():.4f}")

最后,我们可以使用生成器生成一些样本并进行可视化:

import matplotlib.pyplot as plt

z = torch.randn(10, input_dim).to(device)
generated_images = generator(z).detach().cpu()
fig, axs = plt.subplots(1, 10, figsize=(10, 1))
for i in range(10):
    axs[i].imshow(generated_images[i].view(28, 28), cmap='gray')
    axs[i].axis('off')
plt.show()

以上就是使用PyTorch实现生成对抗网络(GAN)的基本步骤和一个简单的例子。当然,GAN的训练过程需要花费一定的时间和计算资源,为了得到更好的生成结果,可以进一步优化网络架构和调整超参数。