怎么在python里面加载数据集_Python数据集加载,pytorch

# -*- coding: utf-8 -*-

"""

Created on Sat Jul 18 12:27:15 2020

@author: 陨星落云

"""

#%%

from torchvision import datasets

import torch

#%% 下载数据并加载训练集

path2data = "./data"

train_data = datasets.MNIST(path2data,train=True,download=False)

#%% 抽取训练集数据与标签

x_train,y_train = train_data.data,train_data.targets

print("x_train:",x_train.shape)

print("y_train:",y_train.shape)

#%% 加载验证集

val_data = datasets.MNIST(path2data,train=False,download=False)

#%% 抽取验证集数据与标签

x_val,y_val = val_data.data,val_data.targets

print("x_val:",x_val.shape)

print("y_val:",y_val.shape)

#%% 在张量中增加一个维度

if len(x_train.shape)==3:

x_train = x_train.unsqueeze(1)

print(x_train.shape)

if len(x_val.shape)==3:

x_val = x_val.unsqueeze(1)

print(x_val.shape)

#%% 导入需要的包

from torchvision import utils

import matplotlib.pylab as plt

import numpy as np

#%% 显示图像函数

def show(img):

# tensor转numpy

npimg = img.numpy()

# 转H*W*C

npimg_tr = np.transpose(npimg,(1,2,0))

plt.imshow(npimg_tr,interpolation="nearest")

plt.show()

#%% 批量显示图像

# make a grid of 40 images, 8 images per row

x_grid = utils.make_grid(x_train[:40],nrow=8,padding=2)

print(x_grid.shape)

show(x_grid)

结果:

x_train: torch.Size([60000, 28, 28])

y_train: torch.Size([60000])

x_val: torch.Size([10000, 28, 28])

y_val: torch.Size([10000])

torch.Size([60000, 1, 28, 28])

torch.Size([10000, 1, 28, 28])

torch.Size([3, 152, 242])

怎么在python里面加载数据集_Python数据集加载,pytorch_第1张图片

你可能感兴趣的:(怎么在python里面加载数据集_Python数据集加载,pytorch)