【pytorch】model

打印网络结构(不带节点名称):

for ele in model.modules():
    print(ele)

打印named_parameters():

for (name, param) in model.named_parameters():
    if not param.requires_grad:
        print(name, param.data)

几个主要api的区别:pytorch model.named_parameters() ,model.parameters() ,model.state_dict().items()

打印模型状态:

import torch

model = torch.nn.BatchNorm2d((10, 3, 112, 112))
print(model)             # BatchNorm2d((10, 3, 112, 112), eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
print(model.training)    # True
model.train()
print(model.training)    # True
model.eval()
print(model.training)    # False
model.train()
print(model.training)    # True

你可能感兴趣的:(PyTorch,框架,pytorch,python,深度学习)