Pytorch 计算模型复杂度 (Params 和 FLOPs)

from torchscan.crawler import crawl_module
from fvcore.nn import FlopCountAnalysis
import torch.nn as nn
import torch


def parse_shapes(input):
    if isinstance(input, list) or isinstance(input,tuple):
        out_shapes = [item.shape[1:] for item in input]
    else:
        out_shapes = input.shape[1:]

    return out_shapes

def flop_counter(model,input):
    try:
        module_info = crawl_module(model, parse_shapes(input))
        flops = sum(layer["flops"] for layer in module_info["layers"])
    except Exception as e:
        print(f'\nflops counter came across error: {e} \n')
        try:
            print('try another counter...\n')
            if isinstance(input, list):
                input = tuple(input)
            flops = FlopCountAnalysis(model, input).total()
        except Exception as e:
            print(e)
            raise e
        else:
            flops = flops / 1e9
            print(f'FLOPs : {flops:.5f} G')
            return flops

    else:
        flops = flops / 1e9
        print(f'FLOPs : {flops:.5f} G')
        return flops

def print_network_params(model,model_name):
    num_params = 0
    if isinstance(model,list):
        for m in model:
            for param in m.parameters():
                num_params += param.numel()
        print('[Network %s] Total number of parameters : %.5f M' % (model_name, num_params / 1e6))

    else:
        for param in model.parameters():
            num_params += param.numel()
        print('[Network %s] Total number of parameters : %.5f M' % (model_name, num_params / 1e6))


#SpatialGroupEnhance
class SGE(nn.Module):

    def __init__(self, groups):
        super().__init__()
        self.groups = groups
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.weight = nn.Parameter(torch.zeros(1, groups, 1, 1))
        self.bias = nn.Parameter(torch.zeros(1, groups, 1, 1))
        self.sig = nn.Sigmoid()

    def forward(self, x):
        b, c, h, w = x.shape
        x = x.view(b * self.groups, -1, h, w)  # bs*g,dim//g,h,w
        xn = x * self.avg_pool(x)  # bs*g,dim//g,h,w
        xn = xn.sum(dim=1, keepdim=True)  # bs*g,1,h,w
        t = xn.view(b * self.groups, -1)  # bs*g,h*w

        t = t - t.mean(dim=1, keepdim=True)  # bs*g,h*w
        std = t.std(dim=1, keepdim=True) + 1e-5
        t = t / std  # bs*g,h*w
        t = t.view(b, self.groups, h, w)  # bs,g,h*w

        t = t * self.weight + self.bias  # bs,g,h*w
        t = t.view(b * self.groups, 1, h, w)  # bs*g,1,h*w
        x = x * self.sig(t)
        x = x.view(b, c, h, w)

        return x




if __name__ == '__main__':
    x = torch.randn(1,256,32,32)
    model = SGE(groups=4)
    out = model(x)
    print_network_params(model,'SGE')
    flop_counter(model,x)   #support multiple input


注: 一般地,只要是nn.Module 的子类上述代码都能正常跑,如果你定义的完整的模型跑不通,那建议试试把模型拆开,按module 来计算。

你可能感兴趣的:(深度学习,Pytorch,pytorch,python,深度学习)