计算 tensorflow 和 pytorch 模型的浮点运算数

本文主要讨论如何计算 tensorflowpytorch 模型的 FLOPs。如有表述不当之处欢迎批评指正。欢迎任何形式的转载,但请务必注明出处。

目录

  • 1. 引言
  • 2. 模型结构
  • 3. 计算模型的 FLOPs
    • 3.1. tensorflow 1.12.0
    • 3.2. tensorflow 2.3.1
    • 3.3. pytorch 1.10.1+cu102
    • 3.4. 结果对比
  • 4. 总结

1. 引言

FLOPsfloating point operations 的缩写,指浮点运算数,可以用来衡量模型/算法的计算复杂度。本文主要讨论如何在 tensorflow 1.x, tensorflow 2.x 以及 pytorch 中利用相关工具计算对应模型的 FLOPs

2. 模型结构

为了说明方便,先搭建一个简单的神经网络模型,其模型结构以及主要参数如表1 所示。

1 模型结构及主要参数
Layers channels Kernels Strides Units Activation
Conv2D 32 (4,4) (1,2) \ relu
GRU \ \ \ 96 \
Dense \ \ \ 256 sigmoid

tensorflow(实际使用 tensorflow 中的 keras 模块)实现该模型的代码为:

from tensorflow.keras.layers import *
from tensorflow.keras.models import load_model, Model

def test_model_tf(Input_shape):
    # shape: [B, C, T, F]
    main_input = Input(batch_shape=Input_shape, name='main_inputs')
    
    conv = Conv2D(32, kernel_size=(4, 4), strides=(1, 2), activation='relu', data_format='channels_first', name='conv')(main_input)
    
    # shape: [B, T, FC]
    gru = Reshape((conv.shape[2], conv.shape[1] * conv.shape[3]))(conv)
    gru = GRU(units=96, reset_after=True, return_sequences=True, name='gru')(gru)
    
    output = Dense(256, activation='sigmoid', name='output')(gru)
    
    model = Model(inputs=[main_input], outputs=[output])
    
    return model

pytorch 实现该模型的代码为:

import torch
import torch.nn as nn

class test_model_torch(nn.Module):
    def __init__(self):
        super(test_model_torch, self).__init__()

        self.conv2d = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(4,4), stride=(1,2))
        self.relu = nn.ReLU()

        self.gru = nn.GRU(input_size=4064, hidden_size=96)

        self.fc = nn.Linear(96, 256)
        self.sigmoid = nn.Sigmoid()

    def forward(self, inputs):
        # shape: [B, C, T, F]
        out = self.conv2d(inputs)
        out = self.relu(out)
        
        # shape: [B, T, FC]
        batch, channel, frame, freq = out.size()
        out = torch.reshape(out, (batch, frame, freq*channel))
        out, _ = self.gru(out)
        
        out = self.fc(out)
        out = self.sigmoid(out)

        return out

3. 计算模型的 FLOPs

本节讨论的版本具体为:tensorflow 1.12.0, tensorflow 2.3.1 以及 pytorch 1.10.1+cu102

3.1. tensorflow 1.12.0

tensorflow 1.12.0 环境中,可以使用以下代码计算模型的 FLOPs

import tensorflow as tf
import tensorflow.keras.backend as K

def get_flops(model):
    run_meta = tf.RunMetadata()
    opts = tf.profiler.ProfileOptionBuilder.float_operation()

    flops = tf.profiler.profile(graph=K.get_session().graph,
                                run_meta=run_meta, cmd='op', options=opts)
 
    return flops.total_float_ops

if __name__ == "__main__":
    x = K.random_normal(shape=(1, 1, 100, 256))
    model = test_model_tf(x.shape)
    print('FLOPs of tensorflow 1.12.0:', get_flops(model))

3.2. tensorflow 2.3.1

tensorflow 2.3.1 环境中,可以使用以下代码计算模型的 FLOPs

import tensorflow.compat.v1 as tf
import tensorflow.compat.v1.keras.backend as K
tf.disable_eager_execution()

def get_flops(model):
    run_meta = tf.RunMetadata()
    opts = tf.profiler.ProfileOptionBuilder.float_operation()

    flops = tf.profiler.profile(graph=K.get_session().graph,
                                run_meta=run_meta, cmd='op', options=opts)
 
    return flops.total_float_ops

if __name__ == "__main__":
    x = K.random_normal(shape=(1, 1, 100, 256))
    model = test_model_tf(x.shape)
    print('FLOPs of tensorflow 2.3.1:', get_flops(model))

3.3. pytorch 1.10.1+cu102

pytorch 1.10.1+cu102 环境中,可以使用以下代码计算模型的 FLOPs(需要安装 thop):

import thop

x = torch.randn(1, 1, 100, 256)
model = test_model_torch()
flops, _ = thop.profile(model, inputs=(x,))
print('FLOPs of pytorch 1.10.1:', flops * 2)

需要注意的是,thop 返回的是 MACs (Multiply–Accumulate Operations),其等于 2 2 2 倍的 FLOPs,所以上述代码有乘 2 2 2 操作。

3.4. 结果对比

三者计算出的 FLOPs 分别为:
tensorflow 1.12.0
tf1.12.0
tensorflow 2.3.1
tf2.3.1
pytorch 1.10.1
pytorch 1.10.1
可以看到 tensorflow 1.12.0tensorflow 2.3.1 的结果基本在同一个量级,而与 pytorch 1.10.1 计算出来的相差甚远。但如果将上述模型结构改为只包含第一层 Conv2D,三者计算出来的 FLOPs 却又是一致的。所以推断差异主要来自于 GRUFLOPs。如读者知道其中详情,还请不吝赐教。

4. 总结

本文给出了在 tensorflow 1.x, tensorflow 2.x 以及 pytorch 中利用相关工具计算模型 FLOPs 的方法,但从本文所使用的测试模型来看, tensorflowpytorch 统计出的结果相差甚远。当然,也可以根据网络层的类型及其对应的参数,推导计算出每个网络层所需的 FLOPs

你可能感兴趣的:(神经网络,FLOPs,模型计算复杂度,浮点运算数,tensorflow,2.x,tensorflow,2.0)