Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Anthropic Claude API support #1239

Merged
merged 2 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 95 additions & 6 deletions memgpt/cli/cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from memgpt.llm_api.openai import openai_get_model_list
from memgpt.llm_api.azure_openai import azure_openai_get_model_list
from memgpt.llm_api.google_ai import google_ai_get_model_list, google_ai_get_model_context_window
from memgpt.llm_api.anthropic import anthropic_get_model_list, antropic_get_model_context_window
from memgpt.llm_api.llm_api_tools import LLM_API_PROVIDER_OPTIONS
from memgpt.local_llm.constants import DEFAULT_ENDPOINTS, DEFAULT_OLLAMA_MODEL, DEFAULT_WRAPPER_NAME
from memgpt.local_llm.utils import get_available_wrappers
from memgpt.server.utils import shorten_key_middle
Expand Down Expand Up @@ -64,14 +66,14 @@ def configure_llm_endpoint(config: MemGPTConfig, credentials: MemGPTCredentials)
# get default
default_model_endpoint_type = config.default_llm_config.model_endpoint_type
if config.default_llm_config.model_endpoint_type is not None and config.default_llm_config.model_endpoint_type not in [
"openai",
"azure",
"google_ai",
provider for provider in LLM_API_PROVIDER_OPTIONS if provider != "local"
]: # local model
default_model_endpoint_type = "local"

provider = questionary.select(
"Select LLM inference provider:", choices=["openai", "azure", "google_ai", "local"], default=default_model_endpoint_type
"Select LLM inference provider:",
choices=LLM_API_PROVIDER_OPTIONS,
default=default_model_endpoint_type,
).ask()
if provider is None:
raise KeyboardInterrupt
Expand Down Expand Up @@ -184,6 +186,46 @@ def configure_llm_endpoint(config: MemGPTConfig, credentials: MemGPTCredentials)

model_endpoint_type = "google_ai"

elif provider == "anthropic":
# check for key
if credentials.anthropic_key is None:
# allow key to get pulled from env vars
anthropic_api_key = os.getenv("ANTHROPIC_API_KEY", None)
# if we still can't find it, ask for it as input
if anthropic_api_key is None:
while anthropic_api_key is None or len(anthropic_api_key) == 0:
# Ask for API key as input
anthropic_api_key = questionary.password(
"Enter your Anthropic API key (starts with 'sk-', see https://console.anthropic.com/settings/keys):"
).ask()
if anthropic_api_key is None:
raise KeyboardInterrupt
credentials.anthropic_key = anthropic_api_key
credentials.save()
else:
# Give the user an opportunity to overwrite the key
anthropic_api_key = None
default_input = (
shorten_key_middle(credentials.anthropic_key) if credentials.anthropic_key.startswith("sk-") else credentials.anthropic_key
)
anthropic_api_key = questionary.password(
"Enter your Anthropic API key (starts with 'sk-', see https://console.anthropic.com/settings/keys):",
default=default_input,
).ask()
if anthropic_api_key is None:
raise KeyboardInterrupt
# If the user modified it, use the new one
if anthropic_api_key != default_input:
credentials.anthropic_key = anthropic_api_key
credentials.save()

model_endpoint_type = "anthropic"
model_endpoint = "https://api.anthropic.com/v1"
model_endpoint = questionary.text("Override default endpoint:", default=model_endpoint).ask()
if model_endpoint is None:
raise KeyboardInterrupt
provider = "anthropic"

else: # local models
# backend_options_old = ["webui", "webui-legacy", "llamacpp", "koboldcpp", "ollama", "lmstudio", "lmstudio-legacy", "vllm", "openai"]
backend_options = builtins.list(DEFAULT_ENDPOINTS.keys())
Expand Down Expand Up @@ -291,6 +333,12 @@ def get_model_options(
model_options = [mo for mo in model_options if str(mo).startswith("gemini") and "-pro" in str(mo)]
# model_options = ["gemini-pro"]

elif model_endpoint_type == "anthropic":
if credentials.anthropic_key is None:
raise ValueError("Missing Anthropic API key")
fetched_model_options = anthropic_get_model_list(url=model_endpoint, api_key=credentials.anthropic_key)
model_options = [obj["name"] for obj in fetched_model_options]

else:
# Attempt to do OpenAI endpoint style model fetching
# TODO support local auth with api-key header
Expand Down Expand Up @@ -382,6 +430,26 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_
if model is None:
raise KeyboardInterrupt

elif model_endpoint_type == "anthropic":
try:
fetched_model_options = get_model_options(
credentials=credentials, model_endpoint_type=model_endpoint_type, model_endpoint=model_endpoint
)
except Exception as e:
# NOTE: if this fails, it means the user's key is probably bad
typer.secho(
f"Failed to get model list from {model_endpoint} - make sure your API key and endpoints are correct!", fg=typer.colors.RED
)
raise e

model = questionary.select(
"Select default model:",
choices=fetched_model_options,
default=fetched_model_options[0],
).ask()
if model is None:
raise KeyboardInterrupt

else: # local models

# ask about local auth
Expand Down Expand Up @@ -522,8 +590,8 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_
fetched_context_window,
"custom",
]
except:
print(f"Failed to get model details for model '{model}' on Google AI API")
except Exception as e:
print(f"Failed to get model details for model '{model}' on Google AI API ({str(e)})")

context_window_input = questionary.select(
"Select your model's context window (see https://cloud.google.com/vertex-ai/generative-ai/docs/learn/model-versioning#gemini-model-versions):",
Expand All @@ -533,6 +601,27 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_
if context_window_input is None:
raise KeyboardInterrupt

elif model_endpoint_type == "anthropic":
try:
fetched_context_window = str(
antropic_get_model_context_window(url=model_endpoint, api_key=credentials.anthropic_key, model=model)
)
print(f"Got context window {fetched_context_window} for model {model}")
context_length_options = [
fetched_context_window,
"custom",
]
except Exception as e:
print(f"Failed to get model details for model '{model}' ({str(e)})")

context_window_input = questionary.select(
"Select your model's context window (see https://docs.anthropic.com/claude/docs/models-overview):",
choices=context_length_options,
default=context_length_options[0],
).ask()
if context_window_input is None:
raise KeyboardInterrupt

else:

# Ask the user to specify the context length
Expand Down
8 changes: 8 additions & 0 deletions memgpt/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ class MemGPTCredentials:
google_ai_key: Optional[str] = None
google_ai_service_endpoint: Optional[str] = None

# anthropic config
anthropic_key: Optional[str] = None

# azure config
azure_auth_type: str = "api_key"
azure_key: Optional[str] = None
Expand Down Expand Up @@ -77,6 +80,8 @@ def load(cls) -> "MemGPTCredentials":
# gemini
"google_ai_key": get_field(config, "google_ai", "key"),
"google_ai_service_endpoint": get_field(config, "google_ai", "service_endpoint"),
# anthropic
"anthropic_key": get_field(config, "anthropic", "key"),
# open llm
"openllm_auth_type": get_field(config, "openllm", "auth_type"),
"openllm_key": get_field(config, "openllm", "key"),
Expand Down Expand Up @@ -113,6 +118,9 @@ def save(self):
set_field(config, "google_ai", "key", self.google_ai_key)
set_field(config, "google_ai", "service_endpoint", self.google_ai_service_endpoint)

# anthropic
set_field(config, "anthropic", "key", self.anthropic_key)

# openllm config
set_field(config, "openllm", "auth_type", self.openllm_auth_type)
set_field(config, "openllm", "key", self.openllm_key)
Expand Down
72 changes: 72 additions & 0 deletions memgpt/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,78 @@ def to_openai_dict(self, max_tool_id_length=TOOL_CALL_ID_MAX_LEN) -> dict:

return openai_message

def to_anthropic_dict(self, inner_thoughts_xml_tag="thinking") -> dict:
# raise NotImplementedError

def add_xml_tag(string: str, xml_tag: Optional[str]):
# NOTE: Anthropic docs recommends using <thinking> tag when using CoT + tool use
return f"<{xml_tag}>{string}</{xml_tag}" if xml_tag else string

if self.role == "system":
raise ValueError(f"Anthropic 'system' role not supported")

elif self.role == "user":
assert all([v is not None for v in [self.text, self.role]]), vars(self)
anthropic_message = {
"content": self.text,
"role": self.role,
}
# Optional field, do not include if null
if self.name is not None:
anthropic_message["name"] = self.name

elif self.role == "assistant":
assert self.tool_calls is not None or self.text is not None
anthropic_message = {
"role": self.role,
}
content = []
if self.text is not None:
content.append(
{
"type": "text",
"text": add_xml_tag(string=self.text, xml_tag=inner_thoughts_xml_tag),
}
)
if self.tool_calls is not None:
for tool_call in self.tool_calls:
content.append(
{
"type": "tool_use",
"id": tool_call.id,
"name": tool_call.function["name"],
"input": json.loads(tool_call.function["arguments"]),
}
)

# If the only content was text, unpack it back into a singleton
# TODO
anthropic_message["content"] = content

# Optional fields, do not include if null
if self.name is not None:
anthropic_message["name"] = self.name

elif self.role == "tool":
# NOTE: Anthropic uses role "user" for "tool" responses
assert all([v is not None for v in [self.role, self.tool_call_id]]), vars(self)
anthropic_message = {
"role": "user", # NOTE: diff
"content": [
# TODO support error types etc
{
"type": "tool_result",
"tool_use_id": self.tool_call_id,
"content": self.text,
}
],
}

else:
raise ValueError(self.role)

return anthropic_message

def to_google_ai_dict(self, put_inner_thoughts_in_kwargs: bool = True) -> dict:
"""Go from Message class to Google AI REST message object

Expand Down
Loading
Loading