已经是最新一篇文章了!
已经是最后一篇文章了!
GAN原理与变体详解
前言
生成对抗网络(GAN)通过对抗训练的方式学习数据分布,能够生成逼真的图像、文本等内容。本文介绍GAN的原理、训练技巧和主要变体。
GAN基本原理
对抗训练
GAN由两个网络组成:
- 生成器(Generator):从噪声生成假样本
- 判别器(Discriminator):区分真假样本
import numpy as np
import matplotlib.pyplot as plt
np.random.seed(42)
# 可视化GAN训练过程
def visualize_gan_training():
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
# 真实数据分布
ax = axes[0]
x_real = np.random.randn(500) * 0.5 + 2
ax.hist(x_real, bins=30, alpha=0.7, label='真实数据', color='blue')
ax.set_title('真实数据分布')
ax.legend()
# 初始生成器分布
ax = axes[1]
x_fake_init = np.random.randn(500) * 2
ax.hist(x_fake_init, bins=30, alpha=0.7, label='初始生成', color='red')
ax.hist(x_real, bins=30, alpha=0.3, label='真实数据', color='blue')
ax.set_title('训练初期')
ax.legend()
# 训练后生成器分布
ax = axes[2]
x_fake_trained = np.random.randn(500) * 0.6 + 2
ax.hist(x_fake_trained, bins=30, alpha=0.7, label='训练后生成', color='red')
ax.hist(x_real, bins=30, alpha=0.3, label='真实数据', color='blue')
ax.set_title('训练后期')
ax.legend()
plt.tight_layout()
plt.show()
visualize_gan_training()
损失函数
GAN的目标函数(极小极大博弈):
\[\min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{data}}[\log D(x)] + \mathbb{E}_{z \sim p_z}[\log(1 - D(G(z)))]\]def gan_loss_visualization():
"""可视化GAN损失函数"""
d_output = np.linspace(0.01, 0.99, 100)
# 判别器对真实样本的损失: -log(D(x))
d_real_loss = -np.log(d_output)
# 判别器对假样本的损失: -log(1-D(G(z)))
d_fake_loss = -np.log(1 - d_output)
# 生成器损失: -log(D(G(z)))
g_loss = -np.log(d_output)
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
ax = axes[0]
ax.plot(d_output, d_real_loss, label='真实样本损失 -log(D(x))', color='blue')
ax.plot(d_output, d_fake_loss, label='假样本损失 -log(1-D(G(z)))', color='red')
ax.set_xlabel('判别器输出 D')
ax.set_ylabel('损失')
ax.set_title('判别器损失')
ax.legend()
ax.grid(True, alpha=0.3)
ax = axes[1]
ax.plot(d_output, g_loss, label='生成器损失 -log(D(G(z)))', color='green')
ax.set_xlabel('判别器输出 D(G(z))')
ax.set_ylabel('损失')
ax.set_title('生成器损失')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
gan_loss_visualization()
PyTorch实现
简单GAN
try:
import torch
import torch.nn as nn
import torch.optim as optim
class Generator(nn.Module):
"""生成器"""
def __init__(self, latent_dim=100, output_dim=784):
super().__init__()
self.model = nn.Sequential(
nn.Linear(latent_dim, 256),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(256),
nn.Linear(256, 512),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(512),
nn.Linear(512, 1024),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(1024),
nn.Linear(1024, output_dim),
nn.Tanh()
)
def forward(self, z):
return self.model(z)
class Discriminator(nn.Module):
"""判别器"""
def __init__(self, input_dim=784):
super().__init__()
self.model = nn.Sequential(
nn.Linear(input_dim, 512),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, x):
return self.model(x)
# 测试
latent_dim = 100
G = Generator(latent_dim=latent_dim)
D = Discriminator()
z = torch.randn(32, latent_dim)
fake_images = G(z)
d_output = D(fake_images)
print(f"噪声输入: {z.shape}")
print(f"生成图像: {fake_images.shape}")
print(f"判别器输出: {d_output.shape}")
except ImportError:
print("PyTorch未安装")
训练GAN
try:
def train_gan(G, D, train_loader, epochs=50, latent_dim=100, lr=0.0002):
"""训练GAN"""
criterion = nn.BCELoss()
optimizer_G = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
G_losses = []
D_losses = []
for epoch in range(epochs):
g_loss_epoch = 0
d_loss_epoch = 0
for real_images, in train_loader:
batch_size = real_images.size(0)
# 标签
real_labels = torch.ones(batch_size, 1)
fake_labels = torch.zeros(batch_size, 1)
# ========== 训练判别器 ==========
optimizer_D.zero_grad()
# 真实样本
d_real = D(real_images)
d_loss_real = criterion(d_real, real_labels)
# 假样本
z = torch.randn(batch_size, latent_dim)
fake_images = G(z)
d_fake = D(fake_images.detach())
d_loss_fake = criterion(d_fake, fake_labels)
d_loss = d_loss_real + d_loss_fake
d_loss.backward()
optimizer_D.step()
# ========== 训练生成器 ==========
optimizer_G.zero_grad()
z = torch.randn(batch_size, latent_dim)
fake_images = G(z)
d_fake = D(fake_images)
g_loss = criterion(d_fake, real_labels)
g_loss.backward()
optimizer_G.step()
g_loss_epoch += g_loss.item()
d_loss_epoch += d_loss.item()
G_losses.append(g_loss_epoch / len(train_loader))
D_losses.append(d_loss_epoch / len(train_loader))
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1}/{epochs}, D_loss: {D_losses[-1]:.4f}, G_loss: {G_losses[-1]:.4f}")
return G_losses, D_losses
# 创建模拟数据
from torch.utils.data import DataLoader, TensorDataset
X_train = torch.randn(1000, 784)
train_dataset = TensorDataset(X_train)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
# 训练(简化演示)
G = Generator()
D = Discriminator()
G_losses, D_losses = train_gan(G, D, train_loader, epochs=20)
except NameError:
print("需要先定义Generator和Discriminator")
深度卷积GAN(DCGAN)
try:
class DCGenerator(nn.Module):
"""DCGAN生成器"""
def __init__(self, latent_dim=100, channels=1):
super().__init__()
self.model = nn.Sequential(
# 输入: latent_dim -> 256 * 7 * 7
nn.ConvTranspose2d(latent_dim, 256, 7, 1, 0, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(True),
# 256 * 7 * 7 -> 128 * 14 * 14
nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(True),
# 128 * 14 * 14 -> channels * 28 * 28
nn.ConvTranspose2d(128, channels, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, z):
z = z.view(z.size(0), z.size(1), 1, 1)
return self.model(z)
class DCDiscriminator(nn.Module):
"""DCGAN判别器"""
def __init__(self, channels=1):
super().__init__()
self.model = nn.Sequential(
# channels * 28 * 28 -> 64 * 14 * 14
nn.Conv2d(channels, 64, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# 64 * 14 * 14 -> 128 * 7 * 7
nn.Conv2d(64, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
# 128 * 7 * 7 -> 1
nn.Conv2d(128, 1, 7, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, x):
return self.model(x).view(-1, 1)
# 测试DCGAN
dc_G = DCGenerator()
dc_D = DCDiscriminator()
z = torch.randn(8, 100)
fake = dc_G(z)
score = dc_D(fake)
print(f"DCGAN生成图像: {fake.shape}")
print(f"判别器得分: {score.shape}")
except NameError:
print("需要先导入PyTorch")
GAN变体
主要变体对比
| 变体 | 特点 | 解决问题 |
|---|---|---|
| DCGAN | 卷积结构 | 图像生成 |
| WGAN | Wasserstein距离 | 训练稳定性 |
| CGAN | 条件生成 | 可控生成 |
| StyleGAN | 风格控制 | 高质量人脸 |
| CycleGAN | 无配对翻译 | 图像转换 |
WGAN损失
def wasserstein_loss_visualization():
"""Wasserstein损失"""
# WGAN使用Wasserstein距离,解决梯度消失
print("WGAN损失函数:")
print(" 判别器(Critic)损失: E[D(x)] - E[D(G(z))]")
print(" 生成器损失: -E[D(G(z))]")
print("\n关键改进:")
print(" - 移除Sigmoid,输出无界")
print(" - 使用权重裁剪或梯度惩罚")
print(" - 更稳定的训练")
wasserstein_loss_visualization()
条件GAN(CGAN)
try:
class ConditionalGenerator(nn.Module):
"""条件生成器"""
def __init__(self, latent_dim=100, num_classes=10, output_dim=784):
super().__init__()
self.label_emb = nn.Embedding(num_classes, num_classes)
self.model = nn.Sequential(
nn.Linear(latent_dim + num_classes, 256),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(256),
nn.Linear(256, 512),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(512),
nn.Linear(512, output_dim),
nn.Tanh()
)
def forward(self, z, labels):
# 嵌入标签
label_emb = self.label_emb(labels)
# 拼接噪声和标签
x = torch.cat([z, label_emb], dim=1)
return self.model(x)
class ConditionalDiscriminator(nn.Module):
"""条件判别器"""
def __init__(self, input_dim=784, num_classes=10):
super().__init__()
self.label_emb = nn.Embedding(num_classes, num_classes)
self.model = nn.Sequential(
nn.Linear(input_dim + num_classes, 512),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, x, labels):
label_emb = self.label_emb(labels)
x = torch.cat([x, label_emb], dim=1)
return self.model(x)
# 测试CGAN
cG = ConditionalGenerator()
cD = ConditionalDiscriminator()
z = torch.randn(8, 100)
labels = torch.randint(0, 10, (8,))
fake = cG(z, labels)
score = cD(fake, labels)
print(f"条件生成: {fake.shape}")
print(f"标签: {labels}")
except NameError:
print("需要先导入PyTorch")
训练技巧
常见问题与解决
| 问题 | 解决方案 |
|---|---|
| 模式崩溃 | Mini-batch discrimination, Unrolled GAN |
| 训练不稳定 | 使用WGAN, 梯度惩罚 |
| 梯度消失 | 使用LeakyReLU, 标签平滑 |
| 生成质量差 | 更深网络, 渐进式训练 |
# 训练技巧代码
print("GAN训练技巧:")
print("1. 标签平滑: 使用0.9代替1.0")
print("2. 噪声标签: 偶尔翻转标签")
print("3. 两步训练: 先训练D多次,再训练G")
print("4. 谱归一化: 限制判别器权重")
print("5. 特征匹配: 匹配中间层特征")
常见问题
Q1: GAN训练为什么困难?
- 需要平衡G和D
- 容易模式崩溃
- 评估困难
Q2: 如何评估GAN质量?
- FID(Fréchet Inception Distance)
- IS(Inception Score)
- 人工评估
Q3: GAN和VAE的区别?
| 特性 | GAN | VAE |
|---|---|---|
| 训练方式 | 对抗 | 最大似然 |
| 生成质量 | 更清晰 | 可能模糊 |
| 训练稳定性 | 不稳定 | 稳定 |
| 潜在空间 | 无结构 | 有结构 |
Q4: 如何解决模式崩溃?
- 增加判别器多样性
- 使用不同的损失函数
- Mini-batch discrimination
总结
| 概念 | 描述 |
|---|---|
| 生成器 | 从噪声生成样本 |
| 判别器 | 区分真假样本 |
| 对抗训练 | 两个网络相互博弈 |
| 模式崩溃 | 生成多样性不足 |
参考资料
- Goodfellow, I. et al. (2014). “Generative Adversarial Nets”
- Radford, A. et al. (2015). “Unsupervised Representation Learning with DCGANs”
- Arjovsky, M. et al. (2017). “Wasserstein GAN”
- Karras, T. et al. (2019). “A Style-Based Generator Architecture for GANs”
版权声明: 如无特别声明,本文版权归 sshipanoo 所有,转载请注明本文链接。
(采用 CC BY-NC-SA 4.0 许可协议进行授权)
本文标题:《 机器学习基础系列——生成对抗网络 》
本文链接:http://localhost:3015/ai/%E7%94%9F%E6%88%90%E5%AF%B9%E6%8A%97%E7%BD%91%E7%BB%9C.html
本文最后一次更新为 天前,文章中的某些内容可能已过时!