Pytorch网络结构可视化:torchsummary

一、单通道输入网络

单通道输入的情况大致有以下两种结构:

1、结构1:只有一条路可以走

在这里插入图片描述

2、结构2:输入为一条路,输出为多条路

Pytorch网络结构可视化:torchsummary_第1张图片
以上两种的输入只有一个input,这种是经常遇到的情况。

import torch
import torch.nn as nn
from torchsummary import summary

class Network(nn.Module): 
    def __init__(self, channels_img, features_d, num_classes, img_size): 
        super(Network, self).__init__()
        self.img_size = img_size
        self.disc = nn.Conv2d(
            in_channels = channels_img + 1, 
            out_channels = features_d, 
            kernel_size = (4,4)
        )

        # ConditionalGan: 
        self.embed = nn.Embedding(
            num_embeddings = num_classes, 
            embedding_dim = img_size * img_size
        )

   def forward(self, x, labels): 
        embedding = self.embed(labels).view(labels.shape[0], 1, self.img_size, self.img_size)
        x = torch.cat([x, embedding], dim = 1)
        return self.disc(x) 
    
# device: 
device = torch.device("cpu")

# hyperparameter: 
batch_size = 64

# Initialize model: 
model = Network(
    channels_img = 1, 
    features_d = 16, 
    num_classes = 10, 
    img_size = 28).to(device) 

# Print model summary: 
summary(
    model, 
    input_size = [(1, 28, 28), (1, 28, 28)], # MNIST
    batch_size = batch_size
)



参考资料:
PYTORCH网络结构可视化(TORCHSUMMARY+TENSORBOARDX+网络多输通道入情况+多分支网络)
在Python中嵌入层:如何在Torchsummary中正确使用?
pytorch 网络结构可视化方法汇总(三种实现方法详解)

你可能感兴趣的:(#,Pytorch,AI/模型训练,AI/模型调优,pytorch,深度学习,人工智能)