1
1
package com .unfbx .chatgptsteamoutput .controller ;
2
2
3
3
import cn .hutool .core .util .StrUtil ;
4
- import cn .hutool .json .JSONUtil ;
5
- import com .unfbx .chatgpt .OpenAiStreamClient ;
6
- import com .unfbx .chatgpt .entity .billing .CreditGrantsResponse ;
7
- import com .unfbx .chatgpt .entity .chat .Message ;
8
4
import com .unfbx .chatgpt .exception .BaseException ;
9
5
import com .unfbx .chatgpt .exception .CommonError ;
10
- import com .unfbx .chatgptsteamoutput .config .LocalCache ;
11
- import com .unfbx .chatgptsteamoutput .listener .OpenAISSEEventSourceListener ;
6
+ import com .unfbx .chatgptsteamoutput .controller .request .ChatRequest ;
7
+ import com .unfbx .chatgptsteamoutput .controller .response .ChatResponse ;
8
+ import com .unfbx .chatgptsteamoutput .service .SseService ;
12
9
import lombok .extern .slf4j .Slf4j ;
13
10
import org .springframework .stereotype .Controller ;
14
- import org .springframework .web .bind .annotation .CrossOrigin ;
15
- import org .springframework .web .bind .annotation .GetMapping ;
16
- import org .springframework .web .bind .annotation .RequestHeader ;
17
- import org .springframework .web .bind .annotation .RequestParam ;
11
+ import org .springframework .web .bind .annotation .*;
18
12
import org .springframework .web .servlet .mvc .method .annotation .SseEmitter ;
19
13
20
- import java .io .IOException ;
21
- import java .time .LocalDateTime ;
22
- import java .util .ArrayList ;
23
- import java .util .List ;
14
+ import javax .servlet .http .HttpServletResponse ;
24
15
import java .util .Map ;
25
16
26
17
/**
33
24
@ Slf4j
34
25
public class ChatController {
35
26
36
- private final OpenAiStreamClient openAiStreamClient ;
27
+ private final SseService sseService ;
37
28
38
- public ChatController (OpenAiStreamClient openAiStreamClient ) {
39
- this .openAiStreamClient = openAiStreamClient ;
29
+ public ChatController (SseService sseService ) {
30
+ this .sseService = sseService ;
40
31
}
41
32
42
- @ GetMapping ("/chat" )
33
+ /**
34
+ * 创建sse连接
35
+ *
36
+ * @param headers
37
+ * @return
38
+ */
43
39
@ CrossOrigin
44
- public SseEmitter chat (@ RequestParam ("message" ) String msg , @ RequestHeader Map <String , String > headers ) throws IOException {
45
- //默认30秒超时,设置为0L则永不超时
46
- SseEmitter sseEmitter = new SseEmitter (0l );
47
- String uid = headers .get ("uid" );
48
- if (StrUtil .isBlank (uid )) {
49
- throw new BaseException (CommonError .SYS_ERROR );
50
- }
51
- String messageContext = (String ) LocalCache .CACHE .get (uid );
52
- List <Message > messages = new ArrayList <>();
53
- if (StrUtil .isNotBlank (messageContext )) {
54
- messages = JSONUtil .toList (messageContext , Message .class );
55
- if (messages .size () >= 10 ) {
56
- messages = messages .subList (1 , 10 );
57
- }
58
- Message currentMessage = Message .builder ().content (msg ).role (Message .Role .USER ).build ();
59
- messages .add (currentMessage );
60
- } else {
61
- Message currentMessage = Message .builder ().content (msg ).role (Message .Role .USER ).build ();
62
- messages .add (currentMessage );
63
- }
64
- sseEmitter .send (SseEmitter .event ().id (uid ).name ("连接成功!!!!" ).data (LocalDateTime .now ()).reconnectTime (3000 ));
65
- sseEmitter .onCompletion (() -> {
66
- log .info (LocalDateTime .now () + ", uid#" + uid + ", on completion" );
67
- });
68
- sseEmitter .onTimeout (() -> log .info (LocalDateTime .now () + ", uid#" + uid + ", on timeout#" + sseEmitter .getTimeout ()));
69
- sseEmitter .onError (
70
- throwable -> {
71
- try {
72
- log .info (LocalDateTime .now () + ", uid#" + "765431" + ", on error#" + throwable .toString ());
73
- sseEmitter .send (SseEmitter .event ().id ("765431" ).name ("发生异常!" ).data (throwable .getMessage ()).reconnectTime (3000 ));
74
- } catch (IOException e ) {
75
- e .printStackTrace ();
76
- }
77
- }
78
- );
79
- OpenAISSEEventSourceListener openAIEventSourceListener = new OpenAISSEEventSourceListener (sseEmitter );
80
- openAiStreamClient .streamChatCompletion (messages , openAIEventSourceListener );
81
- LocalCache .CACHE .put (uid , JSONUtil .toJsonStr (messages ), LocalCache .TIMEOUT );
82
- return sseEmitter ;
40
+ @ GetMapping ("/createSse" )
41
+ public SseEmitter createConnect (@ RequestHeader Map <String , String > headers ) {
42
+ String uid = getUid (headers );
43
+ return sseService .createSse (uid );
44
+ }
45
+
46
+ /**
47
+ * 聊天接口
48
+ *
49
+ * @param chatRequest
50
+ * @param headers
51
+ */
52
+ @ CrossOrigin
53
+ @ PostMapping ("/chat" )
54
+ @ ResponseBody
55
+ public ChatResponse sseChat (@ RequestBody ChatRequest chatRequest , @ RequestHeader Map <String , String > headers , HttpServletResponse response ) {
56
+ String uid = getUid (headers );
57
+ return sseService .sseChat (uid , chatRequest );
58
+ }
59
+
60
+ /**
61
+ * 关闭连接
62
+ *
63
+ * @param headers
64
+ */
65
+ @ CrossOrigin
66
+ @ GetMapping ("/closeSse" )
67
+ public void closeConnect (@ RequestHeader Map <String , String > headers ) {
68
+ String uid = getUid (headers );
69
+ sseService .closeSse (uid );
83
70
}
84
71
85
72
@ GetMapping ("" )
@@ -92,4 +79,19 @@ public String websocket() {
92
79
return "websocket.html" ;
93
80
}
94
81
82
+ /**
83
+ * 获取uid
84
+ *
85
+ * @param headers
86
+ * @return
87
+ */
88
+ private String getUid (Map <String , String > headers ) {
89
+ String uid = headers .get ("uid" );
90
+ if (StrUtil .isBlank (uid )) {
91
+ throw new BaseException (CommonError .SYS_ERROR );
92
+ }
93
+ return uid ;
94
+ }
95
+
96
+
95
97
}
0 commit comments