pytorch构建卷积自编码器实现图片的压缩的功能,图片检索(以图搜图)功能

  1. 使用Pytorch构建卷积自编码器,尝试对MNIST数据集进行压缩。
  2. 利用上述自编码器完成手写数字的“以图搜图”的功能。
    要求有构建与训练自编码器的代码和训练过程,以及以图搜图的结果(在测试集中随机拿出10张类别不同的图片,查找其中每个图片的最接近的5-10张图片)。

代码实现:

import numpy as np
import torch
import torchvision
from torch.utils.data import Dataset,DataLoader
import torch.nn as nn
from torch import optim
from torchvision.datasets import MNIST
from PIL import Image
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import StepLR
import os
from torchvision import transforms as T
import cv2

os.environ["CUDA_VISIBLE_DEVICES"]="0,1,2"

transform=T.Compose([
        T.Resize((28,28)),
        T.Grayscale(),
        T.ToTensor(),
    ])

#定义数据对象
class MNISTDataset(Dataset):
    def __init__(self,images,transform):
        self.dataset=images.data.to(torch.float32)
        self.transform=transform
    def __getitem__(self,index):
        data=self.dataset[index]
        data=np.asarray(data)
        data=Image.fromarray(data)
        data=self.transform(data)
        return data
    def __len__(self):
        return len(self.dataset)
    
##加载数据
def load_mnist():
    trainMnist=MNIST('',download=False,train=True)
    testMnist=MNIST('',download=False,train=False)
    trainMnist=MNISTDataset(trainMnist,transform=transform)
    testMnist=MNISTDataset(testMnist,transform=transform)
    trainLoader=DataLoader(trainMnist,shuffle=True,batch_size=256)
    testLoader=DataLoader(testMnist,shuffle=False,batch_size=24)
    return (trainLoader,testLoader,trainMnist,testMnist)

#定义网络模型(卷积自编码器)
class EncodeModel(nn.Module):
    def __init__(self,judge=True):
        super(EncodeModel,self).__init__()
        self.encode=nn.Sequential(
                nn.Conv2d(1,16,kernel_size=7),#16*22*22
                nn.BatchNorm2d(16),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=2,stride=2),#16*11*11
                
                nn.Conv2d(16,4,kernel_size=3),#4*9*9
                nn.BatchNorm2d(4),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=2,stride=2),#4*4*4
                
                nn.Conv2d(4,4,kernel_size=2),#4*3*3
                nn.BatchNorm2d(4),
                nn.ReLU(inplace=True),
            )
        self.decode=nn.Sequential(
                nn.ConvTranspose2d(4,1,kernel_size=5,stride=2),#1*9*9
                nn.BatchNorm2d(1),
                nn.ReLU(inplace=True),
                
                nn.ConvTranspose2d(1,1,kernel_size=7,stride=2),#1*23*23
                nn.BatchNorm2d(1),
                nn.ReLU(inplace=True),
                
                nn.ConvTranspose2d(1,1,kernel_size=6,stride=1),#1*28*28
                nn.BatchNorm2d(1),
                nn.ReLU(inplace=True),
            )
        
    def forward(self,x,judge):
        enOutputs=self.encode(x)
        outputs=self.decode(enOutputs)
        if judge:
            return outputs
        else:
            return enOutputs

def train(trainLoader,testLoader):
    model=EncodeModel().cuda()
    optimizer=optim.Adam(model.parameters(),lr=0.001)
    scheduler=StepLR(optimizer,step_size=5,gamma=0.8)
    criterion=nn.MSELoss().cuda()
    epochs=50
    for epoch in range(epochs):
        for (i,trainData) in enumerate(trainLoader):
            trainData=trainData.cuda()
            outputs=model(trainData,True).cuda()
            optimizer.zero_grad()
            loss=criterion(outputs,trainData)
            loss.backward()
            optimizer.step()
            torch.save(model.state_dict(),'EncodeModel.pth')
        print('epoch:{} loss:{:7f}'.format(epoch,loss.item()))
        scheduler.step()
        model.train(False)
        for (i,testData) in enumerate(testLoader):
            testData=testData.cuda()
            outputs=model(testData,True)
            plt.figure(1)
            testData=testData.to('cpu')
            outputs=outputs.to('cpu')
            plt.imshow((torchvision.utils.make_grid(outputs).permute((1,2,0))).detach().numpy())
            plt.show()
            break
        model.train(True)
    return model
#以图搜图函数  
def search_by_image(testMnist,inputImage,K=5):
    model=EncodeModel()
    model.load_state_dict(torch.load('EncodeModel.pth'))
    model.train(False)
    criterion=nn.MSELoss()
    testLoader=DataLoader(testMnist,batch_size=1,shuffle=False)
    inputImage=inputImage.unsqueeze(0)
    inputEncode=model(inputImage,False)
    lossList=[]
    for (i,testImage) in enumerate(testLoader):
        testEncode=model(testImage,False)
        enLoss=criterion(inputEncode,testEncode)
        lossList.append((i,enLoss.item()))
    lossList=sorted(lossList,key=lambda x:x[1],reverse=False)[:K]
    plt.figure(1)
    trueImage=inputImage.squeeze(0).squeeze(0)
    plt.imshow(trueImage,cmap='gray',shape=(28,28))
    plt.title('true')
    plt.show()
    for j in range(K):
        showImage=testMnist[lossList[j][0]]
        showImage=showImage.squeeze(0)
        showImage=np.array(showImage)
        plt.subplot(1,5,j+1)
        plt.imshow(showImage,cmap='gray')
    plt.title('sim')
    plt.show()
    
if __name__=='__main__':
    trainLoader,testLoader,trainMnist,testMnist=load_mnist()
#    model=train(trainLoader,testLoader)
    i=0
    for inputImage in testMnist:
        search_by_image(testMnist,inputImage)
        i+=1
        if i>200:
            break

训练效果:pytorch构建卷积自编码器实现图片的压缩的功能,图片检索(以图搜图)功能_第1张图片pytorch构建卷积自编码器实现图片的压缩的功能,图片检索(以图搜图)功能_第2张图片

你可能感兴趣的:(pytorch,人工智能)