MINIST 数据展示代码可视化minist(深入浅出pytorch)

import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nn

import numpy as np
batch_size = 64

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307, ), (0.3081, ))
])#把[]中的操作整成一个pipline,均值和标准差

train_dataset = datasets.MNIST(root='./dataset/mnist/',
                                train=True,
                                download=True,
                                transform=transform)
train_loader = DataLoader(train_dataset,
                          shuffle=True,
                          batch_size=batch_size)
test_dataset = datasets.MNIST(root='./dataset/mnist/',
                              train=False,
                              download=True,
                              transform=transform)
test_loader = DataLoader(test_dataset,
                         shuffle=False,
                         batch_size=batch_size)

import matplotlib.pyplot as plt

figure = plt.figure()
num_of_images = 60

for imgs,tragets in test_loader:
    break

for index in range(num_of_images):
    plt.subplot(6,10,index + 1)
    plt.axis('off')
    img = imgs[index,...]
    plt.imshow(img.numpy().squeeze(),cmap = 'gray_r')
plt.show()

MINIST 数据展示代码可视化minist(深入浅出pytorch)_第1张图片
MINIST 数据展示代码可视化minist(深入浅出pytorch)_第2张图片

你可能感兴趣的:(pytorch学习)