大多数情况下,数据被分为训练集、验证集和测试集。我们使用训练集训练模型,在训练期间,使用验证集来跟踪模型的性能,使用测试集来评估最终的模型。一般情况下,测试集是被隐藏的。训练一个模型,我们最少需要训练集和验证集。当我们只有训练集的时候,会将训练集拆分为训练集,验证集两部分或者训练集、验证集和测试集三部分。
PyTorch的torchvision包提供了多个流行的数据集。
我们从torchvision中导入MNIST数据集:
from torchvision import datasets
# path to store data and/or load from
path2data = "./data"
# laod training data
train_data = datasets.MNIST(path2data, train=True, download=True)
# extract data and targets
x_train, y_train = train_data.data, train_data.targets
print(x_train.shape)
print(y_train.shape)
torch.Size([60000, 28, 28])
torch.Size([60000])
# loading validation data
val_data = datasets.MNIST(path2data, train=False, download=True)
# extract data and targets
x_val, y_val = val_data.data, val_data.targets
print(x_val.shape)
print(y_val.shape)
# torch.Size([10000, 28, 28])
# torch.Size([10000])
# add a dimension to tensor to become B*C*H*W
if len(x_train.shape)==3:
x_train = x_train.unsqueeze(1)
print(x_train.shape)
if len(x_val.shape)==3:
x_val= x_val.unsqueeze(1)
print(x_val.shape)
# torch.Size([60000, 1, 28, 28])
# torch.Size([10000, 1, 28, 28])
现在,我们显示一些样本
from torchvision import utils
import matplotlib.pyplot as plt
import numpy as np
def show(img):
# convert tensor to numpy array
npimg = img.numpy()
# convert to H*W*C shape
npimg_tr = np.transpose(npimg, (1, 2, 0))
plt.imshow(npimg_tr, interpolation="nearest")
plt.show()
# make a grid of 40 images, 8 images per row
x_grid = utils.make_grid(x_train[:40], nrow=8, padding=2)
print(x_grid.shape)
# call helper function
show(x_grid)
数据变换也叫做数据增强是一种提高模型性能的技术。torchvision包提供了常用的图像变换。
from torchvision import transforms
# loading MNIST training dataset
train_data = datasets.MNIST(path2data, train=True, download=True)
# define transformations
data_transform = transform.Compose([
transforms.RandomHorizontalFlip(p=1),
transforms.RandomVerticalFlip(p=1),
transforms.ToTensor(),
])
# get a sample image from training dataset
img = train_data[0][0]
# transfrom sample image
img_tr = data_transform(img)
# convert tensor to numpy array
img_tr_np = img_tr.numpy()
# show original and transformed images
plt.subplot(1,2,1)
plt.imshow(img, cmap="gray")
plt.title("original")
plt.subplot(1,2,2)
plt.imshow(img_tr_np[0], cmap="gray")
plt.title("transformed")
plt.show()
结果显示:
3.我们也可以把变换函数传递到dataset类中:
# define transformations
data_transform = transforms.Compose([
transforms.RandomHorizontalFlip(1),
transforms.RandomVerticalFlip(1),
transforms.ToTensor(),
])
# Loading MNIST training data with on-the-fly transformations
train_data = datasets.MNIST(path2data, train=True, download=True, transform=data_transform)
如果数据是以张量形式提供的,则可以使用以下命令将其包装为PyTorch数据集
TensorDataset类。
from torch.utils.data import TensorDataset
# warp tensors into a dataset
train_ds = TensorDataset(x_train, y_train)
val_ds = TensorDataset(x_val, y_val)
for x,y in train_ds:
print(x.shape, y.item())
break
# torch.Size([1, 28, 28]) 5
为了方便地在训练期间迭代数据,我们可以使用DataLoader类创建一个数据加载器,如下所示:
from torch.utils.data import DataLoader
# create a data loader from dataset
train_dl = DataLoader(train_ds, batch_size=8)
val_dl = DataLoader(val_ds, batch_size=8)
# iterate over batches
for xb,yb in train_dl:
print(xb.shape)
print(yb.shape)
break
# torch.Size([8,1,28,28])
# torch.Size([8])