pytorch查看网络架构的几种方法

一、Print(model)

import torch
from torch import nn

net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 1))
print(net)

outpu:
Sequential(
(0): Linear(in_features=4, out_features=8, bias=True)
(1): ReLU()
(2): Linear(in_features=8, out_features=1, bias=True)
)

二、torchsummary

import torch
from torch import nn

net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 1))

from torchsummary import summary
print(summary(net,input_size=(2,4)))

----------------------------------------------------------------
Layer (type) Output Shape Param #
==============================================================
Linear-1 [-1, 2, 8] 40
ReLU-2 [-1, 2, 8] 0
Linear-3 [-1, 2, 1] 9
==============================================================
Total params: 49
Trainable params: 49
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
----------------------------------------------------------------

你可能感兴趣的:(机器学习,pytorch,深度学习,人工智能)