GAN
生成对抗网络(GAN)由两个神经网络组成,一个是生成器,另一个是鉴别器,它们被设计为竞争对手。鉴别器经过训练后,可将训练集中的数据分类为真实数据,将生成器产生的数据分类为伪造数据;生成器在训练后,能创建可以以假乱真的数据来欺骗鉴别器。
GAN基础概念
- 如果网络的输入是一幅猫咪的图像,输出值应该是1,对应真(true);如果图像中不是猫咪,输出值应该是0,对应伪(false)。
- 改动此项任务,如果网络的输入是一幅真猫咪的图像,则输出值应该是1,对应真(true);如果图像中是假猫咪,则输出值应该是0,对应伪(false)。
- 再次改动此项任务,如果网络的输入是一幅真猫咪的图像,输出值应该是1,对应真(true);如果图像中是由猫咪生成器生成的猫咪图像,输出值应该是0,对应伪(false)。
- 现在,假设我们用一个被训练用于生成图像的神经网络,取代之前只能生成低质量猫咪图像的组件。我们称它为生成器 (generator)。同时,我们把分类器称为鉴别器(discriminator)。如果图像通过了鉴别器的检验,我们奖励生成器。如果伪造的图像被识破,我们惩罚生成器。随着训练的进展,鉴别器的表现越来越好,生成器也必须不断进步,才能骗过更好的鉴别器。最终,生成器也变得非常出色,可以生成足以以假乱真的图像。鉴别器和生成器是竞争对手(adversary)关系,双方都试图超越对方,并在这个过程中逐步提高。我们称这种架构为生成对抗网络 (Generative Adversarial Network,GAN)。
GAN的训练
在GAN的架构中,生成器和鉴别器都需要训练。我们不希望先用所有的训练数据训练其中任何一方,再训练另一方。我们希望它们能一起学习,任何一方都不应该超过另一方太多。下面的三步训练循环是实现这一目标的一种方法。
- 第1步——向鉴别器展示一个真实的数据样本,告诉它该样本的分类应该是1.0。即我们向鉴别器展示一幅实际数据集中的图像,并让它对图像进行分类。输出应为1.0,我们再用损失来更新鉴别器。
- 第2步——向鉴别器显示一个生成器的输出,告诉它该样本的分类应该是0.0。即训练鉴别器,不过这一次我们向它展示生成器的图像。输出的结果应该是0.0。我们只用损失来更新鉴别器。在这一步中,我们不要更新生成器。因为我们不希望它因为被鉴别器识破而受到奖励。
- 第3步——向鉴别器显示一个生成器的输出,告诉生成器结果应该是1.0。即训练生成器。我们先用它生成一个图像,并将生成的图像输入给鉴别器进行分类。鉴别器的预期输出应该是1.0。换句话说,我们希望生成器能成功骗过鉴别器,让它误以为图像是真实的,而不是生成的。我们只用结果的损失来更新生成器,而不更新鉴别器。因为我们不希望因为错误分类而奖励鉴别器。
- 这就是大多数GAN训练方案的核心。
Comments | NOTHING