孪生网络入门Siamese Network

学习一个模型之前我们先了解一下该模型的作用,孪生网络解决的问题:对于两个图片或者其他数据源,进行change detection、验证等。

模型思路:每次需要输入两个样本作为一个样本对计算距离。将输入映射为一个特征向量,使用两个向量之间的“距离”(L2 Norm)来表示输入之间的差异(图像语义上的差距)。孪生网络可以用以判断两个样本之间的相似度。

孪生网络特点:一类包含两个或更多个相同子网络的神经网络架构。 这里相同是指它们具有相同的配置即具有相同的参数和权重。 参数更新在两个子网上共同进行。

  • 子网共享权重意味着训练需要更少的参数,也就意味着需要更少的数据并且不容易过拟合。
  • 每个子网本质上产生其输入的表示。

网络结构:每个分支网络结构相同,参数共享。即共用一个网络主体。

孪生网络入门Siamese Network_第1张图片

损失函数:Contrastive Loss损失函数

提供两个输入,一个是否相同的标签。

其中

注意这里设置了一个阈值m,表示我们只考虑不相似特征欧式距离在0~margin之间的,当距离超过margin的,则把其loss看做为0。

loss对比:softmax只需要输入一个样本;Triplet Loss需要输入三个样本;Contrastive Loss 两个样本

实现代码:

class SiameseNetwork(nn.Module):
	    def __init__(self):
	        super(SiameseNetwork, self).__init__()
	        self.cnn1 = nn.Sequential(
	            nn.ReflectionPad2d(1),
	            nn.Conv2d(1, 4, kernel_size=3),
	            nn.ReLU(inplace=True),
	            nn.BatchNorm2d(4),
	            nn.Dropout2d(p=.2),
	            
	            nn.ReflectionPad2d(1),
	            nn.Conv2d(4, 8, kernel_size=3),
	            nn.ReLU(inplace=True),
	            nn.BatchNorm2d(8),
	            nn.Dropout2d(p=.2),
	
	            nn.ReflectionPad2d(1),
	            nn.Conv2d(8, 8, kernel_size=3),
	            nn.ReLU(inplace=True),
	            nn.BatchNorm2d(8),
	            nn.Dropout2d(p=.2),
	        )
	
	        self.fc1 = nn.Sequential(
	            nn.Linear(8*100*100, 500),
	            nn.ReLU(inplace=True),
	
	            nn.Linear(500, 500),
	            nn.ReLU(inplace=True),
	
	            nn.Linear(500, 5)
	        )
	
	    def forward_once(self, x):
	        output = self.cnn1(x)
	        output = output.view(output.size()[0], -1)
	        output = self.fc1(output)
	        return output
	
	    def forward(self, input1, input2):
	        output1 = self.forward_once(input1)
	        output2 = self.forward_once(input2)
	        return output1, output2

class ContrastiveLoss(torch.nn.Module):
	    """
	    Contrastive loss function.
	    Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
	    """
	
	    def __init__(self, margin=2.0):
	        super(ContrastiveLoss, self).__init__()
	        self.margin = margin
	
	    def forward(self, output1, output2, label):
	        euclidean_distance = F.pairwise_distance(output1, output2)
	        loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2)  
	                                      (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
	
	        return loss_contrastive

 

你可能感兴趣的:(深度学习)