孪生网络(Siamese Network)实现手写数字聚类

Siamese Network通常用于小样本的学习,是meta learning的方法。

Siamese Network,其使用CNN网络作为特征提取器,不同类别的样本,共用一个CNN网络,在CNN网络之后添加了全连接层,可以用于判别输入的样本是否是同一类别。也就是二分类问题。

这里实现的孪生网络输入是从相同类别或不同类别样本中随机采样一对数据,如果是相同类别,则标签为1 ,如果是不同类别,则标签为0。注意相同类别和不同类别样本对要平衡。具体实现还是看代码比较直接。

图一特征提取 

 

孪生网络(Siamese Network)实现手写数字聚类_第1张图片

图2 contrstive loss 

 相同类别相识度为1, 不同类别相识度为0

孪生网络(Siamese Network)实现手写数字聚类_第2张图片

 图3 三元法 triplet loss

\alpha:margin (>0)超参数,期望不同类别的分离程度

如果,,则没有loss, 否则,loss为

写成数学表达式就是:

还有一种方法是,还可以三个样本组成一对样本送入网络,即随机抽取一个一个样本,再随机抽取一个相同类别的样本作为正样本和不同类别的样本作为负样本,组成样本对。

相应的loss function使用triplet loss,这种方法可以取得更好的效果。

这里先给出triplet loss,相应自定义数据集该日再补充

class TripletLoss(nn.Module):
    """
    Triplet loss
    Takes embeddings of an anchor sample, a positive sample and a negative sample
    """

    def __init__(self, margin):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative, size_average=True):
        distance_positive = (anchor - positive).pow(2).sum(1)  # .pow(.5)
        distance_negative = (anchor - negative).pow(2).sum(1)  # .pow(.5)
        losses = F.relu(distance_positive - distance_negative + self.margin)
        return losses.mean() if size_average else losses.sum()

这里给出完整代码,包含三个代码文件,siamese_dataset, model, main()

下面代码是自定义的数据集

import torch
import torch.nn as nn
import torchvision
import torchvision.datasets as dataset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader,Dataset
import torchvision.utils
import numpy as np
import random
from torch.utils.data.sampler import BatchSampler
from PIL import Image





class SiameseMNIST(Dataset):
    """
    Train: For each sample creates randomly a positive or a negative pair
    Test: Creates fixed pairs for testing
    """

    def __init__(self, mnist_dataset):
        self.mnist_dataset = mnist_dataset

        self.train = self.mnist_dataset.train
        self.transform = self.mnist_dataset.transform

        if self.train:
            self.train_labels = self.mnist_dataset.targets
            self.train_data = self.mnist_dataset.data
            self.labels_set = set(self.train_labels.numpy())
            self.label_to_indices = {label: np.where(self.train_labels.numpy() == label)[0]
                                     for label in self.labels_set}
        else:
            # generate fixed pairs for testing
            self.test_labels = self.mnist_dataset.targets
            self.test_data = self.mnist_dataset.data
            self.labels_set = set(self.test_labels.numpy())
            self.label_to_indices = {label: np.where(self.test_labels.numpy() == label)[0]
                                     for label in self.labels_set}

            random_state = np.random.RandomState(29)

            positive_pairs = [[i,
                               random_state.choice(self.label_to_indices[self.test_labels[i].item()]),
                               1]
                              for i in range(0, len(self.test_data), 2)]

            negative_pairs = [[i,
                               random_state.choice(self.label_to_indices[
                                                       np.random.choice(
                                                           list(self.labels_set - set([self.test_labels[i].item()]))
                                                       )
                                                   ]),
                               0]
                              for i in range(1, len(self.test_data), 2)]
            self.test_pairs = positive_pairs + negative_pairs

    def __getitem__(self, index):
        if self.train:
            target = np.random.randint(0, 2)
            img1, label1 = self.train_data[index], self.train_labels[index].item()
            if target == 1:
                siamese_index = index
                while siamese_index == index:
                    siamese_index = np.random.choice(self.label_to_indices[label1])
            else:
                siamese_label = np.random.choice(list(self.labels_set - set([label1])))
                siamese_index = np.random.choice(self.label_to_indices[siamese_label])
            img2 = self.train_data[siamese_index]
        else:
            img1 = self.test_data[self.test_pairs[index][0]]
            img2 = self.test_data[self.test_pairs[index][1]]
            target = self.test_pairs[index][2]

        img1 = Image.fromarray(img1.numpy(), mode='L')
        img2 = Image.fromarray(img2.numpy(), mode='L')
        if self.transform is not None:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
        return (img1, img2), target

    def __len__(self):
        return len(self.mnist_dataset)

因为MNIST数据集比较简单,所以模型也比较也简单。重点是,Contrastiveloss函数

import torch
import torch.nn as nn
import torch.nn.functional as F

class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        self.cnn1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5),
            nn.PReLU(),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(32, 64, kernel_size=5),
            nn.PReLU(),
            nn.MaxPool2d(2, 2),
        )

        self.fc1 = nn.Sequential(
            nn.Linear(64*4*4, 256),
            nn.PReLU(),

            nn.Linear(256, 256),
            nn.PReLU(),

            nn.Linear(256, 2))

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)


    def forward_once(self, x):
        output = self.cnn1(x)
        output = output.view(output.size()[0], -1)
        output = self.fc1(output)
        return output

    def forward(self, input1, input2):
        output1 = self.forward_once(input1)
        output2 = self.forward_once(input2)
        return output1, output2




class ContrastiveLoss(nn.Module):
    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
        self.eps = 1e-9

    def forward(self, output1, output2, target, size_average=True):
        distances = (output1-output2).pow(2).sum(1)
        loss = 0.5*(target.float()*distances +
                    (1 - target).float()*F.relu(self.margin - (distances+self.eps).sqrt()).pow(2))

        return loss.mean() if size_average else loss.sum()

训练和验证主程序,把文件路径改一下就可以了。

import sys
import os
import torch
import torch.nn as nn
import torchvision
from torch.optim.lr_scheduler import ExponentialLR, MultiStepLR
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
import torchvision.datasets as dataset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import numpy as np
import random
from torch.autograd import Variable
from torch import optim
import torch.nn.functional as F
import nibabel as nib
import argparse
from tqdm import tqdm
import visdom
from Siamese_minist import SiameseMNIST
from siamese_model import SiameseNetwork, ContrastiveLoss




parser = argparse.ArgumentParser()
parser.add_argument('--train_dir', type=str, default='./data')
parser.add_argument('--test_dir', type=str, default=None)
parser.add_argument('--batch_size', type=int, default=128, help='batch size')
parser.add_argument('--epochs', type=int, default=20, help='number epoch to training')
parser.add_argument('--lr', type=float, default=1e-3, help='initial learning rate')
parser.add_argument('--nw', type=int, default=16, help='Dataloader num_works')
parser.add_argument('--save_path', type=str, default='/home/yang/cnn3d/mutipule_calssification/weight_path/Siamese_model.pth', help='model weight save path')
parser.add_argument('--train_path', type=str, default='/home/yang/cnn3d/mutipule_calssification/data_selected/category_10/train_data', help='training data path')
parser.add_argument('--test_path', type=str, default='/home/yang/cnn3d/mutipule_calssification/data_selected/category_10/test_data', help='test data path')
parser.add_argument('--margin', type=float, default=1.0, help='contrastive loss margin ')
parser.add_argument('--gamma', type=float, default=0.95, help='optimizer scheduler gamma')

torch.manual_seed(1)

opt = parser.parse_args()
print(opt)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
#visdom 可视化,在Teminal窗口输入 python3 -m visdom.server
viz = visdom.Visdom()
train_dataset_path = opt.train_path
test_dataset_path = opt.test_path
mean, std = 0.1307, 0.3081
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((mean,), (std,))])
minist_path = "/home/yang/cnn3d/mutipule_calssification/SiameseNetwork/mnist"

minist_train = dataset.MNIST(minist_path, train=True, transform=transform, download=False)
minist_test = dataset.MNIST(minist_path, train=False, transform=transform, download=False)

train_dataset = SiameseMNIST(minist_train)
test_dataset = SiameseMNIST(minist_test)

train_loader = DataLoader(minist_train, batch_size=64)
test_loader = DataLoader(minist_test, batch_size=64)


train_dataloader = DataLoader(train_dataset,
                        shuffle=True,
                        num_workers=opt.nw,
                        batch_size=opt.batch_size)
test_dataloader = DataLoader(test_dataset, shuffle=True, batch_size=opt.batch_size, num_workers=opt.nw)

net = SiameseNetwork().to(device)


criterion = ContrastiveLoss(margin=opt.margin)
optimizer = optim.Adam(net.parameters(), lr=opt.lr)
scheduler = ExponentialLR(optimizer, gamma=opt.gamma)
scheduler1 = MultiStepLR(optimizer, [10, 20], gamma=0.1)
def show_plot(iteration,loss):
    plt.plot(iteration, loss)
    plt.show()


mnist_classes = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728',
              '#9467bd', '#8c564b', '#e377c2', '#7f7f7f',
              '#bcbd22', '#17becf']

def plot_embeddings(embeddings, targets, xlim=None, ylim=None):
    plt.figure(figsize=(10, 10))
    for i in range(10):
        inds = np.where(targets==i)[0]
        plt.scatter(embeddings[inds,0], embeddings[inds,1], alpha=0.5, color=colors[i])
    if xlim:
        plt.xlim(xlim[0], xlim[1])
    if ylim:
        plt.ylim(ylim[0], ylim[1])
    plt.legend(mnist_classes)

def extract_embeddings(dataloader, model, cuda=True):
    with torch.no_grad():
        model.eval()
        embeddings = np.zeros((len(dataloader.dataset), 2))
        labels = np.zeros(len(dataloader.dataset))
        k = 0
        for images, target in dataloader:
            if cuda:
                images = images.to(device)
            embeddings[k:k+len(images)] = model.forward_once(images).data.cpu().numpy()
            labels[k:k+len(images)] = target.numpy()
            k += len(images)
    return embeddings, labels


viz.line([0.], [0.], win='train_loss', opts=dict(title='training Loss'))
viz.line([0.], [0.], win='val_loss', opts=dict(title='valuation Loss'))

def main():
    net.train()
    counter = []
    loss_history = []
    iteration_number = 0
    global_step = 0.0
    val_step = 0.0
    for epoch in range(opt.epochs):
        train_loss = 0.0
        train_bar = tqdm(train_dataloader, file=sys.stdout)
        for index, data in enumerate(train_bar):
            (image0, image1), label = data
            image0, image1, label = image0.to(device), image1.to(device), label.to(device)
            optimizer.zero_grad()

            output1, output2 = net(image0, image1)
            loss_contrastive = criterion(output1, output2, label)
            loss_contrastive.backward()
            train_loss += loss_contrastive.item()
            global_step += 1

            optimizer.step()
            viz.line([loss_contrastive.item()], [global_step], win='train_loss', opts=dict(title='training Loss'),
                     update='append')

            if index % 10 == 0:
                iteration_number += 10
                counter.append(iteration_number)
                loss_history.append(loss_contrastive.item())

        print("Epoch number {} Current loss {}".format(epoch+1, train_loss/(len(train_dataloader))))
        print("第%d个epoch的学习率:%f" % (epoch + 1, optimizer.param_groups[0]['lr']))
        scheduler1.step()
        if epoch % 5 == 0:
            net.eval()
            with torch.no_grad():
                loss = 0.0
                val_bar = tqdm(test_dataloader, file=sys.stdout)
                for index, data in enumerate(val_bar):
                    val_step += 1
                    (val_image0, val_image1), val_label = data
                    val_image0, val_image1, val_label = val_image0.to(device), val_image1.to(device), val_label.to(device)
                    output1, output2 = net(val_image0, val_image1)
                    loss_contrastive = criterion(output1, output2, val_label)
                    loss += loss_contrastive.item()
                    viz.line([loss_contrastive.item()], [val_step], win='val_loss', opts=dict(title='valuation loss'), update='append')
                print('epoch %d| valuation Loss:%.4f' % (epoch, loss/len(test_dataloader)))
    # torch.save(net.state_dict(), opt.save_path)
    show_plot(counter, loss_history)

def valuation():
    net.eval()
    dataiter = iter(test_dataloader)
    with torch.no_grad():
        num = 0.0
        x0, _, label1 = next(dataiter)
        min_diatance = 10
        predic_label = None
        for i in range(len(test_dataset)-1):
            _, x1, label2 = next(dataiter)
            output1, output2 = net(Variable(x0).cuda(), Variable(x1).cuda())
            euclidean_distance = F.pairwise_distance(output1, output2)
            if euclidean_distance < min_diatance:
                min_diatance = euclidean_distance
                predic_label = label2
            if predic_label == label1:
                num += 1

        print('min diatance: ', min_diatance)
        print('predicted label', predic_label)




if __name__ == '__main__':
    main()
    #聚类结果可视化
    train_embeddings, train_labels = extract_embeddings(train_loader, net)
    #figure1 train data
    plot_embeddings(train_embeddings, train_labels)
    val_embeddings, val_labels = extract_embeddings(test_loader, net)
    #figure2 test data
    plot_embeddings(val_embeddings, val_labels)
    plt.show()









运行结果:

训练的loss曲线:

孪生网络(Siamese Network)实现手写数字聚类_第3张图片

训练集数据效果: 

孪生网络(Siamese Network)实现手写数字聚类_第4张图片

你可能感兴趣的:(聚类,深度学习,cnn,pytorch)