Resnet101特征提取2048维度

import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image

# Load pre-trained ResNet-101 model
resnet = models.resnet101(pretrained=True)

# Remove the final fully connected layer
feature_extractor = torch.nn.Sequential(*list(resnet.children())[:-1])

# Set model to evaluation mode
feature_extractor.eval()

# Define image transforms
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load image
img = Image.open(
    "/home/a/Downloads/Laysan_Albat1_545.jpg").convert(
    'RGB')

# Apply transforms to image
img_tensor = transform(img)

# Add batch dimension to tensor
img_tensor = img_tensor.unsqueeze(0)

# Pass image through ResNet-101 model
with torch.no_grad():
    features = feature_extractor(img_tensor)

# Flatten the features into a 1D vector
feature_vector = torch.flatten(features, start_dim=1)

# Print shape of feature vector
print(feature_vector.shape)

你可能感兴趣的:(python,开发语言)