torch_vision(二):模型和预训练weight模块 torchvision.models

torchvision.models简单介绍

介绍

torchvision.models模块提供了很多模型架构,以及对应的预先训练好的权重。
最新的版本的特性是相比于旧版本

  1. 一个模型架构可以加载多种不同的权重。
  2. 可以获取到预处理方法,这些转换中的任何细微差异(例如插值、调整大小/裁剪大小等)都可能导致准确性大幅降低或模型无法使用。
  3. 提供元数据,包括类别标签,准确度等指标。

以一个分类模型为例:

from torchvision.io import read_image
from torchvision.models import resnet50, ResNet50_Weights

img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")

# Step 1: Initialize model with the best available weights
# ResNet50_Weights.IMAGENET1K_V1  ResNet50_Weights.IMAGENET1K_V2是其他可以选择的版本,DEFAULT一般是最优的版本
weights = ResNet50_Weights.DEFAULT 
model = resnet50(weights=weights)
model.eval()

# Step 2: Initialize the inference transforms
preprocess = weights.transforms()
# Step 3: Apply inference preprocessing transforms
batch = preprocess(img).unsqueeze(0)

# Step 4: Use the model and print the predicted category
prediction = model(batch).squeeze(0).softmax(0)
class_id = prediction.argmax().item()
score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id]
print(f"{category_name}: {100 * score:.1f}%")

目标检测

from torchvision.io.image import read_image
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights
from torchvision.utils import draw_bounding_boxes
from torchvision.transforms.functional import to_pil_image

img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")

# Step 1: Initialize model with the best available weights
weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
model = fasterrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.9)
model.eval()

# Step 2: Initialize the inference transforms
preprocess = weights.transforms()

# Step 3: Apply inference preprocessing transforms
batch = [preprocess(img)]

# Step 4: Use the model and visualize the prediction
prediction = model(batch)[0]
labels = [weights.meta["categories"][i] for i in prediction["labels"]]
box = draw_bounding_boxes(img, boxes=prediction["boxes"],
                          labels=labels,
                          colors="red",
                          width=4, font_size=30)
im = to_pil_image(box.detach())
im.show()

语义分割

from torchvision.io.image import read_image
from torchvision.models.segmentation import fcn_resnet50, FCN_ResNet50_Weights
from torchvision.transforms.functional import to_pil_image

img = read_image("gallery/assets/dog1.jpg")

# Step 1: Initialize model with the best available weights
weights = FCN_ResNet50_Weights.DEFAULT
model = fcn_resnet50(weights=weights)
model.eval()

# Step 2: Initialize the inference transforms
preprocess = weights.transforms()

# Step 3: Apply inference preprocessing transforms
batch = preprocess(img).unsqueeze(0)

# Step 4: Use the model and visualize the prediction
prediction = model(batch)["out"]
normalized_masks = prediction.softmax(dim=1)
class_to_idx = {cls: idx for (idx, cls) in enumerate(weights.meta["categories"])}
mask = normalized_masks[0, class_to_idx["dog"]]
to_pil_image(mask).show()

可以选择的模型和weight

torchvision.models包含很多模型和预先训练好的weight, 能够处理多种任务,图像分类,语义分割,目标检测,关键点检测,视频分类,光流估计等。

The torchvision.models subpackage contains definitions of models for addressing different tasks, 
including:image classification, pixelwise semantic segmentation, object detection, instance segmentation, 
person keypoint detection, video classification, and optical flow.

具体各个任务有哪些可以在torchvision.models可以获取到的模型,请查看

MODELS AND PRE-TRAINED WEIGHTS

其实覆盖的模型不算多,超分,生成模型,图像修复,图像增强等多种任务并没有相关模型在torchvision.models中。

你可能感兴趣的:(图像处理算法,深度学习,目标检测,计算机视觉)