THOP 是 PyTorch 非常实用的一个第三方库,可以统计模型的 FLOPs 和参数量。使用方法为:
from thop import clever_format
from thop import profile
class YourModule(nn.Module):
# your definition
def count_your_model(model, x, y):
# your rule here
input = torch.randn(1, 3, 224, 224)
flops, params = profile(model, inputs=(input, ),
custom_ops={YourModule: count_your_model})
flops, params = clever_format([flops, params], "%.3f")
profile 函数实现机制是利用了 PyTorch 的 torch.nn.Module.register_forward_hook。
handler_collection = []
if custom_ops is None:
custom_ops = {}
嵌套定义add_hooks
函数,仅作用于网络的叶子节点。
torch.nn.Module.register_buffer 向模块添加持久缓冲区。这通常用于注册模型参数之外的缓冲区。例如,BatchNorm 的running_mean
不是参数,而是持久状态的一部分。可以使用给定名称作为属性访问缓冲区。
torch.numel 返回input
张量中的元素总数。
全局变量 register_hooks 定义了每种 op 对应的钩子函数,具体定义在 count_hooks.py 中。
torch.nn.Module.register_forward_hook 注册模块上的前向挂钩。
每次在 forward() 计算输出后都会调用该钩子。它应该有以下签名:
hook(module, input, output) -> None
钩子不应该修改输入或输出。返回i类型为torch.utils.hooks.RemovableHandle
。
def add_hooks(m):
if len(list(m.children())) > 0:
return
if hasattr(m, "total_ops") or hasattr(m, "total_params"):
logger.warning("Either .total_ops or .total_params is already defined in %s."
"Be careful, it might change your code's behavior." % str(m))
m.register_buffer('total_ops', torch.zeros(1))
m.register_buffer('total_params', torch.zeros(1))
for p in m.parameters():
m.total_params += torch.Tensor([p.numel()])
m_type = type(m)
fn = None
if m_type in custom_ops: # if defined both op maps, use custom_ops to overwrite.
fn = custom_ops[m_type]
elif m_type in register_hooks:
fn = register_hooks[m_type]
if fn is None:
if verbose:
print("THOP has not implemented counting method for ", m)
else:
if verbose:
print("Register FLOP counter for module %s" % str(m))
handler = m.register_forward_hook(fn)
handler_collection.append(handler)
预先获取模型的模式,后面进行恢复。
torch.nn.Module.apply 将fn
递归地应用于每个子模块(由.children()
返回)以及自身。 典型用途包括初始化模型的参数(另请参见 torch-nn-init)。
运行网络。
# original_device = model.parameters().__next__().device
training = model.training
model.eval()
model.apply(add_hooks)
with torch.no_grad():
model(*inputs)
遍历叶子节点,即有效实体,统计计算量和参数量。
total_ops = 0
total_params = 0
for m in model.modules():
if len(list(m.children())) > 0: # skip for non-leaf module
continue
total_ops += m.total_ops
total_params += m.total_params
total_ops = total_ops.item()
total_params = total_params.item()
清空handler_collection
中的元素句柄。
# reset model to original status
model.train(training)
for handler in handler_collection:
handler.remove()
return total_ops, total_params