Pytorch计算神经网络的参数量和浮点操作数

参考链接:https://github.com/ShichenLiu/CondenseNet/blob/master/utils.py

计算参数量

遍历模型的参数,并提取参数的维度:

def get_layer_param(model):
    return sum([torch.numel(param) for param in model.parameters()])

计算浮点操作数:

使用multiply-add来衡量计算量,即权重相乘后的相加次数算为一次。解析常用Pytorch模块的操作的浮点计算次数。
2D卷积层[conv2d]
设输入的特征图维度为 ( C i n , H , W ) (C_{in},H,W) (Cin,H,W), H , W , C i n H,W,C_{in} H,W,Cin分别为特征图的高度,宽度和通道数。卷积层的维度为 ( K , C i n , k H , k W ) (K,C_{in},kH,kW) (K,Cin,kH,kW) K , C i n , h , w K,C_{in},h,w K,Cin,h,w分别为filter数量,kernel的厚度,高和宽,padding为 P h , P w P_{h},P_{w} Ph,Pw,stride为 S h , S w S_{h},S_{w} Sh,Sw,组数为 g g g(若为普通卷积,则 g = 1 g=1 g=1)。设输出的特征图维度为 ( K , H ′ , W ′ ) (K,H^{'},W^{'}) (K,H,W),其中 H ′ = H + 2 P h − k H S h + 1 , W ′ = W + 2 P w − k W S w + 1 m u l t i − a d d = H ′ × W ′ × C i n × K × k H × k W / g H^{'}=\frac{H+2P_{h}-kH}{S_{h}}+1,W^{'}=\frac{W+2P_{w}-kW}{S_{w}}+1\\ multi-add=H^{'}\times W^{'}\times C_{in}\times K\times kH\times kW/g H=ShH+2PhkH+1,W=SwW+2PwkW+1multiadd=H×W×Cin×K×kH×kW/g
Note that流行model的卷积层不再使用偏置,所以只有权重的multi-add操作数。
平均池化层[AvgPool2d,AdaptiveAvgPool2D]
AvgPool2d
设输入的特征图维度为 ( C i n , H , W ) (C_{in},H,W) (Cin,H,W),池化层的kernel规模为 ( k H , k W ) (kH,kW) (kH,kW),padding为 P h , P w P_{h},P_{w} Ph,Pw,stride为 S h , S w S_{h},S_{w} Sh,Sw,输出的特征图维度为 ( C i n , H ′ , W ′ ) (C_{in},H^{'},W^{'}) (Cin,H,W), H ′ , W ′ H^{'},W^{'} H,W与2d卷积层计算规则相同。
m u l t i − a d d = C i n × k H × k W × H ′ × W ′ multi-add=C_{in}\times kH\times kW\times H^{'}\times W^{'} multiadd=Cin×kH×kW×H×W
AdaptiveAvgPool2D
m u l t i − a d d = C i n × H × W multi-add=C_{in}\times H\times W multiadd=Cin×H×W
线性层[Linear]
设权重的维度为 O × I O\times I O×I,则偏置的维度为 O O O:
m u l t i − a d d = O × I + O = O ( I + 1 ) multi-add=O\times I+ O=O(I+1) multiadd=O×I+O=O(I+1)

使用Pytorch递归遍历叶子节点模块来计算multi-adds

count_ops = 0

def measure_layer(layer, x, multi_add=1):
    type_name = str(layer)[:str(layer).find('(')].strip()
    print(type_name)
    if type_name in ['Conv2d']:
        out_h = int((x.size()[2] + 2 * layer.padding[0] - layer.kernel_size[0]) //
                    layer.stride[0] + 1)
        out_w = int((x.size()[3] + 2 * layer.padding[1] - layer.kernel_size[1]) //
                    layer.stride[1] + 1)
        delta_ops = layer.in_channels * layer.out_channels * layer.kernel_size[0] *  \
                layer.kernel_size[1] * out_h * out_w // layer.groups * multi_add

    ### ops_nonlinearity
    elif type_name in ['ReLU']:
        delta_ops = x.numel()

    ### ops_pooling
    elif type_name in ['AvgPool2d']:
        in_w = x.size()[2]
        kernel_ops = layer.kernel_size * layer.kernel_size
        out_w = int((in_w + 2 * layer.padding - layer.kernel_size) // layer.stride + 1)
        out_h = int((in_w + 2 * layer.padding - layer.kernel_size) // layer.stride + 1)
        delta_ops = x.size()[1] * out_w * out_h * kernel_ops

    elif type_name in ['AdaptiveAvgPool2d']:
        delta_ops = x.numel()

    ### ops_linear
    elif type_name in ['Linear']:
        weight_ops = layer.weight.numel() * multi_add
        bias_ops = layer.bias.numel()
        delta_ops = weight_ops + bias_ops

    elif type_name in ['BatchNorm2d']:
        normalize_ops = x.numel()
        scale_shift = normalize_ops
        delta_ops = normalize_ops + scale_shift

    ### ops_nothing
    elif type_name in ['Dropout2d', 'DropChannel', 'Dropout']:
        delta_ops = 0

    ### unknown layer type
    else:
        raise TypeError('unknown layer type: %s' % type_name)

    global count_ops
    count_ops += delta_ops
    return

def is_leaf(module):
    return sum(1 for x in module.children()) == 0

# 判断是否为需要计算flops的结点模块
def should_measure(module):
	# 代码中的残差结构可能定义了空内容的Sequential
    if str(module).startswith('Sequential'):
        return False
    if is_leaf(module):
        return True
    return False

def measure_model(model, shape=(1,3,224,224)):
    global count_ops
    data = torch.zeros(shape)

    # 将计算flops的操作集成到forward函数
    def new_forward(m):
        def lambda_forward(x):
            measure_layer(m, x)
            return m.old_forward(x)
        return lambda_forward

    def modify_forward(model):
        for child in model.children():
            if should_measure(child):
                # 新增一个old_forward属性保存默认的forward函数
                # 便于计算flops结束后forward函数的恢复
                child.old_forward = child.forward
                child.forward = new_forward(child)
            else:
                modify_forward(child)

    def restore_forward(model):
        for child in model.children():
            # 对修改后的forward函数进行恢复
            if is_leaf(child) and hasattr(child, 'old_forward'):
                child.forward = child.old_forward
                child.old_forward = None
            else:
                restore_forward(child)

    modify_forward(model)
    # forward过程中对全局的变量count_ops进行更新
    model.forward(data)
    restore_forward(model)

    return count_ops

if __name__ == '__main__':
    net = ResNet18()
    print(measure_model(net))
    # ≈1.8G,和原文一致

你可能感兴趣的:(深度学习与pytorch)