一句话实现自动调整batch_size,再也不会Cuda out of memory

需要huggingface的accelerate库

核心函数(装饰器)
accelerate.find_executable_batch_size(function: callable = None, starting_batch_size: int = 128):
官方文档
(很短,小白基本看不懂,实际只需改2句话)

传统代码

以一个torch的valid_loop函数为例,valid_loop函数需要以batch_size为第一个参数(必须位置上是第一个,不能用关键字传参)

# 验证loop,验证一个epoch
def valid_loop_fn(batch_size: int, dataset, model, data_collator=None):
    model.eval()
    losses = []
    data_loader = DataLoader(dataset, batch_size=batch_size,
                             shuffle=False, collate_fn=data_collator)

    with torch.no_grad():
        tk0 = tqdm(data_loader, total=len(data_loader), desc='valid')
        for step, batch in enumerate(tk0):
            outputs = model(**batch)
            loss = outputs[0]
            losses.append(loss.cpu().item())

    # 计算平均loss
    avg_loss = np.mean(losses)
    return avg_loss

正常情况下,如果传入一个较大的batch_size例如512,就会报错

valid_losses = valid_loop_fn(512,tokenized_dataset['test'], model, data_collator)

然后RuntimeError: CUDA out of memory
然后就要手动修改batch_size,很麻烦

改进代码

此时,只需要accelerate.find_executable_batch_size对valid_loop进行装饰即可,即函数前面加一句@accelerate.find_executable_batch_size(starting_batch_size=512)
如果显存不够,会自动将batch_size减半,不会报错

import accelerate
# 验证loop,验证一个epoch
@accelerate.find_executable_batch_size(starting_batch_size=128)
def valid_loop_fn(batch_size: int, dataset, model, data_collator=None):
    model.eval()
    losses = []
    data_loader = DataLoader(dataset, batch_size=batch_size,
                             shuffle=False, collate_fn=data_collator)

    with torch.no_grad():
        tk0 = tqdm(data_loader, total=len(data_loader), desc='valid')
        for step, batch in enumerate(tk0):
            outputs = model(**batch)
            loss = outputs[0]
            losses.append(loss.cpu().item())

    # 计算平均loss
    avg_loss = np.mean(losses)
    return avg_loss

另一个改动,就是调用valid_loop时不需要再传入batch_size,如果传入反而会报错。因为装饰器已经帮你设定了batch_size,你只需要告诉被装饰的函数除了batch_size以外的信息即可,如下

valid_losses = valid_loop_fn(tokenized_dataset['test'], model, data_collator)

经过一点微小的改动,再也不会因为batch_size太大而报错了
(*^▽^*)

你可能感兴趣的:(深度学习,pytorch,cuda,batchsize,显存)