非线性回归的原理与实现

1.激活函数:激活函数是为了让神经网络可以拟合复杂的非线性函数,比如torch.nn.functional.relu()

2.人工神经网络是多层人工神经元组成的网络结构,输入层,隐含层,输出层

3,隐含层大于2的神经网络,都可以叫深度神经网络。

import torch
import matplotlib.pyplot as plt
from time import perf_counter
# 增加一个维度100000行1列
x = torch.unsqueeze(torch.linspace(-3,3,100000),dim=1)
y = x.pow(3) + 0.3*torch.rand(x.size())
plt.scatter(x.numpy(),y.numpy(),s=0.01)
plt.show()

# 继承nn.module构建一个net
from torch import nn,optim
import torch.nn.functional as Func

class Net(nn.Module):
    def __init__(self, input_feature, num_hidden, outputs):
        super(Net,self).__init__()
        # 创建了输入特征到隐藏层的线性变换
        self.hidden = nn.Linear(input_feature,num_hidden)
        # 创建了隐藏层到输出层的线性变换
        self.out = nn.Linear(num_hidden,outputs)

    def forward(self, x):
        # 对输入数据 x 进行隐藏层的线性变换,并经过 ReLU
        x = Func.relu(self.hidden(x))
        # 表示将数据 x 通过输出层的线性变换
        x = self.out(x)
        return x


CUDA = torch.cuda.is_available()
if CUDA:
    net = Net(input_feature=1, num_hidden=20, outputs=1).cuda()
    inputs = x.cuda()
    target = y.cuda()
else:
    net = Net(input_feature=1, num_hidden=20, outputs=1)
    inputs = x
    target = y

criterion = nn.MSELoss()
optimizer = optim.SGD(net.parameters(),lr=1e-2)

def Train(model,criterion,optimizer,epochs):
    for epochs in range(epochs):

        output = model(inputs)
        loss = criterion(output,target)

        optimizer.zero_grad()
        loss.backward()

        optimizer.step() #权重更新
        if epochs % 80 == 0:
            draw(output,loss)
    return model,loss

def draw(output,loss):
    if CUDA:
        output = output.cpu()
    # 清空画布
    plt.cla()
    plt.scatter(x.numpy(), y.numpy())
    plt.plot(x.numpy(),output.data.numpy(),'r-',lw=5)
    plt.text(0.5,0,'loss=%s'%(loss.item()),fontdict={'size':20,'color':'red'})
    plt.pause(0.005)

# 训练

START = perf_counter()
net,loss = Train(net,criterion,optimizer,10000)
FINISH = perf_counter()
time = FINISH - START
print("time:%s" % time)
print("loss:",loss.item())
print("weights:",list(net.parameters()))

你可能感兴趣的:(回归,数据挖掘,人工智能)