解决任何机器学习问题都需要花费大量精力来准备数据。PyTorch提供了许多工具来简化数据加载过程,并使代码更具可读性。在本教程中,我们将看到如何从自定义的数据集中加载和预处理/数据增强。
确保安装了以下工具包:
scikit-image
:用于图像输入/输出和变换pandas
:为了更轻松地解析csv
torchvision
软件包提供了一些常见的数据集和转换,甚至我们不需要编写自定义类,torchvision
中常用的数据集之一是ImageFolder
。其的格式如下:每一类创建一个文件夹,对应放入该类的所有图片。
root/ants/xxx.png
root/ants/xxy.jpeg
root/ants/xxz.png
.
.
.
root/bees/123.jpg
root/bees/nsdf3.png
root/bees/asd932_.png
data_transform
datasets.ImageFolder
DataLoader
import torch
from torchvision import transforms, datasets
data_transform = transforms.Compose([
transforms.RandomSizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
train_dataset = datasets.ImageFolder(root='./train',
transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=4, shuffle=True,
num_workers=4)
import torch
from torchvision import transforms, datasets
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
data_dir = 'data/hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
data_transforms[x])
for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
shuffle=True, num_workers=4)
for x in ['train', 'val']}
torch.utils.data.Dataset
,其主要包含三个方法:__init__
__getitem__
__len__
class ImageFolder(data.Dataset):
"""A generic data loader where the images are arranged in this way: ::
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
Args:
root (string): Root directory path.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
Attributes:
classes (list): List of the class names.
class_to_idx (dict): Dict with items (class_name, class_index).
imgs (list): List of (image path, class_index) tuples
"""
def __init__(self, root, transform=None, target_transform=None,
loader=default_loader):
classes, class_to_idx = find_classes(root)
imgs = make_dataset(root, class_to_idx)
if len(imgs) == 0:
raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
"Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
self.root = root
self.imgs = imgs
self.classes = classes
self.class_to_idx = class_to_idx
self.transform = transform
self.target_transform = target_transform
self.loader = loader
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is class_index of the target class.
"""
path, target = self.imgs[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.imgs)
其中类的默认输入参数里的 default_loader
函数,该函数主要分两种情况调用两个函数,一般采用的是pil_loader
函数。
def pil_loader(path):
with open(path, 'rb') as f:
with Image.open(f) as img:
return img.convert('RGB')
def accimage_loader(path):
import accimage
try:
return accimage.Image(path)
except IOError:
# Potentially a decoding problem, fall back to PIL.Image
return pil_loader(path)
def default_loader(path):
from torchvision import get_image_backend
if get_image_backend() == 'accimage':
return accimage_loader(path)
else:
return pil_loader(path)
自己的数据集格式为:假设 img_path
是你图片的文件夹,该文件夹下面放了所有图像的数据(包括train
和val
),然后txt_path
下面放了 train.txt
和 val.txt
两个文件,txt
中每行都是图像的路径,tab
键,标签。
|-img_path
|-- xxx.png
|-- xxy.png
|-- xqy.png
|-txt_path
|-- train.txt
|-- val.txt
train.txt
示例
img_path/xxx.png 0
img_path/xxy.png 1
img_path/xqy.png 0
定义自己的数据读取接口
class My_Data(Dataset):
def __init__(self, img_path, txt_path, dataset = '', data_transforms=None, loader = default_loader):
with open(txt_path) as input_file:
lines = input_file.readlines()
self.img_name = [os.path.join(img_path, line.strip().split('\t')[0]) for line in lines]
self.img_label = [int(line.strip().split('\t')[-1]) for line in lines]
self.data_transforms = data_transforms
self.dataset = dataset
self.loader = loader
def __getitem__(self, item):
img_name = self.img_name[item]
label = self.img_label[item]
img = self.loader(img_name)
if self.data_transforms is not None:
try:
img = self.data_transforms[self.dataset](img)
except:
print("Cannot transform image: {}".format(img_name))
return img, label
def __len__(self):
return len(self.img_name)
完整调用代码:
import torch
from torchvision import transforms, datasets
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
image_datasets = {x: My_Data(img_path='/ImagePath',
txt_path=('/TxtFile/' + x + '.txt'),
data_transforms=data_transforms,
dataset=x) for x in ['train', 'val']}
dataloders = {x: torch.utils.data.DataLoader(image_datasets[x],
batch_size=batch_size,
shuffle=True) for x in ['train', 'val']}
完整代码:Github
参考:
https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
https://blog.csdn.net/u014380165/article/details/78634829