ChatGLM2-6B在Windows下的微调

ChatGLM2-6B在Windows下的微调

零、重要参考资料

1、ChatGLM2-6B! 我跑通啦!本地部署+微调(windows系统):这是最关键的一篇文章,提供了Windows下的脚本
2、LangChain + ChatGLM2-6B 搭建个人专属知识库:提供了基本的训练思路。

一、前提

1、已完成ChatGLM2-6B的部署,假设部署位置为D:_ChatGPT\langchain-chatglm_test\ChatGLM2-6B
2、部署环境
Windows 10 专业版、已安装CUDA11.3、Anaconda3,有显卡NVIDIA GeForce RTX 3060 Laptop GPU。

二、总体思路

由于官方文档和一般博客中都是在Linux环境下完成,所以在Windows下主要注意两点:
1、huggingface下载的chatglm2-6b模型的目录不能有减号存在,否则报错。
2、使用bat文件替代官方文档中的sh文件。

三、安装依赖及环境准备

1、进入Anaconda Powershell Prompt

2、进入虚拟环境

conda activate langchain-chatglm_test

3、安装依赖

pip install rouge_chinese nltk jieba datasets

4、禁用W&B,如果不禁用可能会中断微调训练,以防万一

setx WANDB_DISABLED true

四、准备数据集

1、在ChatGLM2-6B的ptuning目录下创建train.json 和 dev.json这两个文件,文件中的数据如下:

{"content": "你好,你是谁", "summary": "你好,我是树先生的助手小6。"}
{"content": "你是谁", "summary": "你好,我是树先生的助手小6。"}
{"content": "树先生是谁", "summary": "树先生是一个程序员,热衷于用技术探索商业价值,持续努力为粉丝带来价值输出,运营公众号《程序员树先生》。"}
{"content": "介绍下树先生", "summary": "树先生是一个程序员,热衷于用技术探索商业价值,持续努力为粉丝带来价值输出,运营公众号《程序员树先生》。"}
{"content": "树先生", "summary": "树先生是一个程序员,热衷于用技术探索商业价值,持续努力为粉丝带来价值输出,运营公众号《程序员树先生》。"}

2、这里为了简化,只准备了5条测试数据,实际使用的时候肯定需要大量的训练数据。如下为train.json和dev.json的

五、创建训练和推理脚本

1、ChatGLM2-6B默认只提供了Linux下训练和推理使用的train.sh和evaluate.sh脚本,没有提供WIndows下的脚本,因此需要自己创建脚本。
2、在ptuning目录下创建train.bat脚本,文件内容如下:

set PRE_SEQ_LEN=128
set LR=2e-2
set NUM_GPUS=1

python main.py ^
    --do_train ^
    --train_file train.json ^
    --validation_file dev.json ^
    --preprocessing_num_workers 10 ^
    --prompt_column content ^
    --response_column summary ^
    --overwrite_cache ^
    --model_name_or_path D:\_ChatGPT\_common\chatglm2_6b ^
    --output_dir output/adgen-chatglm2-6b-pt-%PRE_SEQ_LEN%-%LR% ^
    --overwrite_output_dir ^
    --max_source_length 128 ^
    --max_target_length 128 ^
    --per_device_train_batch_size 1 ^
    --per_device_eval_batch_size 1 ^
    --gradient_accumulation_steps 16 ^
    --predict_with_generate ^
    --max_steps 3000 ^
    --logging_steps 10 ^
    --save_steps 1000 ^
    --learning_rate %LR% ^
    --pre_seq_len %PRE_SEQ_LEN% ^
    --quantization_bit 4

注意model_name_or_path后跟的是实际的从huggingface下载的chatglm2-6b模型文件的位置,这个路径里不能有减号存在。
train.json、dev.json这里放的是两个文件的实际位置,可以根据需要修改。

3、在ptuning目录下创建evaluate.bat脚本,文件内容如下:

set PRE_SEQ_LEN=128
set CHECKPOINT=adgen-chatglm2-6b-pt-128-2e-2
set STEP=3000
set NUM_GPUS=1

python main.py ^
    --do_predict ^
    --validation_file dev.json ^
    --test_file dev.json ^
    --overwrite_cache ^
    --prompt_column content ^
    --response_column summary ^
    --model_name_or_path D:\_ChatGPT\_common\chatglm2_6b ^
    --ptuning_checkpoint ./output/%CHECKPOINT%/checkpoint-%STEP% ^
    --output_dir ./output/%CHECKPOINT% ^
    --overwrite_output_dir ^
    --max_source_length 128 ^
    --max_target_length 128 ^
    --per_device_eval_batch_size 1 ^
    --predict_with_generate ^
    --pre_seq_len %PRE_SEQ_LEN% ^
    --quantization_bit 4

六、训练和推理

1、进入Anaconda Powershell Prompt

2、进入虚拟环境

conda activate langchain-chatglm_test

3、进入ptuning目录

cd D:\_ChatGPT\langchain-chatglm_test\ChatGLM2-6B\ptuning

4、训练:训练需要比较长的时间,大概几个小时。

.\train.bat

5、推理:由于数量小,所以推理比较快

.\evaluate.bat

执行完成后,会生成评测文件,评测指标为中文 Rouge score 和 BLEU-4。生成的结果保存在 ./output/adgen-chatglm2-6b-pt-32-2e-2/generated_predictions.txt。我们准备了 5 条推理数据,所以相应的在文件中会有 5 条评测数据,labels 是 dev.json 中的预测输出,predict 是 ChatGLM2-6B 生成的结果,对比预测输出和生成结果,评测模型训练的好坏。如果不满意调整训练的参数再次进行训练。

七、创建脚本,部署微调后的模型

1、本来在Linux下可以修改ptuning目录下的web_demo.sh脚本即可实现部署,在Windows下需要在ptuning目录下自行创建web_demo.bat脚本,内容如下:

python web_demo.py ^
    --model_name_or_path D:\_ChatGPT\_common\chatglm2_6b ^
    --ptuning_checkpoint output\adgen-chatglm2-6b-pt-128-2e-2\checkpoint-3000 ^
	--pre_seq_len 128

2、修改ptuning目录下的web_demo.py脚本,使模型能被本地访问:

demo.queue().launch(share=False, inbrowser=True, server_name='0.0.0.0', server_port=7860)

八、启动应用

1、进入Anaconda Powershell Prompt

2、进入虚拟环境

conda activate langchain-chatglm_test

3、进入ptuning目录

cd D:\_ChatGPT\langchain-chatglm_test\ChatGLM2-6B\ptuning

4、启动微调后的模型(注意启动前关闭fanqiang软件cd)

.\web_demo.bat

5、这时问他你训练过的问题,发觉已经使用的是微调后的模型了。
ChatGLM2-6B在Windows下的微调_第1张图片

你可能感兴趣的:(大模型,windows,语言模型,chatgpt)