res50和vgg6快速实现

from functools import reduce
from operator import add

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet
from torchvision.models import vgg

from .base.feature import extract_feat_vgg, extract_feat_res
from .base.correlation import Correlation
from .learner import HPNLearner

import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F


class HypercorrSqueezeNetwork(nn.Module):
    def __init__(self, backbone, use_original_imgsize):
        super(HypercorrSqueezeNetwork, self).__init__()
        # 1. Backbone network initialization
        self.backbone_type = backbone
        self.use_original_imgsize = use_original_imgsize
        if backbone == 'vgg16':
            self.backbone = vgg.vgg16(pretrained=True)
            self.feat_ids = [17, 19, 21, 24, 26, 28, 30]
            self.extract_feats = extract_feat_vgg
            nbottlenecks = [2, 2, 3, 3, 3, 1]
        elif backbone == 'resnet50':
            self.backbone = resnet.resnet50(pretrained=True)
            self.feat_ids = list(range(4, 17))
            self.extract_feats = extract_feat_res
            nbottlenecks = [3, 4, 6, 3]
        elif backbone == 'resnet101':
            self.backbone = resnet.resnet101(pretrained=True)
            self.feat_ids = list(range(4, 34))
            self.extract_feats = extract_feat_res
            nbottlenecks = [3, 4, 23, 3]
        else:
            raise Exception('Unavailable backbone: %s' % backbone)

        self.bottleneck_ids = reduce(add, list(map(lambda x: list(range(x)), nbottlenecks)))
        #vgg  bottleneck_ids=[0, 1, 0, 1, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0]
        #res50 bottleneck_ids=[0, 1, 2, 0, 1, 2, 3, 0, 1, 2, 3, 4, 5, 0, 1, 2]

        self.lids = reduce(add, [[i + 1] * x for i, x in enumerate(nbottlenecks)])

        #vgg16 lids = [1, 1, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6]
        #res50  lids = [1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 4, 4, 4]
        #

        self.stack_ids = torch.tensor(self.lids).bincount().__reversed__().cumsum(dim=0)[:3]  #
        #vgg stack_ids = tensor([1, 4, 7])
        #res50 stack_ids = tensor([ 3,  9, 13])

        self.backbone.eval()
        self.hpn_learner = HPNLearner(list(reversed(nbottlenecks[-3:])))
        self.cross_entropy_loss = nn.CrossEntropyLoss()


    def forward(self, query_img, support_img, support_mask):
        with torch.no_grad():
            query_feats = self.extract_feats(query_img, self.backbone, self.feat_ids, self.bottleneck_ids, self.lids)
            support_feats = self.extract_feats(support_img, self.backbone, self.feat_ids, self.bottleneck_ids,
                                               self.lids)
            support_feats = self.mask_feature(support_feats, support_mask.clone())
            corr = Correlation.multilayer_correlation(query_feats, support_feats, self.stack_ids)

ResNet-50 结构res50和vgg6快速实现_第1张图片

VGG16 结构

 res50和vgg6快速实现_第2张图片

 

r""" Extracts intermediate features from given backbone network & layer ids """

# vgg16  bottleneck_ids=[0, 1, 0, 1, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0]
# vgg16 lids = [1, 1, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6]
def extract_feat_vgg(img, backbone, feat_ids, bottleneck_ids=None, lids=None):
    r""" Extract intermediate features from VGG """
    feats = []
    feat = img
    for lid, module in enumerate(backbone.features):
        feat = module(feat)
        if lid in feat_ids:
            feats.append(feat.clone())

    return feats


def extract_feat_res(img, backbone, feat_ids, bottleneck_ids, lids):
    r""" Extract intermediate features from ResNet"""
    feats = []

    # Layer 0
    feat = backbone.conv1.forward(img)
    feat = backbone.bn1.forward(feat)
    feat = backbone.relu.forward(feat)
    feat = backbone.maxpool.forward(feat)

    # Layer 1-4


    # res50 bottleneck_ids=[0, 1, 2, 0, 1, 2, 3, 0, 1, 2, 3, 4, 5, 0, 1, 2]


    # res50  lids = [1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 4, 4, 4]

    for hid, (bid, lid) in enumerate(zip(bottleneck_ids, lids)):  #bid = bottleneck_id    lid = layer_id
        #out:
        #hid,bottleneck_id ,layer_id
        # 0(0, 1)
        # 1(1, 1)
        # 2(2, 1)

        # 3(0, 2)
        # 4(1, 2)
        # 5(2, 2)
        # 6(3, 2)

        # 7(0, 3)
        # 8(1, 3)
        # 9(2, 3)
        # 10(3, 3)
        # 11(4, 3)
        # 12(5, 3)

        # 13(0, 4)
        # 14(1, 4)
        # 15(2, 4)
        res = feat
        feat = backbone.__getattr__('layer%d' % lid)[bid].conv1.forward(feat)
        feat = backbone.__getattr__('layer%d' % lid)[bid].bn1.forward(feat)
        feat = backbone.__getattr__('layer%d' % lid)[bid].relu.forward(feat)
        feat = backbone.__getattr__('layer%d' % lid)[bid].conv2.forward(feat)
        feat = backbone.__getattr__('layer%d' % lid)[bid].bn2.forward(feat)
        feat = backbone.__getattr__('layer%d' % lid)[bid].relu.forward(feat)
        feat = backbone.__getattr__('layer%d' % lid)[bid].conv3.forward(feat)
        feat = backbone.__getattr__('layer%d' % lid)[bid].bn3.forward(feat)

        if bid == 0:
            res = backbone.__getattr__('layer%d' % lid)[bid].downsample.forward(res)

        feat += res
        #feat_ids = list(range(4, 17))
        if hid + 1 in feat_ids:
            feats.append(feat.clone())

        feat = backbone.__getattr__('layer%d' % lid)[bid].relu.forward(feat)

    return feats    # print(len(feats)) ->  13

你可能感兴趣的:(深度学习,机器学习,python)