21
21
from memgpt .llm_api .openai import openai_get_model_list
22
22
from memgpt .llm_api .azure_openai import azure_openai_get_model_list
23
23
from memgpt .llm_api .google_ai import google_ai_get_model_list , google_ai_get_model_context_window
24
+ from memgpt .llm_api .anthropic import anthropic_get_model_list , antropic_get_model_context_window
25
+ from memgpt .llm_api .llm_api_tools import LLM_API_PROVIDER_OPTIONS
24
26
from memgpt .local_llm .constants import DEFAULT_ENDPOINTS , DEFAULT_OLLAMA_MODEL , DEFAULT_WRAPPER_NAME
25
27
from memgpt .local_llm .utils import get_available_wrappers
26
28
from memgpt .server .utils import shorten_key_middle
@@ -64,14 +66,14 @@ def configure_llm_endpoint(config: MemGPTConfig, credentials: MemGPTCredentials)
64
66
# get default
65
67
default_model_endpoint_type = config .default_llm_config .model_endpoint_type
66
68
if config .default_llm_config .model_endpoint_type is not None and config .default_llm_config .model_endpoint_type not in [
67
- "openai" ,
68
- "azure" ,
69
- "google_ai" ,
69
+ provider for provider in LLM_API_PROVIDER_OPTIONS if provider != "local"
70
70
]: # local model
71
71
default_model_endpoint_type = "local"
72
72
73
73
provider = questionary .select (
74
- "Select LLM inference provider:" , choices = ["openai" , "azure" , "google_ai" , "local" ], default = default_model_endpoint_type
74
+ "Select LLM inference provider:" ,
75
+ choices = LLM_API_PROVIDER_OPTIONS ,
76
+ default = default_model_endpoint_type ,
75
77
).ask ()
76
78
if provider is None :
77
79
raise KeyboardInterrupt
@@ -184,6 +186,46 @@ def configure_llm_endpoint(config: MemGPTConfig, credentials: MemGPTCredentials)
184
186
185
187
model_endpoint_type = "google_ai"
186
188
189
+ elif provider == "anthropic" :
190
+ # check for key
191
+ if credentials .anthropic_key is None :
192
+ # allow key to get pulled from env vars
193
+ anthropic_api_key = os .getenv ("ANTHROPIC_API_KEY" , None )
194
+ # if we still can't find it, ask for it as input
195
+ if anthropic_api_key is None :
196
+ while anthropic_api_key is None or len (anthropic_api_key ) == 0 :
197
+ # Ask for API key as input
198
+ anthropic_api_key = questionary .password (
199
+ "Enter your Anthropic API key (starts with 'sk-', see https://console.anthropic.com/settings/keys):"
200
+ ).ask ()
201
+ if anthropic_api_key is None :
202
+ raise KeyboardInterrupt
203
+ credentials .anthropic_key = anthropic_api_key
204
+ credentials .save ()
205
+ else :
206
+ # Give the user an opportunity to overwrite the key
207
+ anthropic_api_key = None
208
+ default_input = (
209
+ shorten_key_middle (credentials .anthropic_key ) if credentials .anthropic_key .startswith ("sk-" ) else credentials .anthropic_key
210
+ )
211
+ anthropic_api_key = questionary .password (
212
+ "Enter your Anthropic API key (starts with 'sk-', see https://console.anthropic.com/settings/keys):" ,
213
+ default = default_input ,
214
+ ).ask ()
215
+ if anthropic_api_key is None :
216
+ raise KeyboardInterrupt
217
+ # If the user modified it, use the new one
218
+ if anthropic_api_key != default_input :
219
+ credentials .anthropic_key = anthropic_api_key
220
+ credentials .save ()
221
+
222
+ model_endpoint_type = "anthropic"
223
+ model_endpoint = "https://api.anthropic.com/v1"
224
+ model_endpoint = questionary .text ("Override default endpoint:" , default = model_endpoint ).ask ()
225
+ if model_endpoint is None :
226
+ raise KeyboardInterrupt
227
+ provider = "anthropic"
228
+
187
229
else : # local models
188
230
# backend_options_old = ["webui", "webui-legacy", "llamacpp", "koboldcpp", "ollama", "lmstudio", "lmstudio-legacy", "vllm", "openai"]
189
231
backend_options = builtins .list (DEFAULT_ENDPOINTS .keys ())
@@ -291,6 +333,12 @@ def get_model_options(
291
333
model_options = [mo for mo in model_options if str (mo ).startswith ("gemini" ) and "-pro" in str (mo )]
292
334
# model_options = ["gemini-pro"]
293
335
336
+ elif model_endpoint_type == "anthropic" :
337
+ if credentials .anthropic_key is None :
338
+ raise ValueError ("Missing Anthropic API key" )
339
+ fetched_model_options = anthropic_get_model_list (url = model_endpoint , api_key = credentials .anthropic_key )
340
+ model_options = [obj ["name" ] for obj in fetched_model_options ]
341
+
294
342
else :
295
343
# Attempt to do OpenAI endpoint style model fetching
296
344
# TODO support local auth with api-key header
@@ -382,6 +430,26 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_
382
430
if model is None :
383
431
raise KeyboardInterrupt
384
432
433
+ elif model_endpoint_type == "anthropic" :
434
+ try :
435
+ fetched_model_options = get_model_options (
436
+ credentials = credentials , model_endpoint_type = model_endpoint_type , model_endpoint = model_endpoint
437
+ )
438
+ except Exception as e :
439
+ # NOTE: if this fails, it means the user's key is probably bad
440
+ typer .secho (
441
+ f"Failed to get model list from { model_endpoint } - make sure your API key and endpoints are correct!" , fg = typer .colors .RED
442
+ )
443
+ raise e
444
+
445
+ model = questionary .select (
446
+ "Select default model:" ,
447
+ choices = fetched_model_options ,
448
+ default = fetched_model_options [0 ],
449
+ ).ask ()
450
+ if model is None :
451
+ raise KeyboardInterrupt
452
+
385
453
else : # local models
386
454
387
455
# ask about local auth
@@ -522,8 +590,8 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_
522
590
fetched_context_window ,
523
591
"custom" ,
524
592
]
525
- except :
526
- print (f"Failed to get model details for model '{ model } ' on Google AI API" )
593
+ except Exception as e :
594
+ print (f"Failed to get model details for model '{ model } ' on Google AI API ( { str ( e ) } ) " )
527
595
528
596
context_window_input = questionary .select (
529
597
"Select your model's context window (see https://cloud.google.com/vertex-ai/generative-ai/docs/learn/model-versioning#gemini-model-versions):" ,
@@ -533,6 +601,27 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_
533
601
if context_window_input is None :
534
602
raise KeyboardInterrupt
535
603
604
+ elif model_endpoint_type == "anthropic" :
605
+ try :
606
+ fetched_context_window = str (
607
+ antropic_get_model_context_window (url = model_endpoint , api_key = credentials .anthropic_key , model = model )
608
+ )
609
+ print (f"Got context window { fetched_context_window } for model { model } " )
610
+ context_length_options = [
611
+ fetched_context_window ,
612
+ "custom" ,
613
+ ]
614
+ except Exception as e :
615
+ print (f"Failed to get model details for model '{ model } ' ({ str (e )} )" )
616
+
617
+ context_window_input = questionary .select (
618
+ "Select your model's context window (see https://docs.anthropic.com/claude/docs/models-overview):" ,
619
+ choices = context_length_options ,
620
+ default = context_length_options [0 ],
621
+ ).ask ()
622
+ if context_window_input is None :
623
+ raise KeyboardInterrupt
624
+
536
625
else :
537
626
538
627
# Ask the user to specify the context length
0 commit comments