train.py
import os
import torch.optim as optim
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models
from torch.autograd import Variable
from PIL import ImageFile, Image
ImageFile.LOAD_TRUNCATED_IMAGES = True
Image.MAX_IMAGE_PIXELS = None
# 数据集路径
train_path = './data/doc_4_angles_dataset/train'
val_path = './data/doc_4_angles_dataset/val'
# 模型保存目录
model_dir = './checkpoints'
if not os.path.exists(model_dir):
os.mkdir(model_dir)
# 设置参数
gpu_id = [1, 2, 3, 4]
cls_num = 4 # 分类数量
model_lr = 1e-4
BATCH_SIZE = len(gpu_id) * 512
EPOCHS = 100
DEVICE = torch.device('cuda:' + str(gpu_id[0]))
# 训练集数据预处理
transform = transforms.Compose([
transforms.Resize((128, 128)), # resize
transforms.RandomVerticalFlip(), # 随机旋转
transforms.RandomCrop(50), # 随机裁剪
transforms.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5), # 修改亮度、对比度和饱和度
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # 归一化
])
# 测试集数据预处理
transform_test = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
# 读取训练集数据
dataset_train = datasets.ImageFolder(train_path, transform)
# 读取验证集数据
dataset_test = datasets.ImageFolder(val_path, transform_test)
# 导入数据
train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)
# 实例化模型并且移动到GPU
criterion = nn.CrossEntropyLoss() # 使用交叉熵作为loss
model = torchvision.models.resnet18(pretrained=False) # resnet18, pretrained=False表示不使用预训练模型
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, cls_num) # 分类数量
if len(gpu_id) > 1: # 多卡情况
model = nn.DataParallel(model, device_ids=gpu_id)
model.to(DEVICE)
# 选择简单暴力的Adam优化器,学习率调低
optimizer = optim.Adam(model.parameters(), lr=model_lr)
# 调整学习率
def adjust_learning_rate(optimizer, epoch):
"""Sets the learning rate to the initial LR decayed by 10 every 50 epochs"""
model_lr_new = model_lr * (0.1 ** (epoch // 50))
print("lr:", model_lr_new)
for param_group in optimizer.param_groups:
param_group['lr'] = model_lr_new
# 定义训练过程
def train(model, device, train_loader, optimizer, epoch):
model.train()
sum_loss = 0
total_num = len(train_loader.dataset)
# print(total_num, len(train_loader))
print("total train data: ", total_num)
for batch_idx, (data, target) in enumerate(train_loader):
data, target = Variable(data).to(device), Variable(target).to(device)
output = model(data)
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print_loss = loss.data.item()
sum_loss += print_loss
if (batch_idx + 1) % 50 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),
100. * (batch_idx + 1) / len(train_loader), loss.item()))
ave_loss = sum_loss / len(train_loader)
print('epoch:{}, loss:{}'.format(epoch, ave_loss))
# 定义验证过程
def val(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
total_num = len(test_loader.dataset)
# print(total_num, len(test_loader))
print("\ntotal val data: ", total_num)
with torch.no_grad():
for data, target in test_loader:
data, target = Variable(data).to(device), Variable(target).to(device)
output = model(data)
loss = criterion(output, target)
_, pred = torch.max(output.data, 1)
correct += torch.sum(pred == target)
print_loss = loss.data.item()
test_loss += print_loss
correct = correct.data.item()
acc = correct / total_num
avgloss = test_loss / len(test_loader)
print('Val set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
avgloss, correct, len(test_loader.dataset), 100 * acc))
return avgloss
if __name__ == '__main__':
for epoch in range(1, EPOCHS + 1):
adjust_learning_rate(optimizer, epoch)
# train
train(model, DEVICE, train_loader, optimizer, epoch)
avgloss = val(model, DEVICE, test_loader)
# save model
model_name = "epoch-" + str(epoch) + "_loss-" + str(avgloss) + ".pth"
model_path = os.path.join(model_dir, model_name)
if len(gpu_id) > 1:
torch.save(model.module, model_path)
else:
torch.save(model, model_path)
test.py
import os
import numpy as np
from PIL import Image
import cv2
import torch.utils.data.distributed
import torchvision.transforms as transforms
from torch.autograd import Variable
import sys
# find_path = os.path.abspath(os.path.join(os.getcwd(), "../")) # 在tools目录下执行
find_path = os.path.abspath(os.getcwd()) # 在code目录下执行
sys.path.insert(0, find_path)
print(sys.path)
from stamp_angle_detect.rotate import rotate
# 设置参数
input_dir = './data/doc_4_angles_dataset/test/90'
output_dir = './output'
model_path = "./checkpoints/epoch-1_loss-1.4412968158721924.pth"
gpu_id = 5
DEVICE = torch.device('cuda:' + str(gpu_id))
# 定义类别
# classes = list(range(-179, 181, 1)) # -179 ~ 180
classes = list(range(-90, 181, 90)) # -90, 0, 90, 180
# classes = list(range(-170, 181, 10)) # -170, -160, ..., 0, 10, ..., 180
# 创建目录
for cls in classes:
class_dir = os.path.join(output_dir, str(cls))
if not os.path.exists(class_dir):
os.makedirs(class_dir)
# 数据预处理
transform_test = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
# load model
model = torch.load(model_path)
model.eval()
model.to(DEVICE)
# test
test_list = os.listdir(input_dir)
for file in test_list:
input_path = os.path.join(input_dir, file)
img = Image.open(input_path).convert('RGB') # 只读取三通道
img_tensor = transform_test(img)
img_tensor.unsqueeze_(0)
img_tensor = Variable(img_tensor).to(DEVICE)
out = model(img_tensor) # infer
# Predict
_, pred = torch.max(out.data, 1)
angle = classes[pred.data.item()]
print('Image Name:{}, predict:{}'.format(file, angle))
# rotate
img_np = np.array(img)
dims = img_np.ndim
borderValue = (255, 255, 255) if dims > 1 else 255
rotate_img = rotate(img_np, -angle, borderValue=borderValue)
# save
output_path = os.path.join(output_dir, str(angle), file)
rotate_img = cv2.cvtColor(rotate_img, cv2.COLOR_RGB2BGR)
cv2.imwrite(output_path, rotate_img)