学习笔记:动手学深度学习 16 模型选择、欠拟合、过拟合

多项式回归

Python 3.8.8 (default, Apr 13 2021, 15:08:03) [MSC v.1916 64 bit (AMD64)]
Type 'copyright', 'credits' or 'license' for more information
IPython 7.22.0 -- An enhanced Interactive Python. Type '?' for help.
PyDev console: using IPython 7.22.0
Python 3.8.8 (default, Apr 13 2021, 15:08:03) [MSC v.1916 64 bit (AMD64)] on win32
import math
import numpy as np
import torch
from torch import nn
from d2l import torch as d2l
Backend Qt5Agg is interactive backend. Turning interactive mode on.
max_degree = 20  # 多项式的最大阶数
n_train, n_test = 100, 100  # 训练和测试数据集大小
true_w = np.zeros(max_degree)  # 分配大量的空间
true_w[0:4] = np.array([5, 1.2, -3.4, 5.6])#20维,只有前四个有值
features = np.random.normal(size=(n_train + n_test, 1))
features 
Out[5]: 
array([[-1.24877991e+00],
       [ 8.28209948e-01],
       [-5.60350353e-02],
       [ 1.09842452e+00],
       [ 1.13567338e-01],
       [ 4.85657757e-01],
       [-4.20929085e-01],
       [-4.34989122e-02],
       [ 9.37545283e-01],
       [-5.83074156e-01],
       [ 8.86417121e-01],
       [ 1.03349135e+00],
       [ 4.73495832e-01],
       [-2.76563829e-01],
       [-3.36949842e-01],
       [-1.05634152e+00],
       [ 2.70768877e-02],
       [ 3.38999550e-01],
       [-1.01361759e-01],
       [ 4.02025173e-03],
       [-7.54897633e-01],
       [ 1.72726079e+00],
       [ 4.95219082e-02],
       [-3.41822842e-01],
       [ 5.49566206e-01],
       [-5.98505929e-01],
       [-9.37318187e-02],
       [-2.82199809e-01],
       [-1.96982592e-01],
       [-2.93186566e-02],
       [ 6.65900487e-01],
       [ 1.69821343e-01],
       [-1.48140575e+00],
       [ 7.00007070e-04],
       [-3.69277304e-01],
       [-3.23709345e-01],
       [ 3.00839480e-01],
       [-4.69767089e-01],
       [ 1.10927271e+00],
       [ 5.64456307e-01],
       [ 5.13324209e-02],
       [-1.32305190e+00],
       [-6.66163775e-02],
       [ 3.36522935e-01],
       [ 1.36290439e+00],
       [ 1.55449498e-01],
       [-2.97874388e-01],
       [-1.21726427e+00],
       [-6.84886759e-01],
       [ 4.20792991e-01],
       [-5.18045361e-01],
       [-3.75052307e-01],
       [-3.96546493e-01],
       [ 8.62088944e-01],
       [-3.20745866e+00],
       [-1.59177394e+00],
       [ 1.89175614e-02],
       [ 2.86639744e-01],
       [ 2.11043140e-01],
       [ 1.50832055e-01],
       [ 1.68166399e+00],
       [-8.12642422e-01],
       [-3.77610823e-01],
       [ 1.05191317e+00],
       [-1.04033316e+00],
       [-1.17078385e+00],
       [-2.24901389e+00],
       [-8.80283992e-01],
       [ 7.64857999e-01],
       [-1.33705116e+00],
       [ 1.66214283e+00],
       [-2.26089673e-01],
       [-1.95577731e-01],
       [ 9.74724880e-01],
       [ 5.03271540e-01],
       [ 4.62495625e-01],
       [ 2.96670037e-01],
       [-1.45279641e+00],
       [-7.22165093e-01],
       [ 6.34779405e-01],
       [ 1.44516446e+00],
       [ 1.01794945e+00],
       [-3.23743200e-01],
       [ 2.69098483e-01],
       [ 1.01854272e+00],
       [ 1.31611845e+00],
       [ 1.55891243e+00],
       [-1.15601101e+00],
       [ 7.87161538e-01],
       [-5.82257846e-01],
       [-1.54901781e+00],
       [ 9.30331582e-01],
       [ 3.44433061e-01],
       [ 1.88846232e-01],
       [-2.60092370e-01],
       [ 2.30275526e+00],
       [ 8.19156798e-01],
       [-1.07635430e+00],
       [ 3.04595056e-01],
       [-3.66152037e-01],
       [ 1.68987245e-01],
       [ 2.42085877e-01],
       [ 1.35590463e+00],
       [ 1.90207817e+00],
       [ 7.23011558e-01],
       [ 1.35794493e-01],
       [ 2.43643875e-01],
       [-3.08209917e-02],
       [ 3.66283553e-01],
       [-4.39776223e-01],
       [-8.33263555e-02],
       [ 4.38833963e-01],
       [-7.41108035e-01],
       [ 4.94164793e-01],
       [-5.56791095e-02],
       [ 1.11634799e+00],
       [ 1.24135082e+00],
       [-6.35433536e-01],
       [ 6.13647699e-01],
       [-1.67862348e+00],
       [-7.66517812e-01],
       [-5.91435826e-01],
       [ 1.84324927e-01],
       [ 1.21430958e+00],
       [ 1.13863138e+00],
       [ 1.01644315e+00],
       [ 2.24761108e-01],
       [-1.29769881e+00],
       [-3.76616998e-01],
       [-6.98516296e-01],
       [ 1.94997980e-01],
       [-5.44454501e-01],
       [ 1.05314092e+00],
       [ 3.60557306e-01],
       [-7.84796429e-01],
       [-5.82938869e-01],
       [ 9.68154455e-01],
       [-7.81071259e-01],
       [ 9.94622425e-01],
       [ 8.04636823e-01],
       [ 6.58799822e-02],
       [ 1.60391581e+00],
       [ 5.76989203e-01],
       [ 1.45813026e+00],
       [-4.90776022e-01],
       [ 1.59270405e-01],
       [ 1.65989704e+00],
       [ 2.67181631e-01],
       [-7.62639947e-01],
       [-2.67281396e+00],
       [-5.05441093e-01],
       [-6.70531654e-02],
       [-2.81798924e-01],
       [-8.90950214e-01],
       [-9.89637660e-03],
       [ 7.71098411e-01],
       [-1.27349581e-01],
       [-1.43933978e+00],
       [ 7.16563759e-01],
       [-3.34302839e-01],
       [-2.22298829e+00],
       [ 7.74795379e-01],
       [-2.42183622e-01],
       [-5.15941065e-01],
       [-2.45396930e-01],
       [-8.92292197e-01],
       [ 1.22152077e+00],
       [-1.30969202e-01],
       [-8.73798264e-01],
       [-1.60433545e-01],
       [ 8.08214014e-02],
       [ 2.37668105e-01],
       [ 4.36946348e-01],
       [ 1.45773835e+00],
       [ 6.97224637e-01],
       [-7.38344966e-01],
       [-1.00882934e+00],
       [-9.76357960e-01],
       [ 2.56287117e+00],
       [-4.73031264e-01],
       [ 2.38196523e-02],
       [ 9.92168679e-01],
       [-2.55290885e-01],
       [-5.94431216e-01],
       [-1.93339405e-02],
       [-4.02053662e-01],
       [-2.31414830e-01],
       [-1.94373832e-01],
       [-7.89412244e-01],
       [-1.65454374e-01],
       [ 1.76983933e+00],
       [-4.45153005e-01],
       [ 1.73557897e+00],
       [-8.55271220e-01],
       [-7.57470328e-02],
       [-4.62766499e-01],
       [-7.70784988e-02],
       [ 5.39440504e-01],
       [ 1.17672795e+00],
       [ 1.94881896e+00]])
np.random.shuffle(features)
poly_features = np.power(features, np.arange(max_degree).reshape(1, -1))
poly_features
Out[8]: 
array([[ 1.00000000e+00, -1.45279641e+00,  2.11061741e+00, ...,
        -5.72111784e+02,  8.31161947e+02, -1.20750909e+03],
       [ 1.00000000e+00, -1.33705116e+00,  1.78770581e+00, ...,
        -1.39481447e+02,  1.86493831e+02, -2.49351794e+02],
       [ 1.00000000e+00, -5.60350353e-02,  3.13992518e-03, ...,
        -5.29436706e-22,  2.96670045e-23, -1.66239164e-24],
       ...,
       [ 1.00000000e+00, -7.38344966e-01,  5.45153290e-01, ...,
        -5.75981341e-03,  4.25272924e-03, -3.13998123e-03],
       [ 1.00000000e+00,  3.04595056e-01,  9.27781481e-02, ...,
         1.67220087e-09,  5.09344118e-10,  1.55143700e-10],
       [ 1.00000000e+00,  2.43643875e-01,  5.93623377e-02, ...,
         3.75701957e-11,  9.15374806e-12,  2.23025464e-12]])
for i in range(max_degree):
    poly_features[:, i] /= math.gamma(i + 1)  # `gamma(n)` = (n-1)!
# `labels`的维度: (`n_train` + `n_test`,)
labels = np.dot(poly_features, true_w)
labels += np.random.normal(scale=0.1, size=labels.shape)
# NumPy ndarray转换为tensor
true_w, features, poly_features, labels = [torch.tensor(x, dtype=
    d2l.float32) for x in [true_w, features, poly_features, labels]]
features[:2], poly_features[:2, :], labels[:2]
Out[10]: 
(tensor([[-1.4528],
         [-1.3371]]),
 tensor([[ 1.0000e+00, -1.4528e+00,  1.0553e+00, -5.1105e-01,  1.8561e-01,
          -5.3932e-02,  1.3059e-02, -2.7102e-03,  4.9217e-04, -7.9447e-05,
           1.1542e-05, -1.5244e-06,  1.8455e-07, -2.0624e-08,  2.1402e-09,
          -2.0729e-10,  1.8822e-11, -1.6085e-12,  1.2982e-13, -9.9265e-15],
         [ 1.0000e+00, -1.3371e+00,  8.9385e-01, -3.9838e-01,  1.3316e-01,
          -3.5609e-02,  7.9352e-03, -1.5157e-03,  2.5332e-04, -3.7633e-05,
           5.0317e-06, -6.1161e-07,  6.8146e-08, -7.0088e-09,  6.6937e-10,
          -5.9665e-11,  4.9860e-12, -3.9215e-13,  2.9129e-14, -2.0498e-15]]),
 tensor([-3.0809, -1.8968]))
def evaluate_loss(net, data_iter, loss):  #@save
    """评估给定数据集上模型的损失。"""
    metric = d2l.Accumulator(2)  # 损失的总和, 样本数量
    for X, y in data_iter:
        out = net(X)
        y = y.reshape(out.shape)
        l = loss(out, y)
        metric.add(l.sum(), l.numel())
    return metric[0] / metric[1]
def train(train_features, test_features, train_labels, test_labels,
          num_epochs=400):
    loss = nn.MSELoss()
    input_shape = train_features.shape[-1]
    # 不设置偏置,因为我们已经在多项式特征中实现了它
    net = nn.Sequential(nn.Linear(input_shape, 1, bias=False)) #线性网络
    batch_size = min(10, train_labels.shape[0])
    train_iter = d2l.load_array((train_features, train_labels.reshape(-1,1)),
                                batch_size)
    test_iter = d2l.load_array((test_features, test_labels.reshape(-1,1)),
                               batch_size, is_train=False)
    trainer = torch.optim.SGD(net.parameters(), lr=0.01)
    animator = d2l.Animator(xlabel='epoch', ylabel='loss', yscale='log',
                            xlim=[1, num_epochs], ylim=[1e-3, 1e2],
                            legend=['train', 'test'])
    for epoch in range(num_epochs):
        d2l.train_epoch_ch3(net, train_iter, loss, trainer)
        if epoch == 0 or (epoch + 1) % 20 == 0:
            animator.add(epoch + 1, (evaluate_loss(net, train_iter, loss),
                                     evaluate_loss(net, test_iter, loss)))
    print('weight:', net[0].weight.data.numpy())
    
train(poly_features[:n_train, :4], poly_features[n_train:, :4],
      labels[:n_train], labels[n_train:])
weight: [[ 4.9856424 1.1942737 -3.3927531 5.609303 ]] # 从多项式特征中选择前2个维度,即 1, x train(poly_features[:n_train, :2], poly_features[n_train:, :2], labels[:n_train], labels[n_train:])
weight: [[3.4970608 4.6152606]] # 从多项式特征中选取所有维度 train(poly_features[:n_train, :], poly_features[n_train:, :], labels[:n_train], labels[n_train:], num_epochs=1500)
weight: [[ 4.9796734 1.2244654 -3.3590324 5.4656415 -0.13239844 0.29832777 0.28756332 -0.06417518 0.26940146 -0.00781312 -0.10101216 0.09097816 0.22320819 0.00915462 -0.02727097 -0.15652183 0.16408099 0.2014766 0.18008365 -0.12054636]]

学习笔记:动手学深度学习 16 模型选择、欠拟合、过拟合_第1张图片

 学习笔记:动手学深度学习 16 模型选择、欠拟合、过拟合_第2张图片

 

学习笔记:动手学深度学习 16 模型选择、欠拟合、过拟合_第3张图片

 

你可能感兴趣的:(深度学习,回归,人工智能)