Tqdm 是一个快速,可扩展的Python进度条,可以在 Python 长循环中添加一个进度提示信息,用户只需要封装任意的迭代器 tqdm(iterator)。
import time
from tqdm import tqdm#导入tqdm
a = 10
b = range(a)#从0到10
c = tqdm(b, total=a)#使用tqdm产生一个进度条,并且会实时的推算剩余时间
for i in c:
time.sleep(1)#使用for循环迭代,把进度条打印出来了,
运行结果:
0%| | 0/10 [00:00, ?it/s]
10%|█ | 1/10 [00:01<00:09, 1.01s/it]
20%|██ | 2/10 [00:02<00:08, 1.01s/it]
30%|███ | 3/10 [00:03<00:07, 1.01s/it]
40%|████ | 4/10 [00:04<00:06, 1.01s/it]
50%|█████ | 5/10 [00:05<00:05, 1.01s/it]
60%|██████ | 6/10 [00:06<00:04, 1.01s/it]
70%|███████ | 7/10 [00:07<00:03, 1.01s/it]
80%|████████ | 8/10 [00:08<00:02, 1.01s/it]
90%|█████████ | 9/10 [00:09<00:01, 1.01s/it]
100%|██████████| 10/10 [00:10<00:00, 1.01s/it]
100%|██████████| 10/10 [00:10<00:00, 1.01s/it]
使用官方提供的:均值mean = 0.1307,标准差std = 0.3081
torch.save(model.state_ldict(),path)#模型保存
torch.save(optimizer.state_dict(), path)#优化器保存
#模型加载
model.loade_state_dict(torch.load(path))
optimizer.load_state_dict(torch.load(path))
#传播:前向传播,把原来的图像通过transfroms翻译之后作为参数传入,给我们返回一个模型计算出来的值,然后通过这个值与标签进行比对,继而得到损失值
#然后使用损失进行反向传播,再进行单次优化
#梯度置0就是把loss关于weigh的导数置为0,都不需要与其他的mini_batch混合起来计算
当网络参数进行反馈的时候,梯度是一个累加的过程,而不是可替代的;但是在处理每一个batch_size的时候都是累加的,但是我们并不希望他是累加的,我们希望每一次都是独立的。如果累加的话酒瓯相当于增加了batch_size的尺寸了,这是我们不需要的,所以在进行每一次训练而不是每一轮都需要进行梯度置0.
#梯度置0:optimizer.zero_grad()
#反向传播:loss.backward()
#单次优化:optimizer.step
http://zhuanlan.zhihu.com/p/35709485
from torchvision.datasets import MNIST
from torchvision import transforms
#图像处理
my_transforms = transforms.Compose(
[transforms.PILToTensor(),
transforms.Normalize(mean=(0.1307,), std=(0.3081,))#图像标准化处理
]
)
#获取数据源
mnist_train = MNIST(root="../MNIST_data", train=True, download=True, transform=my_transforms)
#数据加载
from torch.utils.data import DataLoader#导入数据加载器
dataloader = DataLoader(mnist_train, batch_size=8, shuffle=True)
for (images, labels) in dataloader:
pass
from torch import nn
class MnistModel(nn.Module):
def __init__(self):#继承__init__方法
super(MnistModel, self).__init__()
self.fc1 = nn.Linear(1*28*28, 10)#最初传入的图片的像素点是1*28*28的,最后我们要收敛成10个结果
def forward(self, image):#继承前向传播的方法
image_viwed = image.view(-1, 1*28*28)#此处需要拍平
out = self.fc1(image_viwed)
return out
#实例化模型
model = MnistModel()
from torch import optim#导入优化器
optim.Adam(model.parameters(), lr=1e-4)
#损失函数
loss_functon = nn.CrossEntropyLoss()
dataloader = tqdm(dataloader, total=len(dataloader))
#循环可迭代对象
for (images, labels) in dataloader:
#梯度置零
optimizer.zero_grad()
#前向传播
output = model(images)
#计算损失
loss = loss_functon(output, labels)
#反向传播
loss.backward()
#优化器更新
optimizer.step()
#保存模型
save(model.state_dict(), 'models/model.pkl')
save(optimizer.state_dict(), 'models/optimizer.pkl')
#save中第二个参数是保存路径'models/model.pkl'
运行结果:
0%| | 0/7500 [00:00, ?it/s]
0%| | 1/7500 [00:00<54:58, 2.27it/s]
1%| | 56/7500 [00:00<00:54, 136.43it/s]
2%|▏ | 117/7500 [00:00<00:28, 259.16it/s]
2%|▏ | 183/7500 [00:00<00:20, 358.13it/s]
3%|▎ | 247/7500 [00:00<00:16, 429.47it/s]
4%|▍ | 306/7500 [00:00<00:15, 473.85it/s]
5%|▍ | 370/7500 [00:01<00:13, 513.85it/s]
6%|▌ | 428/7500 [00:01<00:13, 517.88it/s]
7%|▋ | 488/7500 [00:01<00:12, 540.91it/s]
7%|▋ | 551/7500 [00:01<00:12, 561.71it/s]
8%|▊ | 613/7500 [00:01<00:11, 578.30it/s]
9%|▉ | 674/7500 [00:01<00:11, 587.12it/s]
10%|▉ | 736/7500 [00:01<00:11, 596.71it/s]
11%|█ | 799/7500 [00:01<00:11, 605.59it/s]
12%|█▏ | 864/7500 [00:01<00:10, 608.81it/s]
12%|█▏ | 926/7500 [00:01<00:10, 611.42it/s]
13%|█▎ | 992/7500 [00:02<00:10, 607.85it/s]
14%|█▍ | 1059/7500 [00:02<00:10, 608.86it/s]
15%|█▌ | 1125/7500 [00:02<00:10, 605.86it/s]
16%|█▌ | 1186/7500 [00:02<00:10, 606.85it/s]
17%|█▋ | 1251/7500 [00:02<00:10, 602.15it/s]
18%|█▊ | 1318/7500 [00:02<00:10, 604.34it/s]
18%|█▊ | 1380/7500 [00:02<00:10, 608.67it/s]
19%|█▉ | 1442/7500 [00:02<00:09, 611.93it/s]
20%|██ | 1508/7500 [00:02<00:09, 608.39it/s]
21%|██ | 1571/7500 [00:03<00:09, 605.98it/s]
22%|██▏ | 1632/7500 [00:03<00:10, 581.82it/s]
23%|██▎ | 1691/7500 [00:03<00:09, 583.70it/s]
23%|██▎ | 1754/7500 [00:03<00:09, 588.64it/s]
24%|██▍ | 1816/7500 [00:03<00:09, 589.21it/s]
25%|██▌ | 1878/7500 [00:03<00:09, 588.82it/s]
26%|██▌ | 1939/7500 [00:03<00:09, 594.50it/s]
27%|██▋ | 1999/7500 [00:03<00:09, 595.68it/s]
27%|██▋ | 2061/7500 [00:03<00:09, 602.43it/s]
28%|██▊ | 2122/7500 [00:03<00:08, 604.12it/s]
29%|██▉ | 2184/7500 [00:04<00:08, 608.56it/s]
30%|██▉ | 2246/7500 [00:04<00:08, 611.49it/s]
31%|███ | 2308/7500 [00:04<00:09, 544.32it/s]
32%|███▏ | 2364/7500 [00:04<00:09, 534.13it/s]
32%|███▏ | 2431/7500 [00:04<00:09, 556.09it/s]
33%|███▎ | 2492/7500 [00:04<00:08, 570.89it/s]
34%|███▍ | 2557/7500 [00:04<00:08, 585.16it/s]
35%|███▍ | 2616/7500 [00:04<00:09, 533.18it/s]
36%|███▌ | 2677/7500 [00:04<00:08, 553.78it/s]
37%|███▋ | 2744/7500 [00:05<00:08, 570.29it/s]
37%|███▋ | 2810/7500 [00:05<00:08, 579.39it/s]
38%|███▊ | 2869/7500 [00:05<00:08, 566.32it/s]
39%|███▉ | 2926/7500 [00:05<00:08, 567.31it/s]
40%|███▉ | 2989/7500 [00:05<00:07, 577.34it/s]
41%|████ | 3050/7500 [00:05<00:07, 586.34it/s]
42%|████▏ | 3114/7500 [00:05<00:07, 593.42it/s]
42%|████▏ | 3179/7500 [00:05<00:07, 601.58it/s]
43%|████▎ | 3243/7500 [00:05<00:07, 603.29it/s]
44%|████▍ | 3306/7500 [00:06<00:06, 602.56it/s]
45%|████▍ | 3367/7500 [00:06<00:06, 596.50it/s]
46%|████▌ | 3427/7500 [00:06<00:06, 588.98it/s]
47%|████▋ | 3490/7500 [00:06<00:06, 592.33it/s]
47%|████▋ | 3551/7500 [00:06<00:06, 597.25it/s]
48%|████▊ | 3615/7500 [00:06<00:06, 601.17it/s]
49%|████▉ | 3676/7500 [00:06<00:06, 578.13it/s]
50%|████▉ | 3734/7500 [00:06<00:06, 578.66it/s]
51%|█████ | 3798/7500 [00:06<00:06, 588.14it/s]
51%|█████▏ | 3857/7500 [00:06<00:06, 588.52it/s]
52%|█████▏ | 3916/7500 [00:07<00:06, 588.71it/s]
53%|█████▎ | 3979/7500 [00:07<00:05, 592.62it/s]
54%|█████▍ | 4039/7500 [00:07<00:05, 594.47it/s]
55%|█████▍ | 4100/7500 [00:07<00:05, 598.65it/s]
56%|█████▌ | 4166/7500 [00:07<00:05, 599.78it/s]
56%|█████▋ | 4226/7500 [00:07<00:05, 582.70it/s]
57%|█████▋ | 4285/7500 [00:07<00:05, 582.60it/s]
58%|█████▊ | 4346/7500 [00:07<00:05, 584.10it/s]
59%|█████▉ | 4411/7500 [00:07<00:05, 595.13it/s]
60%|█████▉ | 4472/7500 [00:08<00:05, 598.69it/s]
61%|██████ | 4539/7500 [00:08<00:04, 601.76it/s]
61%|██████▏ | 4608/7500 [00:08<00:04, 609.77it/s]
62%|██████▏ | 4676/7500 [00:08<00:04, 612.59it/s]
63%|██████▎ | 4738/7500 [00:08<00:04, 614.68it/s]
64%|██████▍ | 4803/7500 [00:08<00:04, 607.71it/s]
65%|██████▍ | 4864/7500 [00:08<00:04, 607.96it/s]
66%|██████▌ | 4928/7500 [00:08<00:04, 608.95it/s]
67%|██████▋ | 4989/7500 [00:08<00:04, 609.21it/s]
67%|██████▋ | 5050/7500 [00:08<00:04, 600.45it/s]
68%|██████▊ | 5111/7500 [00:09<00:03, 602.85it/s]
69%|██████▉ | 5179/7500 [00:09<00:03, 608.32it/s]
70%|██████▉ | 5241/7500 [00:09<00:03, 611.33it/s]
71%|███████ | 5308/7500 [00:09<00:03, 610.69it/s]
72%|███████▏ | 5370/7500 [00:09<00:03, 613.23it/s]
72%|███████▏ | 5433/7500 [00:09<00:03, 612.81it/s]
73%|███████▎ | 5496/7500 [00:09<00:03, 617.64it/s]
74%|███████▍ | 5558/7500 [00:09<00:03, 617.97it/s]
75%|███████▍ | 5620/7500 [00:09<00:03, 618.13it/s]
76%|███████▌ | 5684/7500 [00:09<00:02, 612.50it/s]
77%|███████▋ | 5748/7500 [00:10<00:02, 611.88it/s]
77%|███████▋ | 5810/7500 [00:10<00:02, 605.26it/s]
78%|███████▊ | 5871/7500 [00:10<00:02, 598.60it/s]
79%|███████▉ | 5931/7500 [00:10<00:02, 590.05it/s]
80%|███████▉ | 5999/7500 [00:10<00:02, 598.84it/s]
81%|████████ | 6059/7500 [00:10<00:02, 598.85it/s]
82%|████████▏ | 6124/7500 [00:10<00:02, 597.09it/s]
82%|████████▏ | 6185/7500 [00:10<00:02, 600.07it/s]
83%|████████▎ | 6246/7500 [00:10<00:02, 602.87it/s]
84%|████████▍ | 6312/7500 [00:11<00:01, 602.59it/s]
85%|████████▍ | 6373/7500 [00:11<00:01, 571.69it/s]
86%|████████▌ | 6435/7500 [00:11<00:01, 584.63it/s]
87%|████████▋ | 6502/7500 [00:11<00:01, 592.30it/s]
87%|████████▋ | 6562/7500 [00:11<00:01, 586.71it/s]
88%|████████▊ | 6622/7500 [00:11<00:01, 590.36it/s]
89%|████████▉ | 6683/7500 [00:11<00:01, 587.12it/s]
90%|████████▉ | 6746/7500 [00:11<00:01, 599.33it/s]
91%|█████████ | 6809/7500 [00:11<00:01, 600.31it/s]
92%|█████████▏| 6871/7500 [00:11<00:01, 604.28it/s]
92%|█████████▏| 6933/7500 [00:12<00:00, 600.96it/s]
93%|█████████▎| 6995/7500 [00:12<00:00, 606.48it/s]
94%|█████████▍| 7060/7500 [00:12<00:00, 610.62it/s]
95%|█████████▍| 7123/7500 [00:12<00:00, 607.90it/s]
96%|█████████▌| 7184/7500 [00:12<00:00, 608.17it/s]
97%|█████████▋| 7245/7500 [00:12<00:00, 608.48it/s]
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader
from torch import nn
from tqdm import tqdm
# from MNIST_train import test
import os#判断路径是否存在的模块
# import numpy as np
import torch
class MnistModel(nn.Module):
def __init__(self):#继承__init__方法
super(MnistModel, self).__init__()
self.fc1 = nn.Linear(1*28*28, 10)#最初传入的图片的像素点是1*28*28的,最后我们要收敛成10个结果
def forward(self, image):#继承前向传播的方法
image_viwed = image.view(-1, 1*28*28)#此处需要拍平
out = self.fc1(image_viwed)
return out
#实例化模型
model = MnistModel()
if os.path.exists('models/model.pkl'):
model.load_state_dict(torch.load('models/model.pkl'))
#损失函数
loss_functon = nn.CrossEntropyLoss()
#图像处理
my_transforms = transforms.Compose(
[transforms.ToTensor(),
#transforms.PILToTensor(),
transforms.Normalize(mean=(0.1307,), std=(0.3081,))#图像标准化处理
]
)
#获取数据源
mnist_train = MNIST(root="../MNIST_data", train=False, download=True, transform=my_transforms)
#数据加载
# from torch.utils.data import DataLoader#导入数据加载器
dataloader = DataLoader(mnist_train, batch_size=8, shuffle=True)
dataloader = tqdm(dataloader, total=len(dataloader))
model.eval()
with torch.no_grad():
for images, labels in dataloader:
#获取结果
output = model(images)
print(output)
exit()
#计算损失
loss = loss_functon(output, labels)
运行结果:
0%| | 0/1250 [00:00, ?it/s]
0%| | 0/1250 [00:00, ?it/s]
tensor([[-1.3393e+01, 5.8203e+00, 1.7052e+01, 9.0464e+00, -1.4631e+01,
-1.9053e+00, 2.6430e+00, -1.4275e+01, 2.6435e+00, -1.2792e+01],
[-3.2997e+00, -8.3555e+00, 2.9035e+00, -2.1400e+00, -6.2862e+00,
-5.0284e+00, 1.0334e+01, -6.2676e+00, 1.5082e+00, -9.5994e+00],
[ 1.3473e+00, -2.4987e+01, -5.6099e+00, -1.0971e+01, 3.0164e+00,
-6.1351e-01, -2.2419e+00, -3.3190e+00, -6.0956e-02, -2.2503e+00],
[-8.2029e+00, 6.8706e+00, 4.3379e+00, 8.0266e-01, 1.7146e+00,
-1.5904e+00, -5.4480e-03, 3.8092e+00, -1.7199e-01, 3.7706e-02],
[-1.1637e+01, -8.7083e+00, 5.0389e-01, -1.3896e+00, -6.7360e+00,
3.2284e+00, 5.3744e+00, -2.3499e+01, 6.6444e+00, -1.0760e+01],
[-1.4884e+01, 9.0861e+00, 2.7453e+00, 3.7568e+00, -4.2774e+00,
-3.0478e+00, -3.6093e-01, -7.9317e-01, 2.8089e+00, -6.7414e-01],
[-1.8389e+00, -9.0608e+00, 8.2760e-01, 7.3417e+00, -9.3287e+00,
3.0612e+00, -1.0551e+01, -7.5760e+00, -2.0199e-01, -1.3060e+00],
[-4.0707e+00, -7.2295e+00, 4.1510e+00, 1.0424e+01, -6.3570e+00,
3.8209e-01, -8.3809e+00, -1.0495e+01, -2.5747e+00, -5.9327e+00]])
Process finished with exit code 0
#结果得到的就是我们的10分类中每个分类的可能性
model.eval()
with torch.no_grad():
for images, labels in dataloader:
#获取结果
output = model(images)
result = output.max(dim=1)#在dim=1的维度上面比较每一个列表里的十个数
print(result)
exit()
运行结果:
0%| | 0/1250 [00:00, ?it/s]
0%| | 0/1250 [00:00, ?it/s]
torch.return_types.max(
values=tensor([ 8.3219, 7.4667, 14.7933, 8.5420, 5.7937, 3.0753, 12.5724, 5.9098]),
indices=tensor([1, 2, 6, 1, 4, 5, 6, 3]))
Process finished with exit code 0
for images, labels in dataloader:
#获取结果
output = model(images)
result = output.max(dim=1).indices#获取索引
print(result)#打印训练结果
print(labels)#打印标签
exit()
#计算损失
loss = loss_functon(output, labels)
#运行结果:
0%| | 0/1250 [00:00, ?it/s]
0%| | 0/1250 [00:00, ?it/s]
tensor([2, 8, 3, 4, 6, 2, 2, 3])
tensor([2, 8, 3, 4, 6, 2, 2, 3])
Process finished with exit code 0
#从结果中可以看得到这次的训练结果和标签值完全相同,说明在这个epoch中都识别对了!
print(result.eq(labels))#比较result与labels是否相等
运行结果:
0%| | 0/1250 [00:00, ?it/s]tensor([1, 6, 5, 3, 8, 0, 4, 1])
tensor([1, 6, 5, 3, 8, 0, 4, 1])
0%| | 0/1250 [00:00, ?it/s]
tensor([True, True, True, True, True, True, True, True])
Process finished with exit code 0
#说明此次识别准确率为100%
布尔类型不能直接求均值,需要使用强制类型转化吧布尔值转换为浮点类型
print(result.eq(labels).float().mean())#先转换成浮点类型再求均值
运行结果:
0%| | 0/1250 [00:00, ?it/s]
tensor(1.)
Process finished with exit code 0
#准确率:1.0 即 100%
print(result.eq(labels).float().mean().item())#把识别准确率取出来
运行结果:
1.0
0%| | 0/1250 [00:00, ?it/s]
Process finished with exit code 0
# from torchvision.datasets import MNIST
# from torchvision import transforms
# from torch.utils.data import DataLoader
# from torch import nn
# from torch import optim#导入优化器
# from tqdm import tqdm#导入tqdm
# from torch import save
# from torch import save, load
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader
from torch import nn
from tqdm import tqdm
# from MNIST_train import test
import os#判断路径是否存在的模块
import numpy as np
import torch
class MnistModel(nn.Module):
def __init__(self):#继承__init__方法
super(MnistModel, self).__init__()
self.fc1 = nn.Linear(1*28*28, 10)#最初传入的图片的像素点是1*28*28的,最后我们要收敛成10个结果
def forward(self, image):#继承前向传播的方法
image_viwed = image.view(-1, 1*28*28)#此处需要拍平
out = self.fc1(image_viwed)
return out
#实例化模型
model = MnistModel()
if os.path.exists('models/model.pkl'):
model.load_state_dict(torch.load('models/model.pkl'))
#损失函数
loss_functon = nn.CrossEntropyLoss()
#图像处理
my_transforms = transforms.Compose(
[transforms.ToTensor(),
#transforms.PILToTensor(),
transforms.Normalize(mean=(0.1307,), std=(0.3081,))#图像标准化处理
]
)
#获取数据源
mnist_train = MNIST(root="../MNIST_data", train=False, download=True, transform=my_transforms)
#数据加载
# from torch.utils.data import DataLoader#导入数据加载器
dataloader = DataLoader(mnist_train, batch_size=8, shuffle=True)
dataloader = tqdm(dataloader, total=len(dataloader))
succeed = []#保存准确率的列表
model.eval()
with torch.no_grad():
for images, labels in dataloader:
#获取结果
output = model(images)
result = output.max(dim=1).indices#获取索引
# print(result)#打印训练结果
# print(labels)#打印标签
# print(result.eq(labels))#比较result与labels是否相等
# print(result.eq(labels).float().mean())
# print(result.eq(labels).float().mean().item())#把识别准确率取出来
succeed.append(result.eq(labels).float().mean().item())
# exit()
#计算损失
loss = loss_functon(output, labels)
print('一轮的成功率', np.mean(succeed))
运行结果:
0%| | 0/1250 [00:00, ?it/s]
5%|▍ | 57/1250 [00:00<00:02, 542.92it/s]
9%|▉ | 112/1250 [00:00<00:02, 546.49it/s]
14%|█▍ | 172/1250 [00:00<00:01, 558.39it/s]
19%|█▉ | 237/1250 [00:00<00:01, 582.55it/s]
24%|██▎ | 296/1250 [00:00<00:01, 574.66it/s]
28%|██▊ | 354/1250 [00:00<00:01, 567.53it/s]
33%|███▎ | 411/1250 [00:00<00:01, 559.71it/s]
37%|███▋ | 467/1250 [00:00<00:01, 526.87it/s]
42%|████▏ | 520/1250 [00:00<00:01, 527.60it/s]
46%|████▋ | 580/1250 [00:01<00:01, 540.98it/s]
51%|█████ | 640/1250 [00:01<00:01, 550.30it/s]
56%|█████▌ | 696/1250 [00:01<00:01, 545.38it/s]
60%|██████ | 751/1250 [00:01<00:00, 538.69it/s]
64%|██████▍ | 805/1250 [00:01<00:00, 531.24it/s]
69%|██████▉ | 860/1250 [00:01<00:00, 528.92it/s]
73%|███████▎ | 916/1250 [00:01<00:00, 530.67it/s]
78%|███████▊ | 971/1250 [00:01<00:00, 535.86it/s]
83%|████████▎ | 1032/1250 [00:01<00:00, 549.47it/s]
87%|████████▋ | 1087/1250 [00:02<00:00, 534.16it/s]
91%|█████████▏| 1141/1250 [00:02<00:00, 528.62it/s]
96%|█████████▌| 1200/1250 [00:02<00:00, 546.09it/s]
100%|██████████| 1250/1250 [00:02<00:00, 542.33it/s]
一轮的成功率 0.9141
上述便是对手写数字识别的训练,那么如何进行多轮训练呢?