生成式对抗网络-GAN
关于 生成式对抗网络的基础介绍 + 一种扩展的生成对抗网络🧐
对抗生成网络
模型结构
-
生成模型 G: 捕捉数据分布
-
判别模型 D: 估计样本来自训练数据还是 G的概率(判别是真的还是假的的概率)
训练方式
数据要求: 以张量的形式传入[M, N, K]
-
基于反向传播机制
-
判别器 D 的目标
-
生成器 G 的目标就是最小化 D 的目标价值,
-
在实际G的模型训练中,由于一开始的生成网络与实际网络差异大,
更趋向于 0,所以 G难以从 判别器D中获得梯度信号,难以改进生成能力,所以训练G时,可以最大化(logD(G(z)))
训练机制
交替进行 k步 D优化, 一步G优化
-
训练判别器 D
-
批量训练的平均损失函数
-
训练生成器 G
**原始目标:最小化**
-
批量训练的平均损失函数
-
梯度计算
-
参数更新
注意: 进行 G更新的时候,判别器 D的参数是冻结的
CycleGAN
模型结构
-
两个生成模型 G 和 F:
- 生成器 G: 学习将图像从域 X 转换到域 Y 的映射,即
- 生成器 F: 学习将图像从域 Y 转换到域 X 的映射,即
- 生成器 G: 学习将图像从域 X 转换到域 Y 的映射,即
-
两个判别模型
和 : - 判别器
: 区分图像是来自目标域 Y 的真实图像还是由 G 生成的假图像 - 判别器
: 区分图像是来自目标域 X 的真实图像还是由 F 生成的假图像
- 判别器
训练方式
CycleGAN的训练目标是学习两个映射
-
的图像与域 Y 中的图像无法区分 -
的图像与域 X 中的图像无法区分
为了实现无监督训练,提出了循环一致性损失(Cycle Consistency Loss)
-
前向循环一致性:
-
后向循环一致性:
整个优化目标函数是对抗性损失(Adversarial Loss)和循环一致性损失的加权和
联合损失函数
1. 对抗性损失 (Adversarial Loss)
CycleGAN 使用标准的 GAN 损失来提高生成的图像与目标域中的图像分布匹配准确率
-
生成器 G 和 判别器
的对抗损失: 鼓励 看起来像真实图像 Y
-
相关概念
- 判别器
的目标是最大化 - 生成器 G 的目标是最小化
- G 通常会最大化
以避免早期训练不稳定。
- 判别器
-
生成器 F 和 判别器
的对抗损失: 鼓励 看起来像真实图像 X
-
相关概念
- 判别器
的目标是最大化 - 生成器 F 的目标是最小化
- 在实际训练中,F 通常会最大化
- 判别器
2. 循环一致性损失 (Cycle Consistency Loss)
循环一致性损失确保了转换的可逆性,防止生成器学习到将源域中的所有图像映射到目标域中的相同图像,从而避免“模式崩溃”,就是降低泛化能力
-
前向循环一致性损失:
-
后向循环一致性损失:
训练机制
CycleGAN 的训练也采用交替优化策略,同时优化两个生成器和两个判别器
-
训练判别器
和 : -
对于
: 使用来自域 Y 的真实图像 和由 G 生成的假图像 来计算损失并更新参数 - 批量训练的平均损失函数(以
为例,目标是最大化):
- 梯度计算:
-
参数更新:
- 批量训练的平均损失函数(以
-
对于
: 类似地,使用来自域 X 的真实图像 和由 F 生成的假图像 来计算损失并更新参数
-
-
训练生成器 G 和 F:
-
对于 G 和 F:
- 计算包括对抗损失和循环一致性损失在内的总损失
- 批量训练的平均损失函数(以 G 为例,目标是最小化其对抗部分,并最小化循环一致性部分):
- 梯度计算:
-
参数更新:
-
注意: 进行 G 和 F 更新时,判别器
和 的参数是冻结的
-
生成式对抗网络-GAN