A demo for feature extraction with vgg19 in pytorch

这里提取灰度图的特征,所以我把它堆叠了三次变成了三通道,vgg-19的预训练模型可以在pytorch提供的官方地址下载

import torch,os
import scipy.io as sio
import numpy as np
import scipy.misc as sc
from torchvision import models
from torch.autograd import Variable

model = models.vgg19()
model.load_state_dict(torch.load('models/vgg19-dcbb9e9d.pth'))
model.cuda()

data_file = os.listdir('img/')
for pic in data_file:
    pic_path = 'img/'+pic
    print(pic_path)
    data = sc.imread(pic_path)
    data = np.resize(data,[224,224])
    img = np.zeros((1, 3, 224, 224)).astype(np.float32)
    for i in range(3):
        img[:,i,:,:] = data
    img = torch.from_numpy(img)
    img = Variable(img).cuda()
    feature = model(img)
    tmp = []
    for key,parm in enumerate(model.classifier.parameters()):
        if key == 0 or key == 2 or key == 4:
            continue
        parm = parm.cpu()
        tmp.append(parm.data.numpy())
    sio.savemat('feature/'+pic[:-4]+'.mat',{'feature_4096_first':tmp[0],
                                            'feature_4096_second': tmp[1],
                                            'feature_1000': tmp[2]})
    # break

你可能感兴趣的:(A demo for feature extraction with vgg19 in pytorch)