pytorch学习笔记—池化层

import torchvision
from tensorboardX import SummaryWriter
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
import torch

test_set = torchvision.datasets.CIFAR10("./datasets2",train=False,transform=transforms.ToTensor(),download=True)#获取数据集

test_loader = DataLoader(dataset=test_set,batch_size=64)#使用Dataloader设置batchsize


class nn_maxpool(nn.Module):

    def __init__(self) -> None:
        super().__init__()
        self.my_maxpool = nn.MaxPool2d(kernel_size=3,stride=3,ceil_mode=True)#当cei_mode为true时,不到卷积核大小也会保存


    def forward(self,x):
        output = self.my_maxpool(x)
        return output

nn_maxpool_true =nn_maxpool()
# input = torch.tensor([[1,0,0],
#                       [1,1,1],
#                       [1,1,1]],dtype=torch.float)#这个tensor只是二维,但是根据API文档要求是4维,所以必须进行一个rewshape
# input = torch.reshape(input,(-1,1,3,3))
#
# output = nn_maxpool_true(input)
# print(output)

writer = SummaryWriter("logs")
step = 0
for data in test_loader:
    imgs,targets = data
    output = nn_maxpool_true(imgs)
    writer.add_images("input",imgs,step)
    writer.add_images("output",output,step)
    step = step+1

writer.close()

你可能感兴趣的:(python,pytorch,深度学习)