pytorch backward 求梯度 累计 样式

pytorch backwad 函数计算梯度是 累计式的

关于 pytorch 的 backward()函数反向传播计算梯度是 累计式的,见下图(主要是图中用黑框框出来的部分内容)。
因为这样,所以才需要 optimizer.zero_grad()
pytorch backward 求梯度 累计 样式_第1张图片

利用 gradient accumulation 的框架

一般的优化框架是:

# loop through batches
for (inputs, labels) in data_loader:

    # extract inputs and labels
    inputs = inputs.to(device)
    labels = labels.to(device)

    # passes and weights update
    with torch.set_grad_enabled(True):
        
        # forward pass 
        preds = model(inputs)
        loss  = criterion(preds, labels)

        # backward pass
        loss.backward() 

        # weights update
        optimizer.step()
        optimizer.zero_grad()

利用 gradient accumulation 的框架是这样的。
为什么需要 gradient accumulation 呢?
因为可能会出现 训练集 batch size 比较大,电脑 显存吃不下的情况,这样就需要将一个 batch 分为几个小 batch训练,但是同时又都利用它们的gradient 信息。

# batch accumulation parameter
accum_iter = 4  

# loop through enumaretad batches
for batch_idx, (inputs, labels) in enumerate(data_loader):

    # extract inputs and labels
    inputs = inputs.to(device)
    labels = labels.to(device)

    # passes and weights update
    with torch.set_grad_enabled(True):
        
        # forward pass 
        preds = model(inputs)
        loss  = criterion(preds, labels)

        # normalize loss to account for batch accumulation
        loss = loss / accum_iter 

        # backward pass
        loss.backward()

        # weights update
        if ((batch_idx + 1) % accum_iter == 0) or (batch_idx + 1 == len(data_loader)):
            optimizer.step()
            optimizer.zero_grad()

你可能感兴趣的:(PyTorch,pytorch,python,深度学习)