前面讲过的图片输入方式是从图片的文件夹来读取图片的一种方式。但是必须将类别单独放在一个文件夹。我们现在创建Dataset的子类来进行输入。
import torch
from torch.utils import data
from PIL import Image # pip install pillow
import numpy as np
from torchvision import transforms
import matplotlib.pyplot as plt
%matplotlib inline
import glob # 可以获取某一个条件下的所有的路径
# 自定义输入Dataset类
class MyDataset(data.Dataset):
def __init__(self,imgsPath):
self.imgs_path=imgsPath # 图片路径
def __getitem__(self,index):
return self.imgs_path[index]
def __len__(self):
return len(self.imgs_path)
# 获取所有图片的路径
all_imgs_path = glob.glob(r'E:\Codes\Python\PyTorch\dataset2\*.jpg') # 获取路径下,所以以.jpg结尾的图片路径。在该目录下,是四种天气的所有图片
#for i in range(5):
# print(all_imgs_path[i])
weather_dataset = MyDataset(all_imgs_path)
len(weather_dataset) # 1122。由于My_dataset内部实现了 __len__ 方法,所以可以使用len方法
创建Dataloader:
from torch.utils.data import DataLoader
wh_dl = DataLoader(weather_dataset,batch_size=4)
next(iter(wh_dl)) # 返回一个批次的数据(返回4张图片的路径)。列表形式
获取图片的标签:
我们获取了图片路径后,要获取它对应得标签值。
species = ['cloudy', 'rain', 'shine', 'sunrise']
# 将这4个类别使用数值型进行编码
species_to_idx = dict((c, i) for i, c in enumerate(species)) # cloudy:0,rain:1,....
print(species_to_idx)
idx_to_species = dict((v, k) for k, v in species_to_idx.items()) # 字典的items方法会以元祖形式返回key value对象
idx_to_species
# 提取所有图片的标签
all_labels = []
for img in all_imgs_path: # img就是一张图片的路径
for i, c in enumerate(species): # species = ['cloudy', 'rain', 'shine', 'sunrise']
if c in img:
all_labels.append(i) # 以数值形式代表标签
划分数据集:
index = np.random.permutation(len(all_imgs_path)) # 将所有图片的长度做一个乱序处理
index
all_imgs_path = np.array(all_imgs_path)[index] # 需要先将all_imgs_path转换array才能索引
all_labels = np.array(all_labels)[index]
s = int(len(all_imgs_path)*0.8) # 百分之八十作为训练数据集
train_imgs = all_imgs_path[:s]
train_labels = all_labels[:s]
test_imgs = all_imgs_path[s:]
test_labels = all_labels[s:]
transform = transforms.Compose([
transforms.Resize((96, 96)),
transforms.ToTensor(), # 将图片数据在0~1之间,0维度是channel
])
创建输入:下面要开始创建输入(上面只是演示)。
class Mydataset(data.Dataset):
def __init__(self, img_paths, labels, transform):
self.imgs = img_paths
self.labels = labels
self.transforms = transform
# 对于图片的读取和转换,我们也都放在__getitem__里面,当给一个索引时,返回的是图片对象,而不再是图片的路径了
# 所以要对图片进行读取和转换。
def __getitem__(self, index):
img = self.imgs[index]
label = self.labels[index]
pil_img = Image.open(img)
data = self.transforms(pil_img)
return data, label # 返回图片对象和对应的标签
def __len__(self):
return len(self.imgs)
BATCH_SIZE = 16
weather_dataset = Mydataset(all_imgs_path,all_labels,transform)
# 类型是 torch.utils.Dataset对象
weather_dl = data.DataLoader(weather_dataset,batch_size=BATCH_SIZE,shuffle=True,num_workers=0) # num_workers表示使用多少个进程读取,0表示不管它
# 取出一个批次的数据
imgs_batch,lables_batch = next(iter(weather_dl))
print(imgs_batch.shape) # imgs_batch的shape是:torch.Size([16,3,96,96])
print(lables_batch.shape) # torch.Size([16]))
plt.figure(figsize=(12, 8))
for i, (img, label) in enumerate(zip(imgs_batch[-6:], lables_batch[-6:])):
img = img.permute(1, 2, 0).numpy() # 交换每一维的数据,要将图片的channel(3)放在最后,再转换为ndarray的类型
plt.subplot(2, 3, i+1)
plt.title(idx_to_species.get(label.item())) # label是一个tensor,单个tensor获取标量使用item方法
plt.imshow(img)
创建Dataloader:
通过定义子类的方式创建图片的输入,这种方式不仅可以应用于图片,还可以应用于csv、ndarray数据都可以 。
train_ds = Mydataset(train_imgs,train_labels,transform)
test_ds = Mydataset(test_imgs,test_labels,transform)
train_dl = data.DataLoader(train_ds,batch_size=BATCH_SIZE,shuffle=True)
test_dl = data.DataLoader(test_ds,batch_size=BATCH_SIZE,shuffle=False)
imgs, labels = next(iter(train_dl))
imgs.shape,labels.shape
灵活的使用Dataset类构造输入: 比如train_dl已经创建好了,我要用在tf里面,通道数要放在后面,即将每个批次的数据变为[16, 96, 96,3],通过创建子类的方式可以实现。
class New_dataset(data.Dataset):
def __init__(self, some_dataset): # 输入的是已有的Dataset
self.ds = some_dataset
def __getitem__(self, index):
img, label = self.ds[index] # 返回的是对应的图片对应和标签。返回的是单个图片,没有批次大小,即[3,96,96]
img = img.permute(1, 2, 0) # 通道数变换。
return img, label
def __len__(self):
return len(self.ds)
train_new_dataset = New_dataset(train_ds)
img, label = train_new_dataset[2]
img.shape,label.shape # 现在是:h*w*c