本项目使用Pytroch
,并基于ResNet50
模型,实现了对天气图片的识别,过程详细,十分适合基础阶段的同学阅读。
核心步骤:
DataSet
及Dataloader
本项目数据来源:
https://www.heywhale.com/mw/dataset/60d9bd7c056f570017c305ee/file
http://vcc.szu.edu.cn/research/2017/RSCM.html
由于数据是直接下载,且目录分的很规整,本项目的数据处理部分较为简单,直接手动复制,合并两个数据集即可。
总数据量约7万张。
配置文件的主要存储一些各个模块通用的一些全局变量,如各种文件的存放位置等等(本人Java程序员出身,一些Python的代码规范不太熟悉,望见谅)。
config.py
:
import time
import torch
# 项目配置文件
class Common:
'''
通用配置
'''
basePath = "D:/Data/weather/source/all/" # 图片文件基本路径
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # 设备配置
imageSize = (224,224) # 图片大小
labels = ["cloudy","haze","rainy","shine","snow","sunny","sunrise","thunder"] # 标签名称/文件夹名称
class Train:
'''
训练相关配置
'''
batch_size = 128
num_workers = 0 # 对于Windows用户,这里应设置为0,否则会出现多线程错误
lr = 0.001
epochs = 40
logDir = "./log/" + time.strftime('%Y-%m-%d-%H-%M-%S',time.gmtime()) # 日志存放位置
modelDir = "./model/" # 模型存放位置
dada_loader.py
# 自定义数据加载器
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from config import Common
from config import Train
import os
from PIL import Image
import torch.utils.data as Data
import numpy
# 定义数据处理transform
transform = transforms.Compose([
transforms.Resize(Common.imageSize),
transforms.ToTensor()
])
def loadDataFromDir():
'''
从文件夹中获取数据
'''
images = []
labels = []
# 1. 获取根文件夹下所有分类文件夹
for d in os.listdir(Common.basePath):
for imagePath in os.listdir(Common.basePath + d): # 2. 获取某一类型下所有的图片名称
# 3. 读取文件
image = Image.open(Common.basePath + d + "/" + imagePath).convert('RGB')
print("加载数据" + str(len(images)) + "条")
# 4. 添加到图片列表中
images.append(transform(image))
# 5. 构造label
categoryIndex = Common.labels.index(d) # 获取分类下标
label = [0] * 8 # 初始化label
label[categoryIndex] = 1 # 根据下标确定目标值
label = torch.tensor(label,dtype=torch.float) # 转为tensor张量
# 6. 添加到目标值列表
labels.append(label)
# 7. 关闭资源
image.close()
# 返回图片列表和目标值列表
return images, labels
class WeatherDataSet(Dataset):
'''
自定义DataSet
'''
def __init__(self):
'''
初始化DataSet
:param transform: 自定义转换器
'''
images, labels = loadDataFromDir() # 在文件夹中加载图片
self.images = images
self.labels = labels
def __len__(self):
'''
返回数据总长度
:return:
'''
return len(self.images)
def __getitem__(self, idx):
image = self.images[idx]
label = self.labels[idx]
return image, label
def splitData(dataset):
'''
分割数据集
:param dataset:
:return:
'''
# 求解一下数据的总量
total_length = len(dataset)
# 确认一下将80%的数据作为训练集, 剩下的20%的数据作为测试集
train_length = int(total_length * 0.8)
validation_length = total_length - train_length
# 利用Data.random_split()直接切分数据集, 按照80%, 20%的比例进行切分
train_dataset,validation_dataset = Data.random_split(dataset=dataset, lengths=[train_length, validation_length])
return train_dataset, validation_dataset
# 1. 分割数据集
train_dataset, validation_dataset = splitData(WeatherDataSet())
# 2. 训练数据集加载器
trainLoader = DataLoader(train_dataset, batch_size=Train.batch_size, shuffle=True, num_workers=Train.num_workers)
# 3. 验证集数据加载器
valLoader = DataLoader(validation_dataset, batch_size=Train.batch_size, shuffle=False,
num_workers=Train.num_workers)
主要步骤:
PIL
库
PIL
教程:https://blog.csdn.net/weixin_43790276/article/details/108478270
3*224*224
,故需要使用Pytroch的transforms
工具进行处理
transforms
教程:https://blog.csdn.net/qq_38410428/article/details/94719553
DataSet
(继承DataSet类,并实现重写三个核心方法)model.py
import torch
from torch import nn
import torchvision.models as models
from config import Common, Train
# 引入rest50模型
net = models.resnet50()
net.load_state_dict(torch.load("./model/resnet50-11ad3fa6.pth"))
class WeatherModel(nn.Module):
def __init__(self, net):
super(WeatherModel, self).__init__()
# resnet50
self.net = net
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.1)
self.fc = nn.Linear(1000, 8)
self.output = nn.Softmax(dim=1)
def forward(self, x):
x = self.net(x)
x = self.relu(x)
x = self.dropout(x)
x = self.fc(x)
x = self.output(x)
return x
model = WeatherModel(net)
主要步骤:
Pytorch
官方的残差网络预训练模型关于新版本的引入方法:https://blog.csdn.net/Sihang_Xie/article/details/125646287
train.py
# 训练部分
import time
import torch
from torch import nn
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter
from config import Common, Train
from model import model as weatherModel
from data_loader import trainLoader, valLoader
from torch import optim
# 1. 获取模型
model = weatherModel
model.to(Common.device)
# 2. 定义损失函数
criterion = nn.CrossEntropyLoss()
# 3. 定义优化器
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 4. 创建writer
writer = SummaryWriter(log_dir=Train.logDir, flush_secs=500)
def train(epoch):
'''
训练函数
'''
# 1. 获取dataLoader
loader = trainLoader
# 2. 调整为训练状态
model.train()
print()
print('========== Train Epoch:{} Start =========='.format(epoch))
epochLoss = 0 # 每个epoch的损失
epochAcc = 0 # 每个epoch的准确率
correctNum = 0 # 正确预测的数量
for data, label in loader:
data, label = data.to(Common.device), label.to(Common.device) # 加载到对应设备
batchAcc = 0 # 单批次正确率
batchCorrectNum = 0 # 单批次正确个数
optimizer.zero_grad() # 清空梯度
output = model(data) # 获取模型输出
loss = criterion(output, label) # 计算损失
loss.backward() # 反向传播梯度
optimizer.step() # 更新参数
epochLoss += loss.item() * data.size(0) # 计算损失之和
# 计算正确预测的个数
labels = torch.argmax(label, dim=1)
outputs = torch.argmax(output, dim=1)
for i in range(0, len(labels)):
if labels[i] == outputs[i]:
correctNum += 1
batchCorrectNum += 1
batchAcc = batchCorrectNum / data.size(0)
print("Epoch:{}\t TrainBatchAcc:{}".format(epoch, batchAcc))
epochLoss = epochLoss / len(trainLoader.dataset) # 平均损失
epochAcc = correctNum / len(trainLoader.dataset) # 正确率
print("Epoch:{}\t Loss:{} \t Acc:{}".format(epoch, epochLoss, epochAcc))
writer.add_scalar("train_loss", epochLoss, epoch) # 写入日志
writer.add_scalar("train_acc", epochAcc, epoch) # 写入日志
return epochAcc
def val(epoch):
'''
验证函数
:param epoch: 轮次
:return:
'''
# 1. 获取dataLoader
loader = valLoader
# 2. 初始化损失、准确率列表
valLoss = []
valAcc = []
# 3. 调整为验证状态
model.eval()
print()
print('========== Val Epoch:{} Start =========='.format(epoch))
epochLoss = 0 # 每个epoch的损失
epochAcc = 0 # 每个epoch的准确率
correctNum = 0 # 正确预测的数量
with torch.no_grad():
for data, label in loader:
data, label = data.to(Common.device), label.to(Common.device) # 加载到对应设备
batchAcc = 0 # 单批次正确率
batchCorrectNum = 0 # 单批次正确个数
output = model(data) # 获取模型输出
loss = criterion(output, label) # 计算损失
epochLoss += loss.item() * data.size(0) # 计算损失之和
# 计算正确预测的个数
labels = torch.argmax(label, dim=1)
outputs = torch.argmax(output, dim=1)
for i in range(0, len(labels)):
if labels[i] == outputs[i]:
correctNum += 1
batchCorrectNum += 1
batchAcc = batchCorrectNum / data.size(0)
print("Epoch:{}\t ValBatchAcc:{}".format(epoch, batchAcc))
epochLoss = epochLoss / len(valLoader.dataset) # 平均损失
epochAcc = correctNum / len(valLoader.dataset) # 正确率
print("Epoch:{}\t Loss:{} \t Acc:{}".format(epoch, epochLoss, epochAcc))
writer.add_scalar("val_loss", epochLoss, epoch) # 写入日志
writer.add_scalar("val_acc", epochAcc, epoch) # 写入日志
return epochAcc
if __name__ == '__main__':
maxAcc = 0.75
for epoch in range(1,Train.epochs + 1):
trainAcc = train(epoch)
valAcc = val(epoch)
if valAcc > maxAcc:
maxAcc = valAcc
# 保存最大模型
torch.save(model, Train.modelDir + "weather-" + time.strftime('%Y-%m-%d-%H-%M-%S', time.gmtime()) + ".pth")
# 保存模型
torch.save(model,Train.modelDir+"weather-"+time.strftime('%Y-%m-%d-%H-%M-%S',time.gmtime())+".pth")
主要步骤:
tensorboard
的writer关于
tensorboard
的使用:https://blog.csdn.net/weixin_43637851/article/details/116003280
验证函数和训练函数的区别就是是否需要更新参数
epochs
次,不断保存正确率最大的模型,以及最后一次的训练模型训练过程中电脑的状态:
查看训练日志(tensorboard):
保存的模型:
pridect.py
import torch
import torchvision.transforms as transforms
from PIL import Image
from config import Common
def pridect(imagePath, modelPath):
'''
预测函数
:param imagePath: 图片路径
:param modelPath: 模型路径
:return:
'''
# 1. 读取图片
image = Image.open(imagePath)
# 2. 进行缩放
image = image.resize(Common.imageSize)
image.show()
# 3. 加载模型
model = torch.load(modelPath)
model = model.to(Common.device)
# 4. 转为tensor张量
transform = transforms.ToTensor()
x = transform(image)
x = torch.unsqueeze(x, 0) # 升维
x = x.to(Common.device)
# 5. 传入模型
output = model(x)
# 6. 使用argmax选出最有可能的结果
output = torch.argmax(output)
print("预测结果:",Common.labels[output.item()])
if __name__ == '__main__':
pridect("D:/Download/76ee4c5e833499949eac41561dcb487d.jpeg","./model/weather-2022-10-14-07-36-57.pth")
去网上随便找的图片:
https://github.com/mengxianglong123/weather-recognition
欢迎交流学习