Pytorch实现自编码器

原文地址

分类目录——Pytorch

  • 什么是编码器

有一中数据压缩的、降维的意思

举个例子来说明,同一张图片,高清的和标清的我们都能识别出图片中的内容(这里就考虑识别这一个需求,其他需求暂不考虑),这是因为即使是标清的图片,也保留了进行识别的关键特征。但是高清的在无论是在保存,还是在提取上都会更费工夫。深度学习处理起来亦是如此,深度学习会包含很多层,每层节点也很多,这种情况下,如果输入数据的规模太大,神经网络也很难训练出结果。那么,能在保留关键特征的基础上对数据尽心降维,就是一项一劳永逸的活动。

这个编码器要怎么用呢

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-1nZbFk6U-1582679206189)(https://morvanzhou.github.io/static/results/ML-intro/auto3.png)]

图片引自 什么是自编码 (Autoencoder)

编码器的构造跟自己的应用(或者分类,或者回归)上两套体系。编码器也是一个完整的训练流程,虽说叫编码器,其实其内部包括编码(上图中的压缩)和解码(上图中的解压)两部分,编码用来降维,解码用来将维度回复,通过维度恢复的数据(上图中的黑色X)与原始数据(上图的白色X)的误差来训练编码器参数,训练完成后编码部分将能压缩到原始数据的关键特征,极大地加速训练过程。

另外我觉得,翻译是一个很好的例子,自己有中思路可不可以做一种压缩(编码)一种万国语,存放在计算机中,计算机能识别的;甚至可以跨越表达方式,比如‘狗’、‘dog’ 另外还有一张狗的图片。他们在计算机中的表现形式是一样的,通过不同的模型可以翻译成‘狗’、‘dog’ 和狗的图片。

下面用一个例子来说明

这个程序的数据是手写数字识别的图片,分辨率为28*28,通过编码器将28*28维度的像素维度降维到3维;然后用3维数据在三维坐标平面内进行了可视化;最后用svm就编码之后的3维数据进行分类,因为压缩之后只有3个维度,为了节约时间只用了1000个训练数据,所以最终的准确率并没有很高。

  • 导入支持包与设置超参数

    import torch
    import torch.nn as nn
    import torch.utils.data as Data
    import torchvision
    import matplotlib.pyplot as plt
    from matplotlib import cm
    from mpl_toolkits.mplot3d import Axes3D
    import os
    import numpy as np
    from sklearn import svm
    from sklearn.model_selection import GridSearchCV
    
    # 超参数
    EPOCH = 10
    BATCH_SIZE = 64
    LR = 0.005
    if os.path.exists('mnist/'):  # 如果已经存在(下载)了就不用下载了
        DOWNLOAD_MNIST = False
    else:
        DOWNLOAD_MNIST = True   # 下过数据的话, 就可以设置成 False
    N_TEST_IMG = 5          # 到时候显示 5张图片看效果, 如上图一
    
  • 获得手写数字图片数据

    ####################################### 获取手写数字图片数据
    train_data = torchvision.datasets.MNIST(
        root='./mnist/',
        train=True,                                     # this is training data
        transform=torchvision.transforms.ToTensor(),    # Converts a PIL.Image or numpy.ndarray to
                                                        # torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]
        download=DOWNLOAD_MNIST,                        # download it if you don't have it
    )
    
    test_data = torchvision.datasets.MNIST(root='./mnist/', train=False)
    
    # 批训练 50samples, 1 channel, 28x28 (50, 1, 28, 28)
    train_loader = Data.DataLoader(
        dataset=train_data,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=0
    )
    
  • 构造编码器

    ################################# 构造编码器
    class AutoEncoder(nn.Module):
        def __init__(self):
            super(AutoEncoder, self).__init__()
    
            # 编码网络
            self.encoder = nn.Sequential(
                nn.Linear(28*28, 128),
                nn.Tanh(),
                nn.Linear(128, 64),
                nn.Tanh(),
                nn.Linear(64, 12),
                nn.Tanh(),
                nn.Linear(12, 3),   # 压缩成3个特征, 是为了寿面好进行 3D 图像可视化
                # 当然也可以压缩到5个特征,选其中的三个来作图
            )
            # 解码网络
            self.decoder = nn.Sequential(
                nn.Linear(3, 12),
                nn.Tanh(),
                nn.Linear(12, 64),
                nn.Tanh(),
                nn.Linear(64, 128),
                nn.Tanh(),
                nn.Linear(128, 28*28),
                nn.Sigmoid(),       # 激励函数让输出值在 (0, 1)
            )
    
        def forward(self, x):
            encoded = self.encoder(x)
            decoded = self.decoder(encoded)
            return encoded, decoded
    # 定义一个编码器对象
    autoencoder = AutoEncoder()
    
  • 训练编码器

    ############################## 训练编码器
    optimizer = torch.optim.Adam(autoencoder.parameters(), lr=LR)
    loss_func = nn.MSELoss()
    
    for epoch in range(EPOCH):
        for step, (x, b_label) in enumerate(train_loader):
            b_x = x.view(-1, 28*28)   # batch x, shape (batch, 28*28)
            # b_y跟b_x是一样
    
            encoded_x, decoded_x = autoencoder(b_x)
    
            loss = loss_func(decoded_x, b_x)    # 这里如果写成b_x会更容易裂解
            optimizer.zero_grad()               # clear gradients for this training step
            loss.backward()                     # backpropagation, compute gradients
            optimizer.step()                    # apply gradients
    
  • 可视化

    ########################### 画图的部分
    # 取200个数据来作图
    view_data = train_data.train_data[:200].view(-1, 28 * 28).type(torch.FloatTensor) / 255.
    encoded_data, _ = autoencoder(view_data)  # 提取压缩的特征值
    fig = plt.figure(2)
    ax = Axes3D(fig)  # 3D 图
    # x, y, z 的数据值
    X = encoded_data.data[:, 0].numpy()
    Y = encoded_data.data[:, 1].numpy()
    Z = encoded_data.data[:, 2].numpy()
    values = train_data.train_labels[:200].numpy()  # 标签值
    for x, y, z, s in zip(X, Y, Z, values):
        c = cm.rainbow(int(255 * s / 9))  # 上色
        ax.text(x, y, z, s, backgroundcolor=c)  # 标位子
    ax.set_xlim(X.min(), X.max())
    ax.set_ylim(Y.min(), Y.max())
    ax.set_zlim(Z.min(), Z.max())
    plt.show()
    # 注意这里进行了plt.show(),程序会停在这里,需要把图片关闭之后下面的程序才能进行,也可以调换一下跟下面svm分类部分替换位置
    

    注意这里进行了plt.show(),程序会停在这里,需要把图片关闭之后下面的程序才能进行,也可以调换一下跟下面svm分类部分替换位置

  • 用SVM对编码(压缩)后的数据进行数字识别

    ################################### 用SVM分类
    # 取1000个训练数据来训练svm
    svm_train = train_data.train_data[:1000].view(-1, 28 * 28).type(torch.FloatTensor) / 255.
    s_t_x_afterencoder = autoencoder(svm_train)[0].data.numpy()
    print(s_t_x_afterencoder.shape())
    s_t_y = train_data.train_labels[:1000].numpy()  # 标签值
    print(s_t_y.shape())
    # 取1000个训练数据来测试
    svm_test = test_data.test_data[:1000].view(-1, 28 * 28).type(torch.FloatTensor) / 255.
    s_te_x_afterencoder = autoencoder(svm_test)[0].data.numpy()
    s_te_y = test_data.test_labels[:1000].numpy()  # 标签值
    
    c_can = np.logspace(-3, 2, 10)
    gamma_can = np.logspace(-3, 2, 10)
    
    model = svm.SVC(kernel='rbf', decision_function_shape='ovr', random_state=1)
    clf = GridSearchCV(model, param_grid={'C': c_can, 'gamma': gamma_can}, cv=5, n_jobs=5)
    clf.fit(s_t_x_afterencoder, s_t_y)
    
    print('测试集准确率:\t', clf.score(s_te_x_afterencoder, s_te_y))  # 因为压缩到了三个特征,准确率并不是很高
    # 测试集准确率:	 0.764
    
  • 参考文献

什么是自编码 (Autoencoder)

AutoEncoder (自编码/非监督学习)

分类目录——Matplotlib

你可能感兴趣的:(Python,#,Pytorch)