自从上次学了PyTorch以后,又丢下了很长一段时间,都快忘光了。这次刷了一遍《Dive into DL PyTorch》后,又尝试做了Kaggle上的Digit Reconizer比赛。
https://tangshusen.me/Dive-into-DL-PyTorch/#/
https://www.kaggle.com/kanncaa1/pytorch-tutorial-for-deep-learning-lovers
https://blog.csdn.net/oliver233/article/details/83274285
代码链接:https://www.kaggle.com/yannnnnnnnnnnn/kernel5d66c76231?scriptVersionId=28190914 version2
实验结果如下,目前感觉还行吧,后续继续调整。
Kaggle Kernel的默认环境
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load in
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
# Input data files are available in the "../input/" directory.
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory
import os
for dirname, _, filenames in os.walk('/kaggle/input'):
for filename in filenames:
print(os.path.join(dirname, filename))
# Any results you write to the current directory are saved as output.
输出
/kaggle/input/digit-recognizer/test.csv
/kaggle/input/digit-recognizer/train.csv
/kaggle/input/digit-recognizer/sample_submission.csv
读取数据
digit_recon_tran_csv = pd.read_csv('/kaggle/input/digit-recognizer/train.csv',dtype = np.float32)
digit_recon_test_csv = pd.read_csv('/kaggle/input/digit-recognizer/test.csv',dtype = np.float32)
print('tran dataset size: ',digit_recon_tran_csv.size,'\n')
print('test dataset size: ',digit_recon_test_csv.size,'\n')
输出
tran dataset size: 32970000
test dataset size: 21952000
将pandas数据转换成numpy
tran_label = digit_recon_tran_csv.label.values
tran_image = digit_recon_tran_csv.loc[:,digit_recon_tran_csv.columns != "label"].values/255 # normalization
test_image = digit_recon_test_csv.values/255
print('train label size: ',tran_label.shape)
print('train image size: ',tran_image.shape)
print('test image size: ',test_image.shape)
输出
train label size: (42000,)
train image size: (42000, 784)
test image size: (28000, 784)
利用sklearn把train分割成train和valid
from sklearn.model_selection import train_test_split
train_image, valid_image, train_label, valid_label = train_test_split(tran_image,
tran_label,
test_size = 0.2,
random_state = 42)
print('train size: ',train_image.shape)
print('valid size: ',valid_image.shape)
输出
train size: (33600, 784)
valid size: (8400, 784)
可视化一下测试数据
# visual
import matplotlib.pyplot as plt
plt.imshow(train_image[10].reshape(28,28))
plt.axis("off")
plt.title(str(train_label[10]))
plt.show()
利用PyTorch构建data_loader
import torch
import torch.nn as nn
import numpy as np
train_image = torch.from_numpy(train_image)
train_label = torch.from_numpy(train_label).type(torch.LongTensor) # data type is long
valid_image = torch.from_numpy(valid_image)
valid_label = torch.from_numpy(valid_label).type(torch.LongTensor) # data type is long
# form dataset
train_dataset = torch.utils.data.TensorDataset(train_image,train_label)
valid_dataset = torch.utils.data.TensorDataset(valid_image,valid_label)
# form loader
batch_size = 64 # 2^5=64
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size = batch_size, shuffle = True)
利用PyTorch构建模型,这里我自己随便设计的,主要参考AlexNet
import torchvision
from torchvision import transforms
from torchvision import models
class YANNet(nn.Module):
def __init__(self):
super(YANNet,self).__init__()
self.conv = nn.Sequential(
# size: 28*28
nn.Conv2d(1,8,3,1,1), # in_channels out_channels kernel_size stride padding
nn.ReLU(),
nn.Conv2d(8,16,3,1,1),
nn.ReLU(),
nn.MaxPool2d(2),
# size: 14*14
nn.Conv2d(16,16,3,1,1),
nn.ReLU(),
nn.Conv2d(16,8,3,1,1),
nn.ReLU(),
nn.MaxPool2d(2)
)
self.fc = nn.Sequential(
# size: 7*7
nn.Linear(8*7*7,256),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(256,256),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(256,10)
)
def forward(self, img):
x = self.conv(img)
o = self.fc(x.view(x.shape[0],-1))
return o
构建模型,并开始训练
model = YANNet()
error = nn.CrossEntropyLoss()
optim = torch.optim.SGD(model.parameters(),lr=0.1)
num_epoc = 7
from torch.autograd import Variable
for epoch in range(num_epoc):
epoc_train_loss = 0.0
epoc_train_corr = 0.0
epoc_valid_corr = 0.0
print('Epoch:{}/{}'.format(epoch,num_epoc))
for data in train_loader:
images,labels = data
images = Variable(images.view(64,1,28,28))
labels = Variable(labels)
outputs = model(images)
optim.zero_grad()
loss = error(outputs,labels)
loss.backward()
optim.step()
epoc_train_loss += loss.data
outputs = torch.max(outputs.data,1)[1]
epoc_train_corr += torch.sum(outputs==labels.data)
with torch.no_grad():
for data in valid_loader:
images,labels = data
images = Variable(images.view(len(images),1,28,28))
labels = Variable(labels)
outputs = model(images)
outputs = torch.max(outputs.data,1)[1]
epoc_valid_corr += torch.sum(outputs==labels.data)
print("loss is :{:.4f},Train Accuracy is:{:.4f}%,Test Accuracy is:{:.4f}".format(epoc_train_loss/len(train_dataset),100*epoc_train_corr/len(train_dataset),100*epoc_valid_corr/len(valid_dataset)))
输出
Epoch:0/7
loss is :0.0322,Train Accuracy is:22.7262%,Test Accuracy is:73.0119
Epoch:1/7
loss is :0.0047,Train Accuracy is:90.8244%,Test Accuracy is:94.4167
Epoch:2/7
loss is :0.0024,Train Accuracy is:95.4881%,Test Accuracy is:96.2143
Epoch:3/7
loss is :0.0019,Train Accuracy is:96.4226%,Test Accuracy is:96.6667
Epoch:4/7
loss is :0.0016,Train Accuracy is:97.0804%,Test Accuracy is:96.3095
Epoch:5/7
loss is :0.0013,Train Accuracy is:97.5833%,Test Accuracy is:97.1310
Epoch:6/7
loss is :0.0012,Train Accuracy is:97.8155%,Test Accuracy is:97.5119
对test数据进行预测并保存成csv文件
test_results = np.zeros((test_image.shape[0],2),dtype='int32')
for i in range(test_image.shape[0]):
one_image = torch.from_numpy(test_image[i]).view(1,1,28,28)
one_output = model(one_image)
test_results[i,0] = i+1
test_results[i,1] = torch.max(one_output.data,1)[1].numpy()
Data = {'ImageId': test_results[:, 0], 'Label': test_results[:, 1]}
DataFrame = pd.DataFrame(Data)
DataFrame.to_csv('submission.csv', index=False, sep=',')
以上就是整个流程,但是不得不说目前的精度还不是特别高,仅有94%左右,可能原因有以下几个: