THOP: 统计 PyTorch 模型的 FLOPs 和参数量

THOP 是 PyTorch 非常实用的一个第三方库,可以统计模型的 FLOPs 和参数量。使用方法为:

from thop import clever_format
from thop import profile

class YourModule(nn.Module):
    # your definition
def count_your_model(model, x, y):
    # your rule here

input = torch.randn(1, 3, 224, 224)
flops, params = profile(model, inputs=(input, ), 
                        custom_ops={YourModule: count_your_model})
flops, params = clever_format([flops, params], "%.3f")

profile 函数实现机制是利用了 PyTorch 的 torch.nn.Module.register_forward_hook。

profile

    handler_collection = []
    if custom_ops is None:
        custom_ops = {}

嵌套定义add_hooks函数,仅作用于网络的叶子节点。
torch.nn.Module.register_buffer 向模块添加持久缓冲区。这通常用于注册模型参数之外的缓冲区。例如,BatchNorm 的running_mean不是参数,而是持久状态的一部分。可以使用给定名称作为属性访问缓冲区。
torch.numel 返回input张量中的元素总数。
全局变量 register_hooks 定义了每种 op 对应的钩子函数,具体定义在 count_hooks.py 中。
torch.nn.Module.register_forward_hook 注册模块上的前向挂钩。

每次在 forward() 计算输出后都会调用该钩子。它应该有以下签名:

hook(module, input, output) -> None

钩子不应该修改输入或输出。返回i类型为torch.utils.hooks.RemovableHandle

    def add_hooks(m):
        if len(list(m.children())) > 0:
            return

        if hasattr(m, "total_ops") or hasattr(m, "total_params"):
            logger.warning("Either .total_ops or .total_params is already defined in %s." 
                           "Be careful, it might change your code's behavior." % str(m))

        m.register_buffer('total_ops', torch.zeros(1))
        m.register_buffer('total_params', torch.zeros(1))

        for p in m.parameters():
            m.total_params += torch.Tensor([p.numel()])

        m_type = type(m)
        fn = None
        if m_type in custom_ops:  # if defined both op maps, use custom_ops to overwrite.
            fn = custom_ops[m_type]
        elif m_type in register_hooks:
            fn = register_hooks[m_type]

        if fn is None:
            if verbose:
                print("THOP has not implemented counting method for ", m)
        else:
            if verbose:
                print("Register FLOP counter for module %s" % str(m))
            handler = m.register_forward_hook(fn)
            handler_collection.append(handler)

预先获取模型的模式,后面进行恢复。
torch.nn.Module.apply 将fn递归地应用于每个子模块(由.children()返回)以及自身。 典型用途包括初始化模型的参数(另请参见 torch-nn-init)。
运行网络。

    # original_device = model.parameters().__next__().device
    training = model.training

    model.eval()
    model.apply(add_hooks)

    with torch.no_grad():
        model(*inputs)

遍历叶子节点,即有效实体,统计计算量和参数量。

    total_ops = 0
    total_params = 0
    for m in model.modules():
        if len(list(m.children())) > 0:  # skip for non-leaf module
            continue
        total_ops += m.total_ops
        total_params += m.total_params

    total_ops = total_ops.item()
    total_params = total_params.item()

清空handler_collection中的元素句柄。

    # reset model to original status
    model.train(training)
    for handler in handler_collection:
        handler.remove()

    return total_ops, total_params

你可能感兴趣的:(DeepLearning,PyTorch)