import importlib
from models.base_model import BaseModel
def find_model_using_name(model_name):
"""Import the module "models/[model_name]_model.py".
In the file, the class called DatasetNameModel() will
be instantiated. It has to be a subclass of BaseModel,
and it is case-insensitive.
"""
model_filename = "models." + model_name + "_model"
modellib = importlib.import_module(model_filename)
model = None
target_model_name = model_name.replace('_', '') + 'model'
for name, cls in modellib.__dict__.items():
if name.lower() == target_model_name.lower() \
and issubclass(cls, BaseModel):
model = cls
if model is None:
print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
exit(0)
return model
def get_option_setter(model_name):
"""Return the static method of the model class."""
model_class = find_model_using_name(model_name)
return model_class.modify_commandline_options
def create_model(opt):
"""Create a model given the option.
This function warps the class CustomDatasetDataLoader.
This is the main interface between this package and 'train.py'/'test.py'
Example:
>>> from models import create_model
>>> model = create_model(opt)
"""
model = find_model_using_name(opt.model)
instance = model(opt)
print("model [%s] was created" % type(instance).__name__)
return instance
这段代码实现了一个用于创建模型的函数create_model
,主要包含以下几个部分:
find_model_using_name
:根据模型名称找到对应的模型类。首先根据模型名称构造模型文件名,然后动态加载模型文件对应的模块,并遍历模块中的所有类,找到类名与模型名称匹配的模型类。
get_option_setter
:返回模型类的modify_commandline_options
静态方法,用于修改命令行参数。
create_model
:根据命令行参数opt
创建模型实例。首先调用find_model_using_name
函数找到对应的模型类,然后创建该模型类的实例instance
,并返回。
# coding: utf-8
import torch.nn as nn
import torch
from .mynet3 import F_mynet3
from .BAM import BAM
from .PAM2 import PAM as PAM
def define_F(in_c, f_c, type='unet'):
if type == 'mynet3':
print("using mynet3 backbone")
return F_mynet3(backbone='resnet18', in_c=in_c,f_c=f_c, output_stride=32)
else:
NotImplementedError('no such F type!')
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
class CDSA(nn.Module):
"""self attention module for change detection
"""
def __init__(self, in_c, ds=1, mode='BAM'):
super(CDSA, self).__init__()
self.in_C = in_c
self.ds = ds
print('ds: ',self.ds)
self.mode = mode
if self.mode == 'BAM':
self.Self_Att = BAM(self.in_C, ds=self.ds)
elif self.mode == 'PAM':
self.Self_Att = PAM(in_channels=self.in_C, out_channels=self.in_C, sizes=[1,2,4,8],ds=self.ds)
self.apply(weights_init)
def forward(self, x1, x2):
height = x1.shape[3]
x = torch.cat((x1, x2), 3)
x = self.Self_Att(x)
return x[:,:,:,0:height], x[:,:,:,height:]
这段代码实现了一个用于变化检测的自注意力模块CDSA(Change Detection Self Attention)。主要包含以下几个部分:
define_F
:定义一个特征提取网络F,根据type参数的不同可以选择不同的backbone
,如果type
为'mynet3'
则使用F_mynet3
网络。
weights_init
:初始化网络权重的函数,用于给网络的卷积层和批归一化层设置初始值(卷积层为均值为0,标准差为0.02的正态分布,批归一化层的权重为1,偏置为0)。
__init__
:构造函数,初始化了一些变量,如输入通道数in_c
、降采样因子ds
、自注意力机制的类型mode
等。然后根据mode
的不同选择不同的自注意力模块,如果mode
为'BAM'
则使用BAM模块,如果mode
为'PAM'
则使用PAM模块。最后使用weights_init
函数初始化网络权重。
forward
:前向传播函数,输入两个特征图x1
、x2
,将它们在通道维度上进行拼接,然后将拼接后的特征图传入Self_Att
模块中,得到输出特征图x
。最后将x
沿着通道维度进行分离,分别得到x1
、x2
两个特征图,返回它们。
import torch
import torch.nn.functional as F
from torch import nn
class BAM(nn.Module):
""" Basic self-attention module
"""
def __init__(self, in_dim, ds=8, activation=nn.ReLU):
super(BAM, self).__init__()
self.chanel_in = in_dim
self.key_channel = self.chanel_in //8
self.activation = activation
self.ds = ds #
self.pool = nn.AvgPool2d(self.ds)
print('ds: ',ds)
self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1) #
def forward(self, input):
"""
inputs :
x : input feature maps( B X C X W X H)
returns :
out : self attention value + input feature
attention: B X N X N (N is Width*Height)
"""
x = self.pool(input)
m_batchsize, C, width, height = x.size()
proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1) # B X C X (N)/(ds*ds)
proj_key = self.key_conv(x).view(m_batchsize, -1, width * height) # B X C x (*W*H)/(ds*ds)
energy = torch.bmm(proj_query, proj_key) # transpose check
energy = (self.key_channel**-.5) * energy
attention = self.softmax(energy) # BX (N) X (N)/(ds*ds)/(ds*ds)
proj_value = self.value_conv(x).view(m_batchsize, -1, width * height) # B X C X N
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
out = out.view(m_batchsize, C, width, height)
out = F.interpolate(out, [width*self.ds,height*self.ds])
out = out + input
return out
这段代码实现了一个基础的自注意力模块(BAM)。主要包含以下几个部分:
__init__
:构造函数,初始化了一些变量,如输入通道数in_dim
、降采样因子ds
、激活函数activation
等。然后定义了三个卷积层,分别对应Query
、Key
、Value
,其中Query
和Key
的输出通道数为输入通道数的1/8,Value
的输出通道数为输入通道数。最后定义了可学习的参数gamma
和softmax
函数。
forward
:前向传播函数,主要功能是将输入数据x
先经过一个平均池化层进行降采样,然后将降采样后的特征图x
分别传入Query
、Key
、Value
三个卷积层中,得到对应的特征张量proj_query
、proj_key
、proj_value
。然后计算proj_query
和proj_key
的转置矩阵相乘,得到能量矩阵energy
。接下来,将energy
除以key_channel
的平方根,再使用softmax
函数得到注意力矩阵attention
。最后,将proj_value
和attention
矩阵相乘得到输出张量out
,经过插值调整大小后再加上输入张量x
,得到最终输出。
import torch
import itertools
from .base_model import BaseModel
from . import backbone
import torch.nn.functional as F
from . import loss
class CDF0Model(BaseModel):
"""
change detection module:
feature extractor
contrastive loss
"""
@staticmethod
def modify_commandline_options(parser, is_train=True):
return parser
def __init__(self, opt):
BaseModel.__init__(self, opt)
self.istest = opt.istest
# specify the training losses you want to print out. The training/test scripts will call
self.loss_names = ['f']
# specify the images you want to save/display. The training/test scripts will call
self.visual_names = ['A', 'B', 'L', 'pred_L_show'] # visualizations for A and B
if self.istest:
self.visual_names = ['A', 'B', 'pred_L_show']
self.visual_features = ['feat_A', 'feat_B']
# specify the models you want to save to the disk. The training/test scripts will call and .
if self.isTrain:
self.model_names = ['F']
else: # during test time, only load Gs
self.model_names = ['F']
self.ds=1
# define networks (both Generators and discriminators)
self.n_class = 2
self.netF = backbone.define_F(in_c=3, f_c=opt.f_c, type=opt.arch).to(self.device)
if self.isTrain:
# define loss functions
self.criterionF = loss.BCL()
# initialize optimizers; schedulers will be automatically created by function .
self.optimizer_G = torch.optim.Adam(itertools.chain(self.netF.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizers.append(self.optimizer_G)
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
input (dict): include the data itself and its metadata information.
The option 'direction' can be used to swap domain A and domain B.
"""
self.A = input['A'].to(self.device)
self.B = input['B'].to(self.device)
if not self.istest:
self.L = input['L'].to(self.device).long()
self.image_paths = input['A_paths']
if self.isTrain:
self.L_s = self.L.float()
self.L_s = F.interpolate(self.L_s, size=torch.Size([self.A.shape[2]//self.ds, self.A.shape[3]//self.ds]),mode='nearest')
self.L_s[self.L_s == 1] = -1 # change
self.L_s[self.L_s == 0] = 1 # no change
def test(self, val=False):
"""Forward function used in test time.
This function wraps function in no_grad() so we don't save intermediate steps for backprop
It also calls to produce additional visualization results
"""
with torch.no_grad():
self.forward()
self.compute_visuals()
if val: # score
from util.metrics import RunningMetrics
metrics = RunningMetrics(self.n_class)
pred = self.pred_L.long()
metrics.update(self.L.detach().cpu().numpy(), pred.detach().cpu().numpy())
scores = metrics.get_cm()
return scores
def forward(self):
"""Run forward pass; called by both functions and ."""
self.feat_A = self.netF(self.A) # f(A)
self.feat_B = self.netF(self.B) # f(B)
self.dist = F.pairwise_distance(self.feat_A, self.feat_B, keepdim=True)
# print(self.dist.shape)
self.dist = F.interpolate(self.dist, size=self.A.shape[2:], mode='bilinear',align_corners=True)
self.pred_L = (self.dist > 1).float()
self.pred_L_show = self.pred_L.long()
return self.pred_L
def backward(self):
"""Calculate the loss for generators F and L"""
# print(self.weight)
self.loss_f = self.criterionF(self.dist, self.L_s)
self.loss = self.loss_f
if torch.isnan(self.loss):
print(self.image_paths)
self.loss.backward()
def optimize_parameters(self):
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
# forward
self.forward() # compute feat and dist
self.optimizer_G.zero_grad() # set G's gradients to zero
self.backward() # calculate graidents for G
self.optimizer_G.step() # udpate G's weights
这段代码用于图像变化检测。主要包含以下几个部分:
modify_commandline_options
:用于修改命令行选项,但在这里并没有进行任何修改,所以直接返回了parser
。
__init__
:构造函数,初始化了一些变量,如loss_names
、visual_names
、model_names
等。然后定义了一个特征提取网络netF,使用了指定的backbone,也定义了损失函数和优化器。
set_input
:将输入数据A、B、L转移到GPU上,并将L转换为long类型,同时根据self.ds对L进行插值操作,调整L的大小为A的大小的1/self.ds倍。
test
:模型测试函数,返回预测的变化图或变化评价指标。
forward
:前向传播函数,主要功能是将输入数据A、B传入特征提取网络netF,得到两个特征向量。然后计算两个特征向量之间的距离dist,并将dist插值调整为A的大小,使用二值阈值法将dist转成二值图像pred_L,得到预测的变化图pred_L。
backward
:反向传播函数,计算损失,并进行反向传播。
optimize_parameters
:优化器,将梯度清零,计算损失并进行反向传播,最后更新网络参数。
import torch
import itertools
from .base_model import BaseModel
from . import backbone
import torch.nn.functional as F
from . import loss
class CDFAModel(BaseModel):
"""
change detection module:
feature extractor+ spatial-temporal-self-attention
contrastive loss
"""
@staticmethod
def modify_commandline_options(parser, is_train=True):
return parser
def __init__(self, opt):
BaseModel.__init__(self, opt)
# specify the training losses you want to print out. The training/test scripts will call
self.loss_names = ['f']
# specify the images you want to save/display. The training/test scripts will call
if opt.phase == 'test':
self.istest = True
self.visual_names = ['A', 'B', 'L', 'pred_L_show'] # visualizations for A and B
if self.istest:
self.visual_names = ['A', 'B', 'pred_L_show'] # visualizations for A and B
self.visual_features = ['feat_A','feat_B']
# specify the models you want to save to the disk. The training/test scripts will call and .
if self.isTrain:
self.model_names = ['F','A']
else: # during test time, only load Gs
self.model_names = ['F','A']
self.istest = False
self.ds = 1
self.n_class =2
self.netF = backbone.define_F(in_c=3, f_c=opt.f_c, type=opt.arch).to(self.device)
self.netA = backbone.CDSA(in_c=opt.f_c, ds=opt.ds, mode=opt.SA_mode).to(self.device)
if self.isTrain:
# define loss functions
self.criterionF = loss.BCL()
# initialize optimizers; schedulers will be automatically created by function .
self.optimizer_G = torch.optim.Adam(itertools.chain(
self.netF.parameters(),
), lr=opt.lr*opt.lr_decay, betas=(opt.beta1, 0.999))
self.optimizer_A = torch.optim.Adam(self.netA.parameters(), lr=opt.lr*1, betas=(opt.beta1, 0.999))
self.optimizers.append(self.optimizer_G)
self.optimizers.append(self.optimizer_A)
def set_input(self, input):
self.A = input['A'].to(self.device)
self.B = input['B'].to(self.device)
if self.istest is False:
if 'L' in input.keys():
self.L = input['L'].to(self.device).long()
self.image_paths = input['A_paths']
if self.isTrain:
self.L_s = self.L.float()
self.L_s = F.interpolate(self.L_s, size=torch.Size([self.A.shape[2]//self.ds, self.A.shape[3]//self.ds]),mode='nearest')
self.L_s[self.L_s == 1] = -1 # change
self.L_s[self.L_s == 0] = 1 # no change
def test(self, val=False):
with torch.no_grad():
self.forward()
self.compute_visuals()
if val: # 返回score
from util.metrics import RunningMetrics
metrics = RunningMetrics(self.n_class)
pred = self.pred_L.long()
metrics.update(self.L.detach().cpu().numpy(), pred.detach().cpu().numpy())
scores = metrics.get_cm()
return scores
else:
return self.pred_L.long()
def forward(self):
"""Run forward pass; called by both functions and ."""
self.feat_A = self.netF(self.A) # f(A)
self.feat_B = self.netF(self.B) # f(B)
self.feat_A, self.feat_B = self.netA(self.feat_A,self.feat_B)
self.dist = F.pairwise_distance(self.feat_A, self.feat_B, keepdim=True) # 特征距离
self.dist = F.interpolate(self.dist, size=self.A.shape[2:], mode='bilinear',align_corners=True)
self.pred_L = (self.dist > 1).float()
# self.pred_L = F.interpolate(self.pred_L, size=self.A.shape[2:], mode='nearest')
self.pred_L_show = self.pred_L.long()
return self.pred_L
def backward(self):
self.loss_f = self.criterionF(self.dist, self.L_s)
self.loss = self.loss_f
# print(self.loss)
self.loss.backward()
def optimize_parameters(self):
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
# forward
self.forward() # compute feat and dist
self.set_requires_grad([self.netF, self.netA], True)
self.optimizer_G.zero_grad() # set G's gradients to zero
self.optimizer_A.zero_grad()
self.backward() # calculate graidents for G
self.optimizer_G.step() # udpate G's weights
self.optimizer_A.step()
这段代码实现了一个用于变化检测的模型,主要包含以下几个部分:
modify_commandline_options
:该方法用于修改命令行选项,但在这里并没有进行任何修改,所以直接返回了parser
。
__init__
:构造函数,初始化了一些变量,如loss_names
、visual_names
、model_names
等。然后定义了两个模型,一个是特征提取网络netF,另一个是空间-时间自注意力网络netA。最后定义了损失函数和优化器。
set_input
:将输入数据A、B、L转移到GPU上,并将L转换为long类型,同时根据self.ds对L进行插值操作,调整L的大小为A的大小的1/self.ds倍。
test
:模型测试函数,返回预测的变化图或变化评价指标。
forward
:前向传播函数,主要功能是将输入数据A、B传入特征提取网络netF和空间-时间自注意力网络netA,得到两个特征向量。然后计算两个特征向量之间的距离dist,并将dist插值调整为A的大小,得到预测的变化图pred_L。
backward
:反向传播函数,计算损失,并进行反向传播。
optimize_parameters
:优化器,将梯度清零,计算损失并进行反向传播,最后更新网络参数。
import torch.nn as nn
import torch
class BCL(nn.Module):
"""
batch-balanced contrastive loss
no-change,1
change,-1
"""
def __init__(self, margin=2.0):
super(BCL, self).__init__()
self.margin = margin
def forward(self, distance, label):
label[label==255] = 1
mask = (label != 255).float()
distance = distance * mask
pos_num = torch.sum((label==1).float())+0.0001
neg_num = torch.sum((label==-1).float())+0.0001
loss_1 = torch.sum((1+label) / 2 * torch.pow(distance, 2)) /pos_num
loss_2 = torch.sum((1-label) / 2 * mask *
torch.pow(torch.clamp(self.margin - distance, min=0.0), 2)
) / neg_num
loss = loss_1 + loss_2
return loss
这段代码实现了一个批量平衡对比损失(batch-balanced contrastive loss)的模块BCL。主要包含以下几个部分:
__init__
:构造函数,初始化超参数margin
。
forward
:前向传播函数,输入特征向量之间的距离distance
和标签label
。首先将标签中值为255的像素(无效像素)替换为1,然后根据标签生成一个掩码mask
,将distance
中无效像素的距离值置为0。接着统计标签中正样本的数量pos_num
和负样本的数量neg_num
,并计算损失loss
。其中,loss_1
表示正样本之间的距离的均方差,loss_2
表示负样本之间的距离的均方差,损失loss
为两者之和。在计算loss_2
时,使用了一个margin
来限制负样本之间的距离,如果两个负样本之间的距离小于margin
,则将它们之间的距离置为0。最后返回损失loss
。
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
import math
class F_mynet3(nn.Module):
def __init__(self, backbone='resnet18',in_c=3, f_c=64, output_stride=8):
self.in_c = in_c
super(F_mynet3, self).__init__()
self.module = mynet3(backbone=backbone, output_stride=output_stride, f_c=f_c, in_c=self.in_c)
def forward(self, input):
return self.module(input)
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}
def ResNet34(output_stride, BatchNorm=nn.BatchNorm2d, pretrained=True, in_c=3):
"""
output, low_level_feat:
512, 64
"""
print(in_c)
model = ResNet(BasicBlock, [3, 4, 6, 3], output_stride, BatchNorm, in_c=in_c)
if in_c != 3:
pretrained = False
if pretrained:
model._load_pretrained_model(model_urls['resnet34'])
return model
def ResNet18(output_stride, BatchNorm=nn.BatchNorm2d, pretrained=True, in_c=3):
"""
output, low_level_feat:
512, 256, 128, 64, 64
"""
model = ResNet(BasicBlock, [2, 2, 2, 2], output_stride, BatchNorm, in_c=in_c)
if in_c !=3:
pretrained=False
if pretrained:
model._load_pretrained_model(model_urls['resnet18'])
return model
def ResNet50(output_stride, BatchNorm=nn.BatchNorm2d, pretrained=True, in_c=3):
"""
output, low_level_feat:
2048, 256
"""
model = ResNet(Bottleneck, [3, 4, 6, 3], output_stride, BatchNorm, in_c=in_c)
if in_c !=3:
pretrained=False
if pretrained:
model._load_pretrained_model(model_urls['resnet50'])
return model
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
dilation=dilation, padding=dilation, bias=False)
self.bn1 = BatchNorm(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = BatchNorm(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = BatchNorm(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
dilation=dilation, padding=dilation, bias=False)
self.bn2 = BatchNorm(planes)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = BatchNorm(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
self.dilation = dilation
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
这段代码主要包含了三个函数:ResNet18、ResNet34和ResNet50,以及两个类:BasicBlock和Bottleneck。
ResNet18、ResNet34和ResNet50是三个不同深度的ResNet模型,分别包含18、34和50个卷积层,用于图像分类、目标检测、语义分割等任务。这三个函数都接受四个参数:output_stride表示输出特征图的步长,BatchNorm表示批归一化层的类型,pretrained表示是否使用预训练模型,in_c表示输入图像的通道数。这三个函数都返回一个ResNet模型。
BasicBlock和Bottleneck是ResNet模型中的基本块。BasicBlock是ResNet18和ResNet34使用的基本块,包含两个3x3的卷积层和一个跳跃连接,用于提取图像特征。Bottleneck是ResNet50使用的基本块,包含三个卷积层(分别是1x1、3x3和1x1的卷积)和一个跳跃连接,用于提取更加深层的图像特征。
class ResNet(nn.Module):
def init(self, block, layers, output_stride, BatchNorm, pretrained=True, in_c=3):
self.inplanes = 64
self.in_c = in_c
print('in_c: ',self.in_c)
super(ResNet, self).init()
blocks = [1, 2, 4]
if output_stride == 32:
strides = [1, 2, 2, 2]
dilations = [1, 1, 1, 1]
elif output_stride == 16:
strides = [1, 2, 2, 1]
dilations = [1, 1, 1, 2]
elif output_stride == 8:
strides = [1, 2, 1, 1]
dilations = [1, 1, 2, 4]
elif output_stride == 4:
strides = [1, 1, 1, 1]
dilations = [1, 2, 4, 8]
else:
raise NotImplementedError
# Modules
self.conv1 = nn.Conv2d(self.in_c, 64, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = BatchNorm(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0], BatchNorm=BatchNorm)
self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1], BatchNorm=BatchNorm)
self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2], BatchNorm=BatchNorm)
self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm)
# self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm)
self._init_weight()
def make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
BatchNorm(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, dilation, downsample, BatchNorm))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm))
return nn.Sequential(*layers)
def make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
BatchNorm(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation,
downsample=downsample, BatchNorm=BatchNorm))
self.inplanes = planes * block.expansion
for i in range(1, len(blocks)):
layers.append(block(self.inplanes, planes, stride=1,
dilation=blocks[i]*dilation, BatchNorm=BatchNorm))
return nn.Sequential(*layers)
def forward(self, input):
x = self.conv1(input)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x) # | 4
x = self.layer1(x) # | 4
low_level_feat2 = x # | 4
x = self.layer2(x) # | 8
low_level_feat3 = x
x = self.layer3(x) # | 16
low_level_feat4 = x
x = self.layer4(x) # | 32
return x, low_level_feat2, low_level_feat3, low_level_feat4
def init_weight(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill(1)
m.bias.data.zero()
def _load_pretrained_model(self, model_path):
pretrain_dict = model_zoo.load_url(model_path)
model_dict = {}
state_dict = self.state_dict()
for k, v in pretrain_dict.items():
if k in state_dict:
model_dict[k] = v
state_dict.update(model_dict)
self.load_state_dict(state_dict)
def build_backbone(backbone, output_stride, BatchNorm, in_c=3):
if backbone == 'resnet50':
return ResNet50(output_stride, BatchNorm, in_c=in_c)
elif backbone == 'resnet34':
return ResNet34(output_stride, BatchNorm, in_c=in_c)
elif backbone == 'resnet18':
return ResNet18(output_stride, BatchNorm, in_c=in_c)
else:
raise NotImplementedError
这是一个构建 ResNet 网络的代码,包括了 ResNet50、ResNet34 和 ResNet18 三种不同的网络结构,可以根据需要选择不同的结构进行使用。
其中 ResNet 的网络结构通过 _make_layer 和 _make_MG_unit 两个函数进行构建,这两个函数会根据传入的参数构建对应的 ResNet 结构,包括卷积、规范化、ReLU 激活函数等。同时在 _make_layer 和 _make_MG_unit 函数中也会构建残差块(block),这里的残差块是指 ResNet 中的基本单元,包括了卷积、规范化、ReLU 激活函数等,并且在残差块中还会进行跳跃连接,以保证信息的流畅。
在 ResNet 中,除了输入层和输出层之外,还包括了四个残差块,每个残差块中包含了多个残差单元,其中最后一个残差单元包含了下采样操作。这里的下采样操作是通过改变步长实现的,使得残差块的输出尺寸变为输入的一半。
最后的 build_backbone 函数是用来根据传入的参数构建对应的 ResNet 网络结构的,其中 backbone 是指 ResNet 的版本,output_stride 是指输出的步长,BatchNorm 是指使用哪种规范化方法,in_c 是指输入的通道数。
class DR(nn.Module):
def init(self, in_d, out_d):
super(DR, self).init()
self.in_d = in_d
self.out_d = out_d
self.conv1 = nn.Conv2d(self.in_d, self.out_d, 1, bias=False)
self.bn1 = nn.BatchNorm2d(self.out_d)
self.relu = nn.ReLU()
def forward(self, input):
x = self.conv1(input)
x = self.bn1(x)
x = self.relu(x)
return x
class Decoder(nn.Module):
def init(self, fc, BatchNorm):
super(Decoder, self).init()
self.fc = fc
self.dr2 = DR(64, 96)
self.dr3 = DR(128, 96)
self.dr4 = DR(256, 96)
self.dr5 = DR(512, 96)
self.last_conv = nn.Sequential(nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1, bias=False),
BatchNorm(256),
nn.ReLU(),
nn.Dropout(0.5),
nn.Conv2d(256, self.fc, kernel_size=1, stride=1, padding=0, bias=False),
BatchNorm(self.fc),
nn.ReLU(),
)
self.init_weight()
def forward(self, x,low_level_feat2, low_level_feat3, low_level_feat4):
x2 = self.dr2(low_level_feat2)
x3 = self.dr3(low_level_feat3)
x4 = self.dr4(low_level_feat4)
x = self.dr5(x)
x = F.interpolate(x, size=x2.size()[2:], mode='bilinear', align_corners=True)
x3 = F.interpolate(x3, size=x2.size()[2:], mode='bilinear', align_corners=True)
x4 = F.interpolate(x4, size=x2.size()[2:], mode='bilinear', align_corners=True)
x = torch.cat((x, x2, x3, x4), dim=1)
x = self.last_conv(x)
return x
def init_weight(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
torch.nn.init.kaiming_normal(m.weight)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill(1)
m.bias.data.zero_()
def build_decoder(fc, backbone, BatchNorm):
return Decoder(fc, BatchNorm)
class mynet3(nn.Module):
def init(self, backbone='resnet18', output_stride=16, f_c=64, freeze_bn=False, in_c=3):
super(mynet3, self).init()
print('arch: mynet3')
BatchNorm = nn.BatchNorm2d
self.backbone = build_backbone(backbone, output_stride, BatchNorm, in_c)
self.decoder = build_decoder(f_c, backbone, BatchNorm)
if freeze_bn:
self.freeze_bn()
def forward(self, input):
x, f2, f3, f4 = self.backbone(input)
x = self.decoder(x, f2, f3, f4)
return x
def freeze_bn(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
这段代码主要包含三个类:DR、Decoder和mynet3。
DR类定义了一个深度卷积层,包含一个1x1的卷积操作、一个批归一化操作和一个ReLU激活函数,用于特征提取。
Decoder类定义了一个解码器,包含四个DR层,以及一个卷积层和一个批归一化层,用于将高级特征映射回原始图片。
mynet3类是整个模型的主体,包括一个backbone和一个decoder。其中,backbone是通过build_backbone函数构建的,用于提取原始图片的高级特征,decoder使用Decoder类实现,将高级特征映射回原始图片。同时,还提供了一个freeze_bn函数,用于冻结所有批归一化层的参数,以加速模型的训练。
import torch
import torch.nn.functional as F
from torch import nn
class _PAMBlock(nn.Module):
'''
The basic implementation for self-attention block/non-local block
Input/Output:
N * C * H * (2*W)
Parameters:
in_channels : the dimension of the input feature map
key_channels : the dimension after the key/query transform
value_channels : the dimension after the value transform
scale : choose the scale to partition the input feature maps
ds : downsampling scale
'''
def __init__(self, in_channels, key_channels, value_channels, scale=1, ds=1):
super(_PAMBlock, self).__init__()
self.scale = scale
self.ds = ds
self.pool = nn.AvgPool2d(self.ds)
self.in_channels = in_channels
self.key_channels = key_channels
self.value_channels = value_channels
self.f_key = nn.Sequential(
nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(self.key_channels)
)
self.f_query = nn.Sequential(
nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(self.key_channels)
)
self.f_value = nn.Conv2d(in_channels=self.in_channels, out_channels=self.value_channels,
kernel_size=1, stride=1, padding=0)
def forward(self, input):
x = input
if self.ds != 1:
x = self.pool(input)
# input shape: b,c,h,2w
batch_size, c, h, w = x.size(0), x.size(1), x.size(2), x.size(3)//2
local_y = []
local_x = []
step_h, step_w = h//self.scale, w//self.scale
for i in range(0, self.scale):
for j in range(0, self.scale):
start_x, start_y = i*step_h, j*step_w
end_x, end_y = min(start_x+step_h, h), min(start_y+step_w, w)
if i == (self.scale-1):
end_x = h
if j == (self.scale-1):
end_y = w
local_x += [start_x, end_x]
local_y += [start_y, end_y]
value = self.f_value(x)
query = self.f_query(x)
key = self.f_key(x)
value = torch.stack([value[:, :, :, :w], value[:,:,:,w:]], 4) # B*N*H*W*2
query = torch.stack([query[:, :, :, :w], query[:,:,:,w:]], 4) # B*N*H*W*2
key = torch.stack([key[:, :, :, :w], key[:,:,:,w:]], 4) # B*N*H*W*2
local_block_cnt = 2*self.scale*self.scale
# self-attention func
def func(value_local, query_local, key_local):
batch_size_new = value_local.size(0)
h_local, w_local = value_local.size(2), value_local.size(3)
value_local = value_local.contiguous().view(batch_size_new, self.value_channels, -1)
query_local = query_local.contiguous().view(batch_size_new, self.key_channels, -1)
query_local = query_local.permute(0, 2, 1)
key_local = key_local.contiguous().view(batch_size_new, self.key_channels, -1)
sim_map = torch.bmm(query_local, key_local) # batch matrix multiplication
sim_map = (self.key_channels**-.5) * sim_map
sim_map = F.softmax(sim_map, dim=-1)
context_local = torch.bmm(value_local, sim_map.permute(0,2,1))
# context_local = context_local.permute(0, 2, 1).contiguous()
context_local = context_local.view(batch_size_new, self.value_channels, h_local, w_local, 2)
return context_local
# Parallel Computing to speed up
# reshape value_local, q, k
v_list = [value[:,:,local_x[i]:local_x[i+1],local_y[i]:local_y[i+1]] for i in range(0, local_block_cnt, 2)]
v_locals = torch.cat(v_list,dim=0)
q_list = [query[:,:,local_x[i]:local_x[i+1],local_y[i]:local_y[i+1]] for i in range(0, local_block_cnt, 2)]
q_locals = torch.cat(q_list,dim=0)
k_list = [key[:,:,local_x[i]:local_x[i+1],local_y[i]:local_y[i+1]] for i in range(0, local_block_cnt, 2)]
k_locals = torch.cat(k_list,dim=0)
# print(v_locals.shape)
context_locals = func(v_locals,q_locals,k_locals)
context_list = []
for i in range(0, self.scale):
row_tmp = []
for j in range(0, self.scale):
left = batch_size*(j+i*self.scale)
right = batch_size*(j+i*self.scale) + batch_size
tmp = context_locals[left:right]
row_tmp.append(tmp)
context_list.append(torch.cat(row_tmp, 3))
context = torch.cat(context_list, 2)
context = torch.cat([context[:,:,:,:,0],context[:,:,:,:,1]],3)
if self.ds !=1:
context = F.interpolate(context, [h*self.ds, 2*w*self.ds])
return context
class PAMBlock(_PAMBlock):
def __init__(self, in_channels, key_channels=None, value_channels=None, scale=1, ds=1):
if key_channels == None:
key_channels = in_channels//8
if value_channels == None:
value_channels = in_channels
super(PAMBlock, self).__init__(in_channels,key_channels,value_channels,scale,ds)
class PAM(nn.Module):
"""
PAM module
"""
def __init__(self, in_channels, out_channels, sizes=([1]), ds=1):
super(PAM, self).__init__()
self.group = len(sizes)
self.stages = []
self.ds = ds # output stride
self.value_channels = out_channels
self.key_channels = out_channels // 8
self.stages = nn.ModuleList(
[self._make_stage(in_channels, self.key_channels, self.value_channels, size, self.ds)
for size in sizes])
self.conv_bn = nn.Sequential(
nn.Conv2d(in_channels * self.group, out_channels, kernel_size=1, padding=0,bias=False),
# nn.BatchNorm2d(out_channels),
)
def _make_stage(self, in_channels, key_channels, value_channels, size, ds):
return PAMBlock(in_channels,key_channels,value_channels,size,ds)
def forward(self, feats):
priors = [stage(feats) for stage in self.stages]
# concat
context = []
for i in range(0, len(priors)):
context += [priors[i]]
output = self.conv_bn(torch.cat(context, 1))
return output
这是一个PAM(Position Attention Module)模块的实现,用于处理图像特征,提高模型在特定任务上的性能。
具体来说,这个模块包含了一个_PAMBlock
类和一个PAM
类。_PAMBlock
类是PAM模块的基本组成部分,它包含了输入通道数、关键通道数、值通道数、缩放比例和下采样比例等参数,并实现了一个前向传播函数forward
。在这个函数中,输入张量经过一系列特征提取操作后,被分成了多个局部块,每个局部块都是一个子张量,然后每个子张量分别通过一个函数func
进行处理,得到了一个上下文张量context_local
。最后,这些上下文张量context_local
被拼接成一个输出张量context
并返回。这个类的主要作用是实现PAM模块的核心机制。
PAM
类则是对_PAMBlock
类的一个封装,它接收输入张量和输出通道数,并根据输入张量的不同区域大小,构建了多个_PAMBlock
实例。最终,这些实例提取出的特征被拼接在一起,并通过一个1x1卷积层得到最终的输出。
总的来说,这个模块的作用是将输入张量分成多个局部块,然后在每个局部块上计算上下文特征。这种机制可以有效地提高模型对特定任务的性能,同时也可以增强模型在局部区域的感知能力。