首先下载Mnist数据集,解压后放入./
import numpy as np
import struct
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
#数据读入部分
def readfile(tgt):
times = {'train':60000,'t10k':10000}
def get_image(buf1):
image_index = 0
image_index += struct.calcsize('>IIII')
im = []
for i in range(times[tgt]):
temp = struct.unpack_from('>784B', buf1, image_index)
im.append(np.reshape(temp, (28, 28)))
image_index += struct.calcsize('>784B')
return im
def get_label(buf2):
label_index = 0
label_index += struct.calcsize('>II')
labels = []
for i in range(times[tgt]):
label = struct.unpack_from('>1B', buf2, label_index)
labels.append(label[0])
label_index += struct.calcsize('>1B')
return labels
with open(f'./{tgt}-images.idx3-ubyte', 'rb') as f1:
buf1 = f1.read()
im = get_image(buf1)
with open(f'./{tgt}-labels.idx1-ubyte', 'rb') as f2:
buf2 = f2.read()
label = get_label(buf2)
return im,label
train = ‘train’
test = ‘t10k’
a = readfile(train)
b = readfile(test)
X_train = a[0]
X_train = np.stack(X_train)
X_train = torch.tensor(X_train).float()
X_train.unsqueeze_(1)
y = a[1]
torch.manual_seed(0)
#定义网络,并实现sklearn风格接口
class Net(nn.Module):
def __init__(self, *args, **kwargs):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 1, (3, 3))
self.maxpoll1 = nn.MaxPool2d(kernel_size=(2, 2))
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(169,10)
self.fc2 = nn.Linear(10,10)
def forward(self, x, *args, **kwargs):
x = self.conv1(x)
x = self.maxpoll1(x)
x = self.flatten(x)
# print(x.shape)
x = self.fc1(x)
for i in range(5):
x = self.fc2(x)
x = F.softmax(x)
return x
def fit(self, X, y, epochs=10,batchsize = 10,need_categorize=True):
if need_categorize:
y = self.categorize(y).long()
self.criterion = nn.CrossEntropyLoss()
self.optimizer = torch.optim.Adam(self.parameters())
lss = []
self.train()
for i in tqdm(range(epochs)):
ls = 0
start = 0
end = batchsize
cnt = 0
while(1):
if start >= X.shape[0] - 1:
break
if end >= X.shape[0] - 1:
end = X.shape[0] - 1
y_pred = self.forward(X[start:end,:,:,:])
# print(y[start:end,:].argmax(1))
self.loss = self.criterion(y_pred, y[start:end,:].argmax(1))
ls += self.loss
self.optimizer.zero_grad()
self.loss.backward()
self.optimizer.step()
start += batchsize
end += batchsize
cnt += 1
ls /= cnt
# print(self.score(X,y))
print(ls)
lss.append(ls)
self.eval()
return lss
def predict(self,X,need_categorize=True):
y = self.forward(X)
y = y.argmax(1)
return y
def score(self, X, y,need_categorize=True):
if need_categorize:
y = self.categorize(y)
y1 = self.predict(X)
y1 = y1.numpy().astype('int').reshape((-1,1))
y = y.numpy().argmax(1).astype('int').reshape((-1,1))
metric = (y == y1).sum() / y.shape[0]
return metric
def categorize(self,y):
def func(x,i):
i = int(i)
x[i] = 1
return x
y = torch.tensor(y)
self.dim = int(y.max() + 1)
if len(y.shape) == 2:
pass
else:
y = y.reshape((-1,1))
return torch.tensor(np.stack([func(np.zeros((self.dim)),i[0]) for i in y]))
net = Net()
lss = net.fit(X_train[:10000,:,:,:],y[:10000],epochs=100,batchsize=16)#训练网络
‘’’
训练过程
1%| | 1/100 [00:02<04:14, 2.57s/it]tensor(1.6657, grad_fn=
2%|▏ | 2/100 [00:05<04:13, 2.58s/it]tensor(1.6453, grad_fn=
3%|▎ | 3/100 [00:07<04:16, 2.64s/it]tensor(1.6301, grad_fn=
4%|▍ | 4/100 [00:10<04:21, 2.73s/it]tensor(1.6245, grad_fn=
5%|▌ | 5/100 [00:14<04:35, 2.90s/it]tensor(1.6239, grad_fn=
6%|▌ | 6/100 [00:17<04:47, 3.06s/it]tensor(1.6188, grad_fn=
7%|▋ | 7/100 [00:21<04:58, 3.21s/it]tensor(1.6146, grad_fn=
8%|▊ | 8/100 [00:24<04:56, 3.22s/it]tensor(1.6073, grad_fn=
9%|▉ | 9/100 [00:27<05:00, 3.30s/it]tensor(1.6112, grad_fn=
10%|█ | 10/100 [00:31<05:01, 3.35s/it]tensor(1.6198, grad_fn=
11%|█ | 11/100 [00:34<04:58, 3.35s/it]tensor(1.6143, grad_fn=
12%|█▏ | 12/100 [00:38<04:58, 3.39s/it]tensor(1.6156, grad_fn=
13%|█▎ | 13/100 [00:41<05:00, 3.45s/it]tensor(1.6145, grad_fn=
14%|█▍ | 14/100 [00:45<05:00, 3.50s/it]tensor(1.6058, grad_fn=
15%|█▌ | 15/100 [00:49<05:01, 3.55s/it]tensor(1.6014, grad_fn=
16%|█▌ | 16/100 [00:52<04:58, 3.55s/it]tensor(1.5980, grad_fn=
17%|█▋ | 17/100 [00:56<05:00, 3.62s/it]tensor(1.5993, grad_fn=
18%|█▊ | 18/100 [00:59<04:55, 3.60s/it]tensor(1.6017, grad_fn=
19%|█▉ | 19/100 [01:03<04:52, 3.61s/it]tensor(1.5985, grad_fn=
20%|██ | 20/100 [01:07<04:49, 3.62s/it]tensor(1.6086, grad_fn=
21%|██ | 21/100 [01:10<04:46, 3.63s/it]tensor(1.5997, grad_fn=
22%|██▏ | 22/100 [01:14<04:46, 3.68s/it]tensor(1.5968, grad_fn=
23%|██▎ | 23/100 [01:18<04:43, 3.68s/it]tensor(1.6002, grad_fn=
24%|██▍ | 24/100 [01:22<04:42, 3.72s/it]tensor(1.5980, grad_fn=
25%|██▌ | 25/100 [01:26<04:52, 3.90s/it]tensor(1.5981, grad_fn=
26%|██▌ | 26/100 [01:31<05:02, 4.08s/it]tensor(1.5945, grad_fn=
27%|██▋ | 27/100 [01:36<05:27, 4.48s/it]tensor(1.5995, grad_fn=
28%|██▊ | 28/100 [01:41<05:33, 4.63s/it]tensor(1.6033, grad_fn=
29%|██▉ | 29/100 [01:45<05:15, 4.44s/it]tensor(1.5984, grad_fn=
30%|███ | 30/100 [01:49<05:03, 4.33s/it]tensor(1.6073, grad_fn=
31%|███ | 31/100 [01:53<04:51, 4.22s/it]tensor(1.6063, grad_fn=
32%|███▏ | 32/100 [01:57<04:40, 4.12s/it]tensor(1.5978, grad_fn=
33%|███▎ | 33/100 [02:01<04:31, 4.06s/it]tensor(1.6016, grad_fn=
34%|███▍ | 34/100 [02:05<04:24, 4.00s/it]tensor(1.6025, grad_fn=
35%|███▌ | 35/100 [02:09<04:18, 3.98s/it]tensor(1.6040, grad_fn=
36%|███▌ | 36/100 [02:13<04:23, 4.11s/it]tensor(1.5952, grad_fn=
37%|███▋ | 37/100 [02:17<04:21, 4.15s/it]tensor(1.6000, grad_fn=
38%|███▊ | 38/100 [02:22<04:23, 4.25s/it]tensor(1.6040, grad_fn=
39%|███▉ | 39/100 [02:26<04:22, 4.31s/it]tensor(1.6014, grad_fn=
40%|████ | 40/100 [02:30<04:15, 4.27s/it]tensor(1.6053, grad_fn=
41%|████ | 41/100 [02:34<04:05, 4.16s/it]tensor(1.6016, grad_fn=
42%|████▏ | 42/100 [02:39<04:03, 4.20s/it]tensor(1.5973, grad_fn=
43%|████▎ | 43/100 [02:43<03:57, 4.16s/it]tensor(1.5943, grad_fn=
44%|████▍ | 44/100 [02:47<04:02, 4.33s/it]tensor(1.6000, grad_fn=
45%|████▌ | 45/100 [02:52<03:58, 4.34s/it]tensor(1.5918, grad_fn=
46%|████▌ | 46/100 [02:56<03:54, 4.35s/it]tensor(1.6148, grad_fn=
47%|████▋ | 47/100 [03:00<03:49, 4.33s/it]tensor(1.6056, grad_fn=
48%|████▊ | 48/100 [03:04<03:41, 4.25s/it]tensor(1.5986, grad_fn=
49%|████▉ | 49/100 [03:08<03:32, 4.16s/it]tensor(1.6038, grad_fn=
50%|█████ | 50/100 [03:12<03:25, 4.11s/it]tensor(1.6008, grad_fn=
51%|█████ | 51/100 [03:16<03:19, 4.07s/it]tensor(1.5996, grad_fn=
52%|█████▏ | 52/100 [03:20<03:13, 4.03s/it]tensor(1.5983, grad_fn=
53%|█████▎ | 53/100 [03:24<03:06, 3.98s/it]tensor(1.6088, grad_fn=
54%|█████▍ | 54/100 [03:28<03:02, 3.96s/it]tensor(1.6084, grad_fn=
55%|█████▌ | 55/100 [03:32<02:56, 3.92s/it]tensor(1.5990, grad_fn=
56%|█████▌ | 56/100 [03:36<02:51, 3.89s/it]tensor(1.5984, grad_fn=
57%|█████▋ | 57/100 [03:39<02:44, 3.84s/it]tensor(1.5975, grad_fn=
58%|█████▊ | 58/100 [03:43<02:40, 3.81s/it]tensor(1.5939, grad_fn=
59%|█████▉ | 59/100 [03:47<02:35, 3.79s/it]tensor(1.5983, grad_fn=
60%|██████ | 60/100 [03:51<02:29, 3.75s/it]tensor(1.6144, grad_fn=
61%|██████ | 61/100 [03:54<02:26, 3.77s/it]tensor(1.5967, grad_fn=
62%|██████▏ | 62/100 [03:58<02:24, 3.81s/it]tensor(1.6061, grad_fn=
63%|██████▎ | 63/100 [04:02<02:23, 3.87s/it]tensor(1.6089, grad_fn=
64%|██████▍ | 64/100 [04:06<02:20, 3.90s/it]tensor(1.6301, grad_fn=
65%|██████▌ | 65/100 [04:10<02:16, 3.91s/it]tensor(1.6260, grad_fn=
66%|██████▌ | 66/100 [04:14<02:12, 3.90s/it]tensor(1.6090, grad_fn=
67%|██████▋ | 67/100 [04:18<02:08, 3.89s/it]tensor(1.6034, grad_fn=
68%|██████▊ | 68/100 [04:22<02:07, 3.97s/it]tensor(1.6016, grad_fn=
69%|██████▉ | 69/100 [04:26<02:06, 4.07s/it]tensor(1.6103, grad_fn=
70%|███████ | 70/100 [04:30<02:00, 4.03s/it]tensor(1.6197, grad_fn=
71%|███████ | 71/100 [04:34<01:56, 4.02s/it]tensor(1.6116, grad_fn=
72%|███████▏ | 72/100 [04:38<01:51, 4.00s/it]tensor(1.6138, grad_fn=
73%|███████▎ | 73/100 [04:42<01:46, 3.96s/it]tensor(1.6181, grad_fn=
74%|███████▍ | 74/100 [04:46<01:42, 3.93s/it]tensor(1.6069, grad_fn=
75%|███████▌ | 75/100 [04:50<01:37, 3.89s/it]tensor(1.6059, grad_fn=
76%|███████▌ | 76/100 [04:54<01:37, 4.05s/it]tensor(1.6237, grad_fn=
77%|███████▋ | 77/100 [04:58<01:34, 4.09s/it]tensor(1.6089, grad_fn=
78%|███████▊ | 78/100 [05:02<01:25, 3.91s/it]tensor(1.6178, grad_fn=
79%|███████▉ | 79/100 [05:05<01:18, 3.76s/it]tensor(1.6188, grad_fn=
80%|████████ | 80/100 [05:09<01:13, 3.69s/it]tensor(1.6343, grad_fn=
81%|████████ | 81/100 [05:12<01:08, 3.61s/it]tensor(1.6117, grad_fn=
82%|████████▏ | 82/100 [05:16<01:05, 3.66s/it]tensor(1.6036, grad_fn=
83%|████████▎ | 83/100 [05:19<01:00, 3.57s/it]tensor(1.6103, grad_fn=
84%|████████▍ | 84/100 [05:23<00:56, 3.52s/it]tensor(1.6077, grad_fn=
85%|████████▌ | 85/100 [05:26<00:52, 3.52s/it]tensor(1.6200, grad_fn=
86%|████████▌ | 86/100 [05:30<00:48, 3.50s/it]tensor(1.6193, grad_fn=
87%|████████▋ | 87/100 [05:33<00:45, 3.49s/it]tensor(1.6079, grad_fn=
88%|████████▊ | 88/100 [05:37<00:41, 3.48s/it]tensor(1.6000, grad_fn=
89%|████████▉ | 89/100 [05:40<00:37, 3.44s/it]tensor(1.6084, grad_fn=
90%|█████████ | 90/100 [05:43<00:34, 3.42s/it]tensor(1.6139, grad_fn=
91%|█████████ | 91/100 [05:47<00:30, 3.40s/it]tensor(1.6149, grad_fn=
92%|█████████▏| 92/100 [05:50<00:27, 3.40s/it]tensor(1.6075, grad_fn=
93%|█████████▎| 93/100 [05:53<00:23, 3.38s/it]tensor(1.6105, grad_fn=
94%|█████████▍| 94/100 [05:57<00:20, 3.38s/it]tensor(1.6165, grad_fn=
95%|█████████▌| 95/100 [06:00<00:16, 3.38s/it]tensor(1.6063, grad_fn=
96%|█████████▌| 96/100 [06:04<00:13, 3.38s/it]tensor(1.6069, grad_fn=
97%|█████████▋| 97/100 [06:07<00:10, 3.35s/it]tensor(1.6090, grad_fn=
98%|█████████▊| 98/100 [06:10<00:06, 3.33s/it]tensor(1.6037, grad_fn=
99%|█████████▉| 99/100 [06:13<00:03, 3.31s/it]tensor(1.6162, grad_fn=
100%|██████████| 100/100 [06:17<00:00, 3.78s/it]tensor(1.6157, grad_fn=
‘’’
plt.plot(list(range(len(lss))),lss)#可视化loss
#可以看出因为定义了一个权值复用五次的全连接层(self.fc2),网络发生了明显的过拟合
net.score(X_train[20000:60000-1,:,:,:],y[20000:60000 - 1]) #评估网络
#0.8310457761444036
#随机挑选几个样本进行可视化
for i in range(9):
item = np.random.randint(0,60000 -1)
plt.subplot(3, 3, i + 1)
title = f"标签:{str(y[item])} 预测:{net.predict(X_train[[item],:,:,:])[0]}"
plt.title(title, fontproperties='SimHei')
plt.imshow(X_train[item,:,:,:].squeeze(0), cmap='gray')
plt.show()
#可视化一下self.fc1的权值
for i in range(10):
plt.subplot(3,4,i+1)
plt.title(f'{i}')
plt.imshow(list(net.fc1.parameters())[0][i,:].view((13,13)).detach().numpy(),cmap='gray')
plt.show()