自学内容网 自学内容网

Vanna使用ollama分析本地MySQL数据库 加入redis保存训练记录

相关代码

在这里插入图片描述

from vanna.base.base import VannaBase
from vanna.chromadb import ChromaDB_VectorStore
from vanna.ollama import Ollama
import logging
import os
import requests
import json
import pandas as pd
import chromadb
import redis
import pickle
from IPython.display import display

logging.basicConfig(level=logging.INFO)

class MyVanna(ChromaDB_VectorStore, Ollama):
    def __init__(self, config=None):
        # 初始化配置
        self.config = {
            'model': 'llama2:latest',
            'ollama_host': 'http://127.0.0.1:11434',
            'verbose': True,
            'temperature': 0.1,
            'collection_name': 'my_vanna_collection',
            'redis_host': '127.0.0.1',
            'redis_port': 6379,
            'redis_db': 5,
            'redis_password': '123456',
            'redis_key_prefix': 'vanna_training:'
        }
        if config:
            self.config.update(config)
            
        # 初始化 ChromaDB
        self.chroma_client = chromadb.PersistentClient(path=self.config['chroma_db_path'])
        try:
            self._collection = self.chroma_client.get_collection(self.config['collection_name'])
            logging.info(f"获取已存在的集合: {self.config['collection_name']}")
        except:
            self._collection = self.chroma_client.create_collection(self.config['collection_name'])
            logging.info(f"创建新的集合: {self.config['collection_name']}")
            
        # 初始化 Redis 连接
        try:
            self.redis_client = redis.Redis(
                host=self.config['redis_host'],
                port=self.config['redis_port'],
                db=self.config['redis_db'],
                password=self.config['redis_password'],
                decode_responses=False,
                socket_timeout=5,
                retry_on_timeout=True
            )
            # 测试连接
            self.redis_client.ping()
            logging.info("Redis 连接成功")
        except Exception as e:
            logging.error(f"Redis 连接错误: {str(e)}")
            raise
            
        # 初始化父类
        ChromaDB_VectorStore.__init__(self, config=self.config)
        Ollama.__init__(self, config=self.config)
        
        self._ddl = None
        
    def submit_prompt(self, prompt, **kwargs):
        """重写 submit_prompt 方法"""
        try:
            url = f"{self.config['ollama_host']}/api/generate"
            
            # 如果传入的是消息列表,则组合成单个提示词
            if isinstance(prompt, list):
                full_prompt = "\n".join([msg.get('content', '') for msg in prompt if isinstance(msg, dict)])
            else:
                full_prompt = prompt
                
            data = {
                "model": self.config['model'],
                "prompt": full_prompt,
                "stream": False
            }
            
            headers = {
                "Content-Type": "application/json"
            }
            
            logging.info(f"发送请求到 Ollama: {url}")
            logging.debug(f"请求数据: {json.dumps(data, ensure_ascii=False)}")
            
            response = requests.post(url, json=data, headers=headers)
            response.raise_for_status()
            
            response_data = response.json()
            logging.debug(f"Ollama 响应: {json.dumps(response_data, ensure_ascii=False)}")
            
            if 'response' in response_data:
                return response_data['response'].strip()
            else:
                logging.error(f"Ollama 响应格式错误: {response_data}")
                raise ValueError("无效的 Ollama 响应格式")
                
        except Exception as e:
            logging.error(f"提交 prompt 错误: {str(e)}")
            raise
            
    def train(self, ddl=None, question=None, sql=None, documentation=None):
        """重写 train 方法,使用 Redis"""
        try:
            if ddl:
                self._ddl = ddl
                # 保存 DDL 到 Redis
                self.redis_client.set(f"{self.config['redis_key_prefix']}ddl", ddl)
                logging.info("DDL 已保存到 Redis")
                
            if question and sql:
                # 准备训练数据
                data = {
                    'question': question,
                    'sql': sql,
                    'documentation': documentation or ''
                }
                
                # 生成唯一 ID
                import hashlib
                doc_id = hashlib.md5(json.dumps(data, ensure_ascii=False).encode()).hexdigest()
                
                # 保存到 Redis
                key = f"{self.config['redis_key_prefix']}data:{doc_id}"
                self.redis_client.set(key, pickle.dumps(data))
                
                # 将 ID 添加到训练数据集合中
                self.redis_client.sadd(f"{self.config['redis_key_prefix']}data_ids", doc_id)
                
                logging.info(f"训练数据已保存到 Redis: {data}")
                
            return True
            
        except Exception as e:
            logging.error(f"训练错误: {str(e)}")
            raise
            
    def get_sql_prompt(self, question, ddl=None, similar_questions=None, similar_sql=None, 
                      initial_prompt=None, question_sql_list=None, ddl_list=None, doc_list=None,
                      **kwargs):
        """重写 get_sql_prompt 方法"""
        # 使用存储的 DDL
        if not ddl and self._ddl:
            ddl = self._ddl
            
        # 构建提示词
        prompt = "你是一个 SQL 专家。请根据以下信息生成 SQL 查询。\n\n"
        
        prompt += "### 数据库结构:\n"
        if ddl:
            prompt += f"{ddl}\n\n"
            
        # 添加文档说明
        if doc_list:
            prompt += "### 相关文档:\n"
            for doc in doc_list:
                prompt += f"{doc}\n"
            prompt += "\n"
            
        prompt += "### 问题:\n"
        prompt += f"{question}\n\n"
        
        if similar_questions and similar_sql:
            prompt += "### 相似问题和对应的 SQL:\n"
            for q, s in zip(similar_questions, similar_sql):
                prompt += f"\n问题: {q}\nSQL: {s}\n"
                
        prompt += "\n### 请生成对应的 SQL 查询 汉字转为简体:\n"
        return prompt
            
    def generate_sql(self, question, **kwargs):
        try:
            if self._ddl:
                kwargs['ddl'] = self._ddl
            return super().generate_sql(question, **kwargs)
        except Exception as e:
            logging.error(f"SQL 生成错误: {str(e)}")
            raise
            
    def get_related_ddl(self, question=None, **kwargs):
        """重写 get_related_ddl 方法,从 Redis 获取 DDL"""
        try:
            if self._ddl:
                return self._ddl
            
            # 从 Redis 获取 DDL
            ddl = self.redis_client.get(f"{self.config['redis_key_prefix']}ddl")
            if ddl:
                self._ddl = ddl.decode()
                return self._ddl
            
            return None
        except Exception as e:
            logging.error(f"获取 DDL 错误: {str(e)}")
            return None

    def generate_plotly_code(self, question, sql_result=None, **kwargs):
        """重写 generate_plotly_code 方法"""
        try:
            # 构建提示词
            prompt = self.get_plotly_prompt(question, sql_result=sql_result, **kwargs)
            
            # 添加系统提示词
            system_prompt = "你是一个数据可视化专家。请根据用户的需求生成 Plotly 图表代码。只返回 Python 代码,不需要其他解释。如果繁体转为简体。"
            full_prompt = f"{system_prompt}\n\n{prompt}"
            
            # 直接调用 submit_prompt
            return self.submit_prompt(full_prompt, is_plotly=True)
            
        except Exception as e:
            logging.error(f"生成图表代码错误: {str(e)}")
            raise
            
    def get_plotly_prompt(self, question, sql=None, sql_result=None, **kwargs):
        """重写 get_plotly_prompt 方法"""
        prompt = f"""请根据以下信息生成 Plotly 图表代码:

问题:{question}

SQL查询:{sql if sql else ''}

查询结果:{sql_result if sql_result else ''}

要求:
1. 使用 Plotly Express 生成图表
2. 只返回 Python 代码
3. 不要包含任何解释或说明
4. 确保代码的正确性
5. 如果繁体转为简体
"""
        return prompt

    def get_training_data(self):
        """重写 get_training_data 方法,使用 Redis"""
        try:
            # 获取所有训练数据 ID
            data_ids = self.redis_client.smembers(f"{self.config['redis_key_prefix']}data_ids")
            
            if not data_ids:
                logging.info("Redis 中没有找到训练数据")
                return pd.DataFrame(columns=['question', 'sql', 'documentation'])
            
            # 获取所有训练数据
            documents = []
            for doc_id in data_ids:
                try:
                    key = f"{self.config['redis_key_prefix']}data:{doc_id.decode()}"
                    data = self.redis_client.get(key)
                    if data:
                        doc_data = pickle.loads(data)
                        documents.append(doc_data)
                        logging.info(f"从 Redis 加载数据: {doc_data}")
                except Exception as e:
                    logging.error(f"处理 Redis 数据时出错: {e}")
                    continue
            
            # 创建 DataFrame
            if documents:
                df = pd.DataFrame(documents)
                logging.info(f"已加载 {len(df)} 条训练数据")
                return df
            else:
                logging.info("没有找到有效的训练数据")
                return pd.DataFrame(columns=['question', 'sql', 'documentation'])
                
        except Exception as e:
            logging.error(f"获取训练数据错误: {str(e)}")
            return pd.DataFrame(columns=['question', 'sql', 'documentation'])

def train_model(vn):
    try:
        # 训练 DDL
        print("开始训练 DDL...")
        ddl = """
         CREATE TABLE `customer` (
        `name` int NOT NULL COMMENT '姓名',
        `gender` int DEFAULT NULL COMMENT '性别(男性=1/女性=2)',
        `id_card` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '身份证',
        `mobile` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '手机',
        `nation` varchar(10) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '民族',
        `residential_city` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '居住城市',
        `age` int DEFAULT NULL COMMENT '岁数 年纪',
        `salary` int NOT NULL COMMENT '薪水',
        PRIMARY KEY (`name`)
        ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='customer'
        """
        vn.train(ddl=ddl)
        print("DDL 训练完成")

        # 训练示例查询
        examples = [
            {
                'question': "宁波有多少客户?",
                'sql': "SELECT COUNT(*) as count FROM customer WHERE residential_city like '%宁波%'"
            },
            {
                'question': "有多少女性客户?",
                'sql': "SELECT COUNT(*) as count FROM customer WHERE gender = 2"
            },
            {
                'question': "客户平均年龄是多少?",
                'sql': "SELECT AVG(age) as average_age FROM customer"
            },
            {
                'question': "客户平均薪水是多少?",
                'sql': "SELECT AVG(salary) as average_salary FROM customer"
            }
        ]
        
        for example in examples:
            print(f"\n训练示例: {example['question']}")
            vn.train(question=example['question'], sql=example['sql'])
            
        
        print("\n所有训练完成")
        result = vn.ask("宁波有多少客户?")
        print(f"\n查询问题: 宁波有多少客户?\n查询结果: {result}")
        
    except Exception as e:
        logging.error(f"训练错误: {str(e)}")
        raise

if __name__ == "__main__":
    try:
        # 初始化 Vanna
        vn = MyVanna()
        
        # 连接数据库
        vn.connect_to_mysql(
            host='localhost',
            dbname='test',
            user='root',
            password='123456',
            port=3306
        )
        
        # 训练模型
        train_model(vn)
        
        # 启动 Flask 应用
        from vanna.flask import VannaFlaskApp
        app = VannaFlaskApp(vn)
        app.run(host='0.0.0.0', port=7123)
        
    except Exception as e:
        logging.error(f"程序运行错误: {str(e)}")

CREATE TABLE `customer` (
  `name` int NOT NULL COMMENT '姓名',
  `gender` int DEFAULT NULL COMMENT '性别(男性=1/女性=2)',
  `id_card` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '身份证',
  `mobile` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '手机',
  `nation` varchar(10) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '民族',
  `residential_city` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '居住城市',
  `age` int DEFAULT NULL COMMENT '岁数 年纪',
  `salary` int NOT NULL COMMENT '薪水',
  `id` int NOT NULL AUTO_INCREMENT COMMENT 'id',
  PRIMARY KEY (`id`)
) ENGINE=InnoDB AUTO_INCREMENT=21 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='customer';
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`)  VALUES ('1','1','330201199001011234','13800001111','汉族','宁波','27','5520','1');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`)  VALUES ('2','2','330201199102022345','13800002222','汉族','宁波','70','7042','2');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`)  VALUES ('3','1','330201199203033456','13800003333','回族','宁波','94','4119','3');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`)  VALUES ('4','2','330201199304044567','13800004444','汉族','宁波','60','4886','4');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`)  VALUES ('5','1','330201199405055678','13800005555','壮族','宁波','5','5762','5');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`)  VALUES ('6','1','110101199506066789','13800006666','汉族','北京','58','5515','6');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`)  VALUES ('7','2','310101199607077890','13800007777','汉族','上海','69','2927','7');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`)  VALUES ('8','1','440101199708088901','13800008888','满族','广州','90','5979','8');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`)  VALUES ('9','2','500101199809099012','13800009999','汉族','重庆','91','7256','9');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`)  VALUES ('10','1','610101199910101123','13800010000','回族','西安','28','4067','10');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`)  VALUES ('11','2','320101199001111234','13800011111','汉族','南京','13','1979','11');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`)  VALUES ('12','1','330101199002121345','13800012222','畲族','杭州','8','994','12');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`)  VALUES ('13','2','420101199003131456','13800013333','汉族','武汉','29','1073','13');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`)  VALUES ('14','1','510101199004141567','13800014444','彝族','成都','84','1441','14');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`)  VALUES ('15','2','350101199005151678','13800015555','汉族','福州','33','7725','15');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`)  VALUES ('16','1','370101199006161789','13800016666','汉族','济南','89','3821','16');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`)  VALUES ('17','2','430101199007171890','13800017777','苗族','长沙','86','3082','17');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`)  VALUES ('18','1','220101199008181901','13800018888','汉族','长春','48','4170','18');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`)  VALUES ('19','2','450101199009192012','13800019999','壮族','南宁','30','1498','19');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`)  VALUES ('20','1','130101199010202123','13800020000','汉族','石家庄','54','941','20');

1. 系统概述

这是一个基于 Ollama 和 Redis 的智能 SQL 问答系统,可以将自然语言问题转换为 SQL 查询语句。系统具有以下主要特点:

  • 基于 LLM (Large Language Model) 的自然语言转 SQL
  • 支持训练数据的持久化存储
  • 提供 REST API 接口
  • 支持数据可视化生成

2. 核心组件

2.1 MyVanna 类

主要继承关系:

class MyVanna(ChromaDB_VectorStore, Ollama)

核心配置参数:

self.config = {
    'model': 'llama2:latest',          # LLM 模型
    'ollama_host': 'http://127.0.0.1:11434',  # Ollama 服务地址
    'temperature': 0.1,                # 生成温度
    'redis_host': '127.0.0.1',        # Redis 配置
    'redis_port': 6379,
    'redis_db': 5,
    'redis_password': '123456',
    'redis_key_prefix': 'vanna_training:'
}

3. 主要功能模块

3.1 提示词生成 (Prompt Engineering)

def get_sql_prompt(self, question, ddl=None, similar_questions=None, similar_sql=None, ...):

提示词结构:

  1. 角色定义
  2. 数据库结构说明
  3. 相关文档
  4. 用户问题
  5. 相似问题参考
  6. 输出要求

3.2 训练功能

def train(self, ddl=None, question=None, sql=None, documentation=None):

训练数据包含:

  • DDL(数据库结构)
  • 问题-SQL 对
  • 相关文档

存储方式:

  • 使用 Redis 持久化
  • 使用 hash 作为唯一标识
  • 支持批量训练

3.3 SQL 生成

def generate_sql(self, question, **kwargs):

工作流程:

  1. 获取相关 DDL
  2. 构建提示词
  3. 调用 LLM 生成 SQL
  4. 错误处理和日志记录

3.4 数据可视化

def generate_plotly_code(self, question, sql_result=None, **kwargs):

特点:

  • 使用 Plotly 生成可视化代码
  • 支持 SQL 结果的直接可视化
  • 自动处理中文编码

4. 示例训练数据

examples = [
    {
        'question': "宁波有多少客户?",
        'sql': "SELECT COUNT(*) as count FROM customer WHERE residential_city like '%宁波%'"
    },
    {
        'question': "有多少女性客户?",
        'sql': "SELECT COUNT(*) as count FROM customer WHERE gender = 2"
    }
    # ...
]

5. 部署和使用

5.1 环境要求

  • Python 3.x
  • Redis 服务
  • MySQL 数据库
  • Ollama 服务

5.2 启动服务

if __name__ == "__main__":
    vn = MyVanna()
    vn.connect_to_mysql(...)
    train_model(vn)
    app = VannaFlaskApp(vn)
    app.run(host='0.0.0.0', port=7123)

6. 改进建议

  1. 错误处理优化

    • 添加更详细的错误类型
    • 实现错误重试机制
  2. 性能优化

    • 添加缓存机制
    • 实现批量处理
  3. 安全性增强

    • 添加 SQL 注入防护
    • 实现访问控制
  4. 功能扩展

    • 支持更多数据库类型
    • 添加更多可视化选项
    • 实现对话历史记录

7. 总结

该系统通过结合 LLM 和传统数据库技术,实现了一个灵活的自然语言到 SQL 的转换系统。其模块化设计和可扩展性使其适合在实际业务场景中使用和扩展。

主要优势:

  • 模块化设计
  • 可扩展架构
  • 完整的训练流程
  • 持久化存储支持

潜在改进空间:

  • 性能优化
  • 安全性增强
  • 功能扩展
  • 错误处理完善

原文地址:https://blog.csdn.net/hzether/article/details/143840276

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!