deep learning linear regression demo1

import numpy as np
import torch
import random


def data_iter(features:torch.tensor, labels, batch_size):
    indies = len(labels)
    if len(features) == indies:
        indies_ = list(range(indies))
        # print('indies_ 1', indies_)
        random.shuffle(indies_)
        # print('indies_ 2', indies_)
        for i in range(0, indies, batch_size):
            j = torch.tensor(indies_[i: min(i+batch_size, indies)])
            # print('j:', j)
            yield features.index_select(0, j), labels.index_select(0, j)


def linereg(x, w, b):
    return torch.mm(x, w) + b


def sgd_loss(y_hat, y):
    return (y_hat - y.view(y_hat.size())) ** 2/2
    # return (y_hat.view(y.size()) - y) ** 2/2


def optimization(params, lr, batch_size):
    for param in params:
        param.data -= lr * param.grad/batch_size


def test():
    num_features = 2
    num_data = 1000
    init_w = [2, -3.4]
    init_b = 3

    features = torch.normal(0, 1, (num_data, num_features))
    labels = features[:, 0] * init_w[0] + features[:, 1]*init_w[1] + init_b
    e = torch.normal(0, 0.01, size=labels.size())
    print(features.size())
    print(labels.size())
    labels += e
    batch_size = 10

    for x, y in data_iter(features, labels, batch_size):
        print('x:', x)
        print('y:', y)
        break

    epoch = 4
    lr = 0.01
    model = linereg
    loss_fun = sgd_loss
    w = torch.normal(0, 1, [num_features, 1], requires_grad=True)
    b = torch.ones(1, requires_grad=True)
    for i in range(epoch):
        for x, y in data_iter(features, labels, batch_size):
            y_hat = model(x, w, b)
            loss = loss_fun(y_hat, y).sum()
            print('epoch:', i, 'loss:', loss.item())
            loss.backward()
            optimization([w, b], lr, batch_size)
            # print(w.grad.data.size())
            # print(type(w.grad.data))
            # break
            w.grad.data.zero_()
            b.grad.data.zero_()


if __name__ == '__main__':
    test()

你可能感兴趣的:(深度学习,线性回归,人工智能)