学习资源:Dataset | timmdocs (fast.ai)
目录
1.数据集Dataset
2.ImageDataset
2.1Parser解析器
_getitem_(index: int) → Tuple[Any, Any]
2.2Usage使用
3.IterableImageDataset
__iter__
3.1Usage使用
4.AugMixDataset
4.1 _getitem_(index: int) → Tuple[Any, Any]
4.2Usage
timm库中,data的包括:
timm库中有三个主要的Dataset类:
在这篇文档中,将单独学习,以及查看这些Dataset类的各种用例。
imagedatset可以用来创建训练和验证数据集,它的功能与torchvision.datasets.ImageFolder非常相似,有一些不错的插件。
class ImageDataset(
root: str,
parser: Union[ParserImageInTar,
ParserImageFolder,
str] = None,
class_map: Dict[str, str] = '',
load_bytes: bool = False,
transform: List = None) -> Tuple[Any, Any]:
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
{'dog': 0, 'cat': 1, ..}
[('root/dog/xxx.png', 0), ('root/dog/xxy.png', 0), ..., ('root/cat/123.png', 1), ('root/cat/nsdf3.png', 1), ...]
这个parser对象是可下标访问的,在执行类似parser[index]的操作时,它会返回self.samples中特定索引处的一个样本。因此,执行类似parser[0]的操作将返回('root/dog/xxx.png', 0)。
设置了解析器,ImageDataset将根据索引从parser获取一个图像。然后它将图像读取为一个PIL。根据load_bytes参数将图像转换为RGB或以字节的形式读取图像。最后,它对图像进行变换并返回目标。如果target为None,则返回一个虚拟目标torch.tensor(-1)。
img, target = self.parser[index]
imagedatset也可以用来替换torchvision.datasets.ImageFolder。
【案例】考虑到现在有imagenette2-320数据集
①数据集结构
imagenette2-320
├── train
│ ├── n01440764
│ ├── n02102040
│ ├── n02979186
│ ├── n03000684
│ ├── n03028079
│ ├── n03394916
│ ├── n03417042
│ ├── n03425413
│ ├── n03445777
│ └── n03888257
└── val
├── n01440764
├── n02102040
├── n02979186
├── n03000684
├── n03028079
├── n03394916
├── n03417042
├── n03425413
├── n03445777
└── n03888257
②每个子文件夹包含一组属于该类的. jpeg文件。
wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-320.tgz
gunzip imagenette2-320.tgz
tar -xvf imagenette2-320.tar
③然后,可以像这样创建一个ImageDatset:
from timm.data.dataset import ImageDataset
dataset = ImageDataset('./imagenette2-320')
dataset[0]
(, 0)
④可以看到datad.parser是ParserImageFolder的一个实例:
dataset.parser
⑤最后,查看解析器中的class_to_idx字典映射:
dataset.parser.samples[:5]
[('./imagenette2-320/train/n01440764/ILSVRC2012_val_00000293.JPEG', 0),
('./imagenette2-320/train/n01440764/ILSVRC2012_val_00002138.JPEG', 0),
('./imagenette2-320/train/n01440764/ILSVRC2012_val_00003014.JPEG', 0),
('./imagenette2-320/train/n01440764/ILSVRC2012_val_00006697.JPEG', 0),
('./imagenette2-320/train/n01440764/ILSVRC2012_val_00007197.JPEG', 0)]
Timm将变换应用到图像上,并将目标设置为一个虚拟目标。当目标为None时,torch.tensor(-1, dtype=torch.long)。与上面的ImageDatset类似,IterableImageDatset首先创建一个parser,该parser获取基于根目录的样本元组。
IterableImageDataset中的__iter__方法:
注意:IterableImageDataset没有定义__getitem__方法,因此它是不可下标访问的。执行类似于数据集[0]的操作(其中数据集是iterableimagedatset的实例)将返回错误。
from timm.data import IterableImageDataset
from timm.data.parsers.parser_image_folder import ParserImageFolder
from timm.data.transforms_factory import create_transform
root = '../../imagenette2-320/'
parser = ParserImageFolder(root)
iterable_dataset = IterableImageDataset(root=root, parser=parser)
parser[0], next(iter(iterable_dataset))
IterableImageDataset是不可下标访问的。
iterable_dataset[0]
> >
---------------------------------------------------------------------------
NotImplementedError Traceback (most recent call last)
in
----> 1 iterable_dataset[0]
~/opt/anaconda3/lib/python3.8/site-packages/torch/utils/data/dataset.py in __getitem__(self, index)
30
31 def __getitem__(self, index) -> T_co:--->
32 raise NotImplementedError
33
34 def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
NotImplementedError:
AugmixDataset支持ImageDatset并将其转换为AugmixDataset。
class AugmixDataset(dataset: ImageDataset, num_splits: int = 2):
What's an Augmix Dataset and when would we need to do this?
Loss Output表示Xorig, Xaugmix1和Xaugmix2上标签和模型预测之间的分类损失和λ乘以Jensen-Shannon损失的总和。
注意:augmix1和augmix2是原始批处理的扩充版本,其中扩充是从操作列表中随机选择的。
对于这种情况,需要三个版本的批处理——original, augmix1和augmix2。使用AugmixDataset!
如果num_splits=3,那么x_list有三个项- original + augmented1 + augmented2。
from timm.data import ImageDataset, IterableImageDataset, AugMixDataset, create_loader
dataset = ImageDataset('../../imagenette2-320/')
dataset = AugMixDataset(dataset, num_splits=2)
loader_train = create_loader(
dataset,
input_size=(3, 224, 224),
batch_size=8,
is_training=True,
scale=[0.08, 1.],
ratio=[0.75, 1.33],
num_aug_splits=2
)
# Requires GPU to work
next(iter(loader_train))[0].shape
>> torch.Size([16, 3, 224, 224])
Q:传入batch_size=8,但批大小返回loader_train是16?为什么会这样呢?
A:num_aug_splits=2。在这种情况下,loader_train有前8张原始图像和后8张图像表示augmix1。
num_aug_splits=3,那么有效的batch_size将是24,其中前8张图像将是原始图像,下8张表示augmix1,最后8张表示augmix2。