PyTorch实现多项式拟合

利用梯度下降的方法对原函数继续拟合

import torch
import math
import matplotlib.pyplot as plt

class Fitting_polynomial(torch.nn.Module):
    def __init__(self):

        super(Fitting_polynomial,self).__init__()
        self.a = torch.nn.Parameter(torch.randn(()))
        self.b = torch.nn.Parameter(torch.randn(()))
        self.c = torch.nn.Parameter(torch.randn(()))
        self.d = torch.nn.Parameter(torch.randn(()))


    def forward(self, x):

        y = self.a + self.b * x + self.c * x ** 2 + self.d * x ** 3

        return y

    def string(self):
        """
        Just like any class in Python, you can also define custom method on PyTorch modules
        """
        return f'y = {self.a.item()} + {self.b.item()} x + {self.c.item()} x^2 + {self.d.item()} x^3'

    def plot_poly(self,x):
        
        fig = plt.figure(figsize=(14,8))
        y = self.a + self.b * x + self.c * x ** 2 + self.d * x ** 3
        y = y.detach().numpy()
        plt.plot(x,y,label="fitting")
        plt.legend()
        

定义原函数(sin);新建模型,并利用MSE作为损失计算的函数


# Create Tensors to hold input and outputs.
x = torch.linspace(-math.pi, math.pi, 1000)
y = torch.sin(x)

# Construct our model by instantiating the class defined above
model = Fitting_polynomial()
    

criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-8, momentum=0.9)

训练

for t in range(30000):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)

    # Compute and print loss
    loss = criterion(y_pred, y)
    if t % 2000 == 1999:
        print("epoch:{},mse:{}".format(t+1, loss.item()))
        print(f'Result: {model.string()}')
        plt.plot(x,y,label="raw")
        plt.legend()
        model.plot_poly(x)
       
    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    


epoch:2000,mse:1447.431396484375
Result: y = 0.11133185774087906 + -0.7943404316902161 x + -0.01912527345120907 x^2 + 0.14130832254886627 x^3
epoch:4000,mse:958.48291015625
Result: y = 0.09364601969718933 + -0.4856458902359009 x + -0.016138650476932526 x^2 + 0.09744332730770111 x^3
epoch:6000,mse:635.2223510742188
Result: y = 0.07877693325281143 + -0.23467518389225006 x + -0.013576172292232513 x^2 + 0.06178080663084984 x^3
epoch:8000,mse:421.49969482421875
Result: y = 0.06626877188682556 + -0.03063386306166649 x + -0.01142055168747902 x^2 + 0.03278686851263046 x^3
epoch:10000,mse:280.1954040527344
Result: y = 0.05574660003185272 + 0.13525323569774628 x + -0.009607195854187012 x^2 + 0.009214584715664387 x^3
epoch:12000,mse:186.76930236816406
Result: y = 0.046895161271095276 + 0.27012065052986145 x + -0.008081765845417976 x^2 + -0.009949849918484688 x^3
epoch:14000,mse:124.99755096435547
Result: y = 0.03944912552833557 + 0.37976884841918945 x + -0.0067985402420163155 x^2 + -0.025530679151415825 x^3
epoch:16000,mse:84.1541748046875
Result: y = 0.03318541869521141 + 0.4689137637615204 x + -0.005719069391489029 x^2 + -0.03819802775979042 x^3
epoch:18000,mse:57.14808654785156
Result: y = 0.027916258201003075 + 0.5413891077041626 x + -0.004811000544577837 x^2 + -0.04849664866924286 x^3
epoch:20000,mse:39.29086685180664
Result: y = 0.023483725264668465 + 0.6003121137619019 x + -0.004047111142426729 x^2 + -0.056869521737098694 x^3
epoch:22000,mse:27.48279571533203
Result: y = 0.019754987210035324 + 0.6482172012329102 x + -0.0034045118372887373 x^2 + -0.06367673724889755 x^3
epoch:24000,mse:19.674556732177734
Result: y = 0.016618283465504646 + 0.687164306640625 x + -0.002863941714167595 x^2 + -0.06921107321977615 x^3
epoch:26000,mse:14.511096000671387
Result: y = 0.013979638926684856 + 0.718828558921814 x + -0.0024092060048133135 x^2 + -0.07371050119400024 x^3
epoch:28000,mse:11.096450805664062
Result: y = 0.011759957298636436 + 0.7445719838142395 x + -0.0020266727078706026 x^2 + -0.07736856490373611 x^3
epoch:30000,mse:8.838287353515625
Result: y = 0.009892717935144901 + 0.7655012011528015 x + -0.0017048786394298077 x^2 + -0.0803426131606102 x^3

拟合过程可视化:(部分)

PyTorch实现多项式拟合_第1张图片

PyTorch实现多项式拟合_第2张图片
PyTorch实现多项式拟合_第3张图片
PyTorch实现多项式拟合_第4张图片
PyTorch实现多项式拟合_第5张图片

你可能感兴趣的:(pytorch,环境搭建与代码笔记,python,pytorch,回归)