LLM系列 | 14: 实测OpenAI函数调用功能:以数据库问答为例

本文涉及的产品
实时计算 Flink 版,5000CU*H 3个月
检索分析服务 Elasticsearch 版,2核4GB开发者规格 1个月
大数据开发治理平台 DataWorks,不限时长
简介: 今天这篇小作文以数据库问答(Text2SQL)为例进一步介绍ChatGPT的函数调用。本文将介绍如何将模型生成的结果输入到自定义的函数中,并利用该功能实现数据库问答功能。

简介

黑云翻墨未遮山,白雨跳珠乱入船。
黑云翻墨未遮山,白雨跳珠乱入船.jpg

小伙伴们好,我是微信公众号《小窗幽记机器学习》的小编:划龙舟的小男孩。紧接前面几篇ChatGPT Prompt工程和应用系列文章:

更多、更新文章欢迎关注微信公众号:小窗幽记机器学习。后续会持续整理模型加速、模型部署、模型压缩、LLM、AI艺术等系列专题,敬请关注。

今天这篇小作文以数据库问答(Text2SQL)为例进一步介绍ChatGPT的函数调用。本文将介绍如何将模型生成的结果输入到自定义的函数中,并利用该功能实现数据库问答功能。为简单起见,这里将使用Chinook 示例数据库

需要特别注意:
生产环境中,生成的SQL可能存在较高风险。因为模型在生成正确的 SQL 这方面暂不完全可靠,小伙伴们评估谨慎使用

数据库相关

环境相关设置及其辅助函数代码请于文末附录部分。以下直接介绍示例数据库相关细节。

获取 Chinook 数据相关的信息:

import sqlite3

conn = sqlite3.connect("data/chinook.db")
print("Opened database successfully")

def get_table_names(conn):
    """Return a list of table names."""
    table_names = []
    tables = conn.execute("SELECT name FROM sqlite_master WHERE type='table';")
    for table in tables.fetchall():
        table_names.append(table[0])
    return table_names


def get_column_names(conn, table_name):
    """Return a list of column names."""
    column_names = []
    columns = conn.execute(f"PRAGMA table_info('{table_name}');").fetchall()
    for col in columns:
        column_names.append(col[1])
    return column_names


def get_database_info(conn):
    """Return a list of dicts containing the table name and columns for each table in the database."""
    table_dicts = []
    for table_name in get_table_names(conn):
        columns_names = get_column_names(conn, table_name)
        table_dicts.append({"table_name": table_name, "column_names": columns_names})
    return table_dicts

获取 db 中的table

可以获取chinook db中有哪些 table:

table_names = get_table_names(conn)
print("table_names=", table_names)

输出结果如下:

table_names= ['albums', 'sqlite_sequence', 'artists', 'customers', 'employees', 'genres', 'invoices', 'invoice_items', 'media_types', 'playlists', 'playlist_track', 'tracks', 'sqlite_stat1']

获取各 table 的schema

database_schema_dict = get_database_info(conn)
database_schema_string = "\n".join(
    [
        f"Table: {table['table_name']}\nColumns: {', '.join(table['column_names'])}"
        for table in database_schema_dict
    ]
)

database_schema_dict结果如下:

[{'table_name': 'albums', 'column_names': ['AlbumId', 'Title', 'ArtistId']},
 {'table_name': 'sqlite_sequence', 'column_names': ['name', 'seq']},
 {'table_name': 'artists', 'column_names': ['ArtistId', 'Name']},
 {'table_name': 'customers',
  'column_names': ['CustomerId',
   'FirstName',
   'LastName',
   'Company',
   'Address',
   'City',
   'State',
   'Country',
   'PostalCode',
   'Phone',
   'Fax',
   'Email',
   'SupportRepId']},
 {'table_name': 'employees',
  'column_names': ['EmployeeId',
   'LastName',
   'FirstName',
   'Title',
   'ReportsTo',
   'BirthDate',
   'HireDate',
   'Address',
   'City',
   'State',
   'Country',
   'PostalCode',
   'Phone',
   'Fax',
   'Email']},
 {'table_name': 'genres', 'column_names': ['GenreId', 'Name']},
 {'table_name': 'invoices',
  'column_names': ['InvoiceId',
   'CustomerId',
   'InvoiceDate',
   'BillingAddress',
   'BillingCity',
   'BillingState',
   'BillingCountry',
   'BillingPostalCode',
   'Total']},
 {'table_name': 'invoice_items',
  'column_names': ['InvoiceLineId',
   'InvoiceId',
   'TrackId',
   'UnitPrice',
   'Quantity']},
 {'table_name': 'media_types', 'column_names': ['MediaTypeId', 'Name']},
 {'table_name': 'playlists', 'column_names': ['PlaylistId', 'Name']},
 {'table_name': 'playlist_track', 'column_names': ['PlaylistId', 'TrackId']},
 {'table_name': 'tracks',
  'column_names': ['TrackId',
   'Name',
   'AlbumId',
   'MediaTypeId',
   'GenreId',
   'Composer',
   'Milliseconds',
   'Bytes',
   'UnitPrice']},
 {'table_name': 'sqlite_stat1', 'column_names': ['tbl', 'idx', 'stat']}]

database_schema_string结果如下:

'Table: albums\nColumns: AlbumId, Title, ArtistId\nTable: sqlite_sequence\nColumns: name, seq\nTable: artists\nColumns: ArtistId, Name\nTable: customers\nColumns: CustomerId, FirstName, LastName, Company, Address, City, State, Country, PostalCode, Phone, Fax, Email, SupportRepId\nTable: employees\nColumns: EmployeeId, LastName, FirstName, Title, ReportsTo, BirthDate, HireDate, Address, City, State, Country, PostalCode, Phone, Fax, Email\nTable: genres\nColumns: GenreId, Name\nTable: invoices\nColumns: InvoiceId, CustomerId, InvoiceDate, BillingAddress, BillingCity, BillingState, BillingCountry, BillingPostalCode, Total\nTable: invoice_items\nColumns: InvoiceLineId, InvoiceId, TrackId, UnitPrice, Quantity\nTable: media_types\nColumns: MediaTypeId, Name\nTable: playlists\nColumns: PlaylistId, Name\nTable: playlist_track\nColumns: PlaylistId, TrackId\nTable: tracks\nColumns: TrackId, Name, AlbumId, MediaTypeId, GenreId, Composer, Milliseconds, Bytes, UnitPrice\nTable: sqlite_stat1\nColumns: tbl, idx, stat'

定义相关函数

定义functions规范

注意,在定义functions规范时要将数据库的schema插入到函数规范中,这对模型来说是很重要的。

functions = [
    {
        "name": "ask_database",
        "description": "请使用以下函数来回答关于音乐的用户问题。输出结果应为一个完整的 SQL 查询。",
        "parameters": {
            "type": "object",
            "properties": {
                "query": {
                    "type": "string",
                    "description": f"""
                            用于提取信息以回答用户问题的 SQL 查询。
                            SQL 查询应使用以下数据库模式编写:
                            {database_schema_string}
                            查询应以纯文本形式返回,而不是 JSON 格式。
                            """,
                }
            },
            "required": ["query"],
        },
    }
]

定义执行SQL语句的函数

# ChatGPT 生成的query会输入到 ask_database
def ask_database(conn, query):
    """Function to query SQLite database with a provided SQL query."""
    try:
        results = str(conn.execute(query).fetchall())
    except Exception as e:
        results = f"query failed with error: {e}"
    return results

# 根据 message["function_call"]["name"] 判断函数调用时机
def execute_function_call(message):
    if message["function_call"]["name"] == "ask_database":
        query = json.loads(message["function_call"]["arguments"])["query"]
        results = ask_database(conn, query)
    else:
        results = f"Error: function {message['function_call']['name']} does not exist"
    return results

示例1:查询曲目数量Top5的艺术家

messages = []
messages.append({"role": "system", "content": "基于 Chinook 音乐数据库生成 SQL 查询来回答用户的问题。"})
messages.append({"role": "user", "content": "你好,按照曲目数量,排名前5位的艺术家有谁?"})
chat_response = chat_completion_request(messages, functions)
print("chat_response=", chat_response.json())
assistant_message = chat_response.json()["choices"][0]["message"]
print("assistant_message=", assistant_message)
messages.append(assistant_message)
if assistant_message.get("function_call"):
    results = execute_function_call(assistant_message)
    messages.append({"role": "function", "name": assistant_message["function_call"]["name"], "content": results})
pretty_print_conversation(messages)

返回结果如下:

chat_response= {'id': 'chatcmpl-7TQdhItIU3FEgvvYliCGRoD0QetgX', 'object': 'chat.completion', 'created': 1687248269, 'model': 'gpt-3.5-turbo-0613', 'choices': [{'index': 0, 'message': {'role': 'assistant', 'content': None, 'function_call': {'name': 'ask_database', 'arguments': '{\n  "query": "SELECT artists.Name, COUNT(tracks.TrackId) AS NumTracks FROM artists JOIN albums ON artists.ArtistId = albums.ArtistId JOIN tracks ON albums.AlbumID = tracks.AlbumId GROUP BY artists.ArtistId ORDER BY NumTracks DESC LIMIT 5;"\n}'}}, 'finish_reason': 'function_call'}], 'usage': {'prompt_tokens': 448, 'completion_tokens': 67, 'total_tokens': 515}}

assistant_message= {'role': 'assistant', 'content': None, 'function_call': {'name': 'ask_database', 'arguments': '{\n  "query": "SELECT artists.Name, COUNT(tracks.TrackId) AS NumTracks FROM artists JOIN albums ON artists.ArtistId = albums.ArtistId JOIN tracks ON albums.AlbumID = tracks.AlbumId GROUP BY artists.ArtistId ORDER BY NumTracks DESC LIMIT 5;"\n}'}}

最终pretty_print_conversation(messages)的结果如下:

system: 基于 Chinook 音乐数据库生成 SQL 查询来回答用户的问题。

user: 你好,按照曲目数量,排名前5位的艺术家有谁?

assistant: {'name': 'ask_database', 'arguments': '{\n  "query": "SELECT artists.Name, COUNT(tracks.TrackId) AS NumTracks FROM artists JOIN albums ON artists.ArtistId = albums.ArtistId JOIN tracks ON albums.AlbumID = tracks.AlbumId GROUP BY artists.ArtistId ORDER BY NumTracks DESC LIMIT 5;"\n}'}

function (ask_database): [('Iron Maiden', 213), ('U2', 135), ('Led Zeppelin', 114), ('Metallica', 112), ('Deep Purple', 92)]

示例2:查询哪个专辑曲目最多

messages.append({"role": "user", "content": "曲目最多的专辑的名称是什么?"})
chat_response = chat_completion_request(messages, functions)
assistant_message = chat_response.json()["choices"][0]["message"]
messages.append(assistant_message)
if assistant_message.get("function_call"):
    results = execute_function_call(assistant_message)
    messages.append({"role": "function", "content": results, "name": assistant_message["function_call"]["name"]})
pretty_print_conversation(messages)

输出结果如下:

system: 基于 Chinook 音乐数据库生成 SQL 查询来回答用户的问题。

user: 你好,按照曲目数量,排名前5位的艺术家有谁?

assistant: {'name': 'ask_database', 'arguments': '{\n  "query": "SELECT artists.Name, COUNT(tracks.TrackId) AS NumTracks FROM artists JOIN albums ON artists.ArtistId = albums.ArtistId JOIN tracks ON albums.AlbumID = tracks.AlbumId GROUP BY artists.ArtistId ORDER BY NumTracks DESC LIMIT 5;"\n}'}

function (ask_database): [('Iron Maiden', 213), ('U2', 135), ('Led Zeppelin', 114), ('Metallica', 112), ('Deep Purple', 92)]

user: 曲目最多的专辑的名称是什么?

assistant: {'name': 'ask_database', 'arguments': '{\n  "query": "SELECT albums.Title, COUNT(tracks.TrackId) AS NumTracks FROM albums JOIN tracks ON albums.AlbumId = tracks.AlbumId GROUP BY albums.AlbumId ORDER BY NumTracks DESC LIMIT 1;"\n}'}

function (ask_database): [('Greatest Hits', 57)]

小结

通过上述示例可以确切感受openai函数调用功能的强大,这也为开发者构建更多稳健服务提供更强的保障。

附录

import json
import openai
import requests
import os
from tenacity import retry, wait_random_exponential, stop_after_attempt
from termcolor import colored

GPT_MODEL = "gpt-3.5-turbo-0613"
openai.api_key  = "sk-xxx"
os.environ['HTTP_PROXY'] = "xxx"
os.environ['HTTPS_PROXY'] = "xxx"

# 调用API的重试机制
@retry(wait=wait_random_exponential(min=1, max=40), stop=stop_after_attempt(3))
def chat_completion_request(messages, functions=None, function_call=None, model=GPT_MODEL):
    headers = {
        "Content-Type": "application/json",
        "Authorization": "Bearer " + openai.api_key,
    }
    json_data = {"model": model, "messages": messages}
    if functions is not None:
        json_data.update({"functions": functions})
    if function_call is not None:
        json_data.update({"function_call": function_call})
    try:
        response = requests.post(
            "https://api.openai.com/v1/chat/completions",
            headers=headers,
            json=json_data,
        )
        return response
    except Exception as e:
        print("Unable to generate ChatCompletion response")
        print(f"Exception: {e}")
        return e


# 处理输出,方便阅读
def pretty_print_conversation(messages):
    role_to_color = {
        "system": "red",
        "user": "green",
        "assistant": "blue",
        "function": "magenta",
    }
    formatted_messages = []
    for message in messages:
        if message["role"] == "system":
            formatted_messages.append(f"system: {message['content']}\n")
        elif message["role"] == "user":
            formatted_messages.append(f"user: {message['content']}\n")
        elif message["role"] == "assistant" and message.get("function_call"):
            formatted_messages.append(f"assistant: {message['function_call']}\n")
        elif message["role"] == "assistant" and not message.get("function_call"):
            formatted_messages.append(f"assistant: {message['content']}\n")
        elif message["role"] == "function":
            formatted_messages.append(f"function ({message['name']}): {message['content']}\n")
    for formatted_message in formatted_messages:
        print(
            colored(
                formatted_message,
                role_to_color[messages[formatted_messages.index(formatted_message)]["role"]],
            )
        )
相关文章
|
1月前
|
前端开发 关系型数据库 数据库
使用 Flask 连接数据库和用户登录功能进行数据库的CRUD
使用 Flask 连接数据库和用户登录功能进行数据库的CRUD
43 0
|
1月前
|
数据库 索引
评论功能里数据库的设计
【4月更文挑战第2天】本文探讨了评论系统的树形结构设计,提出了四种方法:邻接表、分段式path、Nested Set和Closure Table。针对评论业务功能,如加载评论页和查看回复,优先考虑邻接表和分段式path。采用邻接表思路,设计了评论表结构,包括Uid、Biz、BizID、RootID、PID、Content、索引和级联删除规则。同时提到了索引设计,如Uid、Biz+BizID、PID和Ctime/Utime,以优化查询性能。
64 3
|
1月前
|
存储 SQL 关系型数据库
关系型数据库强大的查询功能
【5月更文挑战第9天】关系型数据库强大的查询功能
28 3
|
1月前
|
存储 安全 算法
【软件设计师备考 专题 】数据库的控制功能(并发控制、恢复、安全性、完整性)
【软件设计师备考 专题 】数据库的控制功能(并发控制、恢复、安全性、完整性)
65 0
|
3天前
|
存储 缓存 安全
LLM应用实战:当图谱问答(KBQA)集成大模型(三)
本文主要是针对KBQA方案基于LLM实现存在的问题进行优化,主要涉及到响应时间提升优化以及多轮对话效果优化,提供了具体的优化方案以及相应的prompt。
13 1
|
6天前
|
Prometheus 监控 关系型数据库
数据库实时监控功能
【6月更文挑战第9天】数据库实时监控功能
16 4
|
9天前
|
存储 监控 数据管理
数据库原理与应用——简答题练习(数据管理技术发展、数据库主要特征、数据模型、关系模型、实体性之间的关系、DBMS的功能、相关术语解释、数据库系统)
数据库原理与应用——简答题练习(数据管理技术发展、数据库主要特征、数据模型、关系模型、实体性之间的关系、DBMS的功能、相关术语解释、数据库系统)
22 0
|
24天前
|
JSON 自然语言处理 API
【LLM落地应用实战】LLM + TextIn文档解析技术实测
文档解析技术是从这些海量且复杂的数据中高效准确地提取有价值信息的关键。它从输入文档图像开始,经过图像处理、版面分析、内容识别和语义理解等流程,最终输出结构化电子文档或语义信息。通过文档解析技术,我们能够深入理解文档的结构、内容和主题,使得信息更易于检索、分析和利用。
|
27天前
|
SQL 关系型数据库 开发工具
Beekeeper Studio是一个多功能的数据库管理和开发工具
【5月更文挑战第19天】Beekeeper Studio是一个多功能的数据库管理和开发工具
39 5
|
1月前
|
存储 安全 机器人
【LLM】智能学生顾问构建技术学习(Lyrz SDK + OpenAI API )
【5月更文挑战第13天】智能学生顾问构建技术学习(Lyrz SDK + OpenAI API )
45 1

热门文章

最新文章