进入pytorch官网(需要梯子):https://pytorch.org
Docs -> Torchvision
(1)点击Models and pre-trained weights可以下载已经训练好的神经网络。
(2)点击Datasets可以下载数据集。包括Image classification(图像分类)、Image detection or segmentation(图像检测和分割)、Optical Flow(光流)、Stereo Matching(立体匹配)、Image pairs(图像对)、Image captioning(图像描述、图像字幕)、Video classification(视频分类)、Video prediction(视频预测)、Base classes for custom datasets(自定义数据集的基类)和Transforms v2。
链接:https://pytorch.org/vision/stable/datasets.html
CIFAR10
数据集链接:https://www.cs.toronto.edu/~kriz/cifar.html
使用Python脚本下载:
root
为安装路径。
import torchvision
train_set = torchvision.datasets.CIFAR10(root="G:\\dataset_CIFAR10",
train=True,
download=True)
test_set = torchvision.datasets.CIFAR10(root="G:\\dataset_CIFAR10",
train=False,
download=True)
python代码:
import torchvision
train_set = torchvision.datasets.CIFAR10(root="G:\\dataset_CIFAR10",
train=True,
download=False)
test_set = torchvision.datasets.CIFAR10(root="G:\\dataset_CIFAR10",
train=False,
download=False)
print("test_set[0]: " + str(test_set[0]))
print("test_set.classes: " + str(test_set.classes))
img, target = test_set[0]
print("img: " + str(img))
print("target: " + str(target))
print("test_set.classes[target]: " + str(test_set.classes[target]))
img.show()
输出结果:
test_set[0]: (
, 3)
test_set.classes: [‘airplane’, ‘automobile’, ‘bird’, ‘cat’, ‘deer’, ‘dog’, ‘frog’, ‘horse’, ‘ship’, ‘truck’]
img:
target: 3
test_set.classes[target]: cat
Process finished with exit code 0
python脚本代码:
import torchvision
from torch.utils.tensorboard import SummaryWriter
dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
train_set = torchvision.datasets.CIFAR10(root="G:\\dataset_CIFAR10",
train=True,
transform=dataset_transform,
download=False)
test_set = torchvision.datasets.CIFAR10(root="G:\\dataset_CIFAR10",
train=False,
transform=dataset_transform,
download=False)
writer = SummaryWriter("p10")
for i in range(10):
img, target = test_set[i]
writer.add_image("test_set", img, i)
writer.close()
脚本运行后,点击Terminal,输入:
tensorboard --logdir=“p10”
Dataset类似于扑克牌堆,Dataloader类似于洗牌抓牌,如果shuffle为True图片需要洗牌(每次图片顺序不一样),shuffle为False图片不需要洗牌(每次图片顺序一样)。num_workers多进程,一般设置为0(windows下大于0会出现问题?)
imagenet_data = torchvision.datasets.ImageNet('path/to/imagenet_root/')
data_loader = torch.utils.data.DataLoader(imagenet_data,
batch_size=4,
shuffle=True,
num_workers=args.nThreads)
链接: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
python脚本代码:
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
# 准备的测试数据集
test_data = torchvision.datasets.CIFAR10(root="G:\\Anaconda\\pycharm_pytorch\\learning_project\\dataset_CIFAR10",
train=False,
transform=torchvision.transforms.ToTensor(),
download=False)
test_loader = DataLoader(dataset=test_data, batch_size=4, shuffle=True, num_workers=0, drop_last=False)
# 测试数据集中的第一张图片及target
img, target = test_data[0]
print(img.shape)
print(target)
writer = SummaryWriter("dataloader")
step = 1
for data in test_loader:
imgs, targets = data
print("imgs.shape: " + str(imgs.shape))
print("targets: " + str(targets))
writer.add_images("test_data", imgs, step)
step = step + 1
writer.close()
注意:添加多张图片函数为add_images()。
脚本运行后,点击Terminal,输入:
tensorboard --logdir=“dataloader”
如果将shuffle设置为False,则不对数据集进行洗牌,python代码如下:
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
# 准备的测试数据集
test_data = torchvision.datasets.CIFAR10(root="G:\\dataset_CIFAR10",
train=False,
transform=torchvision.transforms.ToTensor(),
download=False)
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=False, num_workers=0, drop_last=False)
writer = SummaryWriter("dataloader")
for epoch in range(2):
step_new = 1
for data in test_loader:
imgs, targets = data
writer.add_images("Epoch-{}".format(epoch), imgs, step_new)
step_new = step_new + 1
writer.close()
图片展示如下: