训练大模型时显存占用影响因素总结(以starcoderplus 15.5B为例)

背景:实验室计算资源紧张,单卡最大显存是RTX 3090 24G, 但是又要用大模型(指参数量达到10B)做实验。

模型文件获取(大约60B) https://huggingface.co/bigcode/starcoderplus/tree/main

角度一: 

看模型是否提供了降低参数精度的加载方式,如

self.model = AutoModelForCausalLM.from_pretrained(checkpoint,device_map='auto',load_in_8bit=True)

角度二:

针对transformers提供的pipeline

目前已经探索的参数有

(1) max_new_tokens

我将max_new_tokens设置成max_new_tokens=len(input_prompt)*0.2+50,因为输入字符串的长度和编码后的token sequence比值大约是5:1,所以乘上0.2,加的50是给输出预留长度。显存占用会同时受输入和预计输出长度的影响。

你可能感兴趣的:(提示工程,大模型,提示工程,starcoder)