如何用resnet50提取图片特征【咨询大厂大佬版】

import torch
import torch.nn as nn
import os
from torchvision import models, transforms
from torch.autograd import Variable
import numpy as np
from PIL import Image
import torchvision.models as models
import pandas as pd
 
 
file_path='./images/'
save_path = ''
transform1 = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor()])
names = os.listdir(file_path)
resnet50 = models.resnet50(pretrained=True)
feature_extractor = torch.nn.Sequential(*list(resnet50.children())[:-1]) #去掉最后的fc层
for name in names:
    pic=file_path+name
    img = Image.open(pic)
    img1 = transform1(img)
    x = Variable(torch.unsqueeze(img1, dim=0).float(), requires_grad=False)
    y = feature_extractor(x).squeeze().cpu() #去掉多余的一维
    torch.save(y,save_path+name[:-4]+".pth")

你可能感兴趣的:(深度学习,pytorch,人工智能)