快速入门Pytorch:以MNIST数据集为例

首先下载Mnist数据集,解压后放入./

import numpy as np

import struct

import matplotlib.pyplot as plt

import torch

import torch.nn as nn

import torch.nn.functional as F

from tqdm import tqdm

#数据读入部分

def readfile(tgt):

times = {'train':60000,'t10k':10000}

def get_image(buf1):

    image_index = 0

    image_index += struct.calcsize('>IIII')

    im = []

    for i in range(times[tgt]):

        temp = struct.unpack_from('>784B', buf1, image_index)

        im.append(np.reshape(temp, (28, 28)))

        image_index += struct.calcsize('>784B')

    return im



def get_label(buf2): 

    label_index = 0

    label_index += struct.calcsize('>II')

    labels = []

    for i in range(times[tgt]):

        label = struct.unpack_from('>1B', buf2, label_index)

        labels.append(label[0])

        label_index += struct.calcsize('>1B')

    return labels



with open(f'./{tgt}-images.idx3-ubyte', 'rb') as f1:

    buf1 = f1.read()

    im = get_image(buf1)

with open(f'./{tgt}-labels.idx1-ubyte', 'rb') as f2:

    buf2 = f2.read()

    label = get_label(buf2)

return  im,label

train = ‘train’

test = ‘t10k’

a = readfile(train)

b = readfile(test)

X_train = a[0]

X_train = np.stack(X_train)

X_train = torch.tensor(X_train).float()

X_train.unsqueeze_(1)

y = a[1]

torch.manual_seed(0)

#定义网络,并实现sklearn风格接口

class Net(nn.Module):

def __init__(self, *args, **kwargs):

    super(Net, self).__init__()

    self.conv1 = nn.Conv2d(1, 1, (3, 3))

    self.maxpoll1 = nn.MaxPool2d(kernel_size=(2, 2))

    self.flatten = nn.Flatten()

    self.fc1 = nn.Linear(169,10)

    self.fc2 = nn.Linear(10,10)



def forward(self, x, *args, **kwargs):

    x = self.conv1(x)

    x = self.maxpoll1(x)

    x = self.flatten(x)

    # print(x.shape)

    x = self.fc1(x)

    for i in range(5):

        x = self.fc2(x)

    x = F.softmax(x)

    return x



def fit(self, X, y, epochs=10,batchsize = 10,need_categorize=True):

    if need_categorize:

        y = self.categorize(y).long()

    self.criterion = nn.CrossEntropyLoss()

    self.optimizer = torch.optim.Adam(self.parameters())

    lss = []

    self.train()

    for i in tqdm(range(epochs)):

        ls = 0

        start = 0

        end = batchsize

        cnt = 0

        while(1):

            if start >= X.shape[0] - 1:

                break

            if end >= X.shape[0] - 1:

                end = X.shape[0] - 1

            y_pred = self.forward(X[start:end,:,:,:])

            # print(y[start:end,:].argmax(1))

            self.loss = self.criterion(y_pred, y[start:end,:].argmax(1))

            ls += self.loss

            self.optimizer.zero_grad()

            self.loss.backward()

            self.optimizer.step()

            start += batchsize

            end += batchsize

            cnt += 1

        ls /= cnt

        # print(self.score(X,y))

        print(ls)

        lss.append(ls)

    self.eval()

    return lss



def predict(self,X,need_categorize=True):

    y = self.forward(X)

    y = y.argmax(1)

    return y



def score(self, X, y,need_categorize=True):

    if need_categorize:

        y = self.categorize(y)

    y1 = self.predict(X)

    y1 =  y1.numpy().astype('int').reshape((-1,1))

    y = y.numpy().argmax(1).astype('int').reshape((-1,1))

    metric = (y == y1).sum() / y.shape[0]

    return metric



def categorize(self,y):

    

    def func(x,i):

        i = int(i)

        x[i] = 1

        return x

    y = torch.tensor(y)

    self.dim = int(y.max() + 1)

    if len(y.shape) == 2:

        pass

    else:

        y = y.reshape((-1,1))

    return torch.tensor(np.stack([func(np.zeros((self.dim)),i[0]) for i in y]))

net = Net()

lss = net.fit(X_train[:10000,:,:,:],y[:10000],epochs=100,batchsize=16)#训练网络

‘’’

训练过程

1%| | 1/100 [00:02<04:14, 2.57s/it]tensor(1.6657, grad_fn=)

2%|▏ | 2/100 [00:05<04:13, 2.58s/it]tensor(1.6453, grad_fn=)

3%|▎ | 3/100 [00:07<04:16, 2.64s/it]tensor(1.6301, grad_fn=)

4%|▍ | 4/100 [00:10<04:21, 2.73s/it]tensor(1.6245, grad_fn=)

5%|▌ | 5/100 [00:14<04:35, 2.90s/it]tensor(1.6239, grad_fn=)

6%|▌ | 6/100 [00:17<04:47, 3.06s/it]tensor(1.6188, grad_fn=)

7%|▋ | 7/100 [00:21<04:58, 3.21s/it]tensor(1.6146, grad_fn=)

8%|▊ | 8/100 [00:24<04:56, 3.22s/it]tensor(1.6073, grad_fn=)

9%|▉ | 9/100 [00:27<05:00, 3.30s/it]tensor(1.6112, grad_fn=)

10%|█ | 10/100 [00:31<05:01, 3.35s/it]tensor(1.6198, grad_fn=)

11%|█ | 11/100 [00:34<04:58, 3.35s/it]tensor(1.6143, grad_fn=)

12%|█▏ | 12/100 [00:38<04:58, 3.39s/it]tensor(1.6156, grad_fn=)

13%|█▎ | 13/100 [00:41<05:00, 3.45s/it]tensor(1.6145, grad_fn=)

14%|█▍ | 14/100 [00:45<05:00, 3.50s/it]tensor(1.6058, grad_fn=)

15%|█▌ | 15/100 [00:49<05:01, 3.55s/it]tensor(1.6014, grad_fn=)

16%|█▌ | 16/100 [00:52<04:58, 3.55s/it]tensor(1.5980, grad_fn=)

17%|█▋ | 17/100 [00:56<05:00, 3.62s/it]tensor(1.5993, grad_fn=)

18%|█▊ | 18/100 [00:59<04:55, 3.60s/it]tensor(1.6017, grad_fn=)

19%|█▉ | 19/100 [01:03<04:52, 3.61s/it]tensor(1.5985, grad_fn=)

20%|██ | 20/100 [01:07<04:49, 3.62s/it]tensor(1.6086, grad_fn=)

21%|██ | 21/100 [01:10<04:46, 3.63s/it]tensor(1.5997, grad_fn=)

22%|██▏ | 22/100 [01:14<04:46, 3.68s/it]tensor(1.5968, grad_fn=)

23%|██▎ | 23/100 [01:18<04:43, 3.68s/it]tensor(1.6002, grad_fn=)

24%|██▍ | 24/100 [01:22<04:42, 3.72s/it]tensor(1.5980, grad_fn=)

25%|██▌ | 25/100 [01:26<04:52, 3.90s/it]tensor(1.5981, grad_fn=)

26%|██▌ | 26/100 [01:31<05:02, 4.08s/it]tensor(1.5945, grad_fn=)

27%|██▋ | 27/100 [01:36<05:27, 4.48s/it]tensor(1.5995, grad_fn=)

28%|██▊ | 28/100 [01:41<05:33, 4.63s/it]tensor(1.6033, grad_fn=)

29%|██▉ | 29/100 [01:45<05:15, 4.44s/it]tensor(1.5984, grad_fn=)

30%|███ | 30/100 [01:49<05:03, 4.33s/it]tensor(1.6073, grad_fn=)

31%|███ | 31/100 [01:53<04:51, 4.22s/it]tensor(1.6063, grad_fn=)

32%|███▏ | 32/100 [01:57<04:40, 4.12s/it]tensor(1.5978, grad_fn=)

33%|███▎ | 33/100 [02:01<04:31, 4.06s/it]tensor(1.6016, grad_fn=)

34%|███▍ | 34/100 [02:05<04:24, 4.00s/it]tensor(1.6025, grad_fn=)

35%|███▌ | 35/100 [02:09<04:18, 3.98s/it]tensor(1.6040, grad_fn=)

36%|███▌ | 36/100 [02:13<04:23, 4.11s/it]tensor(1.5952, grad_fn=)

37%|███▋ | 37/100 [02:17<04:21, 4.15s/it]tensor(1.6000, grad_fn=)

38%|███▊ | 38/100 [02:22<04:23, 4.25s/it]tensor(1.6040, grad_fn=)

39%|███▉ | 39/100 [02:26<04:22, 4.31s/it]tensor(1.6014, grad_fn=)

40%|████ | 40/100 [02:30<04:15, 4.27s/it]tensor(1.6053, grad_fn=)

41%|████ | 41/100 [02:34<04:05, 4.16s/it]tensor(1.6016, grad_fn=)

42%|████▏ | 42/100 [02:39<04:03, 4.20s/it]tensor(1.5973, grad_fn=)

43%|████▎ | 43/100 [02:43<03:57, 4.16s/it]tensor(1.5943, grad_fn=)

44%|████▍ | 44/100 [02:47<04:02, 4.33s/it]tensor(1.6000, grad_fn=)

45%|████▌ | 45/100 [02:52<03:58, 4.34s/it]tensor(1.5918, grad_fn=)

46%|████▌ | 46/100 [02:56<03:54, 4.35s/it]tensor(1.6148, grad_fn=)

47%|████▋ | 47/100 [03:00<03:49, 4.33s/it]tensor(1.6056, grad_fn=)

48%|████▊ | 48/100 [03:04<03:41, 4.25s/it]tensor(1.5986, grad_fn=)

49%|████▉ | 49/100 [03:08<03:32, 4.16s/it]tensor(1.6038, grad_fn=)

50%|█████ | 50/100 [03:12<03:25, 4.11s/it]tensor(1.6008, grad_fn=)

51%|█████ | 51/100 [03:16<03:19, 4.07s/it]tensor(1.5996, grad_fn=)

52%|█████▏ | 52/100 [03:20<03:13, 4.03s/it]tensor(1.5983, grad_fn=)

53%|█████▎ | 53/100 [03:24<03:06, 3.98s/it]tensor(1.6088, grad_fn=)

54%|█████▍ | 54/100 [03:28<03:02, 3.96s/it]tensor(1.6084, grad_fn=)

55%|█████▌ | 55/100 [03:32<02:56, 3.92s/it]tensor(1.5990, grad_fn=)

56%|█████▌ | 56/100 [03:36<02:51, 3.89s/it]tensor(1.5984, grad_fn=)

57%|█████▋ | 57/100 [03:39<02:44, 3.84s/it]tensor(1.5975, grad_fn=)

58%|█████▊ | 58/100 [03:43<02:40, 3.81s/it]tensor(1.5939, grad_fn=)

59%|█████▉ | 59/100 [03:47<02:35, 3.79s/it]tensor(1.5983, grad_fn=)

60%|██████ | 60/100 [03:51<02:29, 3.75s/it]tensor(1.6144, grad_fn=)

61%|██████ | 61/100 [03:54<02:26, 3.77s/it]tensor(1.5967, grad_fn=)

62%|██████▏ | 62/100 [03:58<02:24, 3.81s/it]tensor(1.6061, grad_fn=)

63%|██████▎ | 63/100 [04:02<02:23, 3.87s/it]tensor(1.6089, grad_fn=)

64%|██████▍ | 64/100 [04:06<02:20, 3.90s/it]tensor(1.6301, grad_fn=)

65%|██████▌ | 65/100 [04:10<02:16, 3.91s/it]tensor(1.6260, grad_fn=)

66%|██████▌ | 66/100 [04:14<02:12, 3.90s/it]tensor(1.6090, grad_fn=)

67%|██████▋ | 67/100 [04:18<02:08, 3.89s/it]tensor(1.6034, grad_fn=)

68%|██████▊ | 68/100 [04:22<02:07, 3.97s/it]tensor(1.6016, grad_fn=)

69%|██████▉ | 69/100 [04:26<02:06, 4.07s/it]tensor(1.6103, grad_fn=)

70%|███████ | 70/100 [04:30<02:00, 4.03s/it]tensor(1.6197, grad_fn=)

71%|███████ | 71/100 [04:34<01:56, 4.02s/it]tensor(1.6116, grad_fn=)

72%|███████▏ | 72/100 [04:38<01:51, 4.00s/it]tensor(1.6138, grad_fn=)

73%|███████▎ | 73/100 [04:42<01:46, 3.96s/it]tensor(1.6181, grad_fn=)

74%|███████▍ | 74/100 [04:46<01:42, 3.93s/it]tensor(1.6069, grad_fn=)

75%|███████▌ | 75/100 [04:50<01:37, 3.89s/it]tensor(1.6059, grad_fn=)

76%|███████▌ | 76/100 [04:54<01:37, 4.05s/it]tensor(1.6237, grad_fn=)

77%|███████▋ | 77/100 [04:58<01:34, 4.09s/it]tensor(1.6089, grad_fn=)

78%|███████▊ | 78/100 [05:02<01:25, 3.91s/it]tensor(1.6178, grad_fn=)

79%|███████▉ | 79/100 [05:05<01:18, 3.76s/it]tensor(1.6188, grad_fn=)

80%|████████ | 80/100 [05:09<01:13, 3.69s/it]tensor(1.6343, grad_fn=)

81%|████████ | 81/100 [05:12<01:08, 3.61s/it]tensor(1.6117, grad_fn=)

82%|████████▏ | 82/100 [05:16<01:05, 3.66s/it]tensor(1.6036, grad_fn=)

83%|████████▎ | 83/100 [05:19<01:00, 3.57s/it]tensor(1.6103, grad_fn=)

84%|████████▍ | 84/100 [05:23<00:56, 3.52s/it]tensor(1.6077, grad_fn=)

85%|████████▌ | 85/100 [05:26<00:52, 3.52s/it]tensor(1.6200, grad_fn=)

86%|████████▌ | 86/100 [05:30<00:48, 3.50s/it]tensor(1.6193, grad_fn=)

87%|████████▋ | 87/100 [05:33<00:45, 3.49s/it]tensor(1.6079, grad_fn=)

88%|████████▊ | 88/100 [05:37<00:41, 3.48s/it]tensor(1.6000, grad_fn=)

89%|████████▉ | 89/100 [05:40<00:37, 3.44s/it]tensor(1.6084, grad_fn=)

90%|█████████ | 90/100 [05:43<00:34, 3.42s/it]tensor(1.6139, grad_fn=)

91%|█████████ | 91/100 [05:47<00:30, 3.40s/it]tensor(1.6149, grad_fn=)

92%|█████████▏| 92/100 [05:50<00:27, 3.40s/it]tensor(1.6075, grad_fn=)

93%|█████████▎| 93/100 [05:53<00:23, 3.38s/it]tensor(1.6105, grad_fn=)

94%|█████████▍| 94/100 [05:57<00:20, 3.38s/it]tensor(1.6165, grad_fn=)

95%|█████████▌| 95/100 [06:00<00:16, 3.38s/it]tensor(1.6063, grad_fn=)

96%|█████████▌| 96/100 [06:04<00:13, 3.38s/it]tensor(1.6069, grad_fn=)

97%|█████████▋| 97/100 [06:07<00:10, 3.35s/it]tensor(1.6090, grad_fn=)

98%|█████████▊| 98/100 [06:10<00:06, 3.33s/it]tensor(1.6037, grad_fn=)

99%|█████████▉| 99/100 [06:13<00:03, 3.31s/it]tensor(1.6162, grad_fn=)

100%|██████████| 100/100 [06:17<00:00, 3.78s/it]tensor(1.6157, grad_fn=)

‘’’

plt.plot(list(range(len(lss))),lss)#可视化loss

#可以看出因为定义了一个权值复用五次的全连接层(self.fc2),网络发生了明显的过拟合

net.score(X_train[20000:60000-1,:,:,:],y[20000:60000 - 1]) #评估网络

#0.8310457761444036

#随机挑选几个样本进行可视化

for i in range(9):

item = np.random.randint(0,60000 -1)

plt.subplot(3, 3, i + 1)

title = f"标签:{str(y[item])} 预测:{net.predict(X_train[[item],:,:,:])[0]}"

plt.title(title, fontproperties='SimHei')

plt.imshow(X_train[item,:,:,:].squeeze(0), cmap='gray')

plt.show()

#可视化一下self.fc1的权值

for i in range(10):

plt.subplot(3,4,i+1)

plt.title(f'{i}')

plt.imshow(list(net.fc1.parameters())[0][i,:].view((13,13)).detach().numpy(),cmap='gray')

plt.show()

你可能感兴趣的:(快速入门Pytorch:以MNIST数据集为例)