K中期考核-降低显存的关键

读入模型参数-且不占用显卡

see_memory_usage(‘message’)

# 4. 读入checkpiont参数
    state_dict=torch.load('../train-output/'+ args.model_name_or_path.split('/')[-1] +'/unet/diffusion_pytorch_model.bin', map_location='cpu')
    # state_dict=torch.load('../train-output/'+ args.model_name_or_path.split('/')[-1] +'/unet/diffusion_pytorch_model.bin')

如果不加map_location='cpu',读入的参数就会占用GPU

see_memory_usage(‘message’)

可以通过添加下面函数来检测GPU占用情况

import gc
import psutil
def see_memory_usage(message):
 
    # python doesn't do real-time garbage collection so do it explicitly to get the correct RAM reports
    gc.collect()
 
    # Print message except when distributed but not rank 0
    print(message)
    print(f"MA {round(torch.cuda.memory_allocated() / (1024 * 1024 * 1024),2 )} GB \
        Max_MA {round(torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024),2)} GB \
        CA {round(torch.cuda.memory_reserved() / (1024 * 1024 * 1024),2)} GB \
        Max_CA {round(torch.cuda.max_memory_reserved() / (1024 * 1024 * 1024))} GB ")
 
    vm_stats = psutil.virtual_memory()
    used_GB = round(((vm_stats.total - vm_stats.available) / (1024**3)), 2)
    print(f'CPU Virtual Memory:  used = {used_GB} GB, percent = {vm_stats.percent}%')
 
    # get the peak memory to report correct data, so reset the counter for the next call
    if hasattr(torch.cuda, 'reset_peak_memory_stats'):
        return torch.cuda.reset_peak_memory_stats()

你可能感兴趣的:(深度学习)