torch.utils.data.Dataset
torch.utils.data.DataLoader
以猫狗为例实现分类
按照如下图所示建立文件和文件夹,我这里自己准备了20张猫狗图像。
test.txt文件是后面代码生成的,先不用管,cats和dogs里面放上自己的图片,然后通过脚本生成test.txt文件,text.txt的脚本 代码如下:
#!/usr/bin/python
# -*- coding:utf-8 -*-
import os
def generate(dir, label):
files = os.listdir(dir)
files.sort()
listText = open('data/test/test.txt', 'a')
for file in files:
fileType = os.path.split(file)
if fileType[1] == '.txt':
continue
name = "/test/cats/" + file + ' ' + str(int(label)) + '\n'
print(name)
listText.write(name)
listText.close()
def generate1(dir, label):
files = os.listdir(dir)
files.sort()
listText = open('data/test/test.txt', 'a')
for file in files:
fileType = os.path.split(file)
if fileType[1] == '.txt':
continue
name = "/test/dogs/" + file + ' ' + str(int(label)) + '\n'
print(name)
listText.write(name)
listText.close()
outer_path = 'data/test' # 这里是你的图片的目录
if __name__ == '__main__':
i = 0
folderlist = os.listdir(outer_path) # 列举文件夹
for folder in folderlist:
if i == 0:
generate(os.path.join(outer_path, folder), i)
if i == 1:
generate1(os.path.join(outer_path, folder), i)
i += 1
由于就两个文件,此处就直接用两个相同的代码生成(鄙人代码功底不好,凑合着看),生成后的text.txt文件如下样式:前面是路径,后面是对应的标签。
到这里,样本集的收集以及简单归类已经完成啦,下面我们将开始采用pytorch的数据集相关API和类,也就是我们以后要经常用到的dataset和dataloader。
Dataset类的使用: 是一个抽象类,所有的类都应该是此类的子类(也就是说应该继承该类)。 所有的子类都要重写 len(), getitem() 这两个魔术方法(魔术方法在执行程序的时候会自动执行)。
len() 此方法应该提供数据集的大小(容量)
getitem() 此方法应该提供支持下标索方式引访问数据集
dataloader:对dataset获取的数据可以进行打包,打乱,变换操作。
简言之:dataset是获取数据,dataloader是对获取的数据进行变换等操作。
定义mydataset:
class MyDataset(Dataset):
def __init__(self, root_dir, names_file, transform=None):
self.root_dir = root_dir
self.names_file = names_file
self.transform = transform
self.size = 0
self.names_list = []
if not os.path.isfile(self.names_file):
print(self.names_file + 'i does not exist!')
file = open(self.names_file)
for f in file:
self.names_list.append(f)
self.size += 1
def __len__(self):
return self.size
def __getitem__(self, idx):
image_path = self.root_dir + self.names_list[idx].split(' ')[0]
print(image_path)
# image_path = self.names_list[idx].split(' ')[0]
if not os.path.isfile(image_path):
print(image_path + 'you does not exist!')
return None
image = io.imread(image_path) # use skitimage
label = int(self.names_list[idx].split(' ')[1])
sample = {'image': image, 'label': label}
if self.transform:
sample = self.transform(sample)
return sample
定义一个实例化对象:
train_dataset = MyDataset(root_dir='./data/train',
names_file='./data/train/train.txt',
transform=None)
plt.figure()
for (cnt,i) in enumerate(train_dataset):
image = i['image']
label = i['label']
ax = plt.subplot(4, 4, cnt+1)
ax.axis('off')
ax.imshow(image)
ax.set_title('label {}'.format(label))
plt.pause(0.001)
if cnt == 15:
break
注意修改一下自己的路径。
以上并没有用到dataloader这个类,下面使用dataloader对dataset得到的数据集进行变换:
先对图像数据集进行resize和转变为tensor向量:
# 变换Resize
class Resize(object):
def __init__(self, output_size: tuple):
self.output_size = output_size
def __call__(self, sample):
# 图像
image = sample['image']
# 使用skitimage.transform对图像进行缩放
image_new = transform.resize(image, self.output_size)
return {'image': image_new, 'label': sample['label']}
# # 变换ToTensor
class ToTensor(object):
def __call__(self, sample):
image = sample['image']
image_new = np.transpose(image, (2, 0, 1))
return {'image': torch.from_numpy(image_new),
'label': sample['label']}
然后调用dataloader函数对数据集进行处理:
# 对原始的训练数据集进行变换
transformed_trainset = MyDataset(root_dir='./data',
names_file='./data/test/test.txt',
transform=transforms.Compose(
[Resize((512,512)),
ToTensor()]
))
dataloader函数可以完成数据集打乱shuffle,batch,numworks(多线程)
可视化使用dataloader后的代码:
def show_images_batch(sample_batched):
images_batch, labels_batch = \
sample_batched['image'], sample_batched['label']
grid = make_grid(images_batch)
plt.imshow(grid.numpy().transpose(1, 2, 0))
# sample_batch: Tensor , NxCxHxW
plt.figure()
for i_batch, sample_batch in enumerate(trainset_dataloader):
show_images_batch(sample_batch)
plt.axis('off')
plt.ioff()
plt.show()
plt.show()