深度学习的一般性流程:
1. 构建网络模型结构
2. 选择损失函数
3. 选择优化器进行训练
梯度下降法(gradient descent)是一个最优化算法,常用于机器学习和人工智能当中用来递归性地逼近最小偏差模型。
torch.optim.SGD 是随机梯度下降的优化函数
梯度下降(Gradient Descent)方法变种:
下图来自https://towardsdatascience.com/gradient-descent-algrithm-and-its-variants-10f652806a3
回归问题的训练代码
#--------------data--------------------
x = torch.unsqueeze(torch.linspace(-1,1,100),dim=1)
y = x.pow(2)
#--------------train-------------------
optimizer = torch.optim.SGD(mynet.parameters(),lr=0.1) #优化器
loss_func = torch.nn.MSELoss() #损失函数
for epoch in range(1000):
optimizer.zero_grad()
#forward + backward + optimize
pred = mynet(x)
loss = loss_func(pred,y)
loss.backward()
optimizer.step()
#----------------prediction---------------
test_data = torch.tensor([-1.0])
pred = mynet(test_data)
print(test_data, pred.data)
分类问题的训练代码
#----------------data------------------
data_num = 100
x = torch.unsqueeze(torch.linspace(-1,1,data_num), dim=1)
y0 = torch.zeros(50)
y1 = torch.ones(50)
y = torch.cat((y0, y1), ).type(torch.LongTensor) #数据
#----------------train------------------
optimizer = torch.optim.SGD(mynet.parameters(),lr=0.1) #优化器
loss_func = torch.nn.CrossEntropyLoss() #损失函数
for epoch in range(1000):
optimizer.zero_grad()
#forward + backward + optimize
pred = mynet(x)
loss = loss_func(pred,y)
loss.backward()
optimizer.step()
#----------------prediction---------------
test_data = torch.tensor([-1.0])
pred = mynet(test_data)
print(test_data, pred.data)
注意:optimizer.zero_grad() 每次做反向传播之前都要归零梯度,不然梯度会累加在一起,造成结果不收敛。
mini_batch的训练代码:
数据
import torch
import torch.utils.data as Data
import matplotlib.pyplot as plt
x = torch.linspace(1,10,10)
y = torch.linspace(10,1,10)
torch_database = Data.TensorDataset(x,y)
plt.scatter(x.data,y.data,s=10,cmap="autumn")
plt.show()
#支持批处理的操作
BATCH_SIZE = 5
loader = Data.DataLoader(
dataset = torch_database,
batch_size = BATCH_SIZE
)
for epoch in range(3):
for step,(batch_x,batch_y) in enumerate(loader):
print("Epoch:",epoch," step:",step," batch x",batch_x.numpy()," batch y: ",batch_y.numpy())
高级梯度下降方法:
更详细可参考:https://blog.csdn.net/fishmai/article/details/52510826
https://blog.csdn.net/tsyccnh/article/details/76769232