2021-11-14 泉搭建的model

image.png
# -*- coding: utf-8 -*-
"""
Created on %(date)s

@Author : %(QuanWang)s
          [email protected]
"""

import torch 
import torch.nn as nn
import torchvision
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
dataset=torchvision.datasets.CIFAR10("../data",train=False, transform=torchvision.transforms.ToTensor(), download=False)
dataloader=DataLoader(dataset=dataset, batch_size=64)


class Quan_CIFAR(nn.Module):
    def __init__(self):
        super(Quan_CIFAR,self).__init__()
        # self.conv1=Conv2d(3,32,5,padding=2)
        # self.maxpool1=MaxPool2d(2)
        # self.conv2=Conv2d(32, 32, 5,padding=2)
        # self.maxpool2=MaxPool2d(2)
        # self.conv3=Conv2d(32, 64, 5,padding=2)
        # self.maxpool3=MaxPool2d(2)
        # self.flatten=Flatten()
        # self.Linear1=Linear(1024,64)
        # self.Linear2=Linear(64,10)
        
        self.model1=nn.Sequential(
            Conv2d(3,32,5,padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5,padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5,padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024,64),
            Linear(64,10)
            
            )
    def forward(self,x):
        # x=self.conv1(x)
        # x=self.maxpool1(x)
        # x=self.conv2(x)
        # x=self.maxpool2(x)
        # x=self.conv3(x)
        # x=self.maxpool3(x)
        # x=self.flatten(x)
        # x=self.Linear1(x)
        # x=self.Linear2(x)
        x=self.model1(x)
        return x
    
model=Quan_CIFAR()
print(model)
input=torch.ones((64,3,32,32))
output=model(input)
print(output.shape)
writer=SummaryWriter('../logs_seq')
writer.add_graph(model,input)
writer.close()

你可能感兴趣的:(2021-11-14 泉搭建的model)