大模型微调报错:RuntimeError: expected scalar type Half but found Float

微调chatglm 报错RuntimeError: expected scalar type Half but found Float

1. 背景

博主显卡:3090
最初的设置:bfloat16
开始训练后,线性层报错

2. 解决: 统一代码中所有精度

1)将模型和数据精度都设置为torch.float32/torch.float16

xxx = torch.tensor(xxx, dtype=torch.float32)
model.config.torch_dtype = torch.float32

2)将模型参数都设置为torch.float32/torch.float16

for param in model.parameters():
    # Check if parameter dtype is  Float (float32)
    if param.dtype == torch.float16:
        param.data = param.data.to(torch.float32)

你可能感兴趣的:(多模态大模型,MLLM,pytorch,深度学习,人工智能)