动手学深度学习——权重衰退的简洁实现代码

import torch
import torch.nn as nn
import numpy as np
import sys
sys.path.append("..")
import d2lzh_pytorch as d2l
import sys
from matplotlib import pyplot as plt

n_train, n_test, num_inputs = 20, 100, 200
#训练数据集越小,越容易过拟合。训练数据集为20,测试数据集为100,特征的纬度选择200.
#数据越小,模型越简单,过拟合越容易发生

true_w, true_b = torch.ones(num_inputs, 1) * 0.01, 0.05
#真实的权重就是0.01*全1的一个向量,偏差b为0.05

"""
读取一个人工的数据集
"""
features = torch.randn((n_train + n_test, num_inputs)) #特征
labels = torch.matmul(features, true_w) + true_b #样本数
labels += torch.tensor(np.random.normal(0, 0.01,size=labels.size()), dtype=torch.float)
train_features, test_features = features[:n_train, :],features[n_train:, :]
train_labels, test_labels = labels[:n_train], labels[n_train:]


"""
定义训练和测试模型
"""
batch_size, num_epochs, lr = 1, 100, 0.003
net, loss = d2l.linreg, d2l.squared_loss
dataset = torch.utils.data.TensorDataset(train_features,train_labels)
train_iter = torch.utils.data.DataLoader(dataset, batch_size,shuffle=True)

def fit_and_plot_pytorch(wd):
    #对权重参数衰减,权重名称一般是以weight结尾
    net=nn.Linear(num_inputs,1)
    nn.init.normal_(net.weight,mean=0,std=1)
    nn.init.normal_(net.bias,mean=0,std=1)
    optimizer_w=torch.optim.SGD(params=[net.weight],lr=lr,weight_decay=wd)
    #对权重参数衰减
    optimizer_b=torch.optim.SGD(params=[net.bias],lr=lr)
    #不对偏差参数衰减
    train_ls,test_ls=[],[]
    for _ in range(num_epochs):
        for X,y in train_iter:
            l = loss(net(X), y).mean()
            optimizer_w.zero_grad()
            optimizer_b.zero_grad()

            l.backward()
            optimizer_w.step()
            optimizer_b.step()
            # 对两个optimizer实例分别调⽤step函数,从⽽分别更新权᯿和偏差
        train_ls.append(loss(net(train_features),
                             train_labels).mean().item())
        test_ls.append(loss(net(test_features),
                            test_labels).mean().item())
    d2l.semilogy(range(1, num_epochs + 1), train_ls, 'epochs', 'loss',
                 range(1, num_epochs + 1), test_ls, ['train', 'test'])
    print('L2 norm of w:', net.weight.data.norm().item())

fit_and_plot_pytorch(0)
plt.show()

fit_and_plot_pytorch(3)
plt.show()

动手学深度学习——权重衰退的简洁实现代码_第1张图片

动手学深度学习——权重衰退的简洁实现代码_第2张图片

动手学深度学习——权重衰退的简洁实现代码_第3张图片

 

 

 

你可能感兴趣的:(动手学深度学习,深度学习,人工智能,python,算法,pytorch)