大模型-ChatGLM-6B模型部署与微调记录
大模型-ChatGLM-6B模型部署与微调记录
模型权重下载:
登录魔塔社区:https://modelscope.cn/models/ZhipuAI/chatglm2-6b
拷贝以下代码执行后,便可快速权重下载到本地
# 备注:最新模型版本要求modelscope >= 1.9.0
# pip install modelscope -U
from modelscope.utils.constant import Tasks
from modelscope import Model
from modelscope.pipelines import pipeline
model = Model.from_pretrained('ZhipuAI/chatglm2-6b', device_map='auto', revision='v1.0.12')
pipe = pipeline(task=Tasks.chat, model=model)
inputs = {'text':'你好', 'history': []}
result = pipe(inputs)
inputs = {'text':'介绍下清华大学', 'history': result['history']}
result = pipe(inputs)
print(result)
运行微调除 ChatGLM2-6B 的依赖之外,还需要安装以下依赖:
pip install rouge_chinese nltk jieba datasets
下载数据集
ADGEN 数据集任务为根据输入(content)生成一段广告词(summary)。
{
"content": "类型#上衣*版型#宽松*版型#显瘦*图案#线条*衣样式#衬衫*衣袖型#泡泡袖*衣款式#抽绳",
"summary": "这件衬衫的款式非常的宽松,利落的线条可以很好的隐藏身材上的小缺点,穿在身上有着很好的显瘦效果。领口装饰了一个可爱的抽绳,漂亮的绳结展现出了十足的个性,配合时尚的泡泡袖型,尽显女性甜美可爱的气息。"
}
参数解释:
PRE_SEQ_LEN=128
LR=2e-2
NUM_GPUS=2
torchrun --standalone --nnodes=1 --nproc-per-node=$NUM_GPUS main.py \
--do_train \
--train_file /home/data/project/GOOGOSOFT/LLM/ChatGLM2-6B-main/AdvertiseGen/train.json \
--validation_file /home/data/project/GOOGOSOFT/LLM/ChatGLM2-6B-main/AdvertiseGen/dev.json \
--preprocessing_num_workers 10 \
--prompt_column content \
--response_column summary \
--overwrite_cache \
--model_name_or_path /home/data/project/GOOGOSOFT/LLM/ChatGLM2-6B-main/ZhipuAI/chatglm2-6b \
--output_dir output/adgen-chatglm2-6b-pt-$PRE_SEQ_LEN-$LR \
--overwrite_output_dir \
--max_source_length 128 \
--max_target_length 256 \
--per_device_train_batch_size 20 \
--per_device_eval_batch_size 20 \
--gradient_accumulation_steps 16 \
--predict_with_generate \
--max_steps 6000 \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate $LR \
--pre_seq_len $PRE_SEQ_LEN \
--quantization_bit 4
以下是一个 Python 脚本,用于计算 AdvertiseGen 数据集中 content 列的最大长度。此脚本假设数据集是 JSON 格式,文件路径为 AdvertiseGen/train.json。
脚本:计算最大 max_source_length
import json
# 数据集文件路径
train_file = "AdvertiseGen/train.json"
# 加载数据集
def load_data(file_path):
with open(file_path, "r", encoding="utf-8") as f:
data = json.load(f)
return data
# 计算最大输入长度
def calculate_max_source_length(data, column_name="content"):
lengths = [len(item[column_name]) for item in data if column_name in item]
max_length = max(lengths)
print(f"最大输入长度 (max_source_length): {max_length}")
return max_length
# 主函数
if __name__ == "__main__":
# 加载数据
data = load_data(train_file)
# 计算最大长度
max_source_length = calculate_max_source_length(data, column_name="content")
训练:
原文地址:https://blog.csdn.net/guoqingru0311/article/details/144745598
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!