timm库(CV利器)的入门教程(1)

省流:使用timm加载CNN进行图像分类,调整CNN使之更适合你的任务

问:使用timm搭建一个可以使用的CNN或ViT拢共需要几步?

答:4步

0.安装 timm

1.import timm

2.创建model

3.运行model

这一节很基础,会的兄弟们可以跳过看后面的

接下来具体讲一下如何使用,代码codebook会之后给出

安装、导入

准备自己python环境,建议用anconda来管理,然后安装上torch 和cuda

对于小白,可以先不安装cuda,只在cpu上跑代码

配置环境网上有很多教程,环境准备好了之后就可以在cmd安装timm包了,很简单:

conda imstall timm
#or
pip install timm
#or
git clone https://github.com/rwightman/pytorch-image-models
cd pytorch-image-models && pip install -e .

新建文件

在文件开头导入必要的包

import torch
import timm

创建、使用模型

创建模型的最简单方法是使用create_model;

这一个可用于在 timm 库中创建任何模型的工厂函数

这个函数各个参数有什么用,内部具体怎么实现的,怎么玩出花来的之后再讲,先只用它来创建一个CNN用来做分类任务

model_resnet34 = timm.create_model('resnet34', pretrained=True)

'resnet34'是模型架构的名字

pretrained=True则会自动从网上下载训练好的模型权重加载到resnet34上

然后模型就创建好了,可以直接使用了

这里我们使用随机张量表示图像

torch.randn:用来生成随机数字的tensor,这些随机数字满足标准正态分布(0~1)

x = torch.randn([1, 1, 224, 224])#创建一个tensor 代表 一张3x224x224的图片
out = model_resnet50(x)#out就是x所对应的表示类别的一个tensor
print(out.shape)# Results: torch.Size([1, 1000])代表1000个类别

我们可以看到模型已经处理了图像并返回了预期的输出形状。

查看模型信息

那么怎么知道timm都可以导入哪些模型来使用呢?

model_list = timm.list_models()#返回一个包含所有模型名称的list
print(len(model_list))#964
pretrain_model_list = timm.list_models(pretrained = True)#筛选出带预训练模型的
print(len(pretrain_model_list))#770
##使用通配符字符串来列出可用的不同 ResNet 变体
resnet_model_list = timm.list_models('*resnet*')
pretrain_resnet_model_list = timm.list_models('*resnet*' , pretrained = True)

调整模型-创建适合自己的模型

直接导入训练好的模型并不是万能的,经常会有维度不匹配的情况 比如说我的resnet34模型在cifar10和imagenet两个数据集上进行训练,分类类别不一样,输入的图片大小不一样,那我应该怎么创建合适的模型呢?

改变输出类别数目

分类类别数量:num_classes

model的主体提取特征,之后往往会接一个mlp层用作分类

如果设置num_classes,表示重设全连接层,num_classes设置为你需要分类的类别数量即可

import torch
x = torch.randn([1, 3, 224, 224])
​
model_resnet34_out10 = timm.create_model('resnet34', pretrained=True, num_classes=10)
out = model_resnet34_out10 (x)
print(out.shape)# Results: torch.Size([1, 10])

改变输入通道数

输入通道数:in_chans

对图片的大小,可以在输入model之前进行resize处理到统一大小

但是如果输入的图片不是传统rgb图片,通道不是3怎么办

当然,我们可以复制单通道像素来创建3通道图像,从而将其单通道输入图像转换为3通道图像。但是对于timm,他又一套申请的参数加载模式,我们可以直接改变in_chans 来指定输入图像的通道数

通道数改变后,对应的权重参数会进行相应的处理,此处不作详细说明 可参照:Models API and Pretrained weights | timmdocs或直接查看源代码

x = torch.randn([1, 1, 224, 224])
model_resnet34_in1 = timm.create_model('resnet50',pretrained=True, in_chans=1)

特性

所有model都有一个通用的默认配置接口和API

所有模型都支持通过create_model提取中间特征(vit除外)

所有型号都有一个预训练重量加载器,可调整最后一个线性层,也可调整3通道输入为1个通道输入

Learning rate schedulers/Optimizers/Augment

你可能感兴趣的:(timm教程,深度学习,计算机视觉,人工智能)