Skip to content

Commit 374d8f9

Browse files
eavanvalkenburgmoonbox3
andauthoredFeb 19, 2025
Python: filter improvements (microsoft#10588)
### Motivation and Context <!-- Thank you for your contribution to the semantic-kernel repo! Please help reviewers and future users, providing the following information: 1. Why is this change required? 2. What problem does it solve? 3. What scenario does it contribute to? 4. If it fixes an open issue, please link to the issue here. --> Fixes some inconsistency in the filter setup between dotnet and python Adds a sample for how to use a filter to retry a prompt with a different model. Adds a sample for semantic caching Closes: microsoft#10572 microsoft#5924 microsoft#10595 ### Description <!-- Describe your changes, the overall approach, the underlying design. These notes will help understanding how your code works. Thanks! --> ### Contribution Checklist <!-- Before submitting this PR, please make sure: --> - [x] The code builds clean without any errors or warnings - [x] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [x] All unit tests pass, and I have added new tests where possible - [x] I didn't break anyone 😄 --------- Co-authored-by: Evan Mattson <[email protected]>
1 parent 14a3263 commit 374d8f9

19 files changed

+384
-79
lines changed
 
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# Copyright (c) Microsoft. All rights reserved.
2+
3+
import asyncio
4+
import time
5+
from collections.abc import Awaitable, Callable
6+
from dataclasses import dataclass, field
7+
from typing import Annotated
8+
from uuid import uuid4
9+
10+
from semantic_kernel.connectors.ai.embeddings.embedding_generator_base import EmbeddingGeneratorBase
11+
from semantic_kernel.connectors.ai.open_ai.services.open_ai_chat_completion import OpenAIChatCompletion
12+
from semantic_kernel.connectors.ai.open_ai.services.open_ai_text_embedding import OpenAITextEmbedding
13+
from semantic_kernel.connectors.memory.in_memory.in_memory_store import InMemoryVectorStore
14+
from semantic_kernel.data.record_definition import vectorstoremodel
15+
from semantic_kernel.data.record_definition.vector_store_record_fields import (
16+
VectorStoreRecordDataField,
17+
VectorStoreRecordKeyField,
18+
VectorStoreRecordVectorField,
19+
)
20+
from semantic_kernel.data.vector_search.vector_search_options import VectorSearchOptions
21+
from semantic_kernel.data.vector_search.vectorized_search import VectorizedSearchMixin
22+
from semantic_kernel.data.vector_storage.vector_store import VectorStore
23+
from semantic_kernel.data.vector_storage.vector_store_record_collection import VectorStoreRecordCollection
24+
from semantic_kernel.filters.filter_types import FilterTypes
25+
from semantic_kernel.filters.functions.function_invocation_context import FunctionInvocationContext
26+
from semantic_kernel.filters.prompts.prompt_render_context import PromptRenderContext
27+
from semantic_kernel.functions.function_result import FunctionResult
28+
from semantic_kernel.kernel import Kernel
29+
30+
COLLECTION_NAME = "llm_responses"
31+
RECORD_ID_KEY = "cache_record_id"
32+
33+
34+
# Define a simple data model to store, the prompt, the result, and the prompt embedding.
35+
@vectorstoremodel
36+
@dataclass
37+
class CacheRecord:
38+
prompt: Annotated[str, VectorStoreRecordDataField(embedding_property_name="prompt_embedding")]
39+
result: Annotated[str, VectorStoreRecordDataField(is_full_text_searchable=True)]
40+
prompt_embedding: Annotated[list[float], VectorStoreRecordVectorField(dimensions=1526)] = field(
41+
default_factory=list
42+
)
43+
id: Annotated[str, VectorStoreRecordKeyField] = field(default_factory=lambda: str(uuid4()))
44+
45+
46+
# Define the filters, one for caching the results and one for using the cache.
47+
class PromptCacheFilter:
48+
"""A filter to cache the results of the prompt rendering and function invocation."""
49+
50+
def __init__(
51+
self,
52+
embedding_service: EmbeddingGeneratorBase,
53+
vector_store: VectorStore,
54+
collection_name: str = COLLECTION_NAME,
55+
score_threshold: float = 0.2,
56+
):
57+
self.embedding_service = embedding_service
58+
self.vector_store = vector_store
59+
self.collection: VectorStoreRecordCollection[str, CacheRecord] = vector_store.get_collection(
60+
collection_name, data_model_type=CacheRecord
61+
)
62+
self.score_threshold = score_threshold
63+
64+
async def on_prompt_render(
65+
self, context: PromptRenderContext, next: Callable[[PromptRenderContext], Awaitable[None]]
66+
):
67+
"""Filter to cache the rendered prompt and the result of the function.
68+
69+
It uses the score threshold to determine if the result should be cached.
70+
The direction of the comparison is based on the default distance metric for
71+
the in memory vector store, which is cosine distance, so the closer to 0 the
72+
closer the match.
73+
"""
74+
await next(context)
75+
assert context.rendered_prompt # nosec
76+
prompt_embedding = await self.embedding_service.generate_raw_embeddings([context.rendered_prompt])
77+
await self.collection.create_collection_if_not_exists()
78+
assert isinstance(self.collection, VectorizedSearchMixin) # nosec
79+
results = await self.collection.vectorized_search(
80+
vector=prompt_embedding[0], options=VectorSearchOptions(vector_field_name="prompt_embedding", top=1)
81+
)
82+
async for result in results.results:
83+
if result.score < self.score_threshold:
84+
context.function_result = FunctionResult(
85+
function=context.function.metadata,
86+
value=result.record.result,
87+
rendered_prompt=context.rendered_prompt,
88+
metadata={RECORD_ID_KEY: result.record.id},
89+
)
90+
91+
async def on_function_invocation(
92+
self, context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]]
93+
):
94+
"""Filter to store the result in the cache if it is new."""
95+
await next(context)
96+
result = context.result
97+
if result and result.rendered_prompt and RECORD_ID_KEY not in result.metadata:
98+
prompt_embedding = await self.embedding_service.generate_embeddings([result.rendered_prompt])
99+
cache_record = CacheRecord(
100+
prompt=result.rendered_prompt,
101+
result=str(result),
102+
prompt_embedding=prompt_embedding[0],
103+
)
104+
await self.collection.create_collection_if_not_exists()
105+
await self.collection.upsert(cache_record)
106+
107+
108+
async def execute_async(kernel: Kernel, title: str, prompt: str):
109+
"""Helper method to execute and log time."""
110+
print(f"{title}: {prompt}")
111+
start = time.time()
112+
result = await kernel.invoke_prompt(prompt)
113+
elapsed = time.time() - start
114+
print(f"\tElapsed Time: {elapsed:.3f}")
115+
return result
116+
117+
118+
async def main():
119+
# create the kernel and add the chat service and the embedding service
120+
kernel = Kernel()
121+
chat = OpenAIChatCompletion(service_id="default")
122+
embedding = OpenAITextEmbedding(service_id="embedder")
123+
kernel.add_service(chat)
124+
kernel.add_service(embedding)
125+
# create the in-memory vector store
126+
vector_store = InMemoryVectorStore()
127+
# create the cache filter and add the filters to the kernel
128+
cache = PromptCacheFilter(embedding_service=embedding, vector_store=vector_store)
129+
kernel.add_filter(FilterTypes.PROMPT_RENDERING, cache.on_prompt_render)
130+
kernel.add_filter(FilterTypes.FUNCTION_INVOCATION, cache.on_function_invocation)
131+
132+
# Run the sample
133+
print("\nIn-memory cache sample:")
134+
r1 = await execute_async(kernel, "First run", "What's the tallest building in New York?")
135+
print(f"\tResult 1: {r1}")
136+
r2 = await execute_async(kernel, "Second run", "How are you today?")
137+
print(f"\tResult 2: {r2}")
138+
r3 = await execute_async(kernel, "Third run", "What is the highest building in New York City?")
139+
print(f"\tResult 3: {r3}")
140+
141+
142+
if __name__ == "__main__":
143+
asyncio.run(main())

‎python/samples/concepts/filtering/function_invocation_filters_stream.py

+15-12
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import logging
55
import os
66
from collections.abc import Callable, Coroutine
7-
from functools import reduce
87
from typing import Any
98

109
from semantic_kernel import Kernel
@@ -38,17 +37,21 @@ async def streaming_exception_handling(
3837
):
3938
await next(context)
4039

41-
async def override_stream(stream):
42-
try:
43-
async for partial in stream:
44-
yield partial
45-
except Exception as e:
46-
yield [
47-
StreamingChatMessageContent(role=AuthorRole.ASSISTANT, content=f"Exception caught: {e}", choice_index=0)
48-
]
40+
if context.is_streaming:
4941

50-
stream = context.result.value
51-
context.result = FunctionResult(function=context.result.function, value=override_stream(stream))
42+
async def override_stream(stream):
43+
try:
44+
async for partial in stream:
45+
yield partial
46+
except Exception as e:
47+
yield [
48+
StreamingChatMessageContent(
49+
role=AuthorRole.ASSISTANT, content=f"Exception caught: {e}", choice_index=0
50+
)
51+
]
52+
53+
stream = context.result.value
54+
context.result = FunctionResult(function=context.result.function, value=override_stream(stream))
5255

5356

5457
async def chat(chat_history: ChatHistory) -> bool:
@@ -77,7 +80,7 @@ async def chat(chat_history: ChatHistory) -> bool:
7780
print("")
7881
chat_history.add_user_message(user_input)
7982
if streamed_chunks:
80-
streaming_chat_message = reduce(lambda first, second: first + second, streamed_chunks)
83+
streaming_chat_message = sum(streamed_chunks[1:], streamed_chunks[0])
8184
chat_history.add_message(streaming_chat_message)
8285
return True
8386

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Copyright (c) Microsoft. All rights reserved.
2+
3+
import asyncio
4+
import logging
5+
from collections.abc import Awaitable, Callable
6+
7+
from semantic_kernel import Kernel
8+
from semantic_kernel.connectors.ai.open_ai.prompt_execution_settings.open_ai_prompt_execution_settings import (
9+
OpenAIChatPromptExecutionSettings,
10+
)
11+
from semantic_kernel.connectors.ai.open_ai.services.open_ai_chat_completion import OpenAIChatCompletion
12+
from semantic_kernel.filters import FunctionInvocationContext
13+
from semantic_kernel.filters.filter_types import FilterTypes
14+
from semantic_kernel.functions.kernel_arguments import KernelArguments
15+
16+
# This sample shows how to use a filter to use a fallback service if the default service fails to execute the function.
17+
# this works by replacing the settings that point to the default service
18+
# with the settings that point to the fallback service
19+
# after the default service fails to execute the function.
20+
21+
logger = logging.getLogger(__name__)
22+
23+
24+
class RetryFilter:
25+
"""A filter that retries the function invocation with a different model if it fails."""
26+
27+
def __init__(self, default_service_id: str, fallback_service_id: str):
28+
"""Initialize the filter with the default and fallback service ids."""
29+
self.default_service_id = default_service_id
30+
self.fallback_service_id = fallback_service_id
31+
32+
async def retry_filter(
33+
self,
34+
context: FunctionInvocationContext,
35+
next: Callable[[FunctionInvocationContext], Awaitable[None]],
36+
) -> None:
37+
"""A filter that retries the function invocation with a different model if it fails."""
38+
try:
39+
# try the default function
40+
await next(context)
41+
except Exception as ex:
42+
print("Expected failure to execute the function: ", ex)
43+
# if the default function fails, try the fallback function
44+
if (
45+
context.arguments
46+
and context.arguments.execution_settings
47+
and self.default_service_id in context.arguments.execution_settings
48+
):
49+
# get the settings for the default service
50+
settings = context.arguments.execution_settings.pop(self.default_service_id)
51+
settings.service_id = self.fallback_service_id
52+
# add them back with the right service id
53+
context.arguments.execution_settings[self.fallback_service_id] = settings
54+
# try again!
55+
await next(context)
56+
else:
57+
raise ex
58+
59+
60+
async def main() -> None:
61+
# set the ids for the default and fallback services
62+
default_service_id = "default_service"
63+
fallback_service_id = "fallback_service"
64+
kernel = Kernel()
65+
# create the filter with the ids
66+
retry_filter = RetryFilter(default_service_id=default_service_id, fallback_service_id=fallback_service_id)
67+
# add the filter to the kernel
68+
kernel.add_filter(FilterTypes.FUNCTION_INVOCATION, retry_filter.retry_filter)
69+
70+
# add the default and fallback services
71+
default_service = OpenAIChatCompletion(service_id=default_service_id, api_key="invalid_key")
72+
kernel.add_service(default_service)
73+
fallback_service = OpenAIChatCompletion(service_id=fallback_service_id)
74+
kernel.add_service(fallback_service)
75+
76+
# create the settings for the request
77+
request_settings = OpenAIChatPromptExecutionSettings(service_id=default_service_id)
78+
# invoke a simple prompt function
79+
response = await kernel.invoke_prompt(
80+
function_name="retry_function",
81+
prompt="How are you today?",
82+
arguments=KernelArguments(settings=request_settings),
83+
)
84+
85+
print("Model response: ", response)
86+
87+
# Sample output:
88+
# Expected failure to execute the function: Error occurred while invoking function retry_function:
89+
# ("<class 'semantic_kernel.connectors.ai.open_ai.services.open_ai_chat_completion.OpenAIChatCompletion'> service
90+
# failed to complete the prompt", AuthenticationError("Error code: 401 - {'error': {'message': 'Incorrect API key
91+
# provided: invalid_key. You can find your API key at https://platform.openai.com/account/api-keys.', 'type':
92+
# 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}"))
93+
# Model response: I'm just a program, so I don't experience feelings, but I'm here and ready to help you out.
94+
# How can I assist you today?
95+
96+
97+
if __name__ == "__main__":
98+
asyncio.run(main())

‎python/samples/concepts/filtering/retry_with_filters.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22

33
import asyncio
44
import logging
5-
from collections.abc import Callable, Coroutine
6-
from typing import Any
5+
from collections.abc import Awaitable, Callable
76

87
from samples.concepts.setup.chat_completion_services import Services, get_chat_completion_service_and_request_settings
98
from semantic_kernel import Kernel
@@ -34,7 +33,7 @@ def __init__(self):
3433
self._invocation_count = 0
3534

3635
@kernel_function(name="GetWeather", description="Get the weather of the day at the current location.")
37-
def get_wather(self) -> str:
36+
def get_weather(self) -> str:
3837
"""Get the weather of the day at the current location.
3938
4039
Simulates a call to an external service to get the weather.
@@ -50,7 +49,7 @@ def get_wather(self) -> str:
5049

5150
async def retry_filter(
5251
context: FunctionInvocationContext,
53-
next: Callable[[FunctionInvocationContext], Coroutine[Any, Any, None]],
52+
next: Callable[[FunctionInvocationContext], Awaitable[None]],
5453
) -> None:
5554
"""A filter that retries the function invocation if it fails.
5655

‎python/semantic_kernel/connectors/ai/chat_completion_client_base.py

+3
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ async def get_chat_message_contents(
157157
function_call=function_call,
158158
chat_history=chat_history,
159159
arguments=kwargs.get("arguments"),
160+
execution_settings=settings,
160161
function_call_count=fc_count,
161162
request_index=request_index,
162163
function_behavior=settings.function_choice_behavior,
@@ -289,6 +290,8 @@ async def get_streaming_chat_message_contents(
289290
function_call=function_call,
290291
chat_history=chat_history,
291292
arguments=kwargs.get("arguments"),
293+
is_streaming=True,
294+
execution_settings=settings,
292295
function_call_count=fc_count,
293296
request_index=request_index,
294297
function_behavior=settings.function_choice_behavior,

‎python/semantic_kernel/connectors/memory/in_memory/in_memory_collection.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,15 @@
66

77
from pydantic import Field
88

9-
from semantic_kernel.data.filter_clauses.any_tags_equal_to_filter_clause import AnyTagsEqualTo
10-
from semantic_kernel.data.filter_clauses.equal_to_filter_clause import EqualTo
11-
129
if sys.version_info >= (3, 12):
1310
from typing import override # pragma: no cover
1411
else:
1512
from typing_extensions import override # pragma: no cover
1613

1714
from semantic_kernel.connectors.memory.in_memory.const import DISTANCE_FUNCTION_MAP
1815
from semantic_kernel.data.const import DistanceFunction
16+
from semantic_kernel.data.filter_clauses.any_tags_equal_to_filter_clause import AnyTagsEqualTo
17+
from semantic_kernel.data.filter_clauses.equal_to_filter_clause import EqualTo
1918
from semantic_kernel.data.filter_clauses.filter_clause_base import FilterClauseBase
2019
from semantic_kernel.data.kernel_search_results import KernelSearchResults
2120
from semantic_kernel.data.record_definition.vector_store_model_definition import VectorStoreRecordDefinition
@@ -29,6 +28,7 @@
2928
from semantic_kernel.data.vector_search.vectorized_search import VectorizedSearchMixin
3029
from semantic_kernel.exceptions import VectorSearchExecutionException, VectorStoreModelValidationError
3130
from semantic_kernel.kernel_types import OneOrMany
31+
from semantic_kernel.utils.list_handler import empty_generator
3232

3333
KEY_TYPES = str | int | float
3434

@@ -171,7 +171,7 @@ async def _inner_search_vectorized(
171171
),
172172
total_count=len(return_records) if options and options.include_total_count else None,
173173
)
174-
return KernelSearchResults(results=None)
174+
return KernelSearchResults(results=empty_generator())
175175

176176
async def _generate_return_list(
177177
self, return_records: dict[KEY_TYPES, float], options: VectorSearchOptions | None

0 commit comments

Comments
 (0)