环境:
使用torchvision中已有的模型进行迁移学习,构建自定义模型,代码如下:
# -*- coding: utf-8 -*-
"""
Created on Tue Feb 25 12:58:25 2020
@author: zhou-wenqing
图像分类任务
"""
from PIL import Image
import torch
from torchvision import models, transforms, datasets
from torch import nn
import torch.nn.functional as F
from torchsummary import summary
from skorch import NeuralNetClassifier
import matplotlib.pyplot as plt
plt.rcParams['font.family'] = ['sans-serif']
plt.rcParams['font.sans-serif'] = ['Times New Roman']
plt.rcParams['font.size'] = 12
import time
import copy
import numpy as np
import argparse
#%% 根据torchvision自带的模型进行迁移学习,迁移学习的方式包括:
# 1)冻结参数(卷积层只作提取图像特征用,权重使用imagenet预训练权重,不再参与梯度更新)
# 2)修改输出全连接层数,和自定义数据集所需分类类别数量对应
def create_model(model_name, # 模型名称
num_classes, # 类别数量
feature_extract:bool, # 是否作特征提取
use_pretrained=True, # 是否加载预训练权重
):
model_ft = None
input_size = 0
def set_parameter_requires_grad(model, feature_extracting):
if feature_extracting:
for param in model.parameters():
param.requires_grad = False
if model_name == "resnet":
"""
Resnet18
"""
model_ft = models.resnet18(pretrained=use_pretrained)
set_parameter_requires_grad(model_ft, feature_extract)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, num_classes)
input_size = 224
elif model_name == "alexnet":
"""
Alexnet
"""
model_ft = models.alexnet(pretrained=use_pretrained)
set_parameter_requires_grad(model_ft, feature_extract)
num_ftrs = model_ft.classifier[6].in_features
model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
input_size = 224
elif model_name == "vgg":
"""
VGG11_bn
"""
model_ft = models.vgg11_bn(pretrained=use_pretrained)
set_parameter_requires_grad(model_ft, feature_extract)
num_ftrs = model_ft.classifier[6].in_features
model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
input_size = 224
elif model_name == "squeezenet":
"""
Squeezenet
"""
model_ft = models.squeezenet1_0(pretrained=use_pretrained)
set_parameter_requires_grad(model_ft, feature_extract)
model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1,1), stride=(1,1))
model_ft.num_classes = num_classes
input_size = 224
elif model_name == "densenet":
"""
Densenet
"""
model_ft = models.densenet121(pretrained=use_pretrained)
set_parameter_requires_grad(model_ft, feature_extract)
num_ftrs = model_ft.classifier.in_features
model_ft.classifier = nn.Linear(num_ftrs, num_classes)
input_size = 224
elif model_name == "inception":
"""
Inception v3
Be careful, expects (299,299) sized images and has auxiliary output
"""
model_ft = models.inception_v3(pretrained=use_pretrained)
set_parameter_requires_grad(model_ft, feature_extract)
# Handle the auxilary net
num_ftrs = model_ft.AuxLogits.fc.in_features
model_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes)
# Handle the primary net
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs,num_classes)
input_size = 299
else:
print("Invalid model name, exiting...")
exit()
return model_ft, input_size
#%%
def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_inception=False, plot=False):
since = time.time()
val_acc_history = []
train_loss = []
train_acc = []
iters = []
if plot:
plt.ion()
best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0.0
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10)
# 每个 epoch 包含 training 和 validation phase.
for phase in ['train', 'val']:
if phase == 'train':
model.train() # Set model to training mode
else:
model.eval() # Set model to evaluate mode
running_loss = 0.0
running_corrects = 0
# Iterate over data.
for idx, (inputs, labels) in enumerate(dataloaders[phase]):
inputs = inputs.to(device)
labels = labels.to(device)
# zero the parameter gradients
optimizer.zero_grad()
# forward
# track history if only in train
with torch.set_grad_enabled(phase == 'train'):
# 计算模型输出及 loss.
# 对于 inception 模型,训练时,其还包括一个辅助 loss;
# 最终的 loss 是辅助 loss 和最终输出 loss 的两者之和.
# 但,测试时,只考虑最终输出的 loss.
if is_inception and phase == 'train':
outputs, aux_outputs = model(inputs)
loss1 = criterion(outputs, labels)
loss2 = criterion(aux_outputs, labels)
loss = loss1 + 0.4*loss2
else:
outputs = model(inputs)
loss = criterion(outputs, labels)
_, preds = torch.max(outputs, 1)
# backward + optimize only if in training phase
if phase == 'train':
loss.backward()
optimizer.step()
# print(f'epoch:{epoch} | batch:{batch} | iters:{iters} | batch train loss:{train_loss} | batch train acc: {train_acc}')
# statistics
# 每个batch的loss和预测正确的数量相加起来
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
epoch_loss = running_loss / len(dataloaders[phase].dataset)
epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
if plot:
if phase=='train':
train_loss.append(epoch_loss)
train_acc.append(epoch_acc)
iters.append(epoch)
if (epoch+1) % 1 == 0: # plotting
plt.cla()
plt.subplot(121)
plt.plot(iters, train_loss, 'r', label='train loss')
plt.title('Train Loss')
plt.xlabel('epochs')
plt.legend()
plt.subplot(122)
plt.plot(iters, train_acc,'b', label='train acc')
plt.xlabel('epochs')
plt.title('Train Acc')
plt.legend()
plt.ioff()
plt.show()
print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
# deep copy the model
if phase == 'val' and epoch_acc > best_acc:
best_acc = epoch_acc
best_model_wts = copy.deepcopy(model.state_dict())
if phase == 'val':
val_acc_history.append(epoch_acc)
print()
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(
time_elapsed // 60, time_elapsed % 60))
print('Best val Acc: {:4f}'.format(best_acc))
# load best model weights
model.load_state_dict(best_model_wts)
return model, val_acc_history
if __name__=='__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='alexnet',
help='模型名称,alexnet,resnet,vgg,squeezenet,densnet,inception')
parser.add_argument('--data', type=str, default=r"D:\Datasets\flower_photos", help='数据集根路径')
parser.add_argument('--epochs', type=int, default=2)
parser.add_argument('--extract', type=bool, default=True)
parser.add_argument('--pretrained', type=bool, default=True)
parser.add_argument('--batch-size', type=int, default=8)
parser.add_argument('--num-classes', type=int, default=5)
parser.add_argument('--train-ratio', type=float, default=0.8,
help='训练集比例')
opt = parser.parse_args()
print(opt)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 模型初始化
model_ft, input_size = create_model(model_name=opt.model,
num_classes=opt.num_classes,
feature_extract=opt.extract,
use_pretrained=opt.pretrained)
#%% Loading image dataset
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop((input_size,input_size)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize((input_size,input_size),
# transforms.CenterCrop(input_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
full_dataset = datasets.ImageFolder(root=opt.data,
transform=data_transforms['train'])
print('数据集总长度:', len(full_dataset))
# 分割数据集
train_size = int(opt.train_ratio * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
print('训练集总长度:', len(train_dataset))
print('验证集总长度:', len(test_dataset))
image_datasets = {'train':train_dataset,
'val':test_dataset}
# Create training and validation dataloaders
dataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batch_size, shuffle=True, num_workers=0) for x in ['train', 'val']}
# 模型放于 device
model_ft = model_ft.to(device)
# 打印实例化后的模型
print(summary(model_ft, (3,input_size, input_size)))
# 收集待优化/待更新的参数.
# 如果是 finetuning,则更新全部网络参数;
# 如果是 feature extraction,则只更新 requires_grad=True 的参数.
params_to_update = model_ft.parameters()
print("Params to learn:")
if opt.extract:
params_to_update = []
for name,param in model_ft.named_parameters():
if param.requires_grad == True:
params_to_update.append(param)
print("\t",name)
else:
for name,param in model_ft.named_parameters():
if param.requires_grad == True:
print("\t",name)
# 所有参数均是待优化参数.
optimizer_ft = torch.optim.SGD(params_to_update, lr=0.001, momentum=0.9)
# 设置 loss 函数
criterion = nn.CrossEntropyLoss()
# Train and evaluate
model_ft, hist = train_model(model_ft,
dataloaders_dict,
criterion,
optimizer_ft,
num_epochs=opt.epochs,
is_inception=(opt.model=="inception"))
torch.save(model_ft,'custom_model.pt')
## 转化jittrace
# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)
# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model_ft, example)
traced_script_module.save('custom_traced_model.pt')
这部分内容在前面代码中已经实现了,这里再次强调一下,关于Torch Script模型的转化可以参考官方教程:https://pytorch.org/tutorials/advanced/cpp_export.html#a-minimal-c-application
pytorch官方文档还不详细,没有介绍怎么加载图像,网上好多教程都是在Linux环境操作的,环境设置比Windows环境方便,Opencv库路径在CMakeLists.txt文件中需要手动指定,完整文件内容如下:
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(example-app)
set(OpenCV_DIR "E:\\ScientificComputing\\opencv-4.2.0\\build\\install")
find_package(Torch REQUIRED) # 查找libtorch
find_package(OpenCV REQUIRED) # 查找OpenCV
if(NOT Torch_FOUND)
message(FATAL_ERROR "Pytorch Not Found!")
endif(NOT Torch_FOUND)
message(STATUS "Pytorch status:")
message(STATUS " libraries: ${TORCH_LIBRARIES}")
message(STATUS "OpenCV library status:")
message(STATUS " version: ${OpenCV_VERSION}")
message(STATUS " libraries: ${OpenCV_LIBS}")
message(STATUS " include path: ${OpenCV_INCLUDE_DIRS}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
add_executable(example-app example-app.cpp)
target_link_libraries(example-app "${TORCH_LIBRARIES}" "${OpenCV_LIBS}")
set_property(TARGET example-app PROPERTY CXX_STANDARD 14)
# The following code block is suggested to be used on Windows.
# According to https://github.com/pytorch/pytorch/issues/25457,
# the DLLs need to be copied to avoid memory errors.
if (MSVC)
file(GLOB TORCH_DLLS "${TORCH_INSTALL_PREFIX}/lib/*.dll")
add_custom_command(TARGET example-app
POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy_if_different
${TORCH_DLLS}
$:example-app>)
endif (MSVC)
该程序实现:1)加载自定义script module;2)预测图像,输出最大索引值
#include // One-stop header.
#include
#include
#include
#include
#include
#include
#include
using namespace std;
using namespace cv;
int main(int argc, const char *argv[])
{
if (argc != 3) // Here we need 2 arguments
{
std::cerr << "usage: example-app \n" ;
return -1;
}
torch::jit::script::Module module;
try
{
// Deserialize the ScriptModule from a file using torch::jit::load().
module = torch::jit::load(argv[2]);
}
catch (const c10::Error &e)
{
std::cerr << "error loading the model\n";
return -1;
}
std::cout << "Loading model succesfully...\n";
//杈撳叆鍥惧儚
auto image = cv::imread(argv[1],cv::ImreadModes::IMREAD_COLOR);
cv::Mat image_transfomed;
cv::resize(image, image_transfomed, cv::Size(224, 224));
cv::cvtColor(image_transfomed, image_transfomed, cv::COLOR_BGR2RGB);
// convert cv::Mat to at::Tensor (see https://pytorch.org/cppdocs/api/namespace_at.html#namespace-at)
torch::Tensor tensor_image = torch::from_blob(image_transfomed.data, {image_transfomed.rows, image_transfomed.cols,3},torch::kByte);
tensor_image = tensor_image.permute({2,0,1});
tensor_image = tensor_image.toType(torch::kFloat);
tensor_image = tensor_image.div(255);
tensor_image = tensor_image.unsqueeze(0);
// Execute the model and turn its output into a tensor.
// at::Tensor output = module.forward(inputs).toTensor();
at::Tensor output = module.forward({tensor_image}).toTensor();
// cout << "output:" << output << endl;
// std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << '\n';
auto max_result = output.max(1, true);
auto max_index = std::get<1>(max_result).item<float>();
cout << "max index predicted: " << max_index << endl;
}
cd example-app
mkdir build
cd build
cmake -DCMAKE_PREFIX_PATH="F\\libtorch" ..
cmake --build . --config Release
没有问题的话便在Release目录下得到编译好的程序已经一些动态库:
在build目录下执行命令:
./Release/example-app.exe "C:\Users\zhou-\Pictures\sunflower.jpg" custom_traced_model.pt
huiti会提示找不到opencv的相关库:
将提示缺失的相关库放置example-app.exe
所在目录即可: