机器学习实践指南(五)—— GD/SGD/MSGD 伪代码演示

GD:梯度下降

while True:
    loss = f(params)
    d_loss_wrt_params = ...
    params -= eta * d_loss_wrt_params
    if :
        return params

SGD:随机梯度下降

逐样本训练:

for x_i, y_i in training_data:
    loss = f(params, x_i, y_i)
    d_loss_wrt_params = ...
    params -= eta * d_loss_wrt_params
    if :
        return params

更进一步,如果外层还有一个 epochs:

for j in range(epochs):
    random.shuffle(training_data)
    for x_i, y_i in training_data:
        ...

MSGD(Minibatch SGD):块随机梯度下降

n = len(training_data)
mini_batch_size = ...
mini_batches = [training_data[k:k+mini_batch_size] for k in range(0, n, mini_batch_size)]
for mini_batch in mini_batches:
    loss = f(params, mini_batch)
    d_loss_wrt_params = ...
    params -= eta * d_loss_wrt_params
    if return params

你可能感兴趣的:(机器学习)