SCAR的pytorch实现

本文所实现的网络来源于SCAR:Spatial-/Channel-wise Attention Regression Networks for Crowd Counting(Neurocompting 2019)

import torch;from torchvision import models
from torchvision.models import vgg16
import warnings;from torch import nn
warnings.filterwarnings("ignore")
vgg16 = vgg16(pretrained=True)
def initialize_weights(models):
    for model in models:
        real_init_weights(model)
import warnings
warnings.filterwarnings("ignore")
def real_init_weights(m):

    if isinstance(m, list):
        for mini_m in m:
            real_init_weights(mini_m)
    else:
        if isinstance(m, nn.Conv2d):
            nn.init.normal_(m.weight, std=0.01)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.Linear):
            m.weight.data.normal_(0.0, std=0.01)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)
        elif isinstance(m,nn.Module):
            for mini_m in m.children():
                real_init_weights(mini_m)
        else:
            print( m )
class SCAR(torch.nn.Module):
    def __init__(self,loadwieght=False):
        super(SCAR,self).__init__()
        self.vgg10=vgg10
        if loadwieght==False:
            mod = models.vgg16(pretrained=True)
            initialize_weights(self.modules())
            self.vgg10.load_state_dict(mod.features[0:23].state_dict())
        
        self.dconv1=torch.nn.Conv2d(512,512,3,dilation=2,stride=1,padding=2)
        self.dconv2 = torch.nn.Conv2d(512, 512, 3, dilation=2, stride=1,padding=2)
        self.dconv3 = torch.nn.Conv2d(512, 512, 3, dilation=2, stride=1,padding=2)
        self.dconv4 = torch.nn.Conv2d(512, 256, 3, dilation=2, stride=1,padding=2)
        self.dconv5 = torch.nn.Conv2d(256, 128, 3, dilation=2, stride=1,padding=2)
        self.dconv6 = torch.nn.Conv2d(128, 64, 3, dilation=2, stride=1,padding=2)
        
        self.relu = torch.nn.functional.relu
        self.SAM=SAM()
        self.CAM=CAM()
        self.finalconv=torch.nn.Conv2d(128,1,1)
        self.upsample=torch.nn.functional.upsample
    def forward(self,x):
        y=self.vgg10(x)
    
        y=self.relu(self.dconv1(y))
        y = self.relu(self.dconv1(y))
        y = self.relu(self.dconv2(y))
        y = self.relu(self.dconv3(y))
        y = self.relu(self.dconv4(y))
        y = self.relu(self.dconv5(y))
        y = self.relu(self.dconv6(y))

        
        y_sa=self.SAM(y)
        
        y_ca=self.CAM(y)
       

        y=torch.cat((y_ca,y_sa),dim=1)
    
        y=self.finalconv(y)
       
        y=self.upsample(y,scale_factor=8)#由于进行了三次池化 因此8倍上取样
        return y


vgg10=torch.nn.Sequential(torch.nn.Conv2d(3,64,3,stride=1,padding=1),
                          torch.nn.ReLU(inplace=True),
                          torch.nn.Conv2d(64, 64, 3, stride=1,padding=1),
                          torch.nn.ReLU(inplace=True),
                          torch.nn.MaxPool2d(2,2),

                          torch.nn.Conv2d(64, 128, 3, stride=1,padding=1),
                          torch.nn.ReLU(inplace=True),
                          torch.nn.Conv2d(128, 128, 3, stride=1,padding=1),
                          torch.nn.ReLU(inplace=True),
                          torch.nn.MaxPool2d(2,2),

                          torch.nn.Conv2d(128, 256, 3, stride=1,padding=1),
                          torch.nn.ReLU(inplace=True),
                          torch.nn.Conv2d(256, 256, 3, stride=1,padding=1),
                          torch.nn.ReLU(inplace=True),
                          torch.nn.Conv2d(256, 256, 3, stride=1,padding=1),
                          torch.nn.ReLU(inplace=True),
                          torch.nn.MaxPool2d(2,2),  #尝试不进行下采样以达到不进行上采样

                          torch.nn.Conv2d(256, 512, 3, stride=1,padding=1),
                          torch.nn.ReLU(inplace=True),
                          torch.nn.Conv2d(512, 512, 3, stride=1,padding=1),
                          torch.nn.ReLU(inplace=True),
                          torch.nn.Conv2d(512, 512, 3, stride=1,padding=1),
                          torch.nn.ReLU(inplace=True),
                          #torch.nn.MaxPool2d(2),
)

class SAM(torch.nn.Module):
    def __init__(self):
        super(SAM,self).__init__()# SAM不改变输入到SAM中的x的shape
        self.q=torch.nn.Conv2d(64,64,1)
        self.k = torch.nn.Conv2d(64, 64, 1)
        self.v=torch.nn.Conv2d(64, 64, 1)
   

        self.lamda=torch.nn.Conv2d(64,64,1)
        self.bn=torch.nn.BatchNorm2d(64)


    def forward(self,x):
       
        N, C, H, W = x.size()
        
        q=self.q(x).view((N,-1,H*W)).permute(0,2,1) # HW*C
        k=self.q(x).view((N,-1,H*W))
        v=self.v(x).view((N,-1,H*W))
        mid=torch.bmm(q,k)
      
        attention=torch.nn.functional.softmax(mid,dim=-1)# HW*HW
        
        y=torch.bmm(v,attention)
        y=y.view((N,C,H,W))
        y=self.lamda(y)+x
     
        return y

class CAM(torch.nn.Module):
    def __init__(self):
        super(CAM,self).__init__()
        self.conv1=torch.nn.Conv2d(64,64,1)
        self.conv2 = torch.nn.Conv2d(64, 64, 1)
        self.bn = torch.nn.BatchNorm2d(64)


    def forward(self,x):
     
        N, C, H, W = x.size()
        q=self.conv1(x).view(N,C,-1)# C*HW
   
        k=self.conv1(x).view(N,-1,C) # HW*C
      
        attention_pre=torch.bmm(q,k)# C*C
        
        attention=torch.nn.functional.softmax(attention_pre,dim=-1)
        v=x.view(N,C,-1)
        cl2=torch.bmm(attention,v).view((N,C,H,W))
        cfinal=self.conv2(cl2)+x
       

        return cfinal

你可能感兴趣的:(人黑话不多,人群计数,pytorch,深度学习,人工智能)