腳本可以選擇的網絡有:Mobilenet_v2, Googlenet, Inception_v3, resnet50, Densnet121。 當然也可以添加你自己的網絡。
a. 數據路徑下包含train和val兩個文件夾,文件夾下面存放所有類別的數據,一個類別一個文件夾。
b. resize是圖片壓縮尺寸,crop-size是圖片中心剪裁后輸入網絡的尺寸。
c. 如果要加載訓練過的模型,開啟pre,并設置model-path模型路徑。
d. 如果要使用Focal-loss,需要新建Focal-loss.py,并拷貝代碼,代碼我會在後面列出。
訓練代碼:
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
import copy
import FocalLoss
import argparse, os
import time, datetime
# import tqdm
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--net-name', dest='net_name', type=str, default='resnet50')
parser.add_argument('--data', dest='data', type=str, default='./data/class', help='The data path!')
parser.add_argument('--resize', dest='resize', type=int, default=300, help='img resize')
parser.add_argument('--crop-size', dest='crop_size', type=int, default=224, help='crop resized img enter net!')
parser.add_argument('--batch-size', dest='batch_size', type=int, default=16)
parser.add_argument('--epochs', dest='epochs', type=int, default=100)
parser.add_argument('--classes', dest='classes', type=int, default=3, help='class number')
parser.add_argument('--save-path', dest='save_path', type=str, default='./model', help='save model path!')
parser.add_argument('--pre', dest='pre_training', action='store_true', help='pre_training or not!')
parser.add_argument('--model-path', dest='model_path', type=str, default='./model/epoch0_acc0.3676_loss1.0442.pt',
help='pre_training model path!')
parser.add_argument('--focal-loss', dest='focal_loss', action='store_true', help='use focal loss!')
parser.add_argument('--fe', dest='feature_extract', default=True,
help='Flag for feature extractiing, When False, wei finetune the whole mode, When True, we only update the reshaped layer paras!')
args = parser.parse_args()
return args
args = parse_args()
transform = {
'train': transforms.Compose([transforms.Resize(args.resize),
transforms.CenterCrop(args.crop_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
'val': transforms.Compose([transforms.Resize(args.resize),
transforms.CenterCrop(args.crop_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
}
def set_parameter_requires_grad(model, feature_extracting):
if feature_extracting:
for param in model.parameters():
## 設置 requires_grad=False 凍結參數,以便在backward()中不計算梯度.
param.requires_grad = False
def Net(feature_extract, net_name):
net = None
if net_name in ['Mobilenet_v2', 'mobilenet_v2']:
net = torchvision.models.mobilenet_v2(pretrained=True)
set_parameter_requires_grad(net, feature_extract)
classifier = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(net.last_channel, args.classes),
)
net.classifier = classifier
elif net_name in ['googlenet', 'Googlenet']:
net = torchvision.models.googlenet(pretrained=True)
set_parameter_requires_grad(net, feature_extract)
in_features = net.fc.in_features
fc = nn.Linear(in_features, args.classes, bias=True)
net.fc = fc
elif net_name in ['inception_v3', 'Inception_v3']:
net = torchvision.models.inception_v3(pretrained=True)
set_parameter_requires_grad(net, feature_extract)
num_ftrs = net.AuxLogits.fc.in_features
AuxLogits = nn.Linear(num_ftrs, args.classes)
net.AuxLogits.fc = AuxLogits
num_ftrs = net.fc.in_features
fc = nn.Linear(num_ftrs, args.classes)
net.fc = fc
elif net_name in ['resnet50', 'Resnet50']:
net = torchvision.models.resnet50(pretrained=True)
set_parameter_requires_grad(net, feature_extract)
in_features = net.fc.in_features
fc = nn.Linear(in_features, args.classes)
net.fc = fc
elif net_name in ['densenet121', 'Densnet121']:
net = torchvision.models.densenet121(pretrained=True)
set_parameter_requires_grad(net, feature_extract)
in_features = net.classifier.in_features
classifier = nn.Linear(in_features, args.classes)
net.classifier = classifier
else:
assert net, 'please add yourself net or input right net name!'
print(list(net.children()))
net = net.to(device)
return net
def train(net, data_loader, optim, criterion, exp_lr_scheduler, epochs, net_name):
best_acc = 0
best_model_wts = copy.deepcopy(net.state_dict())
for epoch in range(epochs):
print('Epoch{}/{}'.format(epoch, epochs - 1))
for phase in ['train', 'val']:
if phase == 'train':
exp_lr_scheduler.step()
net.train()
if phase == 'val':
net.eval()
running_loss = 0
running_corects = 0
step = 1
steps = len(data_loader[phase])
for inputs, labels in data_loader[phase]:
strat_time = time.time()
inputs = inputs.to(device)
labels = labels.to(device)
optim.zero_grad()
with torch.set_grad_enabled(phase == 'train'):
if net_name in ['inception_v3', 'Inception_v3'] and phase == 'train':
outputs, aux_outputs = net(inputs)
loss1 = criterion(outputs, labels)
loss2 = criterion(aux_outputs, labels)
loss = loss1 + 0.4 * loss2
else:
outputs = net(inputs)
loss = criterion(outputs, labels)
_, predict = torch.max(outputs, 1)
if phase == 'train':
loss.backward()
optim.step()
end_time = time.time()
residual_time = str(datetime.timedelta(seconds=(steps - step) * (end_time - strat_time)))[:-7]
print("\r%d/%d [%s>%s] -ETA: %s - loss: %4f\n" % (
step, steps, '=' * int(29 * step / steps), '.' * (29 - int(29 * step / steps)), residual_time, loss), end='', flush=True)
step += 1
running_loss += loss.item() * inputs.size(0)
running_corects += torch.sum(predict == labels.data)
epoch_loss = running_loss / data_size[phase]
epoch_acc = running_corects.double() / data_size[phase]
print('{} Loss:{:.4f} acc:{:.4f}'.format(phase, epoch_loss, epoch_acc))
if epoch_acc > best_acc and phase == 'val':
best_acc = epoch_acc
torch.save(net.state_dict(), os.path.join(args.save_path,
'epoch{}_acc{:.4f}_loss{:.4f}.pt'.format(epoch, epoch_acc, epoch_loss)))
best_model_wts = copy.deepcopy(net.state_dict())
print('Best val acc', best_acc)
net.load_state_dict(best_model_wts)
torch.save(net.state_dict(), os.path.join(args.save_path, 'best_model_acc{:.4f}.pt'.format(best_acc)))
if __name__ == '__main__':
# load data
imgs_datasets = {x: torchvision.datasets.ImageFolder(os.path.join(args.data, x), transform=transform[x]) for x in ['train', 'val']}
data_loader = {
x: torch.utils.data.DataLoader(imgs_datasets[x], batch_size=args.batch_size, shuffle=True if x == 'train' else False, num_workers=0) for x in ['train', 'val']}
data_size = {x: len(imgs_datasets[x]) for x in ['train', 'val']}
img_class = imgs_datasets['train'].classes
net = Net(args.feature_extract, args.net_name)
# load model
if args.pre_training:
print('load model:{}'.format(args.model_path))
net.load_state_dict(torch.load(args.model_path))
# Observe that all parameters are being optimized
if args.feature_extract:
params_to_update = []
for name, param in net.named_parameters():
if param.requires_grad == True:
params_to_update.append(param)
else:
params_to_update = net.parameters()
optim = torch.optim.Adam(params=params_to_update, lr=0.001)
# optim = torch.optim.SGD(params=net.parameters(), lr=0.001, momentum=0.9)
# Loss function
criterion = FocalLoss() if args.focal_loss else nn.CrossEntropyLoss()
print('criterion: ', criterion)
# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optim, step_size=7, gamma=0.1)
train(net, data_loader, optim, criterion, exp_lr_scheduler, args.epochs, args.net_name)
Focal-loss代碼:
import torch
import torch.nn as nn
import torch.nn.functional as F
class FocalLoss(nn.Module): # 1d and 2d
def __init__(self, gamma=2, size_average=True):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.size_average = size_average
def forward(self, logit, target, class_weight=None, type='softmax'):
target = target.view(-1, 1).long()
if type == 'sigmoid':
if class_weight is None:
class_weight = [1] * 2 # [0.5, 0.5]
prob = torch.sigmoid(logit)
prob = prob.view(-1, 1)
prob = torch.cat((1 - prob, prob), 1)
select = torch.FloatTensor(len(prob), 2).zero_().cpu()
select.scatter_(1, target, 1.)
elif type == 'softmax':
B, C = logit.size()
if class_weight is None:
class_weight = [1] * C # [1/C]*C
# logit = logit.permute(0, 2, 3, 1).contiguous().view(-1, C)
prob = F.softmax(logit, 1)
select = torch.FloatTensor(len(prob), C).zero_().cpu()
select.scatter_(1, target, 1.)
class_weight = torch.FloatTensor(class_weight).cpu().view(-1, 1)
class_weight = torch.gather(class_weight, 0, target)
prob = (prob * select).sum(1).view(-1, 1)
prob = torch.clamp(prob, 1e-8, 1 - 1e-8)
batch_loss = - class_weight * (torch.pow((1 - prob), self.gamma)) * prob.log()
if self.size_average:
loss = batch_loss.mean()
else:
loss = batch_loss
return loss
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from collections import Counter
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
def Load_Data(img_size, imgs_path):
transform = transforms.Compose([transforms.Resize(img_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
testset = torchvision.datasets.ImageFolder(root=imgs_path, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=8, shuffle=False, num_workers=0)
return (testset, testloader)
def Load_Model(model_path):
## 訓練時只保存權重參數,需加載網絡
# net = torchvision.models.mobilenet_v2(pretrained=False)
# classifier = nn.Sequential(
# nn.Dropout(0.2),
# nn.Linear(net.last_channel, 3),
# )
# net.classifier = classifier
# model = net.to(device)
# state_dict = torch.load(model_path)
# model.load_state_dict(state_dict)
# 訓練時保存了整個模型
model = torch.load(model_path)
return model
def test(data, model):
testset, testloader = data
labels_list = testset.targets
classes_list = testset.classes
dict_right = {}
dict_count = {}
for i in range(len(classes_list)):
dict_right[i] = 0
dict_count[i] = Counter(labels_list)[i]
with torch.no_grad():
for data in testloader:
inputs, labels = data
inputs.to(device)
labels.to(device)
outputs = model(inputs)
_, predicted = torch.max(outputs, 1)
print('predicted:', predicted, 'labels:', labels)
for i in range(len(labels)):
if predicted[i] == labels[i]:
dict_right[int(labels[i])] += 1
for i in range(len(classes_list)):
acc = dict_right[i] / dict_count[i]
print('Class {} Acc: {:.4f}'.format(classes_list[i], acc))
if __name__ == '__main__':
model_path = r'./model/epoch0_acc0.9044_loss0.2180.pt' # 模型路徑
imgs_path = './data/class/val' # 測試集路徑
img_size = 224 # 輸入網絡尺寸
model = Load_Model(model_path)
testdata = Load_Data(img_size=img_size, imgs_path=imgs_path)
test(testdata, model)
import os
import torch
from torchvision import transforms
from PIL import Image
import time
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
img_size = 224
model = torch.load('./model/epoch0_acc0.9044_loss0.2180.pt')
data_transforms = transforms.Compose([
transforms.Resize(img_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def predict_image(img_path):
start_time = time.time()
img = Image.open(img_path)
input = data_transforms(img).unsqueeze(0)
input.to(device)
output = model(input)
_, predict = torch.max(output, 1)
end_time = time.time()
print('use_time', end_time - start_time)
print('precicted classes: ', predict.numpy()[0])
while True:
img_path = input('please input img_path:')
if not os.path.exists(img_path) and img_path != 'q':
print("The path error, Try again!")
continue
if img_path == 'q':
break
predict_image(img_path)
參考文章:https://blog.csdn.net/weixin_40123108/article/details/85714030