使用pytorch建立LSTM神经网络训练识别手写数字

import torch

from torch import nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt


# torch.manual_seed(1)    # reproducible
# 设置基本参数
# Hyper Parameters
EPOCH = 1               # train the training data n times, to save time, we just train 1 epoch
BATCH_SIZE 

你可能感兴趣的:(机器学习(深度学习))