ImageNet数据由3通道RGB图像组成。因此,为了能够在大多数库中使用预先训练的权值,模型期望一个3通道的输入图像。
比如对于resnet34,如果我们使用1个channel的输入的话:
import torch import torchvision m = torchvision.models.resnet34(pretrained=True) x = torch.randn(1, 1, 224, 224) try: m(x).shape except Exception as e: print(e) ''' Given groups=1, weight of size [64, 3, 7, 7], expected input[1, 1, 224, 224] to have 3 channels, but got 1 channels instead '''
是会报错的
此时的一种方法是将1维的channel复制两次,成为三维的channel
import torch import torchvision m = torchvision.models.resnet34(pretrained=True) x = torch.randn(1, 1, 224, 224) x=torch.cat((x,x,x),1)# 新增了这一行 try: print(m(x).shape) except Exception as e: print(e) #torch.Size([1, 1000])
然而,如果维度比3多的话,可能就没有办法删去某个维度,然后使用预训练模型。它们可以做的只是随机初始化权重,自己训练。
输入channel是1或者25都ok了
import timm
m = timm.create_model('resnet34', pretrained=True, in_chans=1)
x = torch.randn(1, 1, 224, 224)
m(x).shape
#torch.Size([1, 1000])
m = timm.create_model('resnet34', pretrained=True, in_chans=25)
# 25-channel image
x = torch.randn(1, 25, 224, 224)
m(x).shape
#torch.Size([1, 1000])
timm数据库中,有三种主要的数据集类:
ImageDataset
IterableImageDataset
AugMixDataset
2.1 ImageDataset
与
torchvision.datasets.ImageFolder 类似,ImageDataset的作用是创建训练集和验证集
通过使用create_parser函数,我们可以自动设置解析器
解析器找到所有root路径上的图片和目标
root路径结构如下所示
解析器创建一个class_to_idx字典:
同时有一个叫samples的元组列表:
解析器是可以下标访问的, parser[index]将返回一个self.samples中标签是index的样本(比如parser[0],会返回一个('root/dog/xxx.png', 0)
__getitem__(index: int) → Tuple[Any, Any]
一旦解析器创建完毕,那我们可以用以下方式获得图片和标签
img, target = self.parser[index]
然后将图像识别成PIL.Image,然后转换成RGB图像,还是读取成二进制,这取决于load_bytes语句
如果图片没有target,那么我们将target设置为-1
ImageDataset也可以作为torchvision.datasets.ImageFolder的一个代替
假设我们有imagenette2-320数据集,他的文件架构如下所示
数据集来源:
wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-320.tgz
每一个 n****都是一个文件夹,里面是属于这个类的JPEG文件
创建 ImageDataset:
from timm.data.dataset import ImageDataset dataset = ImageDataset('./imagenette2-320') dataset[0] #(
, 0)
dataset.parser
from timm.data.dataset import ImageDataset dataset = ImageDataset('./imagenette2-320') dataset.parser #
class_to_idx
from timm.data.dataset import ImageDataset dataset = ImageDataset('./imagenette2-320') dataset.parser.class_to_idx ''' {'n01440764': 0, 'n02102040': 1, 'n02979186': 2, 'n03000684': 3, 'n03028079': 4, 'n03394916': 5, 'n03417042': 6, 'n03425413': 7, 'n03445777': 8, 'n03888257': 9} '''
paser的sample
from timm.data.dataset import ImageDataset dataset = ImageDataset('./imagenette2-320') dataset.parser.samples[:5] ''' [('./imagenette2-320\\train\\n01440764\\ILSVRC2012_val_00000293.JPEG', 0), ('./imagenette2-320\\train\\n01440764\\ILSVRC2012_val_00002138.JPEG', 0), ('./imagenette2-320\\train\\n01440764\\ILSVRC2012_val_00003014.JPEG', 0), ('./imagenette2-320\\train\\n01440764\\ILSVRC2012_val_00006697.JPEG', 0), ('./imagenette2-320\\train\\n01440764\\ILSVRC2012_val_00007197.JPEG', 0)] '''
可视化一张数据的图片
import matplotlib.pyplot as plt # plt 用于显示图片 import matplotlib.image as mpimg # mpimg 用于读取图片 import numpy as np lena = mpimg.imread(dataset.parser.samples[0][0]) # 读取和代码处于同一目录下的 lena.png # 此时 lena 就已经是一个 np.array 了,可以对它进行任意处理 lena.shape #(512, 512, 3) plt.imshow(lena) # 显示图片 plt.axis('off') # 不显示坐标轴 plt.show()
和pytorch的 IterableDataset 类似,timm提供了 IterableImageDataset。
和ImageDataset相似,
IterableImageDataset首先创建一个解析器,他也基于根目录创建一组样本。
和ImageDataset相似,解析器也返回一组图像,图像的target也是图像所在的文件夹名称
***但有一点需要注意,IterableImageDataset并没有__getitem__方法,因此他不可以用下标访问。
dataset[0]会报错
从IterableImageDataset的解析器中得到图片和对应的标签
from timm.data import IterableImageDataset
from timm.data.parsers.parser_image_folder import ParserImageFolder
from timm.data.transforms_factory import create_transform
root = './imagenette2-320/'
parser = ParserImageFolder(root)
iterable_dataset = IterableImageDataset(root=root, parser=parser)
parser[0]
# (<_io.BufferedReader name='./imagenette2-320/train\\n01440764\\ILSVRC2012_val_00000293.JPEG'>,0)
next(iter(iterable_dataset))
# (<_io.BufferedReader name='./imagenette2-320/train\\n01440764\\ILSVRC2012_val_00000293.JPEG'>,0)
augmix 是一种数据增强的方法
class AugmixDataset(
dataset: ImageDataset,
num_splits: int = 2)
最后的返回结果是 original data 和num_splits-1 轮的增强数据(每一轮增强数据都是原始数据的基础上获得的)
__getitem__(index: int) -> Tuple[Any, Any]
2.3.2 使用方法
这个需要GPU,所以我在服务器上跑的
>>> from timm.data import ImageDataset, IterableImageDataset, AugMixDataset, create_loader
>>>
>>> dataset = ImageDataset('./imagenette2-320/')
>>> dataset = AugMixDataset(dataset, num_splits=2)
>>> loader_train = create_loader(
... dataset,
... input_size=(3, 224, 224),
... batch_size=8,
... is_training=True,
... scale=[0.08, 1.],
... ratio=[0.75, 1.33],
... num_aug_splits=2
... )
>>> next(iter(loader_train))[0].shape
torch.Size([16, 3, 224, 224])
注意看这里,我们的batch_size是8,返回的是16维,因为original是8,这里augmix又是8维
timm的 Dataloader比`torch.utils.data.DataLoader`快,且略有不同
创建timm的dataloader的最基本的方法就是调用timm.data.loader中的create_loader。它需要一个dataset对象,一个input_size和一个batch_size
创建 ImageDataset:
from timm.data.dataset import ImageDataset dataset = ImageDataset('./imagenette2-320') dataset[0] #(
, 0)
from timm.data.loader import create_loader try: # only works if gpu present on machine train_loader = create_loader(dataset, (3, 224, 224), 4) except: train_loader = create_loader(dataset, (3, 224, 224), 4, use_prefetcher=False)
那么,这里为什么要用异常处理语句呢?
timm 有一个类PrefetchLoader。我们默认用这个DataLoader来创建我们的DataLoader。但是它只工作在GPU上。
我本地的train_loader:
服务器(有GPU)的train_loader: