import os
from pathlib import Path
from PIL import Image
from torch.utils.data import DataLoader,Dataset
import numpy as np
from torchvision import transforms, models
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import torch
from torch import nn
import torch.optim as optim
import torch.nn.functional as F
import random
import time
%matplotlib inline
label_dict = {'cat':0
,'dog':1
}
path = Path('D:\图像数据集\kaggle_cat_vs_dog')
os.listdir(path)
all_file_name = os.listdir(path/'train')
random.shuffle(all_file_name)
all_labels = [label_dict[i.split('.')[0]] for i in all_file_name]
Xtrain, Xvalid, Ytrain, Yvalid = train_test_split(all_file_name
,all_labels
,test_size=0.3)
train_transforms = transforms.Compose([transforms.RandomRotation(30),
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
test_transforms = transforms.Compose([transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
class CatvsDogDataset(Dataset):
def __init__(self, path, mode, file_names, labels):
'''
mode: 1.train(train and valid) 2.test1(test)
'''
self.data_list = [path/mode/file for file in file_names]
self.labels = labels
self.mode = mode
def __getitem__(self, index):
img = Image.open(self.data_list[index])
label = self.labels[index]
if self.mode == 'train':
return train_transforms(img), torch.LongTensor([label])
else:
return test_transforms(img)
def __len__(self):
return len(self.labels)
train_loader = DataLoader(CatvsDogDataset(path, 'train', Xtrain, Ytrain)
,shuffle=True
,batch_size=32
)
valid_loader = DataLoader(CatvsDogDataset(path, 'train', Xvalid, Yvalid)
,shuffle=True
,batch_size=32
)
class Feature_net(nn.Module):
def __init__(self, model):
super().__init__()
if model == 'vgg':
vgg = models.vgg19(pretrained=True
)
self.feature = nn.Sequential(*list(vgg.children())[:-2])
elif model == 'inception_v3':
inception = models.inception_v3(pretrained=True)
self.feature = nn.Sequential(*list(inception.childrenldren())[:-1])
self.feature._modules.pop('13')
self.feature.add_module('global average', nn.AvgPool2d(35))
elif model == 'resnet152':
resnet = models.resnet152(pretrained=True)
self.feature = nn.Sequential(*list(resnet.children())[:-1])
def forward(self, x):
x = self.feature(x)
x = x.view(x.size(0), -1)
return x
class Classifier(nn.Module):
def __init__(self, dim, n_class):
super().__init__()
self.fc = nn.Sequential(
nn.Linear(dim, 1000)
,nn.ReLU(inplace=True)
,nn.Dropout(0.5)
,nn.Linear(1000, n_class)
)
def forward(self, x):
x = self.fc(x)
return x
model_fe = Feature_net('vgg')
for parma in model_fe.parameters():
parma.requires_grad = False
model_clf = Classifier(25088, 2)
model_fe = model_fe.cuda()
model_clf = model_clf.cuda()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model_clf.parameters())
now = lambda x:time.time()
begin_time = now()
model_clf.train()
for epoch in range(1):
sum_loss = 0.0
for i, data in enumerate(train_loader):
imgs, labels = data
imgs, labels = imgs.cuda(), labels.cuda()
if torch.cuda.is_available():
imgs, labels = imgs.cuda(), labels.cuda()
x = model_fe(imgs)
x = model_clf(x)
loss = criterion(x, labels.squeeze())
optimizer.zero_grad()
loss.backward()
optimizer.step()
sum_loss += loss.item()
if i % 100 == 99:
print('[%d,%d] loss:%.03f' %
(epoch + 1, i + 1, sum_loss / 100))
sum_loss = 0.0
print(f'耗时:{now()-begin_time}s')
tmpx,tmpy = next(iter(valid_loader))
print(tmpx.shape)
with torch.no_grad():
model_fe.eval()
model_clf.eval()
out = model_fe(tmpx.cuda())
out = model_clf(out)
print(torch.max(out,1)[1])
print(tmpy.squeeze())
print(sum(torch.max(out,1)[1].cpu() == tmpy.squeeze()).numpy()/32)
model_fe.eval()
model_clf.eval()
with torch.no_grad():
eval_acc = 0
for data in valid_loader:
img, label = data
if torch.cuda.is_available():
img = img.cuda()
label = label.cuda()
out = model_fe(img)
out = model_clf(out)
_, pred = torch.max(out, 1)
num_correct = (pred == label.squeeze()).sum()
eval_acc += num_correct.item()
print(f'Acc: {eval_acc/len(Yvalid)}')