pytroch模型可视化

1.trochinfo可视化

​
from torch import nn
from torchinfo import summary


net = nn.Sequential(
    nn.Linear(28 * 28, 400),
    nn.ReLU(),
    nn.Linear(400, 200),
    nn.ReLU(),
    nn.Linear(200, 100),
    nn.ReLU(),
    nn.Linear(100, 10)
).cuda()

summary(net, input_size=(1, 784), col_names=["input_size",
                                             "output_size",
                                             "num_params",
                                             "kernel_size",
                                             "mult_adds",
                                             "trainable", ])

​

pytroch模型可视化_第1张图片

2. torchsummary

from torch import nn
from torchsummary import summary

net = nn.Sequential(
    nn.Linear(28 * 28, 400),
    nn.ReLU(),
    nn.Linear(400, 200),
    nn.ReLU(),
    nn.Linear(200, 100),
    nn.ReLU(),
    nn.Linear(100, 10)
).cuda()


summary(net, input_size=(784,))

pytroch模型可视化_第2张图片

 

你可能感兴趣的:(人工智能,python,深度学习)