1. 导入包
import os
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.utils.data as Data
import torchvision
import torch.nn.functional as F
import numpy as np
from torchvision import datasets, transforms
learning_rate = 1e-4
keep_prob_rate = 0.7 #
max_epoch = 3
BATCH_SIZE = 50
DOWNLOAD_MNIST = False
if not(os.path.exists('MNIST')) or not os.listdir('MNIST'):
# not mnist dir or mnist is empyt dir
DOWNLOAD_MNIST = True
2. 导入数据
train_data = torchvision.datasets.MNIST(root='./', train = True, download=DOWNLOAD_MNIST,
transform = torchvision.transforms.Compose([
transforms.ToTensor(),
]))
train_loader = Data.DataLoader(dataset = train_data ,batch_size= BATCH_SIZE ,shuffle= True)
test_data = torchvision.datasets.MNIST(root = './', train = False,
transform = transforms.Compose([
transforms.ToTensor(),
]))
test_loader = Data.DataLoader(dataset = test_data, batch_size = BATCH_SIZE, shuffle = True)
test_x = Variable(torch.unsqueeze(test_data.test_data,dim = 1),volatile = True).type(torch.FloatTensor)[:500]/255.
test_y = test_data.test_labels[:500].numpy()
print(test_x.shape)
print(test_y.shape)
# torch.Size([500, 1, 28, 28])
# (500,)
3. CNN模型
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d( # ???
# patch 7 * 7 ; 1 in channels ; 32 out channels ; ; stride is 1
# padding style is same(that means the convolution opration's input and output have the same size)
in_channels = 1 ,
out_channels = 32 ,
kernel_size = 7 ,
stride = 1 ,
padding = 0 ,
),
nn.ReLU(), # activation function
nn.MaxPool2d(2), # pooling operation
)
self.conv2 = nn.Sequential( # ???
# line 1 : convolution function, patch 5*5 , 32 in channels ;64 out channels; padding style is same; stride is 1
# line 2 : choosing your activation funciont
# line 3 : pooling operation function.
nn.Conv2d(
in_channels = 32,
out_channels = 64,
kernel_size = 5,
stride = 1,
padding = 0,
),
nn.ReLU(),
nn.MaxPool2d(1),
)
self.out1 = nn.Linear(7*7*64 , 1024 , bias= True) # full connection layer one
self.dropout = nn.Dropout(keep_prob_rate)
self.out2 = nn.Linear(1024, 10, bias=True)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(-1, 64*7*7) # flatten the output of coonv2 to (batch_size ,64 * 7 * 7) # ???
out1 = self.out1(x)
out1 = F.relu(out1)
out1 = self.dropout(out1)
out2 = self.out2(out1)
output = F.softmax(out2)
return output
4. 训练与测试
def test(cnn):
global prediction
y_pre = cnn(test_x)
_,pre_index= torch.max(y_pre,1)
pre_index= pre_index.view(-1)
prediction = pre_index.data.numpy()
correct = np.sum(prediction == test_y)
return correct / 500.0
def train(cnn):
optimizer = torch.optim.Adam(cnn.parameters(), lr=learning_rate )
loss_func = nn.CrossEntropyLoss()
for epoch in range(max_epoch):
for step, (x_, y_) in enumerate(train_loader):
x ,y= Variable(x_),Variable(y_)
output = cnn(x)
loss = loss_func(output, y) # 标量
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step != 0 and step % 20 ==0:
print("=" * 10,step,"="*5,"="*5, "test accuracy is ",test(cnn) ,"=" * 10 )
4.1 训练
cnn = CNN()
train(cnn)
========== 20 ===== ===== test accuracy is 0.25 ==========
========== 40 ===== ===== test accuracy is 0.458 ==========
========== 60 ===== ===== test accuracy is 0.572 ==========
========== 80 ===== ===== test accuracy is 0.624 ==========
========== 100 ===== ===== test accuracy is 0.638 ==========
========== 120 ===== ===== test accuracy is 0.718 ==========
========== 140 ===== ===== test accuracy is 0.744 ==========
========== 160 ===== ===== test accuracy is 0.796 ==========
========== 180 ===== ===== test accuracy is 0.806 ==========
========== 200 ===== ===== test accuracy is 0.808 ==========
========== 220 ===== ===== test accuracy is 0.844 ==========
========== 240 ===== ===== test accuracy is 0.84 ==========
========== 260 ===== ===== test accuracy is 0.864 ==========
========== 280 ===== ===== test accuracy is 0.86 ==========
========== 300 ===== ===== test accuracy is 0.878 ==========
========== 320 ===== ===== test accuracy is 0.868 ==========
========== 340 ===== ===== test accuracy is 0.876 ==========
========== 360 ===== ===== test accuracy is 0.862 ==========
========== 380 ===== ===== test accuracy is 0.86 ==========
========== 400 ===== ===== test accuracy is 0.892 ==========
========== 420 ===== ===== test accuracy is 0.87 ==========
========== 440 ===== ===== test accuracy is 0.882 ==========
========== 460 ===== ===== test accuracy is 0.898 ==========
========== 480 ===== ===== test accuracy is 0.892 ==========
========== 500 ===== ===== test accuracy is 0.884 ==========
========== 520 ===== ===== test accuracy is 0.892 ==========
========== 540 ===== ===== test accuracy is 0.892 ==========
========== 560 ===== ===== test accuracy is 0.902 ==========
========== 580 ===== ===== test accuracy is 0.902 ==========
========== 600 ===== ===== test accuracy is 0.904 ==========
========== 620 ===== ===== test accuracy is 0.902 ==========
========== 640 ===== ===== test accuracy is 0.904 ==========
========== 660 ===== ===== test accuracy is 0.906 ==========
========== 680 ===== ===== test accuracy is 0.908 ==========
========== 700 ===== ===== test accuracy is 0.922 ==========
========== 720 ===== ===== test accuracy is 0.916 ==========
========== 740 ===== ===== test accuracy is 0.918 ==========
========== 760 ===== ===== test accuracy is 0.906 ==========
========== 780 ===== ===== test accuracy is 0.924 ==========
========== 800 ===== ===== test accuracy is 0.928 ==========
========== 820 ===== ===== test accuracy is 0.918 ==========
========== 840 ===== ===== test accuracy is 0.922 ==========
========== 860 ===== ===== test accuracy is 0.918 ==========
========== 880 ===== ===== test accuracy is 0.93 ==========
========== 900 ===== ===== test accuracy is 0.924 ==========
========== 920 ===== ===== test accuracy is 0.932 ==========
========== 940 ===== ===== test accuracy is 0.934 ==========
========== 960 ===== ===== test accuracy is 0.926 ==========
========== 980 ===== ===== test accuracy is 0.932 ==========
========== 1000 ===== ===== test accuracy is 0.934 ==========
========== 1020 ===== ===== test accuracy is 0.926 ==========
========== 1040 ===== ===== test accuracy is 0.924 ==========
========== 1060 ===== ===== test accuracy is 0.934 ==========
========== 1080 ===== ===== test accuracy is 0.932 ==========
========== 1100 ===== ===== test accuracy is 0.934 ==========
========== 1120 ===== ===== test accuracy is 0.936 ==========
========== 1140 ===== ===== test accuracy is 0.936 ==========
========== 1160 ===== ===== test accuracy is 0.93 ==========
========== 1180 ===== ===== test accuracy is 0.932 ==========
========== 20 ===== ===== test accuracy is 0.934 ==========
========== 40 ===== ===== test accuracy is 0.946 ==========
========== 60 ===== ===== test accuracy is 0.94 ==========
========== 80 ===== ===== test accuracy is 0.946 ==========
========== 100 ===== ===== test accuracy is 0.946 ==========
========== 120 ===== ===== test accuracy is 0.944 ==========
========== 140 ===== ===== test accuracy is 0.946 ==========
========== 160 ===== ===== test accuracy is 0.956 ==========
========== 180 ===== ===== test accuracy is 0.936 ==========
========== 200 ===== ===== test accuracy is 0.95 ==========
========== 220 ===== ===== test accuracy is 0.956 ==========
========== 240 ===== ===== test accuracy is 0.946 ==========
========== 260 ===== ===== test accuracy is 0.944 ==========
========== 280 ===== ===== test accuracy is 0.944 ==========
========== 300 ===== ===== test accuracy is 0.954 ==========
========== 320 ===== ===== test accuracy is 0.964 ==========
========== 340 ===== ===== test accuracy is 0.95 ==========
========== 360 ===== ===== test accuracy is 0.962 ==========
========== 380 ===== ===== test accuracy is 0.948 ==========
========== 400 ===== ===== test accuracy is 0.96 ==========
========== 420 ===== ===== test accuracy is 0.946 ==========
========== 440 ===== ===== test accuracy is 0.96 ==========
========== 460 ===== ===== test accuracy is 0.948 ==========
========== 480 ===== ===== test accuracy is 0.95 ==========
========== 500 ===== ===== test accuracy is 0.958 ==========
========== 520 ===== ===== test accuracy is 0.954 ==========
========== 540 ===== ===== test accuracy is 0.948 ==========
========== 560 ===== ===== test accuracy is 0.958 ==========
========== 580 ===== ===== test accuracy is 0.948 ==========
========== 600 ===== ===== test accuracy is 0.96 ==========
========== 620 ===== ===== test accuracy is 0.96 ==========
========== 640 ===== ===== test accuracy is 0.96 ==========
========== 660 ===== ===== test accuracy is 0.95 ==========
========== 680 ===== ===== test accuracy is 0.962 ==========
========== 700 ===== ===== test accuracy is 0.964 ==========
========== 720 ===== ===== test accuracy is 0.962 ==========
========== 740 ===== ===== test accuracy is 0.96 ==========
========== 760 ===== ===== test accuracy is 0.954 ==========
========== 780 ===== ===== test accuracy is 0.956 ==========
========== 800 ===== ===== test accuracy is 0.962 ==========
========== 820 ===== ===== test accuracy is 0.962 ==========
========== 840 ===== ===== test accuracy is 0.968 ==========
========== 860 ===== ===== test accuracy is 0.962 ==========
========== 880 ===== ===== test accuracy is 0.972 ==========
========== 900 ===== ===== test accuracy is 0.96 ==========
========== 920 ===== ===== test accuracy is 0.958 ==========
========== 940 ===== ===== test accuracy is 0.966 ==========
========== 960 ===== ===== test accuracy is 0.972 ==========
========== 980 ===== ===== test accuracy is 0.964 ==========
========== 1000 ===== ===== test accuracy is 0.968 ==========
========== 1020 ===== ===== test accuracy is 0.968 ==========
========== 1040 ===== ===== test accuracy is 0.956 ==========
========== 1060 ===== ===== test accuracy is 0.96 ==========
========== 1080 ===== ===== test accuracy is 0.97 ==========
========== 1100 ===== ===== test accuracy is 0.968 ==========
========== 1120 ===== ===== test accuracy is 0.964 ==========
========== 1140 ===== ===== test accuracy is 0.97 ==========
========== 1160 ===== ===== test accuracy is 0.97 ==========
========== 1180 ===== ===== test accuracy is 0.96 ==========
========== 20 ===== ===== test accuracy is 0.96 ==========
========== 40 ===== ===== test accuracy is 0.962 ==========
========== 60 ===== ===== test accuracy is 0.97 ==========
========== 80 ===== ===== test accuracy is 0.958 ==========
========== 100 ===== ===== test accuracy is 0.966 ==========
========== 120 ===== ===== test accuracy is 0.962 ==========
========== 140 ===== ===== test accuracy is 0.968 ==========
========== 160 ===== ===== test accuracy is 0.972 ==========
========== 180 ===== ===== test accuracy is 0.972 ==========
========== 200 ===== ===== test accuracy is 0.978 ==========
========== 220 ===== ===== test accuracy is 0.968 ==========
========== 240 ===== ===== test accuracy is 0.956 ==========
========== 260 ===== ===== test accuracy is 0.97 ==========
========== 280 ===== ===== test accuracy is 0.964 ==========
========== 300 ===== ===== test accuracy is 0.97 ==========
========== 320 ===== ===== test accuracy is 0.972 ==========
========== 340 ===== ===== test accuracy is 0.976 ==========
========== 360 ===== ===== test accuracy is 0.968 ==========
========== 380 ===== ===== test accuracy is 0.97 ==========
========== 400 ===== ===== test accuracy is 0.974 ==========
========== 420 ===== ===== test accuracy is 0.974 ==========
========== 440 ===== ===== test accuracy is 0.968 ==========
========== 460 ===== ===== test accuracy is 0.976 ==========
========== 480 ===== ===== test accuracy is 0.97 ==========
========== 500 ===== ===== test accuracy is 0.96 ==========
========== 520 ===== ===== test accuracy is 0.966 ==========
========== 540 ===== ===== test accuracy is 0.974 ==========
========== 560 ===== ===== test accuracy is 0.974 ==========
========== 580 ===== ===== test accuracy is 0.972 ==========
========== 600 ===== ===== test accuracy is 0.974 ==========
========== 620 ===== ===== test accuracy is 0.97 ==========
========== 640 ===== ===== test accuracy is 0.974 ==========
========== 660 ===== ===== test accuracy is 0.976 ==========
========== 680 ===== ===== test accuracy is 0.97 ==========
========== 700 ===== ===== test accuracy is 0.974 ==========
========== 720 ===== ===== test accuracy is 0.962 ==========
========== 740 ===== ===== test accuracy is 0.98 ==========
========== 760 ===== ===== test accuracy is 0.976 ==========
========== 780 ===== ===== test accuracy is 0.966 ==========
========== 800 ===== ===== test accuracy is 0.968 ==========
========== 820 ===== ===== test accuracy is 0.974 ==========
========== 840 ===== ===== test accuracy is 0.964 ==========
========== 860 ===== ===== test accuracy is 0.974 ==========
========== 880 ===== ===== test accuracy is 0.974 ==========
========== 900 ===== ===== test accuracy is 0.982 ==========
========== 920 ===== ===== test accuracy is 0.972 ==========
========== 940 ===== ===== test accuracy is 0.974 ==========
========== 960 ===== ===== test accuracy is 0.976 ==========
========== 980 ===== ===== test accuracy is 0.976 ==========
========== 1000 ===== ===== test accuracy is 0.984 ==========
========== 1020 ===== ===== test accuracy is 0.976 ==========
========== 1040 ===== ===== test accuracy is 0.976 ==========
========== 1060 ===== ===== test accuracy is 0.982 ==========
========== 1080 ===== ===== test accuracy is 0.974 ==========
========== 1100 ===== ===== test accuracy is 0.976 ==========
========== 1120 ===== ===== test accuracy is 0.974 ==========
========== 1140 ===== ===== test accuracy is 0.98 ==========
========== 1160 ===== ===== test accuracy is 0.98 ==========
========== 1180 ===== ===== test accuracy is 0.978 ==========
4.2 测试
def predict(test_x, idx):
y_pre = cnn(test_x)
print(y_pre.shape)
_, pre_index = torch.max(y_pre, 1)
prediction = pre_index.data.numpy()
print(prediction)
print("img: ", test_y[idx].data.numpy())
import matplotlib.pylab as plt
def showTorchImage(image):
mode = transforms.ToPILImage()(image)
plt.imshow(mode)
plt.show()
idx = 18
showTorchImage(test_x[idx, :, :, :])
predict(test_x, idx)
torch.Size([50, 10])
[7 8 8 3 1 5 1 6 9 4 3 5 8 1 7 1 6 9 2 7 6 2 3 9 5 1 4 7 0 0 5 0 9 2 9 2 6
3 6 1 9 2 5 7 2 0 5 6 2 6]
img: 2