首先介绍自己的 Mydataset
import os
import glob
import csv
import random
from PIL import Image
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
class Mydataset(Dataset):
def __init__(self, root, resize, mode):
super(Mydataset, self).__init__()
self.root = root
self.resize = resize
self.name2label = {} # 0,1,2 ...
for name in sorted(os.listdir(os.path.join(root))):
if not os.path.isdir(os.path.join(root, name)):
continue
self.name2label[name] = len(self.name2label.keys())
print(self.name2label)
self.images, self.labels = self.load_csv('imagess.csv')
if mode == 'train': # %60 = %0->%60
self.images = self.images[:int(0.6 * len(self.images))]
self.labels = self.labels[:int(0.6 * len(self.labels))]
elif mode == 'val': # %20 = %60->%80
self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]
else: # %20 = %80->%100
self.images = self.images[int(0.8 * len(self.images)):]
self.labels = self.labels[int(0.8 * len(self.labels)):]
def load_csv(self, filename):
if not os.path.exists(os.path.join(self.root, filename)):
images = []
for name in self.name2label.keys():
images += glob.glob(os.path.join(self.root, name, '*.png'))
images += glob.glob(os.path.join(self.root, name, '*.jpg'))
images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
print(len(images), images)
random.shuffle(images)
with open(os.path.join(self.root, filename), mode='w', newline='') as f:
write = csv.writer(f)
for img in images:
name = img.split(os.sep)[-2]
label = self.name2label[name]
write.writerow([img, label])
print('writen into csv file:', filename)
# read csv
images, labels = [], []
with open(os.path.join(self.root, filename)) as f:
reader = csv.reader(f)
for row in reader:
img, label = row
label = int(label)
images.append(img)
labels.append(label)
assert len(images) == len(labels)
return images, labels
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
# idx-[0->len(images)]
img, label = self.images[idx], self.labels[idx]
tf = transforms.Compose([
lambda x: Image.open(x).convert('RGB'),
transforms.Resize((int(self.resize * 1.25), int(self.resize * 1.25))),
transforms.RandomRotation(15),
transforms.CenterCrop(self.resize),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
img = tf(img)
label = torch.tensor(label)
return img, label
def denormalize(self, x_hat):
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
# x_hat = (x - mean) / std
# x = x_hat * std + mean
# x:[x,h,w]
# mean: [3] -> [3, 1, 1]
mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
x = x_hat * std + mean
return x
def main():
import visdom
import time
import torchvision
viz = visdom.Visdom()
transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor()
])
tmp = torchvision.datasets.ImageFolder(root='dataset', transform=transform)
loader = DataLoader(tmp, batch_size=32, shuffle=True)
for x, y in loader:
viz.images(x, nrow=8, win='batch', opts=dict(title='batch'))
viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))
time.sleep(10)
if __name__ == "__main__":
main()
基于 resnet18
如何加载数据训练,首先完成一个 Flatten.py
的函数
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
class Flatten(nn.Module):
def __init__(self):
super(Flatten, self).__init__()
def forward(self, x):
shape = torch.prod(torch.tensor(x.shape[1:])).item()
return x.view(-1, shape)
def plot_image(img, label, name):
fig = plt.figure()
for i in range(6):
plt.subplot(2,3, i+1)
plt.tight_layout()
plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')
plt.title('{}: {}'.format(name, label[i].item()))
plt.xticks([])
plt.yticks([])
plt.show()
完成 train_resnrt18.py 训练程序
import torch
import visdom
import torch.nn as nn
import torch.optim
from mydataset import Mydataset
from torch.utils.data import Dataset, DataLoader
from Flatten import Flatten
from torchvision.models.resnet import resnet18
batchsize = 32
learning_rate = 1e-5
epoches = 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_db = Mydataset('datasets', 32, mode='train')
val_db = Mydataset('datasets', 32, mode='val')
test_db = Mydataset('datasets', 32, mode='test')
train_loader = DataLoader(train_db, batch_size=batchsize, shuffle=True, num_workers=4)
val_loader = DataLoader(val_db, batch_size=batchsize, num_workers=2)
test_loader = DataLoader(test_db, batch_size=batchsize, num_workers=2)
# 训练模型
viz = visdom.Visdom()
def evaluate(model, loader):
correct = 0
total = len(loader.dataset)
for x, y in loader:
x, y = x.to(device), y.to(device)
with torch.no_grad():
logits = model(x)
pred = logits.argmax(dim=1)
correct += torch.eq(pred, y).sum().float().item()
return correct/total
def main():
model = resnet18(pretrained=True) # 比较好的 model
model = nn.Sequential(*list(model.children())[:-1], # [b, 512, 1, 1] -> 接全连接层
Flatten(), # [b, 512, 1, 1] -> [b, 512]
nn.Linear(512, 2)).to(device) # 添加全连接层
# x = torch.randn(2, 3, 224, 224)
# print(model(x).shape)
# 定义损失函数
criterion = nn.CrossEntropyLoss()
# 定义迭代参数的算法
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
best_acc, best_epoch = 0, 0
global_step = 0
viz.line([0], [-1], win='loss', opts=dict(title='loss'))
viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))
for epoch in range(epoches):
for step, (x, y) in enumerate(train_loader):
viz.images(train_db.denormalize(x), nrow=8, win='batch', opts=dict(title='batch'))
viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))
x, y = x.to(device), y.to(device)
model.train()
logits = model(x)
loss = criterion(logits, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
viz.line([loss.item()], [global_step], win='loss', update='append')
global_step += 1
if epoch % 1 == 0:
val_acc = evaluate(model, val_loader)
if val_acc > best_acc:
best_epoch = epoch
best_acc = val_acc
viz.line([val_acc], [global_step], win='val_acc', update='append')
print("best acc:", best_acc, "best epoch:", best_epoch)
torch.save(model.state_dict(), 'resnet18-circle25-50.pkl')
print("loaded from ckpt!")
test_acc = evaluate(model, test_loader)
print("test acc:", test_acc)
if __name__ == "__main__":
main()
使用 visdom 进行可视化,完成物体的识别.