生成对抗网络(Generative Adversarial Nets,GAN)

核心思想

生成对抗网络是一种基于对抗学习的深度生成模型,最早由Ian Goodfellow于2014年在《Generative Adversarial Nets》中提出,一经提出便成为了学术界研究的热点,也将生成模型的热度推向了另一个新的高峰。上节有讨论到,直接用图片做监督存带来均值灾难,我们又无法得到真实分布从而监督训练。因此,借助变分推断的思想做一个概率分布近似。从一个简单的已知分布(如标准高斯分布)出发,通过某种方式或手段,将其近似为真实数据的概率分布。GAN正是遵循这一理论,但实现过程中直接对齐分布是很难的,因为我们并不知道概率分布的函数形式,所以也无法得知它到底有几个参数。

所以可以换一个思想,既然无法得到概率分布函数的具体形式,没有参数,不好近似,那我就不去近似他了。对于两个分布而言,如果它们的大多数随机采样的样本概率都是对齐的,那不就说明这两个概率分布函数已经接近了吗。很好,你已经掌握了生成对抗网络的要领,试着自己实现一下吧。(-_-||)

网络架构

生成对抗网络采用双网络架构设计,由生成器(Generator, G)和判别器(Discriminator, D)两个神经网络组成,它们在训练过程中相互对抗、共同进化。

生成器网络:作为整个系统的"创造者",生成器的任务是学习从简单的噪声分布(通常是标准高斯分布)到复杂数据分布的映射关系。具体而言,它接收一个低维的随机噪声向量 \(z \sim p_z(z)\) 作为输入,通过多层神经网络的非线性变换,输出与真实数据维度相同的生成样本 \(G(z)\)。生成器的目标是使生成的样本在分布上尽可能接近真实数据分布 \(p_{data}(x)\)

判别器网络:作为系统的"鉴别专家",判别器本质上是一个二分类器。它接收来自两个不同源的样本——真实数据集中的样本 \(x \sim p_{data}(x)\) 和生成器产生的样本 \(G(z)\),输出一个概率值 \(D(x) \in [0,1]\),表示输入样本来自真实数据分布的概率。判别器的目标是最大化正确分类的概率:对真实样本输出接近1,对生成样本输出接近0。

GAN Net

这种架构设计的巧妙之处在于,通过样本对齐来实现分布对齐。如果生成器产生的样本能够骗过判别器,说明生成分布已经在某种程度上接近了真实分布,而无需显式地建模分布函数。随着训练的进行,两个网络在对抗中相互提升,最终达到一个动态平衡状态。实现了样本概率之间的对齐,而不是样本对齐。(上一节说了直接图像对齐带来均值灾难,其本质原因是训练过程对齐的是样本而不是样本的概率)

对抗训练机制

GAN 的损失函数核心思想在于 对抗训练,生成器 G 和判别器 D 进行一场零和博弈:

  • 生成器(Generator):试图生成足以 "以假乱真" 的样本,欺骗判别器
  • 判别器(Discriminator):试图准确区分真实样本和生成样本

这种对抗关系可以用一个形象的比喻来理解:生成器就像一个造假币的人,而判别器就像一个验钞机。造假者不断提高造假技术,而验钞机也在不断提升鉴别能力。在这个博弈过程中,两者的能力都在不断提升。论文所述训练步骤如图所示。

GAN Train

训练步骤

  • 首先,使用交叉熵损失函数来训练判别器参数 \(\theta_d\)。从高斯分布 \(p_g(z)\) 中采样 \(m\) 个噪声向量:\(\{ z^{(1)}, \ldots, z^{(m)} \}\)。从数据分布 \(p_{\text{data}}(x)\) 中采样 \(m\) 个真实样本:\(\{ x^{(1)}, \ldots, x^{(m)} \}\)。使用交叉熵损失训练判别器网络 \(\theta_d\)

    \[ \mathop{\arg\min}_{\theta_d} \left( \frac{1}{m} \sum_{i=1}^{m} \left[ \log D(x^{(i)}) + \log \left( 1 - D(G(z^{(i)})) \right) \right] \right) \]

    从损失函数可以看出,目标是使 \(\log D(x^{(i)})\) 趋于 1,\(D(G(z^{(i)}))\) 趋于 0,以达到判别器识别生成数据的目的。

  • 判别器更新后,固定判别器参数。重新再次采样 \(m\) 个噪声向量:\(\{ z^{(1)}, \ldots, z^{(m)} \}\)。最小化判别损失以使生成器能够骗过判别器(即让生成的样本更像真实样本):

    \[ \mathop{\arg\min}_{\theta_g} \left( \frac{1}{m} \sum_{i=1}^{m} \left[ \log \left( 1 - D(G(z^{(i)})) \right) \right] \right) \]

    💡 在实际应用中,为了提升生成器的梯度,也常使用替代形式:\(-\log(D(G(z)))\),其目标是让判别器认为生成样本为真。

  • 重复以上步骤,直到判别器和生成器都收敛。

最后,整个训练的损失函数可以整合为对抗损失函数(Adversarial Loss Function / Minimax Objective)

\[ \min_G \max_D \; V(D, G) = E_{x \sim p_{\text{data}}(x)} \left[ \log D(x) \right] + E_{z \sim p_g(z)} \left[ \log \left( 1 - D(G(z)) \right) \right] \]

其中:

  • \(D(x)\) 表示判别器对真实样本 \(x\) 判定为真的概率;
  • \(D(G(z))\) 表示判别器对生成样本为真的概率;
  • \(G(z)\) 是生成器对输入噪声 \(z\) 的映射输出;
  • \(E_{x \sim p_{\text{data}}(x)}\) 是对真实数据分布的期望;
  • \(E_{z \sim p_g(z)}\) 是对生成器输入噪声分布的期望。

GAN能达到平衡的原因以及背后的数学证明

平衡原因的直观解释

这里首先给出作者画的图来解释为什么对抗训练能够收敛。图中绿色线代表生成器的概率分布,蓝色虚线代表判别器的分布,黑色虚线代表真实数据分布。字母 \(z\) 代表采样的高斯噪声,\(x\) 代表真实数据,我们从 \(z\) 中采样,通过生成器将 \(z\) 的样本映射成符合 \(x\) 分布的样本。在图\((a)\) 中,判别器和生成器的概率分布都是随机设定,大部分样本的映射都不正确。首先训练判别器,既让判别器能够识别出哪些是映射过来的数据,哪些是真实样本数据。训练完成之后就是图\((b)\) 所展示的分布。真实数据判别接近1,生成数据判别接近0。然后再训练生成器,让生成器朝着让判别器无法判别的方向移动,即为图\((c)\) ,最后经过循环拉扯,达到了图\((d)\) 的平衡。即生成器的分布和真实数据的分布相同。

GAN theory

背后的数学原理

在深入理解GAN的数学基础之前,我们需要先回顾KL散度(Kullback-Leibler Divergence)的两个关键局限性:

KL散度的局限性:

之前我们有讨论到KL散度的非负性,其实它还具有其他两个特性,即:

  1. 非对称性:KL散度不满足对称性 \[KL(P\|Q) \neq KL(Q\|P)\]

  2. 数值不稳定性:当两个分布的支撑集不重叠时,KL散度趋向于无穷大

\[\text{当 } P(x) > 0 \text{ 且 } Q(x) = 0 \text{ 时,} KL(P\|Q) \to +\infty\] \[\text{当 } Q(x) > 0 \text{ 且 } P(x) = 0 \text{ 时,} KL(P\|Q) \to -\infty\]

这两个问题在GAN训练中会导致严重的梯度消失和数值不稳定问题。因此,引入Jensen-Shannon散度(JS散度)作为更适合的距离度量。

JS 散度

\[ JS(P \| Q) = \frac{1}{2} KL(P \| M) + \frac{1}{2} KL(Q \| M) \] 其中 \(M = \frac{1}{2}(P + Q)\)

积分形式 \[ JS(P\|Q) = \frac{1}{2}\int p(x)\log\frac{2p(x)}{p(x)+q(x)}dx + \frac{1}{2}\int q(x)\log\frac{2q(x)}{p(x)+q(x)}dx \]

JS散度的重要特性:

  1. 对称性\(JS(P\|Q) = JS(Q\|P)\)

  2. 非负性\(JS(P\|Q) \geq 0\),当且仅当 \(P = Q\) 时等号成立

  3. 有界性\(0 \leq JS(P\|Q) \leq \log 2\)

  4. 平滑性:相比KL散度,JS散度在处理不重叠分布时更加稳定

但其实以上两种概率分布度量方式都有局限性,在 \(PQ\) 不重叠的情况下,KL散度趋于无穷,这会导致梯度爆炸,而JS散度趋于常值,导致梯度消失。更多感兴趣的可以看:两者分布不重合JS散度为log2的数学证明

另外,我们再介绍一下狄拉克函数:

狄拉克函数 \(\delta(x)\) 与生成器的概率分布 \(p_g(x)\) 表示

狄拉克函数 \(\delta(x)\) 并不是数学中一个严格意义上的函数,而是在泛函分析中被称为广义函数(generalized function)或分布(distribution),它在除零以外的点上都等于零,且其在整个定义域上的积分等于1。用数学定义的方式可以写为: \[ \delta(x) = \begin{cases} 0, & x \neq 0, \\ \infty, & x = 0, \end{cases} \quad\text{且}\quad \int_{-\infty}^{\infty} \delta(x) \, dx = 1. \]

以上定义方式并不严谨,感兴趣的朋友请移步泛函分析,(再展开就没完没了了)。它具有抽样性质(Sifting Property):

\[ \int_{-\infty}^{\infty} f(x) \, \delta(x-a) \, dx = f(a). \]

没错,学过信号系统的同学立马反应过来了,这他么不就是脉冲函数吗!是的,只是在不同领域叫法不同。

生成器 \(G\) 将噪声 \(z \sim p_z(z)\) 映射为样本 \(x\),那么生成的样本 \(x\) 的概率密度可以通过以下方式计算:

\[p_g(x) = \int_z p_z(z)\delta(x-G(z))dz\] 对于确定性映射 \(G\),只有在 \(x = G(z)\) 的某个 \(z\) 值处才有非零概率密度。狄拉克δ函数 \(\delta(x-G(z))\) 正好捕获了这一点:

  • \(x = G(z)\) 时,\(\delta(x-G(z)) = \infty\)
  • \(x \neq G(z)\) 时,\(\delta(x-G(z)) = 0\)

意思是我们先按 \(p_z(z)\) 采样一个潜变量 \(z\),生成器 \(G(z)\) 会把这个 \(z\) 映射成一个数据样本 \(x' = G(z)\),δ 函数 \(\delta(x - G(z))\) 表示:只有当 \(x\) 恰好等于 \(G(z)\) 时(生成的样本符合真实图像概率分布),这个积分才有贡献,所以积分结果就是——所有能生成 \(x\)\(z\) 的概率质量总和

也就是说,\(p_g(x)\)通过生成器将潜空间的分布 \(p_z(z)\) 推送(pushforward)到数据空间后得到的分布

更数学化的理解(推送分布)

这个公式本质上是 概率分布的变换公式,在测度论的语言里就是: \[ p_g = G_\# p_z \] 其中 \(G_\#\) 表示 推送测度(pushforward measure)——用生成器 \(G\)\(z\) 空间的概率分布搬到 \(x\) 空间。如果生成器是确定性的(没有噪声),那么:

  • 每个 \(z\) 只会映射到一个 \(x\)
  • δ 函数就表示了这种确定映射关系

举个简单例子

假设:

  • \(p_z(z)\) 是均匀分布在区间 \([0,1]\)
  • \(G(z) = 2z\)

那么:

\[ p_g(x) = \int_{z=0}^1 1 \cdot \delta(x - 2z) \, dz \]

用 δ 函数的性质,解得:

\[ p_g(x) = \frac{1}{2} \quad \text{当 } x \in [0, 2] \quad \text{否则 } 0 \]

这正是把 \([0,1]\) 的均匀分布线性拉伸到 \([0,2]\) 后的概率密度。

回归GAN的正题,在理想情况下,当GAN训练达到纳什均衡时

  • 判别器无法区分真假样本,即 \(D(x) = 0.5\)
  • 生成器的生成分布与真实数据分布一致,即 \(p_g = p_{data}\)

用数学表达有:

  • 对于固定生成器 \(G\),判别器的最优解为 \[ D^*(x)=\frac{p_{data}(x)}{p_{data}(x)+p_g(x)}. \]

把该最优判别器代回目标函数后,生成器的目标等价于最小化真实分布与生成分布之间的JS散度: \[ C(G)\equiv V\big(G,D^*\big) = -\log 4 + 2\,\mathrm{JS}\big(p_{data}\parallel p_g\big). \]

因为 \(JS \ge 0\) 且当且仅当 \(p_g=p_{data}\) 时为 0,所以在纳什均衡处 \(p_g=p_{data}\),此时 \(D^*(x)=\tfrac{1}{2}\)

详细证明

步骤1:求解最优判别器

对于固定的生成器G,判别器D的目标是最大化:

\[V(G,D) = E_{x \sim p_{data}}[\log D(x)] + E_{z \sim p_z}[\log(1-D(G(z)))]\]

其中 \(E\) 表示数学期望,展开得:

\[V(G,D) = \int_x p_{data}(x)\log D(x)dx + \int_z p_z(z)\log(1-D(G(z)))dz\]

因此把第二项用狄拉克函数展开有:
\[ \log(1 - D(G(z))) = \int_x \log(1 - D(x)) \, \delta\big(x - G(z)\big) \, dx. \] 把上面的表达式代回原积分:

\[ \begin{aligned} &\int_z p_z(z) \log(1 - D(G(z))) \, dz \\ &= \int_z p_z(z) \left[ \int_x \log(1 - D(x)) \, \delta(x - G(z)) \, dx \right] dz. \end{aligned} \]

交换积分顺序(Fubini 定理,在广义函数意义下成立):

\[ = \int_x \log(1 - D(x)) \left[ \int_z p_z(z) \, \delta(x - G(z)) \, dz \right] dx. \]

又因为 \(p_g(x) = \int_z p_z(z)\delta(x-G(z))dz\),代入有:

\[V(G,D) = \int_x p_{data}(x)\log D(x) + p_g(x)\log(1-D(x))dx\]

写成期望的形式: \[V(G,D) = E_{x \sim p_{data}}[\log D(x)] + E_{x \sim p_g}[\log(1-D(x))]\]

注意:\(D(x)\) 在每个 \(x\) 处独立出现,因此对于固定的 \(G\),求使 \(V(G,D)\) 最大化的 \(D\) 可以逐点(对每个 \(x\))独立求解,因此可以通过求被积函数的极值点来最大化概率。

\[\text{令} f(D) = p_{data}(x)\log D(x) + p_g(x)\log(1-D(x))\]

对D求导并令其为0: \[\frac{\partial f}{\partial D} = \frac{p_{data}(x)}{D(x)} - \frac{p_g(x)}{1-D(x)} = 0\]

二阶导数为

\[\frac{\partial^2 f}{\partial D^2} = -\frac{p_{data}(x)}{D(x)^2} - \frac{p_g(x)}{(1-D(x))^2} < 0\]

因此该临界点是全局最大值(对每个 \(x\))。若 \(p_{data}(x)=p_g(x)=0\),则该点在积分意义上不影响结果(可任意定义)。

解得最优判别器: \[D^*_G(x) = \frac{p_{data}(x)}{p_{data}(x) + p_g(x)}\]

步骤2:将最优判别器代回目标函数

\(D^*_G(x)\) 代入原始目标函数: \[ \begin{aligned} C(G) \equiv V\big(G,D^*\big) &= \int \Big[ p_{data}(x)\log D^*(x) + p_g(x)\log\big(1-D^*(x)\big) \Big] dx \\ &= \int \Big[ p_{data}\log\frac{p_{data}}{p_{data}+p_g} + p_g\log\frac{p_g}{p_{data}+p_g} \Big] dx \\ &= \int \Big[ p_{data}\log p_{data} + p_g\log p_g - (p_{data}+p_g)\log(p_{data}+p_g) \Big] dx \\ &= E_{x\sim p_{data}}\left[\log\frac{p_{data}(x)}{p_{data}(x) + p_g(x)}\right] + E_{x\sim p_g}\left[\log\frac{p_g(x)}{p_{data}(x) + p_g(x)}\right] \end{aligned} \]

步骤3:变换为JS散度

注意到: \[\log\frac{p_{data}(x)}{p_{data}(x) + p_g(x)} = \log\frac{p_{data}(x)}{2 \cdot \frac{p_{data}(x) + p_g(x)}{2}} = \log\frac{p_{data}(x)}{2M(x)} = -\log 2 + \log\frac{p_{data}(x)}{M(x)}\]

其中 \(M(x) = \frac{p_{data}(x) + p_g(x)}{2}\) 是两个分布的平均。

因此: \[C(G) = -\log 4 + KL(p_{data} \| M) + KL(p_g \| M)\]

根据JS散度的定义:

\[JS(p_{data} \| p_g) = \frac{1}{2}KL(p_{data} \| M) + \frac{1}{2}KL(p_g \| M)\]

我们得到

\[C(G) = -\log 4 + 2 \cdot JS(p_{data} \| p_g)\]

由于生成器G要最小化 \(C(G)\),而 \(\log 4\) 是常数,因此: \[\min_G C(G) \Leftrightarrow \min_G JS(p_{data} \| p_g)\]

根据JS散度的性质和公式,有:

  1. 当且仅当 \(p_g=p_{data}\) 时,JS散度为0,达到全局最小值
  2. 此时最优判别器 \(D^*_G(x) = \frac{1}{2}\),无法区分真假样本
  3. 这证明了GAN在理论上能够学习到真实数据分布

这个证明揭示了几个重要事实:

  1. GAN隐式地最小化JS散度:虽然我们没有显式计算两个分布之间的距离,但通过对抗训练,实际上在最小化JS散度

  2. 判别器的作用:判别器不仅仅是一个分类器,它实际上在估计两个分布的密度比

  3. 收敛性保证:在理想条件下(无限容量、充分训练),GAN能够收敛到真实数据分布

然而,这个理论分析基于几个理想假设(如无限容量的模型、全局最优等),在实践中难以满足,这也是GAN训练不稳定的原因之一。

GAN的优势与挑战

优势:

  1. 无需显式建模概率密度:不需要像VAE那样引入变分下界
  2. 生成质量高:能够生成非常逼真的样本
  3. 理论优雅:基于博弈论的框架简洁而有力

挑战:

  1. 训练不稳定:生成器和判别器的平衡很难把握
  2. 模式坍塌(Mode Collapse):生成器可能只学会生成某几种模式的样本
  3. 梯度消失:当判别器过于强大时,生成器的梯度会消失
  4. 评估困难:难以定量评估生成模型的质量

GAN的变体与改进

为了解决原始GAN的问题,研究者们提出了许多改进版本:

  1. DCGAN(Deep Convolutional GAN):使用卷积神经网络架构,提高了训练稳定性
  2. WGAN(Wasserstein GAN):使用Wasserstein距离替代JS散度,缓解了梯度消失问题
  3. LSGAN(Least Squares GAN):使用最小二乘损失,使生成样本更接近决策边界
  4. StyleGAN:引入风格控制,实现了高质量、可控的图像生成
  5. CycleGAN:实现了无配对数据的图像到图像转换

总结

GAN开创了生成模型的新时代,通过对抗学习的思想,巧妙地避开了直接建模概率分布的困难。虽然训练过程存在诸多挑战,但其强大的生成能力使其在图像生成、风格迁移、超分辨率等多个领域取得了巨大成功。随着各种改进技术的出现,GAN已经成为深度学习中最重要的生成模型之一。

GAN的PyTorch实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def train_gan(generator, discriminator, dataloader, num_epochs, 
noise_dim, device, k=1, lr=0.0002):
"""
GAN训练函数

Args:
generator: 生成器模型
discriminator: 判别器模型
dataloader: 数据加载器
num_epochs: 训练轮数
noise_dim: 噪声向量维度
device: 训练设备 (cpu/cuda)
k: 判别器更新次数(论文中使用k=1)
lr: 学习率
"""
# 优化器 - 使用动量优化
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

# 损失函数
criterion = nn.BCELoss()

generator.to(device)
discriminator.to(device)

for epoch in range(num_epochs):
for batch_idx, real_data in enumerate(dataloader):
batch_size = real_data.size(0)
real_data = real_data.to(device)

# ========== 训练判别器 k 步 ==========
for _ in range(k):
# 1. 从噪声先验采样
z = torch.randn(batch_size, noise_dim).to(device)

# 2. 从数据分布采样(已通过dataloader获得)

# 3. 更新判别器(梯度上升)
optimizer_D.zero_grad()

# 真实数据的判别器输出
real_output = discriminator(real_data)
real_labels = torch.ones(batch_size, 1).to(device)
loss_real = criterion(real_output, real_labels)

# 生成数据的判别器输出
fake_data = generator(z).detach()
fake_output = discriminator(fake_data)
fake_labels = torch.zeros(batch_size, 1).to(device)
loss_fake = criterion(fake_output, fake_labels)

# 判别器总损失:最大化 log D(x) + log(1 - D(G(z)))
d_loss = loss_real + loss_fake
d_loss.backward()
optimizer_D.step()

# ========== 训练生成器 1 步 ==========
# 从噪声先验采样
z = torch.randn(batch_size, noise_dim).to(device)

# 更新生成器(梯度下降)
optimizer_G.zero_grad()

# 生成假数据
fake_data = generator(z)
fake_output = discriminator(fake_data)

# 生成器损失:最小化 log(1 - D(G(z)))
# 等价于最大化 log D(G(z))
g_labels = torch.ones(batch_size, 1).to(device)
g_loss = criterion(fake_output, g_labels)

g_loss.backward()
optimizer_G.step()

# 打印训练信息
if batch_idx % 100 == 0:
print(f'Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(dataloader)} '
f'Loss_D: {d_loss.item():.4f} Loss_G: {g_loss.item():.4f}')

return generator, discriminator