程序调用大模型返回结构化输出(JSON)
大家很多时候使用langchain、llamaindex等大模型框架的时候,一直很头疼的就是大模型的answer不一定就按照约定的结构化数据返回,那么下面就以讯飞的spark为例解决这个问题
星火模型调用类
按照OpenAI的sdk的格式,封装了讯飞星火的调用类如下:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
'''
@File : SparkApi.py
@Time : 2024/08/08 15:30:54
@Author : CrissChan
@Version : 1.0
@Site : https://blog.csdn.net/crisschan
@Desc : 按照OpenAI的SDK格式封装的讯飞星火大模型的调用类
'''
import _thread as thread
import base64
import datetime
import hashlib
import hmac
import json
from urllib.parse import urlparse, urlencode
import ssl
from datetime import datetime
from time import mktime
from wsgiref.handlers import format_date_time
import websocket
class SparkAI:
def __init__(self, appid, api_key, api_secret, spark_url, domain):
self.appid = appid
self.api_key = api_key
self.api_secret = api_secret
self.spark_url = spark_url
self.domain = domain
self.answer = ""
def _create_url(self):
# 生成URL的逻辑保持不变
now = datetime.now()
date = format_date_time(mktime(now.timetuple()))
host = urlparse(self.spark_url).netloc
path = urlparse(self.spark_url).path
signature_origin = f"host: {host}\ndate: {date}\nGET {path} HTTP/1.1"
signature_sha = hmac.new(self.api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
digestmod=hashlib.sha256).digest()
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
authorization_origin = f'api_key="{self.api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
v = {
"authorization": authorization,
"date": date,
"host": host
}
return self.spark_url + '?' + urlencode(v)
def _on_message(self, ws, message):
data = json.loads(message)
code = data['header']['code']
if code != 0:
print(f'请求错误: {code}, {data}')
ws.close()
else:
choices = data["payload"]["choices"]
status = choices["status"]
content = choices["text"][0]["content"]
print(content, end="")
self.answer += content
if status == 2:
ws.close()
def _gen_params(self, question):
return {
"header": {"app_id": self.appid, "uid": "1234"},
"parameter": {
"chat": {
"domain": self.domain,
"random_threshold": 0.5,
"max_tokens": 2048,
"auditing": "default"
}
},
"payload": {
"message": {
"text": [
{"role": "user", "content": question}
]
}
}
}
def chat(self, question):
self.answer = "" # 重置答案
ws_url = self._create_url()
ws = websocket.WebSocketApp(
ws_url,
on_message=self._on_message,
on_error=lambda ws, error: print("### error:", error),
on_close=lambda ws, close_status_code, close_msg: print(" "),
on_open=lambda ws: ws.send(json.dumps(self._gen_params(question)))
)
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
return self.answer
调用SparkAI返回结构数据
将星火大模型的appid、api_secret、apikey都存放在.env配置文件中,通过dotenv将内容读出来存入变量中(_变量的介绍可以再次文章中参考https://blog.csdn.net/crisschan/article/details/133277855)
from dotenv import load_dotenv, find_dotenv
_=load_dotenv(find_dotenv())
appid = os.getenv("SPARK_APP_ID")
api_secret=os.getenv("SPARK_APP_SECRET")
api_key=os.getenv("SPARK_APP_KEY")
#用于配置大模型版本,默认"general/generalv2"
# domain = "general" # v1.5版本
domain = "generalv2" # v2.0版本
#云端环境的服务地址
spark_url = "ws://spark-api.xf-yun.com/v2.1/chat" # v2.0环境的地址ws(s)://spark-api.xf-yun.com/v2.1/chat
client = SparkAI(appid=appid, api_key=api_key, api_secret=api_secret, spark_url=spark_url, domain=domain)
下面设计结构化返回的内容如下代码。
import os
from pydantic import BaseModel
# 这个定义创建了一个数据模型,其中:
#name 必须是字符串
#date 必须是字符串
#participants 必须是字符串列表
class CalendarEvent(BaseModel):
name: str
date: str
participants: list[str]
Pydantic是一个用于数据验证和设置管理的 Python 库,BaseModel 是 Pydantic 中的一个核心类,用于创建数据模型。通过继承 BaseModel,可以轻松定义具有类型提示的数据结构,这些模型可以自动验证数据,确保数据符合预期的格式和类型。使用 Pydantic 的 BaseModel 可以更容易地处理和验证从 AI 响应中提取的结构化数据,提高代码的可靠性和可读性。在通过如下代码的约束,就可以收获一个按照格式化好的json。
prompt = """
系统:提取事件信息。
用户:Alice和Bob将在周五参加科学展览会。
请以JSON格式提供以下信息:
{
"name": "事件名称",
"date": "事件日期",
"participants": ["参与者列表"]
}
"""
response = client.chat(prompt)
try:
event_dict = json.loads(response)
event = CalendarEvent(**event_dict)
print(f"活动名称: {event.name}")
print(f"日期: {event.date}")
print(f"参与者: {', '.join(event.participants)}")
except json.JSONDecodeError:
print("无法解析响应为JSON格式")
except Exception as e:
print(f"处理响应时出错: {str(e)}")
运行后得到如下结果:
{
"name": "科学展览会",
"date": "周五",
"participants": ["Alice", "Bob"]
}
活动名称: 科学展览会
日期: 周五
参与者: Alice, Bob