TopK

@TOPK

from __future__ import print_function, division
import torch
import torch.nn as nn

import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt

import os
#import copy
#import math
import torch.nn.functional as F
plt.ion()   # interactive mode

data_dir = '/home/cc/Desktop/keji/sucai/'
# Models to choose from [resnet, alexnet, vgg, squeezenet, densenet, inception]
model_name = "resnet"

# Number of classes in the dataset
num_classes = 5

# Batch size for training (change depending on how much memory you have)
batch_size = 4     #批处理尺寸(batch_size)

# Number of epochs to train for 
EPOCH = 300
#pre='/home/cc/Desktop/123/train5_densenet169/net_019.pth'
# Flag for feature extracting. When False, we finetune the whole model, 
#   when True we only update the reshaped layer params
#feature_extract = True
feature_extract =False
# 超参数设置
pre_epoch = 0  # 定义已经遍历数据集的次数

LR = 0.001        #学习率

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):
    # Initialize these variables which will be set in this if statement. Each of these
    #   variables is model specific.
    net = None
    input_size = 0

    if model_name == "resnet":
        """ Resnet18
        """
        net = models.resnet18(pretrained=use_pretrained)
        set_parameter_requires_grad(net, feature_extract)
        num_ftrs = net.fc.in_features
        net.fc = nn.Linear(num_ftrs, num_classes) 
        pre='net_060.pth'
        net.load_state_dict(torch.load(pre))       
        input_size = 224
    elif model_name == "resnet34":
        """ Resnet34
        """
        net = models.resnet34(pretrained=use_pretrained)
        set_parameter_requires_grad(net, feature_extract)
        net.fc = nn.Linear(8192, num_classes)
        #pre='net_015.pth'
        #net.load_state_dict(torch.load(pre))       
        input_size = 299
    elif model_name == "alexnet":
        """ Alexnet
        """
        net = models.alexnet(pretrained=use_pretrained)
        set_parameter_requires_grad(net, feature_extract)
        num_ftrs = net.classifier[6].in_features
        net.classifier[6] = nn.Linear(num_ftrs,num_classes)
        input_size = 224

    elif model_name == "vgg":
        """ VGG11_bn
        """
        net = models.vgg11_bn(pretrained=use_pretrained)
        set_parameter_requires_grad(net, feature_extract)
        num_ftrs = net.classifier[6].in_features
        net.classifier[6] = nn.Linear(num_ftrs,num_classes)
        input_size = 224

    elif model_name == "squeezenet":
        """ Squeezenet
        """
        net = models.squeezenet1_0(pretrained=use_pretrained)
        set_parameter_requires_grad(net, feature_extract)
        net.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1,1), stride=(1,1))
        net.num_classes = num_classes
        input_size = 224

    elif model_name == "resnet101":
        """ Resnet101
        """
        net = models.resnet101(pretrained=use_pretrained)
        set_parameter_requires_grad(net, feature_extract)
        num_ftrs = net.fc.in_features
        net.fc = nn.Linear(num_ftrs, num_classes)
        pre='/home/cc/Desktop/dj/123/train3_resnet/train_resnet101/net_018.pth'
        net.load_state_dict(torch.load(pre))
        input_size = 224

    elif model_name == "densenet":
        """ Densenet  
        """
        net = models.densenet121(pretrained=use_pretrained)
        set_parameter_requires_grad(net, feature_extract)
        num_ftrs = net.classifier.in_features
        net.classifier = nn.Linear(num_ftrs, num_classes) 
        input_size = 224
    elif model_name == "densenet169":
        """ Densenet
        """
        net = models.densenet169(pretrained=use_pretrained)
        set_parameter_requires_grad(net, feature_extract)
        num_ftrs = net.classifier.in_features
        net.classifier = nn.Linear(num_ftrs, num_classes) 
        
        net.load_state_dict(torch.load(pre))
        input_size = 224

    elif model_name == "inception":
        """ Inception v3 
        Be careful, expects (299,299) sized images and has auxiliary output
        """
        net = models.inception_v3(pretrained=use_pretrained)
        set_parameter_requires_grad(net, feature_extract)
        # Handle the auxilary net
        num_ftrs = net.AuxLogits.fc.in_features
        net.AuxLogits.fc = nn.Linear(num_ftrs, num_classes)
        # Handle the primary net
        num_ftrs = net.fc.in_features
        net.fc = nn.Linear(num_ftrs,num_classes)
        input_size = 299

    else:
        print("Invalid model name, exiting...")
        exit()
    
    return net, input_size


# Initialize the model for this run
model_ft, input_size = initialize_model(model_name, num_classes, feature_extract, use_pretrained=True)

data_transforms = {
    'test': transforms.Compose([
        transforms.RandomResizedCrop(input_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(input_size),
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms['test']) for x in ['test']}
#print(len(image_datasets['train']))
# Create training and validation dataloaders
dataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=1, shuffle=False, num_workers=2) for x in ['test']}


'''
a=image_datasets['train'].class_to_idx
'''

a={ 0:'140112199012187729',  1:'140112199012197730',  2:'140112199012197731',  3:'140112199012197732',  4:'140112199012197733'}
o=0
oi=0
#checkpoint = torch.load(pre)
#model.class_to_idx = checkpoint['class_to_idx']
#model.idx_to_class = checkpoint['idx_to_class']
topk=3
for phase in ['test']:
    model_ft.eval()
    model_ft = model_ft.to(device)
    for data in dataloaders_dict[phase]:
        
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = model_ft(images)
        m=nn.Softmax()
        ps = m(outputs)
        #ps = torch.exp(outputs)
        # Find the topk predictions
        topk, topclass = ps.topk(3, dim=1)
        #print(topk,topclass)
        top_classes = [a[class_] for class_ in topclass.cpu().numpy()[0]]
        #print(top_classes)
        top_p = topk.cpu().detach().numpy()
        print(top_p,top_classes)

你可能感兴趣的:(pytorch)