Pytorch-OpCounter: Pytorch平台计算模型#Parameters和FLOPS的工具包

Pytorch-OpCounter: Pytorch平台计算模型#Parameters和FLOPS的工具包

  • OpCounter的安装
  • 使用示例
  • 实现原理

OpCounter (Github地址:https://github.com/Lyken17/pytorch-OpCounter)除了能够统计各种模型结构的参数以及FLOPS, 还能为那些特殊的运算定制化统计规则,非常好用。

OpCounter的安装

方式1: pip install thop
方式2: pip install --upgrade git+https://github.com/Lyken17/pytorch-OpCounter.git

使用示例

from torchvision.models import resnet50
from thop import profile, clevar_style

model = resnet50()
input = torch.randn(1, 3, 224, 224) # (batch_size, num_channel, Height, Width)
flops, params = profile(model, inputs=(input, )) 
print('flops: {}, params: {}'.format(flops, params))

输出结果如下:

flops: 2914598912.0, params: 7978856.0

如果模型中有自定义的特殊运算类:ModuleName,为其定义的运算统计规则为count_model, 如下:

class ModuleName(nn.Module):
    # your definition
def count_model(model, x, y):
    # your rule here

则调用时,可以通过参数custom_ops来定制:

flops, params = profile(model, inputs=(input, ), 
                        custom_ops={ModuleName: count_model})

此外,clevar_style可以对输出结果进行简单处理,以更好的展示:

flops, params = clever_format([flops, params], "%.3f")
print('flops: {}, params: {}'.format(flops, params))

实现原理

该工具为每一种基本操作都定义了参数统计和运算量计算,目前主要包含视觉方面的运算,包括各种卷积、激活函数、池化、批归一化等。例如最常见的二维卷积运算,它的统计代码如下所示:

def count_conv2d(m, x, y):
    x = x[0]
    cin = m.in_channels
    cout = m.out_channels
    kh, kw = m.kernel_size
    batch_size = x.size()[0]
    out_h = y.size(2)
    out_w = y.size(3)
    # ops per output element
    # kernel_mul = kh * kw * cin
    # kernel_add = kh * kw * cin - 1
    kernel_ops = multiply_adds * kh * kw
    bias_ops = 1 if m.bias is not None else 0
    ops_per_element = kernel_ops + bias_ops
    # total ops
    # num_out_elements = y.numel()
    output_elements = batch_size * out_w * out_h * cout
    total_ops = output_elements * ops_per_element * cin // m.groups
    m.total_ops = torch.Tensor([int(total_ops)])

你可能感兴趣的:(工具,Deep,learning,pytorch,深度学习)