
上接题目: register_hook(...) / module.register_forward_hook(hook) / ...




目的:统计参数信息/ 注册hook,搭建 model 前向传播的hooks / 输入input ,正向传播 / remove  hook / 输出网络结构和参数信息

torchsummary.py : summary(...)

import torch
import torch.nn as nn
from torch.autograd import Variable

from collections import OrderedDict
import numpy as np

def summary(model, input_size, batch_size=-1, device="cuda"):

    def register_hook(module):

        def hook(module, input, output):
            class_name = str(module.__class__).split(".")[-1].split("'")[0]
            module_idx = len(summary)

            m_key = "%s-%i" % (class_name, module_idx + 1)
            summary[m_key] = OrderedDict()
            summary[m_key]["input_shape"] = list(input[0].size())
            summary[m_key]["input_shape"][0] = batch_size
            if isinstance(output, (list, tuple)):
                summary[m_key]["output_shape"] = [
                    [-1] + list(o.size())[1:] for o in output
                summary[m_key]["output_shape"] = list(output.size())
                summary[m_key]["output_shape"][0] = batch_size

            params = 0
            if hasattr(module, "weight") and hasattr(module.weight, "size"):
                params += torch.prod(torch.LongTensor(list(module.weight.size())))
                summary[m_key]["trainable"] = module.weight.requires_grad
            if hasattr(module, "bias") and hasattr(module.bias, "size"):
                params += torch.prod(torch.LongTensor(list(module.bias.size())))
            summary[m_key]["nb_params"] = params

        if (
            not isinstance(module, nn.Sequential)
            and not isinstance(module, nn.ModuleList)
            and not (module == model)

    device = device.lower()
    assert device in [
    ], "Input device is not valid, please specify 'cuda' or 'cpu'"

    if device == "cuda" and torch.cuda.is_available():
        dtype = torch.cuda.FloatTensor
        dtype = torch.FloatTensor

    # multiple inputs to the network
    if isinstance(input_size, tuple):
        input_size = [input_size]

    # batch_size of 2 for batchnorm
    x = [torch.rand(2, *in_size).type(dtype) for in_size in input_size]
    # print(type(x[0]))

    # create properties
    summary = OrderedDict()
    hooks = []

    # register hook

    # make a forward pass
    # print(x.shape)

    # remove these hooks
    for h in hooks:

    line_new = "{:>20}  {:>25} {:>15}".format("Layer (type)", "Output Shape", "Param #")
    total_params = 0
    total_output = 0
    trainable_params = 0
    for layer in summary:
        # input_shape, output_shape, trainable, nb_params
        line_new = "{:>20}  {:>25} {:>15}".format(
        total_params += summary[layer]["nb_params"]
        total_output += np.prod(summary[layer]["output_shape"])
        if "trainable" in summary[layer]:
            if summary[layer]["trainable"] == True:
                trainable_params += summary[layer]["nb_params"]

    # assume 4 bytes/number (float on cuda).
    total_input_size = abs(np.prod(input_size) * batch_size * 4. / (1024 ** 2.))
    total_output_size = abs(2. * total_output * 4. / (1024 ** 2.))  # x2 for gradients
    total_params_size = abs(total_params.numpy() * 4. / (1024 ** 2.))
    total_size = total_params_size + total_output_size + total_input_size

    print("Total params: {0:,}".format(total_params))
    print("Trainable params: {0:,}".format(trainable_params))
    print("Non-trainable params: {0:,}".format(total_params - trainable_params))
    print("Input size (MB): %0.2f" % total_input_size)
    print("Forward/backward pass size (MB): %0.2f" % total_output_size)
    print("Params size (MB): %0.2f" % total_params_size)
    print("Estimated Total Size (MB): %0.2f" % total_size)
    # return summary

module.py : model.apply(...)

    def apply(self: T, fn: Callable[['Module'], None]) -> T:
        r"""Applies ``fn`` recursively to every submodule (as returned by ``.children()``)
        as well as self. Typical use includes initializing the parameters of a model
        (see also :ref:`nn-init-doc`).

            fn (:class:`Module` -> None): function to be applied to each submodule

            Module: self


            >>> @torch.no_grad()
            >>> def init_weights(m):
            >>>     print(m)
            >>>     if type(m) == nn.Linear:
            >>>         m.weight.fill_(1.0)
            >>>         print(m.weight)
            >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
            >>> net.apply(init_weights)
            Linear(in_features=2, out_features=2, bias=True)
            Parameter containing:
            tensor([[ 1.,  1.],
                    [ 1.,  1.]])
            Linear(in_features=2, out_features=2, bias=True)
            Parameter containing:
            tensor([[ 1.,  1.],
                    [ 1.,  1.]])
              (0): Linear(in_features=2, out_features=2, bias=True)
              (1): Linear(in_features=2, out_features=2, bias=True)
              (0): Linear(in_features=2, out_features=2, bias=True)
              (1): Linear(in_features=2, out_features=2, bias=True)
        for module in self.children():
        return self


        for module in self.children():




接下来:  / model(*x) / module.py  _call_impl(...) 

    def _call_impl(self, *input, **kwargs):
        for hook in itertools.chain(
            result = hook(self, input)
            if result is not None:
                if not isinstance(result, tuple):
                    result = (result,)
                input = result
        if torch._C._get_tracing_state():
            result = self._slow_forward(*input, **kwargs)
            result = self.forward(*input, **kwargs)
        for hook in itertools.chain(
            hook_result = hook(self, input, result)
            if hook_result is not None:
                result = hook_result
        if (len(self._backward_hooks) > 0) or (len(_global_backward_hooks) > 0):
            var = result
            while not isinstance(var, torch.Tensor):
                if isinstance(var, dict):
                    var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
                    var = var[0]
            grad_fn = var.grad_fn
            if grad_fn is not None:
                for hook in itertools.chain(
                    wrapper = functools.partial(hook, self)
                    functools.update_wrapper(wrapper, hook)
        return result

/home/wangbin/anaconda3/envs/deep_learning/bin/python3.7 /media/wangbin/F/深度学习_程序/yolo_practice/resnet_yolo.py
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 209, 209]           9,408
       BatchNorm2d-2         [-1, 64, 209, 209]             128
              ReLU-3         [-1, 64, 209, 209]               0
         MaxPool2d-4         [-1, 64, 105, 105]               0
            Conv2d-5         [-1, 64, 105, 105]           4,096
       BatchNorm2d-6         [-1, 64, 105, 105]             128
              ReLU-7         [-1, 64, 105, 105]               0
            Conv2d-8         [-1, 64, 105, 105]          36,864
       BatchNorm2d-9         [-1, 64, 105, 105]             128
             ReLU-10         [-1, 64, 105, 105]               0
           Conv2d-11        [-1, 256, 105, 105]          16,384
      BatchNorm2d-12        [-1, 256, 105, 105]             512
           Conv2d-13        [-1, 256, 105, 105]          16,384
      BatchNorm2d-14        [-1, 256, 105, 105]             512
             ReLU-15        [-1, 256, 105, 105]               0
       Bottleneck-16        [-1, 256, 105, 105]               0
           Conv2d-17         [-1, 64, 105, 105]          16,384
      BatchNorm2d-18         [-1, 64, 105, 105]             128
             ReLU-19         [-1, 64, 105, 105]               0
           Conv2d-20         [-1, 64, 105, 105]          36,864
      BatchNorm2d-21         [-1, 64, 105, 105]             128
             ReLU-22         [-1, 64, 105, 105]               0
           Conv2d-23        [-1, 256, 105, 105]          16,384
      BatchNorm2d-24        [-1, 256, 105, 105]             512
             ReLU-25        [-1, 256, 105, 105]               0
       Bottleneck-26        [-1, 256, 105, 105]               0
           Conv2d-27         [-1, 64, 105, 105]          16,384
      BatchNorm2d-28         [-1, 64, 105, 105]             128
             ReLU-29         [-1, 64, 105, 105]               0
           Conv2d-30         [-1, 64, 105, 105]          36,864
      BatchNorm2d-31         [-1, 64, 105, 105]             128
             ReLU-32         [-1, 64, 105, 105]               0
           Conv2d-33        [-1, 256, 105, 105]          16,384
      BatchNorm2d-34        [-1, 256, 105, 105]             512
             ReLU-35        [-1, 256, 105, 105]               0
       Bottleneck-36        [-1, 256, 105, 105]               0
           Conv2d-37        [-1, 128, 105, 105]          32,768
      BatchNorm2d-38        [-1, 128, 105, 105]             256
             ReLU-39        [-1, 128, 105, 105]               0
           Conv2d-40          [-1, 128, 53, 53]         147,456
      BatchNorm2d-41          [-1, 128, 53, 53]             256
             ReLU-42          [-1, 128, 53, 53]               0
           Conv2d-43          [-1, 512, 53, 53]          65,536
      BatchNorm2d-44          [-1, 512, 53, 53]           1,024
           Conv2d-45          [-1, 512, 53, 53]         131,072
      BatchNorm2d-46          [-1, 512, 53, 53]           1,024
             ReLU-47          [-1, 512, 53, 53]               0
       Bottleneck-48          [-1, 512, 53, 53]               0
           Conv2d-49          [-1, 128, 53, 53]          65,536
      BatchNorm2d-50          [-1, 128, 53, 53]             256
             ReLU-51          [-1, 128, 53, 53]               0
           Conv2d-52          [-1, 128, 53, 53]         147,456
      BatchNorm2d-53          [-1, 128, 53, 53]             256
             ReLU-54          [-1, 128, 53, 53]               0
           Conv2d-55          [-1, 512, 53, 53]          65,536
      BatchNorm2d-56          [-1, 512, 53, 53]           1,024
             ReLU-57          [-1, 512, 53, 53]               0
       Bottleneck-58          [-1, 512, 53, 53]               0
           Conv2d-59          [-1, 128, 53, 53]          65,536
      BatchNorm2d-60          [-1, 128, 53, 53]             256
             ReLU-61          [-1, 128, 53, 53]               0
           Conv2d-62          [-1, 128, 53, 53]         147,456
      BatchNorm2d-63          [-1, 128, 53, 53]             256
             ReLU-64          [-1, 128, 53, 53]               0
           Conv2d-65          [-1, 512, 53, 53]          65,536
      BatchNorm2d-66          [-1, 512, 53, 53]           1,024
             ReLU-67          [-1, 512, 53, 53]               0
       Bottleneck-68          [-1, 512, 53, 53]               0
           Conv2d-69          [-1, 128, 53, 53]          65,536
      BatchNorm2d-70          [-1, 128, 53, 53]             256
             ReLU-71          [-1, 128, 53, 53]               0
           Conv2d-72          [-1, 128, 53, 53]         147,456
      BatchNorm2d-73          [-1, 128, 53, 53]             256
             ReLU-74          [-1, 128, 53, 53]               0
           Conv2d-75          [-1, 512, 53, 53]          65,536
      BatchNorm2d-76          [-1, 512, 53, 53]           1,024
             ReLU-77          [-1, 512, 53, 53]               0
       Bottleneck-78          [-1, 512, 53, 53]               0
           Conv2d-79          [-1, 256, 53, 53]         131,072
      BatchNorm2d-80          [-1, 256, 53, 53]             512
             ReLU-81          [-1, 256, 53, 53]               0
           Conv2d-82          [-1, 256, 27, 27]         589,824
      BatchNorm2d-83          [-1, 256, 27, 27]             512
             ReLU-84          [-1, 256, 27, 27]               0
           Conv2d-85         [-1, 1024, 27, 27]         262,144
      BatchNorm2d-86         [-1, 1024, 27, 27]           2,048
           Conv2d-87         [-1, 1024, 27, 27]         524,288
      BatchNorm2d-88         [-1, 1024, 27, 27]           2,048
             ReLU-89         [-1, 1024, 27, 27]               0
       Bottleneck-90         [-1, 1024, 27, 27]               0
           Conv2d-91          [-1, 256, 27, 27]         262,144
      BatchNorm2d-92          [-1, 256, 27, 27]             512
             ReLU-93          [-1, 256, 27, 27]               0
           Conv2d-94          [-1, 256, 27, 27]         589,824
      BatchNorm2d-95          [-1, 256, 27, 27]             512
             ReLU-96          [-1, 256, 27, 27]               0
           Conv2d-97         [-1, 1024, 27, 27]         262,144
      BatchNorm2d-98         [-1, 1024, 27, 27]           2,048
             ReLU-99         [-1, 1024, 27, 27]               0
      Bottleneck-100         [-1, 1024, 27, 27]               0
          Conv2d-101          [-1, 256, 27, 27]         262,144
     BatchNorm2d-102          [-1, 256, 27, 27]             512
            ReLU-103          [-1, 256, 27, 27]               0
          Conv2d-104          [-1, 256, 27, 27]         589,824
     BatchNorm2d-105          [-1, 256, 27, 27]             512
            ReLU-106          [-1, 256, 27, 27]               0
          Conv2d-107         [-1, 1024, 27, 27]         262,144
     BatchNorm2d-108         [-1, 1024, 27, 27]           2,048
            ReLU-109         [-1, 1024, 27, 27]               0
      Bottleneck-110         [-1, 1024, 27, 27]               0
          Conv2d-111          [-1, 256, 27, 27]         262,144
     BatchNorm2d-112          [-1, 256, 27, 27]             512
            ReLU-113          [-1, 256, 27, 27]               0
          Conv2d-114          [-1, 256, 27, 27]         589,824
     BatchNorm2d-115          [-1, 256, 27, 27]             512
            ReLU-116          [-1, 256, 27, 27]               0
          Conv2d-117         [-1, 1024, 27, 27]         262,144
     BatchNorm2d-118         [-1, 1024, 27, 27]           2,048
            ReLU-119         [-1, 1024, 27, 27]               0
      Bottleneck-120         [-1, 1024, 27, 27]               0
          Conv2d-121          [-1, 256, 27, 27]         262,144
     BatchNorm2d-122          [-1, 256, 27, 27]             512
            ReLU-123          [-1, 256, 27, 27]               0
          Conv2d-124          [-1, 256, 27, 27]         589,824
     BatchNorm2d-125          [-1, 256, 27, 27]             512
            ReLU-126          [-1, 256, 27, 27]               0
          Conv2d-127         [-1, 1024, 27, 27]         262,144
     BatchNorm2d-128         [-1, 1024, 27, 27]           2,048
            ReLU-129         [-1, 1024, 27, 27]               0
      Bottleneck-130         [-1, 1024, 27, 27]               0
          Conv2d-131          [-1, 256, 27, 27]         262,144
     BatchNorm2d-132          [-1, 256, 27, 27]             512
            ReLU-133          [-1, 256, 27, 27]               0
          Conv2d-134          [-1, 256, 27, 27]         589,824
     BatchNorm2d-135          [-1, 256, 27, 27]             512
            ReLU-136          [-1, 256, 27, 27]               0
          Conv2d-137         [-1, 1024, 27, 27]         262,144
     BatchNorm2d-138         [-1, 1024, 27, 27]           2,048
            ReLU-139         [-1, 1024, 27, 27]               0
      Bottleneck-140         [-1, 1024, 27, 27]               0
          Conv2d-141          [-1, 512, 27, 27]         524,288
     BatchNorm2d-142          [-1, 512, 27, 27]           1,024
            ReLU-143          [-1, 512, 27, 27]               0
          Conv2d-144          [-1, 512, 14, 14]       2,359,296
     BatchNorm2d-145          [-1, 512, 14, 14]           1,024
            ReLU-146          [-1, 512, 14, 14]               0
          Conv2d-147         [-1, 2048, 14, 14]       1,048,576
     BatchNorm2d-148         [-1, 2048, 14, 14]           4,096
          Conv2d-149         [-1, 2048, 14, 14]       2,097,152
     BatchNorm2d-150         [-1, 2048, 14, 14]           4,096
            ReLU-151         [-1, 2048, 14, 14]               0
      Bottleneck-152         [-1, 2048, 14, 14]               0
          Conv2d-153          [-1, 512, 14, 14]       1,048,576
     BatchNorm2d-154          [-1, 512, 14, 14]           1,024
            ReLU-155          [-1, 512, 14, 14]               0
          Conv2d-156          [-1, 512, 14, 14]       2,359,296
     BatchNorm2d-157          [-1, 512, 14, 14]           1,024
            ReLU-158          [-1, 512, 14, 14]               0
          Conv2d-159         [-1, 2048, 14, 14]       1,048,576
     BatchNorm2d-160         [-1, 2048, 14, 14]           4,096
            ReLU-161         [-1, 2048, 14, 14]               0
      Bottleneck-162         [-1, 2048, 14, 14]               0
          Conv2d-163          [-1, 512, 14, 14]       1,048,576
     BatchNorm2d-164          [-1, 512, 14, 14]           1,024
            ReLU-165          [-1, 512, 14, 14]               0
          Conv2d-166          [-1, 512, 14, 14]       2,359,296
     BatchNorm2d-167          [-1, 512, 14, 14]           1,024
            ReLU-168          [-1, 512, 14, 14]               0
          Conv2d-169         [-1, 2048, 14, 14]       1,048,576
     BatchNorm2d-170         [-1, 2048, 14, 14]           4,096
            ReLU-171         [-1, 2048, 14, 14]               0
      Bottleneck-172         [-1, 2048, 14, 14]               0
          Conv2d-173          [-1, 256, 14, 14]         524,288
     BatchNorm2d-174          [-1, 256, 14, 14]             512
          Conv2d-175          [-1, 256, 14, 14]         589,824
     BatchNorm2d-176          [-1, 256, 14, 14]             512
          Conv2d-177          [-1, 256, 14, 14]          65,536
     BatchNorm2d-178          [-1, 256, 14, 14]             512
          Conv2d-179          [-1, 256, 14, 14]         524,288
     BatchNorm2d-180          [-1, 256, 14, 14]             512
detnet_bottleneck-181          [-1, 256, 14, 14]               0
          Conv2d-182          [-1, 256, 14, 14]          65,536
     BatchNorm2d-183          [-1, 256, 14, 14]             512
          Conv2d-184          [-1, 256, 14, 14]         589,824
     BatchNorm2d-185          [-1, 256, 14, 14]             512
          Conv2d-186          [-1, 256, 14, 14]          65,536
     BatchNorm2d-187          [-1, 256, 14, 14]             512
detnet_bottleneck-188          [-1, 256, 14, 14]               0
          Conv2d-189          [-1, 256, 14, 14]          65,536
     BatchNorm2d-190          [-1, 256, 14, 14]             512
          Conv2d-191          [-1, 256, 14, 14]         589,824
     BatchNorm2d-192          [-1, 256, 14, 14]             512
          Conv2d-193          [-1, 256, 14, 14]          65,536
     BatchNorm2d-194          [-1, 256, 14, 14]             512
detnet_bottleneck-195          [-1, 256, 14, 14]               0
          Conv2d-196           [-1, 30, 14, 14]          69,120
     BatchNorm2d-197           [-1, 30, 14, 14]              60
Total params: 26,728,060
Trainable params: 26,728,060
Non-trainable params: 0
Input size (MB): 2.00
Forward/backward pass size (MB): 1038.47
Params size (MB): 101.96
Estimated Total Size (MB): 1142.43
Process finished with exit code 0
