Skip to content

Commit 5098f85

Browse files
authored
Regenerate (#26)
* Update app.py * Update base.py * Update app.py * Update app.py * Update app.py * Update app.py * Update app.py * Update app.py * Update app.py * Update app.py
1 parent e08de4e commit 5098f85

File tree

2 files changed

+28
-6
lines changed

2 files changed

+28
-6
lines changed

app.py

+23-5
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,29 @@ def revise(history, latest_message):
3333
return history, ''
3434

3535

36-
def revoke(history):
36+
def revoke(history, last_state):
3737
if len(history) >= 1:
3838
history.pop()
39+
last_state[0] = history
40+
last_state[1] = ''
41+
last_state[2] = ''
3942
return history
4043

4144

4245
def interrupt(allow_generate):
4346
allow_generate[0] = False
4447

4548

49+
def regenerate(last_state, max_length, top_p, temperature, allow_generate):
50+
history, query, continue_message = last_state
51+
if len(query) == 0:
52+
print("Please input a query first.")
53+
return
54+
for x in predictor.predict_continue(query, continue_message, max_length, top_p,
55+
temperature, allow_generate, history, last_state):
56+
yield x
57+
58+
4659
# 搭建 UI 界面
4760
with gr.Blocks(css=""".message {
4861
width: inherit !important;
@@ -61,7 +74,7 @@ def interrupt(allow_generate):
6174
""")
6275
with gr.Row():
6376
with gr.Column(scale=4):
64-
chatbot = gr.Chatbot(elem_id="chat-box", show_label=False).style(height=800)
77+
chatbot = gr.Chatbot(elem_id="chat-box", show_label=False).style(height=850)
6578
with gr.Column(scale=1):
6679
with gr.Row():
6780
max_length = gr.Slider(32, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
@@ -78,21 +91,26 @@ def interrupt(allow_generate):
7891
show_label=False, placeholder="Revise message", lines=2).style(container=False)
7992
revise_btn = gr.Button("修订")
8093
revoke_btn = gr.Button("撤回")
94+
regenerate_btn = gr.Button("重新生成")
8195
interrupt_btn = gr.Button("终止生成")
8296

8397
history = gr.State([])
8498
allow_generate = gr.State([True])
8599
blank_input = gr.State("")
100+
last_state = gr.State([[], '', '']) # history, query, continue_message
86101
generate_button.click(
87102
predictor.predict_continue,
88-
inputs=[query, blank_input, max_length, top_p, temperature, allow_generate, history],
103+
inputs=[query, blank_input, max_length, top_p, temperature, allow_generate, history, last_state],
89104
outputs=[chatbot, query])
90105
revise_btn.click(revise, inputs=[history, revise_message], outputs=[chatbot, revise_message])
91-
revoke_btn.click(revoke, inputs=[history], outputs=[chatbot])
106+
revoke_btn.click(revoke, inputs=[history, last_state], outputs=[chatbot])
92107
continue_btn.click(
93108
predictor.predict_continue,
94-
inputs=[query, continue_message, max_length, top_p, temperature, allow_generate, history],
109+
inputs=[query, continue_message, max_length, top_p, temperature, allow_generate, history, last_state],
95110
outputs=[chatbot, query, continue_message])
111+
regenerate_btn.click(regenerate, inputs=[last_state, max_length, top_p, temperature, allow_generate],
112+
outputs=[chatbot, query, continue_message])
96113
interrupt_btn.click(interrupt, inputs=[allow_generate])
114+
97115
demo.queue(concurrency_count=4).launch(server_name='0.0.0.0', server_port=7860, share=False, inbrowser=False)
98116
demo.close()

predictors/base.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import copy
12
from abc import ABC, abstractmethod
23

34

@@ -27,8 +28,11 @@ def stream_chat_continue(self, *args, **kwargs):
2728
raise NotImplementedError
2829

2930
def predict_continue(self, query, latest_message, max_length, top_p,
30-
temperature, allow_generate, history, *args,
31+
temperature, allow_generate, history, last_state, *args,
3132
**kwargs):
33+
last_state[0] = copy.deepcopy(history)
34+
last_state[1] = query
35+
last_state[2] = latest_message
3236
if history is None:
3337
history = []
3438
allow_generate[0] = True

0 commit comments

Comments
 (0)