目录
评估函数,计算 图片多分类的准确率 topK
保存准确率信息
完整代码
## topk的准确率计算
def accuracy(output, label, topk=(1,)):
maxk = max(topk)
batch_size = label.size(0)
# 获取前K的索引
_, pred = output.topk(maxk, 1, True, True) #使用topk来获得前k个的索引
pred = pred.t() # 进行转置
# eq按照对应元素进行比较 view(1,-1) 自动转换到行为1,的形状, expand_as(pred) 扩展到pred的shape
# expand_as 执行按行复制来扩展,要保证列相等
correct = pred.eq(label.view(1, -1).expand_as(pred)) # 与正确标签序列形成的矩阵相比,生成True/False矩阵
# print(correct)
rtn = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0) # 前k行的数据 然后平整到1维度,来计算true的总个数
rtn.append(correct_k.mul_(100.0 / batch_size)) # mul_() ternsor 的乘法 正确的数目/总的数目 乘以100 变成百分比
return rtn
class AvgrageMeter(object):
def __init__(self):
self.reset()
def reset(self):
self.avg = 0
self.sum = 0
self.cnt = 0
def update(self, val, n=1):
self.sum += val * n
self.cnt += n
self.avg = self.sum / self.cnt
import torch
import torchvision.transforms as transform
import cv2
from torch.utils.data import Dataset, DataLoader, random_split
import os
from PIL import Image
from tqdm import tqdm
data_path = r'dogs-vs-cats/train'
"""
# os.listdir 会将 本目录下所有的文件的名字 添加到列表中
path_list = os.listdir(path)
print(path_list)
"""
# 定义 数据 生成函数
#
class myData(Dataset):
def __init__(self, data_path):
self.data_path = data_path
self.transform = transform.Compose(
[
transform.Resize(size=(224, 224)),
transform.ToTensor(),
transform.Normalize(mean=[0.5, 0.5 ,0.5], std=[0.5, 0.5, 0.5])
]
)
self.path_list = os.listdir(data_path)
def __getitem__(self, idx): # 作用就是 将 图片转化为张量 然后打标签 返回 数据和标签
img_path = self.path_list[idx]
abs_img_path = os.path.join(self.data_path, img_path) #拼接成绝对地址 会加入/
img = Image.open(abs_img_path)
img = self.transform(img)
#打标签
if img_path.split('.')[0] == 'dog': #狗是1 猫是0
label = 1
else:
label = 0
label = torch.as_tensor(label, dtype=torch.int64)
return img, label
def __len__(self):
return len(self.path_list)
"""
# 测试
train_data = myData(data_path)
for i, item in enumerate(tqdm(train_data)):
print(item[1])
"""
#划分数据集
def data_sp(data_set):
x = int(len(data_set)*0.8)
y = len(data_set) - x
train_d, val_d = random_split(data_set, [x, y])
return train_d, val_d
data = myData(data_path)
train_d, val_d = data_sp(data)
# 使用dataloader 加载数据 成迭代器 打乱,batch_size等
trainloader = DataLoader(train_d, shuffle=True, batch_size=128)
valloader = DataLoader(val_d, shuffle=True, batch_size=128)
"""
# 这样 每轮输出 128个 batch_size 为128, 每个循环 训练18个数据
for dat in trainloader:
print(dat)
break
"""
# 搭建网络 两种方法 先使用 残差网络, 再自己定义一个试一试
from torchvision import models
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
"""
#加载 模型库中的残差网络
resnet18 = models.resnet18(pretrained=True)
resnet18.fc = nn.Linear(512, 2)
model = resnet18
"""
#搭建一个自己的网络
# 卷积 池化 卷积 池化 三个全连接层 中间用relu函数
class myNet(nn.Module):
def __init__(self):
super(myNet, self).__init__()
self.conv1 = nn.Conv2d(3, 8, kernel_size=3, stride=1, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1)
#self.relu = F.relu()
# 全连接层之前 需要展开成1维 两次池化 224/2/2 = 56
self.fc1 = nn.Linear(16*56*56, 256)
self.fc2 = nn.Linear(256, 64)
self.fc3 = nn.Linear(64 , 2) #输出2维 两个标签
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
#xia一卷积池化
x = self.pool(F.relu(self.conv2(x)))
# 展开 进入全连接层
x = x.view(-1, 16*56*56)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
model = myNet()
# 保存和更新准确率的结果
class AverageMeter(object):
def __init__(self):
self.reset()
def reset(self):
self.avg = 0
self.sum = 0
self.cnt = 0
def update(self, val, n=1):
self.sum += val * n
self.cnt += n
self.avg = self.sum / self.cnt
# 再多分类问题中, topk准确率是 只要正确标签 的概率 在最大的k个概率中即为正确
def accuracy(output, label, topk=(1,)):
maxk = max(topk)
batch_size = label.size(0)
# get top-k index
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t() # transpose
correct = pred.eq(label.view(1, -1).expand_as(pred))
rtn = []
for k in topk:
correct_k = correct[:k].contiguous().view(-1).float().sum(0)
rtn.append(correct_k.mul_(100.0 / batch_size))
return rtn
#下面定义优化器和损失函数
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
epochs = 10
model = model.to(device) # gpu
#下面进行训练网网络
for epoch in range(epochs):
model.train()
ac = AverageMeter()
train_loss = 0.0
# 这两句 是 再训练时 更能把握进度
trainloader = tqdm(trainloader)
trainloader.set_description('[%s%04d/%04d]' % ('Epoch:', epoch + 1, epochs))
for i, data in enumerate(trainloader, 0):
x, labels = data[0].to(device), data[1].to(device)
#清空 上一个 batch的梯度
optimizer.zero_grad()
pred = model(x)
loss = criterion(pred, labels)
loss.backward()
optimizer.step()
# 计算准确率
acc1, acc2 = accuracy(pred, labels, topk=(1, 2))
# accuracy 只有一个返回值,
n = x.size(0)
ac.update(acc1.item(), n)
train_loss += loss.item()
#可视化
postfix = {'train_loss': '%.6f' % (train_loss / (i + 1)), 'train_acc': '%.6f' % ac.avg}
trainloader.set_postfix(log=postfix)
# 验证集验证
model.eval()
with torch.no_grad():
val_ac = AverageMeter()
valloader = tqdm(valloader)
valloader.set_description('[%s%04d/%04d]' % ('Epoch:', epoch + 1, epochs))
val_loss = 0.0
for i, data in enumerate(valloader, 0):
val_x, val_labels = data[0].to(device), data[1].to(device)
val_pred = model(val_x)
loss = criterion(val_pred, val_labels)
prec1, prec2 = accuracy(val_pred, val_labels, topk=(1, 2))
n = val_x.size(0) # batch_size=32
val_ac.update(prec1.item(), n)
val_loss += loss.item()
postfix = {'validation_loss': '%.6f' % (val_loss / (i + 1)), 'validation_acc': '%.6f' % val_ac.avg}
valloader.set_postfix(log=postfix)
#torch.save(model.state_dict(), 'model.pth')
print("训练完成")
"""
#测试 成果
test_path = r'dogs-vs-cats/test1/148.jpg'
image = Image.open(test_path)
Tr = transform.Resize([224, 224])
img = Tr(image) #调整大小
img = transform.ToTensor()(img) #转化为tensor
img = transform.Normalize(mean=[0.5, 0.5 ,0.5], std=[0.5, 0.5, 0.5])(img) #标准化
mo = myNet()
state_dict = torch.load('model.pth')
mo.load_state_dict(state_dict)
oo = mo(img)
softmax_func = nn.Softmax(dim=1)
res = softmax_func(oo)
if res[:, 0] > 0.5:
animal = "Cat"
else:
animal = "Dog"
i = cv2.imread(test_path)
cv2.putText(i, animal, (30, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 0), 3)
cv2.imshow('cats vs dogs', i)
cv2.waitKey(0)
cv2.destroyAllWindows()
"""