pytorch-Mnist教程

pytorch-Mnist教程

https://pytorch-cn.readthedocs.io/zh/latest/package_references/torch/#blas-and-lapack-operations

https://github.com/pytorch/examples/blob/master/mnist/main.py

资源的两个链接:第一个是中文文档,第二个是github上的学习教程,挺不错的。

一、首先加载库

from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F 
import torch.optim as optim
from torchvision import datasets,transforms

二、定义训练参数

parser = argparse.ArgumentParser()
    	parser.add_argument('--batch_size',type=int,default=64)
    	parser.add_argument('--test-batch_size',type=int,default=1000)
    	parser.add_argument('--epochs',type=int,default=10)
    	parser.add_argument('--lr',type=float,default=0.01)
    	parser.add_argument('--momentum',type=float,default=0.5)
    	parser.add_argument('--no_cuda',action='store_true',default=False)
    	parser.add_argument('--seed',type=int,default=1)
    	parser.add_argument('--log_interval',type=int,default=10)

    	args = parser.parse_args()

这里定义了,batch_size,test_batch_size,epochs,lr,momentum,no_cuda,seed,log_interval

args = parser.parse_args()

use_cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)

device = torch.device('cuda' if use_cuda else 'cpu')

kwargs = {'num_workers':1,'pin_memory':True} if use_cuda else{}

三、载入数据

train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=False, transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=args.test_batch_size, shuffle=True, **kwargs)

四、网络定义

def __init__(self):

		super(Net,self).__init__()
		self.conv1 = nn.Conv2d(1,10,kernel_size=5)
		self.conv2 = nn.Conv2d(10,20,kernel_size=5)
		self.conv2_drop = nn.Dropout2d()
		self.fc1 = nn.Linear(320,50)
		self.fc2 = nn.Linear(50,10)


	def forward(self,x):

		x = F.relu(F.max_pool2d(self.conv1(x),2))
		x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x),2)))
		x = x.view(-1,320)
		x = F.relu(self.fc1(x))
		x = F.dropout(x,training=self.training)
		x = self.fc2(x)

		return F.log_softmax(x,dim=1)

这里分别定义了:

conv1(1,10,5),输入通道数为1,输出通道数为10,卷积核size为5

max_pool2d(2),pooling核的size为2

F.relu(F.max_pool2d(self.conv1(x),2))

然后是:

conv2(10,20,5)

其中还定义了一个conv2_drop = Dropout2d()

最后是两个全连接层,输出维度为10

4.1 训练细节定义

def train(args,model,device,train_loader,optimizer,epoch):

		model.train()
		for batch_idx,(data,target) in enumerate(train_loader):

			data,target = data.to(device),target.to(device)
			optimizer.zero_grad()
			output = model(data)
			loss = F.nll_loss(output,target)
			loss.backward()
			optimizer.step()

			if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

大致就是载入完数据后,将数据转为应该输出形式,然后通过zero_grad(),然后进行输出,定义loss,定义backward,定义梯度下降。

五 测试

 def test(args,model,device,train_loader,optimizer,epoch):

     	model.eval()
     	test_loss = 0
     	correct = 0
     	with torch.no_grad():

     		for data,target in test_loader:

     			data,target = data.to(device),target.to(device)
     			output = model(data)
     			test_loss += F.nll_loss(output,target).item()
     			pred = output.max(1,keepdim=True)[1]
     			correct += pred.eq(target.view_as(pred)).sum().item()

     	test_loss /=len(test_loader.dataset)
     	print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

 

你可能感兴趣的:(pytorch)