PyTorch之—Siamese网络

文章目录

          • 1. 数据集处理
          • 2. 网络与损失函数
          • 3. 代码如下:

本来是想做检测图像的相似度的,偶然见到这篇文章。于是写下了这篇博文。
本文参考于: github and PyTorch 中文网人脸相似度对比

关于Siamese网络 请查看,或者查看 。就是两个共享参数的CNN。每次的输入是一对图像+1个label,共3个值。注意label=0或1(又称正负样本),表示输入的两张图片match(匹配、同一个人)或no-match(不匹配、非同一人)。 下图是Siamese基本结构
PyTorch之—Siamese网络_第1张图片

1. 数据集处理

数据采用的是AT&T人脸数据。共40个人,每个人有10张脸。
数据集下载:https://files.cnblogs.com/files/king-lps/att_faces.zip
解压后文件夹下共40个文件夹,每个文件夹里有10张pgm图片。

2. 网络与损失函数

显然该网络前向传播是两张图同时输入进行。
损失函数公式:m为容忍度, D w D_w Dw 为两张图之间的欧式距离。
l o s s = ( 1 − Y ) 1 2 D w 2 + ( Y ) 1 2 { m a x ( 0 , m − D w ) } 2 \rm loss = (1-Y) \frac{1}{2}D_w^2 + (Y) \frac{1}{2}\{max(0,m-D_w)\}^2 loss=(1Y)21Dw2+(Y)21{max(0,mDw)}2

D w = { G w ( X 1 ) − G w ( X 2 ) } 2 \rm D_w = \sqrt{\{G_w(X_1) - G_w(X_2)\}^2} Dw={Gw(X1)Gw(X2)}2

3. 代码如下:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torch
from torch.autograd import Variable
import os
import random
import linecache
import numpy as np
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import PIL.ImageOps
import matplotlib.pyplot as plt


class Config():
    root = 'E:/pytorch_AI_learning/att_faces'
    txt_root = 'E:/pytorch_AI_learning/att_faces/train.txt'
    train_batch_size = 32
    train_number_epochs = 30


# Helper functions
def imshow(img, text=None, should_save=False):
    npimg = img.numpy()
    plt.axis("off")
    if text:
        plt.text(75, 8, text, style='italic', fontweight='bold',
                 bbox={'facecolor': 'white', 'alpha': 0.8, 'pad': 10})
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


def show_plot(iteration, loss):
    plt.plot(iteration, loss)
    plt.show()


def convert(train=True):
    if (train):
        f = open(Config.txt_root, 'w')
        data_path = Config.root + '/'
        if (not os.path.exists(data_path)):
            os.makedirs(data_path)
        for i in range(40):
            for j in range(10):
                img_path = data_path + 's' + str(i + 1) + '/' + str(j + 1) + '.pgm'
                f.write(img_path + ' ' + str(i) + '\n')
        f.close()


convert(True)  # 生成train.txt文件

# ready the dataset, Not use ImageFolder as the author did
class MyDataset(Dataset):

    def __init__(self, txt, transform=None, target_transform=None, should_invert=False):

        self.transform = transform
        self.target_transform = target_transform
        self.should_invert = should_invert
        self.txt = txt

    def __getitem__(self, index):

        line = linecache.getline(self.txt, random.randint(1, self.__len__()))
        line.strip('\n')
        img0_list = line.split()
        should_get_same_class = random.randint(0, 1)
        if should_get_same_class:
            while True:
                img1_list = linecache.getline(self.txt, random.randint(1, self.__len__())).strip('\n').split()
                if img0_list[1] == img1_list[1]:
                    break
        else:
            img1_list = linecache.getline(self.txt, random.randint(1, self.__len__())).strip('\n').split()

        img0 = Image.open(img0_list[0])
        img1 = Image.open(img1_list[0])
        img0 = img0.convert("L")
        img1 = img1.convert("L")

        if self.should_invert:
            img0 = PIL.ImageOps.invert(img0)
            img1 = PIL.ImageOps.invert(img1)

        if self.transform is not None:
            img0 = self.transform(img0)
            img1 = self.transform(img1)

        return img0, img1, torch.from_numpy(np.array([int(img1_list[1] != img0_list[1])], dtype=np.float32))

    def __len__(self):
        fh = open(self.txt, 'r')
        num = len(fh.readlines())
        fh.close()
        return num



""" =======Visualising some of the data==========
train_data = MyDataset(txt = Config.txt_root, transform=transforms.ToTensor(),
                       target_transform=transforms.Compose([transforms.Resize((100,100)),transforms.ToTensor()],))
train_loader = DataLoader(dataset=train_data, batch_size=8, shuffle=True)
it = iter(train_loader)
p1, p2, label = it.next()
example_batch = it.next()
concatenated = torch.cat((example_batch[0],example_batch[1]),0)
imshow(torchvision.utils.make_grid(concatenated))
print(example_batch[2].numpy())
"""



# Neural Net Definition, Standard CNNs
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        self.cnn1 = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(1, 4, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(4),
            nn.Dropout2d(p=.2),

            nn.ReflectionPad2d(1),
            nn.Conv2d(4, 8, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(8),
            nn.Dropout2d(p=.2),

            nn.ReflectionPad2d(1),
            nn.Conv2d(8, 8, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(8),
            nn.Dropout2d(p=.2),
        )

        self.fc1 = nn.Sequential(
            nn.Linear(8 * 100 * 100, 500),
            nn.ReLU(inplace=True),

            nn.Linear(500, 500),
            nn.ReLU(inplace=True),

            nn.Linear(500, 5)
        )

    def forward_once(self, x):
        output = self.cnn1(x)
        output = output.view(output.size()[0], -1)
        output = self.fc1(output)
        return output

    def forward(self, input1, input2):
        output1 = self.forward_once(input1)
        output2 = self.forward_once(input2)
        return output1, output2


# Custom Contrastive Loss
class ContrastiveLoss(torch.nn.Module):
    """
    Contrastive loss function.
    Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
    """

    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2)
        loss_contrastive = torch.mean((1 - label) * torch.pow(euclidean_distance, 2) +
                                      (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))

        return loss_contrastive


# Training
train_data = MyDataset(txt=Config.txt_root, transform=transforms.Compose(
    [transforms.Resize((100, 100)), transforms.ToTensor()]), should_invert=False)
train_dataloader = DataLoader(dataset=train_data, shuffle=True, num_workers=2, batch_size=Config.train_batch_size)

net = SiameseNetwork().cuda()
criterion = ContrastiveLoss()
optimizer = optim.Adam(net.parameters(), lr=0.0005)

counter = []
loss_history = []
iteration_number = 0

if __name__ == '__main__':
    for epoch in range(0, Config.train_number_epochs):
        for i, data in enumerate(train_dataloader, 0):
            img0, img1, label = data
            img0, img1, label = Variable(img0).cuda(), Variable(img1).cuda(), Variable(label).cuda()
            output1, output2 = net(img0, img1)
            optimizer.zero_grad()
            loss_contrastive = criterion(output1, output2, label)
            loss_contrastive.backward()
            optimizer.step()

            if i % 10 == 0:
                print("Epoch:{},  Current loss {}\n".format(epoch, loss_contrastive.data.item()))
                iteration_number += 10
                counter.append(iteration_number)
                loss_history.append(loss_contrastive.data.item())
    show_plot(counter, loss_history)
    torch.save(net.state_dict(), './SiameseNet.pth')

PyTorch之—Siamese网络_第2张图片

你可能感兴趣的:(pytorch)