resnet升级

# -*- encoding: utf-8 -*-
"""
@File    : ResNet.py
@Time    : 2021-05-08 14:50
@Author  : XD
@Email   : [email protected]
@Software: PyCharm
"""
from __future__ import absolute_import

import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from IPython import embed

class ResNet50(nn.Module):
    def __init__(self, num_class, loss = {'softmax, metric'},**kwargs):
        super(ResNet50, self).__init__()
        resnet50 = torchvision.models.resnet50(pretrained = False)
        self.loss = loss
        self.base = nn.Sequential(*list(resnet50.children())[:-2])
        if not self.loss == {'metric'}:
            self.classifier = nn.Linear(2048, num_class)

    def forward(self, x):
        x = self.base(x)
        x = F.avg_pool2d(x,x.size()[2:])
        f = x.view(x.size(0), -1) #future
        #归一化特征
        #f = 1. * f / (torch.norm(f, 2, dim = -1, keepdim = True).expand_as(f) + 1e-12)

        if not self.training:
            return f
        y = self.classifier(f)
        if self.loss == {'softmax'}:
            return y
        elif self.loss == {'metric'}:
            return f
        elif self.loss == {'softmax','metric'}:
            return y, f
        else:
            print('loss setting error')






if __name__ == '__main__':
    model = ResNet50(num_class = 751)
    imgs = torch.rand(32, 3 , 256, 128)
    f = model(imgs)
    embed()

你可能感兴趣的:(resnet升级)