LLM系列 | 14: 实测OpenAI函数调用功能:以数据库问答为例

本文涉及的产品
实时计算 Flink 版,5000CU*H 3个月
检索分析服务 Elasticsearch 版,2核4GB开发者规格 1个月
大数据开发治理平台 DataWorks,不限时长
简介: 今天这篇小作文以数据库问答(Text2SQL)为例进一步介绍ChatGPT的函数调用。本文将介绍如何将模型生成的结果输入到自定义的函数中,并利用该功能实现数据库问答功能。

简介

黑云翻墨未遮山,白雨跳珠乱入船。
黑云翻墨未遮山,白雨跳珠乱入船.jpg

小伙伴们好,我是微信公众号《小窗幽记机器学习》的小编:划龙舟的小男孩。紧接前面几篇ChatGPT Prompt工程和应用系列文章:

更多、更新文章欢迎关注微信公众号:小窗幽记机器学习。后续会持续整理模型加速、模型部署、模型压缩、LLM、AI艺术等系列专题,敬请关注。

今天这篇小作文以数据库问答(Text2SQL)为例进一步介绍ChatGPT的函数调用。本文将介绍如何将模型生成的结果输入到自定义的函数中,并利用该功能实现数据库问答功能。为简单起见,这里将使用Chinook 示例数据库

需要特别注意:
生产环境中,生成的SQL可能存在较高风险。因为模型在生成正确的 SQL 这方面暂不完全可靠,小伙伴们评估谨慎使用

数据库相关

环境相关设置及其辅助函数代码请于文末附录部分。以下直接介绍示例数据库相关细节。

获取 Chinook 数据相关的信息:

import sqlite3

conn = sqlite3.connect("data/chinook.db")
print("Opened database successfully")

def get_table_names(conn):
    """Return a list of table names."""
    table_names = []
    tables = conn.execute("SELECT name FROM sqlite_master WHERE type='table';")
    for table in tables.fetchall():
        table_names.append(table[0])
    return table_names


def get_column_names(conn, table_name):
    """Return a list of column names."""
    column_names = []
    columns = conn.execute(f"PRAGMA table_info('{table_name}');").fetchall()
    for col in columns:
        column_names.append(col[1])
    return column_names


def get_database_info(conn):
    """Return a list of dicts containing the table name and columns for each table in the database."""
    table_dicts = []
    for table_name in get_table_names(conn):
        columns_names = get_column_names(conn, table_name)
        table_dicts.append({"table_name": table_name, "column_names": columns_names})
    return table_dicts

获取 db 中的table

可以获取chinook db中有哪些 table:

table_names = get_table_names(conn)
print("table_names=", table_names)

输出结果如下:

table_names= ['albums', 'sqlite_sequence', 'artists', 'customers', 'employees', 'genres', 'invoices', 'invoice_items', 'media_types', 'playlists', 'playlist_track', 'tracks', 'sqlite_stat1']

获取各 table 的schema

database_schema_dict = get_database_info(conn)
database_schema_string = "\n".join(
    [
        f"Table: {table['table_name']}\nColumns: {', '.join(table['column_names'])}"
        for table in database_schema_dict
    ]
)

database_schema_dict结果如下:

[{'table_name': 'albums', 'column_names': ['AlbumId', 'Title', 'ArtistId']},
 {'table_name': 'sqlite_sequence', 'column_names': ['name', 'seq']},
 {'table_name': 'artists', 'column_names': ['ArtistId', 'Name']},
 {'table_name': 'customers',
  'column_names': ['CustomerId',
   'FirstName',
   'LastName',
   'Company',
   'Address',
   'City',
   'State',
   'Country',
   'PostalCode',
   'Phone',
   'Fax',
   'Email',
   'SupportRepId']},
 {'table_name': 'employees',
  'column_names': ['EmployeeId',
   'LastName',
   'FirstName',
   'Title',
   'ReportsTo',
   'BirthDate',
   'HireDate',
   'Address',
   'City',
   'State',
   'Country',
   'PostalCode',
   'Phone',
   'Fax',
   'Email']},
 {'table_name': 'genres', 'column_names': ['GenreId', 'Name']},
 {'table_name': 'invoices',
  'column_names': ['InvoiceId',
   'CustomerId',
   'InvoiceDate',
   'BillingAddress',
   'BillingCity',
   'BillingState',
   'BillingCountry',
   'BillingPostalCode',
   'Total']},
 {'table_name': 'invoice_items',
  'column_names': ['InvoiceLineId',
   'InvoiceId',
   'TrackId',
   'UnitPrice',
   'Quantity']},
 {'table_name': 'media_types', 'column_names': ['MediaTypeId', 'Name']},
 {'table_name': 'playlists', 'column_names': ['PlaylistId', 'Name']},
 {'table_name': 'playlist_track', 'column_names': ['PlaylistId', 'TrackId']},
 {'table_name': 'tracks',
  'column_names': ['TrackId',
   'Name',
   'AlbumId',
   'MediaTypeId',
   'GenreId',
   'Composer',
   'Milliseconds',
   'Bytes',
   'UnitPrice']},
 {'table_name': 'sqlite_stat1', 'column_names': ['tbl', 'idx', 'stat']}]

database_schema_string结果如下:

'Table: albums\nColumns: AlbumId, Title, ArtistId\nTable: sqlite_sequence\nColumns: name, seq\nTable: artists\nColumns: ArtistId, Name\nTable: customers\nColumns: CustomerId, FirstName, LastName, Company, Address, City, State, Country, PostalCode, Phone, Fax, Email, SupportRepId\nTable: employees\nColumns: EmployeeId, LastName, FirstName, Title, ReportsTo, BirthDate, HireDate, Address, City, State, Country, PostalCode, Phone, Fax, Email\nTable: genres\nColumns: GenreId, Name\nTable: invoices\nColumns: InvoiceId, CustomerId, InvoiceDate, BillingAddress, BillingCity, BillingState, BillingCountry, BillingPostalCode, Total\nTable: invoice_items\nColumns: InvoiceLineId, InvoiceId, TrackId, UnitPrice, Quantity\nTable: media_types\nColumns: MediaTypeId, Name\nTable: playlists\nColumns: PlaylistId, Name\nTable: playlist_track\nColumns: PlaylistId, TrackId\nTable: tracks\nColumns: TrackId, Name, AlbumId, MediaTypeId, GenreId, Composer, Milliseconds, Bytes, UnitPrice\nTable: sqlite_stat1\nColumns: tbl, idx, stat'

定义相关函数

定义functions规范

注意,在定义functions规范时要将数据库的schema插入到函数规范中,这对模型来说是很重要的。

functions = [
    {
        "name": "ask_database",
        "description": "请使用以下函数来回答关于音乐的用户问题。输出结果应为一个完整的 SQL 查询。",
        "parameters": {
            "type": "object",
            "properties": {
                "query": {
                    "type": "string",
                    "description": f"""
                            用于提取信息以回答用户问题的 SQL 查询。
                            SQL 查询应使用以下数据库模式编写:
                            {database_schema_string}
                            查询应以纯文本形式返回,而不是 JSON 格式。
                            """,
                }
            },
            "required": ["query"],
        },
    }
]

定义执行SQL语句的函数

# ChatGPT 生成的query会输入到 ask_database
def ask_database(conn, query):
    """Function to query SQLite database with a provided SQL query."""
    try:
        results = str(conn.execute(query).fetchall())
    except Exception as e:
        results = f"query failed with error: {e}"
    return results

# 根据 message["function_call"]["name"] 判断函数调用时机
def execute_function_call(message):
    if message["function_call"]["name"] == "ask_database":
        query = json.loads(message["function_call"]["arguments"])["query"]
        results = ask_database(conn, query)
    else:
        results = f"Error: function {message['function_call']['name']} does not exist"
    return results

示例1:查询曲目数量Top5的艺术家

messages = []
messages.append({"role": "system", "content": "基于 Chinook 音乐数据库生成 SQL 查询来回答用户的问题。"})
messages.append({"role": "user", "content": "你好,按照曲目数量,排名前5位的艺术家有谁?"})
chat_response = chat_completion_request(messages, functions)
print("chat_response=", chat_response.json())
assistant_message = chat_response.json()["choices"][0]["message"]
print("assistant_message=", assistant_message)
messages.append(assistant_message)
if assistant_message.get("function_call"):
    results = execute_function_call(assistant_message)
    messages.append({"role": "function", "name": assistant_message["function_call"]["name"], "content": results})
pretty_print_conversation(messages)

返回结果如下:

chat_response= {'id': 'chatcmpl-7TQdhItIU3FEgvvYliCGRoD0QetgX', 'object': 'chat.completion', 'created': 1687248269, 'model': 'gpt-3.5-turbo-0613', 'choices': [{'index': 0, 'message': {'role': 'assistant', 'content': None, 'function_call': {'name': 'ask_database', 'arguments': '{\n  "query": "SELECT artists.Name, COUNT(tracks.TrackId) AS NumTracks FROM artists JOIN albums ON artists.ArtistId = albums.ArtistId JOIN tracks ON albums.AlbumID = tracks.AlbumId GROUP BY artists.ArtistId ORDER BY NumTracks DESC LIMIT 5;"\n}'}}, 'finish_reason': 'function_call'}], 'usage': {'prompt_tokens': 448, 'completion_tokens': 67, 'total_tokens': 515}}

assistant_message= {'role': 'assistant', 'content': None, 'function_call': {'name': 'ask_database', 'arguments': '{\n  "query": "SELECT artists.Name, COUNT(tracks.TrackId) AS NumTracks FROM artists JOIN albums ON artists.ArtistId = albums.ArtistId JOIN tracks ON albums.AlbumID = tracks.AlbumId GROUP BY artists.ArtistId ORDER BY NumTracks DESC LIMIT 5;"\n}'}}

最终pretty_print_conversation(messages)的结果如下:

system: 基于 Chinook 音乐数据库生成 SQL 查询来回答用户的问题。

user: 你好,按照曲目数量,排名前5位的艺术家有谁?

assistant: {'name': 'ask_database', 'arguments': '{\n  "query": "SELECT artists.Name, COUNT(tracks.TrackId) AS NumTracks FROM artists JOIN albums ON artists.ArtistId = albums.ArtistId JOIN tracks ON albums.AlbumID = tracks.AlbumId GROUP BY artists.ArtistId ORDER BY NumTracks DESC LIMIT 5;"\n}'}

function (ask_database): [('Iron Maiden', 213), ('U2', 135), ('Led Zeppelin', 114), ('Metallica', 112), ('Deep Purple', 92)]

示例2:查询哪个专辑曲目最多

messages.append({"role": "user", "content": "曲目最多的专辑的名称是什么?"})
chat_response = chat_completion_request(messages, functions)
assistant_message = chat_response.json()["choices"][0]["message"]
messages.append(assistant_message)
if assistant_message.get("function_call"):
    results = execute_function_call(assistant_message)
    messages.append({"role": "function", "content": results, "name": assistant_message["function_call"]["name"]})
pretty_print_conversation(messages)

输出结果如下:

system: 基于 Chinook 音乐数据库生成 SQL 查询来回答用户的问题。

user: 你好,按照曲目数量,排名前5位的艺术家有谁?

assistant: {'name': 'ask_database', 'arguments': '{\n  "query": "SELECT artists.Name, COUNT(tracks.TrackId) AS NumTracks FROM artists JOIN albums ON artists.ArtistId = albums.ArtistId JOIN tracks ON albums.AlbumID = tracks.AlbumId GROUP BY artists.ArtistId ORDER BY NumTracks DESC LIMIT 5;"\n}'}

function (ask_database): [('Iron Maiden', 213), ('U2', 135), ('Led Zeppelin', 114), ('Metallica', 112), ('Deep Purple', 92)]

user: 曲目最多的专辑的名称是什么?

assistant: {'name': 'ask_database', 'arguments': '{\n  "query": "SELECT albums.Title, COUNT(tracks.TrackId) AS NumTracks FROM albums JOIN tracks ON albums.AlbumId = tracks.AlbumId GROUP BY albums.AlbumId ORDER BY NumTracks DESC LIMIT 1;"\n}'}

function (ask_database): [('Greatest Hits', 57)]

小结

通过上述示例可以确切感受openai函数调用功能的强大,这也为开发者构建更多稳健服务提供更强的保障。

附录

import json
import openai
import requests
import os
from tenacity import retry, wait_random_exponential, stop_after_attempt
from termcolor import colored

GPT_MODEL = "gpt-3.5-turbo-0613"
openai.api_key  = "sk-xxx"
os.environ['HTTP_PROXY'] = "xxx"
os.environ['HTTPS_PROXY'] = "xxx"

# 调用API的重试机制
@retry(wait=wait_random_exponential(min=1, max=40), stop=stop_after_attempt(3))
def chat_completion_request(messages, functions=None, function_call=None, model=GPT_MODEL):
    headers = {
        "Content-Type": "application/json",
        "Authorization": "Bearer " + openai.api_key,
    }
    json_data = {"model": model, "messages": messages}
    if functions is not None:
        json_data.update({"functions": functions})
    if function_call is not None:
        json_data.update({"function_call": function_call})
    try:
        response = requests.post(
            "https://api.openai.com/v1/chat/completions",
            headers=headers,
            json=json_data,
        )
        return response
    except Exception as e:
        print("Unable to generate ChatCompletion response")
        print(f"Exception: {e}")
        return e


# 处理输出,方便阅读
def pretty_print_conversation(messages):
    role_to_color = {
        "system": "red",
        "user": "green",
        "assistant": "blue",
        "function": "magenta",
    }
    formatted_messages = []
    for message in messages:
        if message["role"] == "system":
            formatted_messages.append(f"system: {message['content']}\n")
        elif message["role"] == "user":
            formatted_messages.append(f"user: {message['content']}\n")
        elif message["role"] == "assistant" and message.get("function_call"):
            formatted_messages.append(f"assistant: {message['function_call']}\n")
        elif message["role"] == "assistant" and not message.get("function_call"):
            formatted_messages.append(f"assistant: {message['content']}\n")
        elif message["role"] == "function":
            formatted_messages.append(f"function ({message['name']}): {message['content']}\n")
    for formatted_message in formatted_messages:
        print(
            colored(
                formatted_message,
                role_to_color[messages[formatted_messages.index(formatted_message)]["role"]],
            )
        )
相关文章
|
2天前
|
前端开发 关系型数据库 数据库
使用 Flask 连接数据库和用户登录功能进行数据库的CRUD
使用 Flask 连接数据库和用户登录功能进行数据库的CRUD
27 0
|
2天前
|
数据库 索引
评论功能里数据库的设计
【4月更文挑战第2天】本文探讨了评论系统的树形结构设计,提出了四种方法:邻接表、分段式path、Nested Set和Closure Table。针对评论业务功能,如加载评论页和查看回复,优先考虑邻接表和分段式path。采用邻接表思路,设计了评论表结构,包括Uid、Biz、BizID、RootID、PID、Content、索引和级联删除规则。同时提到了索引设计,如Uid、Biz+BizID、PID和Ctime/Utime,以优化查询性能。
45 3
|
2天前
|
人工智能 自然语言处理 开发工具
AI2 开源新 LLM,重新定义 open AI
艾伦人工智能研究所(Allen Institute for AI,简称 AI2)宣布推出一个名为 OLMo 7B 的新大语言模型,并开源发布了预训练数据和训练代码。OLMo 7B 被描述为 “一个真正开放的、最先进的大型语言模型”。
|
2天前
|
存储 安全 算法
【软件设计师备考 专题 】数据库的控制功能(并发控制、恢复、安全性、完整性)
【软件设计师备考 专题 】数据库的控制功能(并发控制、恢复、安全性、完整性)
60 0
|
2天前
|
存储 安全 机器人
【LLM】智能学生顾问构建技术学习(Lyrz SDK + OpenAI API )
【5月更文挑战第13天】智能学生顾问构建技术学习(Lyrz SDK + OpenAI API )
|
2天前
|
人工智能 API 开发工具
[译][AI OpenAI-doc] 函数调用 Beta
类似于聊天完成 API,助手 API 支持函数调用。函数调用允许您描述函数给助手 API,并让它智能地返回需要调用的函数及其参数。
|
2天前
|
数据管理 关系型数据库 MySQL
数据管理DMS产品使用合集之DMS可以接入其他平台的MySQL数据库,是否还支持无感知变更功能
阿里云数据管理DMS提供了全面的数据管理、数据库运维、数据安全、数据迁移与同步等功能,助力企业高效、安全地进行数据库管理和运维工作。以下是DMS产品使用合集的详细介绍。
|
2天前
|
存储 NoSQL 容灾
数据库非功能需求分析
本文探讨了业务研发在技术设计中如何满足非功能需求,重点关注数据库系统的角色。内容涵盖数据库的可用性、可靠性、性能、可修改性、安全性及成本。文章强调了根据业务场景选择合适的数据类型(如关系型、非关系型、内存型、图数据库和时间序列数据库)以及考虑数据容量和增长速度。对于性能需求,讨论了响应时间、吞吐量和并发处理能力。此外,还提到了升级路径、兼容性、备份方案和成本控制(硬件、软件和人力成本)在数据库管理中的重要性。
35 0
|
2天前
|
SQL 存储 安全
【软件设计师备考 专题 】数据库管理系统的功能和特征
【软件设计师备考 专题 】数据库管理系统的功能和特征
76 0
|
2天前
|
存储 SQL 数据挖掘
视图、触发器和存储过程:提升数据库功能
视图、触发器和存储过程:提升数据库功能
19 1

热门文章

最新文章