【Timm】搭建Vision Transformer系列实践,终于见面了,Timm库!
不久前,探究如何构建基于vision transformer的模型,发现,更多重点应该是放在如何有效利用现有库调用及构建模型,这篇就主要记录调用Timm:create_model相关应用!
目录
1.create_model 概念
2.create_model的直接使用规则
3.create_model参数:pretrained
4.create_model函数的源码探索【案例学习】
①registry
②Register model
③default config
④build model with config
总结
根据英文翻译,也不难发现就是:创建模型,具体来说就是创建网络模型。结合Timm库提供可供调用的模型,利用create_model能够很方便的辅助我们实现一些模型,当然,我们也可以将自己的模型实现并注册进timm,方便调用。
上篇文章的案例就是最为简单的,直接调用timm库所现有的模型:
import timm
model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
补:list_models函数可以查看timm所提供的模型列表,即可直接创建、有预训练的模型列表:
all_pretrained_models_available = timm.list_models(pretrained=True)
print(all_pretrained_models_available)
print(len(all_pretrained_models_available))
##如果没有设置 pretrained=True 的话:
##将会输出612,即有预训练权重参数的模型有452个,没有预训练参数,只有模型结构的共有612个。
pretrained :True or False
model = timm.create_model('resnet50', pretrained=True)
【案例学习】timm 视觉库中的 create_model 函数详解
create_model主体只有50行左右的代码,那么,如何实现从模型到特征提取器的转换?已知timm.list_models()函数中的每一个模型名字(str)实际上都是一个函数。
###输入
import timm
import random
from timm.models import registry
m = timm.list_models()[-1]
print(m)
registry.is_model(m)
###输出
xception71
True
实际上,在 timm 内部,有一个字典称为 _model_entrypoints 包含了所有的模型名称和他们各自的函数。比如说,可以通过 model_entrypoint 函数从 _model_entrypoints 内部得到 xception71 模型的构造函数。
###输入
constuctor_fn = registry.model_entrypoint(m)
print(constuctor_fn)
###输出
or
在 timm.models.xception_aligned 模块中有一个函数称为 xception71 。类似的,timm 中的每一个模型都有着一个这样的构造函数。事实上,内部的 _model_entrypoints 字典大概长这个样子:
_model_entrypoints
> >
{
'cspresnet50':,'cspresnet50d': ,
'cspresnet50w': ,
'cspresnext50': ,
'cspresnext50_iabn': ,
'cspdarknet53': ,
'cspdarknet53_iabn': ,
'darknet53': ,
'densenet121': ,
'densenetblur121d': ,
'densenet121d': ,
'densenet169': ,
'densenet201': ,
'densenet161': ,
'densenet264': ,
}
所以说,在 timm 对应的模块中,每个模型都有一个构造器。比如说 ResNets 系列模型被定义在 timm.models.resnet 模块中。因此,实际上有两种方式来创建一个 resnet34 模型:
import timm
from timm.models.resnet import resnet34
# 使用 create_model
m = timm.create_model('resnet34')
# 直接调用构造函数
m = resnet34()
但使用上,无须调用构造函数。所用模型都可以通过create_model函数来将创建。
resnet34构造函数的源码如下:
@register_model
def resnet34(pretrained=False, **kwargs):
"""Constructs a ResNet-34 model.
"""
model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], **kwargs)
return _create_resnet('resnet34', pretrained, **model_args)
会发现 timm 中的每个模型都有一个 register_model 装饰器。最开始, _model_entrypoints 是一个空字典。通过 register_model 装饰器来不断地像其中添加模型名称和它对应的构造函数。该装饰器的定义如下:
def register_model(fn):
# lookup containing module
mod = sys.modules[fn.__module__]
module_name_split = fn.__module__.split('.')
module_name = module_name_split[-1] if len(module_name_split) else ''
# add model to __all__ in module
model_name = fn.__name__
if hasattr(mod, '__all__'):
mod.__all__.append(model_name)
else:
mod.__all__ = [model_name]
# add entries to registry dict/sets
_model_entrypoints[model_name] = fn
_model_to_module[model_name] = module_name
_module_to_models[module_name].add(model_name)
has_pretrained = False # check if model has a pretrained url to allow filtering on this
if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs:
# this will catch all models that have entrypoint matching cfg key, but miss any aliasing
# entrypoints or non-matching combos
has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url']
if has_pretrained:
_model_has_pretrained.add(model_name)
return fn
可以看到, register_model 函数完成了一些比较基础的步骤,但这里需要指出的是这一句:
_model_entrypoints[model_name] = fn
它将给定的 fn 添加到 _model_entrypoints 其键名为 fn.__name__。所以说 resnet34 函数上的装饰器 @register_model 在 _model_entrypoints 中创建一个新的条目,像这样:
{’resnet34’: }
同样可以看到在 resnet34 构造函数的源码中,在设置完一些 model_args 之后,它会随后调用 _create_resnet 函数。再来看一下该函数的源码:
def _create_resnet(variant, pretrained=False, **kwargs):
return build_model_with_cfg(
ResNet, variant, default_cfg=default_cfgs[variant], pretrained=pretrained, **kwargs)
所以在 _create_resnet 函数之中,会再调用 build_model_with_cfg 函数并将一个构造器类 ResNet 、变量名 resnet34、一个 default_cfg 和一些 **kwargs 传入其中。
timm 中所有的模型都有一个默认的配置,包括指向它的预训练权重参数的URL、类别数、输入图像尺寸、池化尺寸等。resnet34 的默认配置如下:
{'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth',
'num_classes': 1000,
'input_size': (3, 224, 224),
'pool_size': (7, 7),
'crop_pct': 0.875,
'interpolation': 'bilinear',
'mean': (0.485, 0.456, 0.406),
'std': (0.229, 0.224, 0.225),
'first_conv': 'conv1',
'classifier': 'fc'}
这个 build_model_with_cfg 函数负责:
def build_model_with_cfg(
model_cls: Callable,
variant: str,
pretrained: bool,
default_cfg: dict,
model_cfg: dict = None,
feature_cfg: dict = None,
pretrained_strict: bool = True,
pretrained_filter_fn: Callable = None,
pretrained_custom_load: bool = False,
**kwargs):
pruned = kwargs.pop('pruned', False)
features = False
feature_cfg = feature_cfg or {}
if kwargs.pop('features_only', False):
features = True
feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4))
if 'out_indices' in kwargs:
feature_cfg['out_indices'] = kwargs.pop('out_indices')
model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs)
model.default_cfg = deepcopy(default_cfg)
if pruned:
model = adapt_model_from_file(model, variant)
# for classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))
if pretrained:
if pretrained_custom_load:
load_custom_pretrained(model)
else:
load_pretrained(
model,
num_classes=num_classes_pretrained, in_chans=kwargs.get('in_chans', 3),
filter_fn=pretrained_filter_fn, strict=pretrained_strict)
if features:
feature_cls = FeatureListNet
if 'feature_cls' in feature_cfg:
feature_cls = feature_cfg.pop('feature_cls')
if isinstance(feature_cls, str):
feature_cls = feature_cls.lower()
if 'hook' in feature_cls:
feature_cls = FeatureHookNet
else:
assert False, f'Unknown feature class {feature_cls}'
model = feature_cls(model, **feature_cfg)
model.default_cfg = default_cfg_for_features(default_cfg) # add back default_cfg
return model
可以看到,模型在这一步被创建出来:model = model_cls(**kwargs)。本文将不再深入到 pruned 和 adapt_model_from_file 内部查看。
通过本文,已经完全了解了 create_model 函数,了解到:
实操后续整理,查看可行性