CCNet YOLOv5目标检测实现

目前还没有进行训练,yolo验证无误

基于此文章进行修改CCNet

class CCBottleneck(nn.Module):
    def __init__(self, c1, c2, shortcut=True, g=1, e=0.5, recurrence = 2):
        super().__init__()
        self.recurrence = recurrence
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c_, c2, 3, 1, g=g)
        self.add = shortcut and c1 == c2

        self.in_channels = c1
        self.channels = c1 // 8
        self.ConvQuery = nn.Conv2d(self.in_channels, self.channels, kernel_size=1)
        self.ConvKey = nn.Conv2d(self.in_channels, self.channels, kernel_size=1)
        self.ConvValue = nn.Conv2d(self.in_channels, self.in_channels, kernel_size=1)
 
        self.SoftMax = nn.Softmax(dim=3)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        x0 = self.cv2(self.cv1(x))
        x1 = x0
        print('x1 is:',x1)

        for i in range(self.recurrence):
            b, _, h, w = x1.size()

            # [b, c', h, w]
            query = self.ConvQuery(x1)
            # [b, w, c', h] -> [b*w, c', h] -> [b*w, h, c']
            query_H = query.permute(0, 3, 1, 2).contiguous().view(b*w, -1, h).permute(0, 2, 1)
            # [b, h, c', w] -> [b*h, c', w] -> [b*h, w, c']
            query_W = query.permute(0, 2, 1, 3).contiguous().view(b*h, -1, w).permute(0, 2, 1)
            
            # [b, c', h, w]
            key = self.ConvKey(x1)
            # [b, w, c', h] -> [b*w, c', h]
            key_H = key.permute(0, 3, 1, 2).contiguous().view(b*w, -1, h)
            # [b, h, c', w] -> [b*h, c', w]
            key_W = key.permute(0, 2, 1, 3).contiguous().view(b*h, -1, w)
            
            # [b, c, h, w]
            value = self.ConvValue(x1)
            # [b, w, c, h] -> [b*w, c, h]
            value_H = value.permute(0, 3, 1, 2).contiguous().view(b*w, -1, h).float()
            # [b, h, c, w] -> [b*h, c, w]
            value_W = value.permute(0, 2, 1, 3).contiguous().view(b*h, -1, w).float()

            if query_H.is_cuda:
                inf = -1 * torch.diag(torch.tensor(float("inf")).cuda().repeat(h),0).unsqueeze(0).repeat(b*w,1,1)
            else:
                inf = -1 * torch.diag(torch.tensor(float("inf")).repeat(h),0).unsqueeze(0).repeat(b*w,1,1)
            # print('inf is ', inf)
            # print(query_H.is_cuda, inf.is_cuda)

            # [b*w, h, c']* [b*w, c', h] -> [b*w, h, h] -> [b, h, w, h]
            energy_H = (torch.bmm(query_H, key_H)  + inf).view(b, w, h, h).permute(0, 2, 1, 3)
            # energy_H = torch.bmm(query_H, key_H).view(b, w, h, h).permute(0, 2, 1, 3)
            # [b*h, w, c']*[b*h, c', w] -> [b*h, w, w] -> [b, h, w, w]
            energy_W = torch.bmm(query_W, key_W).view(b, h, w, w)
            # [b, h, w, h+w]  concate channels in axis=3
            energy_total = torch.cat([energy_H, energy_W], 3)
            # print('energy_total is ', energy_total)
            concate = self.SoftMax(energy_total)
            # print('concate is ', concate)
            # [b, h, w, h] -> [b, w, h, h] -> [b*w, h, h]
            attention_H = concate[:,:,:, 0:h].permute(0, 2, 1, 3).contiguous().view(b*w, h, h)
            attention_W = concate[:,:,:, h:h+w].contiguous().view(b*h, w, w)
    
            # [b*w, h, c]*[b*w, h, h] -> [b, w, c, h]
            out_H = torch.bmm(value_H, attention_H.permute(0, 2, 1)).view(b, w, -1, h).permute(0, 2, 3, 1)
            out_W = torch.bmm(value_W, attention_W.permute(0, 2, 1)).view(b, h, -1, w).permute(0, 2, 1, 3)

            x1 = self.gamma*(out_H + out_W) + x1
            # print('In cc x1 is:', x1)

        # out = self.conv_out(x1)
        out = x1.expand_as(x0)

        # print('out x1 is:', x1)
        
        return x + out if self.add else out

class C3CC(C3):
    # C3 module with CABottleneck()
    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
        super().__init__(c1, c2, n, shortcut, g, e)
        c_ = int(c2 * e)  # hidden channels
        self.m = nn.Sequential(*(CCBottleneck(c_, c_,shortcut) for _ in range(n)))

你可能感兴趣的:(python,深度学习,pytorch)