PyTorch网络及参数查看

1. 查看网络

# 1. 直接打印网络
print(model)
# 2. 打印网络及按名称查看子网络
for name, module in net.named_modules():
    print(name, module)
# 3. 按名称查看子网络
for name, module in net.named_children():
    print(name, module)
# 4. 使用第三方工具库
# 使用 pip install torchsummary
import torch
import torchvision.models as models
from torchsummary import summary

device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 
model = models.resnet18().to(device) 
summary(model, (3,256,256))

2. 查看参数

# 1. 查看所有参数
for param in net.parameters():
    print(param)
# 2. 按名称查看所有参数
for name,parameters in net.named_parameters():
    print(name,':',parameters.size())
# 2. 查看个别网络参数
print(net.fc.weight.item())
print(net.fc.bias.item())

3. 查看计算图

使用tensorboard

你可能感兴趣的:(PyTorch)