输出网络结构 flops param参数量浮点操作数

print_model_parm_flops模型计算量

#####  print_model_parm_flops模型计算量
def print_model_parm_flops():

    # prods = {}
    # def save_prods(self, input, output):
        # print 'flops:{}'.format(self.__class__.__name__)
        # print 'input:{}'.format(input)
        # print '_dim:{}'.format(input[0].dim())
        # print 'input_shape:{}'.format(np.prod(input[0].shape))
        # grads.append(np.prod(input[0].shape))

    prods = {}
    def save_hook(name):
        def hook_per(self, input, output):
            # print ('flops:{}'.format(self.__class__.__name__))
            # print( 'input:{}'.format(input))
            # print '_dim:{}'.format(input[0].dim())
            # print 'input_shape:{}'.format(np.prod(input[0].shape))
            # prods.append(np.prod(input[0].shape))
           
            #torch.Size([3, 640, 480])所有元素相乘
            prods[name] = np.prod(input[0].shape)
            # prods.append(np.prod(input[0].shape))
        return hook_per

    list_1=[]
    def simple_hook(self, input, output):
        list_1.append(np.prod(input[0].shape))##
    list_2={}
    def simple_hook2(self, input, output):
        list_2['names'] = np.prod(input[0].shape)


    multiply_adds = False
    list_conv=[]
    def conv_hook(self, input, output):
        batch_size, input_channels, input_height, input_width = input[0].size()
        output_channels, output_height, output_width = output[0].size()

        kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) * (2 if multiply_adds else 1)
        bias_ops = 1 if self.bias is not None else 0

        params = output_channels * (kernel_ops + bias_ops)
        flops = batch_size * params * output_height * output_width

        list_conv.append(flops)


    list_linear=[] 
    def linear_hook(self, input, output):
        batch_size = input[0].size(0) if input[0].dim() == 2 else 1

        weight_ops = self.weight.nelement() * (2 if multiply_adds else 1)
        bias_ops = self.bias.nelement()

        flops = batch_size * (weight_ops + bias_ops)
        list_linear.append(flops)

    list_bn=[] 
    def bn_hook(self, input, output):
        list_bn.append(input[0].nelement())

    list_relu=[] 
    def relu_hook(self, input, output):
        list_relu.append(input[0].nelement())

    list_pooling=[]
    def pooling_hook(self, input, output):
        batch_size, input_channels, input_height, input_width = input[0].size()
        output_channels, output_height, output_width = output[0].size()

        kernel_ops = self.kernel_size * self.kernel_size
        bias_ops = 0
        params = output_channels * (kernel_ops + bias_ops)
        flops = batch_size * params * output_height * output_width

        list_pooling.append(flops)


            
    def foo(net):
        childrens = list(net.children())
        if not childrens:
            if isinstance(net, torch.nn.Conv2d):
                # net.register_forward_hook(save_hook(net.__class__.__name__))
                # net.register_forward_hook(simple_hook)
                # net.register_forward_hook(simple_hook2)
                net.register_forward_hook(conv_hook)
            if isinstance(net, torch.nn.Linear):
                net.register_forward_hook(linear_hook)
            if isinstance(net, torch.nn.BatchNorm2d):
                net.register_forward_hook(bn_hook)
            if isinstance(net, torch.nn.ReLU):
                net.register_forward_hook(relu_hook)
            if isinstance(net, torch.nn.MaxPool2d) or isinstance(net, torch.nn.AvgPool2d):
                net.register_forward_hook(pooling_hook)
            return
        for c in childrens:
                foo(c)
    dehaze_net =net.dehaze_net()
#    resnet = models.alexnet()
    foo(dehaze_net)
    input = Variable(torch.rand(3,640,480).unsqueeze(0), requires_grad = True)
    out = dehaze_net(input)


    total_flops = (sum(list_conv) + sum(list_linear) + sum(list_bn) + sum(list_relu) + sum(list_pooling))
    
    print('  + Number of FLOPs: %.2fM' % (total_flops / 1e6))

#    print( list_conv)
#    print( list_linear)
#    print( list_relu)
#
#
#    print ('prods:{}'.format(prods))
#    print( 'list_1:{}'.format(list_1))
#    print( 'list_2:{}'.format(list_2))
#    print ('list_final:{}'.format(list_final))

print_model_parm_flops()

打印参数量print_model_parm_nums()

####   打印参数量
def print_model_parm_nums():
    dehaze_net =net.dehaze_net().cuda()
#    dehaze_net.load_state_dict(torch.load('./snapshots/dehazer.pth'))
    total = sum([param.nelement() for param in dehaze_net.parameters()])
#    print(' + Number of params: %.2fM' % (total / 1e6))  
    print(' + Number of params: %.2f' % (total))  
print_model_parm_nums()

打印结构

# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torchvision
import torch.backends.cudnn as cudnn
import torch.optim
import os
import sys
import argparse
import time
import dataloader
import net
import numpy as np
from torchvision import transforms
from PIL import Image
import glob
import time
from torch.autograd import Variable
#########################   显示 trainable parameters and weights    #############################
def show_summary():
    from collections import OrderedDict
    import pandas as pd
    import numpy as np

    import torch
    from torch.autograd import Variable
    import torch.nn.functional as F
    from torch import nn

    def get_names_dict(model):
        """
        Recursive walk to get names including path
        """
        names = {}
        def _get_names(module, parent_name=''):
            for key, module in module.named_children():
                name = parent_name + '.' + key if parent_name else key
                names[name]=module
                if isinstance(module, torch.nn.Module):
                    _get_names(module, parent_name=name)
        _get_names(model)
        return names


    def torch_summarize_df(input_size, model, weights=False, input_shape=True, nb_trainable=True):
        """
        Summarizes torch model by showing trainable parameters and weights.
        
        author: wassname
        url: https://gist.github.com/wassname/0fb8f95e4272e6bdd27bd7df386716b7
        license: MIT
        
        Modified from:
        - https://github.com/pytorch/pytorch/issues/2001#issuecomment-313735757
        - https://gist.github.com/wassname/0fb8f95e4272e6bdd27bd7df386716b7/
        
        Usage:
            import torchvision.models as models
            model = models.alexnet()
            df = torch_summarize_df(input_size=(3, 224,224), model=model)
            print(df)
            
            #              name class_name        input_shape       output_shape  nb_params
            # 1     features=>0     Conv2d  (-1, 3, 224, 224)   (-1, 64, 55, 55)      23296#(3*11*11+1)*64
            # 2     features=>1       ReLU   (-1, 64, 55, 55)   (-1, 64, 55, 55)          0
            # ...
        """

        def register_hook(module):
            def hook(module, input, output):
                name = ''
                for key, item in names.items():
                    if item == module:
                        name = key
                #
                class_name = str(module.__class__).split('.')[-1].split("'")[0]
                module_idx = len(summary)

                m_key = module_idx + 1

                summary[m_key] = OrderedDict()
                summary[m_key]['name'] = name
                summary[m_key]['class_name'] = class_name
                if input_shape:
                    summary[m_key][
                        'input_shape'] = (1, ) + tuple(input[0].size())[1:]
                summary[m_key]['output_shape'] = (1, ) + tuple(output.size())[1:]###之前是-1 不知道为啥
                
                if weights:
                    summary[m_key]['weights'] = list(
                        [tuple(p.size()) for p in module.parameters()])

    #             summary[m_key]['trainable'] = any([p.requires_grad for p in module.parameters()])
                if nb_trainable:
                    params_trainable = sum([torch.LongTensor(list(p.size())).prod() for p in module.parameters() if p.requires_grad])
                    summary[m_key]['nb_trainable'] = params_trainable
                params = sum([torch.LongTensor(list(p.size())).prod() for p in module.parameters()])
                summary[m_key]['nb_params'] = params


            if  not isinstance(module, nn.Sequential) and \
                not isinstance(module, nn.ModuleList) and \
                not (module == model):
                hooks.append(module.register_forward_hook(hook))
                
        # Names are stored in parent and path+name is unique not the name
        names = get_names_dict(model)

        # check if there are multiple inputs to the network
        if isinstance(input_size[0], (list, tuple)):
            x = [Variable(torch.rand(1, *in_size)) for in_size in input_size]
        else:
            x = Variable(torch.rand(1, *input_size))

        if next(model.parameters()).is_cuda:
            x = x.cuda()

        # create properties
        summary = OrderedDict()
        hooks = []

        # register hook
        model.apply(register_hook)

        # make a forward pass
        model(x)

        # remove these hooks
        for h in hooks:
            h.remove()

        # make dataframe
        df_summary = pd.DataFrame.from_dict(summary, orient='index')

        return df_summary


    # Test on alexnet
#    import torchvision.models as models
#    model = models.alexnet()
    model = net.dehaze_net()
#    model.load_state_dict(torch.load('snapshots/Epoch10_enddehazer.pth'))
    df = torch_summarize_df(input_size=(3, 480, 640), model=model)
    print(df)
    return df

summ=show_summary()

out:

       name class_name        input_shape      output_shape    nb_params  \
1   e_conv1     Conv2d   (1, 3, 480, 640)  (1, 3, 480, 640)   tensor(12)   
2      relu       ReLU   (1, 3, 480, 640)  (1, 3, 480, 640)            0   
3   e_conv2     Conv2d   (1, 3, 480, 640)  (1, 3, 480, 640)   tensor(84)   
4      relu       ReLU   (1, 3, 480, 640)  (1, 3, 480, 640)            0   
5   e_conv3     Conv2d   (1, 6, 480, 640)  (1, 3, 480, 640)  tensor(453)   
6      relu       ReLU   (1, 3, 480, 640)  (1, 3, 480, 640)            0   
7   e_conv4     Conv2d   (1, 6, 480, 640)  (1, 3, 480, 640)  tensor(885)   
8      relu       ReLU   (1, 3, 480, 640)  (1, 3, 480, 640)            0   
9   e_conv5     Conv2d  (1, 12, 480, 640)  (1, 3, 480, 640)  tensor(327)   
10     relu       ReLU   (1, 3, 480, 640)  (1, 3, 480, 640)            0   
11     relu       ReLU   (1, 3, 480, 640)  (1, 3, 480, 640)            0  

你可能感兴趣的:(torch)