对于分类存储的图片,pytorch可以用ImageFolder直接读取,非常方便,但是当需要把训练集划分为训练加验证的话,这个就不太能胜任了。
参考将分类存储的图片切分为训练集、验证集和测试集(PyTorch实现),可以把数据集划分为训练集和数据集,根据自己的数据集和需求小改了一下代码。
原文是针对所有类别样本数目都一样写的,我改成了当每个类别样本数目不一样的时候怎么按比例划分。
from torchvision.datasets import ImageFolder
from PIL import Image
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision import transforms
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_transformer_ImageNet = transforms.Compose([
transforms.Resize(256),
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
])
val_transformer_ImageNet = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize
])
class MyDataset(Dataset):
def __init__(self, filenames, labels, transform):
self.filenames = filenames
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.filenames)
def __getitem__(self, idx):
image = Image.open(self.filenames[idx]).convert('RGB')
image = self.transform(image)
return image, self.labels[idx]
def split_Train_Val_Data(data_dir, ratio):
""" the sum of ratio must equal to 1"""
dataset = ImageFolder(data_dir) # data_dir精确到分类目录的上一级
character = [[] for i in range(len(dataset.classes))]
#print(dataset.class_to_idx)
for x, y in dataset.samples: # 将数据按类标存放
character[y].append(x)
#print(dataset.samples)
train_inputs, val_inputs, test_inputs = [], [], []
train_labels, val_labels, test_labels = [], [], []
for i, data in enumerate(character): # data为一类图片
num_sample_train = int(len(data) * ratio[0])
#print(num_sample_train)
num_sample_val = int(len(data) * ratio[1])
num_val_index = num_sample_train + num_sample_val
for x in data[:num_sample_train]:
train_inputs.append(str(x))
train_labels.append(i)
for x in data[num_sample_train:num_val_index]:
val_inputs.append(str(x))
val_labels.append(i)
#print(len(train_inputs))
train_dataloader = DataLoader(MyDataset(train_inputs, train_labels, train_transformer_ImageNet),
batch_size=8, shuffle=True)
val_dataloader = DataLoader(MyDataset(val_inputs, val_labels, val_transformer_ImageNet),
batch_size=8, shuffle=False)
return train_dataloader, val_dataloader
def data_loader(dataset_dir, batch_size):
img_data = ImageFolder(dataset_dir,
transform=transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor()])
)
data_loader = DataLoader(img_data, batch_size=batch_size, shuffle=True)
return data_loader
'''
if __name__ == '__main__':
data_dir = 'D:\\c\\graduation\\train_data\\data_ex'
train_dataloader, val_dataloader = split_ImageNet(data_dir, [0.8, 0.2])
for x, y in train_dataloader:
#print(x)
print(len(y))
'''