打印模型参数量

打印模型参数量

total_params = sum(p.numel() for p in model.parameters())
print('参数量为:')
print(f'{total_params:,} total parameters.')

添加位置为模型训练模块下,需要注意的部分是,model.parameters()中的model为训练模型如下:

def Train(model, config, train_loader, test_loader):
    start_time = time.time()
    model.train()
    # 打印模型参数量
    total_params = sum(p.numel() for p in model.parameters())
    print('参数量为:')
    print(f'{total_params:,} total parameters.')

你可能感兴趣的:(自然语言处理,python,深度学习)