训练集5000张图片,每类500张,验证集1000张,每类100张。图片命名格式如下图所示。
训练集、验证集分为两个文件夹存放。
class AlexNet(nn.Module):
def __init__(self):
super(AlexNet, self).__init__()
#input size [3*227*227]
self.conv1 = nn.Conv2d(3, 96, 11, stride=4)
self.conv2 = nn.Conv2d(96, 256, 5, padding=2)
self.conv3 = nn.Conv2d(256, 384, 3, padding=1)
self.conv4 = nn.Conv2d(384, 384, 3, padding=1)
self.conv5 = nn.Conv2d(384, 256, 3, padding=1)
self.fc6 = nn.Linear(256 * 6 * 6, 4096)
self.fc7 = nn.Linear(4096, 4096)
self.fc8 = nn.Linear(4096, 10)
def forward(self, x):
c1 = self.conv1(x)
r1 = F.relu(c1)
p1 = F.max_pool2d(r1, (3,3), stride=2)
c2 = self.conv2(p1)
r2 = F.relu(c2)
p2 = F.max_pool2d(r2, (3,3), stride=2)
c3 = self.conv3(p2)
r3 = F.relu(c3)
c4 = self.conv4(r3)
r4 = F.relu(c4)
c5 = self.conv5(r4)
r5 = F.relu(c5)
p5 = F.max_pool2d(r5, (3,3), stride=2)
flatten = p5.view(-1, 256*6*6)
f6 = self.fc6(flatten)
r6 = F.relu(f6)
d6 = F.dropout(r6)
f7 = self.fc7(d6)
r7 = F.relu(f7)
d7 = F.dropout(r7)
f8 = self.fc8(d7)
return f8
Torch里面好像没有LRN层。也没有Crop,直接227*227大小输进去。
class MyDataset_Cifar10(Dataset):
def __init__(self, image_dir):
self.root_dir = image_dir
self.name_list = os.listdir(image_dir)
self.label_list = []
for i in self.name_list:
name,id = i.split('_')
id = id[:-4]
self.label_list.append(id)
def __len__(self):
return len(self.name_list)
def __getitem__(self, item):
img = cv.imread(self.root_dir + self.name_list[item])
img = cv.cvtColor(img, cv.COLOR_BGR2RGB).astype(np.float32) / 255.0
img = cv.resize(img, (227, 227))
img = img.transpose((2, 0, 1))
img = torch.tensor(img)
label = self.label_list[item]
label = torch.tensor(int(label))
return img, label
主要需要告诉DataLoader你的数据在哪?所以需要传入image_dir。其次建立数据和标签的对应关系表,从而在__len__函数中得到数据的总量。最后,根据__getitem__函数的item项,返回一个数据和一个标签。注意,这里的数据和标签最好是能直接拿来训练的数据,而不是纯RGB数据,所以上面的代码进行了浮点型,缩放,通道转换,张量化处理。
def train():
max_epoch = 50
test_epoch = 1
display = 10
train_batch_size = 128
val_batch_size = 64
net = AlexNet()
net.cuda()
best_model = net.state_dict()
best_acc = 0.0
cross_entropy_loss = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0005)
#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)
train_set = MyDataset_Cifar10('/home/dl/DeepHashing/CIFAR10/train/')
val_set = MyDataset_Cifar10('/home/dl/DeepHashing/CIFAR10/query/')
trainloader = torch.utils.data.DataLoader(train_set, batch_size=train_batch_size, shuffle=True)
valloader = torch.utils.data.DataLoader(val_set, batch_size=val_batch_size, shuffle=False)
for e in range(max_epoch):
print('Epoch {}/{}'.format(e,max_epoch))
print('-' * 10)
net.train()
for i, data in enumerate(trainloader, 0):
inputs, labels = data
inputs = Variable(inputs.cuda())
labels = Variable(labels.cuda())
optimizer.zero_grad()
outputs = net(inputs)
loss = cross_entropy_loss(outputs, labels)
loss.backward()
optimizer.step()
#scheduler.step()
if i % display == 0:
print('{} train loss:{} learning rate:{}'.
format(i*train_batch_size, loss.item(), optimizer.param_groups[0]['lr']))
if e % test_epoch == 0:
print('testing...')
net.eval()
acc = 0
with torch.no_grad():
for i, data in enumerate(valloader, 0):
inputs, labels = data
inputs = Variable(inputs.cuda())
labels = Variable(labels.cuda())
outputs = net(inputs)
_, preds = torch.max(outputs.data, 1)
acc += torch.sum(preds == labels.data)
acc = acc.item()/1000
print('val acc:{}'.format(acc))
if acc > best_acc:
best_acc = acc
best_model = net.state_dict()
torch.save(best_model, './torch_test.pkl')
这段代码跑出来val是54.8%,网上其他alexnet的准确率在60-70这样,应该是对的,我们只用了5000样本。
def test():
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
img = cv.imread('/home/dl/test.jpg')
img = cv.cvtColor(img, cv.COLOR_BGR2RGB).astype(np.float32) / 255.0
img = cv.resize(img, (227, 227))
img = img.transpose((2, 0, 1))
img = torch.tensor(img)
img = img.unsqueeze(0)
img = img.cuda()
net = AlexNet()
net.cuda()
net.load_state_dict(torch.load('./torch_test.pkl'))
net.eval()
outputs = net(img)
_, preds = torch.max(outputs.data, 1)
print(classes[preds.item()])
输出frog,为什么输出frog?-_-
补充:头文件和主函数
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
import torch.optim.lr_scheduler
from torch.utils.data import DataLoader,Dataset
from torch.autograd import Variable
import os
import cv2 as cv
import numpy as np
if __name__=='__main__':
#train()
test()
补充:测试的模型(模型效果很拉胯,熟悉个代码流程)
链接:https://pan.baidu.com/s/1YSDNUbwytFhw7X9_mQWciA
提取码:dux9