本文使用pytorch提供的预训练模型训练牛津大学提供的flower数据集。该数据集共1360张,分为17类,每类80张。为了方便,我把顺序打乱后从中拿出1000张作为训练集,剩下的作为验证集。数据集下载
import torch.nn as nn
from torch.utils import data
from torchvision import transforms
from PIL import Image
from torchvision import models as MD
import torch
class Make_data(data.Dataset):
def __init__(self, txt, img, tensform):
self.image = []
self.txt = txt
self.img = img
self.tensform = tensform
file = open(self.txt)
lines = file.readlines()
for line in lines:
l = line.split("\n")[0]
self.image.append(l)
def __getitem__(self, item):
path = self.image[item].split(" ")
img_path = path[0]
img = Image.open(self.img + img_path)
img = self.tensform(img)
targe = int(path[1])
return img, targe
def __len__(self):
return len(self.image)
def train(dada_loader):
model = MD.alexnet(pretrained=False)
model.load_state_dict(torch.load("../models/alexnet-owt.pth"))
num_input = model.classifier[6].in_features
feature_model = list(model.classifier.children())
feature_model.pop()
feature_model.append(nn.Linear(num_input, 17))
model.classifier = nn.Sequential(*feature_model)
model = model.cuda()
critersion = nn.CrossEntropyLoss()
opt = torch.optim.SGD(model.parameters(), lr=0.001)
for ench in range(200):
sum = 0
for i, data in enumerate(dada_loader):
img, targe = data
targe = targe.cuda()
img = img.cuda()
output = model(img)
loss = critersion(output, targe)
opt.zero_grad()
loss.backward()
opt.step()
sum += loss
print(sum)
if ench % 20 == 0:
torch.save(model.state_dict(), "../models/" + str(ench) + ".pkl")
def test(dada_loader):
model = MD.alexnet(pretrained=False)
num_input = model.classifier[6].in_features
feature_model = list(model.classifier.children())
feature_model.pop()
feature_model.append(nn.Linear(num_input, 17))
model.classifier = nn.Sequential(*feature_model)
#加载训练过的模型进行测试
model.load_state_dict(torch.load(""))
model = model.cuda()
for i, data in enumerate(dada_loader):
img, targe = data
targe = targe.cuda()
img = img.cuda()
output = model(img)
_,pred=torch.max(output.data,1)
print(torch.sum(pred==targe))
if __name__ == '__main__':
tensform = transforms.Compose([
transforms.Scale([224, 224]),
transforms.ToTensor()
])
traindata = Make_data(txt="../Data/train.txt", img="../Data/flower/", tensform=tensform)
testdata = Make_data(txt="../Data/test.txt", img="../Data/flower/", tensform=tensform)
train_loader = torch.utils.data.DataLoader(traindata, batch_size=50)
test_loader = torch.utils.data.DataLoader(testdata, batch_size=50)
train(train_loader)
#test(test_loader)