@@ -33,16 +33,29 @@ def revise(history, latest_message):
33
33
return history , ''
34
34
35
35
36
- def revoke (history ):
36
+ def revoke (history , last_state ):
37
37
if len (history ) >= 1 :
38
38
history .pop ()
39
+ last_state [0 ] = history
40
+ last_state [1 ] = ''
41
+ last_state [2 ] = ''
39
42
return history
40
43
41
44
42
45
def interrupt (allow_generate ):
43
46
allow_generate [0 ] = False
44
47
45
48
49
+ def regenerate (last_state , max_length , top_p , temperature , allow_generate ):
50
+ history , query , continue_message = last_state
51
+ if len (query ) == 0 :
52
+ print ("Please input a query first." )
53
+ return
54
+ for x in predictor .predict_continue (query , continue_message , max_length , top_p ,
55
+ temperature , allow_generate , history , last_state ):
56
+ yield x
57
+
58
+
46
59
# 搭建 UI 界面
47
60
with gr .Blocks (css = """.message {
48
61
width: inherit !important;
@@ -61,7 +74,7 @@ def interrupt(allow_generate):
61
74
""" )
62
75
with gr .Row ():
63
76
with gr .Column (scale = 4 ):
64
- chatbot = gr .Chatbot (elem_id = "chat-box" , show_label = False ).style (height = 800 )
77
+ chatbot = gr .Chatbot (elem_id = "chat-box" , show_label = False ).style (height = 850 )
65
78
with gr .Column (scale = 1 ):
66
79
with gr .Row ():
67
80
max_length = gr .Slider (32 , 4096 , value = 2048 , step = 1.0 , label = "Maximum length" , interactive = True )
@@ -78,21 +91,26 @@ def interrupt(allow_generate):
78
91
show_label = False , placeholder = "Revise message" , lines = 2 ).style (container = False )
79
92
revise_btn = gr .Button ("修订" )
80
93
revoke_btn = gr .Button ("撤回" )
94
+ regenerate_btn = gr .Button ("重新生成" )
81
95
interrupt_btn = gr .Button ("终止生成" )
82
96
83
97
history = gr .State ([])
84
98
allow_generate = gr .State ([True ])
85
99
blank_input = gr .State ("" )
100
+ last_state = gr .State ([[], '' , '' ]) # history, query, continue_message
86
101
generate_button .click (
87
102
predictor .predict_continue ,
88
- inputs = [query , blank_input , max_length , top_p , temperature , allow_generate , history ],
103
+ inputs = [query , blank_input , max_length , top_p , temperature , allow_generate , history , last_state ],
89
104
outputs = [chatbot , query ])
90
105
revise_btn .click (revise , inputs = [history , revise_message ], outputs = [chatbot , revise_message ])
91
- revoke_btn .click (revoke , inputs = [history ], outputs = [chatbot ])
106
+ revoke_btn .click (revoke , inputs = [history , last_state ], outputs = [chatbot ])
92
107
continue_btn .click (
93
108
predictor .predict_continue ,
94
- inputs = [query , continue_message , max_length , top_p , temperature , allow_generate , history ],
109
+ inputs = [query , continue_message , max_length , top_p , temperature , allow_generate , history , last_state ],
95
110
outputs = [chatbot , query , continue_message ])
111
+ regenerate_btn .click (regenerate , inputs = [last_state , max_length , top_p , temperature , allow_generate ],
112
+ outputs = [chatbot , query , continue_message ])
96
113
interrupt_btn .click (interrupt , inputs = [allow_generate ])
114
+
97
115
demo .queue (concurrency_count = 4 ).launch (server_name = '0.0.0.0' , server_port = 7860 , share = False , inbrowser = False )
98
116
demo .close ()
0 commit comments