Skip to content

Commit 28e9d11

Browse files
raspawarlramesh-2409
andauthoredMar 13, 2025
Python: Nvidia Embedding Connector (microsoft#10410)
### Motivation and Context This PR adds an Nvidia Embedding Connector. This connector enables integration with NVIDIA NIM API for text embeddings. It allows you to use NVIDIA's embedding models within the Semantic Kernel framework. ### Description ### Contribution Checklist <!-- Before submitting this PR, please make sure: --> - [X] The code builds clean without any errors or warnings - [X] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [X] All unit tests pass, and I have added new tests where possible - [X] I didn't break anyone 😄 --------- Co-authored-by: Lakshmi Ramesh <[email protected]>
1 parent 044cb33 commit 28e9d11

File tree

13 files changed

+526
-0
lines changed

13 files changed

+526
-0
lines changed
 

‎python/samples/concepts/setup/ALL_SETTINGS.md

+1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
| | [VertexAITextEmbedding](../../../semantic_kernel/connectors/ai/google/google_ai/services/google_ai_text_embedding.py) | project_id, <br> region, <br> embedding_model_id | VERTEX_AI_PROJECT_ID, <br> VERTEX_AI_REGION, <br> VERTEX_AI_EMBEDDING_MODEL_ID | Yes, <br> No, <br> Yes | |
3131
| HuggingFace | [HuggingFaceTextCompletion](../../../semantic_kernel/connectors/ai/hugging_face/services/hf_text_completion.py) | ai_model_id | N/A | Yes | |
3232
| | [HuggingFaceTextEmbedding](../../../semantic_kernel/connectors/ai/hugging_face/services/hf_text_embedding.py) | ai_model_id | N/A | Yes | |
33+
| NVIDIA NIM | [NvidiaTextEmbedding](../../../semantic_kernel/connectors/ai/nvidia/services/nvidia_text_embedding.py) | ai_model_id, <br> api_key, <br> base_url | NVIDIA_API_KEY, <br> NVIDIA_TEXT_EMBEDDING_MODEL_ID, <br> NVIDIA_BASE_URL | Yes | [NvidiaAISettings](../../../semantic_kernel/connectors/ai/nvidia/settings/nvidia_settings.py) |
3334
| Mistral AI | [MistralAIChatCompletion](../../../semantic_kernel/connectors/ai/mistral_ai/services/mistral_ai_chat_completion.py) | ai_model_id, <br> api_key | MISTRALAI_CHAT_MODEL_ID, <br> MISTRALAI_API_KEY | Yes, <br> Yes | [MistralAISettings](../../../semantic_kernel/connectors/ai/mistral_ai/settings/mistral_ai_settings.py) |
3435
| | [MistralAITextEmbedding](../../../semantic_kernel/connectors/ai/mistral_ai/services/mistral_ai_text_embedding.py) | ai_model_id, <br> api_key | MISTRALAI_EMBEDDING_MODEL_ID, <br> MISTRALAI_API_KEY | Yes, <br> Yes | |
3536
| Ollama | [OllamaChatCompletion](../../../semantic_kernel/connectors/ai/ollama/services/ollama_chat_completion.py) | ai_model_id, <br> host | OLLAMA_CHAT_MODEL_ID, <br> OLLAMA_HOST | Yes, <br> No | [OllamaSettings](../../../semantic_kernel/connectors/ai/ollama/ollama_settings.py) |

‎python/semantic_kernel/connectors/ai/README.md

+1
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ All base clients inherit from the [`AIServiceClientBase`](../../services/ai_serv
4545
| | [`HuggingFaceTextEmbedding`](./hugging_face/services/hf_text_embedding.py) |
4646
| Mistral AI | [`MistralAIChatCompletion`](./mistral_ai/services/mistral_ai_chat_completion.py) |
4747
| | [`MistralAITextEmbedding`](./mistral_ai/services/mistral_ai_text_embedding.py) |
48+
| [Nvidia](./nvidia/README.md) | [`NvidiaTextEmbedding`](./nvidia/services/nvidia_text_embedding.py) |
4849
| Ollama | [`OllamaChatCompletion`](./ollama/services/ollama_chat_completion.py) |
4950
| | [`OllamaTextCompletion`](./ollama/services/ollama_text_completion.py) |
5051
| | [`OllamaTextEmbedding`](./ollama/services/ollama_text_embedding.py) |
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# semantic_kernel.connectors.ai.nvidia
2+
3+
This connector enables integration with NVIDIA NIM API for text embeddings. It allows you to use NVIDIA's embedding models within the Semantic Kernel framework.
4+
5+
## Quick start
6+
7+
### Initialize the kernel
8+
```python
9+
import semantic_kernel as sk
10+
kernel = sk.Kernel()
11+
```
12+
13+
### Add NVIDIA text embedding service
14+
You can provide your API key directly or through environment variables
15+
```python
16+
embedding_service = NvidiaTextEmbedding(
17+
ai_model_id="nvidia/nv-embedqa-e5-v5", # Default model if not specified
18+
api_key="your-nvidia-api-key", # Can also use NVIDIA_API_KEY env variable
19+
service_id="nvidia-embeddings" # Optional service identifier
20+
)
21+
```
22+
23+
### Add the embedding service to the kernel
24+
```python
25+
kernel.add_service(embedding_service)
26+
```
27+
28+
### Generate embeddings for text
29+
```python
30+
texts = ["Hello, world!", "Semantic Kernel is awesome"]
31+
embeddings = await kernel.get_service("nvidia-embeddings").generate_embeddings(texts)
32+
```
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright (c) Microsoft. All rights reserved.
2+
3+
from semantic_kernel.connectors.ai.nvidia.prompt_execution_settings.nvidia_prompt_execution_settings import (
4+
NvidiaEmbeddingPromptExecutionSettings,
5+
NvidiaPromptExecutionSettings,
6+
)
7+
from semantic_kernel.connectors.ai.nvidia.services.nvidia_text_embedding import NvidiaTextEmbedding
8+
from semantic_kernel.connectors.ai.nvidia.settings.nvidia_settings import NvidiaSettings
9+
10+
__all__ = [
11+
"NvidiaEmbeddingPromptExecutionSettings",
12+
"NvidiaPromptExecutionSettings",
13+
"NvidiaSettings",
14+
"NvidiaTextEmbedding",
15+
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Copyright (c) Microsoft. All rights reserved.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright (c) Microsoft. All rights reserved.
2+
3+
from typing import Annotated, Any, Literal
4+
5+
from pydantic import Field
6+
7+
from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings
8+
9+
10+
class NvidiaPromptExecutionSettings(PromptExecutionSettings):
11+
"""Settings for NVIDIA prompt execution."""
12+
13+
format: Literal["json"] | None = None
14+
options: dict[str, Any] | None = None
15+
16+
def prepare_settings_dict(self, **kwargs) -> dict[str, Any]:
17+
"""Prepare the settings as a dictionary for sending to the AI service.
18+
19+
By default, this method excludes the service_id and extension_data fields.
20+
As well as any fields that are None.
21+
"""
22+
return self.model_dump(
23+
exclude={"service_id", "extension_data", "structured_json_response", "input_type", "truncate"},
24+
exclude_none=True,
25+
by_alias=True,
26+
)
27+
28+
29+
class NvidiaEmbeddingPromptExecutionSettings(NvidiaPromptExecutionSettings):
30+
"""Settings for NVIDIA embedding prompt execution."""
31+
32+
input: str | list[str] | None = None
33+
ai_model_id: Annotated[str | None, Field(serialization_alias="model")] = None
34+
encoding_format: Literal["float", "base64"] = "float"
35+
truncate: Literal["NONE", "START", "END"] = "NONE"
36+
input_type: Literal["passage", "query"] = "query" # required param with default value query
37+
user: str | None = None
38+
extra_headers: dict | None = None
39+
extra_body: dict | None = None
40+
timeout: float | None = None
41+
dimensions: Annotated[int | None, Field(gt=0)] = None
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Copyright (c) Microsoft. All rights reserved.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Copyright (c) Microsoft. All rights reserved.
2+
3+
import logging
4+
from abc import ABC
5+
from typing import Any, ClassVar, Union
6+
7+
from openai import AsyncOpenAI, AsyncStream
8+
from openai.types import CreateEmbeddingResponse
9+
10+
from semantic_kernel.connectors.ai.nvidia import (
11+
NvidiaPromptExecutionSettings,
12+
)
13+
from semantic_kernel.connectors.ai.nvidia.services.nvidia_model_types import NvidiaModelTypes
14+
from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings
15+
from semantic_kernel.const import USER_AGENT
16+
from semantic_kernel.exceptions import ServiceResponseException
17+
from semantic_kernel.kernel_pydantic import KernelBaseModel
18+
19+
logger: logging.Logger = logging.getLogger(__name__)
20+
21+
RESPONSE_TYPE = Union[list[Any],]
22+
23+
24+
class NvidiaHandler(KernelBaseModel, ABC):
25+
"""Internal class for calls to Nvidia API's."""
26+
27+
MODEL_PROVIDER_NAME: ClassVar[str] = "nvidia"
28+
client: AsyncOpenAI
29+
ai_model_type: NvidiaModelTypes = (
30+
NvidiaModelTypes.EMBEDDING
31+
) # TODO: revert this to chat after adding support for chat-compl # noqa: TD002
32+
prompt_tokens: int = 0
33+
completion_tokens: int = 0
34+
total_tokens: int = 0
35+
36+
async def _send_request(self, settings: PromptExecutionSettings) -> RESPONSE_TYPE:
37+
"""Send a request to the Nvidia API."""
38+
if self.ai_model_type == NvidiaModelTypes.EMBEDDING:
39+
assert isinstance(settings, NvidiaPromptExecutionSettings) # nosec
40+
return await self._send_embedding_request(settings)
41+
42+
raise NotImplementedError(f"Model type {self.ai_model_type} is not supported")
43+
44+
async def _send_embedding_request(self, settings: NvidiaPromptExecutionSettings) -> list[Any]:
45+
"""Send a request to the OpenAI embeddings endpoint."""
46+
try:
47+
# unsupported parameters are internally excluded from main dict and added to extra_body
48+
response = await self.client.embeddings.create(**settings.prepare_settings_dict())
49+
50+
self.store_usage(response)
51+
return [x.embedding for x in response.data]
52+
except Exception as ex:
53+
raise ServiceResponseException(
54+
f"{type(self)} service failed to generate embeddings",
55+
ex,
56+
) from ex
57+
58+
def store_usage(
59+
self,
60+
response: CreateEmbeddingResponse,
61+
):
62+
"""Store the usage information from the response."""
63+
if not isinstance(response, AsyncStream) and response.usage:
64+
logger.info(f"OpenAI usage: {response.usage}")
65+
self.prompt_tokens += response.usage.prompt_tokens
66+
self.total_tokens += response.usage.total_tokens
67+
if hasattr(response.usage, "completion_tokens"):
68+
self.completion_tokens += response.usage.completion_tokens
69+
70+
def to_dict(self) -> dict[str, str]:
71+
"""Create a dict of the service settings."""
72+
client_settings = {
73+
"api_key": self.client.api_key,
74+
"default_headers": {k: v for k, v in self.client.default_headers.items() if k != USER_AGENT},
75+
}
76+
if self.client.organization:
77+
client_settings["org_id"] = self.client.organization
78+
base = self.model_dump(
79+
exclude={
80+
"prompt_tokens",
81+
"completion_tokens",
82+
"total_tokens",
83+
"api_type",
84+
"ai_model_type",
85+
"service_id",
86+
"client",
87+
},
88+
by_alias=True,
89+
exclude_none=True,
90+
)
91+
base.update(client_settings)
92+
return base
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Copyright (c) Microsoft. All rights reserved.
2+
3+
from enum import Enum
4+
5+
6+
class NvidiaModelTypes(Enum):
7+
"""Nvidia model types, can be text, chat or embedding."""
8+
9+
EMBEDDING = "embedding"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
# Copyright (c) Microsoft. All rights reserved.
2+
3+
import asyncio
4+
import copy
5+
import logging
6+
import sys
7+
from typing import Any
8+
9+
if sys.version_info >= (3, 12):
10+
from typing import override # pragma: no cover
11+
else:
12+
from typing_extensions import override # pragma: no cover
13+
14+
from numpy import array, ndarray
15+
from openai import AsyncOpenAI
16+
from pydantic import ValidationError
17+
18+
from semantic_kernel.connectors.ai.embeddings.embedding_generator_base import EmbeddingGeneratorBase
19+
from semantic_kernel.connectors.ai.nvidia.prompt_execution_settings.nvidia_prompt_execution_settings import (
20+
NvidiaEmbeddingPromptExecutionSettings,
21+
)
22+
from semantic_kernel.connectors.ai.nvidia.services.nvidia_handler import NvidiaHandler
23+
from semantic_kernel.connectors.ai.nvidia.services.nvidia_model_types import NvidiaModelTypes
24+
from semantic_kernel.connectors.ai.nvidia.settings.nvidia_settings import NvidiaSettings
25+
from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings
26+
from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError
27+
from semantic_kernel.utils.feature_stage_decorator import experimental
28+
29+
logger: logging.Logger = logging.getLogger(__name__)
30+
31+
32+
@experimental
33+
class NvidiaTextEmbedding(NvidiaHandler, EmbeddingGeneratorBase):
34+
"""Nvidia text embedding service."""
35+
36+
def __init__(
37+
self,
38+
ai_model_id: str | None = None,
39+
api_key: str | None = None,
40+
base_url: str | None = None,
41+
client: AsyncOpenAI | None = None,
42+
env_file_path: str | None = None,
43+
service_id: str | None = None,
44+
) -> None:
45+
"""Initializes a new instance of the NvidiaTextEmbedding class.
46+
47+
Args:
48+
ai_model_id (str): NVIDIA model card string, see
49+
https://Nvidia.co/sentence-transformers
50+
api_key: NVIDIA API key, see https://console.NVIDIA.com/settings/keys
51+
(Env var NVIDIA_API_KEY)
52+
base_url: HttpsUrl | None - base_url: The url of the NVIDIA endpoint. The base_url consists of the endpoint,
53+
and more information refer https://docs.api.nvidia.com/nim/reference/
54+
use endpoint if you only want to supply the endpoint.
55+
(Env var NVIDIA_BASE_URL)
56+
client (Optional[AsyncOpenAI]): An existing client to use. (Optional)
57+
env_file_path (str | None): Use the environment settings file as
58+
a fallback to environment variables. (Optional)
59+
service_id (str): Service ID for the model. (optional)
60+
"""
61+
try:
62+
nvidia_settings = NvidiaSettings.create(
63+
api_key=api_key,
64+
base_url=base_url,
65+
embedding_model_id=ai_model_id,
66+
env_file_path=env_file_path,
67+
)
68+
except ValidationError as ex:
69+
raise ServiceInitializationError("Failed to create NVIDIA settings.", ex) from ex
70+
if not nvidia_settings.embedding_model_id:
71+
nvidia_settings.embedding_model_id = "nvidia/nv-embedqa-e5-v5"
72+
logger.warning(f"Default embedding model set as: {nvidia_settings.embedding_model_id}")
73+
if not nvidia_settings.api_key:
74+
logger.warning("API_KEY is missing, inference may fail.")
75+
if not client:
76+
client = AsyncOpenAI(api_key=nvidia_settings.api_key.get_secret_value(), base_url=nvidia_settings.base_url)
77+
super().__init__(
78+
ai_model_id=nvidia_settings.embedding_model_id,
79+
api_key=nvidia_settings.api_key.get_secret_value() if nvidia_settings.api_key else None,
80+
ai_model_type=NvidiaModelTypes.EMBEDDING,
81+
service_id=service_id or nvidia_settings.embedding_model_id,
82+
env_file_path=env_file_path,
83+
client=client,
84+
)
85+
86+
@override
87+
async def generate_embeddings(
88+
self,
89+
texts: list[str],
90+
settings: "PromptExecutionSettings | None" = None,
91+
batch_size: int | None = None,
92+
**kwargs: Any,
93+
) -> ndarray:
94+
raw_embeddings = await self.generate_raw_embeddings(texts, settings, batch_size, **kwargs)
95+
return array(raw_embeddings)
96+
97+
@override
98+
async def generate_raw_embeddings(
99+
self,
100+
texts: list[str],
101+
settings: "PromptExecutionSettings | None" = None,
102+
batch_size: int | None = None,
103+
**kwargs: Any,
104+
) -> Any:
105+
"""Returns embeddings for the given texts in the unedited format.
106+
107+
Args:
108+
texts (List[str]): The texts to generate embeddings for.
109+
settings (NvidiaEmbeddingPromptExecutionSettings): The settings to use for the request.
110+
batch_size (int): The batch size to use for the request.
111+
kwargs (Dict[str, Any]): Additional arguments to pass to the request.
112+
"""
113+
if not settings:
114+
settings = NvidiaEmbeddingPromptExecutionSettings(ai_model_id=self.ai_model_id)
115+
else:
116+
if not isinstance(settings, NvidiaEmbeddingPromptExecutionSettings):
117+
settings = self.get_prompt_execution_settings_from_settings(settings)
118+
assert isinstance(settings, NvidiaEmbeddingPromptExecutionSettings) # nosec
119+
if settings.ai_model_id is None:
120+
settings.ai_model_id = self.ai_model_id
121+
for key, value in kwargs.items():
122+
setattr(settings, key, value)
123+
124+
# move input_type and truncate to extra-body
125+
if not settings.extra_body:
126+
settings.extra_body = {}
127+
settings.extra_body.setdefault("input_type", settings.input_type)
128+
if settings.truncate is not None:
129+
settings.extra_body.setdefault("truncate", settings.truncate)
130+
131+
raw_embeddings = []
132+
tasks = []
133+
134+
batch_size = batch_size or len(texts)
135+
for i in range(0, len(texts), batch_size):
136+
batch = texts[i : i + batch_size]
137+
batch_settings = copy.deepcopy(settings)
138+
batch_settings.input = batch
139+
tasks.append(self._send_request(settings=batch_settings))
140+
141+
results = await asyncio.gather(*tasks)
142+
for raw_embedding in results:
143+
assert isinstance(raw_embedding, list) # nosec
144+
raw_embeddings.extend(raw_embedding)
145+
146+
return raw_embeddings
147+
148+
def get_prompt_execution_settings_class(self) -> type["PromptExecutionSettings"]:
149+
"""Get the request settings class."""
150+
return NvidiaEmbeddingPromptExecutionSettings
151+
152+
@classmethod
153+
def from_dict(cls: type["NvidiaTextEmbedding"], settings: dict[str, Any]) -> "NvidiaTextEmbedding":
154+
"""Initialize an Open AI service from a dictionary of settings.
155+
156+
Args:
157+
settings: A dictionary of settings for the service.
158+
"""
159+
return cls(
160+
ai_model_id=settings.get("ai_model_id"),
161+
api_key=settings.get("api_key"),
162+
env_file_path=settings.get("env_file_path"),
163+
service_id=settings.get("service_id"),
164+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Copyright (c) Microsoft. All rights reserved.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright (c) Microsoft. All rights reserved.
2+
3+
from typing import ClassVar
4+
5+
from pydantic import SecretStr
6+
7+
from semantic_kernel.kernel_pydantic import KernelBaseSettings
8+
9+
10+
class NvidiaSettings(KernelBaseSettings):
11+
"""Nvidia model settings.
12+
13+
The settings are first loaded from environment variables with the prefix 'NVIDIA_'. If the
14+
environment variables are not found, the settings can be loaded from a .env file with the
15+
encoding 'utf-8'. If the settings are not found in the .env file, the settings are ignored;
16+
however, validation will fail alerting that the settings are missing.
17+
18+
Optional settings for prefix 'NVIDIA_' are:
19+
- api_key: NVIDIA API key, see https://console.NVIDIA.com/settings/keys
20+
(Env var NVIDIA_API_KEY)
21+
- base_url: HttpsUrl | None - base_url: The url of the NVIDIA endpoint. The base_url consists of the endpoint,
22+
and more information refer https://docs.api.nvidia.com/nim/reference/
23+
use endpoint if you only want to supply the endpoint.
24+
(Env var NVIDIA_BASE_URL)
25+
- embedding_model_id: str | None - The NVIDIA embedding model ID to use, for example, nvidia/nv-embed-v1.
26+
(Env var NVIDIA_EMBEDDING_MODEL_ID)
27+
- env_file_path: if provided, the .env settings are read from this file path location
28+
"""
29+
30+
env_prefix: ClassVar[str] = "NVIDIA_"
31+
32+
api_key: SecretStr
33+
base_url: str = "https://integrate.api.nvidia.com/v1"
34+
embedding_model_id: str | None
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# Copyright (c) Microsoft. All rights reserved.
2+
3+
from unittest.mock import AsyncMock, patch
4+
5+
import pytest
6+
from openai import AsyncClient
7+
from openai.resources.embeddings import AsyncEmbeddings
8+
9+
from semantic_kernel.connectors.ai.nvidia.prompt_execution_settings.nvidia_prompt_execution_settings import (
10+
NvidiaEmbeddingPromptExecutionSettings,
11+
)
12+
from semantic_kernel.connectors.ai.nvidia.services.nvidia_text_embedding import NvidiaTextEmbedding
13+
from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings
14+
from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError, ServiceResponseException
15+
16+
17+
@pytest.fixture
18+
def nvidia_unit_test_env(monkeypatch, exclude_list, override_env_param_dict):
19+
"""Fixture to set environment variables for NvidiaTextEmbedding."""
20+
if exclude_list is None:
21+
exclude_list = []
22+
23+
if override_env_param_dict is None:
24+
override_env_param_dict = {}
25+
26+
env_vars = {"NVIDIA_API_KEY": "test_api_key", "NVIDIA_EMBEDDING_MODEL_ID": "test_embedding_model_id"}
27+
28+
env_vars.update(override_env_param_dict)
29+
30+
for key, value in env_vars.items():
31+
if key not in exclude_list:
32+
monkeypatch.setenv(key, value)
33+
else:
34+
monkeypatch.delenv(key, raising=False)
35+
36+
return env_vars
37+
38+
39+
def test_init(nvidia_unit_test_env):
40+
nvidia_text_embedding = NvidiaTextEmbedding()
41+
42+
assert nvidia_text_embedding.client is not None
43+
assert isinstance(nvidia_text_embedding.client, AsyncClient)
44+
assert nvidia_text_embedding.ai_model_id == nvidia_unit_test_env["NVIDIA_EMBEDDING_MODEL_ID"]
45+
46+
assert nvidia_text_embedding.get_prompt_execution_settings_class() == NvidiaEmbeddingPromptExecutionSettings
47+
48+
49+
def test_init_validation_fail() -> None:
50+
with pytest.raises(ServiceInitializationError):
51+
NvidiaTextEmbedding(api_key="34523", ai_model_id={"test": "dict"})
52+
53+
54+
def test_init_to_from_dict(nvidia_unit_test_env):
55+
default_headers = {"X-Unit-Test": "test-guid"}
56+
57+
settings = {
58+
"ai_model_id": nvidia_unit_test_env["NVIDIA_EMBEDDING_MODEL_ID"],
59+
"api_key": nvidia_unit_test_env["NVIDIA_API_KEY"],
60+
"default_headers": default_headers,
61+
}
62+
text_embedding = NvidiaTextEmbedding.from_dict(settings)
63+
dumped_settings = text_embedding.to_dict()
64+
assert dumped_settings["ai_model_id"] == settings["ai_model_id"]
65+
assert dumped_settings["api_key"] == settings["api_key"]
66+
67+
68+
@patch.object(AsyncEmbeddings, "create", new_callable=AsyncMock)
69+
async def test_embedding_calls_with_parameters(mock_create, nvidia_unit_test_env) -> None:
70+
ai_model_id = "NV-Embed-QA"
71+
texts = ["hello world", "goodbye world"]
72+
embedding_dimensions = 1536
73+
74+
nvidia_text_embedding = NvidiaTextEmbedding(
75+
ai_model_id=ai_model_id,
76+
)
77+
78+
await nvidia_text_embedding.generate_embeddings(texts, dimensions=embedding_dimensions)
79+
80+
mock_create.assert_awaited_once_with(
81+
input=texts,
82+
model=ai_model_id,
83+
dimensions=embedding_dimensions,
84+
encoding_format="float",
85+
extra_body={"input_type": "query", "truncate": "NONE"},
86+
)
87+
88+
89+
@patch.object(AsyncEmbeddings, "create", new_callable=AsyncMock)
90+
async def test_embedding_calls_with_settings(mock_create, nvidia_unit_test_env) -> None:
91+
ai_model_id = "test_model_id"
92+
texts = ["hello world", "goodbye world"]
93+
settings = NvidiaEmbeddingPromptExecutionSettings(service_id="default")
94+
nvidia_text_embedding = NvidiaTextEmbedding(service_id="default", ai_model_id=ai_model_id)
95+
96+
await nvidia_text_embedding.generate_embeddings(texts, settings=settings)
97+
98+
mock_create.assert_awaited_once_with(
99+
input=texts,
100+
model=ai_model_id,
101+
encoding_format="float",
102+
extra_body={"input_type": "query", "truncate": "NONE"},
103+
)
104+
105+
106+
@patch.object(AsyncEmbeddings, "create", new_callable=AsyncMock, side_effect=Exception)
107+
async def test_embedding_fail(mock_create, nvidia_unit_test_env) -> None:
108+
ai_model_id = "test_model_id"
109+
texts = ["hello world", "goodbye world"]
110+
111+
nvidia_text_embedding = NvidiaTextEmbedding(
112+
ai_model_id=ai_model_id,
113+
)
114+
with pytest.raises(ServiceResponseException):
115+
await nvidia_text_embedding.generate_embeddings(texts)
116+
117+
118+
@patch.object(AsyncEmbeddings, "create", new_callable=AsyncMock)
119+
async def test_embedding_pes(mock_create, nvidia_unit_test_env) -> None:
120+
ai_model_id = "test_model_id"
121+
texts = ["hello world", "goodbye world"]
122+
123+
pes = PromptExecutionSettings(service_id="x", ai_model_id=ai_model_id)
124+
125+
nvidia_text_embedding = NvidiaTextEmbedding(ai_model_id=ai_model_id)
126+
127+
await nvidia_text_embedding.generate_raw_embeddings(texts, pes)
128+
129+
mock_create.assert_awaited_once_with(
130+
input=texts,
131+
model=ai_model_id,
132+
encoding_format="float",
133+
extra_body={"input_type": "query", "truncate": "NONE"},
134+
)

0 commit comments

Comments
 (0)
Please sign in to comment.