Skip to content

Commit 02fe5f2

Browse files
authoredDec 18, 2023
支持ChatGLM3,版本升级 (#44)
* Update download_model.py * Create start_offline_cmd.bat * Update download_model.py * Update download_model.py * Update download_model.py * Update requirements.txt * Create test_models.py * Update app.py * Update test_models.py * 1 * Create chatglm3_predictor.py * Update chatglm3_predictor.py * Update requirements.txt * Update app.py * Update chatglm3_predictor.py * Update base.py * Update chatglm3_predictor.py * Update chatglm3_predictor.py * Update download_model.py * Update chatglm3_predictor.py * Update base.py * Update base.py * Update base.py * Update app.py * Update chatglm3_predictor.py * Update app.py
1 parent eefba1b commit 02fe5f2

11 files changed

+2175
-17
lines changed
 

‎app.py

+14-8
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,12 @@
88
print('Done'.center(64, '-'))
99

1010
# 加载模型
11-
model_name = 'THUDM/chatglm2-6b'
11+
model_name = 'THUDM/chatglm3-6b'
1212

13-
if 'chatglm2' in model_name.lower():
13+
if 'chatglm3' in model_name.lower():
14+
from predictors.chatglm3_predictor import ChatGLM3
15+
predictor = ChatGLM3(model_name)
16+
elif 'chatglm2' in model_name.lower():
1417
from predictors.chatglm2_predictor import ChatGLM2
1518
predictor = ChatGLM2(model_name)
1619
elif 'chatglm' in model_name.lower():
@@ -31,7 +34,10 @@
3134

3235

3336
def revise(history, latest_message):
34-
history[-1] = (history[-1][0], latest_message)
37+
if isinstance(history[-1], tuple):
38+
history[-1] = (history[-1][0], latest_message)
39+
elif isinstance(history[-1], dict):
40+
history[-1]['content'] = latest_message
3541
return history, ''
3642

3743

@@ -76,21 +82,21 @@ def regenerate(last_state, max_length, top_p, temperature, allow_generate):
7682
""")
7783
with gr.Row():
7884
with gr.Column(scale=4):
79-
chatbot = gr.Chatbot(elem_id="chat-box", show_label=False).style(height=850)
85+
chatbot = gr.Chatbot(elem_id="chat-box", show_label=False, height=850)
8086
with gr.Column(scale=1):
8187
with gr.Row():
8288
max_length = gr.Slider(32, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
8389
top_p = gr.Slider(0.01, 1, value=0.7, step=0.01, label="Top P", interactive=True)
8490
temperature = gr.Slider(0.01, 5, value=0.95, step=0.01, label="Temperature", interactive=True)
8591
with gr.Row():
86-
query = gr.Textbox(show_label=False, placeholder="Prompts", lines=4).style(container=False)
92+
query = gr.Textbox(show_label=False, placeholder="Prompts", lines=4)
8793
generate_button = gr.Button("生成")
8894
with gr.Row():
8995
continue_message = gr.Textbox(
90-
show_label=False, placeholder="Continue message", lines=2).style(container=False)
96+
show_label=False, placeholder="Continue message", lines=2)
9197
continue_btn = gr.Button("续写")
9298
revise_message = gr.Textbox(
93-
show_label=False, placeholder="Revise message", lines=2).style(container=False)
99+
show_label=False, placeholder="Revise message", lines=2)
94100
revise_btn = gr.Button("修订")
95101
revoke_btn = gr.Button("撤回")
96102
regenerate_btn = gr.Button("重新生成")
@@ -114,5 +120,5 @@ def regenerate(last_state, max_length, top_p, temperature, allow_generate):
114120
outputs=[chatbot, query, continue_message])
115121
interrupt_btn.click(interrupt, inputs=[allow_generate])
116122

117-
demo.queue(concurrency_count=4).launch(server_name='0.0.0.0', server_port=7860, share=False, inbrowser=False)
123+
demo.queue().launch(server_name='0.0.0.0', server_port=7860, share=False, inbrowser=False)
118124
demo.close()

‎chatglm3/configuration_chatglm.py

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from transformers import PretrainedConfig
2+
3+
4+
class ChatGLMConfig(PretrainedConfig):
5+
model_type = "chatglm"
6+
def __init__(
7+
self,
8+
num_layers=28,
9+
padded_vocab_size=65024,
10+
hidden_size=4096,
11+
ffn_hidden_size=13696,
12+
kv_channels=128,
13+
num_attention_heads=32,
14+
seq_length=2048,
15+
hidden_dropout=0.0,
16+
classifier_dropout=None,
17+
attention_dropout=0.0,
18+
layernorm_epsilon=1e-5,
19+
rmsnorm=True,
20+
apply_residual_connection_post_layernorm=False,
21+
post_layer_norm=True,
22+
add_bias_linear=False,
23+
add_qkv_bias=False,
24+
bias_dropout_fusion=True,
25+
multi_query_attention=False,
26+
multi_query_group_num=1,
27+
apply_query_key_layer_scaling=True,
28+
attention_softmax_in_fp32=True,
29+
fp32_residual_connection=False,
30+
quantization_bit=0,
31+
pre_seq_len=None,
32+
prefix_projection=False,
33+
**kwargs
34+
):
35+
self.num_layers = num_layers
36+
self.vocab_size = padded_vocab_size
37+
self.padded_vocab_size = padded_vocab_size
38+
self.hidden_size = hidden_size
39+
self.ffn_hidden_size = ffn_hidden_size
40+
self.kv_channels = kv_channels
41+
self.num_attention_heads = num_attention_heads
42+
self.seq_length = seq_length
43+
self.hidden_dropout = hidden_dropout
44+
self.classifier_dropout = classifier_dropout
45+
self.attention_dropout = attention_dropout
46+
self.layernorm_epsilon = layernorm_epsilon
47+
self.rmsnorm = rmsnorm
48+
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
49+
self.post_layer_norm = post_layer_norm
50+
self.add_bias_linear = add_bias_linear
51+
self.add_qkv_bias = add_qkv_bias
52+
self.bias_dropout_fusion = bias_dropout_fusion
53+
self.multi_query_attention = multi_query_attention
54+
self.multi_query_group_num = multi_query_group_num
55+
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
56+
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
57+
self.fp32_residual_connection = fp32_residual_connection
58+
self.quantization_bit = quantization_bit
59+
self.pre_seq_len = pre_seq_len
60+
self.prefix_projection = prefix_projection
61+
super().__init__(**kwargs)

0 commit comments

Comments
 (0)