使用SiameseNet进行肺部相似度计算,同样可以用于人脸识别等场景。
特征提取网络结果为Resnet,可以为Resnet34、Resnet50等。
数据组织结构如下图所示:
#!/usr/bin/python
# -*- coding: UTF-8 -*-
"""
@author:uncle德鲁
@file:siamesenet.py
@time:2023/07/29
"""
import os
import torchvision
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import torchvision.utils
import numpy as np
import random
from PIL import Image
import torch
from torch.autograd import Variable
import PIL.ImageOps
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.hub import load_state_dict_from_url
import sys
import datetime
from torchsummary import summary
torch.autograd.set_detect_anomaly(True)
class Logger(object):
def __init__(self, filename, stream=sys.stdout):
self.terminal = stream
self.log = open(filename, 'a')
def write(self, message):
self.terminal.write(message)
self.log.write(message)
def flush(self):
pass
MY_DATA = "lung_mask"
# 现在的时间
now = datetime.datetime.now()
formatted_time = now.strftime("%Y-%m-%d_%H-%M")
sys.stdout = Logger("./result/train_loss_{}.log".format(formatted_time), sys.stdout)
def imshow(img, img_name, text=None, title=None):
npimg = img.numpy()
plt.axis("off")
if text:
plt.text(75, 8, text, style='italic', fontweight='bold',
bbox={'facecolor': 'white', 'alpha': 0.8, 'pad': 10})
if title:
plt.title(title)
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.savefig(img_name)
plt.clf()
def show_plot(iteration, loss, img_name):
plt.plot(iteration, loss)
plt.savefig(img_name)
plt.clf()
class Config:
my_data = MY_DATA
training_dir = "./data/{}/training/".format(my_data)
testing_dir = "./data/{}/testing/".format(my_data)
train_batch_size = 4
train_number_epochs = 10
class SiameseNetworkDataset(Dataset):
def __init__(self, imageFolderDataset, transform=None, should_invert=True):
self.imageFolderDataset = imageFolderDataset
self.transform = transform
self.should_invert = should_invert
def __getitem__(self, index):
img0_tuple = random.choice(self.imageFolderDataset.imgs)
# we need to make sure approx 50% of images are in the same class
should_get_same_class = random.randint(0, 1)
if should_get_same_class:
while True:
# keep looping till the same class image is found
img1_tuple = random.choice(self.imageFolderDataset.imgs)
if img0_tuple[1] == img1_tuple[1]:
break
else:
while True:
# keep looping till a different class image is found
img1_tuple = random.choice(self.imageFolderDataset.imgs)
if img0_tuple[1] != img1_tuple[1]:
break
img0 = Image.open(img0_tuple[0])
img1 = Image.open(img1_tuple[0])
img0 = img0.convert("L")
img1 = img1.convert("L")
if self.should_invert:
img0 = PIL.ImageOps.invert(img0)
img1 = PIL.ImageOps.invert(img1)
if self.transform is not None:
img0 = self.transform(img0)
img1 = self.transform(img1)
return img0, img1, torch.from_numpy(
np.array([int(img1_tuple[1] != img0_tuple[1])], dtype=np.float32))
def __len__(self):
return len(self.imageFolderDataset.imgs)
class BasicBlock(nn.Module):
"""
# 定义 BasicBlock 模块
# ResNet18/34的残差结构, 用的是2个3x3大小的卷积
"""
expansion = 1 # 残差结构中, 判断主分支的卷积核个数是否发生变化,不变则为1
def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs): # downsample 对应虚线残差结构
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
kernel_size=(3, 3), stride=(stride, stride), padding=1, bias=False
)
self.bn1 = nn.BatchNorm2d(out_channel)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False
)
self.bn2 = nn.BatchNorm2d(out_channel)
self.downsample = downsample
def forward(self, x):
identity = x
if self.downsample is not None: # 虚线残差结构,需要下采样
identity = self.downsample(x) # 捷径分支short cut
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out += identity
out = self.relu(out)
return out
class Bottleneck(nn.Module):
"""
# 定义 Bottleneck 模块
# ResNet50/101/152的残差结构,用的是1x1+3x3+1x1的卷积
# 注意:原论文中,在虚线残差结构的主分支上,第一个1x1卷积层的步距是2,第二个3x3卷积层步距是1。
# 但在pytorch官方实现过程中是第一个1x1卷积层的步距是1,第二个3x3卷积层步距是2,
# 这么做的好处是能够在top1上提升大概0.5%的准确率。
# 可参考Resnet v1.5 https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch
"""
expansion = 4 # 残差结构中第三层卷积核个数是第1/2层卷积核个数的4倍
def __init__(self, in_channel, out_channel, stride=1,
downsample=None, groups=1, width_per_group=64):
super(Bottleneck, self).__init__()
width = int(out_channel * (width_per_group / 64.)) * groups
self.conv1 = nn.Conv2d(
in_channels=in_channel,
out_channels=width,
kernel_size=(1, 1),
stride=(1, 1),
bias=False)
self.bn1 = nn.BatchNorm2d(width)
self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, groups=groups,
kernel_size=(3, 3), stride=(stride, stride), bias=False, padding=1
)
self.bn2 = nn.BatchNorm2d(width)
self.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel * self.expansion,
kernel_size=(1, 1), stride=(1, 1), bias=False)
self.bn3 = nn.BatchNorm2d(out_channel * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
def forward(self, x):
identity = x
if self.downsample is not None:
identity = self.downsample(x) # 捷径分支short cut
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
out += identity
out = self.relu(out)
return out
class ResNet(nn.Module):
"""
# 残差网络结构
"""
# block = BasicBlock or Bottleneck
# blocks_num 为残差结构中 conv2_x~conv5_x 中残差块个数, 一个列表
def __init__(self, block, blocks_num, num_classes=1000, include_top=True, groups=1, width_per_group=64):
super(ResNet, self).__init__()
self.include_top = include_top
self.in_channel = 64
self.groups = groups
self.width_per_group = width_per_group
self.conv1 = nn.Conv2d(1,
self.in_channel,
kernel_size=(7, 7),
stride=(2, 2),
padding=3,
bias=False)
self.bn1 = nn.BatchNorm2d(self.in_channel)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, blocks_num[0])
self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)
if self.include_top:
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # output size = (1, 1)
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
# channel 为残差结构中第1层卷积核个数
def _make_layer(self, block, channel, block_num, stride=1):
downsample = None
# ResNet50/101/152 的残差结构, block.expansion=4
if stride != 1 or self.in_channel != channel * block.expansion:
downsample = nn.Sequential(nn.Conv2d(self.in_channel,
channel *
block.expansion,
kernel_size=(1, 1),
stride=(stride, stride),
bias=False),
nn.BatchNorm2d(channel * block.expansion))
layers = []
layers.append(block(self.in_channel,
channel,
downsample=downsample,
stride=stride,
groups=self.groups,
width_per_group=self.width_per_group,
))
self.in_channel = channel * block.expansion
for _ in range(1, block_num):
layers.append(block(self.in_channel,
channel,
groups=self.groups,
width_per_group=self.width_per_group,
))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
if self.include_top:
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def resnet34(num_classes=1000, include_top=True):
"""
# resnet34 结构
# https://download.pytorch.org/models/resnet34-333f7ec4.pth
"""
return ResNet(BasicBlock, [3, 4, 6, 3],
num_classes=num_classes, include_top=include_top)
def resnet50(num_classes=1000, include_top=True):
"""
# resnet50 结构
# https://download.pytorch.org/models/resnet50-19c8e357.pth
"""
return ResNet(Bottleneck, [3, 4, 6, 3],
num_classes=num_classes, include_top=include_top)
def resnet101(num_classes=1000, include_top=True):
"""
# resnet101 结构
# https://download.pytorch.org/models/resnet101-5d3b4d8f.pth
"""
return ResNet(Bottleneck, [3, 4, 23, 3],
num_classes=num_classes, include_top=include_top)
def resnext50_32x4d(num_classes=1000, include_top=True):
"""
# resnext50_32x4d 结构
# https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth
"""
groups = 32
width_per_group = 4
return ResNet(Bottleneck, [3, 4, 6, 3],
num_classes=num_classes,
include_top=include_top,
groups=groups,
width_per_group=width_per_group)
def resnext101_32x8d(num_classes=1000, include_top=True):
"""
# resnext101_32x8d 结构
# https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth
"""
groups = 32
width_per_group = 8
return ResNet(Bottleneck, [3, 4, 23, 3],
num_classes=num_classes,
include_top=include_top,
groups=groups,
width_per_group=width_per_group)
class SiameseNetwork(nn.Module):
def __init__(self, num_classes=1000):
super().__init__()
# self.resnet = resnet50(num_classes=num_classes, include_top=True)
self.resnet = resnet34(num_classes=num_classes, include_top=True)
def initialize_weights(self):
for module in self.modules():
if isinstance(module, nn.Conv2d):
# Initialize the weights of convolutional layers
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.BatchNorm2d):
# Initialize the weights and biases of batch normalization layers
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Linear):
# Initialize the weights and biases of linear layers
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
def forward(self, x):
raise NotImplementedError
class SiameseNetworkQuadret(SiameseNetwork):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def forward(self, x):
x1, x2, x3, x4 = x
x1, _ = self.resnet(x1)
x2, _ = self.resnet(x2)
x3, _ = self.resnet(x3)
x4, _ = self.resnet(x4)
return x1, x2
class SiameseNetworkTriplet(SiameseNetwork):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def forward(self, x):
x1, x2, x3 = x
x1 = self.resnet(x1)
x2 = self.resnet(x2)
x3 = self.resnet(x3)
return x1, x2, x3
class SiameseNetworkDouble(SiameseNetwork):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def forward(self, x1, x2):
x1 = self.resnet(x1)
x2 = self.resnet(x2)
return x1, x2
# Loss Function
class ContrastiveLoss(torch.nn.Module):
"""
Contrastive loss function.
Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
"""
def __init__(self, margin=2.0):
super(ContrastiveLoss, self).__init__()
self.margin = margin
def forward(self, output1, output2, label):
euclidean_distance = F.pairwise_distance(
output1, output2, keepdim=True)
loss_contrastive = torch.mean((1 - label) * torch.pow(euclidean_distance, 2) +
label * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
return loss_contrastive
def run():
base_dir = "./result/{}/".format(MY_DATA)
if not os.path.exists(base_dir):
os.makedirs(base_dir)
folder_dataset = dset.ImageFolder(root=Config.training_dir)
siamese_dataset = SiameseNetworkDataset(imageFolderDataset=folder_dataset,
transform=transforms.Compose([transforms.Resize((100, 100)),
transforms.ToTensor()]),
should_invert=False)
# train
train_dataloader = DataLoader(siamese_dataset,
shuffle=True,
num_workers=4,
batch_size=Config.train_batch_size)
net = SiameseNetworkDouble().cuda()
print(net)
print("-" * 200)
criterion = ContrastiveLoss()
optimizer = optim.Adam(net.parameters(), lr=0.0005)
counter = []
loss_history = []
iteration_number = 0
for epoch in range(0, Config.train_number_epochs):
for i, data in enumerate(train_dataloader, 0):
img0, img1, label = data
img0, img1, label = img0.cuda(), img1.cuda(), label.cuda()
optimizer.zero_grad()
output1, output2 = net(img0, img1)
loss_contrastive = criterion(output1, output2, label)
loss_contrastive.backward()
optimizer.step()
if i % 20 == 0:
print("Epoch {}/{}: Current batch loss = {:4f}\n".format(epoch,
Config.train_number_epochs,
loss_contrastive.item()))
iteration_number += 20
counter.append(iteration_number)
loss_history.append(loss_contrastive.item())
show_plot(counter, loss_history, img_name="{}/train_loss.jpg".format(base_dir))
# test
folder_dataset_test = dset.ImageFolder(root=Config.testing_dir)
siamese_dataset = SiameseNetworkDataset(imageFolderDataset=folder_dataset_test,
transform=transforms.Compose([transforms.Resize((100, 100)),
transforms.ToTensor()]),
should_invert=False)
test_dataloader = DataLoader(
siamese_dataset,
num_workers=4,
batch_size=1,
shuffle=True)
dataiter = iter(test_dataloader)
x0, _, _ = next(dataiter)
for i in range(10):
_, x1, label2 = next(dataiter)
concatenated = torch.cat((x0, x1), 0)
output1, output2 = net(Variable(x0).cuda(), Variable(x1).cuda())
euclidean_distance = F.pairwise_distance(output1, output2)
imshow(img=torchvision.utils.make_grid(concatenated),
img_name="{}/img_{}.png".format(base_dir, i + 1),
text='Dissimilarity: {:.2f}'.format(euclidean_distance.item()))
pass
if __name__ == '__main__':
# net = resnet34(num_classes=10, include_top=True).cuda()
# x = torch.rand(1, 3, 224, 224)
# x = x.cuda()
# print(net(x).shape)
run()