LLM系列 | 14: 实测OpenAI函数调用功能:以数据库问答为例

本文涉及的产品
实时数仓Hologres,5000CU*H 100GB 3个月
智能开放搜索 OpenSearch行业算法版,1GB 20LCU 1个月
实时计算 Flink 版,5000CU*H 3个月
简介: 今天这篇小作文以数据库问答(Text2SQL)为例进一步介绍ChatGPT的函数调用。本文将介绍如何将模型生成的结果输入到自定义的函数中,并利用该功能实现数据库问答功能。

简介

黑云翻墨未遮山,白雨跳珠乱入船。
黑云翻墨未遮山,白雨跳珠乱入船.jpg

小伙伴们好,我是微信公众号《小窗幽记机器学习》的小编:划龙舟的小男孩。紧接前面几篇ChatGPT Prompt工程和应用系列文章:

更多、更新文章欢迎关注微信公众号:小窗幽记机器学习。后续会持续整理模型加速、模型部署、模型压缩、LLM、AI艺术等系列专题,敬请关注。

今天这篇小作文以数据库问答(Text2SQL)为例进一步介绍ChatGPT的函数调用。本文将介绍如何将模型生成的结果输入到自定义的函数中,并利用该功能实现数据库问答功能。为简单起见,这里将使用Chinook 示例数据库

需要特别注意:
生产环境中,生成的SQL可能存在较高风险。因为模型在生成正确的 SQL 这方面暂不完全可靠,小伙伴们评估谨慎使用

数据库相关

环境相关设置及其辅助函数代码请于文末附录部分。以下直接介绍示例数据库相关细节。

获取 Chinook 数据相关的信息:

import sqlite3

conn = sqlite3.connect("data/chinook.db")
print("Opened database successfully")

def get_table_names(conn):
    """Return a list of table names."""
    table_names = []
    tables = conn.execute("SELECT name FROM sqlite_master WHERE type='table';")
    for table in tables.fetchall():
        table_names.append(table[0])
    return table_names


def get_column_names(conn, table_name):
    """Return a list of column names."""
    column_names = []
    columns = conn.execute(f"PRAGMA table_info('{table_name}');").fetchall()
    for col in columns:
        column_names.append(col[1])
    return column_names


def get_database_info(conn):
    """Return a list of dicts containing the table name and columns for each table in the database."""
    table_dicts = []
    for table_name in get_table_names(conn):
        columns_names = get_column_names(conn, table_name)
        table_dicts.append({"table_name": table_name, "column_names": columns_names})
    return table_dicts

获取 db 中的table

可以获取chinook db中有哪些 table:

table_names = get_table_names(conn)
print("table_names=", table_names)

输出结果如下:

table_names= ['albums', 'sqlite_sequence', 'artists', 'customers', 'employees', 'genres', 'invoices', 'invoice_items', 'media_types', 'playlists', 'playlist_track', 'tracks', 'sqlite_stat1']

获取各 table 的schema

database_schema_dict = get_database_info(conn)
database_schema_string = "\n".join(
    [
        f"Table: {table['table_name']}\nColumns: {', '.join(table['column_names'])}"
        for table in database_schema_dict
    ]
)

database_schema_dict结果如下:

[{'table_name': 'albums', 'column_names': ['AlbumId', 'Title', 'ArtistId']},
 {'table_name': 'sqlite_sequence', 'column_names': ['name', 'seq']},
 {'table_name': 'artists', 'column_names': ['ArtistId', 'Name']},
 {'table_name': 'customers',
  'column_names': ['CustomerId',
   'FirstName',
   'LastName',
   'Company',
   'Address',
   'City',
   'State',
   'Country',
   'PostalCode',
   'Phone',
   'Fax',
   'Email',
   'SupportRepId']},
 {'table_name': 'employees',
  'column_names': ['EmployeeId',
   'LastName',
   'FirstName',
   'Title',
   'ReportsTo',
   'BirthDate',
   'HireDate',
   'Address',
   'City',
   'State',
   'Country',
   'PostalCode',
   'Phone',
   'Fax',
   'Email']},
 {'table_name': 'genres', 'column_names': ['GenreId', 'Name']},
 {'table_name': 'invoices',
  'column_names': ['InvoiceId',
   'CustomerId',
   'InvoiceDate',
   'BillingAddress',
   'BillingCity',
   'BillingState',
   'BillingCountry',
   'BillingPostalCode',
   'Total']},
 {'table_name': 'invoice_items',
  'column_names': ['InvoiceLineId',
   'InvoiceId',
   'TrackId',
   'UnitPrice',
   'Quantity']},
 {'table_name': 'media_types', 'column_names': ['MediaTypeId', 'Name']},
 {'table_name': 'playlists', 'column_names': ['PlaylistId', 'Name']},
 {'table_name': 'playlist_track', 'column_names': ['PlaylistId', 'TrackId']},
 {'table_name': 'tracks',
  'column_names': ['TrackId',
   'Name',
   'AlbumId',
   'MediaTypeId',
   'GenreId',
   'Composer',
   'Milliseconds',
   'Bytes',
   'UnitPrice']},
 {'table_name': 'sqlite_stat1', 'column_names': ['tbl', 'idx', 'stat']}]

database_schema_string结果如下:

'Table: albums\nColumns: AlbumId, Title, ArtistId\nTable: sqlite_sequence\nColumns: name, seq\nTable: artists\nColumns: ArtistId, Name\nTable: customers\nColumns: CustomerId, FirstName, LastName, Company, Address, City, State, Country, PostalCode, Phone, Fax, Email, SupportRepId\nTable: employees\nColumns: EmployeeId, LastName, FirstName, Title, ReportsTo, BirthDate, HireDate, Address, City, State, Country, PostalCode, Phone, Fax, Email\nTable: genres\nColumns: GenreId, Name\nTable: invoices\nColumns: InvoiceId, CustomerId, InvoiceDate, BillingAddress, BillingCity, BillingState, BillingCountry, BillingPostalCode, Total\nTable: invoice_items\nColumns: InvoiceLineId, InvoiceId, TrackId, UnitPrice, Quantity\nTable: media_types\nColumns: MediaTypeId, Name\nTable: playlists\nColumns: PlaylistId, Name\nTable: playlist_track\nColumns: PlaylistId, TrackId\nTable: tracks\nColumns: TrackId, Name, AlbumId, MediaTypeId, GenreId, Composer, Milliseconds, Bytes, UnitPrice\nTable: sqlite_stat1\nColumns: tbl, idx, stat'

定义相关函数

定义functions规范

注意,在定义functions规范时要将数据库的schema插入到函数规范中,这对模型来说是很重要的。

functions = [
    {
        "name": "ask_database",
        "description": "请使用以下函数来回答关于音乐的用户问题。输出结果应为一个完整的 SQL 查询。",
        "parameters": {
            "type": "object",
            "properties": {
                "query": {
                    "type": "string",
                    "description": f"""
                            用于提取信息以回答用户问题的 SQL 查询。
                            SQL 查询应使用以下数据库模式编写:
                            {database_schema_string}
                            查询应以纯文本形式返回,而不是 JSON 格式。
                            """,
                }
            },
            "required": ["query"],
        },
    }
]

定义执行SQL语句的函数

# ChatGPT 生成的query会输入到 ask_database
def ask_database(conn, query):
    """Function to query SQLite database with a provided SQL query."""
    try:
        results = str(conn.execute(query).fetchall())
    except Exception as e:
        results = f"query failed with error: {e}"
    return results

# 根据 message["function_call"]["name"] 判断函数调用时机
def execute_function_call(message):
    if message["function_call"]["name"] == "ask_database":
        query = json.loads(message["function_call"]["arguments"])["query"]
        results = ask_database(conn, query)
    else:
        results = f"Error: function {message['function_call']['name']} does not exist"
    return results

示例1:查询曲目数量Top5的艺术家

messages = []
messages.append({"role": "system", "content": "基于 Chinook 音乐数据库生成 SQL 查询来回答用户的问题。"})
messages.append({"role": "user", "content": "你好,按照曲目数量,排名前5位的艺术家有谁?"})
chat_response = chat_completion_request(messages, functions)
print("chat_response=", chat_response.json())
assistant_message = chat_response.json()["choices"][0]["message"]
print("assistant_message=", assistant_message)
messages.append(assistant_message)
if assistant_message.get("function_call"):
    results = execute_function_call(assistant_message)
    messages.append({"role": "function", "name": assistant_message["function_call"]["name"], "content": results})
pretty_print_conversation(messages)

返回结果如下:

chat_response= {'id': 'chatcmpl-7TQdhItIU3FEgvvYliCGRoD0QetgX', 'object': 'chat.completion', 'created': 1687248269, 'model': 'gpt-3.5-turbo-0613', 'choices': [{'index': 0, 'message': {'role': 'assistant', 'content': None, 'function_call': {'name': 'ask_database', 'arguments': '{\n  "query": "SELECT artists.Name, COUNT(tracks.TrackId) AS NumTracks FROM artists JOIN albums ON artists.ArtistId = albums.ArtistId JOIN tracks ON albums.AlbumID = tracks.AlbumId GROUP BY artists.ArtistId ORDER BY NumTracks DESC LIMIT 5;"\n}'}}, 'finish_reason': 'function_call'}], 'usage': {'prompt_tokens': 448, 'completion_tokens': 67, 'total_tokens': 515}}

assistant_message= {'role': 'assistant', 'content': None, 'function_call': {'name': 'ask_database', 'arguments': '{\n  "query": "SELECT artists.Name, COUNT(tracks.TrackId) AS NumTracks FROM artists JOIN albums ON artists.ArtistId = albums.ArtistId JOIN tracks ON albums.AlbumID = tracks.AlbumId GROUP BY artists.ArtistId ORDER BY NumTracks DESC LIMIT 5;"\n}'}}

最终pretty_print_conversation(messages)的结果如下:

system: 基于 Chinook 音乐数据库生成 SQL 查询来回答用户的问题。

user: 你好,按照曲目数量,排名前5位的艺术家有谁?

assistant: {'name': 'ask_database', 'arguments': '{\n  "query": "SELECT artists.Name, COUNT(tracks.TrackId) AS NumTracks FROM artists JOIN albums ON artists.ArtistId = albums.ArtistId JOIN tracks ON albums.AlbumID = tracks.AlbumId GROUP BY artists.ArtistId ORDER BY NumTracks DESC LIMIT 5;"\n}'}

function (ask_database): [('Iron Maiden', 213), ('U2', 135), ('Led Zeppelin', 114), ('Metallica', 112), ('Deep Purple', 92)]

示例2:查询哪个专辑曲目最多

messages.append({"role": "user", "content": "曲目最多的专辑的名称是什么?"})
chat_response = chat_completion_request(messages, functions)
assistant_message = chat_response.json()["choices"][0]["message"]
messages.append(assistant_message)
if assistant_message.get("function_call"):
    results = execute_function_call(assistant_message)
    messages.append({"role": "function", "content": results, "name": assistant_message["function_call"]["name"]})
pretty_print_conversation(messages)

输出结果如下:

system: 基于 Chinook 音乐数据库生成 SQL 查询来回答用户的问题。

user: 你好,按照曲目数量,排名前5位的艺术家有谁?

assistant: {'name': 'ask_database', 'arguments': '{\n  "query": "SELECT artists.Name, COUNT(tracks.TrackId) AS NumTracks FROM artists JOIN albums ON artists.ArtistId = albums.ArtistId JOIN tracks ON albums.AlbumID = tracks.AlbumId GROUP BY artists.ArtistId ORDER BY NumTracks DESC LIMIT 5;"\n}'}

function (ask_database): [('Iron Maiden', 213), ('U2', 135), ('Led Zeppelin', 114), ('Metallica', 112), ('Deep Purple', 92)]

user: 曲目最多的专辑的名称是什么?

assistant: {'name': 'ask_database', 'arguments': '{\n  "query": "SELECT albums.Title, COUNT(tracks.TrackId) AS NumTracks FROM albums JOIN tracks ON albums.AlbumId = tracks.AlbumId GROUP BY albums.AlbumId ORDER BY NumTracks DESC LIMIT 1;"\n}'}

function (ask_database): [('Greatest Hits', 57)]

小结

通过上述示例可以确切感受openai函数调用功能的强大,这也为开发者构建更多稳健服务提供更强的保障。

附录

import json
import openai
import requests
import os
from tenacity import retry, wait_random_exponential, stop_after_attempt
from termcolor import colored

GPT_MODEL = "gpt-3.5-turbo-0613"
openai.api_key  = "sk-xxx"
os.environ['HTTP_PROXY'] = "xxx"
os.environ['HTTPS_PROXY'] = "xxx"

# 调用API的重试机制
@retry(wait=wait_random_exponential(min=1, max=40), stop=stop_after_attempt(3))
def chat_completion_request(messages, functions=None, function_call=None, model=GPT_MODEL):
    headers = {
        "Content-Type": "application/json",
        "Authorization": "Bearer " + openai.api_key,
    }
    json_data = {"model": model, "messages": messages}
    if functions is not None:
        json_data.update({"functions": functions})
    if function_call is not None:
        json_data.update({"function_call": function_call})
    try:
        response = requests.post(
            "https://api.openai.com/v1/chat/completions",
            headers=headers,
            json=json_data,
        )
        return response
    except Exception as e:
        print("Unable to generate ChatCompletion response")
        print(f"Exception: {e}")
        return e


# 处理输出,方便阅读
def pretty_print_conversation(messages):
    role_to_color = {
        "system": "red",
        "user": "green",
        "assistant": "blue",
        "function": "magenta",
    }
    formatted_messages = []
    for message in messages:
        if message["role"] == "system":
            formatted_messages.append(f"system: {message['content']}\n")
        elif message["role"] == "user":
            formatted_messages.append(f"user: {message['content']}\n")
        elif message["role"] == "assistant" and message.get("function_call"):
            formatted_messages.append(f"assistant: {message['function_call']}\n")
        elif message["role"] == "assistant" and not message.get("function_call"):
            formatted_messages.append(f"assistant: {message['content']}\n")
        elif message["role"] == "function":
            formatted_messages.append(f"function ({message['name']}): {message['content']}\n")
    for formatted_message in formatted_messages:
        print(
            colored(
                formatted_message,
                role_to_color[messages[formatted_messages.index(formatted_message)]["role"]],
            )
        )
相关文章
|
2月前
|
存储 NoSQL 关系型数据库
PolarDB开源数据库进阶课17 集成数据湖功能
本文介绍了如何在PolarDB数据库中接入pg_duckdb、pg_mooncake插件以支持数据湖功能, 可以读写对象存储的远程数据, 支持csv, parquet等格式, 支持delta等框架, 并显著提升OLAP性能。
124 1
|
3月前
|
Cloud Native 关系型数据库 分布式数据库
让PolarDB更了解您--PolarDB云原生数据库核心功能体验馆
让PolarDB更了解您——PolarDB云原生数据库核心功能体验馆,由阿里云数据库产品事业部负责人宋震分享。内容涵盖PolarDB技术布局、开源进展及体验馆三大部分。技术布局包括云计算加速数据库演进、数据处理需求带来的变革、软硬协同优化等;开源部分介绍了兼容MySQL和PostgreSQL的两款产品;体验馆则通过实际操作让用户直观感受Serverless、无感切换、SQL2Map等功能。
175 7
|
1月前
|
SQL Linux 数据库
【YashanDB知识库】崖山数据库Outline功能验证
本文来自YashanDB官网,主要测试了数据库优化器在不同场景下优先使用outline计划的功能。测试环境包括相同版本新增数据、绑定参数执行、单机主备架构以及数据库版本升级等场景。通过创建表、插入数据、收集统计信息和创建outline等步骤,验证了在各种情况下优化器均能优先采用存储的outline计划。测试结果表明,即使统计信息失效或数据库版本升级,outline功能依然稳定有效,确保查询计划的一致性和性能优化。详情可见[原文链接](https://www.yashandb.com/newsinfo/7488286.html?templateId=1718516)。
【YashanDB知识库】崖山数据库Outline功能验证
|
1月前
|
数据库连接 PHP 数据库
【YashanDB知识库】PHP使用ODBC使用数据库绑定参数功能异常
【YashanDB知识库】PHP使用ODBC使用数据库绑定参数功能异常
|
3月前
|
人工智能 知识图谱 Docker
KAG:增强 LLM 的专业能力!蚂蚁集团推出专业领域知识增强框架,支持逻辑推理和多跳问答
KAG 是蚂蚁集团推出的专业领域知识服务框架,通过知识增强提升大型语言模型在特定领域的问答性能,支持逻辑推理和多跳事实问答,显著提升推理和问答的准确性和效率。
961 46
KAG:增强 LLM 的专业能力!蚂蚁集团推出专业领域知识增强框架,支持逻辑推理和多跳问答
|
1月前
|
PHP 数据库
【YashanDB知识库】PHP使用OCI接口使用数据库绑定参数功能异常
【YashanDB知识库】PHP使用OCI接口使用数据库绑定参数功能异常
|
1月前
|
存储 NoSQL 关系型数据库
Apifox与Apipost数据库连接功能详细对比,让接口管理更高效!
Apipost 更加全面:无论是关系型还是非关系型数据库,它都为开发者提供了一站式解决方案,非常适合数据库架构复杂的大型项目。相对来说,Apifox偏重关系型分析和管理:若项目主要需求在于管理关系型数据库,而对非关系型的依赖较小,Apifox倒是可以应付。
56 2
|
2月前
|
存储 关系型数据库 分布式数据库
PolarDB开源数据库进阶课16 接入PostGIS全功能及应用举例
本文介绍了如何在PolarDB数据库中接入PostGIS插件全功能,实现地理空间数据处理。此外,文章还提供了使用PostGIS生成泰森多边形(Voronoi diagram)的具体示例,帮助用户理解其应用场景及操作方法。
83 1
|
1月前
|
SQL 关系型数据库 数据库连接
|
3月前
|
存储 Java 数据库连接
时序数据库TDengine 3.3.5.0 发布:高并发支持与增量备份功能引领新升级
TDengine 3.3.5.0 版本正式发布,带来多项更新与优化。新特性包括提升 MQTT 稳定性和高并发性能、新增 taosX 增量备份与恢复、支持 JDBC 和 Rust 连接器 STMT2 接口、灵活配置 Grafana Dashboard 等。性能优化涵盖查询内存管控、多级存储迁移、强密码策略等,全面提升时序数据管理的效率和可靠性。欢迎下载体验并提出宝贵意见。
87 5
下一篇
oss创建bucket