PyTorch - 0. regression 线性回归

Linear Regression

linear regression就是线性回归问题。最基本的例子就是y=mx+b。

在machine learning里面,最简单的就是调包。scikit-learn包你满意。

X, Y = pd.read_csv('my-dataset.csv') + some preprocessing # 把训练数据拿出来
model = LinearRegression() # 建立一个线性回归模型
model.fit(X, Y) # 把X和Y放到模型里组成训练数据
Y-hat = model.predict(X) # 开始训练,用X去预测Y-hat, 然后比较真值Y和预测值Y-hat

和Keras或者scikit-learn不同的是,PyTorch里面没有定义好的model,也没有定义好的fit或者predict这种函数。所以要理解每一步在做什么,才能用PyTorch去建立一个模型。

首先我们有一些数据x和y,Data = {(x_1, y_1), (x_2, y_2), ... (x_N, y_N)}, N是数据的总数量。我们要找到一个线性关系来最好的对应所有的x和y,就是y=mx+b这条线。但我们要知道其实这条线是不一定存在的,或者是一定不存在的。因为最后找到的这条线不会完美的经过所有的数据点,只可能尽量地接近所有点。

所以就有了loss function,用MSE(mean squared error)来计算每一个Y-hat和Y之间的差距:

MSE = \frac{1}{N}\sum_{i=1}^{N}(y_i -\widehat{y_i})^2

 Y-hat就是mx+b, 所以就有了:

MSE = \frac{1}{N}\sum_{i=1}^{N}(y_i -(mx_i + b))^2

在这个公式里我们其实是在计算m和b,就是变量,公式表示就是:

m^*, b^* = arg\underset{m, b}{min}L

(公式不重要)

如果解linear regression里面的两 个变量,就用partial derivatives. \partial L/\partial m = 0, \partial L/\partial b = 0,

但别的模型就不行了,比如classification model。所以就要用到gradient descent。就是在一个loop里每次向gradient的方向移动一点点,然后就会走到那个最底点。

PyTorch

1. 联系模型: 

我们构建一个输入和一个输出的Linear model。用上面的例子就是一个x和一个y

# Create the linear regression model
model = nn.Linear(1, 1)

2. 训练模型:

这里和Keras/Scikit-learn不一样的地方就是我们要自己计算loss和定义optimizer.

这里还要注意,我们每次调用backward()的时候,PyTorch会积累gradients,所以每次要清零:optimizer.zero_grad().

# Loss and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

# Train the model
n_epochs = 30
for i in range(n_epochs):
    # zero the parameter gradients
    optimizer.zero_grad()

    # forward pass
    outputs = model(inputs)
    loss = criterion(outputs, targets)

    # backward and optimize
    loss.backward()
    optimizer.step()

这里的输入也要注意。PyTorch的输入都是Torch Tensors,而不是Numpy arrays作为输入。

Array到Tensor的转换

# (样本数量 * 维度) 
X = X.reshape(N, 1) 
Y = Y.reshape(N, 1) 

# PyTorch 默认是float32
# Numpy 默认是float64
inputs = torch.from_numpy(X.astype(np.float32))
targets = torch.from_numpy(Y.astype(np.float32))

预测

predictions = model(inputs).detach().numpy()

你可能感兴趣的:(线性回归,机器学习)