Python中的判别器网络和生成器网络的对比
发布时间:2024-01-02 23:45:16
在深度学习中,判别器网络和生成器网络是常用的两种模型,常用于对抗生成网络(GAN)中。
判别器网络(Discriminator)是一个用于分类的模型,其目标是根据输入的数据判断其属于真实数据还是生成数据。一般来说,判别器网络由多个网络层组成,通过学习真实数据集中的特征,可以对生成数据进行分类。判别器网络的输出通常是一个0到1之间的概率值,可以表示输入数据属于真实数据的概率。
生成器网络(Generator)是一个用于生成新数据的模型,其目标是学习生成与真实数据相似的数据。生成器网络通常由多个网络层组成,通过输入一些随机噪声向量来生成数据。生成器网络会尽可能地使生成的数据更接近真实数据,以欺骗判别器网络。
下面以图像生成任务为例,来说明判别器网络和生成器网络的使用。
首先,导入需要的库和模块:
import torch import torch.nn as nn import torch.optim as optim import torchvision import torchvision.transforms as transforms
然后,定义判别器网络和生成器网络的结构:
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.fc1 = nn.Linear(784, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = x.view(x.size(0), -1)
x = self.fc1(x)
x = nn.functional.relu(x)
x = self.fc2(x)
x = nn.functional.relu(x)
x = self.fc3(x)
x = self.sigmoid(x)
return x
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.fc1 = nn.Linear(100, 256)
self.fc2 = nn.Linear(256, 512)
self.fc3 = nn.Linear(512, 784)
self.tanh = nn.Tanh()
def forward(self, x):
x = self.fc1(x)
x = nn.functional.relu(x)
x = self.fc2(x)
x = nn.functional.relu(x)
x = self.fc3(x)
x = self.tanh(x)
return x
接下来,准备数据集并定义训练过程:
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.MNIST(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
shuffle=True, num_workers=2)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
discriminator = Discriminator().to(device)
generator = Generator().to(device)
criterion = nn.BCELoss()
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002)
optimizer_g = optim.Adam(generator.parameters(), lr=0.0002)
num_epochs = 50
for epoch in range(num_epochs):
for i, data in enumerate(trainloader, 0):
real_images, _ = data
real_images = real_images.to(device)
# 训练判别器网络
discriminator.zero_grad()
real_labels = torch.ones(real_images.size(0), 1).to(device)
fake_labels = torch.zeros(real_images.size(0), 1).to(device)
outputs = discriminator(real_images)
loss_d_real = criterion(outputs, real_labels)
loss_d_real.backward()
real_score = outputs.mean().item()
noise = torch.randn(real_images.size(0), 100).to(device)
fake_images = generator(noise)
outputs = discriminator(fake_images.detach())
loss_d_fake = criterion(outputs, fake_labels)
loss_d_fake.backward()
fake_score = outputs.mean().item()
loss_d = loss_d_real + loss_d_fake
optimizer_d.step()
# 训练生成器网络
generator.zero_grad()
outputs = discriminator(fake_images)
loss_g = criterion(outputs, real_labels)
loss_g.backward()
optimizer_g.step()
if i % 100 == 0:
print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f'
% (epoch+1, num_epochs, i, len(trainloader),
loss_d.item(), loss_g.item(), real_score, fake_score))
这是一个简单的GAN模型,判别器网络的输入是图像数据,输出是一个0到1之间的概率值。生成器网络的输入是随机噪声向量,输出是生成的图像数据。训练过程中,判别器网络和生成器网络交替训练,判别器网络的目标是最大化判别准确度,生成器网络的目标是最小化生成图像与真实图像的区分度。
