本文所实现的网络来源于SCAR:Spatial-/Channel-wise Attention Regression Networks for Crowd Counting(Neurocompting 2019)
import torch;from torchvision import models
from torchvision.models import vgg16
import warnings;from torch import nn
warnings.filterwarnings("ignore")
vgg16 = vgg16(pretrained=True)
def initialize_weights(models):
for model in models:
real_init_weights(model)
import warnings
warnings.filterwarnings("ignore")
def real_init_weights(m):
if isinstance(m, list):
for mini_m in m:
real_init_weights(mini_m)
else:
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight, std=0.01)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0.0, std=0.01)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m,nn.Module):
for mini_m in m.children():
real_init_weights(mini_m)
else:
print( m )
class SCAR(torch.nn.Module):
def __init__(self,loadwieght=False):
super(SCAR,self).__init__()
self.vgg10=vgg10
if loadwieght==False:
mod = models.vgg16(pretrained=True)
initialize_weights(self.modules())
self.vgg10.load_state_dict(mod.features[0:23].state_dict())
self.dconv1=torch.nn.Conv2d(512,512,3,dilation=2,stride=1,padding=2)
self.dconv2 = torch.nn.Conv2d(512, 512, 3, dilation=2, stride=1,padding=2)
self.dconv3 = torch.nn.Conv2d(512, 512, 3, dilation=2, stride=1,padding=2)
self.dconv4 = torch.nn.Conv2d(512, 256, 3, dilation=2, stride=1,padding=2)
self.dconv5 = torch.nn.Conv2d(256, 128, 3, dilation=2, stride=1,padding=2)
self.dconv6 = torch.nn.Conv2d(128, 64, 3, dilation=2, stride=1,padding=2)
self.relu = torch.nn.functional.relu
self.SAM=SAM()
self.CAM=CAM()
self.finalconv=torch.nn.Conv2d(128,1,1)
self.upsample=torch.nn.functional.upsample
def forward(self,x):
y=self.vgg10(x)
y=self.relu(self.dconv1(y))
y = self.relu(self.dconv1(y))
y = self.relu(self.dconv2(y))
y = self.relu(self.dconv3(y))
y = self.relu(self.dconv4(y))
y = self.relu(self.dconv5(y))
y = self.relu(self.dconv6(y))
y_sa=self.SAM(y)
y_ca=self.CAM(y)
y=torch.cat((y_ca,y_sa),dim=1)
y=self.finalconv(y)
y=self.upsample(y,scale_factor=8)#由于进行了三次池化 因此8倍上取样
return y
vgg10=torch.nn.Sequential(torch.nn.Conv2d(3,64,3,stride=1,padding=1),
torch.nn.ReLU(inplace=True),
torch.nn.Conv2d(64, 64, 3, stride=1,padding=1),
torch.nn.ReLU(inplace=True),
torch.nn.MaxPool2d(2,2),
torch.nn.Conv2d(64, 128, 3, stride=1,padding=1),
torch.nn.ReLU(inplace=True),
torch.nn.Conv2d(128, 128, 3, stride=1,padding=1),
torch.nn.ReLU(inplace=True),
torch.nn.MaxPool2d(2,2),
torch.nn.Conv2d(128, 256, 3, stride=1,padding=1),
torch.nn.ReLU(inplace=True),
torch.nn.Conv2d(256, 256, 3, stride=1,padding=1),
torch.nn.ReLU(inplace=True),
torch.nn.Conv2d(256, 256, 3, stride=1,padding=1),
torch.nn.ReLU(inplace=True),
torch.nn.MaxPool2d(2,2), #尝试不进行下采样以达到不进行上采样
torch.nn.Conv2d(256, 512, 3, stride=1,padding=1),
torch.nn.ReLU(inplace=True),
torch.nn.Conv2d(512, 512, 3, stride=1,padding=1),
torch.nn.ReLU(inplace=True),
torch.nn.Conv2d(512, 512, 3, stride=1,padding=1),
torch.nn.ReLU(inplace=True),
#torch.nn.MaxPool2d(2),
)
class SAM(torch.nn.Module):
def __init__(self):
super(SAM,self).__init__()# SAM不改变输入到SAM中的x的shape
self.q=torch.nn.Conv2d(64,64,1)
self.k = torch.nn.Conv2d(64, 64, 1)
self.v=torch.nn.Conv2d(64, 64, 1)
self.lamda=torch.nn.Conv2d(64,64,1)
self.bn=torch.nn.BatchNorm2d(64)
def forward(self,x):
N, C, H, W = x.size()
q=self.q(x).view((N,-1,H*W)).permute(0,2,1) # HW*C
k=self.q(x).view((N,-1,H*W))
v=self.v(x).view((N,-1,H*W))
mid=torch.bmm(q,k)
attention=torch.nn.functional.softmax(mid,dim=-1)# HW*HW
y=torch.bmm(v,attention)
y=y.view((N,C,H,W))
y=self.lamda(y)+x
return y
class CAM(torch.nn.Module):
def __init__(self):
super(CAM,self).__init__()
self.conv1=torch.nn.Conv2d(64,64,1)
self.conv2 = torch.nn.Conv2d(64, 64, 1)
self.bn = torch.nn.BatchNorm2d(64)
def forward(self,x):
N, C, H, W = x.size()
q=self.conv1(x).view(N,C,-1)# C*HW
k=self.conv1(x).view(N,-1,C) # HW*C
attention_pre=torch.bmm(q,k)# C*C
attention=torch.nn.functional.softmax(attention_pre,dim=-1)
v=x.view(N,C,-1)
cl2=torch.bmm(attention,v).view((N,C,H,W))
cfinal=self.conv2(cl2)+x
return cfinal