DB-GPT-HUB Text-to-SQL微调
项目介绍
DB-GPT-Hub是一个利用LLMs实现Text-to-SQL解析的实验项目,主要包含数据集收集、数据预处理、模型选择与构建和微调权重等步骤,通过这一系列的处理可以在提高Text-to-SQL能力的同时降低模型训练成本,让更多的开发者参与到Text-to-SQL的准确度提升工作当中,最终实现基于数据库的自动问答能力,让用户可以通过自然语言描述完成复杂数据库的查询操作等工作。
本次微调使用的基座模型是Qwen-14B-Chat。
spider数据集,包含训练数据8659条,测试数据1034条。
安装Python3.10
本人在windows上和linux都安装了各个版本的python,参考我的这篇文章,也可以使用构建Docker镜像的方式。
训练
将数据解压到dbgpt_hub/data
目录下,即dbgpt_hub/data/spider
生成数据
sh dbgpt_hub/scripts/gen_train_eval_data.sh
在单卡A6000训练,耗时11小时41分钟。
***** train metrics *****
epoch = 8.0
train_loss = 0.0281
train_runtime = 11:41:19.00
train_samples_per_second = 1.646
train_steps_per_second = 0.103
训练参数设置
CUDA_VISIBLE_DEVICES=1 python dbgpt_hub/train/sft_train.py \
--model_name_or_path /soft/Qwen-14B-Chat/ \
--do_train \
--dataset example_text2sql_train \
--max_source_length 2048 \
--max_target_length 512 \
--finetuning_type lora \
--lora_target c_attn \
--template chatml \
--lora_rank 64 \
--lora_alpha 32 \
--output_dir dbgpt_hub/output/adapter/qwen-14b-sql-lora \
--overwrite_cache \
--overwrite_output_dir \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 16 \
--lr_scheduler_type cosine_with_restarts \
--logging_steps 50 \
--save_steps 2000 \
--learning_rate 2e-4 \
--num_train_epochs 8 \
--plot_loss \
--bf16 >> ${train_log}
# --bf16#v100不支持bf16
测试
生成sql,如果不需要加载微调的checkpoint,将checkpoint_dir和finetuning_type去掉即可。
单卡A6000,每条sql生成大概1.45s。
sh ./dbgpt_hub/scripts/predict_sft.sh
# dbgpt_hub/scripts/predict_sft.sh
CUDA_VISIBLE_DEVICES=1 python dbgpt_hub/predict/predict.py \
--model_name_or_path /soft/Qwen-14B-Chat/ \
--template chatml \
--finetuning_type lora \
--checkpoint_dir dbgpt_hub/output/adapter/qwen-14b-sql-lora \
--predicted_out_filename pred_sql.sql >> ${pred_log}
评估
python dbgpt_hub/eval/evaluation.py --plug_value --input dbgpt_hub/output/pred/pred_sql.sql
微调前后对比
简单sql(248条) | 中等sql(446条) | 复杂sql(174条) | 其他(166条) | (1034条) | |
---|---|---|---|---|---|
微调前准确率 | 0.863 | 0.717 | 0.483 | 0.325 | 0.650 |
微调后准确率 | 0.935 | 0.785 | 0.603 | 0.416 | 0.731 |
其他数据集
- WikiSQL: 一个大型的语义解析数据集,由80,654个自然语句表述和24,241张表格的sql标注构成。WikiSQL中每一个问句的查询范围仅限于同一张表,不包含排序、分组、子查询等复杂操作。
- CHASE: 一个跨领域多轮交互text2sql中文数据集,包含5459个多轮问题组成的列表,一共17940个<query, SQL>二元组,涉及280个不同领域的数据库。
- BIRD-SQL:数据集是一个英文的大规模跨领域文本到SQL基准测试,特别关注大型数据库内容。该数据集包含12,751对文本到SQL数据对和95个数据库,总大小为33.4GB,跨越37个职业领域。BIRD-SQL数据集通过探索三个额外的挑战,即处理大规模和混乱的数据库值、外部知识推理和优化SQL执行效率,缩小了文本到SQL研究与实际应用之间的差距。
- CoSQL:是一个用于构建跨域对话文本到sql系统的语料库。它是Spider和SParC任务的对话版本。CoSQL由30k+回合和10k+带注释的SQL查询组成,这些查询来自Wizard-of-Oz的3k个对话集合,查询了跨越138个领域的200个复杂数据库。每个对话都模拟了一个真实的DB查询场景,其中一个工作人员作为用户探索数据库,一个SQL专家使用SQL检索答案,澄清模棱两可的问题,或者以其他方式通知。
- 按照NSQL的处理模板,对数据集做简单处理,共得到约20w条训练数据
问题解决
-
poetry安装问题
个人感觉poetry不太好用,更换了镜像源之后,解析下载缓慢,不知道是不是因为中断了poetry下载依赖,后续的依赖解析一直卡住,也没有超时提示。最后手动使用pip安装。
-
评估时内网nltk下载问题
nltk需要下载相关语料,内网无法下载,外网也容易超时,到这里下载,将packages文件夹重命名为nltk_data,拷贝到报错说明的几个位置中的一个即可。
Q.E.D.