Pytorch上手使用
近期学习了另一个深度学习框架库Pytorch,对学习进行一些总结,方便自己回顾。
Pytorch是torch的python版本,是由Facebook开源的神经网络框架。与Tensorflow的静态计算图不同,pytorch的计算图是动态的,可以根据计算需要实时改变计算图。
1 安装
如果已经安装了cuda8,则使用pip来安装pytorch会十分简单。若使用其他版本的cuda,则需要下载官方释放出来对应的安装包。具体安装地址参见官网的首页。也就是先安装cuda8,再用pip安装,
-----------------------------------------------------------------------------------------------------------
目前最新稳定版本为0.4.0。上个版本0.3.0的文档有中文版,见中文文档。
pip install torch torchvision # for python2.7
pip3 install torch torchvision # for python3
1
2
-----------------------------------------------------------------------------------------------
wjc是这样安装的:pip install -i https://pypi.tuna.tsinghua.edu.cn/simple torch torchvision
没用pip3,加清华源,快的一批,
2 概述
理解pytorch的基础主要从以下三个方面
Numpy风格的Tensor操作。pytorch中tensor提供的API参考了Numpy的设计,因此熟悉Numpy的用户基本上可以无缝理解,并创建和操作tensor,同时torch中的数组和Numpy数组对象可以无缝的对接。
变量自动求导。在一序列计算过程形成的计算图中,参与的变量可以方便的计算自己对目标函数的梯度。这样就可以方便的实现神经网络的后向传播过程。
神经网络层与损失函数优化等高层封装。网络层的封装存在于torch.nn模块,损失函数由torch.nn.functional模块提供,优化函数由torch.optim模块提供。
因此下面的内容也主要围绕这三个方面来介绍。第3节介绍张量的操作,第4节介绍自动求导,第5节介绍神经网络层等的封装,第6,7节简单介绍损失函数与优化方法。这三部分相对重要。后续的第8节介绍介绍数据集及torchvision,第9节介绍训练过程可视的工具,第10节通过相对完整的示例代码展示pytorch中如何解决MNIST与CIFAR10的分类。
3 Tensor(张量)
Tensor是神经网络框架中重要的基础数据类型,可以简单理解为N维数组的容器对象。tensor之间的通过运算进行连接,从而形成计算图。
3.1 Tensor类型
Torch 定义了七种 CPU tensor 类型和八种 GPU tensor 类型:
Data type CPU tensor GPU tensor
32-bit floating point torch.FloatTensor torch.cuda.FloatTensor
64-bit floating point torch.DoubleTensor torch.cuda.DoubleTensor
16-bit floating point torch.HalfTensor torch.cuda.HalfTensor
8-bit integer (unsigned) torch.ByteTensor torch.cuda.ByteTensor
8-bit integer (signed) torch.CharTensor torch.cuda.CharTensor
16-bit integer (signed) torch.ShortTensor torch.cuda.ShortTensor
32-bit integer (signed) torch.IntTensor torch.cuda.IntTensor
64-bit integer (signed) torch.LongTensor torch.cuda.LongTensor
通常情况下使用Tensor类的构造函数返回的是FloatTensor类型对象,可通过在对象上调用cuda()返回一个新的cuda.FloatTensor类型的对象。
torch模块内提供了操作tensor的接口,而Tensor类型的对象上也设计了对应了接口。例如torch.add()与tensor.add()等价。需要注意的是这些接口都采用创建一个新对象返回的形式。如果想就地修改一个tensor对象,需要使用加后缀下划线的方法。例如x.add_(y),将修改x。Tensor类的构建函数支持从列表或ndarray等类型进行构建。默认tensor为FloatTensor。
下面的几节简单的描述重要的操作tensor的方法。
3.1 tensor的常见创建接口
方法名 说明
Tensor() 直接从参数构造一个的张量,参数支持list,numpy数组
eye(row, column) 创建指定行数,列数的二维单位tensor
linspace(start,end,count) 在区间[s,e]上创建c个tensor
logspace(s,e,c) 在区间[10^s, 10^e]上创建c个tensor
ones(*size) 返回指定shape的张量,元素初始为1
zeros(*size) 返回指定shape的张量,元素初始为0
ones_like(t) 返回与t的shape相同的张量,且元素初始为1
zeros_like(t) 返回与t的shape相同的张量,且元素初始为0
arange(s,e,sep) 在区间[s,e)上以间隔sep生成一个序列张量
3.2 随机采样
方法名 说明
rand(*size) 在区间[0,1)返回一个均匀分布的随机数张量
uniform(s,e) 在指定区间[s,e]上生成一个均匀分布的张量
randn(*size) 返回正态分布N(0,1)取样的随机数张量
normal(means, std) 返回一个正态分布N(means, std)
3.3 序列化
方法名 说明
save(obj, path) 张量对象的保存,通过pickle进行
load(path) 从文件中反序列化一个张量对象
3.4 数学操作
这些方法均为逐元素处理方法
方法名 说明
abs 绝对值
add 加法
addcdiv(t, v, t1, t2) t1与t2的按元素除后,乘v加t
addcmul(t, v, t1, t2) t1与t2的按元素乘后,乘v加t
ceil 向上取整
floor 向下取整
clamp(t, min, max) 将张量元素限制在指定区间
exp 指数
log 对数
pow 幂
mul 逐元素乘法
neg 取反
sigmoid
sign 取符号
sqrt 开根号
tanh
注:这些操作均创建新的tensor,如果需要就地操作,可以使用这些方法的下划线版本,例如abs_。
3.5 归约方法
方法名 说明
cumprod(t, axis) 在指定维度对t进行累积
cumsum 在指定维度对t进行累加
dist(a,b,p=2) 返回a,b之间的p阶范数
mean 均值
median 中位数
std 标准差
var 方差
norm(t,p=2) 返回t的p阶范数
prod(t) 返回t所有元素的积
sum(t) 返回t所有元素的和
3.6 比较方法
方法名 说明
eq 比较tensor是否相等,支持broadcast
equal 比较tensor是否有相同的shape与值
ge/le 大于/小于比较
gt/lt 大于等于/小于等于比较
max/min(t,axis) 返回最值,若指定axis,则额外返回下标
topk(t,k,axis) 在指定的axis维上取最高的K个值
3.7 其他操作
方法名 说明
cat(iterable, axis) 在指定的维度上拼接序列
chunk(tensor, c, axis) 在指定的维度上分割tensor
squeeze(input,dim) 将张量维度为1的dim进行压缩,不指定dim则压缩所有维度为1的维
unsqueeze(dim) squeeze操作的逆操作
transpose(t) 计算矩阵的转置换
cross(a, b, axis) 在指定维度上计算向量积
diag 返回对角线元素
hist(t, bins) 计算直方图
trace 返回迹
3.8 矩阵操作
方法名 说明
dot(t1, t2) 计算张量的内积
mm(t1, t2) 计算矩阵乘法
mv(t1, v1) 计算矩阵与向量乘法
qr(t) 计算t的QR分解
svd(t) 计算t的SVD分解
3.9 tensor对象的方法
方法名 作用
size() 返回张量的shape属性值
numel(input) 计算tensor的元素个数
view(*shape) 修改tensor的shape,与np.reshape类似,view返回的对象共享内存
resize 类似于view,但在size超出时会重新分配内存空间
item 若为单元素tensor,则返回pyton的scalar
from_numpy 从numpy数据填充
numpy 返回ndarray类型
3.10 tensor内部
tensor对象由两部分组成,tensor的信息与存储,storage封装了真正的data,可以由多个tensor共享。大多数操作只是修改tensor的信息,而不修改storage部分。这样达到效率与性能的提升。
3.11 使用pytorch进行线性回归
import torch
import torch.optim as optim
import matplotlib.pyplot as plt
def get_fake_data(batch_size=32):
''' y=x*2+3 '''
x = torch.randn(batch_size, 1) * 20
y = x * 2 + 3 + torch.randn(batch_size, 1)
return x, y
x, y = get_fake_data()
class LinerRegress(torch.nn.Module):
def __init__(self):
super(LinerRegress, self).__init__()
self.fc1 = torch.nn.Linear(1, 1)
def forward(self, x):
return self.fc1(x)
net = LinerRegress()
loss_func = torch.nn.MSELoss()
optimzer = optim.SGD(net.parameters())
for i in range(40000):
optimzer.zero_grad()
out = net(x)
loss = loss_func(out, y)
loss.backward()
optimzer.step()
w, b = [param.item() for param in net.parameters()]
print w, b # 2.01146, 3.184525
# 显示原始点与拟合直线
plt.scatter(x.squeeze().numpy(), y.squeeze().numpy())
plt.plot(x.squeeze().numpy(), (x*w + b).squeeze().numpy())
plt.show()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
这里写图片描述
从这里的代码可以发现,pytorch需要我们自己实现各轮更新,并手动调用反向传播以及更新参数,此外也没有提供评估及预测功能。相对于Keras这种高层的封装,pytorch需要我们了解更多的低层细节。
4 自动求导
tensor对象通过一系列的运算可以组成动态图,对于每个tensor对象,有下面几个变量控制求导的属性。
变量 作用
requirs_grad 默认为False,表示变量是否需要计算导数
grad_fn 变量的梯度函数
grad 变量对应的梯度
在0.3.0版本中,自动求导还需要借助于Variable类来完成,在0.4.0版本中,Variable已经被废除了,tensor自身即可完成这一过程。
import torch
x = torch.randn((4,4), requires_grad=True)
y = 2*x
z = y.sum()
print z.requires_grad # True
z.backward()
print x.grad
'''
tensor([[ 2., 2., 2., 2.],
[ 2., 2., 2., 2.],
[ 2., 2., 2., 2.],
[ 2., 2., 2., 2.]])
'''
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
5 创建神经网络
5.1 神经网络层
torch.nn模块提供了创建神经网络的基础构件,这些层都继承自Module类。下面我们简单看下如何实现Liner层。
class Liner(torch.nn.Module):
def __init__(self,in_features, out_features, bias=True):
super(Liner, self).__init__()
self.weight = torch.nn.Parameter(torch.randn(out_features, in_features))
if bias:
self.bias = torch.nn.Parameter(torch.randn(out_features))
def forward(self, x):
x = x.mm(self.weight)
if self.bias:
x = x + self.bias.expand_as(x)
return x
1
2
3
4
5
6
7
8
9
10
11
12
下面表格中列出了比较重要的神经网络层组件。对应的在nn.functional模块中,提供这些层对应的函数实现。通常对于可训练参数的层使用module,而对于不需要训练参数的层如softmax这些,可以使用functional中的函数。
Layer对应的类 功能说明
Linear(in_dim, out_dim, bias=True) 提供了进行线性变换操作的功能
Dropout(p) Dropout层,有2D,3D的类型
Conv2d(in_c, out_c, filter_size, stride, padding) 二维卷积层,类似的有Conv1d,Conv3d
ConvTranspose2d()
MaxPool2d(filter_size, stride, padding) 二维最大池化层
MaxUnpool2d(filter, stride, padding) 逆过程
AvgPool2d(filter_size, stride, padding) 二维平均池化层
FractionalMaxPool2d 分数最大池化
AdaptiveMaxPool2d([h,w]) 自适应最大池化
AdaptiveAvgPool2d([h,w]) 自自应平均池化
ZeroPad2d(padding_size) 零填充边界
ConstantPad2d(padding_size,const) 常量填充边界
ReplicationPad2d(ps) 复制填充边界
BatchNorm1d() 对2维或3维小批量数据进行标准化操作
RNN(in_dim, hidden_dim, num_layers, activation, dropout, bidi, bias) 构建RNN层
RNNCell(in_dim, hidden_dim, bias, activation) RNN单元
LSTM(in_dim, hidden_dim, num_layers, activation, dropout, bidi, bias) 构建LSTM层
LSTMCell(in_dim, hidden_dim, bias, activation) LSTM单元
GRU(in_dim, hidden_dim, num_layers, activation, dropout, bidi, bias) 构建GRU层
GRUCell(in_dim, hidden_dim, bias, activation) GRU单元
5.2 非线性激活层
激活层类名 作用
ReLU(inplace=False) Relu激活层
Sigmoid Sigmoid激活层
Tanh Tanh激活层
Softmax Softmax激活层
Softmax2d
LogSoftmax LogSoftmax激活层
5.3 容器类型
容器类型 功能
Module 神经网络模块的基类
Sequential 序列模型,类似于keras,用于构建序列型神经网络
ModuleList 用于存储层,不接受输入
Parameters(t) 模块的属性,用于保存其训练参数
ParameterList 参数列表
下面的代码演示了使用容器型模块的方式。
# 方法1
model1 = nn.Sequential()
model.add_module('fc1', nn.Linear(3,4))
model.add_module('fc2', nn.Linear(4,2))
model.add_module('output', nn.Softmax(2))
# 方法2
model2 = nn.Sequential(
nn.Conv2d(1,20,5),
nn.ReLU(),
nn.Conv2d(20,64,5),
nn.ReLU()
)
# 方法3
model3 = nn.ModuleList([nn.Linear(3,4), nn.ReLU(), nn.Linear(4,2)])
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
5.4 其他层
容器类型 功能
Embedding(vocab_size, feature_dim) 词向量层
Embeddingbag
5.5 模型的保存
前面我们知道tensor可以通过save与load方法实现序列化与反序列化。由tensor组成的网络同样也可以方便的保存。不过通常没有必要完全保存网络模块对象,只需要保存各层的权重数据即可,这些数据保存在模块的state_dict字典中,因此只需要序列化这个词典。
# 模型的保存
torch.save(model.state_dict, 'path')
# 模型的加载
model.load_state_dict('path)
1
2
3
4
5.6 实现LeNet神经网络
torch.nn.Module提供了神经网络的基类,当实现神经网络时需要继承自此模块,并在初始化函数中创建网络需要包含的层,并实现forward函数完成前向计算,网络的反向计算会由自动求导机制处理。
下面的示例代码创建了LeNet的卷积神经网络。通常将需要训练的层写在init函数中,将参数不需要训练的层在forward方法里调用对应的函数来实现相应的层。
import torch.nn as nn
import torch.nn.functional as F
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), 2)
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
6 损失函数与优化方法
6.1 损失函数
torch.nn模块中提供了许多损失函数类,这里列出几种相对常见的。
类名 功能
MSELoss 均方差损失
CrossEntropyLoss 交叉熵损失
NLLLoss 负对数似然损失
PoissonNLLLoss 带泊松分布的负对数似然损失
6.2 优化方法
由torch.optim模块提供支持
类名 功能
SGD(params, lr=0.1, momentum=0, dampening=0, weight_decay=0, nesterov=False) 随机梯度下降法
Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False) Adam
RMSprop(params, lr=0.01, alpha=0.99, eps=1e-08, weight_decay=0, momentum=0, centered=False) RMSprop
Adadelta(params, lr=1.0, rho=0.9, eps=1e-06, weight_decay=0) Adadelta
Adagrad(params, lr=0.01, lr_decay=0, weight_decay=0, initial_accumulator_value=0) Adagrad
lr_scheduler.ReduceLROnPlateau(optimizer, mode=’min’, factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode=’rel’, cooldown=0, min_lr=0, eps=1e-08) 学习率的控制
在神经网络的性能调优中,常见的作法是对不对层的网络设置不同的学习率。
class model(nn.Module):
def __init__():
super(model,self).__init__()
self.base = Sequencial()
# code for base sub module
self.classifier = Sequencial()
# code for classifier sub module
optim.SGD([
{'params': model.base.parameters()},
{'params': model.classifier.parameters(), 'lr': 1e-3}
], lr=1e-2, momentum=0.9)
1
2
3
4
5
6
7
8
9
10
11
12
6.3 参数初始化
良好的初始化可以让模型快速收敛,有时甚至可以决定模型是否能训练成功。Pytorch中的参数通常有默认的初始化策略,不需要我们自己指定,但框架仍然留有相应的接口供我们来调整初始化方法。
初始化方法 说明
xavier_uniform_
xavier_normal_
kaiming_uniform_
from torch.nn import init
# net的类定义
...
# 初始化各层权重
for name, params in net.named_parameters():
init.xavier_normal(param[0])
init.xavier_normal(param[1])
1
2
3
4
5
6
7
8
9
10
7 数据集与数据加载器
7.1 DataSet与DataLoader
torch.util.data模块提供了DataSet类用于描述一个数据集。定义自己的数据集需要继承自DataSet类,且实现__getitem__()与__len__()方法。__getitem__方法返回指定索引处的tensor与其对应的label。
为了支持数据的批量及随机化操作,可以使用data模块下的DataLoader类型来返回一个加载器:
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0)
7.2 torchvision简介
torchvision是配合pytorch的独立计算机视觉数据集的工具库,下面介绍其中常用的数据集类型。
torchvision.datasets.ImageFolder(dir, transform, label_map,loader)
提供了从一个目录初始化出来一个图片数据集的便捷方法。
要求目录下的图片分类存放,每一类的图片存储在以类名为目录名的目录下,方法会将每个类名映射到唯一的数字上,如果你对数字有要求,可以用label_map来定义目录名到数字的映射。
torchvision.datasets.DatasetFolder(dir,transform, label_map, loader, extensions)
提供了从一个目录初始化一般数据集的便捷方法。目录下的数据分类存放,每类数据存储在class_xxx命名的目录下。
此外torchvision.datasets下实现了常用的数据集,如CIFAR-10/100, ImageNet, COCO, MNIST, LSUN等。
除了数据集,torchvision的model模块提供了常见的模型实现,如Alex-Net, VGG,Inception, Resnet等。
7.3 torchvision提供的图像变换工具
torchvision的transforms模块提供了对PIL.Image对象和Tensor对象的常见操作。如果需要连续应用多个变换,可以使用Compose对象组装多个变换。
转换操作 说明
Scale PIL图片进行缩放
CenterCrop PIL图片从中心位置剪切
Pad PIL图片填充
ToTensor PIL图片转换为Tensor且归一化到[0,1]
Normalize Tensor标准化
ToPILImage 将Tensor转为PIL表示
import torchvision.tranforms as Trans
tranform = Trans.Compose([
T.Scale(28*28),
T.ToTensor(),
T.Normalize([0.5],[0.5])
])
1
2
3
4
5
6
8 训练过程可视化
8.1 使用Tensorboard
通过使用第三方库tensorboard_logger,将训练过程中的数据保存为日志,然后便可以通过Tensorboard来查看这些数据了。其功能支持相对有限,这里不做过多介绍。
8.2 使用visdom
visdom是facebook开源的一个可视工具,可以用来完成pytorch训练过程的可视化。
安装可以使用pip install visdom
启动类似于tb,在命令行上执行:python -m visdom.server
服务启动后可以使用浏览器打开http://127.0.0.1:8097/即可看到主面板。
visdom的绘图API类似于plot,通过API将绘图数据发送到基于tornado的web服务器上并显示在浏览器中。更详细内容参见visdom的github主页
9 GPU及并行支持
为了能在GPU上运行,Tensor与Module都需要转换到cuda模式下。
import torch
import torchvision
t = torch.Tensor(3,4)
print t.is_cuda #False
t = t.cuda(0)
print t.is_cuda #True
net = torchvision.model.AlexNet()
net.cuda(0)
1
2
3
4
5
6
7
8
9
10
如果有多块显卡,可以通过cuda(device_id)来将tensor分到不同的GPU上以达到负载的均衡。
另一种比较省事的做法是调用torch.set_default_tensor_type使程序默认使用某种cuda的tensor。或者使用torch.cuda.set_device(id)指定使用某个GPU。
10 示例:Pytorch实现CIFAR10与MNIST分类
关于cifar10与mnist数据集不再进行解释了。这里的Model类实现的二者的共同的任务,借鉴了keras的接口方式,Model类提供了train与evaluat方法,并没有实现序列模型的添加方法以及predict方法。此外设定损失函数与优化函数时,也只是简单的全部实例化出来,根据参数选择其中的一个,这里完全可以根据参数动态创建。
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
class Model:
def __init__(self, net, cost, optimist):
self.net = net
self.cost = self.create_cost(cost)
self.optimizer = self.create_optimizer(optimist)
pass
def create_cost(self, cost):
support_cost = {
'CROSS_ENTROPY': nn.CrossEntropyLoss(),
'MSE': nn.MSELoss()
}
return support_cost[cost]
def create_optimizer(self, optimist, **rests):
support_optim = {
'SGD': optim.SGD(self.net.parameters(), lr=0.1, **rests),
'ADAM': optim.Adam(self.net.parameters(), lr=0.01, **rests),
'RMSP':optim.RMSprop(self.net.parameters(), lr=0.001, **rest)
}
return support_optim[optimist]
def train(self, train_loader, epoches=3):
for epoch in range(epoches):
running_loss = 0.0
for i, data in enumerate(train_loader, 0):
inputs, labels = data
self.optimizer.zero_grad()
# forward + backward + optimize
outputs = self.net(inputs)
loss = self.cost(outputs, labels)
loss.backward()
self.optimizer.step()
running_loss += loss.item()
if i % 100 == 0:
print('[epoch %d, %.2f%%] loss: %.3f' %
(epoch + 1, (i + 1)*1./len(train_loader), running_loss / 100))
running_loss = 0.0
print('Finished Training')
def evaluate(self, test_loader):
print('Evaluating ...')
correct = 0
total = 0
with torch.no_grad(): # no grad when test and predict
for data in test_loader:
images, labels = data
outputs = self.net(images)
predicted = torch.argmax(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy of the network on the test images: %d %%' % (100 * correct / total))
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
def cifar_load_data():
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2)
return trainloader, testloader
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), 2)
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
def mnist_load_data():
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize([0,], [1,])])
trainset = torchvision.datasets.MNIST(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32,
shuffle=True, num_workers=2)
testset = torchvision.datasets.MNIST(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=32,shuffle=True, num_workers=2)
return trainloader, testloader
class MnistNet(torch.nn.Module):
def __init__(self):
super(MnistNet, self).__init__()
self.fc1 = torch.nn.Linear(28*28, 512)
self.fc2 = torch.nn.Linear(512, 512)
self.fc3 = torch.nn.Linear(512, 10)
def forward(self, x):
x = x.view(-1, 28*28)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.softmax(self.fc3(x), dim=1)
return x
if __name__ == '__main__':
# train for mnist
net = MnistNet()
model = Model(net, 'CROSS_ENTROPY', 'RMSP')
train_loader, test_loader = mnist_load_data()
model.train(train_loader)
model.evaluate(test_loader)
# train for cifar
net = LeNet()
model = Model(net, 'CROSS_ENTROPY', 'RMSP')
train_loader, test_loader = cifar_load_data()
model.train(train_loader)
model.evaluate(test_loader)
---------------------
作者:zzulp
来源:CSDN
原文:https://blog.csdn.net/zzulp/article/details/80573331
版权声明:本文为博主原创文章,转载请附上博文链接!