Pytorch手写线性回归

原文链接: http://www.cnblogs.com/LiuXinyu12378/p/11374748.html

pytorch手写线性回归

 

import torch
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

LEARN_RATE = 0.1
#1.准备数据
x = torch.randn([500,1])
y_true = x*0.8+3

#2.计算预测值 t_tred = x*w + b

w = torch.rand([],requires_grad=True)
b = torch.tensor(0.,requires_grad=True)

plt.figure()
plt.grid(True)

#开启交互模式
plt.ion()
for i in range(50):

    plt.cla()

    for j in [w,b]:
        if j.grad is not None:
            j.grad.zero_()
    y_predict = x*w+b

    #3.计算损失,把参数的梯度置为0,进行反向传播

    loss = (y_predict-y_true).pow(2).mean()

    loss.backward()

    #4.更新参数,grad表示导数

    w.data = w.data - LEARN_RATE*w.grad
    b.data = b.data - LEARN_RATE*b.grad


    plt.scatter(x.numpy(),y_true.numpy())
    plt.plot(x.numpy(),y_predict.detach().numpy(),color="g")

    plt.pause(0.1)


    if i %50 ==0:
        print( "第{}次,损失{},权重w={},偏执b={}".format(i,loss.data,w.data,b.data))

#关闭交互模式
plt.ioff()
plt.show()

  

转载于:https://www.cnblogs.com/LiuXinyu12378/p/11374748.html

你可能感兴趣的:(Pytorch手写线性回归)