环境
SpringBoot2.x
Maven3.x
JDK1.8
具体步骤
1. 了解相关api信息
首先前往openai api官网OpenAI API,登录自己的账号,然后选择API reference,然后就可以查看相关的模型比如“gpt-3.5-turbo”,还有很多的api。
这里是一些请求体信息:
curl https://api.openai.com/v1/chat/completions \ -H "Content-Type: application/json" \ -H "Authorization: Bearer $OPENAI_API_KEY" \ -d '{ "model": "gpt-3.5-turbo", "messages": [ { "role": "system", "content": "You are a helpful assistant." }, { "role": "user", "content": "Hello!" } ] }'
这里是响应信息:
{ "id": "chatcmpl-123", "object": "chat.completion", "created": 1677652288, "model": "gpt-3.5-turbo-0125", "system_fingerprint": "fp_44709d6fcb", "choices": [{ "index": 0, "message": { "role": "assistant", "content": "\n\nHello there, how may I assist you today?", }, "logprobs": null, "finish_reason": "stop" }], "usage": { "prompt_tokens": 9, "completion_tokens": 12, "total_tokens": 21 } }
最重要的是key,如果是新号则去申请,会送一个5美元的key,或者购买一个key。对了,国内的话需要设置代理才能访问。
2. 创建SpringBoot工程
创建工程后添加依赖
<dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-webflux</artifactId> </dependency>
请求参数类:
public class ChatParams { // 模型 private String model = "gpt-3.5-turbo"; // 消息列表,获取上下文携带返回消息 private List<ChatMessage> messages; // 采样温度,较高的值(如 0.8)将使输出更加随机,而较低的值(如 0.2)将使其更加集中和确定 @DecimalMin("0.0") @DecimalMax("2.0") private Double temperature = 1.0; // 核心采用,其中模型考虑具有top_p概率质量的令牌的结果。因此,0.1 意味着只考虑包含前 10% 概率质量的结果集 @DecimalMin("0.0") @DecimalMax("1.0") private Double top_p = 1.0; // 结果数,如果需要返回多个结果可以设置 n > 1 @Min(1) private Integer n = 1; // 当前用户id,用于处理滥用行为 private Long userId; }
注意,本次是进行流式响应,所以还需一个stream参数:
public class ChatStreamParams extends ChatParams{ private boolean stream = true; }
后端使用WebClient向apenai api发起post请求以获得流式数据
@Transactional @Override public void chatStream2(ChatStreamParams chatParams, String chatId, String userId, Session session) { Flux<String> stringFlux = webClient.post() .uri(apiHost+"/v1/chat/completions") .header("Authorization", "Bearer " + apiKey) .bodyValue(chatParams) .retrieve() .bodyToFlux(String.class); stringFlux .takeWhile(data -> { try { JSONObject json = new JSONObject(data); JSONArray choices = json.getJSONArray("choices"); String finish_reason = choices.getJSONObject(0).getString("finish_reason"); if (finish_reason.equals("stop")) { String content = contentBuilder.toString(); System.out.println("content"+contentBuilder); gptMessageService.save(chatId, Long.valueOf(userId), content, SystemConstants.MESSAGE_TYPE_GPT); contentBuilder.setLength(0); } return !finish_reason.equals("stop"); // 继续处理数据,直到遇到停止条件 } catch (Exception e) { e.printStackTrace(); return false; // 数据格式异常,停止处理数据 } }) .subscribe( data -> { try { JSONObject json = new JSONObject(data); JSONArray choices = json.getJSONArray("choices"); String finish_reason = choices.getJSONObject(0).getString("finish_reason"); if (finish_reason.equals("stop")) { return; } String delta = choices.getJSONObject(0).getString("delta"); System.out.println(data); JSONObject msg = new JSONObject(delta); String content = msg.getString("content"); // 拼接消息,保存到数据库 contentBuilder.append(content); session.getBasicRemote().sendText(content); }catch (Exception e) { e.printStackTrace(); } }, error -> { System.out.println("错误"); error.printStackTrace(); } ); }
参数释义:
chatParams:发送给api的请求参数;
chatId: 会话id;
userId: 用户id;
session: WebSocket的session,前后端流式响应会用到websoket
apiHost: 代理地址;
apiKey: key
注意:在subscribe回调函数里拿到流式数据,在takeWhile回调函数里设置停止条件
好了,通过这些步骤应该能流式响应了,如果遇到啥问题可以留言哦,看到会回。