mxnet,pytorch,keras训练套路对比

使用代码对比下各个框架写模型的一般套路。

mxnet

import mxnet
from mxnet.gluon import nn,loss,Trainer

# 定义模型
class XXModel(nn.Block):
	def __init__(self, **kwargs):
		# 设计模型结构
	def forward(self, x):
		# 计算一次前向结果

# 初始化模型,数据,损失函数,优化器(trainer)
model = XXModel()  # 初始化模型
model.initialize()  # 初始化模型权重

data = nd.random.uniform(shape=(1, 3, 224, 224))  # 输入数据
label = nd.random.randint(0, 2, shape=(1,))  # 标签

loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()  # 损失函数

trainer = gluon.Trainer(model.collect_params(), 'sgd', {'learning_rate': 0.1}) # 优化器

# 训练
with autograd.record(): 
       output = net(data)  # 计算输出
       loss = loss_fn(output, label)  # 计算损失
loss.backward()  # 计算梯度

trainer.step(batch_size=1)  # 根据梯度更新参数

mxnet代码清晰,实际部署应该比pytorch坑更少

pytorch

from torch.optim import optim
from torch import nn

class XXModel(nn.Module):
	def __init__(self,...):
		# 设计,初始化模型结构
	def forward(self, x):
		# 计算一次前向传播结果

# 初始化优化器,损失函数,模型,数据,学习率调节器
optimizer = optim.XXX()  # 实例化优化器,如optim.SGD()

scheduler = optim.lr_scheduler.XXX(optimizer)  # 实例化学习率调节器,如optim.lr_scheduler.StepLR

loss_fn = nn.XXXLoss()  # 实例化损失函数,如nn.CrossEntropyLoss

model = XXModel()  # 实例化网络模型

input = torch.Tensor(data, dtype=torch.float32)  # 实例化数据

target = torch.Tensor(target, dtype=xxx)  # 实例化标签



# 训练
model = model(input)  # 前向传播

loss = loss_fn(output, target)  # 计算损失

optimizer.zero_grad()  # 梯度清0

loss.backward()  # 计算梯度

optimizer.step()  # 一次梯度下降

pytorch和mxnet很相似,上手了pytorch应该也很容易上手mxnet。

keras

from keras.models import Sequential
from keras.layers import Dense
import numpy as np

# 定义数据
x_train = np.ones(shape=xxx)
y_train = np.ones(shape=xxx)

# 定义模型
model = Sequential()
model.add(Dense(units=64, activation='relu', input_dim=100))
model.add(Dense(units=10, activation='softmax'))

# 定义训练需要的参数,compile集成了loss,optimizer
model.compile(loss='categorical_crossentropy',  # 定义损失函数
              optimizer='sgd',  # 定义优化器
              metrics=['accuracy'])  # 定义评估指标

model.fit(x_train, y_train, epochs=5, batch_size=32)  # fit开始训练

keras代码最简洁,也最容易上手。

你可能感兴趣的:(#,计算机视觉)