GAN原理与变体详解

前言

生成对抗网络(GAN)通过对抗训练的方式学习数据分布,能够生成逼真的图像、文本等内容。本文介绍GAN的原理、训练技巧和主要变体。


GAN基本原理

对抗训练

GAN由两个网络组成:

  • 生成器(Generator):从噪声生成假样本
  • 判别器(Discriminator):区分真假样本
import numpy as np
import matplotlib.pyplot as plt

np.random.seed(42)

# 可视化GAN训练过程
def visualize_gan_training():
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    # 真实数据分布
    ax = axes[0]
    x_real = np.random.randn(500) * 0.5 + 2
    ax.hist(x_real, bins=30, alpha=0.7, label='真实数据', color='blue')
    ax.set_title('真实数据分布')
    ax.legend()
    
    # 初始生成器分布
    ax = axes[1]
    x_fake_init = np.random.randn(500) * 2
    ax.hist(x_fake_init, bins=30, alpha=0.7, label='初始生成', color='red')
    ax.hist(x_real, bins=30, alpha=0.3, label='真实数据', color='blue')
    ax.set_title('训练初期')
    ax.legend()
    
    # 训练后生成器分布
    ax = axes[2]
    x_fake_trained = np.random.randn(500) * 0.6 + 2
    ax.hist(x_fake_trained, bins=30, alpha=0.7, label='训练后生成', color='red')
    ax.hist(x_real, bins=30, alpha=0.3, label='真实数据', color='blue')
    ax.set_title('训练后期')
    ax.legend()
    
    plt.tight_layout()
    plt.show()

visualize_gan_training()

损失函数

GAN的目标函数(极小极大博弈):

\[\min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{data}}[\log D(x)] + \mathbb{E}_{z \sim p_z}[\log(1 - D(G(z)))]\]
def gan_loss_visualization():
    """可视化GAN损失函数"""
    
    d_output = np.linspace(0.01, 0.99, 100)
    
    # 判别器对真实样本的损失: -log(D(x))
    d_real_loss = -np.log(d_output)
    
    # 判别器对假样本的损失: -log(1-D(G(z)))
    d_fake_loss = -np.log(1 - d_output)
    
    # 生成器损失: -log(D(G(z)))
    g_loss = -np.log(d_output)
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    ax = axes[0]
    ax.plot(d_output, d_real_loss, label='真实样本损失 -log(D(x))', color='blue')
    ax.plot(d_output, d_fake_loss, label='假样本损失 -log(1-D(G(z)))', color='red')
    ax.set_xlabel('判别器输出 D')
    ax.set_ylabel('损失')
    ax.set_title('判别器损失')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    ax = axes[1]
    ax.plot(d_output, g_loss, label='生成器损失 -log(D(G(z)))', color='green')
    ax.set_xlabel('判别器输出 D(G(z))')
    ax.set_ylabel('损失')
    ax.set_title('生成器损失')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

gan_loss_visualization()

PyTorch实现

简单GAN

try:
    import torch
    import torch.nn as nn
    import torch.optim as optim
    
    class Generator(nn.Module):
        """生成器"""
        
        def __init__(self, latent_dim=100, output_dim=784):
            super().__init__()
            
            self.model = nn.Sequential(
                nn.Linear(latent_dim, 256),
                nn.LeakyReLU(0.2),
                nn.BatchNorm1d(256),
                
                nn.Linear(256, 512),
                nn.LeakyReLU(0.2),
                nn.BatchNorm1d(512),
                
                nn.Linear(512, 1024),
                nn.LeakyReLU(0.2),
                nn.BatchNorm1d(1024),
                
                nn.Linear(1024, output_dim),
                nn.Tanh()
            )
        
        def forward(self, z):
            return self.model(z)
    
    
    class Discriminator(nn.Module):
        """判别器"""
        
        def __init__(self, input_dim=784):
            super().__init__()
            
            self.model = nn.Sequential(
                nn.Linear(input_dim, 512),
                nn.LeakyReLU(0.2),
                nn.Dropout(0.3),
                
                nn.Linear(512, 256),
                nn.LeakyReLU(0.2),
                nn.Dropout(0.3),
                
                nn.Linear(256, 1),
                nn.Sigmoid()
            )
        
        def forward(self, x):
            return self.model(x)
    
    # 测试
    latent_dim = 100
    G = Generator(latent_dim=latent_dim)
    D = Discriminator()
    
    z = torch.randn(32, latent_dim)
    fake_images = G(z)
    d_output = D(fake_images)
    
    print(f"噪声输入: {z.shape}")
    print(f"生成图像: {fake_images.shape}")
    print(f"判别器输出: {d_output.shape}")
    
except ImportError:
    print("PyTorch未安装")

训练GAN

try:
    def train_gan(G, D, train_loader, epochs=50, latent_dim=100, lr=0.0002):
        """训练GAN"""
        
        criterion = nn.BCELoss()
        optimizer_G = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
        optimizer_D = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
        
        G_losses = []
        D_losses = []
        
        for epoch in range(epochs):
            g_loss_epoch = 0
            d_loss_epoch = 0
            
            for real_images, in train_loader:
                batch_size = real_images.size(0)
                
                # 标签
                real_labels = torch.ones(batch_size, 1)
                fake_labels = torch.zeros(batch_size, 1)
                
                # ========== 训练判别器 ==========
                optimizer_D.zero_grad()
                
                # 真实样本
                d_real = D(real_images)
                d_loss_real = criterion(d_real, real_labels)
                
                # 假样本
                z = torch.randn(batch_size, latent_dim)
                fake_images = G(z)
                d_fake = D(fake_images.detach())
                d_loss_fake = criterion(d_fake, fake_labels)
                
                d_loss = d_loss_real + d_loss_fake
                d_loss.backward()
                optimizer_D.step()
                
                # ========== 训练生成器 ==========
                optimizer_G.zero_grad()
                
                z = torch.randn(batch_size, latent_dim)
                fake_images = G(z)
                d_fake = D(fake_images)
                g_loss = criterion(d_fake, real_labels)
                
                g_loss.backward()
                optimizer_G.step()
                
                g_loss_epoch += g_loss.item()
                d_loss_epoch += d_loss.item()
            
            G_losses.append(g_loss_epoch / len(train_loader))
            D_losses.append(d_loss_epoch / len(train_loader))
            
            if (epoch + 1) % 10 == 0:
                print(f"Epoch {epoch+1}/{epochs}, D_loss: {D_losses[-1]:.4f}, G_loss: {G_losses[-1]:.4f}")
        
        return G_losses, D_losses
    
    # 创建模拟数据
    from torch.utils.data import DataLoader, TensorDataset
    X_train = torch.randn(1000, 784)
    train_dataset = TensorDataset(X_train)
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    
    # 训练(简化演示)
    G = Generator()
    D = Discriminator()
    G_losses, D_losses = train_gan(G, D, train_loader, epochs=20)
    
except NameError:
    print("需要先定义Generator和Discriminator")

深度卷积GAN(DCGAN)

try:
    class DCGenerator(nn.Module):
        """DCGAN生成器"""
        
        def __init__(self, latent_dim=100, channels=1):
            super().__init__()
            
            self.model = nn.Sequential(
                # 输入: latent_dim -> 256 * 7 * 7
                nn.ConvTranspose2d(latent_dim, 256, 7, 1, 0, bias=False),
                nn.BatchNorm2d(256),
                nn.ReLU(True),
                
                # 256 * 7 * 7 -> 128 * 14 * 14
                nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
                nn.BatchNorm2d(128),
                nn.ReLU(True),
                
                # 128 * 14 * 14 -> channels * 28 * 28
                nn.ConvTranspose2d(128, channels, 4, 2, 1, bias=False),
                nn.Tanh()
            )
        
        def forward(self, z):
            z = z.view(z.size(0), z.size(1), 1, 1)
            return self.model(z)
    
    
    class DCDiscriminator(nn.Module):
        """DCGAN判别器"""
        
        def __init__(self, channels=1):
            super().__init__()
            
            self.model = nn.Sequential(
                # channels * 28 * 28 -> 64 * 14 * 14
                nn.Conv2d(channels, 64, 4, 2, 1, bias=False),
                nn.LeakyReLU(0.2, inplace=True),
                
                # 64 * 14 * 14 -> 128 * 7 * 7
                nn.Conv2d(64, 128, 4, 2, 1, bias=False),
                nn.BatchNorm2d(128),
                nn.LeakyReLU(0.2, inplace=True),
                
                # 128 * 7 * 7 -> 1
                nn.Conv2d(128, 1, 7, 1, 0, bias=False),
                nn.Sigmoid()
            )
        
        def forward(self, x):
            return self.model(x).view(-1, 1)
    
    # 测试DCGAN
    dc_G = DCGenerator()
    dc_D = DCDiscriminator()
    
    z = torch.randn(8, 100)
    fake = dc_G(z)
    score = dc_D(fake)
    
    print(f"DCGAN生成图像: {fake.shape}")
    print(f"判别器得分: {score.shape}")
    
except NameError:
    print("需要先导入PyTorch")

GAN变体

主要变体对比

变体 特点 解决问题
DCGAN 卷积结构 图像生成
WGAN Wasserstein距离 训练稳定性
CGAN 条件生成 可控生成
StyleGAN 风格控制 高质量人脸
CycleGAN 无配对翻译 图像转换

WGAN损失

def wasserstein_loss_visualization():
    """Wasserstein损失"""
    
    # WGAN使用Wasserstein距离,解决梯度消失
    print("WGAN损失函数:")
    print("  判别器(Critic)损失: E[D(x)] - E[D(G(z))]")
    print("  生成器损失: -E[D(G(z))]")
    print("\n关键改进:")
    print("  - 移除Sigmoid,输出无界")
    print("  - 使用权重裁剪或梯度惩罚")
    print("  - 更稳定的训练")

wasserstein_loss_visualization()

条件GAN(CGAN)

try:
    class ConditionalGenerator(nn.Module):
        """条件生成器"""
        
        def __init__(self, latent_dim=100, num_classes=10, output_dim=784):
            super().__init__()
            
            self.label_emb = nn.Embedding(num_classes, num_classes)
            
            self.model = nn.Sequential(
                nn.Linear(latent_dim + num_classes, 256),
                nn.LeakyReLU(0.2),
                nn.BatchNorm1d(256),
                nn.Linear(256, 512),
                nn.LeakyReLU(0.2),
                nn.BatchNorm1d(512),
                nn.Linear(512, output_dim),
                nn.Tanh()
            )
        
        def forward(self, z, labels):
            # 嵌入标签
            label_emb = self.label_emb(labels)
            # 拼接噪声和标签
            x = torch.cat([z, label_emb], dim=1)
            return self.model(x)
    
    
    class ConditionalDiscriminator(nn.Module):
        """条件判别器"""
        
        def __init__(self, input_dim=784, num_classes=10):
            super().__init__()
            
            self.label_emb = nn.Embedding(num_classes, num_classes)
            
            self.model = nn.Sequential(
                nn.Linear(input_dim + num_classes, 512),
                nn.LeakyReLU(0.2),
                nn.Dropout(0.3),
                nn.Linear(512, 256),
                nn.LeakyReLU(0.2),
                nn.Dropout(0.3),
                nn.Linear(256, 1),
                nn.Sigmoid()
            )
        
        def forward(self, x, labels):
            label_emb = self.label_emb(labels)
            x = torch.cat([x, label_emb], dim=1)
            return self.model(x)
    
    # 测试CGAN
    cG = ConditionalGenerator()
    cD = ConditionalDiscriminator()
    
    z = torch.randn(8, 100)
    labels = torch.randint(0, 10, (8,))
    fake = cG(z, labels)
    score = cD(fake, labels)
    
    print(f"条件生成: {fake.shape}")
    print(f"标签: {labels}")
    
except NameError:
    print("需要先导入PyTorch")

训练技巧

常见问题与解决

问题 解决方案
模式崩溃 Mini-batch discrimination, Unrolled GAN
训练不稳定 使用WGAN, 梯度惩罚
梯度消失 使用LeakyReLU, 标签平滑
生成质量差 更深网络, 渐进式训练
# 训练技巧代码
print("GAN训练技巧:")
print("1. 标签平滑: 使用0.9代替1.0")
print("2. 噪声标签: 偶尔翻转标签")
print("3. 两步训练: 先训练D多次,再训练G")
print("4. 谱归一化: 限制判别器权重")
print("5. 特征匹配: 匹配中间层特征")

常见问题

Q1: GAN训练为什么困难?

  • 需要平衡G和D
  • 容易模式崩溃
  • 评估困难

Q2: 如何评估GAN质量?

  • FID(Fréchet Inception Distance)
  • IS(Inception Score)
  • 人工评估

Q3: GAN和VAE的区别?

特性 GAN VAE
训练方式 对抗 最大似然
生成质量 更清晰 可能模糊
训练稳定性 不稳定 稳定
潜在空间 无结构 有结构

Q4: 如何解决模式崩溃?

  • 增加判别器多样性
  • 使用不同的损失函数
  • Mini-batch discrimination

总结

概念 描述
生成器 从噪声生成样本
判别器 区分真假样本
对抗训练 两个网络相互博弈
模式崩溃 生成多样性不足

参考资料

  • Goodfellow, I. et al. (2014). “Generative Adversarial Nets”
  • Radford, A. et al. (2015). “Unsupervised Representation Learning with DCGANs”
  • Arjovsky, M. et al. (2017). “Wasserstein GAN”
  • Karras, T. et al. (2019). “A Style-Based Generator Architecture for GANs”

版权声明: 如无特别声明,本文版权归 sshipanoo 所有,转载请注明本文链接。

(采用 CC BY-NC-SA 4.0 许可协议进行授权)

本文标题:《 机器学习基础系列——生成对抗网络 》

本文链接:http://localhost:3015/ai/%E7%94%9F%E6%88%90%E5%AF%B9%E6%8A%97%E7%BD%91%E7%BB%9C.html

本文最后一次更新为 天前,文章中的某些内容可能已过时!