PyTorch可视化网络结构

1. 安装pytorchviz

pip install git+https://github.com/szagoruyko/pytorchviz

2.  导入torchviz包

from torchviz import make_dot, make_dot_from_trace

3. 可视化

model = FusionGenerator(3,3,16)
x = Variable(torch.randn(16,3,256,256))#change 12 to the channel number of network input
y = model(x)
# g = make_dot(y)
# g.view()
make_dot(y, params=dict(list(model.named_parameters())))

 

你可能感兴趣的:(pytorch)