不得不看的图片搜索系统实现

原创:余晓龙

图片搜索系统主要分为特征提取和特征匹配两个部分,其中特征提取是深度学习模型中进行数据处理的主要环节,本文将通过一种基于无监督方式---最大化深度互信息(DIM)方法来进行特征提取,并利用提取出来的低维特征实现图片搜索系统。

1. DIM模型原理

DIM模型是通过计算输入样本与编码器输出的特征向量之间的互信息,利用最大化互信息来实现模型的训练。DIM模型在无监督训练中使用两种约束来表示学习。

(1)最大化输入信息和高级特征向量之间的互信息:如果模型输出的低维特征能够代表输入样本,那么该特征分布与输入样本分布的互信息一定是最大的。

(2)对抗匹配先验分布:编码器输出的高级特征要更接近高斯分布,判别器要将编码器生成的数据分布与高斯分布进行区分。

在实现的时候,DIM模型使用了3个判别器,分别从局部互信息的最大化、全局互信息的最大化和先验分布匹配的最小化3个角度来对编码器的输出结果进行约束。

2. 局部互信息和全局互信息最大化约束的原理

局部特征可以理解为进行卷积后得到的特征图,全局特征可以理解为对特征图进行编码得到的特征向量。对于图片,它的相关性更多的体现在局部。图像识别、分类是一个从局部到整体的过程、即全局特征更适用于重构,局部特征更适用于分类任务。DIM模型从局部和全局两个角度对输入和输出计算互信息,而先验匹配的目的是对编码器生成的向量形式进行约束,使其更接近高斯分布。

3. 先验分布匹配最小化约束的原理

DIM模型的编码器主要思想是对输入数据进行编码成特征向量的同时,还希望该特征向量服从于标准的高斯分布,这样做的主要作用是使的编码空间更加规范,有利于解藕特征以便后续学习。

4. 代码实现

本文通过使用Fashion-MNIST数据集来实现图片搜素器。Fashion-MNIST的单个样本大小为28*28像素的灰度图,其中包含训练集60000张图片、测试集10000张图片。样本的标签一共分为10类,包括T-shirt(T恤)、Trouser(裤子)、Pullover(套衫)、Dress(裙子)、Coat(外套)、Sandal(凉鞋)、Shirt(衬衫)、Sneaker(运动鞋)、Bag(包)、Ankle boot(踝靴)。

4.1 加载并显示Fashion-MNIST数据集

import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader

from torchvision.datasets.mnist import FashionMNIST
from torch.optim import Adam
from matplotlib import pyplot as plt
import numpy as np
from tqdm import tqdm
from pathlib import Path
from torchvision.transforms import ToPILImage
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '1, 2, 3'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(device)

batch_size = 256
data_dir = r'./fashon_mnist/'
train_dataset = FashionMNIST(data_dir, download=True, transform=ToTensor())
train_loader = DataLoader(train_dataset, batch_size=batch_size,
                          shuffle=True, drop_last=True,
                          pin_memory=torch.cuda.is_available())
print('train:', len(train_dataset))


def imshowrow(imgs, nrow):
    plt.figure(dpi=200)
    _img = ToPILImage()(torchvision.utils.make_grid(imgs, nrow=nrow))
    plt.axis('off')
    plt.imshow(_img)
    plt.show()



classes = ('T-shirt', 'Trouser', 'Pullover', 'Dress', 'Coat',
           'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle_Boot')

sample = iter(train_loader)
images, labels = sample.next()
print('sample shape:', np.shape(images))
print('sample label:', ','.join('%2d:%-5s' % (labels[j],
                                              classes[labels[j]])
                                for j in range(len(images[:10]))))
imshowrow(images[:10], nrow=10)

4.2 实现DIM模型

定义编码器模型类Encoder与判别器类DeepInfoMaxLoss

Encoder:通过多个卷积层对输入数据进行编码,生成64维特征向量,

DeepInfoMaxLoss:实现全局、局部、先验判别器三个模型结构,合并损失函数得到总损失。

class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.c0 = nn.Conv2d(1, 64, kernel_size=4, stride=1)
        self.c1 = nn.Conv2d(64, 128, kernel_size=4, stride=1)
        self.c2 = nn.Conv2d(128, 256, kernel_size=4, stride=1)
        self.c3 = nn.Conv2d(256, 512, kernel_size=4, stride=1)
    
        self.l1 = nn.Linear(512*16*16, 64)

        self.b1 = nn.BatchNorm2d(128)
        self.b2 = nn.BatchNorm2d(256)
        self.b3 = nn.BatchNorm2d(512)

    def forward(self, x):
        # print('x', x.shape)  # torch.Size([256, 1, 28, 28])
        h = F.relu(self.c0(x))
        # print('h1', h.size())  # torch.Size([256, 64, 25, 25])
        features = F.relu(self.b1(self.c1(h)))
        # print('features', features.size())  # torch.Size([256, 128, 22, 22])
        h = F.relu(self.b2(self.c2(features)))
        # print('h2', h.size())  # torch.Size([256, 256, 19, 19])
        h = F.relu(self.b3(self.c3(h)))
        # print('h3', h.size())  # torch.Size([256, 512, 16, 16])
        encoder = self.l1(h.view(x.shape[0], -1))
        return encoder, features


class DeepInfoMaxLoss(nn.Module):
    def __init__(self, alpha=0.5, beta=1.0, gamma=0.1):
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma

        self.local_d = nn.Sequential(
            nn.Conv2d(192, 512, kernel_size=1),
            nn.ReLU(True),
            nn.Conv2d(512, 512, kernel_size=1),
            nn.ReLU(True),
            nn.Conv2d(512, 1, kernel_size=1)
        )

        self.prior_d = nn.Sequential(
            nn.Linear(64, 1000),
            nn.ReLU(True),
            nn.Linear(1000, 200),
            nn.ReLU(True),
            nn.Linear(200, 1),
            nn.Sigmoid()
        )

        self.global_d_M = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3),
            nn.ReLU(True),
            nn.Conv2d(64, 32, kernel_size=3),
            nn.Flatten()
        )

        self.global_d_fc = nn.Sequential(
            nn.Linear(32 * 18 * 18 + 64, 512),
            nn.ReLU(True),
            nn.Linear(512, 512),
            nn.ReLU(True),
            nn.Linear(512, 1)
        )

    def GlobalD(self, y, M):
        h = self.global_d_M(M)
        h = torch.cat((y, h), dim=1)
        return self.global_d_fc(h)

    def forward(self, y, M, M_prime):
        y_exp = y.unsqueeze(-1).unsqueeze(-1)
        # print('y_exp', y_exp.shape)
        # y_exp torch.Size([256, 64, 1, 1])
        y_exp = y_exp.expand(-1, -1, 22, 22)
        # print('y_exp', y_exp.shape)
        # y_exp torch.Size([256, 64, 22, 22])
        y_M = torch.cat((M, y_exp), dim=1)
        # print('y_M', y_M.shape)
        # y_M torch.Size([256, 192, 22, 22])
        y_M_prime = torch.cat((M_prime, y_exp), dim=1)
        # print('y_M_prime', y_M_prime.shape)
        # y_M_prime torch.Size([256, 192, 22, 22])

        Ej = -F.softplus(-self.local_d(y_M)).mean()
        Em = F.softplus(self.local_d(y_M_prime)).mean()
        Local = (Em - Ej) * self.beta

        Ej = -F.softplus(-self.GlobalD(y, M)).mean()
        Em = F.softplus(self.GlobalD(y, M_prime)).mean()
        Global = (Em - Ej) * self.alpha

        prior = torch.rand_like(y)
        term_a = torch.log(self.prior_d(prior)).mean()
        term_b = torch.log(1.0 - self.prior_d(y)).mean()
        Prior = -(term_a + term_b) * self.gamma

        return Local + Global + Prior

4.3 实例化模型并进行训练

totalepoch = 100
if __name__ == '__main__':
    encoder = Encoder().to(device)
    loss_fn = DeepInfoMaxLoss().to(device)
    optim = Adam(encoder.parameters(), lr=1e-4)
    loss_optim = Adam(loss_fn.parameters(), lr=1e-4)

    epoch_loss = []
    for epoch in range(totalepoch + 1):
        batch = tqdm(train_loader, total=len(train_dataset) // batch_size)
        train_loss = []
        for x, target in batch:
            x = x.to(device)
            optim.zero_grad()
            loss_optim.zero_grad()
            y, M = encoder(x)

            M_prime = torch.cat((M[1:], M[0].unsqueeze(0)), dim=0)
            loss = loss_fn(y, M, M_prime)
            train_loss.append(loss.item())
            batch.set_description(
                str(epoch) + ' Loss:%.4f' % np.mean(train_loss[-20:]
            ))
            loss.backward()
            optim.step()
            loss_optim.step()

        if epoch % 10 == 0:
            root = Path(r'./DIMmodel2/')
            enc_file = root / Path('encoder' + str(epoch) + '.pth')
            loss_file = root / Path('loss' + str(epoch) + '.pth')
            enc_file.parent.mkdir(parents=True, exist_ok=True)
            torch.save(encoder.state_dict(), str(enc_file))
            torch.save(loss_fn.state_dict(), str(loss_file))
       
        epoch_loss.append(np.mean(train_loss[-20:]))
    plt.plot(np.arange(len(epoch_loss)), epoch_loss, 'r')
    plt.show()

训练完成后得到模型文件,在DIMmodel2文件夹下生成encoder100.pth和loss.pth。

4.4 加载模型实现图像搜索

import random

model_path = r'./DIMmodel2/encoder%d.pth' % (totalepoch)
encoder = Encoder().to(device)
encoder.load_state_dict(torch.load(model_path, map_location=device))

batchesimg, batchesenc = [], []
batch = tqdm(train_loader, total=len(train_dataset) // batch_size)

for images, target in batch:
    images = images.to(device)
    with torch.no_grad():
        encoded, features = encoder(images)
    batchesimg.append(images)
    batchesenc.append(encoded)

batchesenc = torch.cat(batchesenc, axis=0)
batchesimg = torch.cat(batchesimg, axis=0)

index = random.randrange(0, len(batchesenc))
batchesenc[index].repeat(len(batchesenc), 1)

l2_dis = F.mse_loss(batchesenc[index].repeat(len(batchesenc), 1),
                    batchesenc, reduction='none').sum(1)

findnum = 5   # 设置需要查找图片的个数
_, indices = l2_dis.topk(findnum, largest=False) # 查找出5个最相似的图片

indices = torch.cat([torch.tensor([index]).to(device), indices])

rel = batchesimg[indices]
imshowrow(rel.cpu(), nrow=len(indices))

从结果图像可以看出,查找出的最相似的5张图片与查找的图像是一样的。通过最大化深度互信息模型实现的图像搜索是有效的。大家可以修改数据集,实现自己的图片搜素系统。

你可能感兴趣的:(不得不看的图片搜索系统实现)