点击下方卡片,关注“CVer”公众号
AI/CV重磅干货,第一时间送达
转载自:机器之心 | 编辑:杜伟
这个项目登上了今天的GitHub Trending。
近一两年,Transformer 跨界 CV 任务不再是什么新鲜事了。
自 2020 年 10 月谷歌提出 Vision Transformer (ViT) 以来,各式各样视觉 Transformer 开始在图像合成、点云处理、视觉 - 语言建模等领域大显身手。
之后,在 PyTorch 中实现 Vision Transformer 成为了研究热点。GitHub 中也出现了很多优秀的项目,今天要介绍的就是其中之一。
该项目名为「vit-pytorch」,它是一个 Vision Transformer 实现,展示了一种在 PyTorch 中仅使用单个 transformer 编码器来实现视觉分类 SOTA 结果的简单方法。
项目当前的 star 量已经达到了 7.5k,创建者为 Phil Wang,ta 在 GitHub 上有 147 个资源库。
项目地址:https://github.com/lucidrains/vit-pytorch
项目作者还提供了一段动图展示:
项目介绍
首先来看 Vision Transformer-PyTorch 的安装、使用、参数、蒸馏等步骤。
第一步是安装:
$ pip install vit-pytorch
第二步是使用:
import torch
from vit_pytorch import ViT
v = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
img = torch.randn(1, 3, 256, 256)
preds = v(img) # (1, 1000)
第三步是所需参数,包括如下:
image_size:图像大小
patch_size:patch 数量
num_classes:分类类别的数量
dim:线性变换 nn.Linear(..., dim) 后输出张量的最后维
depth:Transformer 块的数量
heads:多头注意力层中头的数量
mlp_dim:MLP(前馈)层的维数
channels:图像通道的数量
dropout:Dropout rate
emb_dropout:嵌入 dropout rate
……
最后是蒸馏,采用的流程出自 Facebook AI 和索邦大学的论文《Training data-efficient image transformers & distillation through attention》。
论文地址:https://arxiv.org/pdf/2012.12877.pdf
从 ResNet50(或任何教师网络)蒸馏到 vision transformer 的代码如下:
import torchfrom torchvision.models import resnet50from vit_pytorch.distill import DistillableViT, DistillWrapperteacher = resnet50(pretrained = True)
v = DistillableViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
distiller = DistillWrapper(
student = v,
teacher = teacher,
temperature = 3, # temperature of distillationalpha = 0.5, # trade between main loss and distillation losshard = False # whether to use soft or hard distillation
)
img = torch.randn(2, 3, 256, 256)labels = torch.randint(0, 1000, (2,))
loss = distiller(img, labels)loss.backward()
# after lots of training above ...pred = v(img) # (2, 1000)
除了 Vision Transformer 之外,该项目还提供了 Deep ViT、CaiT、Token-to-Token ViT、PiT 等其他 ViT 变体模型的 PyTorch 实现。
对 ViT 模型 PyTorch 实现感兴趣的读者可以参阅原项目。
上述代码资源下载
后台回复:Trans复现,即可下载代码资源
ICCV和CVPR 2021论文和代码下载
后台回复:CVPR2021,即可下载CVPR 2021论文和代码开源的论文合集
后台回复:ICCV2021,即可下载ICCV 2021论文和代码开源的论文合集
后台回复:Transformer综述,即可下载最新的3篇Transformer综述PDF
CVer-Transformer交流群成立
扫码添加CVer助手,可申请加入CVer-Transformer 微信交流群,方向已涵盖:目标检测、图像分割、目标跟踪、人脸检测&识别、OCR、姿态估计、超分辨率、SLAM、医疗影像、Re-ID、GAN、NAS、深度估计、自动驾驶、强化学习、车道线检测、模型剪枝&压缩、去噪、去雾、去雨、风格迁移、遥感图像、行为识别、视频理解、图像融合、图像检索、论文投稿&交流、PyTorch和TensorFlow等群。
一定要备注:研究方向+地点+学校/公司+昵称(如Transformer+上海+上交+卡卡),根据格式备注,可更快被通过且邀请进群
▲长按加小助手微信,进交流群
▲点击上方卡片,关注CVer公众号