基于pytorch的自定义网络手写数字识别

基于pytorch的手写数字识别

  • 手写数字识别
    • 数据获取
    • 初始化w,b取值,并对取值进行kaiming初始化
    • 定义网络结构
    • 学习率以及损失函数定义,随机梯度下降求最优解
    • 模型训练以及验证
    • 交叉熵损失函数说明

手写数字识别

本文基于MNIST数据集,使用torch基于自定义网络结构实现对数据集的识别,并且对交叉熵损失函数进行详细说明。

数据获取

import torch
from torch import nn
from torch import utils
from torch.nn import functional as F
from torch import optim
import torchvision
from visdom import Visdom
batch_size = 512

train_loader = utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data', train=True, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.tra

你可能感兴趣的:(机器学习,深度学习,pytorch,python,手写数字识别)