【Timm】timm.data 数据集全面详实概念理解

学习资源:Dataset | timmdocs (fast.ai)

目录

1.数据集Dataset 

2.ImageDataset 

2.1Parser解析器

 _getitem_(index: int) → Tuple[Any, Any]

2.2Usage使用

3.IterableImageDataset

__iter__

3.1Usage使用

4.AugMixDataset

4.1  _getitem_(index: int) → Tuple[Any, Any]

4.2Usage


timm库中,data的包括:

  • timm.data.Dataset
  • timm.data.Prefetch Loader

1.数据集Dataset 

timm库中有三个主要的Dataset类:

  • ImageDataset
  • IterableImageDataset
  • AugMixDataset

在这篇文档中,将单独学习,以及查看这些Dataset类的各种用例。 


2.ImageDataset 

imagedatset可以用来创建训练和验证数据集,它的功能与torchvision.datasets.ImageFolder非常相似,有一些不错的插件。

class ImageDataset(
    root: str, 
    parser: Union[ParserImageInTar, 
    ParserImageFolder, 
    str] = None, 
    class_map: Dict[str, str] = '', 
    load_bytes: bool = False, 
    transform: List = None) -> Tuple[Any, Any]:

2.1Parser解析器

  • parser根据create_parser方法自动设置。
  • parser在根目录中查找所有图像和目标,根目录的结构如下:
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png

root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
  •  parser设置一个class_to_idx字典,从类映射到整数,如下所示:
{'dog': 0, 'cat': 1, ..}
  • parser还有一个名为samples的属性,它是一个元组列表,类似于: 
[('root/dog/xxx.png', 0), ('root/dog/xxy.png', 0), ..., ('root/cat/123.png', 1), ('root/cat/nsdf3.png', 1), ...]

这个parser对象是可下标访问的,在执行类似parser[index]的操作时,它会返回self.samples中特定索引处的一个样本。因此,执行类似parser[0]的操作将返回('root/dog/xxx.png', 0)。 

 _getitem_(index: int) → Tuple[Any, Any]

设置了解析器,ImageDataset将根据索引从parser获取一个图像。然后它将图像读取为一个PIL。根据load_bytes参数将图像转换为RGB或以字节的形式读取图像。最后,它对图像进行变换并返回目标。如果target为None,则返回一个虚拟目标torch.tensor(-1)

img, target = self.parser[index]

2.2Usage使用

imagedatset也可以用来替换torchvision.datasets.ImageFolder。

【案例】考虑到现在有imagenette2-320数据集

①数据集结构

imagenette2-320
├── train
│   ├── n01440764
│   ├── n02102040
│   ├── n02979186
│   ├── n03000684
│   ├── n03028079
│   ├── n03394916
│   ├── n03417042
│   ├── n03425413
│   ├── n03445777
│   └── n03888257
└── val
    ├── n01440764
    ├── n02102040
    ├── n02979186
    ├── n03000684
    ├── n03028079
    ├── n03394916
    ├── n03417042
    ├── n03425413
    ├── n03445777
    └── n03888257

②每个子文件夹包含一组属于该类的. jpeg文件。

wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-320.tgz
gunzip imagenette2-320.tgz
tar -xvf imagenette2-320.tar

③然后,可以像这样创建一个ImageDatset: 

from timm.data.dataset import ImageDataset

dataset = ImageDataset('./imagenette2-320')
dataset[0]

(, 0)

④可以看到datad.parser是ParserImageFolder的一个实例:

dataset.parser

⑤最后,查看解析器中的class_to_idx字典映射: 

dataset.parser.samples[:5]

[('./imagenette2-320/train/n01440764/ILSVRC2012_val_00000293.JPEG', 0),
 ('./imagenette2-320/train/n01440764/ILSVRC2012_val_00002138.JPEG', 0),
 ('./imagenette2-320/train/n01440764/ILSVRC2012_val_00003014.JPEG', 0),
 ('./imagenette2-320/train/n01440764/ILSVRC2012_val_00006697.JPEG', 0),
 ('./imagenette2-320/train/n01440764/ILSVRC2012_val_00007197.JPEG', 0)]

3.IterableImageDataset

  • Timm提供了一个类似于PyTorch的IterableDatasetlterableImageDataset
  • 但有一个关键的区别——IterableImageDatset在生成图像和目标之前将转换应用于图像。
  • 当数据来自流或数据长度未知时,这种形式的数据集特别有用。

Timm将变换应用到图像上,并将目标设置为一个虚拟目标。当目标为None时,torch.tensor(-1, dtype=torch.long)。与上面的ImageDatset类似,IterableImageDatset首先创建一个parser,该parser获取基于根目录的样本元组。

__iter__

IterableImageDataset中的__iter__方法:

  • 首先,从self.parser获取一个图像和一个目标。
  • 然后,转换应用于图像。
  • 另外,在返回两者之前将目标设置为一个假值。

注意:IterableImageDataset没有定义__getitem__方法,因此它是不可下标访问的。执行类似于数据集[0]的操作(其中数据集是iterableimagedatset的实例)将返回错误。

3.1Usage使用

from timm.data import IterableImageDataset
from timm.data.parsers.parser_image_folder import ParserImageFolder
from timm.data.transforms_factory import create_transform 

root = '../../imagenette2-320/'
parser = ParserImageFolder(root)
iterable_dataset = IterableImageDataset(root=root, parser=parser)
parser[0], next(iter(iterable_dataset))

IterableImageDataset是不可下标访问的。 

iterable_dataset[0]
> > 
---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
 in 
----> 1 iterable_dataset[0]

~/opt/anaconda3/lib/python3.8/site-packages/torch/utils/data/dataset.py in __getitem__(self, index)
     30 
     31     def __getitem__(self, index) -> T_co:---> 
     32         raise NotImplementedError     
     33 
     34     def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':

NotImplementedError:

4.AugMixDataset

AugmixDataset支持ImageDatset并将其转换为AugmixDataset。

class AugmixDataset(dataset: ImageDataset, num_splits: int = 2):

What's an Augmix Dataset and when would we need to do this?

Loss Output表示Xorig, Xaugmix1和Xaugmix2上标签和模型预测之间的分类损失和λ乘以Jensen-Shannon损失的总和。

注意:augmix1和augmix2是原始批处理的扩充版本,其中扩充是从操作列表中随机选择的。

对于这种情况,需要三个版本的批处理——original, augmix1和augmix2。使用AugmixDataset!

4.1  _getitem_(index: int) → Tuple[Any, Any]

  • 从self.dataset得到一个X和对应的标签y,它是传递给AugmixDataset构造函数的数据集。
  • 对图像X进行规范化,并将其添加到一个名为x_list的变量中。
  • 基于num_splits(默认值为0),对X应用扩展,规范化扩展输出并将其附加到x_list。
    • 如果num_splits=2,那么x_list有两个项- original + augmented。
    • 如果num_splits=3,那么x_list有三个项- original + augmented1 + augmented2。

4.2Usage

from timm.data import ImageDataset, IterableImageDataset, AugMixDataset, create_loader

dataset = ImageDataset('../../imagenette2-320/')
dataset = AugMixDataset(dataset, num_splits=2)
loader_train = create_loader(
    dataset, 
    input_size=(3, 224, 224), 
    batch_size=8, 
    is_training=True, 
    scale=[0.08, 1.], 
    ratio=[0.75, 1.33], 
    num_aug_splits=2
)
# Requires GPU to work

next(iter(loader_train))[0].shape

>> torch.Size([16, 3, 224, 224])

Q:传入batch_size=8,但批大小返回loader_train是16?为什么会这样呢?

A:num_aug_splits=2。在这种情况下,loader_train有前8张原始图像和后8张图像表示augmix1。
num_aug_splits=3,那么有效的batch_size将是24,其中前8张图像将是原始图像,下8张表示augmix1,最后8张表示augmix2。

你可能感兴趣的:(【PyTorch】,深度学习,人工智能)