【开源计划】图像配准中常用损失函数的pytorch实现

前言

按照开源计划的预告,我们首先从基于深度学习的图像配准任务中常用的损失函数的代码实现开始。从我最开始的那一篇博客,即基于深度学习的医学图像配准综述,可以看出,目前基于无监督学习的图像非刚性配准模型成为了一个比较流行的研究方向。这是因为基于监督学习的方法过分依赖传统方法或者模拟变形的方法来提供监督信息,这样既吃力又不讨好。以我的探索经验来讲,以传统配准方法产生的变形场作为监督信息,对网络进行训练,很容易造成过拟合问题。因此,我综合文献综述的结论与探索的经验(有兴趣的话,我可以总结一下我的探索经历以及经验教训),最终选择了基于无监督学习的配准模型,则本文主要介绍这种模型框架下常用的损失函数。实际上,监督学习的损失函数也比较简单,只需要使用深度学习框架(如TensorFlow、PyTorch)提供的函数计算误差即可,本文使用PyTorch进行实现。

损失函数

基于无监督学习的图像非刚性配准模型的损失函数通常是由两部分组成,一个是参考图像与变形后的浮动图像的相似性测度,一个是网络预测变形场的空间正则化。以比较有名的VoxelMorph为例,它的GitHub仓库可以点击此链接。按照他最早的发表在CVPR上的论文,损失函数如下:

【开源计划】图像配准中常用损失函数的pytorch实现_第1张图片

其中第一项就是相似性测度,后面一项就是空间正则化项,用以约束变形场的空间平滑性。下面我们分别对其进行介绍。

相似性测度

常用于测量图像的相似性测度有三个,一个是图像灰度的均方差(mean squared voxel differece),一个是交叉互相关(cross-correlation),一个是互信息(mutual information)。前两个通常用于单模态的图像,而第一个的鲁棒性相比于交叉互相关更差一些,比较容易受图像灰度分布与对比度等的影响。互信息通常用于多模态的图像,在单模态图像的鲁棒性更好,但是到目前为止还没有发现它被用于深度学习网络训练的损失函数中,我的猜想是互信息的计算是基于统计的,不方便进行梯度计算,与反向传播原则相违背。(该看法亟待进一步的考证。好久没看Voxelmorph的开源代码,现在已经有互信息的实现了,有时间可以研究一下)

因此,主要是]使用交叉互相关作为图像配准的损失函数。交叉互相关的公式为:(摘自VoxelMorph)

【开源计划】图像配准中常用损失函数的pytorch实现_第2张图片

他们的代码实现-TensorFlow版请查看链接中的NCC。需要指出的是,在他们的实现版本当中,他们对于三维图像使用了一个9*9*9的窗口来计算相似性,因此成为local cross-correlation,即局部交叉互相关。(没想到现在voxelmorph还提供了pytorch版本的代码,真周到,见链接)

这里展示一下我自己参考开源代码,转写成pytorch实现的局部互相关,如下:

首先,是导入依赖库

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np

接着,是局部互相关:

class LCC(nn.Module):
    """
    local (over window) normalized cross correlation (square)
    """
    def __init__(self, win=[9, 9], eps=1e-5):
        super(LCC, self).__init__()
        self.win = win
        self.eps = eps
        
    def forward(self, I, J):
        I2 = I.pow(2)
        J2 = J.pow(2)
        IJ = I * J
        
        filters = Variable(torch.ones(1, 1, self.win[0], self.win[1]))
        if I.is_cuda:#gpu
            filters = filters.cuda()
        padding = (self.win[0]//2, self.win[1]//2)
        
        I_sum = F.conv2d(I, filters, stride=1, padding=padding)
        J_sum = F.conv2d(J, filters, stride=1, padding=padding)
        I2_sum = F.conv2d(I2, filters, stride=1, padding=padding)
        J2_sum = F.conv2d(J2, filters, stride=1, padding=padding)
        IJ_sum = F.conv2d(IJ, filters, stride=1, padding=padding)
        
        win_size = self.win[0]*self.win[1]
 
        u_I = I_sum / win_size
        u_J = J_sum / win_size
        
        cross = IJ_sum - u_J*I_sum - u_I*J_sum + u_I*u_J*win_size
        I_var = I2_sum - 2 * u_I * I_sum + u_I*u_I*win_size
        J_var = J2_sum - 2 * u_J * J_sum + u_J*u_J*win_size
 
        cc = cross*cross / (I_var*J_var + self.eps)#np.finfo(float).eps
        lcc = -1.0 * torch.mean(cc) + 1
        return lcc

除此之外,我还按照交叉互相关的定义,实现了一个全局的交叉互相关,即对整幅图计算,不依赖窗口。当图像尺寸较大时,全局交叉互相关的敏感度不如局部交叉互相关,实际效果不如局部的,其代码如下:

class GCC(nn.Module):
    """
    global normalized cross correlation (sqrt)
    """
    def __init__(self):
        super(GCC, self).__init__()
 
    def forward(self, I, J):
        I2 = I.pow(2)
        J2 = J.pow(2)
        IJ = I * J
        #average value
        I_ave, J_ave= I.mean(), J.mean()
        I2_ave, J2_ave = I2.mean(), J2.mean()
        IJ_ave = IJ.mean()
        
        cross = IJ_ave - I_ave * J_ave
        I_var = I2_ave - I_ave.pow(2)
        J_var = J2_ave - J_ave.pow(2)
 
#        cc = cross*cross / (I_var*J_var + np.finfo(float).eps)#1e-5
        cc = cross / (I_var.sqrt() * J_var.sqrt() + np.finfo(float).eps)#1e-5
 
        return -1.0 * cc + 1

空间正则化

在训练网络过程中,通过最大化图像的相似性测度,往往使网络产生不连续的变形场,通常要对预测的变形场施加一个空间平滑性的约束,即对变形场的空间梯度进行惩罚,如voxelmorph中的空间正则化,即计算变形场梯度的L2范数的平方:

【开源计划】图像配准中常用损失函数的pytorch实现_第3张图片

这里展示一下我自己参考开源代码,转写成pytorch实现的空间正则化,如下:

class Grad(nn.Module):
    """
    N-D gradient loss
    """
    def __init__(self, penalty='l2'):
        super(Grad, self).__init__()
        self.penalty = penalty
    
    def _diffs(self, y):#y shape(bs, nfeat, vol_shape)
        ndims = y.ndimension() - 2
        df = [None] * ndims
        for i in range(ndims):
            d = i + 2#y shape(bs, c, d, h, w)
            # permute dimensions to put the ith dimension first
#            r = [d, *range(d), *range(d + 1, ndims + 2)]
            y = y.permute(d, *range(d), *range(d + 1, ndims + 2))
            dfi = y[1:, ...] - y[:-1, ...]
            
            # permute back
            # note: this might not be necessary for this loss specifically,
            # since the results are just summed over anyway.
#            r = [*range(1, d + 1), 0, *range(d + 1, ndims + 2)]
            df[i] = dfi.permute(*range(1, d + 1), 0, *range(d + 1, ndims + 2))
        
        return df
    
    def forward(self, pred):
        ndims = pred.ndimension() - 2
        if pred.is_cuda:
            df = Variable(torch.zeros(1).cuda())
        else:
            df = Variable(torch.zeros(1))
        for f in self._diffs(pred):
            if self.penalty == 'l1':
                df += f.abs().mean() / ndims
            else:
                assert self.penalty == 'l2', 'penalty can only be l1 or l2. Got: %s' % self.penalty
                df += f.pow(2).mean() / ndims
        return df

实际上,按照这种空间正则化的思想,还有其他几种方法,比如,最近的一篇期刊论文使用了一种称为折叠惩罚(bending penalty)的正则化方法,实际上就是计算变形场的二阶梯度,按照字面意思,它的目的是对变形场中的折叠进行惩罚,其公式如下:

【开源计划】图像配准中常用损失函数的pytorch实现_第4张图片

我按照公式实现的二维的折叠惩罚项如下,有兴趣的可以自己实现一下三维版的。

class Bend_Penalty(nn.Module):
    """
    Bending Penalty of the spatial transformation (2D)
    """
    def __init__(self):
        super(Bend_Penalty, self).__init__()
    
    def _diffs(self, y, dim):#y shape(bs, nfeat, vol_shape)
        ndims = y.ndimension() - 2
        d = dim + 2
        # permute dimensions to put the ith dimension first
#       r = [d, *range(d), *range(d + 1, ndims + 2)]
        y = y.permute(d, *range(d), *range(d + 1, ndims + 2))
        dfi = y[1:, ...] - y[:-1, ...]
        
        # permute back
        # note: this might not be necessary for this loss specifically,
        # since the results are just summed over anyway.
#       r = [*range(1, d + 1), 0, *range(d + 1, ndims + 2)]
        df = dfi.permute(*range(1, d + 1), 0, *range(d + 1, ndims + 2))
        
        return df
    
    def forward(self, pred):#shape(B,C,H,W)
        Ty = self._diffs(pred, dim=0)
        Tx = self._diffs(pred, dim=1)
        Tyy = self._diffs(Ty, dim=0)
        Txx = self._diffs(Tx, dim=1)
        Txy = self._diffs(Tx, dim=0)
        p = Tyy.pow(2).mean() + Txx.pow(2).mean() + 2 * Txy.pow(2).mean()
        
        return p

另外,还有文章研究了将变形场的L1范数进行空间正则化,它的效果是尽可能减小形变的绝对值。

仿射变换的损失函数

最后,我还实现了一下仿射变换的损失函数,就是将仿射变换的参数与恒等变换的参数之间的差异求L1或L2范数。代码如下:

class IDloss(nn.Module):
    """
    loss between affine transformation and identity transf.
    """
    def __init__(self, penalty='l1'):
        super(IDloss, self).__init__()
        self.penalty = penalty
        self.id = torch.FloatTensor([1, 0, 0, 0, 1, 0])
 
    def forward(self, theta):
        if theta.is_cuda:
            ID = Variable(self.id.cuda())
        else:
            ID = Variable(self.id)
        ID = ID.repeat(theta.size(0), 1).view(theta.shape)
        if self.penalty == 'l1':
            loss = torch.mean(torch.abs(theta - ID))
        else:
            assert self.penalty == 'l2', 'penalty can only be l1 or l2. Got: %s' % self.penalty
            loss = torch.mean(torch.pow(theta - ID, 2))
        return loss

结束语

最后,以上仅供参考,欢迎各位网友批评指正与留言交流。

有兴趣的还可以关注一下我的B站账号:Timmy_毛毛,方便及时获取更新视频内容,谢谢~

你可能感兴趣的:(pytorch,python,配准)