【使用ImageFolder加载数据】

【使用ImageFolder加载数据】

  • 1 使用ImageFolder的前提条件
  • 2 批量加载数据
  • 3 反转类别序号和关键字,绘制样例图

1 使用ImageFolder的前提条件

诸如图片的两分类问题,训练和测试的图片是分别存放好的,如下目录树:

+---test
|   +---airplane
			airplane_561.jpg
			...
			airplane_700.jpg
|   \---lake
			lake_561.jpg
			...
			lake_700.jpg
\---train
    +---airplane
			airplane_001.jpg
			...
			airplane_560.jpg
    \---lake
			lake_001.jpg
			...
			lake_560.jpg

分别存放好后,使用如下语句读取加载:

import torchvision

train_dir = r'2_class/train'
test_dir = r'2_class/test'

from torchvision import transforms

transform = transforms.Compose([
                  transforms.ToTensor(),
                  transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                       std=[0.5, 0.5, 0.5])
])

train_ds = torchvision.datasets.ImageFolder(
               train_dir,
               transform=transform
)


test_ds = torchvision.datasets.ImageFolder(
               test_dir,
               transform=transform
)

print(train_ds.classes)
print(train_ds.class_to_idx)
print(len(train_ds), len(test_ds))

输出如下:

['airplane', 'lake']
{'airplane': 0, 'lake': 1}
1120 280

如果是其它分类问题,也可以按照这种方法加载数据

2 批量加载数据

BATCHSIZE = 16
train_dl = torch.utils.data.DataLoader(
                                       train_ds,
                                       batch_size=BATCHSIZE,
                                       shuffle=True
)
test_dl = torch.utils.data.DataLoader(
                                       test_ds,
                                       batch_size=BATCHSIZE,
)

imgs, labels = next(iter(train_dl))
print(imgs.shape)   #一批次形状
print(imgs[0].shape)#一张图形状

im = imgs[0].permute(1, 2, 0)   #设置通道数为最后一维
print(im.shape)

输出如下:

torch.Size([16, 3, 256, 256])
torch.Size([3, 256, 256])
torch.Size([256, 256, 3])

3 反转类别序号和关键字,绘制样例图

id_to_class = dict((v, k) for k, v in train_ds.class_to_idx.items())
print(id_to_class)

输出如下:

{0: 'airplane', 1: 'lake'}

绘制样例图:

plt.figure(figsize=(12, 8))
for i, (img, label) in enumerate(zip(imgs[:6], labels[:6])):
    img = (img.permute(1, 2, 0).numpy() + 1)/2
    plt.subplot(2, 3, i+1)
    plt.title(id_to_class.get(label.item()))
    plt.xticks([])
    plt.yticks([])
    plt.imshow(img)
    plt.savefig('pics/4-2.jpg', dpi=400)

【使用ImageFolder加载数据】_第1张图片

你可能感兴趣的:(深度学习,pytorch,深度学习)