# Copyright (c) Microsoft. All rights reserved. import logging import os from collections.abc import AsyncGenerator, Mapping, Sequence from html import unescape from typing import TYPE_CHECKING, Any import yaml from pydantic import Field, ValidationError, model_validator from semantic_kernel.connectors.ai.chat_completion_client_base import ChatCompletionClientBase from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings from semantic_kernel.connectors.ai.text_completion_client_base import TextCompletionClientBase from semantic_kernel.const import DEFAULT_SERVICE_NAME from semantic_kernel.contents.chat_history import ChatHistory from semantic_kernel.contents.chat_message_content import ChatMessageContent from semantic_kernel.contents.text_content import TextContent from semantic_kernel.exceptions import FunctionExecutionException, FunctionInitializationError from semantic_kernel.exceptions.function_exceptions import PromptRenderingException from semantic_kernel.filters.filter_types import FilterTypes from semantic_kernel.filters.functions.function_invocation_context import FunctionInvocationContext from semantic_kernel.filters.kernel_filters_extension import _rebuild_prompt_render_context from semantic_kernel.filters.prompts.prompt_render_context import PromptRenderContext from semantic_kernel.functions.function_result import FunctionResult from semantic_kernel.functions.kernel_arguments import KernelArguments from semantic_kernel.functions.kernel_function import TEMPLATE_FORMAT_MAP, KernelFunction from semantic_kernel.functions.kernel_function_metadata import KernelFunctionMetadata from semantic_kernel.functions.kernel_parameter_metadata import KernelParameterMetadata from semantic_kernel.functions.prompt_rendering_result import PromptRenderingResult from semantic_kernel.prompt_template.const import KERNEL_TEMPLATE_FORMAT_NAME, TEMPLATE_FORMAT_TYPES from semantic_kernel.prompt_template.prompt_template_base import PromptTemplateBase from semantic_kernel.prompt_template.prompt_template_config import PromptTemplateConfig if TYPE_CHECKING: from semantic_kernel.services.ai_service_client_base import AIServiceClientBase logger: logging.Logger = logging.getLogger(__name__) PROMPT_FILE_NAME = "skprompt.txt" CONFIG_FILE_NAME = "config.json" PROMPT_RETURN_PARAM = KernelParameterMetadata( name="return", description="The completion result", default_value=None, type="FunctionResult", # type: ignore is_required=True, ) class KernelFunctionFromPrompt(KernelFunction): """Semantic Kernel Function from a prompt.""" prompt_template: PromptTemplateBase prompt_execution_settings: dict[str, PromptExecutionSettings] = Field(default_factory=dict) def __init__( self, function_name: str, plugin_name: str | None = None, description: str | None = None, prompt: str | None = None, template_format: TEMPLATE_FORMAT_TYPES = KERNEL_TEMPLATE_FORMAT_NAME, prompt_template: PromptTemplateBase | None = None, prompt_template_config: PromptTemplateConfig | None = None, prompt_execution_settings: PromptExecutionSettings | Sequence[PromptExecutionSettings] | Mapping[str, PromptExecutionSettings] | None = None, ) -> None: """Initializes a new instance of the KernelFunctionFromPrompt class. Args: function_name (str): The name of the function plugin_name (str): The name of the plugin description (str): The description for the function prompt (Optional[str]): The prompt template_format (Optional[str]): The template format, default is "semantic-kernel" prompt_template (Optional[KernelPromptTemplate]): The prompt template prompt_template_config (Optional[PromptTemplateConfig]): The prompt template configuration prompt_execution_settings (Optional): instance, list or dict of PromptExecutionSettings to be used by the function, can also be supplied through prompt_template_config, but the supplied one is used if both are present. prompt_template_config (Optional[PromptTemplateConfig]): the prompt template config. """ if not prompt and not prompt_template_config and not prompt_template: raise FunctionInitializationError( "The prompt cannot be empty, must be supplied directly, \ through prompt_template_config or in the prompt_template." ) if prompt and prompt_template_config and prompt_template_config.template != prompt: logger.warning( f"Prompt ({prompt}) and PromptTemplateConfig ({prompt_template_config.template}) both supplied, " "using the template in PromptTemplateConfig, ignoring prompt." ) if template_format and prompt_template_config and prompt_template_config.template_format != template_format: logger.warning( f"Template ({template_format}) and PromptTemplateConfig ({prompt_template_config.template_format}) " "both supplied, using the template format in PromptTemplateConfig, ignoring template." ) if not prompt_template: if not prompt_template_config: # prompt must be there if prompt_template and prompt_template_config is not supplied prompt_template_config = PromptTemplateConfig( name=function_name, description=description, template=prompt, template_format=template_format, ) prompt_template = TEMPLATE_FORMAT_MAP[prompt_template_config.template_format]( prompt_template_config=prompt_template_config ) # type: ignore try: metadata = KernelFunctionMetadata( name=function_name, plugin_name=plugin_name, description=description, parameters=prompt_template.prompt_template_config.get_kernel_parameter_metadata(), # type: ignore is_prompt=True, is_asynchronous=True, return_parameter=PROMPT_RETURN_PARAM, ) except ValidationError as exc: raise FunctionInitializationError("Failed to create KernelFunctionMetadata") from exc super().__init__( metadata=metadata, prompt_template=prompt_template, # type: ignore prompt_execution_settings=prompt_execution_settings or {}, # type: ignore ) @model_validator(mode="before") @classmethod def rewrite_execution_settings( cls, data: Any, ) -> dict[str, PromptExecutionSettings]: """Rewrite execution settings to a dictionary. If the prompt_execution_settings is not a dictionary, it is converted to a dictionary. If it is not supplied, but prompt_template is, the prompt_template's execution settings are used. """ if isinstance(data, dict): prompt_execution_settings = data.get("prompt_execution_settings") prompt_template = data.get("prompt_template") if not prompt_execution_settings: if prompt_template: prompt_execution_settings = prompt_template.prompt_template_config.execution_settings data["prompt_execution_settings"] = prompt_execution_settings if not prompt_execution_settings: return data if isinstance(prompt_execution_settings, PromptExecutionSettings): data["prompt_execution_settings"] = { prompt_execution_settings.service_id or DEFAULT_SERVICE_NAME: prompt_execution_settings } if isinstance(prompt_execution_settings, Sequence): data["prompt_execution_settings"] = { s.service_id or DEFAULT_SERVICE_NAME: s for s in prompt_execution_settings } return data async def _invoke_internal(self, context: FunctionInvocationContext) -> None: """Invokes the function with the given arguments.""" prompt_render_result = await self._render_prompt(context) if prompt_render_result.function_result is not None: context.result = prompt_render_result.function_result return if isinstance(prompt_render_result.ai_service, ChatCompletionClientBase): chat_history = ChatHistory.from_rendered_prompt(prompt_render_result.rendered_prompt) try: chat_message_contents = await prompt_render_result.ai_service.get_chat_message_contents( chat_history=chat_history, settings=prompt_render_result.execution_settings, **{"kernel": context.kernel, "arguments": context.arguments}, ) except Exception as exc: raise FunctionExecutionException(f"Error occurred while invoking function {self.name}: {exc}") from exc if not chat_message_contents: raise FunctionExecutionException(f"No completions returned while invoking function {self.name}") context.result = self._create_function_result( completions=chat_message_contents, chat_history=chat_history, arguments=context.arguments, prompt=prompt_render_result.rendered_prompt, ) return if isinstance(prompt_render_result.ai_service, TextCompletionClientBase): try: texts = await prompt_render_result.ai_service.get_text_contents( prompt=unescape(prompt_render_result.rendered_prompt), settings=prompt_render_result.execution_settings, ) except Exception as exc: raise FunctionExecutionException(f"Error occurred while invoking function {self.name}: {exc}") from exc context.result = self._create_function_result( completions=texts, arguments=context.arguments, prompt=prompt_render_result.rendered_prompt ) return raise ValueError(f"Service `{type(prompt_render_result.ai_service).__name__}` is not a valid AI service") async def _invoke_internal_stream(self, context: FunctionInvocationContext) -> None: """Invokes the function stream with the given arguments.""" prompt_render_result = await self._render_prompt(context, is_streaming=True) if prompt_render_result.function_result is not None: context.result = prompt_render_result.function_result return if isinstance(prompt_render_result.ai_service, ChatCompletionClientBase): chat_history = ChatHistory.from_rendered_prompt(prompt_render_result.rendered_prompt) value: AsyncGenerator = prompt_render_result.ai_service.get_streaming_chat_message_contents( chat_history=chat_history, settings=prompt_render_result.execution_settings, **{"kernel": context.kernel, "arguments": context.arguments}, ) elif isinstance(prompt_render_result.ai_service, TextCompletionClientBase): value = prompt_render_result.ai_service.get_streaming_text_contents( prompt=prompt_render_result.rendered_prompt, settings=prompt_render_result.execution_settings ) else: raise FunctionExecutionException( f"Service `{type(prompt_render_result.ai_service)}` is not a valid AI service" ) context.result = FunctionResult( function=self.metadata, value=value, rendered_prompt=prompt_render_result.rendered_prompt ) async def _render_prompt( self, context: FunctionInvocationContext, is_streaming: bool = False ) -> PromptRenderingResult: """Render the prompt and apply the prompt rendering filters.""" self.update_arguments_with_defaults(context.arguments) _rebuild_prompt_render_context() prompt_render_context = PromptRenderContext( function=self, kernel=context.kernel, arguments=context.arguments, is_streaming=is_streaming ) stack = context.kernel.construct_call_stack( filter_type=FilterTypes.PROMPT_RENDERING, inner_function=self._inner_render_prompt, ) await stack(prompt_render_context) if prompt_render_context.rendered_prompt is None: raise PromptRenderingException("Prompt rendering failed, no rendered prompt was returned.") selected_service: tuple["AIServiceClientBase", PromptExecutionSettings] = context.kernel.select_ai_service( function=self, arguments=context.arguments ) return PromptRenderingResult( rendered_prompt=prompt_render_context.rendered_prompt, ai_service=selected_service[0], execution_settings=selected_service[1], function_result=prompt_render_context.function_result, ) async def _inner_render_prompt(self, context: PromptRenderContext) -> None: """Render the prompt using the prompt template.""" context.rendered_prompt = await self.prompt_template.render(context.kernel, context.arguments) def _create_function_result( self, completions: list[ChatMessageContent] | list[TextContent], arguments: KernelArguments, chat_history: ChatHistory | None = None, prompt: str | None = None, ) -> FunctionResult: """Creates a function result with the given completions.""" metadata: dict[str, Any] = { "arguments": arguments, "metadata": [completion.metadata for completion in completions], } if chat_history: metadata["messages"] = chat_history if prompt: metadata["prompt"] = prompt return FunctionResult( function=self.metadata, value=completions, metadata=metadata, rendered_prompt=prompt, ) def update_arguments_with_defaults(self, arguments: KernelArguments) -> None: """Update any missing values with their defaults.""" for parameter in self.prompt_template.prompt_template_config.input_variables: if parameter.name not in arguments and parameter.default not in {None, "", False, 0}: arguments[parameter.name] = parameter.default @classmethod def from_yaml(cls, yaml_str: str, plugin_name: str | None = None) -> "KernelFunctionFromPrompt": """Creates a new instance of the KernelFunctionFromPrompt class from a YAML string.""" try: data = yaml.safe_load(yaml_str) except yaml.YAMLError as exc: # pragma: no cover raise FunctionInitializationError(f"Invalid YAML content: {yaml_str}, error: {exc}") from exc if not isinstance(data, dict): raise FunctionInitializationError(f"The YAML content must represent a dictionary, got {yaml_str}") try: prompt_template_config = PromptTemplateConfig(**data) except ValidationError as exc: raise FunctionInitializationError( f"Error initializing PromptTemplateConfig: {exc} from yaml data: {data}" ) from exc return cls( function_name=prompt_template_config.name, plugin_name=plugin_name, description=prompt_template_config.description, prompt_template_config=prompt_template_config, template_format=prompt_template_config.template_format, ) @classmethod def from_directory(cls, path: str, plugin_name: str | None = None) -> "KernelFunctionFromPrompt": """Creates a new instance of the KernelFunctionFromPrompt class from a directory. The directory needs to contain: - A prompt file named `skprompt.txt` - A config file named `config.json` Returns: KernelFunctionFromPrompt: The kernel function from prompt """ prompt_path = os.path.join(path, PROMPT_FILE_NAME) config_path = os.path.join(path, CONFIG_FILE_NAME) prompt_exists = os.path.exists(prompt_path) config_exists = os.path.exists(config_path) if not config_exists and not prompt_exists: raise FunctionInitializationError( f"{PROMPT_FILE_NAME} and {CONFIG_FILE_NAME} files are required to create a " f"function from a directory, path: {path!s}." ) if not config_exists: raise FunctionInitializationError( f"{CONFIG_FILE_NAME} files are required to create a function from a directory, " f"path: {path!s}, prompt file is there." ) if not prompt_exists: raise FunctionInitializationError( f"{PROMPT_FILE_NAME} files are required to create a function from a directory, " f"path: {path!s}, config file is there." ) function_name = os.path.basename(path) with open(config_path) as config_file: prompt_template_config = PromptTemplateConfig.from_json(config_file.read()) prompt_template_config.name = function_name with open(prompt_path) as prompt_file: prompt_template_config.template = prompt_file.read() prompt_template = TEMPLATE_FORMAT_MAP[prompt_template_config.template_format]( # type: ignore prompt_template_config=prompt_template_config ) return cls( function_name=function_name, plugin_name=plugin_name, prompt_template=prompt_template, prompt_template_config=prompt_template_config, template_format=prompt_template_config.template_format, description=prompt_template_config.description, )