视觉transformer图片处理思路

我们知道,transformer要求将图片分为patch,然后输入网络进行计算,那么我们就需要将二维的图片处理成一维的embeding形式,今天我来给大家介绍一下图片处理的思路。

我们演示一下处理下面这张图片

视觉transformer图片处理思路_第1张图片

我们将图片按照16*16的大小进行分片,得到的结果如下图所示

视觉transformer图片处理思路_第2张图片

接下来我们需要将patch变成tensor。在此之前先介绍一下传统CNN图片处理和transformer图片处理之间的区别

视觉transformer图片处理思路_第3张图片

我们可以看到,传统CNN图片处理得到的向量是三维的,而transformer图片处理得到的向量是二维的,其中num表示一张图片分片数量(也就是分成多少个patch),第二个维度中patch*patch表示每个patch的面积,channel表示通道数。

当我们训练网络的时候,通常需要将数据加载成batch的形式,一个batch里面通常包含多张图片,所以数据格式如下所示

视觉transformer图片处理思路_第4张图片

也就是说,transformer送入网路进行计算的数据是三维的,而传统CNN送入网络进行计算的数据是四维的,这也是CNN和transformer数据加载的主要区别。

下面就贴一段数据处理的演示代码,你可以按照这个代码的思路去写数据加载器。

import torch
from PIL import Image
import torchvision.transforms as tfs
import matplotlib.pyplot as plt


class ImgFactory(object):
    def __init__(self, patch=16):
        super(ImgFactory, self).__init__()
        self.patch = patch
        self.im_tfs = tfs.Compose([
            tfs.ToTensor(),
            tfs.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

    def getImagePatch(self, filename):
        img = Image.open(filename)
        width, height = img.size
        num_patch_w = width // self.patch
        num_patch_h = height // self.patch

        patch_list = []

        num = 1
        for i in range(num_patch_h):
            for j in range(num_patch_w):
                s_y = i*self.patch
                s_x = j*self.patch
                box = (s_x, s_y, self.patch+s_x, self.patch+s_y)
                region = img.crop(box)
                patch_list.append(region)
                plt.subplot(num_patch_h, num_patch_w, num), plt.imshow(region), plt.axis("off")
                num = num + 1

        plt.savefig("patch.png")

        for i in range(len(patch_list)):
            patch_list[i] = self.im_tfs(patch_list[i])
            patch_list[i] = patch_list[i].view(1,-1)

        seq = torch.cat(patch_list, dim=0)
        return seq






if __name__ == "__main__":
    factory = ImgFactory()
    seq = factory.getImagePatch("a.png")
    print(seq.shape)

输出结果是一张图片加载成tensor的格式

视觉transformer图片处理思路_第5张图片

 

你可能感兴趣的:(计算机视觉)