使用pytorch框架能够极大的减少线性回归的代码量,尽管如此,我们也应该学会线性回归的具体实现过程。
获取数据集时任何模型训练开始的基础,这里的数据集是我们自己生成的,y=wx+b+噪声,生成y和x。
# 生成数据集 """⽣成 y = Xw + b + 噪声。"""
def synthetic_data(w, b, num_examples):
X = torch.normal(0, 1, (num_examples, len(w)))
y = torch.matmul(X, w) + b
y += torch.normal(0, 0.01, y.shape)
return X, y.reshape((-1, 1))
初始化w,b,然后调用生成数据集函数,查看生成的数据集。
# 初始化w,和b
true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)
print(features[0], '\n', labels[0])
print(len(features))
调用d2l包里面的函数,这里直接调用。
# 调用画图函数
d2l.set_figsize()
d2l.plt.scatter(features[:, (1)].detach().numpy(), labels.detach().numpy(), 1)
plt.show()
生成数据集后,读取数据集也是很重要的,这里使用小批量读取数据集的方法。
# 读取数据集
def data_iter(batch_size, features, labels):
num_examples = len(features)
indices = list(range(num_examples))
random.shuffle(indices) # random.shuffle可以打乱数组,列表的顺序,但对张量无效
for i in range(0, num_examples, batch_size):
batch_indices = torch.tensor(indices[i:min(i + batch_size,
num_examples)]) # 一般来说可以直接i到i+batch,这里是考虑到num_examples比batch小的情况 初始化张量是可以是列表,数组
yield features[batch_indices], labels[batch_indices]
batch_size = 10
for X, y in data_iter(batch_size, features, labels):
print(X, '\n', y)
这里的batch_size不仅设置了小批量读取数据集的大小,也是后面小批量梯度下降的大小。
w是正态分布,b初始化为0
w = torch.normal(0, 0.01, size=(2, 1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)
模型为y=wx+b
def linreg(X, w, b):
return torch.matmul(X, w) + b
def squared_loss(y_hat, y):
return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2
小批量梯度下降算法,随机抽取一部分数据更新梯度,相当于摸着石头过河的时候选择一片区域看看深度。
def sgd(params, lr, batch_size):
with torch.no_grad():
for param in params:
param -= lr * param.grad / batch_size
param.grad.zero_()
# 训练
lr = 0.03
num_epochs = 3
net = linreg
loss = squared_loss
for epochs in range(num_epochs):
for X, y in data_iter(batch_size, features, labels): # X=feature 小批量 y=labels小批量
l = loss(net(X, w, b), y)
l.sum().backward()
sgd([w, b], lr, batch_size)
with torch.no_grad():
train_l = loss(net(features, w, b), labels)
print("epoch:", epochs + 1, '\n', 'loss:', float(train_l.mean()))
print(true_w-w.reshape(true_w.shape))
print(true_b-b)