简介
黑云翻墨未遮山,白雨跳珠乱入船。
小伙伴们好,我是微信公众号《小窗幽记机器学习》的小编:划龙舟的小男孩。紧接前面几篇ChatGPT Prompt工程和应用系列文章:
- 04:ChatGPT Prompt编写指南
- 05:如何优化ChatGPT Prompt?
- 06:ChatGPT Prompt实践:文本摘要&推断&转换
- 07:ChatGPT Prompt实践:以智能客服邮件为例
- 08:ChatGPT Prompt实践:如何用ChatGPT构建点餐机器人?
- 09:基于ChatGPT构建智能客服系统(query分类&安全审核&防注入)
- 10:如何编写思维链Prompt?以智能客服为例
- 11:LangChain危矣?亲测ChatGPT函数调用功能:以天气问答为例
更多、更新文章欢迎关注微信公众号:小窗幽记机器学习。后续会持续整理模型加速、模型部署、模型压缩、LLM、AI艺术等系列专题,敬请关注。
今天这篇小作文以数据库问答(Text2SQL)为例进一步介绍ChatGPT的函数调用。本文将介绍如何将模型生成的结果输入到自定义的函数中,并利用该功能实现数据库问答功能。为简单起见,这里将使用Chinook 示例数据库。
需要特别注意:
在生产环境中,生成的SQL可能存在较高风险。因为模型在生成正确的 SQL 这方面暂不完全可靠,小伙伴们评估谨慎使用!
数据库相关
环境相关设置及其辅助函数代码请于文末附录部分。以下直接介绍示例数据库相关细节。
获取 Chinook 数据相关的信息:
import sqlite3
conn = sqlite3.connect("data/chinook.db")
print("Opened database successfully")
def get_table_names(conn):
"""Return a list of table names."""
table_names = []
tables = conn.execute("SELECT name FROM sqlite_master WHERE type='table';")
for table in tables.fetchall():
table_names.append(table[0])
return table_names
def get_column_names(conn, table_name):
"""Return a list of column names."""
column_names = []
columns = conn.execute(f"PRAGMA table_info('{table_name}');").fetchall()
for col in columns:
column_names.append(col[1])
return column_names
def get_database_info(conn):
"""Return a list of dicts containing the table name and columns for each table in the database."""
table_dicts = []
for table_name in get_table_names(conn):
columns_names = get_column_names(conn, table_name)
table_dicts.append({"table_name": table_name, "column_names": columns_names})
return table_dicts
获取 db 中的table
可以获取chinook db中有哪些 table:
table_names = get_table_names(conn)
print("table_names=", table_names)
输出结果如下:
table_names= ['albums', 'sqlite_sequence', 'artists', 'customers', 'employees', 'genres', 'invoices', 'invoice_items', 'media_types', 'playlists', 'playlist_track', 'tracks', 'sqlite_stat1']
获取各 table 的schema
database_schema_dict = get_database_info(conn)
database_schema_string = "\n".join(
[
f"Table: {table['table_name']}\nColumns: {', '.join(table['column_names'])}"
for table in database_schema_dict
]
)
database_schema_dict
结果如下:
[{'table_name': 'albums', 'column_names': ['AlbumId', 'Title', 'ArtistId']},
{'table_name': 'sqlite_sequence', 'column_names': ['name', 'seq']},
{'table_name': 'artists', 'column_names': ['ArtistId', 'Name']},
{'table_name': 'customers',
'column_names': ['CustomerId',
'FirstName',
'LastName',
'Company',
'Address',
'City',
'State',
'Country',
'PostalCode',
'Phone',
'Fax',
'Email',
'SupportRepId']},
{'table_name': 'employees',
'column_names': ['EmployeeId',
'LastName',
'FirstName',
'Title',
'ReportsTo',
'BirthDate',
'HireDate',
'Address',
'City',
'State',
'Country',
'PostalCode',
'Phone',
'Fax',
'Email']},
{'table_name': 'genres', 'column_names': ['GenreId', 'Name']},
{'table_name': 'invoices',
'column_names': ['InvoiceId',
'CustomerId',
'InvoiceDate',
'BillingAddress',
'BillingCity',
'BillingState',
'BillingCountry',
'BillingPostalCode',
'Total']},
{'table_name': 'invoice_items',
'column_names': ['InvoiceLineId',
'InvoiceId',
'TrackId',
'UnitPrice',
'Quantity']},
{'table_name': 'media_types', 'column_names': ['MediaTypeId', 'Name']},
{'table_name': 'playlists', 'column_names': ['PlaylistId', 'Name']},
{'table_name': 'playlist_track', 'column_names': ['PlaylistId', 'TrackId']},
{'table_name': 'tracks',
'column_names': ['TrackId',
'Name',
'AlbumId',
'MediaTypeId',
'GenreId',
'Composer',
'Milliseconds',
'Bytes',
'UnitPrice']},
{'table_name': 'sqlite_stat1', 'column_names': ['tbl', 'idx', 'stat']}]
database_schema_string
结果如下:
'Table: albums\nColumns: AlbumId, Title, ArtistId\nTable: sqlite_sequence\nColumns: name, seq\nTable: artists\nColumns: ArtistId, Name\nTable: customers\nColumns: CustomerId, FirstName, LastName, Company, Address, City, State, Country, PostalCode, Phone, Fax, Email, SupportRepId\nTable: employees\nColumns: EmployeeId, LastName, FirstName, Title, ReportsTo, BirthDate, HireDate, Address, City, State, Country, PostalCode, Phone, Fax, Email\nTable: genres\nColumns: GenreId, Name\nTable: invoices\nColumns: InvoiceId, CustomerId, InvoiceDate, BillingAddress, BillingCity, BillingState, BillingCountry, BillingPostalCode, Total\nTable: invoice_items\nColumns: InvoiceLineId, InvoiceId, TrackId, UnitPrice, Quantity\nTable: media_types\nColumns: MediaTypeId, Name\nTable: playlists\nColumns: PlaylistId, Name\nTable: playlist_track\nColumns: PlaylistId, TrackId\nTable: tracks\nColumns: TrackId, Name, AlbumId, MediaTypeId, GenreId, Composer, Milliseconds, Bytes, UnitPrice\nTable: sqlite_stat1\nColumns: tbl, idx, stat'
定义相关函数
定义functions规范
注意,在定义functions规范时要将数据库的schema插入到函数规范中,这对模型来说是很重要的。
functions = [
{
"name": "ask_database",
"description": "请使用以下函数来回答关于音乐的用户问题。输出结果应为一个完整的 SQL 查询。",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": f"""
用于提取信息以回答用户问题的 SQL 查询。
SQL 查询应使用以下数据库模式编写:
{database_schema_string}
查询应以纯文本形式返回,而不是 JSON 格式。
""",
}
},
"required": ["query"],
},
}
]
定义执行SQL语句的函数
# ChatGPT 生成的query会输入到 ask_database
def ask_database(conn, query):
"""Function to query SQLite database with a provided SQL query."""
try:
results = str(conn.execute(query).fetchall())
except Exception as e:
results = f"query failed with error: {e}"
return results
# 根据 message["function_call"]["name"] 判断函数调用时机
def execute_function_call(message):
if message["function_call"]["name"] == "ask_database":
query = json.loads(message["function_call"]["arguments"])["query"]
results = ask_database(conn, query)
else:
results = f"Error: function {message['function_call']['name']} does not exist"
return results
示例1:查询曲目数量Top5的艺术家
messages = []
messages.append({"role": "system", "content": "基于 Chinook 音乐数据库生成 SQL 查询来回答用户的问题。"})
messages.append({"role": "user", "content": "你好,按照曲目数量,排名前5位的艺术家有谁?"})
chat_response = chat_completion_request(messages, functions)
print("chat_response=", chat_response.json())
assistant_message = chat_response.json()["choices"][0]["message"]
print("assistant_message=", assistant_message)
messages.append(assistant_message)
if assistant_message.get("function_call"):
results = execute_function_call(assistant_message)
messages.append({"role": "function", "name": assistant_message["function_call"]["name"], "content": results})
pretty_print_conversation(messages)
返回结果如下:
chat_response= {'id': 'chatcmpl-7TQdhItIU3FEgvvYliCGRoD0QetgX', 'object': 'chat.completion', 'created': 1687248269, 'model': 'gpt-3.5-turbo-0613', 'choices': [{'index': 0, 'message': {'role': 'assistant', 'content': None, 'function_call': {'name': 'ask_database', 'arguments': '{\n "query": "SELECT artists.Name, COUNT(tracks.TrackId) AS NumTracks FROM artists JOIN albums ON artists.ArtistId = albums.ArtistId JOIN tracks ON albums.AlbumID = tracks.AlbumId GROUP BY artists.ArtistId ORDER BY NumTracks DESC LIMIT 5;"\n}'}}, 'finish_reason': 'function_call'}], 'usage': {'prompt_tokens': 448, 'completion_tokens': 67, 'total_tokens': 515}}
assistant_message= {'role': 'assistant', 'content': None, 'function_call': {'name': 'ask_database', 'arguments': '{\n "query": "SELECT artists.Name, COUNT(tracks.TrackId) AS NumTracks FROM artists JOIN albums ON artists.ArtistId = albums.ArtistId JOIN tracks ON albums.AlbumID = tracks.AlbumId GROUP BY artists.ArtistId ORDER BY NumTracks DESC LIMIT 5;"\n}'}}
最终pretty_print_conversation(messages)
的结果如下:
system: 基于 Chinook 音乐数据库生成 SQL 查询来回答用户的问题。
user: 你好,按照曲目数量,排名前5位的艺术家有谁?
assistant: {'name': 'ask_database', 'arguments': '{\n "query": "SELECT artists.Name, COUNT(tracks.TrackId) AS NumTracks FROM artists JOIN albums ON artists.ArtistId = albums.ArtistId JOIN tracks ON albums.AlbumID = tracks.AlbumId GROUP BY artists.ArtistId ORDER BY NumTracks DESC LIMIT 5;"\n}'}
function (ask_database): [('Iron Maiden', 213), ('U2', 135), ('Led Zeppelin', 114), ('Metallica', 112), ('Deep Purple', 92)]
示例2:查询哪个专辑曲目最多
messages.append({"role": "user", "content": "曲目最多的专辑的名称是什么?"})
chat_response = chat_completion_request(messages, functions)
assistant_message = chat_response.json()["choices"][0]["message"]
messages.append(assistant_message)
if assistant_message.get("function_call"):
results = execute_function_call(assistant_message)
messages.append({"role": "function", "content": results, "name": assistant_message["function_call"]["name"]})
pretty_print_conversation(messages)
输出结果如下:
system: 基于 Chinook 音乐数据库生成 SQL 查询来回答用户的问题。
user: 你好,按照曲目数量,排名前5位的艺术家有谁?
assistant: {'name': 'ask_database', 'arguments': '{\n "query": "SELECT artists.Name, COUNT(tracks.TrackId) AS NumTracks FROM artists JOIN albums ON artists.ArtistId = albums.ArtistId JOIN tracks ON albums.AlbumID = tracks.AlbumId GROUP BY artists.ArtistId ORDER BY NumTracks DESC LIMIT 5;"\n}'}
function (ask_database): [('Iron Maiden', 213), ('U2', 135), ('Led Zeppelin', 114), ('Metallica', 112), ('Deep Purple', 92)]
user: 曲目最多的专辑的名称是什么?
assistant: {'name': 'ask_database', 'arguments': '{\n "query": "SELECT albums.Title, COUNT(tracks.TrackId) AS NumTracks FROM albums JOIN tracks ON albums.AlbumId = tracks.AlbumId GROUP BY albums.AlbumId ORDER BY NumTracks DESC LIMIT 1;"\n}'}
function (ask_database): [('Greatest Hits', 57)]
小结
通过上述示例可以确切感受openai函数调用功能的强大,这也为开发者构建更多稳健服务提供更强的保障。
附录
import json
import openai
import requests
import os
from tenacity import retry, wait_random_exponential, stop_after_attempt
from termcolor import colored
GPT_MODEL = "gpt-3.5-turbo-0613"
openai.api_key = "sk-xxx"
os.environ['HTTP_PROXY'] = "xxx"
os.environ['HTTPS_PROXY'] = "xxx"
# 调用API的重试机制
@retry(wait=wait_random_exponential(min=1, max=40), stop=stop_after_attempt(3))
def chat_completion_request(messages, functions=None, function_call=None, model=GPT_MODEL):
headers = {
"Content-Type": "application/json",
"Authorization": "Bearer " + openai.api_key,
}
json_data = {"model": model, "messages": messages}
if functions is not None:
json_data.update({"functions": functions})
if function_call is not None:
json_data.update({"function_call": function_call})
try:
response = requests.post(
"https://api.openai.com/v1/chat/completions",
headers=headers,
json=json_data,
)
return response
except Exception as e:
print("Unable to generate ChatCompletion response")
print(f"Exception: {e}")
return e
# 处理输出,方便阅读
def pretty_print_conversation(messages):
role_to_color = {
"system": "red",
"user": "green",
"assistant": "blue",
"function": "magenta",
}
formatted_messages = []
for message in messages:
if message["role"] == "system":
formatted_messages.append(f"system: {message['content']}\n")
elif message["role"] == "user":
formatted_messages.append(f"user: {message['content']}\n")
elif message["role"] == "assistant" and message.get("function_call"):
formatted_messages.append(f"assistant: {message['function_call']}\n")
elif message["role"] == "assistant" and not message.get("function_call"):
formatted_messages.append(f"assistant: {message['content']}\n")
elif message["role"] == "function":
formatted_messages.append(f"function ({message['name']}): {message['content']}\n")
for formatted_message in formatted_messages:
print(
colored(
formatted_message,
role_to_color[messages[formatted_messages.index(formatted_message)]["role"]],
)
)