Python 3.8.8 (default, Apr 13 2021, 15:08:03) [MSC v.1916 64 bit (AMD64)]
Type 'copyright', 'credits' or 'license' for more information
IPython 7.22.0 -- An enhanced Interactive Python. Type '?' for help.
PyDev console: using IPython 7.22.0
Python 3.8.8 (default, Apr 13 2021, 15:08:03) [MSC v.1916 64 bit (AMD64)] on win32
import math
import numpy as np
import torch
from torch import nn
from d2l import torch as d2l
Backend Qt5Agg is interactive backend. Turning interactive mode on.
max_degree = 20 # 多项式的最大阶数
n_train, n_test = 100, 100 # 训练和测试数据集大小
true_w = np.zeros(max_degree) # 分配大量的空间
true_w[0:4] = np.array([5, 1.2, -3.4, 5.6])#20维,只有前四个有值
features = np.random.normal(size=(n_train + n_test, 1))
features
Out[5]:
array([[-1.24877991e+00],
[ 8.28209948e-01],
[-5.60350353e-02],
[ 1.09842452e+00],
[ 1.13567338e-01],
[ 4.85657757e-01],
[-4.20929085e-01],
[-4.34989122e-02],
[ 9.37545283e-01],
[-5.83074156e-01],
[ 8.86417121e-01],
[ 1.03349135e+00],
[ 4.73495832e-01],
[-2.76563829e-01],
[-3.36949842e-01],
[-1.05634152e+00],
[ 2.70768877e-02],
[ 3.38999550e-01],
[-1.01361759e-01],
[ 4.02025173e-03],
[-7.54897633e-01],
[ 1.72726079e+00],
[ 4.95219082e-02],
[-3.41822842e-01],
[ 5.49566206e-01],
[-5.98505929e-01],
[-9.37318187e-02],
[-2.82199809e-01],
[-1.96982592e-01],
[-2.93186566e-02],
[ 6.65900487e-01],
[ 1.69821343e-01],
[-1.48140575e+00],
[ 7.00007070e-04],
[-3.69277304e-01],
[-3.23709345e-01],
[ 3.00839480e-01],
[-4.69767089e-01],
[ 1.10927271e+00],
[ 5.64456307e-01],
[ 5.13324209e-02],
[-1.32305190e+00],
[-6.66163775e-02],
[ 3.36522935e-01],
[ 1.36290439e+00],
[ 1.55449498e-01],
[-2.97874388e-01],
[-1.21726427e+00],
[-6.84886759e-01],
[ 4.20792991e-01],
[-5.18045361e-01],
[-3.75052307e-01],
[-3.96546493e-01],
[ 8.62088944e-01],
[-3.20745866e+00],
[-1.59177394e+00],
[ 1.89175614e-02],
[ 2.86639744e-01],
[ 2.11043140e-01],
[ 1.50832055e-01],
[ 1.68166399e+00],
[-8.12642422e-01],
[-3.77610823e-01],
[ 1.05191317e+00],
[-1.04033316e+00],
[-1.17078385e+00],
[-2.24901389e+00],
[-8.80283992e-01],
[ 7.64857999e-01],
[-1.33705116e+00],
[ 1.66214283e+00],
[-2.26089673e-01],
[-1.95577731e-01],
[ 9.74724880e-01],
[ 5.03271540e-01],
[ 4.62495625e-01],
[ 2.96670037e-01],
[-1.45279641e+00],
[-7.22165093e-01],
[ 6.34779405e-01],
[ 1.44516446e+00],
[ 1.01794945e+00],
[-3.23743200e-01],
[ 2.69098483e-01],
[ 1.01854272e+00],
[ 1.31611845e+00],
[ 1.55891243e+00],
[-1.15601101e+00],
[ 7.87161538e-01],
[-5.82257846e-01],
[-1.54901781e+00],
[ 9.30331582e-01],
[ 3.44433061e-01],
[ 1.88846232e-01],
[-2.60092370e-01],
[ 2.30275526e+00],
[ 8.19156798e-01],
[-1.07635430e+00],
[ 3.04595056e-01],
[-3.66152037e-01],
[ 1.68987245e-01],
[ 2.42085877e-01],
[ 1.35590463e+00],
[ 1.90207817e+00],
[ 7.23011558e-01],
[ 1.35794493e-01],
[ 2.43643875e-01],
[-3.08209917e-02],
[ 3.66283553e-01],
[-4.39776223e-01],
[-8.33263555e-02],
[ 4.38833963e-01],
[-7.41108035e-01],
[ 4.94164793e-01],
[-5.56791095e-02],
[ 1.11634799e+00],
[ 1.24135082e+00],
[-6.35433536e-01],
[ 6.13647699e-01],
[-1.67862348e+00],
[-7.66517812e-01],
[-5.91435826e-01],
[ 1.84324927e-01],
[ 1.21430958e+00],
[ 1.13863138e+00],
[ 1.01644315e+00],
[ 2.24761108e-01],
[-1.29769881e+00],
[-3.76616998e-01],
[-6.98516296e-01],
[ 1.94997980e-01],
[-5.44454501e-01],
[ 1.05314092e+00],
[ 3.60557306e-01],
[-7.84796429e-01],
[-5.82938869e-01],
[ 9.68154455e-01],
[-7.81071259e-01],
[ 9.94622425e-01],
[ 8.04636823e-01],
[ 6.58799822e-02],
[ 1.60391581e+00],
[ 5.76989203e-01],
[ 1.45813026e+00],
[-4.90776022e-01],
[ 1.59270405e-01],
[ 1.65989704e+00],
[ 2.67181631e-01],
[-7.62639947e-01],
[-2.67281396e+00],
[-5.05441093e-01],
[-6.70531654e-02],
[-2.81798924e-01],
[-8.90950214e-01],
[-9.89637660e-03],
[ 7.71098411e-01],
[-1.27349581e-01],
[-1.43933978e+00],
[ 7.16563759e-01],
[-3.34302839e-01],
[-2.22298829e+00],
[ 7.74795379e-01],
[-2.42183622e-01],
[-5.15941065e-01],
[-2.45396930e-01],
[-8.92292197e-01],
[ 1.22152077e+00],
[-1.30969202e-01],
[-8.73798264e-01],
[-1.60433545e-01],
[ 8.08214014e-02],
[ 2.37668105e-01],
[ 4.36946348e-01],
[ 1.45773835e+00],
[ 6.97224637e-01],
[-7.38344966e-01],
[-1.00882934e+00],
[-9.76357960e-01],
[ 2.56287117e+00],
[-4.73031264e-01],
[ 2.38196523e-02],
[ 9.92168679e-01],
[-2.55290885e-01],
[-5.94431216e-01],
[-1.93339405e-02],
[-4.02053662e-01],
[-2.31414830e-01],
[-1.94373832e-01],
[-7.89412244e-01],
[-1.65454374e-01],
[ 1.76983933e+00],
[-4.45153005e-01],
[ 1.73557897e+00],
[-8.55271220e-01],
[-7.57470328e-02],
[-4.62766499e-01],
[-7.70784988e-02],
[ 5.39440504e-01],
[ 1.17672795e+00],
[ 1.94881896e+00]])
np.random.shuffle(features)
poly_features = np.power(features, np.arange(max_degree).reshape(1, -1))
poly_features
Out[8]:
array([[ 1.00000000e+00, -1.45279641e+00, 2.11061741e+00, ...,
-5.72111784e+02, 8.31161947e+02, -1.20750909e+03],
[ 1.00000000e+00, -1.33705116e+00, 1.78770581e+00, ...,
-1.39481447e+02, 1.86493831e+02, -2.49351794e+02],
[ 1.00000000e+00, -5.60350353e-02, 3.13992518e-03, ...,
-5.29436706e-22, 2.96670045e-23, -1.66239164e-24],
...,
[ 1.00000000e+00, -7.38344966e-01, 5.45153290e-01, ...,
-5.75981341e-03, 4.25272924e-03, -3.13998123e-03],
[ 1.00000000e+00, 3.04595056e-01, 9.27781481e-02, ...,
1.67220087e-09, 5.09344118e-10, 1.55143700e-10],
[ 1.00000000e+00, 2.43643875e-01, 5.93623377e-02, ...,
3.75701957e-11, 9.15374806e-12, 2.23025464e-12]])
for i in range(max_degree):
poly_features[:, i] /= math.gamma(i + 1) # `gamma(n)` = (n-1)!
# `labels`的维度: (`n_train` + `n_test`,)
labels = np.dot(poly_features, true_w)
labels += np.random.normal(scale=0.1, size=labels.shape)
# NumPy ndarray转换为tensor
true_w, features, poly_features, labels = [torch.tensor(x, dtype=
d2l.float32) for x in [true_w, features, poly_features, labels]]
features[:2], poly_features[:2, :], labels[:2]
Out[10]:
(tensor([[-1.4528],
[-1.3371]]),
tensor([[ 1.0000e+00, -1.4528e+00, 1.0553e+00, -5.1105e-01, 1.8561e-01,
-5.3932e-02, 1.3059e-02, -2.7102e-03, 4.9217e-04, -7.9447e-05,
1.1542e-05, -1.5244e-06, 1.8455e-07, -2.0624e-08, 2.1402e-09,
-2.0729e-10, 1.8822e-11, -1.6085e-12, 1.2982e-13, -9.9265e-15],
[ 1.0000e+00, -1.3371e+00, 8.9385e-01, -3.9838e-01, 1.3316e-01,
-3.5609e-02, 7.9352e-03, -1.5157e-03, 2.5332e-04, -3.7633e-05,
5.0317e-06, -6.1161e-07, 6.8146e-08, -7.0088e-09, 6.6937e-10,
-5.9665e-11, 4.9860e-12, -3.9215e-13, 2.9129e-14, -2.0498e-15]]),
tensor([-3.0809, -1.8968]))
def evaluate_loss(net, data_iter, loss): #@save
"""评估给定数据集上模型的损失。"""
metric = d2l.Accumulator(2) # 损失的总和, 样本数量
for X, y in data_iter:
out = net(X)
y = y.reshape(out.shape)
l = loss(out, y)
metric.add(l.sum(), l.numel())
return metric[0] / metric[1]
def train(train_features, test_features, train_labels, test_labels,
num_epochs=400):
loss = nn.MSELoss()
input_shape = train_features.shape[-1]
# 不设置偏置,因为我们已经在多项式特征中实现了它
net = nn.Sequential(nn.Linear(input_shape, 1, bias=False)) #线性网络
batch_size = min(10, train_labels.shape[0])
train_iter = d2l.load_array((train_features, train_labels.reshape(-1,1)),
batch_size)
test_iter = d2l.load_array((test_features, test_labels.reshape(-1,1)),
batch_size, is_train=False)
trainer = torch.optim.SGD(net.parameters(), lr=0.01)
animator = d2l.Animator(xlabel='epoch', ylabel='loss', yscale='log',
xlim=[1, num_epochs], ylim=[1e-3, 1e2],
legend=['train', 'test'])
for epoch in range(num_epochs):
d2l.train_epoch_ch3(net, train_iter, loss, trainer)
if epoch == 0 or (epoch + 1) % 20 == 0:
animator.add(epoch + 1, (evaluate_loss(net, train_iter, loss),
evaluate_loss(net, test_iter, loss)))
print('weight:', net[0].weight.data.numpy())
train(poly_features[:n_train, :4], poly_features[n_train:, :4],
labels[:n_train], labels[n_train:])