torchvision:torchvision
是Pytorch的一个图形库,主要用来构建计算机视觉模型,torchvision
由以下四个部分构成
torchvision.datasets
:包括一些加载数据的函数和常用的数据集接口torchvision.models
:包含常用的模型结构(含预训练模型),例如AlexNet、ResNet等等torchvision.transforms
:包含一些常见的图片变换,例如裁剪、旋转等等torchvision.utils
:其他用法torchvision.datasets:该模块下既有官方提供的数据集,也有自定义数据集的类,两者都是torch.utils.data.Dataset
的子类,因此可以直接输入到torch.utils.data.DataLoader
中去
torchvision.datasets
中提供的官方数据如下,这些数据集详细介绍见此文:数据集介绍
MNIST
Fashion-MNIST
KMNIST
EMNIST
FakeData
COCO
Captions
Detection
LSUN
ImageFolder
DatasetFolder
Imagenet-12
CIFAR
STL10
SVHN
PhotoTour
SBU
Flickr
VOC
Cityscapes
...
这里我们以MNIST数据集为例,演示一下这些官方数据集如何加载,其余数据集的加载和MNIST一致
如下,使用torchvision.datasets.MNIST
加载MNIST数据集
train_data = dataset.MNIST(root='./mnist/',
train=True,
transform=transforms.ToTensor(),
download=True)
test_data = dataset.MNIST(root='./mnist/',
train=False,
transform=transforms.ToTensor(),
download=False)
root
:表示数据集待存放的目录train
:如果为true
将会使用训练集的数据集(training.pt
),如果为false
将会使用测试集数据集(test.pt
)download
:如果为true
将会从网络上下载并放入root
中,如果数据集已下载则不会再次下载transform
:接受PIL图片并返回转换后的图片,常用的就是转换为tensor
(这里便会调用torchvision.transform
)数据集加载成功后,文件布局如下
这里的自定义数据集类指的主要是torchvision.datasets.ImageFolder()
,它继承自 torchvision.datasets.DatasetFolder()
,后者又继承自 torchvision.datasets.VisionDataset()
,而VisionDataset
则是 torch.utils.data.Dataset
的子类
以torchvision.datasets.CIFAR
数据集为例说明如何使用torchvision.datasets.ImageFolder()
,这里的torchvision.datasets.CIFAR
我已经将其转换为png格式存储
图片文件布局如下,torchvision.datasets.ImageFolder()
要求你的图片数据必须按照以下方式进行组织
torchvision.datasets.ImageFoler
参数说明
root
:图片存储的根目录,即各类别文件夹所在目录的上一级目录transform
:对图片进行预处理的操作(函数),原始图片作为输入,返回一个转换后的图片target_transform
:对图片类别进行预处理的操作,输入为 target
,输出对其的转换。如果不传该参数,即对 target
不做任何转换,返回的顺序索引 0,1, 2…loader
:表示数据集加载方式,通常默认加载方式即可torchvision.datasets.ImageFolder(root,transform,target_transform,loader)
如下,使用torchvision.datasets.ImageFoler
对前面的图片进行加载
transforms
部分可暂时忽略train_transforms = transforms.Compose([
transforms.RandomResizedCrop((28, 28)),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.RandomRotation(90),
transforms.RandomGrayscale(0.1),
transforms.ColorJitter(0.3, 0.3, 0.3, 0.3),
transforms.ToTensor()
])
train_dataset = torchvision.datasets.ImageFolder(root='./data/train', transform=train_transforms)
test_dataset = torchvision.datasets.ImageFolder(root='./data/test', transform=transforms.ToTensor)
print(len(train_dataset))
print(len(test_dataset))
同时,通过torchvision.datasets.ImageFolder
生成的train_dataset
和test_dataset
还有如下3个成员变量
self.classes
:使用一个list
保存类别名称self.class_to_idx
:类别对应的索引self.imgs
:是一个list
,每个元素是一个tuple
,每个tuple
保存的是(img-path, class)
print(train_dataset.classes[: 5])
print("-"*30)
print(train_dataset.class_to_idx)
print("-"*30)
print(train_dataset.imgs[: 5])
仍然以上述CIFAR10数据集为例,我们手动实现一下ImageFolder
,这对你理解它大有帮助
import torchvision.datasets
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import os
from PIL import Image
import numpy as np
import glob
# 类别名字
label_name = [
"airplane",
"automobile",
"bird",
"cat",
"deer",
"dog",
"frog",
"horse",
"ship",
"truck"
]
# 类比名字映射索引
label_dict = {}
for idx, name in enumerate(label_name):
label_dict[name] = idx
def default_loader(path):
return Image.open(path).convert("RGB")
train_transforms = transforms.Compose([
transforms.RandomResizedCrop((28, 28)),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.RandomRotation(90),
transforms.RandomGrayscale(0.1),
transforms.ColorJitter(0.3, 0.3, 0.3, 0.3),
transforms.ToTensor()
])
class MyDataset(Dataset):
"""
im_list:是一个列表,每一个元素是图片路径
transform:对图片进行增强
loader:使用PIL对图片进行加载
"""
def __init__(self, im_list, transform=None, loader=default_loader):
super(MyDataset, self).__init__()
# imgs为二维列表,每一个子列表中第一个元素存储im_list,第二个通过label_dict映射为索引
imgs = []
for im_item in im_list:
# 路径'./data/test/airplane/aeroplane_s_000002.png'中倒数第二个是标签名
im_label_name = im_item.split("\\")[-2]
imgs.append([im_item, label_dict[im_label_name]])
self.imgs = imgs
self.transform = transform
self.loader = loader
def __getitem__(self, index):
im__path, im_label = self.imgs[index]
# 会调用PIL加载图片数据
im_data = self.loader(im__path)
# 如果给了transoform那么就对图片进行增强
if self.transform is not None:
im_data = self.transform(im_data)
return im_data, im_label
def __len__(self):
return len(self.imgs)
if __name__ == '__main__':
im_train_list = glob.glob(r'./data/train/*/*.png')
im_test_list = glob.glob(r'./data/test/*/*.png')
train_dataset = MyDataset(im_train_list, transform=train_transforms)
test_dataset = MyDataset(im_test_list, transform=transforms.ToTensor())
print(len(train_dataset))
print(len(test_dataset))
train_loader = DataLoader(dataset=train_dataset, batch_size=6, shuffle=True, num_workers=0)
test_loader = DataLoader(dataset=test_dataset, batch_size=6, shuffle=False, num_workers=0)
torchvision.transforms:该模块是Pytorch中的图像预处理包,包含了一些常用的图像变换,主要实现对数据集的预处理、数据增强,转化为tensor
等操作
使用时如果有很多变换,那么一般会使用Compose将这些步骤给整合到一起
train_transforms = transforms.Compose([
transforms.RandomResizedCrop((28, 28)),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.RandomRotation(90),
transforms.RandomGrayscale(0.1),
transforms.ColorJitter(0.3, 0.3, 0.3, 0.3),
transforms.ToTensor()
])
train_dataset = torchvision.datasets.ImageFolder(root='./data/train', transform=train_transforms
如果变换时只有一种,那么一般会直接给到形参,比如最常使用到的ToTensor
test_dataset = torchvision.datasets.ImageFolder(root='./data/test', transform=transforms.ToTensor)
torchvision.transforms
涉及变换主要有以下4类
裁剪:
transforms.CenterCrop
transforms.RandomCrop
transforms.RandomResizedCrop
transforms.FiveCrop
transforms.TenCrop
翻转和旋转
transforms.RandomHorizontalFlip(p=0.5)
transforms.RandomVerticalFlip(p=0.5)
transforms.RandomRotation
图像变换和转换
transforms.Resize
transforms.Normalize
tensor
并归一化:transforms.ToTensor
transforms.Pad
transforms.ColorJitter
transforms.Grayscale
transforms.LinearTransformation
transforms.RandomAffine
transforms.RandomGrayscale
transforms.ToPILImage
其他操作
transforms
操作使数据增强更灵活:transforms.RandomChoice(transforms)
transforms
中选定一个操作:transforms.RandomApply(transforms, p=0.5)
transform
加上概率进行操作:transforms.RandomOrder
torchvision.models:该模块提供了很多图像处理中的常用模型,并且提供了预训练版本,使用时导入如下包
from torchvision import models
在models
中定义有如下模型
from .alexnet import *
from .convnext import *
from .densenet import *
from .efficientnet import *
from .googlenet import *
from .inception import *
from .mnasnet import *
from .mobilenet import *
from .regnet import *
from .resnet import *
from .shufflenetv2 import *
from .squeezenet import *
from .vgg import *
from .vision_transformer import *
from .swin_transformer import *
from .maxvit import *
from . import detection, optical_flow, quantization, segmentation, video
from ._api import get_model, get_model_builder, get_model_weights, get_weight, list_models
如何使用这些模型,以及如何修改,请参照下面的这批文章
总之我们对模型的修改可能主要集中在全连接层,比如原始模型最后全连接格式为
(fc): Linear(in_features=512, out_features=1000, bias=True)
如果处理的是cifar10数据集,那么out_features
应该是10,所以可以这样改
resnet = models.resnet18()
resnet.fc = nn.Linear(resnet.fc.in_features, 10)