生成手写数字
通过这个任务,我们可以了解GAN的基本代码框架,并实践如何观察训练进程。完成这个简单的任务有助于我们为接下来生成人脸的任务做好准备。
数据集
数据集使用MNIST手写数据集。MNIST数据集共有训练数据60000项、测试数据10000项。每张图像的大小为28*28(像素),每张图像都为灰度图像,位深度为8(灰度图像是0-255)。
import os import torch from torch import nn from torch.utils.data import Dataset import numpy as np import matplotlib.pyplot as plt
class MnistDataset(Dataset): def __init__(self, csv_file): # 打开文件,读取所有的数据 self.data_df = [] with open(csv_file, 'r') as f: while True: data = f.readline() if not data: break data = np.asarray(data.split(','), dtype=np.uint8) self.data_df.append(data) def __len__(self): return len(self.data_df) def __getitem__(self, index): label = self.data_df[index][0] target = torch.zeros(10) target[label] = 1.0 image_values = torch.FloatTensor(self.data_df[index][1:]) / 255.0 # 返回标签、图像数据张量以及目标张量 return label, image_values, target def plot_image(self, index): img = np.asarray(self.data_df[index][1:], dtype=np.uint8).reshape((28, 28)) plt.imshow(img) plt.show()
生成器
生成器是一个神经网络,有28*28个输出值,我们希望训练它输出手写数字。
class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() self.model = nn.Sequential( nn.Linear(1, 200), nn.Sigmoid(), nn.Linear(200, 28*28), nn.Sigmoid() ) # 创建优化器,使用随机梯度下降 self.optimiser = torch.optim.SGD(self.parameters(), lr=0.01) # 计数器和进程记录 self.counter = 0 self.progress = [] def forward(self, inputs): return self.model(inputs)
分类器
鉴别器根据这28*28个值,试图判断它是来自真实数据源还是来自生成器。它是一个继承自nn.Module的神经网络。我们按照PyTorch所需要的方式初始化网络,并创建一个forward()和train()函数。
class Discriminator(nn.Module): def __init__(self): super(Classifier, self).__init__() self.model = nn.Sequential( nn.Linear(28 * 28, 200), nn.LeakyReLU(0.02), nn.Linear(200, 10), nn.LeakyReLU(0.02), # nn.Sigmoid() ) self.loss_func = nn.MSELoss() # self.loss_func.to(device) self.optimiser = torch.optim.SGD(self.parameters(), lr=0.01) # 记录训练进展的计数器和列表 self.counter = 0 self.progress = [] def forward(self, inputs): return self.model(inputs) def train(self, inputs, targets): outputs = self.forward(inputs) loss = self.loss_func(outputs, targets) # 梯度归零,反向传播,并更新权重 self.optimiser.zero_grad() loss.backward() self.optimiser.step() # 每隔10个训练样本增加一次计数器的值,并将损失值添加进列表的末尾 self.counter += 1 if self.counter % 10 == 0: self.progress.append(loss.item()) if self.counter % 10000 == 0: print("counter = ", self.counter) def plot_progress(self): plt.scatter([i for i in range(len(self.progress))], self.progress, s=1) plt.show()
训练GAN
- 第1步,我们用真实的数据训练鉴别器。
- 第2步,我们使用一组生成数据来训练鉴别器。对于生成器输出,detach()的作用是将其从计算图中分离出来。通常,对鉴别器损失直接调用backwards()函数会计算整个计算图路径的所有误差梯度。这个路径从鉴别器损失开始,经过鉴别器本身,最后返回生成器。由于我们只希望训练鉴别器,因此不需要计算生成器的梯度。生成器的detach()可以在该点切断计算图。
第3步,我们输入鉴别器对象和某个随机数训练生成器。这里没有使用detach(),是因为我们希望误差梯度从鉴别器损失传回生成器。生成器的train()函数只更新生成器的链接权重,因此我们不需要防止鉴别器被更新。
def generate_random(size=4): random_data = torch.rand(size) return random_data D = Discriminator() G = Generator() # 训练鉴别器和生成器 for label, image_data_tensor, target_tensor in mnist_dataset: # 使用真实数据训练鉴别器 D.train(image_data_tensor, torch.FloatTensor([1.0])) # 用生成样本训练鉴别器,使用detach()以避免计算生成器G中的梯度 D.train(G.forward(generate_random(1)).detach(), torch.FloatTensor([0.0])) # 训练生成器 G.train(D, generate_random(1), torch.FloatTensor([1.0])) # 绘图 f, axarr = plt.subplots(2,3, figsize=(16,8)) for i in range(2): for j in range(3): output = G.forward(generate_random(1)) img = output.detach().numpy().reshape(28,28) axarr[i,j].imshow(img, interpolation='none', cmap='Blues')
Comments | NOTHING