Skip to content

Commit 675a090

Browse files
authored
feat: Anthropic Claude API support (#1239)
1 parent 5ccba23 commit 675a090

File tree

6 files changed

+581
-7
lines changed

6 files changed

+581
-7
lines changed

memgpt/cli/cli_config.py

+95-6
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from memgpt.llm_api.openai import openai_get_model_list
2222
from memgpt.llm_api.azure_openai import azure_openai_get_model_list
2323
from memgpt.llm_api.google_ai import google_ai_get_model_list, google_ai_get_model_context_window
24+
from memgpt.llm_api.anthropic import anthropic_get_model_list, antropic_get_model_context_window
25+
from memgpt.llm_api.llm_api_tools import LLM_API_PROVIDER_OPTIONS
2426
from memgpt.local_llm.constants import DEFAULT_ENDPOINTS, DEFAULT_OLLAMA_MODEL, DEFAULT_WRAPPER_NAME
2527
from memgpt.local_llm.utils import get_available_wrappers
2628
from memgpt.server.utils import shorten_key_middle
@@ -64,14 +66,14 @@ def configure_llm_endpoint(config: MemGPTConfig, credentials: MemGPTCredentials)
6466
# get default
6567
default_model_endpoint_type = config.default_llm_config.model_endpoint_type
6668
if config.default_llm_config.model_endpoint_type is not None and config.default_llm_config.model_endpoint_type not in [
67-
"openai",
68-
"azure",
69-
"google_ai",
69+
provider for provider in LLM_API_PROVIDER_OPTIONS if provider != "local"
7070
]: # local model
7171
default_model_endpoint_type = "local"
7272

7373
provider = questionary.select(
74-
"Select LLM inference provider:", choices=["openai", "azure", "google_ai", "local"], default=default_model_endpoint_type
74+
"Select LLM inference provider:",
75+
choices=LLM_API_PROVIDER_OPTIONS,
76+
default=default_model_endpoint_type,
7577
).ask()
7678
if provider is None:
7779
raise KeyboardInterrupt
@@ -184,6 +186,46 @@ def configure_llm_endpoint(config: MemGPTConfig, credentials: MemGPTCredentials)
184186

185187
model_endpoint_type = "google_ai"
186188

189+
elif provider == "anthropic":
190+
# check for key
191+
if credentials.anthropic_key is None:
192+
# allow key to get pulled from env vars
193+
anthropic_api_key = os.getenv("ANTHROPIC_API_KEY", None)
194+
# if we still can't find it, ask for it as input
195+
if anthropic_api_key is None:
196+
while anthropic_api_key is None or len(anthropic_api_key) == 0:
197+
# Ask for API key as input
198+
anthropic_api_key = questionary.password(
199+
"Enter your Anthropic API key (starts with 'sk-', see https://console.anthropic.com/settings/keys):"
200+
).ask()
201+
if anthropic_api_key is None:
202+
raise KeyboardInterrupt
203+
credentials.anthropic_key = anthropic_api_key
204+
credentials.save()
205+
else:
206+
# Give the user an opportunity to overwrite the key
207+
anthropic_api_key = None
208+
default_input = (
209+
shorten_key_middle(credentials.anthropic_key) if credentials.anthropic_key.startswith("sk-") else credentials.anthropic_key
210+
)
211+
anthropic_api_key = questionary.password(
212+
"Enter your Anthropic API key (starts with 'sk-', see https://console.anthropic.com/settings/keys):",
213+
default=default_input,
214+
).ask()
215+
if anthropic_api_key is None:
216+
raise KeyboardInterrupt
217+
# If the user modified it, use the new one
218+
if anthropic_api_key != default_input:
219+
credentials.anthropic_key = anthropic_api_key
220+
credentials.save()
221+
222+
model_endpoint_type = "anthropic"
223+
model_endpoint = "https://api.anthropic.com/v1"
224+
model_endpoint = questionary.text("Override default endpoint:", default=model_endpoint).ask()
225+
if model_endpoint is None:
226+
raise KeyboardInterrupt
227+
provider = "anthropic"
228+
187229
else: # local models
188230
# backend_options_old = ["webui", "webui-legacy", "llamacpp", "koboldcpp", "ollama", "lmstudio", "lmstudio-legacy", "vllm", "openai"]
189231
backend_options = builtins.list(DEFAULT_ENDPOINTS.keys())
@@ -291,6 +333,12 @@ def get_model_options(
291333
model_options = [mo for mo in model_options if str(mo).startswith("gemini") and "-pro" in str(mo)]
292334
# model_options = ["gemini-pro"]
293335

336+
elif model_endpoint_type == "anthropic":
337+
if credentials.anthropic_key is None:
338+
raise ValueError("Missing Anthropic API key")
339+
fetched_model_options = anthropic_get_model_list(url=model_endpoint, api_key=credentials.anthropic_key)
340+
model_options = [obj["name"] for obj in fetched_model_options]
341+
294342
else:
295343
# Attempt to do OpenAI endpoint style model fetching
296344
# TODO support local auth with api-key header
@@ -382,6 +430,26 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_
382430
if model is None:
383431
raise KeyboardInterrupt
384432

433+
elif model_endpoint_type == "anthropic":
434+
try:
435+
fetched_model_options = get_model_options(
436+
credentials=credentials, model_endpoint_type=model_endpoint_type, model_endpoint=model_endpoint
437+
)
438+
except Exception as e:
439+
# NOTE: if this fails, it means the user's key is probably bad
440+
typer.secho(
441+
f"Failed to get model list from {model_endpoint} - make sure your API key and endpoints are correct!", fg=typer.colors.RED
442+
)
443+
raise e
444+
445+
model = questionary.select(
446+
"Select default model:",
447+
choices=fetched_model_options,
448+
default=fetched_model_options[0],
449+
).ask()
450+
if model is None:
451+
raise KeyboardInterrupt
452+
385453
else: # local models
386454

387455
# ask about local auth
@@ -522,8 +590,8 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_
522590
fetched_context_window,
523591
"custom",
524592
]
525-
except:
526-
print(f"Failed to get model details for model '{model}' on Google AI API")
593+
except Exception as e:
594+
print(f"Failed to get model details for model '{model}' on Google AI API ({str(e)})")
527595

528596
context_window_input = questionary.select(
529597
"Select your model's context window (see https://cloud.google.com/vertex-ai/generative-ai/docs/learn/model-versioning#gemini-model-versions):",
@@ -533,6 +601,27 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_
533601
if context_window_input is None:
534602
raise KeyboardInterrupt
535603

604+
elif model_endpoint_type == "anthropic":
605+
try:
606+
fetched_context_window = str(
607+
antropic_get_model_context_window(url=model_endpoint, api_key=credentials.anthropic_key, model=model)
608+
)
609+
print(f"Got context window {fetched_context_window} for model {model}")
610+
context_length_options = [
611+
fetched_context_window,
612+
"custom",
613+
]
614+
except Exception as e:
615+
print(f"Failed to get model details for model '{model}' ({str(e)})")
616+
617+
context_window_input = questionary.select(
618+
"Select your model's context window (see https://docs.anthropic.com/claude/docs/models-overview):",
619+
choices=context_length_options,
620+
default=context_length_options[0],
621+
).ask()
622+
if context_window_input is None:
623+
raise KeyboardInterrupt
624+
536625
else:
537626

538627
# Ask the user to specify the context length

memgpt/credentials.py

+8
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ class MemGPTCredentials:
3232
google_ai_key: Optional[str] = None
3333
google_ai_service_endpoint: Optional[str] = None
3434

35+
# anthropic config
36+
anthropic_key: Optional[str] = None
37+
3538
# azure config
3639
azure_auth_type: str = "api_key"
3740
azure_key: Optional[str] = None
@@ -77,6 +80,8 @@ def load(cls) -> "MemGPTCredentials":
7780
# gemini
7881
"google_ai_key": get_field(config, "google_ai", "key"),
7982
"google_ai_service_endpoint": get_field(config, "google_ai", "service_endpoint"),
83+
# anthropic
84+
"anthropic_key": get_field(config, "anthropic", "key"),
8085
# open llm
8186
"openllm_auth_type": get_field(config, "openllm", "auth_type"),
8287
"openllm_key": get_field(config, "openllm", "key"),
@@ -113,6 +118,9 @@ def save(self):
113118
set_field(config, "google_ai", "key", self.google_ai_key)
114119
set_field(config, "google_ai", "service_endpoint", self.google_ai_service_endpoint)
115120

121+
# anthropic
122+
set_field(config, "anthropic", "key", self.anthropic_key)
123+
116124
# openllm config
117125
set_field(config, "openllm", "auth_type", self.openllm_auth_type)
118126
set_field(config, "openllm", "key", self.openllm_key)

memgpt/data_types.py

+72
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,78 @@ def to_openai_dict(self, max_tool_id_length=TOOL_CALL_ID_MAX_LEN) -> dict:
294294

295295
return openai_message
296296

297+
def to_anthropic_dict(self, inner_thoughts_xml_tag="thinking") -> dict:
298+
# raise NotImplementedError
299+
300+
def add_xml_tag(string: str, xml_tag: Optional[str]):
301+
# NOTE: Anthropic docs recommends using <thinking> tag when using CoT + tool use
302+
return f"<{xml_tag}>{string}</{xml_tag}" if xml_tag else string
303+
304+
if self.role == "system":
305+
raise ValueError(f"Anthropic 'system' role not supported")
306+
307+
elif self.role == "user":
308+
assert all([v is not None for v in [self.text, self.role]]), vars(self)
309+
anthropic_message = {
310+
"content": self.text,
311+
"role": self.role,
312+
}
313+
# Optional field, do not include if null
314+
if self.name is not None:
315+
anthropic_message["name"] = self.name
316+
317+
elif self.role == "assistant":
318+
assert self.tool_calls is not None or self.text is not None
319+
anthropic_message = {
320+
"role": self.role,
321+
}
322+
content = []
323+
if self.text is not None:
324+
content.append(
325+
{
326+
"type": "text",
327+
"text": add_xml_tag(string=self.text, xml_tag=inner_thoughts_xml_tag),
328+
}
329+
)
330+
if self.tool_calls is not None:
331+
for tool_call in self.tool_calls:
332+
content.append(
333+
{
334+
"type": "tool_use",
335+
"id": tool_call.id,
336+
"name": tool_call.function["name"],
337+
"input": json.loads(tool_call.function["arguments"]),
338+
}
339+
)
340+
341+
# If the only content was text, unpack it back into a singleton
342+
# TODO
343+
anthropic_message["content"] = content
344+
345+
# Optional fields, do not include if null
346+
if self.name is not None:
347+
anthropic_message["name"] = self.name
348+
349+
elif self.role == "tool":
350+
# NOTE: Anthropic uses role "user" for "tool" responses
351+
assert all([v is not None for v in [self.role, self.tool_call_id]]), vars(self)
352+
anthropic_message = {
353+
"role": "user", # NOTE: diff
354+
"content": [
355+
# TODO support error types etc
356+
{
357+
"type": "tool_result",
358+
"tool_use_id": self.tool_call_id,
359+
"content": self.text,
360+
}
361+
],
362+
}
363+
364+
else:
365+
raise ValueError(self.role)
366+
367+
return anthropic_message
368+
297369
def to_google_ai_dict(self, put_inner_thoughts_in_kwargs: bool = True) -> dict:
298370
"""Go from Message class to Google AI REST message object
299371

0 commit comments

Comments
 (0)