今天学习了刘二大人B站上面的pytorch实践第四讲,这节主要讲的是反向传播算法,关于反向传播在视频中有如下例子的讲解:
黑色(位于上方)的线是前馈传播,红色(位于下方)的线是反向传播,同样的刘二大人为了使小伙伴们能够熟悉反向传播的过程同样留了两个小作业
上图我所做的结果为:-8
上图我所做的结果为:2和2。感兴趣的小伙伴可以自己做一下,我的记过不一定正确。这是为了帮助大家加深对反向传播过程的理解。
然后就是反向传播的代码:
import torch
import matplotlib.pyplot as plt
x_data = [1,2,3]
y_data = [2,4,6]
w = torch.tensor([1.0]) #将权重w设置为tensor类型
w.requires_grad = True #将梯度计算设置为True
#前馈传播函数
def forward(x):
return w*x
#计算loss的函数
def loss(x,y):
y_pred = forward(x)
return (y_pred-y)**2
loss_list = []
print("predict (before training)", 4, forward(4).item())
for epoch in range(100):
for x,y in zip(x_data,y_data):
loss_val = loss(x,y)
loss_val.backward()
print('\t grad',x,y,w.grad.item())
w.data = w.data - 0.01*w.grad.data
w.grad.data.zero_()
loss_list.append(loss_val.item())
print("progress:", epoch, loss_val.item())
print("predict (after training)", 4, forward(4).item())
plt.plot(loss_list,color='blue')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.show()
针对上面的例子老师同样留了作业:是个二次函数的模型
针对课后作业我写了如下代码:
import torch
import matplotlib.pyplot as plt
import numpy as np
x_data = [1, 2, 3, 4]
y_data = [2, 4, 6, 8]
w1 = torch.tensor([np.random.random()])
w2 = torch.tensor([np.random.random()])
b = torch.tensor([np.random.random()])
w1.requires_grad = True
w2.requires_grad = True
b.requires_grad = True
def forward(x):
return w1 * (x ** 2) + w2 * x + b
def loss(x, y):
y_pred = forward(x)
return (y_pred - y) ** 2
loss_list = []
print('predict (before training)',5,forward(5).item())
for epoch in range(1000):
for x, y in zip(x_data, y_data):
loss_val = loss(x, y)
loss_val.backward()
print('\t grad',x,y,w1.grad.item(),w2.grad.item(),b.grad.item())
w1.data = w1.data - 0.0001 * w1.grad.data
w2.data = w2.data - 0.0001 * w2.grad.data
b.data = b.data - 0.0001 * b.grad.data
w1.grad.data.zero_()
w2.grad.data.zero_()
b.grad.data.zero_()
loss_list.append(loss_val.item())
print("progress:", epoch, loss_val.item())
print("predict (after training)", 5, forward(5).item())
plt.plot(loss_list,color='blue')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.show()
可视化结果为:
在我的代码中我把学习率改成了0.0001,原因是当我设置成0.01时,得到的结果都为Nan,上网查阅了看到有的人说是因为学习率高导致,我改成0.0001效果有了改善。会继续跟着刘二大人继续努力!!!