STANet代码解读models部分

models

  • __init__.py
  • backbone.py
  • BAM.py
  • CDF0.py
  • CDFA.py
  • loss.py
  • mynet.py
  • PAM2.py

init.py

import importlib
from models.base_model import BaseModel


def find_model_using_name(model_name):
    """Import the module "models/[model_name]_model.py".

    In the file, the class called DatasetNameModel() will
    be instantiated. It has to be a subclass of BaseModel,
    and it is case-insensitive.
    """
    model_filename = "models." + model_name + "_model"
    modellib = importlib.import_module(model_filename)
    model = None
    target_model_name = model_name.replace('_', '') + 'model'
    for name, cls in modellib.__dict__.items():
        if name.lower() == target_model_name.lower() \
           and issubclass(cls, BaseModel):
            model = cls

    if model is None:
        print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
        exit(0)

    return model


def get_option_setter(model_name):
    """Return the static method  of the model class."""
    model_class = find_model_using_name(model_name)
    return model_class.modify_commandline_options


def create_model(opt):
    """Create a model given the option.

    This function warps the class CustomDatasetDataLoader.
    This is the main interface between this package and 'train.py'/'test.py'

    Example:
        >>> from models import create_model
        >>> model = create_model(opt)
    """
    model = find_model_using_name(opt.model)
    instance = model(opt)
    print("model [%s] was created" % type(instance).__name__)
    return instance

这段代码实现了一个用于创建模型的函数create_model,主要包含以下几个部分:

find_model_using_name:根据模型名称找到对应的模型类。首先根据模型名称构造模型文件名,然后动态加载模型文件对应的模块,并遍历模块中的所有类,找到类名与模型名称匹配的模型类。

get_option_setter:返回模型类的modify_commandline_options静态方法,用于修改命令行参数。

create_model:根据命令行参数opt创建模型实例。首先调用find_model_using_name函数找到对应的模型类,然后创建该模型类的实例instance,并返回。

backbone.py

# coding: utf-8
import torch.nn as nn
import torch
from .mynet3 import F_mynet3
from .BAM import BAM
from .PAM2 import PAM as PAM



def define_F(in_c, f_c, type='unet'):
    if type == 'mynet3':
        print("using mynet3 backbone")
        return F_mynet3(backbone='resnet18', in_c=in_c,f_c=f_c, output_stride=32)
    else:
        NotImplementedError('no such F type!')

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)



class CDSA(nn.Module):
    """self attention module for change detection

    """
    def __init__(self, in_c, ds=1, mode='BAM'):
        super(CDSA, self).__init__()
        self.in_C = in_c
        self.ds = ds
        print('ds: ',self.ds)
        self.mode = mode
        if self.mode == 'BAM':
            self.Self_Att = BAM(self.in_C, ds=self.ds)
        elif self.mode == 'PAM':
            self.Self_Att = PAM(in_channels=self.in_C, out_channels=self.in_C, sizes=[1,2,4,8],ds=self.ds)
        self.apply(weights_init)

    def forward(self, x1, x2):
        height = x1.shape[3]
        x = torch.cat((x1, x2), 3)
        x = self.Self_Att(x)
        return x[:,:,:,0:height], x[:,:,:,height:]

这段代码实现了一个用于变化检测的自注意力模块CDSA(Change Detection Self Attention)。主要包含以下几个部分:
define_F:定义一个特征提取网络F,根据type参数的不同可以选择不同的backbone,如果type'mynet3'则使用F_mynet3网络。

weights_init:初始化网络权重的函数,用于给网络的卷积层和批归一化层设置初始值(卷积层为均值为0,标准差为0.02的正态分布,批归一化层的权重为1,偏置为0)。

__init__:构造函数,初始化了一些变量,如输入通道数in_c、降采样因子ds、自注意力机制的类型mode等。然后根据mode的不同选择不同的自注意力模块,如果mode'BAM'则使用BAM模块,如果mode'PAM'则使用PAM模块。最后使用weights_init函数初始化网络权重。

forward:前向传播函数,输入两个特征图x1x2,将它们在通道维度上进行拼接,然后将拼接后的特征图传入Self_Att模块中,得到输出特征图x。最后将x沿着通道维度进行分离,分别得到x1x2两个特征图,返回它们。

BAM.py

import torch
import torch.nn.functional as F
from torch import nn


class BAM(nn.Module):
    """ Basic self-attention module
    """

    def __init__(self, in_dim, ds=8, activation=nn.ReLU):
        super(BAM, self).__init__()
        self.chanel_in = in_dim
        self.key_channel = self.chanel_in //8
        self.activation = activation
        self.ds = ds  #
        self.pool = nn.AvgPool2d(self.ds)
        print('ds: ',ds)
        self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))

        self.softmax = nn.Softmax(dim=-1)  #

    def forward(self, input):
        """
            inputs :
                x : input feature maps( B X C X W X H)
            returns :
                out : self attention value + input feature
                attention: B X N X N (N is Width*Height)
        """
        x = self.pool(input)
        m_batchsize, C, width, height = x.size()
        proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1)  # B X C X (N)/(ds*ds)
        proj_key = self.key_conv(x).view(m_batchsize, -1, width * height)  # B X C x (*W*H)/(ds*ds)
        energy = torch.bmm(proj_query, proj_key)  # transpose check
        energy = (self.key_channel**-.5) * energy

        attention = self.softmax(energy)  # BX (N) X (N)/(ds*ds)/(ds*ds)

        proj_value = self.value_conv(x).view(m_batchsize, -1, width * height)  # B X C X N

        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(m_batchsize, C, width, height)

        out = F.interpolate(out, [width*self.ds,height*self.ds])
        out = out + input

        return out

这段代码实现了一个基础的自注意力模块(BAM)。主要包含以下几个部分:

__init__:构造函数,初始化了一些变量,如输入通道数in_dim、降采样因子ds、激活函数activation等。然后定义了三个卷积层,分别对应QueryKeyValue,其中QueryKey的输出通道数为输入通道数的1/8,Value的输出通道数为输入通道数。最后定义了可学习的参数gammasoftmax函数。

forward:前向传播函数,主要功能是将输入数据x先经过一个平均池化层进行降采样,然后将降采样后的特征图x分别传入QueryKeyValue三个卷积层中,得到对应的特征张量proj_queryproj_keyproj_value。然后计算proj_queryproj_key的转置矩阵相乘,得到能量矩阵energy。接下来,将energy除以key_channel的平方根,再使用softmax函数得到注意力矩阵attention。最后,将proj_valueattention矩阵相乘得到输出张量out,经过插值调整大小后再加上输入张量x,得到最终输出。

CDF0.py

import torch
import itertools
from .base_model import BaseModel
from . import backbone
import torch.nn.functional as F
from . import loss


class CDF0Model(BaseModel):
    """
    change detection module:
    feature extractor
    contrastive loss
    """
    @staticmethod
    def modify_commandline_options(parser, is_train=True):
        return parser

    def __init__(self, opt):
        BaseModel.__init__(self, opt)
        self.istest = opt.istest
        # specify the training losses you want to print out. The training/test scripts will call 
        self.loss_names = ['f']
        # specify the images you want to save/display. The training/test scripts will call 
        self.visual_names = ['A', 'B', 'L', 'pred_L_show']  # visualizations for A and B
        if self.istest:
            self.visual_names = ['A', 'B', 'pred_L_show']
        self.visual_features = ['feat_A', 'feat_B']
        # specify the models you want to save to the disk. The training/test scripts will call  and .
        if self.isTrain:
            self.model_names = ['F']
        else:  # during test time, only load Gs
            self.model_names = ['F']
        self.ds=1
        # define networks (both Generators and discriminators)
        self.n_class = 2
        self.netF = backbone.define_F(in_c=3, f_c=opt.f_c, type=opt.arch).to(self.device)

        if self.isTrain:
            # define loss functions
            self.criterionF = loss.BCL()
            # initialize optimizers; schedulers will be automatically created by function .
            self.optimizer_G = torch.optim.Adam(itertools.chain(self.netF.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)

    def set_input(self, input):
        """Unpack input data from the dataloader and perform necessary pre-processing steps.

        Parameters:
            input (dict): include the data itself and its metadata information.

        The option 'direction' can be used to swap domain A and domain B.
        """
        self.A = input['A'].to(self.device)
        self.B = input['B'].to(self.device)
        if not self.istest:
            self.L = input['L'].to(self.device).long()
        self.image_paths = input['A_paths']
        if self.isTrain:
            self.L_s = self.L.float()
            self.L_s = F.interpolate(self.L_s, size=torch.Size([self.A.shape[2]//self.ds, self.A.shape[3]//self.ds]),mode='nearest')
            self.L_s[self.L_s == 1] = -1  # change
            self.L_s[self.L_s == 0] = 1  # no change


    def test(self, val=False):
        """Forward function used in test time.
        This function wraps  function in no_grad() so we don't save intermediate steps for backprop
        It also calls  to produce additional visualization results
        """
        with torch.no_grad():
            self.forward()
            self.compute_visuals()
            if val:  # score
                from util.metrics import RunningMetrics
                metrics = RunningMetrics(self.n_class)
                pred = self.pred_L.long()

                metrics.update(self.L.detach().cpu().numpy(), pred.detach().cpu().numpy())
                scores = metrics.get_cm()
                return scores


    def forward(self):
        """Run forward pass; called by both functions  and ."""
        self.feat_A = self.netF(self.A)  # f(A)
        self.feat_B = self.netF(self.B)   # f(B)

        self.dist = F.pairwise_distance(self.feat_A, self.feat_B, keepdim=True)
        # print(self.dist.shape)
        self.dist = F.interpolate(self.dist, size=self.A.shape[2:], mode='bilinear',align_corners=True)
        self.pred_L = (self.dist > 1).float()
        self.pred_L_show = self.pred_L.long()
        return self.pred_L

    def backward(self):
        """Calculate the loss for generators F and L"""
        # print(self.weight)
        self.loss_f = self.criterionF(self.dist, self.L_s)

        self.loss = self.loss_f
        if torch.isnan(self.loss):
           print(self.image_paths)

        self.loss.backward()

    def optimize_parameters(self):
        """Calculate losses, gradients, and update network weights; called in every training iteration"""
        # forward
        self.forward()      # compute feat and dist

        self.optimizer_G.zero_grad()        # set G's gradients to zero
        self.backward()                   # calculate graidents for G
        self.optimizer_G.step()             # udpate G's weights

这段代码用于图像变化检测。主要包含以下几个部分:
modify_commandline_options:用于修改命令行选项,但在这里并没有进行任何修改,所以直接返回了parser

__init__:构造函数,初始化了一些变量,如loss_namesvisual_namesmodel_names等。然后定义了一个特征提取网络netF,使用了指定的backbone,也定义了损失函数和优化器。

set_input:将输入数据A、B、L转移到GPU上,并将L转换为long类型,同时根据self.ds对L进行插值操作,调整L的大小为A的大小的1/self.ds倍。

test:模型测试函数,返回预测的变化图或变化评价指标。

forward:前向传播函数,主要功能是将输入数据A、B传入特征提取网络netF,得到两个特征向量。然后计算两个特征向量之间的距离dist,并将dist插值调整为A的大小,使用二值阈值法将dist转成二值图像pred_L,得到预测的变化图pred_L。

backward:反向传播函数,计算损失,并进行反向传播。

optimize_parameters:优化器,将梯度清零,计算损失并进行反向传播,最后更新网络参数。

CDFA.py

import torch
import itertools
from .base_model import BaseModel
from . import backbone
import torch.nn.functional as F
from . import loss


class CDFAModel(BaseModel):
    """
    change detection module:
    feature extractor+ spatial-temporal-self-attention
    contrastive loss
    """
    @staticmethod
    def modify_commandline_options(parser, is_train=True):
        return parser
    def __init__(self, opt):

        BaseModel.__init__(self, opt)
        # specify the training losses you want to print out. The training/test scripts will call 
        self.loss_names = ['f']
        # specify the images you want to save/display. The training/test scripts will call 
        if opt.phase == 'test':
            self.istest = True
        self.visual_names = ['A', 'B', 'L', 'pred_L_show']  # visualizations for A and B
        if self.istest:
            self.visual_names = ['A', 'B', 'pred_L_show']  # visualizations for A and B

        self.visual_features = ['feat_A','feat_B']
        # specify the models you want to save to the disk. The training/test scripts will call  and .
        if self.isTrain:
            self.model_names = ['F','A']
        else:  # during test time, only load Gs
            self.model_names = ['F','A']
        self.istest = False
        self.ds = 1
        self.n_class =2
        self.netF = backbone.define_F(in_c=3, f_c=opt.f_c, type=opt.arch).to(self.device)
        self.netA = backbone.CDSA(in_c=opt.f_c, ds=opt.ds, mode=opt.SA_mode).to(self.device)

        if self.isTrain:
            # define loss functions
            self.criterionF = loss.BCL()

            # initialize optimizers; schedulers will be automatically created by function .
            self.optimizer_G = torch.optim.Adam(itertools.chain(
                self.netF.parameters(),
            ), lr=opt.lr*opt.lr_decay, betas=(opt.beta1, 0.999))
            self.optimizer_A = torch.optim.Adam(self.netA.parameters(), lr=opt.lr*1, betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_A)


    def set_input(self, input):
        self.A = input['A'].to(self.device)
        self.B = input['B'].to(self.device)
        if self.istest is False:
            if 'L' in input.keys():
                self.L = input['L'].to(self.device).long()

        self.image_paths = input['A_paths']
        if self.isTrain:
            self.L_s = self.L.float()
            self.L_s = F.interpolate(self.L_s, size=torch.Size([self.A.shape[2]//self.ds, self.A.shape[3]//self.ds]),mode='nearest')
            self.L_s[self.L_s == 1] = -1  # change
            self.L_s[self.L_s == 0] = 1  # no change


    def test(self, val=False):
        with torch.no_grad():
            self.forward()
            self.compute_visuals()
            if val:  # 返回score
                from util.metrics import RunningMetrics
                metrics = RunningMetrics(self.n_class)
                pred = self.pred_L.long()

                metrics.update(self.L.detach().cpu().numpy(), pred.detach().cpu().numpy())
                scores = metrics.get_cm()
                return scores
            else:
                return self.pred_L.long()

    def forward(self):
        """Run forward pass; called by both functions  and ."""
        self.feat_A = self.netF(self.A)  # f(A)
        self.feat_B = self.netF(self.B)   # f(B)

        self.feat_A, self.feat_B = self.netA(self.feat_A,self.feat_B)

        self.dist = F.pairwise_distance(self.feat_A, self.feat_B, keepdim=True)  # 特征距离

        self.dist = F.interpolate(self.dist, size=self.A.shape[2:], mode='bilinear',align_corners=True)

        self.pred_L = (self.dist > 1).float()
        # self.pred_L = F.interpolate(self.pred_L, size=self.A.shape[2:], mode='nearest')
        self.pred_L_show = self.pred_L.long()

        return self.pred_L

    def backward(self):
        self.loss_f = self.criterionF(self.dist, self.L_s)

        self.loss = self.loss_f
        # print(self.loss)
        self.loss.backward()

    def optimize_parameters(self):
        """Calculate losses, gradients, and update network weights; called in every training iteration"""
        # forward
        self.forward()      # compute feat and dist

        self.set_requires_grad([self.netF, self.netA], True)
        self.optimizer_G.zero_grad()        # set G's gradients to zero
        self.optimizer_A.zero_grad()
        self.backward()                   # calculate graidents for G
        self.optimizer_G.step()             # udpate G's weights
        self.optimizer_A.step()

这段代码实现了一个用于变化检测的模型,主要包含以下几个部分:

modify_commandline_options:该方法用于修改命令行选项,但在这里并没有进行任何修改,所以直接返回了parser

__init__:构造函数,初始化了一些变量,如loss_namesvisual_namesmodel_names等。然后定义了两个模型,一个是特征提取网络netF,另一个是空间-时间自注意力网络netA。最后定义了损失函数和优化器。

set_input:将输入数据A、B、L转移到GPU上,并将L转换为long类型,同时根据self.ds对L进行插值操作,调整L的大小为A的大小的1/self.ds倍。

test:模型测试函数,返回预测的变化图或变化评价指标。

forward:前向传播函数,主要功能是将输入数据A、B传入特征提取网络netF和空间-时间自注意力网络netA,得到两个特征向量。然后计算两个特征向量之间的距离dist,并将dist插值调整为A的大小,得到预测的变化图pred_L。

backward:反向传播函数,计算损失,并进行反向传播。

optimize_parameters:优化器,将梯度清零,计算损失并进行反向传播,最后更新网络参数。

loss.py

import torch.nn as nn
import torch


class BCL(nn.Module):
    """
    batch-balanced contrastive loss
    no-change,1
    change,-1
    """

    def __init__(self, margin=2.0):
        super(BCL, self).__init__()
        self.margin = margin

    def forward(self, distance, label):
        label[label==255] = 1
        mask = (label != 255).float()
        distance = distance * mask
        pos_num = torch.sum((label==1).float())+0.0001
        neg_num = torch.sum((label==-1).float())+0.0001

        loss_1 = torch.sum((1+label) / 2 * torch.pow(distance, 2)) /pos_num
        loss_2 = torch.sum((1-label) / 2 * mask *
            torch.pow(torch.clamp(self.margin - distance, min=0.0), 2)
        ) / neg_num
        loss = loss_1 + loss_2
        return loss

这段代码实现了一个批量平衡对比损失(batch-balanced contrastive loss)的模块BCL。主要包含以下几个部分:
__init__:构造函数,初始化超参数margin

forward:前向传播函数,输入特征向量之间的距离distance和标签label。首先将标签中值为255的像素(无效像素)替换为1,然后根据标签生成一个掩码mask,将distance中无效像素的距离值置为0。接着统计标签中正样本的数量pos_num和负样本的数量neg_num,并计算损失loss。其中,loss_1表示正样本之间的距离的均方差,loss_2表示负样本之间的距离的均方差,损失loss为两者之和。在计算loss_2时,使用了一个margin来限制负样本之间的距离,如果两个负样本之间的距离小于margin,则将它们之间的距离置为0。最后返回损失loss

mynet.py

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
import math
class F_mynet3(nn.Module):
    def __init__(self, backbone='resnet18',in_c=3, f_c=64, output_stride=8):
        self.in_c = in_c
        super(F_mynet3, self).__init__()
        self.module = mynet3(backbone=backbone, output_stride=output_stride, f_c=f_c, in_c=self.in_c)
    def forward(self, input):
        return self.module(input)
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)
model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}

def ResNet34(output_stride, BatchNorm=nn.BatchNorm2d, pretrained=True, in_c=3):
    """
    output, low_level_feat:
    512, 64
    """
    print(in_c)
    model = ResNet(BasicBlock, [3, 4, 6, 3], output_stride, BatchNorm, in_c=in_c)
    if in_c != 3:
        pretrained = False
    if pretrained:
        model._load_pretrained_model(model_urls['resnet34'])
    return model
def ResNet18(output_stride, BatchNorm=nn.BatchNorm2d, pretrained=True, in_c=3):
    """
    output, low_level_feat:
    512, 256, 128, 64, 64
    """
    model = ResNet(BasicBlock, [2, 2, 2, 2], output_stride, BatchNorm, in_c=in_c)
    if in_c !=3:
        pretrained=False
    if pretrained:
        model._load_pretrained_model(model_urls['resnet18'])
    return model
def ResNet50(output_stride, BatchNorm=nn.BatchNorm2d, pretrained=True, in_c=3):
    """
    output, low_level_feat:
    2048, 256
    """
    model = ResNet(Bottleneck, [3, 4, 6, 3], output_stride, BatchNorm, in_c=in_c)
    if in_c !=3:
        pretrained=False
    if pretrained:
        model._load_pretrained_model(model_urls['resnet50'])
    return model
class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
                               dilation=dilation, padding=dilation, bias=False)
        self.bn1 = BatchNorm(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = BatchNorm(planes)
        self.downsample = downsample
        self.stride = stride
    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        out = self.relu(out)
        return out
class Bottleneck(nn.Module):
    expansion = 4
    def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = BatchNorm(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               dilation=dilation, padding=dilation, bias=False)
        self.bn2 = BatchNorm(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = BatchNorm(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        self.dilation = dilation
    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv3(out)
        out = self.bn3(out)
        if self.downsample is not None:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out

这段代码主要包含了三个函数:ResNet18、ResNet34和ResNet50,以及两个类:BasicBlock和Bottleneck。
ResNet18、ResNet34和ResNet50是三个不同深度的ResNet模型,分别包含18、34和50个卷积层,用于图像分类、目标检测、语义分割等任务。这三个函数都接受四个参数:output_stride表示输出特征图的步长,BatchNorm表示批归一化层的类型,pretrained表示是否使用预训练模型,in_c表示输入图像的通道数。这三个函数都返回一个ResNet模型。
BasicBlock和Bottleneck是ResNet模型中的基本块。BasicBlock是ResNet18和ResNet34使用的基本块,包含两个3x3的卷积层和一个跳跃连接,用于提取图像特征。Bottleneck是ResNet50使用的基本块,包含三个卷积层(分别是1x1、3x3和1x1的卷积)和一个跳跃连接,用于提取更加深层的图像特征。

class ResNet(nn.Module):
    def init(self,  block, layers, output_stride, BatchNorm, pretrained=True, in_c=3):
        self.inplanes = 64
        self.in_c = in_c
        print('in_c: ',self.in_c)
        super(ResNet, self).init()
        blocks = [1, 2, 4]
        if output_stride == 32:
            strides = [1, 2, 2, 2]
            dilations = [1, 1, 1, 1]
        elif output_stride == 16:
            strides = [1, 2, 2, 1]
            dilations = [1, 1, 1, 2]
        elif output_stride == 8:
            strides = [1, 2, 1, 1]
            dilations = [1, 1, 2, 4]
        elif output_stride == 4:
            strides = [1, 1, 1, 1]
            dilations = [1, 2, 4, 8]
        else:
            raise NotImplementedError
        # Modules
        self.conv1 = nn.Conv2d(self.in_c, 64, kernel_size=7, stride=2, padding=3,
                                bias=False)
        self.bn1 = BatchNorm(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0], BatchNorm=BatchNorm)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1], BatchNorm=BatchNorm)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2], BatchNorm=BatchNorm)
        self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm)
        # self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm)
        self._init_weight()
    def make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                BatchNorm(planes * block.expansion),
            )
        layers = []
        layers.append(block(self.inplanes, planes, stride, dilation, downsample, BatchNorm))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm))
        return nn.Sequential(*layers)
    def make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                BatchNorm(planes * block.expansion),
            )
        layers = []
        layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation,
                            downsample=downsample, BatchNorm=BatchNorm))
        self.inplanes = planes * block.expansion
        for i in range(1, len(blocks)):
            layers.append(block(self.inplanes, planes, stride=1,
                                dilation=blocks[i]*dilation, BatchNorm=BatchNorm))
        return nn.Sequential(*layers)
    def forward(self, input):
        x = self.conv1(input)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)   # | 4
        x = self.layer1(x)  # | 4
        low_level_feat2 = x  # | 4
        x = self.layer2(x)  # | 8
        low_level_feat3 = x
        x = self.layer3(x)  # | 16
        low_level_feat4 = x
        x = self.layer4(x)  # | 32
        return x, low_level_feat2, low_level_feat3, low_level_feat4
    def init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill(1)
                m.bias.data.zero()
    def _load_pretrained_model(self, model_path):
        pretrain_dict = model_zoo.load_url(model_path)
        model_dict = {}
        state_dict = self.state_dict()
        for k, v in pretrain_dict.items():
            if k in state_dict:
                model_dict[k] = v
        state_dict.update(model_dict)
        self.load_state_dict(state_dict)
def build_backbone(backbone, output_stride, BatchNorm, in_c=3):
    if backbone == 'resnet50':
        return ResNet50(output_stride, BatchNorm, in_c=in_c)
    elif backbone == 'resnet34':
        return ResNet34(output_stride, BatchNorm, in_c=in_c)
    elif backbone == 'resnet18':
        return ResNet18(output_stride, BatchNorm, in_c=in_c)
    else:
        raise NotImplementedError

这是一个构建 ResNet 网络的代码,包括了 ResNet50、ResNet34 和 ResNet18 三种不同的网络结构,可以根据需要选择不同的结构进行使用。
其中 ResNet 的网络结构通过 _make_layer 和 _make_MG_unit 两个函数进行构建,这两个函数会根据传入的参数构建对应的 ResNet 结构,包括卷积、规范化、ReLU 激活函数等。同时在 _make_layer 和 _make_MG_unit 函数中也会构建残差块(block),这里的残差块是指 ResNet 中的基本单元,包括了卷积、规范化、ReLU 激活函数等,并且在残差块中还会进行跳跃连接,以保证信息的流畅。
在 ResNet 中,除了输入层和输出层之外,还包括了四个残差块,每个残差块中包含了多个残差单元,其中最后一个残差单元包含了下采样操作。这里的下采样操作是通过改变步长实现的,使得残差块的输出尺寸变为输入的一半。
最后的 build_backbone 函数是用来根据传入的参数构建对应的 ResNet 网络结构的,其中 backbone 是指 ResNet 的版本,output_stride 是指输出的步长,BatchNorm 是指使用哪种规范化方法,in_c 是指输入的通道数。

class DR(nn.Module):
    def init(self, in_d, out_d):
        super(DR, self).init()
        self.in_d = in_d
        self.out_d = out_d
        self.conv1 = nn.Conv2d(self.in_d, self.out_d, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(self.out_d)
        self.relu = nn.ReLU()
    def forward(self, input):
        x = self.conv1(input)
        x = self.bn1(x)
        x = self.relu(x)
        return x
class Decoder(nn.Module):
    def init(self, fc, BatchNorm):
        super(Decoder, self).init()
        self.fc = fc
        self.dr2 = DR(64, 96)
        self.dr3 = DR(128, 96)
        self.dr4 = DR(256, 96)
        self.dr5 = DR(512, 96)
        self.last_conv = nn.Sequential(nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1, bias=False),
                                       BatchNorm(256),
                                       nn.ReLU(),
                                       nn.Dropout(0.5),
                                       nn.Conv2d(256, self.fc, kernel_size=1, stride=1, padding=0, bias=False),
                                       BatchNorm(self.fc),
                                       nn.ReLU(),
                                       )
        self.init_weight()
    def forward(self, x,low_level_feat2, low_level_feat3, low_level_feat4):
        x2 = self.dr2(low_level_feat2)
        x3 = self.dr3(low_level_feat3)
        x4 = self.dr4(low_level_feat4)
        x = self.dr5(x)
        x = F.interpolate(x, size=x2.size()[2:], mode='bilinear', align_corners=True)
        x3 = F.interpolate(x3, size=x2.size()[2:], mode='bilinear', align_corners=True)
        x4 = F.interpolate(x4, size=x2.size()[2:], mode='bilinear', align_corners=True)
        x = torch.cat((x, x2, x3, x4), dim=1)
        x = self.last_conv(x)
        return x
    def init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill(1)
                m.bias.data.zero_()
def build_decoder(fc, backbone, BatchNorm):
    return Decoder(fc, BatchNorm)
class mynet3(nn.Module):
    def init(self, backbone='resnet18', output_stride=16, f_c=64, freeze_bn=False, in_c=3):
        super(mynet3, self).init()
        print('arch: mynet3')
        BatchNorm = nn.BatchNorm2d
        self.backbone = build_backbone(backbone, output_stride, BatchNorm, in_c)
        self.decoder = build_decoder(f_c, backbone, BatchNorm)
        if freeze_bn:
            self.freeze_bn()
    def forward(self, input):
        x, f2, f3, f4 = self.backbone(input)
        x = self.decoder(x, f2, f3, f4)
        return x
    def freeze_bn(self):
        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()

这段代码主要包含三个类:DR、Decoder和mynet3。
DR类定义了一个深度卷积层,包含一个1x1的卷积操作、一个批归一化操作和一个ReLU激活函数,用于特征提取。
Decoder类定义了一个解码器,包含四个DR层,以及一个卷积层和一个批归一化层,用于将高级特征映射回原始图片。
mynet3类是整个模型的主体,包括一个backbone和一个decoder。其中,backbone是通过build_backbone函数构建的,用于提取原始图片的高级特征,decoder使用Decoder类实现,将高级特征映射回原始图片。同时,还提供了一个freeze_bn函数,用于冻结所有批归一化层的参数,以加速模型的训练。

PAM2.py

import torch
import torch.nn.functional as F
from torch import nn


class _PAMBlock(nn.Module):
    '''
    The basic implementation for self-attention block/non-local block
    Input/Output:
        N * C  *  H  *  (2*W)
    Parameters:
        in_channels       : the dimension of the input feature map
        key_channels      : the dimension after the key/query transform
        value_channels    : the dimension after the value transform
        scale             : choose the scale to partition the input feature maps
        ds                : downsampling scale
    '''
    def __init__(self, in_channels, key_channels, value_channels, scale=1, ds=1):
        super(_PAMBlock, self).__init__()
        self.scale = scale
        self.ds = ds
        self.pool = nn.AvgPool2d(self.ds)
        self.in_channels = in_channels
        self.key_channels = key_channels
        self.value_channels = value_channels

        self.f_key = nn.Sequential(
            nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
                kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(self.key_channels)
        )
        self.f_query = nn.Sequential(
            nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
                kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(self.key_channels)
        )
        self.f_value = nn.Conv2d(in_channels=self.in_channels, out_channels=self.value_channels,
            kernel_size=1, stride=1, padding=0)



    def forward(self, input):
        x = input
        if self.ds != 1:
            x = self.pool(input)
        # input shape: b,c,h,2w
        batch_size, c, h, w = x.size(0), x.size(1), x.size(2), x.size(3)//2


        local_y = []
        local_x = []
        step_h, step_w = h//self.scale, w//self.scale
        for i in range(0, self.scale):
            for j in range(0, self.scale):
                start_x, start_y = i*step_h, j*step_w
                end_x, end_y = min(start_x+step_h, h), min(start_y+step_w, w)
                if i == (self.scale-1):
                    end_x = h
                if j == (self.scale-1):
                    end_y = w
                local_x += [start_x, end_x]
                local_y += [start_y, end_y]

        value = self.f_value(x)
        query = self.f_query(x)
        key = self.f_key(x)

        value = torch.stack([value[:, :, :, :w], value[:,:,:,w:]], 4)  # B*N*H*W*2
        query = torch.stack([query[:, :, :, :w], query[:,:,:,w:]], 4)  # B*N*H*W*2
        key = torch.stack([key[:, :, :, :w], key[:,:,:,w:]], 4)  # B*N*H*W*2


        local_block_cnt = 2*self.scale*self.scale

        #  self-attention func
        def func(value_local, query_local, key_local):
            batch_size_new = value_local.size(0)
            h_local, w_local = value_local.size(2), value_local.size(3)
            value_local = value_local.contiguous().view(batch_size_new, self.value_channels, -1)

            query_local = query_local.contiguous().view(batch_size_new, self.key_channels, -1)
            query_local = query_local.permute(0, 2, 1)
            key_local = key_local.contiguous().view(batch_size_new, self.key_channels, -1)

            sim_map = torch.bmm(query_local, key_local)  # batch matrix multiplication
            sim_map = (self.key_channels**-.5) * sim_map
            sim_map = F.softmax(sim_map, dim=-1)

            context_local = torch.bmm(value_local, sim_map.permute(0,2,1))
            # context_local = context_local.permute(0, 2, 1).contiguous()
            context_local = context_local.view(batch_size_new, self.value_channels, h_local, w_local, 2)
            return context_local

        #  Parallel Computing to speed up
        #  reshape value_local, q, k
        v_list = [value[:,:,local_x[i]:local_x[i+1],local_y[i]:local_y[i+1]] for i in range(0, local_block_cnt, 2)]
        v_locals = torch.cat(v_list,dim=0)
        q_list = [query[:,:,local_x[i]:local_x[i+1],local_y[i]:local_y[i+1]] for i in range(0, local_block_cnt, 2)]
        q_locals = torch.cat(q_list,dim=0)
        k_list = [key[:,:,local_x[i]:local_x[i+1],local_y[i]:local_y[i+1]] for i in range(0, local_block_cnt, 2)]
        k_locals = torch.cat(k_list,dim=0)
        # print(v_locals.shape)
        context_locals = func(v_locals,q_locals,k_locals)

        context_list = []
        for i in range(0, self.scale):
            row_tmp = []
            for j in range(0, self.scale):
                left = batch_size*(j+i*self.scale)
                right = batch_size*(j+i*self.scale) + batch_size
                tmp = context_locals[left:right]
                row_tmp.append(tmp)
            context_list.append(torch.cat(row_tmp, 3))

        context = torch.cat(context_list, 2)
        context = torch.cat([context[:,:,:,:,0],context[:,:,:,:,1]],3)


        if self.ds !=1:
            context = F.interpolate(context, [h*self.ds, 2*w*self.ds])

        return context


class PAMBlock(_PAMBlock):
    def __init__(self, in_channels, key_channels=None, value_channels=None, scale=1, ds=1):
        if key_channels == None:
            key_channels = in_channels//8
        if value_channels == None:
            value_channels = in_channels
        super(PAMBlock, self).__init__(in_channels,key_channels,value_channels,scale,ds)


class PAM(nn.Module):
    """
        PAM module
    """

    def __init__(self, in_channels, out_channels, sizes=([1]), ds=1):
        super(PAM, self).__init__()
        self.group = len(sizes)
        self.stages = []
        self.ds = ds  # output stride
        self.value_channels = out_channels
        self.key_channels = out_channels // 8


        self.stages = nn.ModuleList(
            [self._make_stage(in_channels, self.key_channels, self.value_channels, size, self.ds)
             for size in sizes])
        self.conv_bn = nn.Sequential(
            nn.Conv2d(in_channels * self.group, out_channels, kernel_size=1, padding=0,bias=False),
            # nn.BatchNorm2d(out_channels),
        )

    def _make_stage(self, in_channels, key_channels, value_channels, size, ds):
        return PAMBlock(in_channels,key_channels,value_channels,size,ds)

    def forward(self, feats):
        priors = [stage(feats) for stage in self.stages]

        #  concat
        context = []
        for i in range(0, len(priors)):
            context += [priors[i]]
        output = self.conv_bn(torch.cat(context, 1))

        return output

这是一个PAM(Position Attention Module)模块的实现,用于处理图像特征,提高模型在特定任务上的性能。
具体来说,这个模块包含了一个_PAMBlock类和一个PAM类。_PAMBlock类是PAM模块的基本组成部分,它包含了输入通道数、关键通道数、值通道数、缩放比例和下采样比例等参数,并实现了一个前向传播函数forward。在这个函数中,输入张量经过一系列特征提取操作后,被分成了多个局部块,每个局部块都是一个子张量,然后每个子张量分别通过一个函数func进行处理,得到了一个上下文张量context_local。最后,这些上下文张量context_local被拼接成一个输出张量context并返回。这个类的主要作用是实现PAM模块的核心机制。
PAM类则是对_PAMBlock类的一个封装,它接收输入张量和输出通道数,并根据输入张量的不同区域大小,构建了多个_PAMBlock实例。最终,这些实例提取出的特征被拼接在一起,并通过一个1x1卷积层得到最终的输出。
总的来说,这个模块的作用是将输入张量分成多个局部块,然后在每个局部块上计算上下文特征。这种机制可以有效地提高模型对特定任务的性能,同时也可以增强模型在局部区域的感知能力。

你可能感兴趣的:(#,STANet复现,深度学习,python,计算机视觉)