TensorFlow很多地方使用不如pytorch方便,比如说获取模型的Flops和parameters这种基本信息都需要查找半天。平时大家在分析模型优势的时候除了在准确率或者精度方面比较,还有一个就是运行效率了。但是每个人的电脑配置不一样,光靠运行时间比较也不好做,一般见得比较多的就是比较Flops和parameters。我之所以在标题上标明年份,是不想让大家浪费时间,很多帖子讲的都是1.X版本的做法,实在是太难用了,好多试了也不行,都是三四年前的帖子了,以下我的方法最近自己刚刚测试完,是可用的。
首先是使用的模型,就选个比较常见的吧
def Alexnet32():
inputs1 = Input(shape=(32, 32, 1))
conv1 = Conv2D(filters=16, kernel_size=3)(inputs1)
BN1 = BatchNormalization()(conv1)
act1 = Activation('relu')(BN1)
pool1 = MaxPooling2D(pool_size=3, strides=1)(act1)
conv4 = Conv2D(filters=32, kernel_size=3, padding='same')(pool1)
BN2 = BatchNormalization()(conv4)
act2 = Activation('relu')(BN2)
pool2 = MaxPooling2D(pool_size=3, strides=1)(act2)
conv5 = Conv2D(filters=128, kernel_size=3, padding='same',
activation='relu')(pool2)
conv6 = Conv2D(filters=128, kernel_size=3, padding='same',
activation='relu')(conv5)
conv7 = Conv2D(filters=128, kernel_size=3, strides=2,
activation='relu')(conv6)
BN3 = BatchNormalization()(conv7)
act3 = Activation('relu')(BN3)
pool3 = MaxPooling2D(pool_size=3, strides=1)(act3)
flat1 = Flatten()(pool3)
dense1 = Dense(300)(flat1)
BN4 = BatchNormalization()(dense1)
drop1 = Dropout(0.2)(BN4)
outputs = Dense(10, activation='softmax')(drop1)
model = Model(inputs=inputs1, outputs=outputs)
# model.summary() # 打印模型结构
return model
导入数据包,直接调用,通过pip install keras-flops安装就可以
from keras_flops import get_flops
flops = get_flops(Alexnet32(), batch_size=1)
print(f"FLOPS: {flops / 10 ** 6:.03} M")
结果比较全面,有各层的数据
==================Model Analysis Report======================
Doc:
scope: The nodes in the model graph are organized by their names, which is hierarchical like filesystem.
flops: Number of float operations. Note: Please read the implementation for the math behind it.
Profile:
node name | # float_ops
_TFProfRoot (--/307.61m flops)
model/conv2d_3/Conv2D (199.36m/199.36m flops)
model/conv2d_2/Conv2D (49.84m/49.84m flops)
model/conv2d_4/Conv2D (42.47m/42.47m flops)
model/dense/MatMul (7.68m/7.68m flops)
model/conv2d_1/Conv2D (7.23m/7.23m flops)
model/conv2d/Conv2D (259.20k/259.20k flops)
model/max_pooling2d_1/MaxPool (194.69k/194.69k flops)
model/max_pooling2d_2/MaxPool (115.20k/115.20k flops)
model/max_pooling2d/MaxPool (112.90k/112.90k flops)
model/conv2d_2/BiasAdd (86.53k/86.53k flops)
model/conv2d_3/BiasAdd (86.53k/86.53k flops)
model/batch_normalization_1/FusedBatchNormV3 (50.37k/50.37k flops)
model/batch_normalization_2/FusedBatchNormV3 (37.63k/37.63k flops)
model/batch_normalization/FusedBatchNormV3 (28.90k/28.90k flops)
model/conv2d_1/BiasAdd (25.09k/25.09k flops)
model/conv2d_4/BiasAdd (18.43k/18.43k flops)
model/conv2d/BiasAdd (14.40k/14.40k flops)
model/dense_1/MatMul (6.00k/6.00k flops)
model/batch_normalization_3/batchnorm/Rsqrt (600/600 flops)
model/batch_normalization_3/batchnorm/add (300/300 flops)
model/batch_normalization_3/batchnorm/add_1 (300/300 flops)
model/batch_normalization_3/batchnorm/mul (300/300 flops)
model/batch_normalization_3/batchnorm/mul_1 (300/300 flops)
model/batch_normalization_3/batchnorm/mul_2 (300/300 flops)
model/batch_normalization_3/batchnorm/sub (300/300 flops)
model/dense/BiasAdd (300/300 flops)
model/dense_1/Softmax (50/50 flops)
model/dense_1/BiasAdd (10/10 flops)
======================End of Report==========================
FLOPS: 3.08e+02 M
据说是将高版本函数直接应用过来的,不得不佩服
from typing import Any, Callable, Dict, List, Optional, Union
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2_as_graph
def try_count_flops(model: Union[tf.Module, tf.keras.Model],
inputs_kwargs: Optional[Dict[str, Any]] = None,
output_path: Optional[str] = None):
"""Counts and returns model FLOPs.
Args:
model: A model instance.
inputs_kwargs: An optional dictionary of argument pairs specifying inputs'
shape specifications to getting corresponding concrete function.
output_path: A file path to write the profiling results to.
Returns:
The model's FLOPs.
"""
if hasattr(model, 'inputs'):
try:
# Get input shape and set batch size to 1.
if model.inputs:
inputs = [
tf.TensorSpec([1] + input.shape[1:], input.dtype)
for input in model.inputs
]
concrete_func = tf.function(model).get_concrete_function(inputs)
# If model.inputs is invalid, try to use the input to get concrete
# function for model.call (subclass model).
else:
concrete_func = tf.function(model.call).get_concrete_function(
**inputs_kwargs)
frozen_func, _ = convert_variables_to_constants_v2_as_graph(concrete_func)
# Calculate FLOPs.
run_meta = tf.compat.v1.RunMetadata()
opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()
if output_path is not None:
opts['output'] = f'file:outfile={output_path}'
else:
opts['output'] = 'none'
flops = tf.compat.v1.profiler.profile(
graph=frozen_func.graph, run_meta=run_meta, options=opts)
return flops.total_float_ops
except Exception as e: # pylint: disable=broad-except
logging.info(
'Failed to count model FLOPs with error %s, because the build() '
'methods in keras layers were not called. This is probably because '
'the model was not feed any input, e.g., the max train step already '
'reached before this run.', e)
return None
return None
flops = try_count_flops(Alexnet32())
print(flops/1000000,"M Flops")
结果只有flops
307.611928 M Flops
这个查了不少,发现都是1.X版本的,需要想办法去兼容,但是这样一来搞不好又会有其他问题,而且我试了几个,一下子也没搞定,也不值得花太多时间。本来就有现成的方法,还是直接用model.summary()看就好了,细节也比较多。这个直接写在模型里运行Alexnet32()就能看到,也可以直接Alexnet32().summary()查看。
Model: "model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 32, 32, 1)] 0
_________________________________________________________________
conv2d (Conv2D) (None, 30, 30, 16) 160
_________________________________________________________________
batch_normalization (BatchNo (None, 30, 30, 16) 64
_________________________________________________________________
activation (Activation) (None, 30, 30, 16) 0
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 28, 28, 16) 0
_________________________________________________________________
conv2d_1 (Conv2D) (None, 28, 28, 32) 4640
_________________________________________________________________
batch_normalization_1 (Batch (None, 28, 28, 32) 128
_________________________________________________________________
activation_1 (Activation) (None, 28, 28, 32) 0
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 26, 26, 32) 0
_________________________________________________________________
conv2d_2 (Conv2D) (None, 26, 26, 128) 36992
_________________________________________________________________
conv2d_3 (Conv2D) (None, 26, 26, 128) 147584
_________________________________________________________________
conv2d_4 (Conv2D) (None, 12, 12, 128) 147584
_________________________________________________________________
batch_normalization_2 (Batch (None, 12, 12, 128) 512
_________________________________________________________________
activation_2 (Activation) (None, 12, 12, 128) 0
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 10, 10, 128) 0
_________________________________________________________________
flatten (Flatten) (None, 12800) 0
_________________________________________________________________
dense (Dense) (None, 300) 3840300
_________________________________________________________________
batch_normalization_3 (Batch (None, 300) 1200
_________________________________________________________________
dropout (Dropout) (None, 300) 0
_________________________________________________________________
dense_1 (Dense) (None, 10) 3010
=================================================================
Total params: 4,182,174
Trainable params: 4,181,222
Non-trainable params: 952
总参数量就是4182174,方法还是很简单直观的。