GAN 生成手写数字


生成手写数字

通过这个任务,我们可以了解GAN的基本代码框架,并实践如何观察训练进程。完成这个简单的任务有助于我们为接下来生成人脸的任务做好准备。

数据集

  • 数据集使用MNIST手写数据集。MNIST数据集共有训练数据60000项、测试数据10000项。每张图像的大小为28*28(像素),每张图像都为灰度图像,位深度为8(灰度图像是0-255)。

    import os
    import torch
    from torch import nn
    from torch.utils.data import Dataset
    import numpy as np
    import matplotlib.pyplot as plt
    class MnistDataset(Dataset):
        def __init__(self, csv_file):
            # 打开文件,读取所有的数据
            self.data_df = []
            with open(csv_file, 'r') as f:
                while True:
                    data = f.readline()
                    if not data:
                        break
                    data = np.asarray(data.split(','), dtype=np.uint8)
                    self.data_df.append(data)
    
        def __len__(self):
            return len(self.data_df)
    
        def __getitem__(self, index):
            label = self.data_df[index][0]
            target = torch.zeros(10)
            target[label] = 1.0
            image_values = torch.FloatTensor(self.data_df[index][1:]) / 255.0
            # 返回标签、图像数据张量以及目标张量
            return label, image_values, target
    
        def plot_image(self, index):
            img = np.asarray(self.data_df[index][1:], dtype=np.uint8).reshape((28, 28))
            plt.imshow(img)
            plt.show()

生成器

  • 生成器是一个神经网络,有28*28个输出值,我们希望训练它输出手写数字。

    class Generator(nn.Module):
        def __init__(self):
            super(Generator, self).__init__()
            self.model = nn.Sequential(
                nn.Linear(1, 200),
                nn.Sigmoid(),
                nn.Linear(200, 28*28),
                nn.Sigmoid()
            )
            # 创建优化器,使用随机梯度下降
            self.optimiser = torch.optim.SGD(self.parameters(), lr=0.01)
            # 计数器和进程记录
            self.counter = 0
            self.progress = []
    
    
        def forward(self, inputs):
            return self.model(inputs)

分类器

  • 鉴别器根据这28*28个值,试图判断它是来自真实数据源还是来自生成器。它是一个继承自nn.Module的神经网络。我们按照PyTorch所需要的方式初始化网络,并创建一个forward()和train()函数。

    class Discriminator(nn.Module):
         def __init__(self):
             super(Classifier, self).__init__()
             self.model = nn.Sequential(
                 nn.Linear(28 * 28, 200),
                 nn.LeakyReLU(0.02),
                 nn.Linear(200, 10),
                 nn.LeakyReLU(0.02),
               # nn.Sigmoid()
             )
             self.loss_func = nn.MSELoss()
             # self.loss_func.to(device)
             self.optimiser = torch.optim.SGD(self.parameters(), lr=0.01)
             # 记录训练进展的计数器和列表
             self.counter = 0
             self.progress = []
    
         def forward(self, inputs):
             return self.model(inputs)
    
         def train(self, inputs, targets):
             outputs = self.forward(inputs)
             loss = self.loss_func(outputs, targets)
             # 梯度归零,反向传播,并更新权重
             self.optimiser.zero_grad()
             loss.backward()
             self.optimiser.step()
    
             # 每隔10个训练样本增加一次计数器的值,并将损失值添加进列表的末尾
             self.counter += 1
    
             if self.counter % 10 == 0:
                 self.progress.append(loss.item())
    
             if self.counter % 10000 == 0:
                 print("counter = ", self.counter)
    
         def plot_progress(self):
             plt.scatter([i for i in range(len(self.progress))], self.progress, s=1)
             plt.show()
    

训练GAN

  • 第1步,我们用真实的数据训练鉴别器。
  • 第2步,我们使用一组生成数据来训练鉴别器。对于生成器输出,detach()的作用是将其从计算图中分离出来。通常,对鉴别器损失直接调用backwards()函数会计算整个计算图路径的所有误差梯度。这个路径从鉴别器损失开始,经过鉴别器本身,最后返回生成器。由于我们只希望训练鉴别器,因此不需要计算生成器的梯度。生成器的detach()可以在该点切断计算图。
  • 第3步,我们输入鉴别器对象和某个随机数训练生成器。这里没有使用detach(),是因为我们希望误差梯度从鉴别器损失传回生成器。生成器的train()函数只更新生成器的链接权重,因此我们不需要防止鉴别器被更新。

    
    
    def generate_random(size=4):
        random_data = torch.rand(size)
        return random_data
    
    D = Discriminator()
    G = Generator()
    
    # 训练鉴别器和生成器
    for label, image_data_tensor, target_tensor in mnist_dataset:
        # 使用真实数据训练鉴别器
        D.train(image_data_tensor, torch.FloatTensor([1.0]))
        # 用生成样本训练鉴别器,使用detach()以避免计算生成器G中的梯度
        D.train(G.forward(generate_random(1)).detach(), torch.FloatTensor([0.0]))
        # 训练生成器
        G.train(D, generate_random(1), torch.FloatTensor([1.0]))
    
    # 绘图
    f, axarr = plt.subplots(2,3, figsize=(16,8))
    for i in range(2):
        for j in range(3):
            output = G.forward(generate_random(1))
            img = output.detach().numpy().reshape(28,28)
            axarr[i,j].imshow(img, interpolation='none', cmap='Blues')

声明:Hello World|版权所有,违者必究|如未注明,均为原创|本网站采用BY-NC-SA协议进行授权

转载:转载请注明原文链接 - GAN 生成手写数字


我的朋友,理论是灰色的,而生活之树是常青的!