Siamese+Resnet进行相似度计算

Siamese+Resnet进行相似度计算

  • 基本介绍
  • 效果
    • 肺部+resnet34
    • 肺部+Resnet50
    • 人脸+自定义网络
  • 完整代码

基本介绍

使用SiameseNet进行肺部相似度计算,同样可以用于人脸识别等场景。
特征提取网络结果为Resnet,可以为Resnet34、Resnet50等。
数据组织结构如下图所示:

  • lung:下面包含训练集training 和测试机testing。training下面为各个类别图片的文件夹。
  • model_data: 为resnet预训练模型存放地址
  • result:保存测试结果和训练的日志。
  • Train_Siamese_with_Resnet.py为训练脚本。主要需要根据情况修改如下参数配置:
    - MY_DATA:选择哪个作为训练数据。直接选择data文件夹下的某个文件夹名字即可,如MY_DATA=“lung”
    - Config类:主要配置batchsize和epoch

Siamese+Resnet进行相似度计算_第1张图片

效果

肺部+resnet34

Siamese+Resnet进行相似度计算_第2张图片

Siamese+Resnet进行相似度计算_第3张图片
Siamese+Resnet进行相似度计算_第4张图片
Siamese+Resnet进行相似度计算_第5张图片

肺部+Resnet50

Siamese+Resnet进行相似度计算_第6张图片

Siamese+Resnet进行相似度计算_第7张图片Siamese+Resnet进行相似度计算_第8张图片
Siamese+Resnet进行相似度计算_第9张图片

人脸+自定义网络

Siamese+Resnet进行相似度计算_第10张图片
Siamese+Resnet进行相似度计算_第11张图片
Siamese+Resnet进行相似度计算_第12张图片
Siamese+Resnet进行相似度计算_第13张图片

完整代码

#!/usr/bin/python
# -*- coding: UTF-8 -*-
"""
@author:uncle德鲁
@file:siamesenet.py
@time:2023/07/29
"""
import os
import torchvision
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import torchvision.utils
import numpy as np
import random
from PIL import Image
import torch
from torch.autograd import Variable
import PIL.ImageOps
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.hub import load_state_dict_from_url
import sys
import datetime
from torchsummary import summary
torch.autograd.set_detect_anomaly(True)


class Logger(object):
    def __init__(self, filename, stream=sys.stdout):
        self.terminal = stream
        self.log = open(filename, 'a')

    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)

    def flush(self):
        pass


MY_DATA = "lung_mask"

# 现在的时间
now = datetime.datetime.now()
formatted_time = now.strftime("%Y-%m-%d_%H-%M")
sys.stdout = Logger("./result/train_loss_{}.log".format(formatted_time), sys.stdout)


def imshow(img, img_name, text=None, title=None):
    npimg = img.numpy()
    plt.axis("off")
    if text:
        plt.text(75, 8, text, style='italic', fontweight='bold',
                 bbox={'facecolor': 'white', 'alpha': 0.8, 'pad': 10})
    if title:
        plt.title(title)

    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.savefig(img_name)
    plt.clf()


def show_plot(iteration, loss, img_name):
    plt.plot(iteration, loss)
    plt.savefig(img_name)
    plt.clf()


class Config:
    my_data = MY_DATA
    training_dir = "./data/{}/training/".format(my_data)
    testing_dir = "./data/{}/testing/".format(my_data)
    train_batch_size = 4
    train_number_epochs = 10


class SiameseNetworkDataset(Dataset):
    def __init__(self, imageFolderDataset, transform=None, should_invert=True):
        self.imageFolderDataset = imageFolderDataset
        self.transform = transform
        self.should_invert = should_invert

    def __getitem__(self, index):
        img0_tuple = random.choice(self.imageFolderDataset.imgs)
        # we need to make sure approx 50% of images are in the same class
        should_get_same_class = random.randint(0, 1)
        if should_get_same_class:
            while True:
                # keep looping till the same class image is found
                img1_tuple = random.choice(self.imageFolderDataset.imgs)
                if img0_tuple[1] == img1_tuple[1]:
                    break
        else:
            while True:
                # keep looping till a different class image is found
                img1_tuple = random.choice(self.imageFolderDataset.imgs)
                if img0_tuple[1] != img1_tuple[1]:
                    break

        img0 = Image.open(img0_tuple[0])
        img1 = Image.open(img1_tuple[0])
        img0 = img0.convert("L")
        img1 = img1.convert("L")

        if self.should_invert:
            img0 = PIL.ImageOps.invert(img0)
            img1 = PIL.ImageOps.invert(img1)

        if self.transform is not None:
            img0 = self.transform(img0)
            img1 = self.transform(img1)

        return img0, img1, torch.from_numpy(
            np.array([int(img1_tuple[1] != img0_tuple[1])], dtype=np.float32))

    def __len__(self):
        return len(self.imageFolderDataset.imgs)


class BasicBlock(nn.Module):
    """
    # 定义 BasicBlock 模块
    # ResNet18/34的残差结构, 用的是2个3x3大小的卷积
    """
    expansion = 1   # 残差结构中, 判断主分支的卷积核个数是否发生变化,不变则为1

    def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):   # downsample 对应虚线残差结构
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                               kernel_size=(3, 3), stride=(stride, stride), padding=1, bias=False
                               )
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
                               kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False
                               )
        self.bn2 = nn.BatchNorm2d(out_channel)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:  # 虚线残差结构,需要下采样
            identity = self.downsample(x)   # 捷径分支short cut

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += identity
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    """
    # 定义 Bottleneck 模块
    # ResNet50/101/152的残差结构,用的是1x1+3x3+1x1的卷积
    #   注意:原论文中,在虚线残差结构的主分支上,第一个1x1卷积层的步距是2,第二个3x3卷积层步距是1。
    #  但在pytorch官方实现过程中是第一个1x1卷积层的步距是1,第二个3x3卷积层步距是2,
    #   这么做的好处是能够在top1上提升大概0.5%的准确率。
    #   可参考Resnet v1.5 https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch
    """
    expansion = 4   # 残差结构中第三层卷积核个数是第1/2层卷积核个数的4倍

    def __init__(self, in_channel, out_channel, stride=1,
                 downsample=None, groups=1, width_per_group=64):
        super(Bottleneck, self).__init__()

        width = int(out_channel * (width_per_group / 64.)) * groups

        self.conv1 = nn.Conv2d(
            in_channels=in_channel,
            out_channels=width,
            kernel_size=(1, 1),
            stride=(1, 1),
            bias=False)
        self.bn1 = nn.BatchNorm2d(width)

        self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, groups=groups,
                               kernel_size=(3, 3), stride=(stride, stride), bias=False, padding=1
                               )
        self.bn2 = nn.BatchNorm2d(width)

        self.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel * self.expansion,
                               kernel_size=(1, 1), stride=(1, 1), bias=False)
        self.bn3 = nn.BatchNorm2d(out_channel * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)   # 捷径分支short cut

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out += identity
        out = self.relu(out)

        return out


class ResNet(nn.Module):
    """
    # 残差网络结构
    """
    # block = BasicBlock or Bottleneck
    # blocks_num 为残差结构中 conv2_x~conv5_x 中残差块个数, 一个列表

    def __init__(self, block, blocks_num, num_classes=1000, include_top=True, groups=1, width_per_group=64):
        super(ResNet, self).__init__()
        self.include_top = include_top
        self.in_channel = 64
        self.groups = groups
        self.width_per_group = width_per_group

        self.conv1 = nn.Conv2d(1,
                               self.in_channel,
                               kernel_size=(7, 7),
                               stride=(2, 2),
                               padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channel)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, blocks_num[0])
        self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
        self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
        self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)
        if self.include_top:
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)
            self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    # channel 为残差结构中第1层卷积核个数
    def _make_layer(self, block, channel, block_num, stride=1):
        downsample = None
        # ResNet50/101/152 的残差结构, block.expansion=4
        if stride != 1 or self.in_channel != channel * block.expansion:
            downsample = nn.Sequential(nn.Conv2d(self.in_channel,
                                                 channel *
                                                 block.expansion,
                                                 kernel_size=(1, 1),
                                                 stride=(stride, stride),
                                                 bias=False),
                                       nn.BatchNorm2d(channel * block.expansion))

        layers = []
        layers.append(block(self.in_channel,
                            channel,
                            downsample=downsample,
                            stride=stride,
                            groups=self.groups,
                            width_per_group=self.width_per_group,
                            ))
        self.in_channel = channel * block.expansion

        for _ in range(1, block_num):
            layers.append(block(self.in_channel,
                                channel,
                                groups=self.groups,
                                width_per_group=self.width_per_group,
                                ))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        if self.include_top:
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.fc(x)

        return x


def resnet34(num_classes=1000, include_top=True):
    """
    # resnet34 结构
    # https://download.pytorch.org/models/resnet34-333f7ec4.pth
    """
    return ResNet(BasicBlock, [3, 4, 6, 3],
                  num_classes=num_classes, include_top=include_top)


def resnet50(num_classes=1000, include_top=True):
    """
    # resnet50 结构
    # https://download.pytorch.org/models/resnet50-19c8e357.pth
    """
    return ResNet(Bottleneck, [3, 4, 6, 3],
                  num_classes=num_classes, include_top=include_top)


def resnet101(num_classes=1000, include_top=True):
    """
    # resnet101 结构
    # https://download.pytorch.org/models/resnet101-5d3b4d8f.pth
    """
    return ResNet(Bottleneck, [3, 4, 23, 3],
                  num_classes=num_classes, include_top=include_top)


def resnext50_32x4d(num_classes=1000, include_top=True):
    """
    # resnext50_32x4d 结构
    # https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth
    """
    groups = 32
    width_per_group = 4
    return ResNet(Bottleneck, [3, 4, 6, 3],
                  num_classes=num_classes,
                  include_top=include_top,
                  groups=groups,
                  width_per_group=width_per_group)


def resnext101_32x8d(num_classes=1000, include_top=True):
    """
    # resnext101_32x8d 结构
    # https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth
    """
    groups = 32
    width_per_group = 8
    return ResNet(Bottleneck, [3, 4, 23, 3],
                  num_classes=num_classes,
                  include_top=include_top,
                  groups=groups,
                  width_per_group=width_per_group)


class SiameseNetwork(nn.Module):
    def __init__(self, num_classes=1000):
        super().__init__()

        # self.resnet = resnet50(num_classes=num_classes, include_top=True)
        self.resnet = resnet34(num_classes=num_classes, include_top=True)

    def initialize_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Conv2d):
                # Initialize the weights of convolutional layers
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, nn.BatchNorm2d):
                # Initialize the weights and biases of batch normalization layers
                nn.init.ones_(module.weight)
                nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Linear):
                # Initialize the weights and biases of linear layers
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)

    def forward(self, x):
        raise NotImplementedError


class SiameseNetworkQuadret(SiameseNetwork):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def forward(self, x):
        x1, x2, x3, x4 = x
        x1, _ = self.resnet(x1)
        x2, _ = self.resnet(x2)
        x3, _ = self.resnet(x3)
        x4, _ = self.resnet(x4)
        return x1, x2


class SiameseNetworkTriplet(SiameseNetwork):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def forward(self, x):
        x1, x2, x3 = x
        x1 = self.resnet(x1)
        x2 = self.resnet(x2)
        x3 = self.resnet(x3)

        return x1, x2, x3


class SiameseNetworkDouble(SiameseNetwork):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def forward(self, x1, x2):
        x1 = self.resnet(x1)
        x2 = self.resnet(x2)
        return x1, x2


# Loss Function


class ContrastiveLoss(torch.nn.Module):
    """
    Contrastive loss function.
    Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
    """

    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(
            output1, output2, keepdim=True)
        loss_contrastive = torch.mean((1 - label) * torch.pow(euclidean_distance, 2) +
                                      label * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
        return loss_contrastive


def run():
    base_dir = "./result/{}/".format(MY_DATA)
    if not os.path.exists(base_dir):
        os.makedirs(base_dir)
    folder_dataset = dset.ImageFolder(root=Config.training_dir)
    siamese_dataset = SiameseNetworkDataset(imageFolderDataset=folder_dataset,
                                            transform=transforms.Compose([transforms.Resize((100, 100)),
                                                                          transforms.ToTensor()]),
                                            should_invert=False)

    # train
    train_dataloader = DataLoader(siamese_dataset,
                                  shuffle=True,
                                  num_workers=4,
                                  batch_size=Config.train_batch_size)
    net = SiameseNetworkDouble().cuda()
    print(net)
    print("-" * 200)
    criterion = ContrastiveLoss()
    optimizer = optim.Adam(net.parameters(), lr=0.0005)

    counter = []
    loss_history = []
    iteration_number = 0
    for epoch in range(0, Config.train_number_epochs):
        for i, data in enumerate(train_dataloader, 0):
            img0, img1, label = data
            img0, img1, label = img0.cuda(), img1.cuda(), label.cuda()
            optimizer.zero_grad()
            output1, output2 = net(img0, img1)
            loss_contrastive = criterion(output1, output2, label)
            loss_contrastive.backward()
            optimizer.step()
            if i % 20 == 0:
                print("Epoch {}/{}: Current batch loss = {:4f}\n".format(epoch,
                                                                         Config.train_number_epochs,
                                                                         loss_contrastive.item()))
                iteration_number += 20
                counter.append(iteration_number)
                loss_history.append(loss_contrastive.item())

    show_plot(counter, loss_history, img_name="{}/train_loss.jpg".format(base_dir))

    # test
    folder_dataset_test = dset.ImageFolder(root=Config.testing_dir)
    siamese_dataset = SiameseNetworkDataset(imageFolderDataset=folder_dataset_test,
                                            transform=transforms.Compose([transforms.Resize((100, 100)),
                                                                          transforms.ToTensor()]),
                                            should_invert=False)

    test_dataloader = DataLoader(
        siamese_dataset,
        num_workers=4,
        batch_size=1,
        shuffle=True)
    dataiter = iter(test_dataloader)
    x0, _, _ = next(dataiter)

    for i in range(10):
        _, x1, label2 = next(dataiter)
        concatenated = torch.cat((x0, x1), 0)

        output1, output2 = net(Variable(x0).cuda(), Variable(x1).cuda())
        euclidean_distance = F.pairwise_distance(output1, output2)
        imshow(img=torchvision.utils.make_grid(concatenated),
               img_name="{}/img_{}.png".format(base_dir, i + 1),
               text='Dissimilarity: {:.2f}'.format(euclidean_distance.item()))
    pass


if __name__ == '__main__':
    # net = resnet34(num_classes=10, include_top=True).cuda()
    # x = torch.rand(1, 3, 224, 224)
    # x = x.cuda()
    # print(net(x).shape)
    run()

你可能感兴趣的:(计算机视觉,cnn,pytorch)