Skip to content

Commit f10ec4a

Browse files
author
guorutao
committed
Merge remote-tracking branch 'origin/main' into main
# Conflicts: # README.md # pom.xml
2 parents 93ae903 + e6f803e commit f10ec4a

File tree

10 files changed

+337
-104
lines changed

10 files changed

+337
-104
lines changed

README.md

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
# 简介
2-
Open AI ChatGPT流式输出。Open AI Stream output. ChatGPT Stream output.
2+
Open AI ChatGPT流式输出。Open AI Stream output. ChatGPT Stream output、
3+
支持Tokens计算。
34

45
**此项目只是对[chatgpt-java](https://github.com/Grt1228/chatgpt-java) SDK的一个简单示例项目,实现流式输出,仅做参考仅做参考仅做参考。大家最好还是自己基于SDK动手实现**
56
---
6-
### 目前本项目支持两种流式输出,基于[ChatGPT-Java SDK](https://github.com/Grt1228/chatgpt-java)
7+
### 目前本项目支持两种流式输出,支持Tokens计算,基于[ChatGPT-Java SDK](https://github.com/Grt1228/chatgpt-java)
78

89
流式输出实现方式 | 小程序 | 安卓 | ios | H5
910
---|---|---|---|---
@@ -61,7 +62,7 @@ sse实现:http://localhost:8000/
6162
websocket实现:http://localhost:8000/websocket
6263
```
6364
能打开此页面表示运行成功
64-
<img width="1080" alt="1" src="https://user-images.githubusercontent.com/27008803/224496424-b75465a0-32fb-491a-934c-c9c524cf5be7.png">
65+
<img width="954" alt="8ccfe107fc10deffdf7fac42b95547b" src="https://user-images.githubusercontent.com/27008803/230941561-79e344ed-b751-40c7-9a59-cbe5216923b1.png">
6566

6667

6768
代码其实很简单,小伙伴们可以下载代码来看下。

src/main/java/com/unfbx/chatgptsteamoutput/ChatgptSteamOutputApplication.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,17 @@ public static void main(String[] args) {
3737
@Bean
3838
public OpenAiStreamClient openAiStreamClient() {
3939
//本地开发需要配置代理地址
40-
Proxy proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress("127.0.0.1", 7890));
40+
// Proxy proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress("127.0.0.1", 7890));
4141
HttpLoggingInterceptor httpLoggingInterceptor = new HttpLoggingInterceptor(new OpenAILogger());
4242
//!!!!!!测试或者发布到服务器千万不要配置Level == BODY!!!!
4343
//!!!!!!测试或者发布到服务器千万不要配置Level == BODY!!!!
4444
httpLoggingInterceptor.setLevel(HttpLoggingInterceptor.Level.HEADERS);
4545
OkHttpClient okHttpClient = new OkHttpClient
4646
.Builder()
47-
.proxy(proxy)
47+
// .proxy(proxy)
4848
.addInterceptor(httpLoggingInterceptor)
4949
.connectTimeout(30, TimeUnit.SECONDS)
50-
.writeTimeout(60, TimeUnit.SECONDS)
50+
.writeTimeout(600, TimeUnit.SECONDS)
5151
.readTimeout(600, TimeUnit.SECONDS)
5252
.build();
5353
return OpenAiStreamClient
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,17 @@
11
package com.unfbx.chatgptsteamoutput.controller;
22

33
import cn.hutool.core.util.StrUtil;
4-
import cn.hutool.json.JSONUtil;
5-
import com.unfbx.chatgpt.OpenAiStreamClient;
6-
import com.unfbx.chatgpt.entity.billing.CreditGrantsResponse;
7-
import com.unfbx.chatgpt.entity.chat.Message;
84
import com.unfbx.chatgpt.exception.BaseException;
95
import com.unfbx.chatgpt.exception.CommonError;
10-
import com.unfbx.chatgptsteamoutput.config.LocalCache;
11-
import com.unfbx.chatgptsteamoutput.listener.OpenAISSEEventSourceListener;
6+
import com.unfbx.chatgptsteamoutput.controller.request.ChatRequest;
7+
import com.unfbx.chatgptsteamoutput.controller.response.ChatResponse;
8+
import com.unfbx.chatgptsteamoutput.service.SseService;
129
import lombok.extern.slf4j.Slf4j;
1310
import org.springframework.stereotype.Controller;
14-
import org.springframework.web.bind.annotation.CrossOrigin;
15-
import org.springframework.web.bind.annotation.GetMapping;
16-
import org.springframework.web.bind.annotation.RequestHeader;
17-
import org.springframework.web.bind.annotation.RequestParam;
11+
import org.springframework.web.bind.annotation.*;
1812
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
1913

20-
import java.io.IOException;
21-
import java.time.LocalDateTime;
22-
import java.util.ArrayList;
23-
import java.util.List;
14+
import javax.servlet.http.HttpServletResponse;
2415
import java.util.Map;
2516

2617
/**
@@ -33,53 +24,49 @@
3324
@Slf4j
3425
public class ChatController {
3526

36-
private final OpenAiStreamClient openAiStreamClient;
27+
private final SseService sseService;
3728

38-
public ChatController(OpenAiStreamClient openAiStreamClient) {
39-
this.openAiStreamClient = openAiStreamClient;
29+
public ChatController(SseService sseService) {
30+
this.sseService = sseService;
4031
}
4132

42-
@GetMapping("/chat")
33+
/**
34+
* 创建sse连接
35+
*
36+
* @param headers
37+
* @return
38+
*/
4339
@CrossOrigin
44-
public SseEmitter chat(@RequestParam("message") String msg, @RequestHeader Map<String, String> headers) throws IOException {
45-
//默认30秒超时,设置为0L则永不超时
46-
SseEmitter sseEmitter = new SseEmitter(0l);
47-
String uid = headers.get("uid");
48-
if (StrUtil.isBlank(uid)) {
49-
throw new BaseException(CommonError.SYS_ERROR);
50-
}
51-
String messageContext = (String) LocalCache.CACHE.get(uid);
52-
List<Message> messages = new ArrayList<>();
53-
if (StrUtil.isNotBlank(messageContext)) {
54-
messages = JSONUtil.toList(messageContext, Message.class);
55-
if (messages.size() >= 10) {
56-
messages = messages.subList(1, 10);
57-
}
58-
Message currentMessage = Message.builder().content(msg).role(Message.Role.USER).build();
59-
messages.add(currentMessage);
60-
} else {
61-
Message currentMessage = Message.builder().content(msg).role(Message.Role.USER).build();
62-
messages.add(currentMessage);
63-
}
64-
sseEmitter.send(SseEmitter.event().id(uid).name("连接成功!!!!").data(LocalDateTime.now()).reconnectTime(3000));
65-
sseEmitter.onCompletion(() -> {
66-
log.info(LocalDateTime.now() + ", uid#" + uid + ", on completion");
67-
});
68-
sseEmitter.onTimeout(() -> log.info(LocalDateTime.now() + ", uid#" + uid + ", on timeout#" + sseEmitter.getTimeout()));
69-
sseEmitter.onError(
70-
throwable -> {
71-
try {
72-
log.info(LocalDateTime.now() + ", uid#" + "765431" + ", on error#" + throwable.toString());
73-
sseEmitter.send(SseEmitter.event().id("765431").name("发生异常!").data(throwable.getMessage()).reconnectTime(3000));
74-
} catch (IOException e) {
75-
e.printStackTrace();
76-
}
77-
}
78-
);
79-
OpenAISSEEventSourceListener openAIEventSourceListener = new OpenAISSEEventSourceListener(sseEmitter);
80-
openAiStreamClient.streamChatCompletion(messages, openAIEventSourceListener);
81-
LocalCache.CACHE.put(uid, JSONUtil.toJsonStr(messages), LocalCache.TIMEOUT);
82-
return sseEmitter;
40+
@GetMapping("/createSse")
41+
public SseEmitter createConnect(@RequestHeader Map<String, String> headers) {
42+
String uid = getUid(headers);
43+
return sseService.createSse(uid);
44+
}
45+
46+
/**
47+
* 聊天接口
48+
*
49+
* @param chatRequest
50+
* @param headers
51+
*/
52+
@CrossOrigin
53+
@PostMapping("/chat")
54+
@ResponseBody
55+
public ChatResponse sseChat(@RequestBody ChatRequest chatRequest, @RequestHeader Map<String, String> headers, HttpServletResponse response) {
56+
String uid = getUid(headers);
57+
return sseService.sseChat(uid, chatRequest);
58+
}
59+
60+
/**
61+
* 关闭连接
62+
*
63+
* @param headers
64+
*/
65+
@CrossOrigin
66+
@GetMapping("/closeSse")
67+
public void closeConnect(@RequestHeader Map<String, String> headers) {
68+
String uid = getUid(headers);
69+
sseService.closeSse(uid);
8370
}
8471

8572
@GetMapping("")
@@ -92,4 +79,19 @@ public String websocket() {
9279
return "websocket.html";
9380
}
9481

82+
/**
83+
* 获取uid
84+
*
85+
* @param headers
86+
* @return
87+
*/
88+
private String getUid(Map<String, String> headers) {
89+
String uid = headers.get("uid");
90+
if (StrUtil.isBlank(uid)) {
91+
throw new BaseException(CommonError.SYS_ERROR);
92+
}
93+
return uid;
94+
}
95+
96+
9597
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
package com.unfbx.chatgptsteamoutput.controller.request;
2+
3+
import lombok.Data;
4+
5+
/**
6+
* 描述:
7+
*
8+
* @author https:www.unfbx.com
9+
* @sine 2023-04-08
10+
*/
11+
@Data
12+
public class ChatRequest {
13+
/**
14+
* 客户端发送的问题参数
15+
*/
16+
private String msg;
17+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
package com.unfbx.chatgptsteamoutput.controller.response;
2+
3+
import com.fasterxml.jackson.annotation.JsonProperty;
4+
import lombok.Data;
5+
6+
/**
7+
* 描述:
8+
*
9+
* @author https:www.unfbx.com
10+
* @sine 2023-04-08
11+
*/
12+
@Data
13+
public class ChatResponse {
14+
/**
15+
* 问题消耗tokens
16+
*/
17+
@JsonProperty("question_tokens")
18+
private long questionTokens = 0;
19+
}

src/main/java/com/unfbx/chatgptsteamoutput/entity/Chat.java

+6-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,12 @@
44
import lombok.Data;
55

66
import java.util.List;
7-
7+
/**
8+
* 描述:
9+
*
10+
* @author https:www.unfbx.com
11+
* @date 2023-04-10
12+
*/
813
@Data
914
public class Chat {
1015

src/main/java/com/unfbx/chatgptsteamoutput/listener/OpenAISSEEventSourceListener.java

+20-5
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import com.fasterxml.jackson.databind.ObjectMapper;
44
import com.unfbx.chatgpt.entity.chat.ChatCompletionResponse;
5+
import com.unfbx.chatgptsteamoutput.config.LocalCache;
56
import lombok.SneakyThrows;
67
import lombok.extern.slf4j.Slf4j;
78
import okhttp3.Response;
@@ -47,6 +48,10 @@ public void onEvent(EventSource eventSource, String id, String type, String data
4748
tokens += 1;
4849
if (data.equals("[DONE]")) {
4950
log.info("OpenAI返回数据结束了");
51+
sseEmitter.send(SseEmitter.event()
52+
.id("[TOKENS]")
53+
.data("<br/><br/>tokens:" + tokens())
54+
.reconnectTime(3000));
5055
sseEmitter.send(SseEmitter.event()
5156
.id("[DONE]")
5257
.data("[DONE]")
@@ -57,10 +62,16 @@ public void onEvent(EventSource eventSource, String id, String type, String data
5762
}
5863
ObjectMapper mapper = new ObjectMapper();
5964
ChatCompletionResponse completionResponse = mapper.readValue(data, ChatCompletionResponse.class); // 读取Json
60-
sseEmitter.send(SseEmitter.event()
61-
.id(completionResponse.getId())
62-
.data(completionResponse.getChoices().get(0).getDelta())
63-
.reconnectTime(3000));
65+
try {
66+
sseEmitter.send(SseEmitter.event()
67+
.id(completionResponse.getId())
68+
.data(completionResponse.getChoices().get(0).getDelta())
69+
.reconnectTime(3000));
70+
} catch (Exception e) {
71+
log.error("sse信息推送失败!");
72+
eventSource.cancel();
73+
e.printStackTrace();
74+
}
6475
}
6576

6677

@@ -74,7 +85,7 @@ public void onClosed(EventSource eventSource) {
7485
@SneakyThrows
7586
@Override
7687
public void onFailure(EventSource eventSource, Throwable t, Response response) {
77-
if(Objects.isNull(response)){
88+
if (Objects.isNull(response)) {
7889
return;
7990
}
8091
ResponseBody body = response.body();
@@ -86,6 +97,10 @@ public void onFailure(EventSource eventSource, Throwable t, Response response) {
8697
eventSource.cancel();
8798
}
8899

100+
/**
101+
* tokens
102+
* @return
103+
*/
89104
public long tokens() {
90105
return tokens;
91106
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
package com.unfbx.chatgptsteamoutput.service;
2+
3+
import com.unfbx.chatgptsteamoutput.controller.request.ChatRequest;
4+
import com.unfbx.chatgptsteamoutput.controller.response.ChatResponse;
5+
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
6+
7+
/**
8+
* 描述:
9+
*
10+
* @author https:www.unfbx.com
11+
* @date 2023-04-08
12+
*/
13+
public interface SseService {
14+
/**
15+
* 创建SSE
16+
* @param uid
17+
* @return
18+
*/
19+
SseEmitter createSse(String uid);
20+
21+
/**
22+
* 关闭SSE
23+
* @param uid
24+
*/
25+
void closeSse(String uid);
26+
27+
/**
28+
* 客户端发送消息到服务端
29+
* @param uid
30+
* @param chatRequest
31+
*/
32+
ChatResponse sseChat(String uid, ChatRequest chatRequest);
33+
}

0 commit comments

Comments
 (0)