|
| 1 | +# Copyright (c) Microsoft. All rights reserved. |
| 2 | + |
| 3 | +import asyncio |
| 4 | +import copy |
| 5 | +import logging |
| 6 | +import sys |
| 7 | +from typing import Any |
| 8 | + |
| 9 | +if sys.version_info >= (3, 12): |
| 10 | + from typing import override # pragma: no cover |
| 11 | +else: |
| 12 | + from typing_extensions import override # pragma: no cover |
| 13 | + |
| 14 | +from numpy import array, ndarray |
| 15 | +from openai import AsyncOpenAI |
| 16 | +from pydantic import ValidationError |
| 17 | + |
| 18 | +from semantic_kernel.connectors.ai.embeddings.embedding_generator_base import EmbeddingGeneratorBase |
| 19 | +from semantic_kernel.connectors.ai.nvidia.prompt_execution_settings.nvidia_prompt_execution_settings import ( |
| 20 | + NvidiaEmbeddingPromptExecutionSettings, |
| 21 | +) |
| 22 | +from semantic_kernel.connectors.ai.nvidia.services.nvidia_handler import NvidiaHandler |
| 23 | +from semantic_kernel.connectors.ai.nvidia.services.nvidia_model_types import NvidiaModelTypes |
| 24 | +from semantic_kernel.connectors.ai.nvidia.settings.nvidia_settings import NvidiaSettings |
| 25 | +from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings |
| 26 | +from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError |
| 27 | +from semantic_kernel.utils.feature_stage_decorator import experimental |
| 28 | + |
| 29 | +logger: logging.Logger = logging.getLogger(__name__) |
| 30 | + |
| 31 | + |
| 32 | +@experimental |
| 33 | +class NvidiaTextEmbedding(NvidiaHandler, EmbeddingGeneratorBase): |
| 34 | + """Nvidia text embedding service.""" |
| 35 | + |
| 36 | + def __init__( |
| 37 | + self, |
| 38 | + ai_model_id: str | None = None, |
| 39 | + api_key: str | None = None, |
| 40 | + base_url: str | None = None, |
| 41 | + client: AsyncOpenAI | None = None, |
| 42 | + env_file_path: str | None = None, |
| 43 | + service_id: str | None = None, |
| 44 | + ) -> None: |
| 45 | + """Initializes a new instance of the NvidiaTextEmbedding class. |
| 46 | +
|
| 47 | + Args: |
| 48 | + ai_model_id (str): NVIDIA model card string, see |
| 49 | + https://Nvidia.co/sentence-transformers |
| 50 | + api_key: NVIDIA API key, see https://console.NVIDIA.com/settings/keys |
| 51 | + (Env var NVIDIA_API_KEY) |
| 52 | + base_url: HttpsUrl | None - base_url: The url of the NVIDIA endpoint. The base_url consists of the endpoint, |
| 53 | + and more information refer https://docs.api.nvidia.com/nim/reference/ |
| 54 | + use endpoint if you only want to supply the endpoint. |
| 55 | + (Env var NVIDIA_BASE_URL) |
| 56 | + client (Optional[AsyncOpenAI]): An existing client to use. (Optional) |
| 57 | + env_file_path (str | None): Use the environment settings file as |
| 58 | + a fallback to environment variables. (Optional) |
| 59 | + service_id (str): Service ID for the model. (optional) |
| 60 | + """ |
| 61 | + try: |
| 62 | + nvidia_settings = NvidiaSettings.create( |
| 63 | + api_key=api_key, |
| 64 | + base_url=base_url, |
| 65 | + embedding_model_id=ai_model_id, |
| 66 | + env_file_path=env_file_path, |
| 67 | + ) |
| 68 | + except ValidationError as ex: |
| 69 | + raise ServiceInitializationError("Failed to create NVIDIA settings.", ex) from ex |
| 70 | + if not nvidia_settings.embedding_model_id: |
| 71 | + nvidia_settings.embedding_model_id = "nvidia/nv-embedqa-e5-v5" |
| 72 | + logger.warning(f"Default embedding model set as: {nvidia_settings.embedding_model_id}") |
| 73 | + if not nvidia_settings.api_key: |
| 74 | + logger.warning("API_KEY is missing, inference may fail.") |
| 75 | + if not client: |
| 76 | + client = AsyncOpenAI(api_key=nvidia_settings.api_key.get_secret_value(), base_url=nvidia_settings.base_url) |
| 77 | + super().__init__( |
| 78 | + ai_model_id=nvidia_settings.embedding_model_id, |
| 79 | + api_key=nvidia_settings.api_key.get_secret_value() if nvidia_settings.api_key else None, |
| 80 | + ai_model_type=NvidiaModelTypes.EMBEDDING, |
| 81 | + service_id=service_id or nvidia_settings.embedding_model_id, |
| 82 | + env_file_path=env_file_path, |
| 83 | + client=client, |
| 84 | + ) |
| 85 | + |
| 86 | + @override |
| 87 | + async def generate_embeddings( |
| 88 | + self, |
| 89 | + texts: list[str], |
| 90 | + settings: "PromptExecutionSettings | None" = None, |
| 91 | + batch_size: int | None = None, |
| 92 | + **kwargs: Any, |
| 93 | + ) -> ndarray: |
| 94 | + raw_embeddings = await self.generate_raw_embeddings(texts, settings, batch_size, **kwargs) |
| 95 | + return array(raw_embeddings) |
| 96 | + |
| 97 | + @override |
| 98 | + async def generate_raw_embeddings( |
| 99 | + self, |
| 100 | + texts: list[str], |
| 101 | + settings: "PromptExecutionSettings | None" = None, |
| 102 | + batch_size: int | None = None, |
| 103 | + **kwargs: Any, |
| 104 | + ) -> Any: |
| 105 | + """Returns embeddings for the given texts in the unedited format. |
| 106 | +
|
| 107 | + Args: |
| 108 | + texts (List[str]): The texts to generate embeddings for. |
| 109 | + settings (NvidiaEmbeddingPromptExecutionSettings): The settings to use for the request. |
| 110 | + batch_size (int): The batch size to use for the request. |
| 111 | + kwargs (Dict[str, Any]): Additional arguments to pass to the request. |
| 112 | + """ |
| 113 | + if not settings: |
| 114 | + settings = NvidiaEmbeddingPromptExecutionSettings(ai_model_id=self.ai_model_id) |
| 115 | + else: |
| 116 | + if not isinstance(settings, NvidiaEmbeddingPromptExecutionSettings): |
| 117 | + settings = self.get_prompt_execution_settings_from_settings(settings) |
| 118 | + assert isinstance(settings, NvidiaEmbeddingPromptExecutionSettings) # nosec |
| 119 | + if settings.ai_model_id is None: |
| 120 | + settings.ai_model_id = self.ai_model_id |
| 121 | + for key, value in kwargs.items(): |
| 122 | + setattr(settings, key, value) |
| 123 | + |
| 124 | + # move input_type and truncate to extra-body |
| 125 | + if not settings.extra_body: |
| 126 | + settings.extra_body = {} |
| 127 | + settings.extra_body.setdefault("input_type", settings.input_type) |
| 128 | + if settings.truncate is not None: |
| 129 | + settings.extra_body.setdefault("truncate", settings.truncate) |
| 130 | + |
| 131 | + raw_embeddings = [] |
| 132 | + tasks = [] |
| 133 | + |
| 134 | + batch_size = batch_size or len(texts) |
| 135 | + for i in range(0, len(texts), batch_size): |
| 136 | + batch = texts[i : i + batch_size] |
| 137 | + batch_settings = copy.deepcopy(settings) |
| 138 | + batch_settings.input = batch |
| 139 | + tasks.append(self._send_request(settings=batch_settings)) |
| 140 | + |
| 141 | + results = await asyncio.gather(*tasks) |
| 142 | + for raw_embedding in results: |
| 143 | + assert isinstance(raw_embedding, list) # nosec |
| 144 | + raw_embeddings.extend(raw_embedding) |
| 145 | + |
| 146 | + return raw_embeddings |
| 147 | + |
| 148 | + def get_prompt_execution_settings_class(self) -> type["PromptExecutionSettings"]: |
| 149 | + """Get the request settings class.""" |
| 150 | + return NvidiaEmbeddingPromptExecutionSettings |
| 151 | + |
| 152 | + @classmethod |
| 153 | + def from_dict(cls: type["NvidiaTextEmbedding"], settings: dict[str, Any]) -> "NvidiaTextEmbedding": |
| 154 | + """Initialize an Open AI service from a dictionary of settings. |
| 155 | +
|
| 156 | + Args: |
| 157 | + settings: A dictionary of settings for the service. |
| 158 | + """ |
| 159 | + return cls( |
| 160 | + ai_model_id=settings.get("ai_model_id"), |
| 161 | + api_key=settings.get("api_key"), |
| 162 | + env_file_path=settings.get("env_file_path"), |
| 163 | + service_id=settings.get("service_id"), |
| 164 | + ) |
0 commit comments