Skip to content

Commit 3724e24

Browse files
authoredJan 30, 2025
Python: Add vector search to Postgres connector (microsoft#10213)
### Motivation and Context Following up on microsoft#8951, this PR adds an implementation of `VectorSearchBase` to `PostgresCollection`. This implementation provides vectorized search and does not implement text search or vectorizable text search. Unit and integration tests are added, and the `python/samples/getting_started/third_party/postgres-memory.ipynb` notebook was expanded to include vector search in the example. ### Contribution Checklist - [x] The code builds clean without any errors or warnings - [x] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [x] All unit tests pass, and I have added new tests where possible - [x] I didn't break anyone 😄
1 parent f32972c commit 3724e24

File tree

8 files changed

+955
-136
lines changed

8 files changed

+955
-136
lines changed
 

‎python/samples/getting_started/third_party/postgres-memory.ipynb

+400-27
Large diffs are not rendered by default.

‎python/semantic_kernel/connectors/memory/postgres/constants.py

+4
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
# Limitation based on pgvector documentation https://github.com/pgvector/pgvector#what-if-i-want-to-index-vectors-with-more-than-2000-dimensions
66
MAX_DIMENSIONALITY = 2000
77

8+
# The name of the column that returns distance value in the database.
9+
# It is used in the similarity search query. Must not conflict with model property.
10+
DISTANCE_COLUMN_NAME = "sk_pg_distance"
11+
812
# Environment Variables
913
PGHOST_ENV_VAR = "PGHOST"
1014
PGPORT_ENV_VAR = "PGPORT"

‎python/semantic_kernel/connectors/memory/postgres/postgres_collection.py

+311-71
Large diffs are not rendered by default.

‎python/semantic_kernel/connectors/memory/postgres/postgres_settings.py

+23-21
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from psycopg.conninfo import conninfo_to_dict
66
from psycopg_pool import AsyncConnectionPool
7+
from psycopg_pool.abc import ACT
78
from pydantic import Field, SecretStr
89

910
from semantic_kernel.connectors.memory.postgres.constants import (
@@ -14,10 +15,7 @@
1415
PGSSL_MODE_ENV_VAR,
1516
PGUSER_ENV_VAR,
1617
)
17-
from semantic_kernel.exceptions.memory_connector_exceptions import (
18-
MemoryConnectorConnectionException,
19-
MemoryConnectorInitializationError,
20-
)
18+
from semantic_kernel.exceptions.memory_connector_exceptions import MemoryConnectorConnectionException
2119
from semantic_kernel.kernel_pydantic import KernelBaseSettings
2220
from semantic_kernel.utils.experimental_decorator import experimental_class
2321

@@ -89,30 +87,34 @@ def get_connection_args(self) -> dict[str, Any]:
8987
if self.password:
9088
result["password"] = self.password.get_secret_value()
9189

92-
# Ensure required values
93-
if "host" not in result:
94-
raise MemoryConnectorInitializationError("host is required. Please set PGHOST or connection_string.")
95-
if "dbname" not in result:
96-
raise MemoryConnectorInitializationError(
97-
"database is required. Please set PGDATABASE or connection_string."
98-
)
99-
if "user" not in result:
100-
raise MemoryConnectorInitializationError("user is required. Please set PGUSER or connection_string.")
101-
if "password" not in result:
102-
raise MemoryConnectorInitializationError(
103-
"password is required. Please set PGPASSWORD or connection_string."
104-
)
105-
10690
return result
10791

108-
async def create_connection_pool(self) -> AsyncConnectionPool:
109-
"""Creates a connection pool based off of settings."""
92+
async def create_connection_pool(
93+
self, connection_class: type[ACT] | None = None, **kwargs: Any
94+
) -> AsyncConnectionPool:
95+
"""Creates a connection pool based off of settings.
96+
97+
Args:
98+
connection_class: The connection class to use.
99+
kwargs: Additional keyword arguments to pass to the connection class.
100+
101+
Returns:
102+
The connection pool.
103+
"""
110104
try:
105+
# Only pass connection_class if it specified, or else allow psycopg to use the default connection class
106+
extra_args: dict[str, Any] = {} if connection_class is None else {"connection_class": connection_class}
107+
111108
pool = AsyncConnectionPool(
112109
min_size=self.min_pool,
113110
max_size=self.max_pool,
114111
open=False,
115-
kwargs=self.get_connection_args(),
112+
# kwargs are passed to the connection class
113+
kwargs={
114+
**self.get_connection_args(),
115+
**kwargs,
116+
},
117+
**extra_args,
116118
)
117119
await pool.open()
118120
except Exception as e:

‎python/semantic_kernel/connectors/memory/postgres/postgres_store.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,20 @@
44
import sys
55
from typing import Any, TypeVar
66

7-
if sys.version_info >= (3, 12):
8-
from typing import override # pragma: no cover
9-
else:
10-
from typing_extensions import override # pragma: no cover
11-
127
from psycopg import sql
138
from psycopg_pool import AsyncConnectionPool
149

1510
from semantic_kernel.connectors.memory.postgres.postgres_collection import PostgresCollection
1611
from semantic_kernel.connectors.memory.postgres.postgres_memory_store import DEFAULT_SCHEMA
17-
from semantic_kernel.data.record_definition.vector_store_model_definition import VectorStoreRecordDefinition
18-
from semantic_kernel.data.vector_storage.vector_store import VectorStore
19-
from semantic_kernel.data.vector_storage.vector_store_record_collection import VectorStoreRecordCollection
12+
from semantic_kernel.data import VectorStore, VectorStoreRecordCollection, VectorStoreRecordDefinition
2013
from semantic_kernel.utils.experimental_decorator import experimental_class
2114

15+
if sys.version_info >= (3, 12):
16+
from typing import override # pragma: no cover
17+
else:
18+
from typing_extensions import override # pragma: no cover
19+
20+
2221
logger: logging.Logger = logging.getLogger(__name__)
2322

2423
TModel = TypeVar("TModel")

‎python/semantic_kernel/connectors/memory/postgres/utils.py

+41-4
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@ def python_type_to_postgres(python_type_str: str) -> str | None:
5252
return None
5353

5454

55-
def convert_row_to_dict(row: tuple[Any, ...], fields: list[tuple[str, VectorStoreRecordField]]) -> dict[str, Any]:
55+
def convert_row_to_dict(
56+
row: tuple[Any, ...], fields: list[tuple[str, VectorStoreRecordField | None]]
57+
) -> dict[str, Any]:
5658
"""Convert a row from a PostgreSQL query to a dictionary.
5759
5860
Uses the field information to map the row values to the corresponding field names.
@@ -65,11 +67,12 @@ def convert_row_to_dict(row: tuple[Any, ...], fields: list[tuple[str, VectorStor
6567
A dictionary representation of the row.
6668
"""
6769

68-
def _convert(v: Any | None, field: VectorStoreRecordField) -> Any | None:
70+
def _convert(v: Any | None, field: VectorStoreRecordField | None) -> Any | None:
6971
if v is None:
7072
return None
71-
if isinstance(field, VectorStoreRecordVectorField):
72-
# psycopg returns vector as a string
73+
if isinstance(field, VectorStoreRecordVectorField) and isinstance(v, str):
74+
# psycopg returns vector as a string if pgvector is not loaded.
75+
# If pgvector is registered with the connection, no conversion is required.
7376
return json.loads(v)
7477
return v
7578

@@ -109,6 +112,8 @@ def get_vector_index_ops_str(distance_function: DistanceFunction) -> str:
109112
>>> get_vector_index_ops_str(DistanceFunction.COSINE)
110113
'vector_cosine_ops'
111114
"""
115+
if distance_function == DistanceFunction.COSINE_DISTANCE:
116+
return "vector_cosine_ops"
112117
if distance_function == DistanceFunction.COSINE_SIMILARITY:
113118
return "vector_cosine_ops"
114119
if distance_function == DistanceFunction.DOT_PROD:
@@ -121,6 +126,38 @@ def get_vector_index_ops_str(distance_function: DistanceFunction) -> str:
121126
raise ValueError(f"Unsupported distance function: {distance_function}")
122127

123128

129+
def get_vector_distance_ops_str(distance_function: DistanceFunction) -> str:
130+
"""Get the PostgreSQL distance operator string for a given distance function.
131+
132+
Args:
133+
distance_function: The distance function for which the operator string is needed.
134+
135+
Note:
136+
For the COSINE_SIMILARITY and DOT_PROD distance functions,
137+
there is additional query steps to retrieve the correct distance.
138+
For dot product, take -1 * inner product, as <#> returns the negative inner product
139+
since Postgres only supports ASC order index scans on operators
140+
For cosine similarity, take 1 - cosine distance.
141+
142+
Returns:
143+
The PostgreSQL distance operator string for the given distance function.
144+
145+
Raises:
146+
ValueError: If the distance function is unsupported.
147+
"""
148+
if distance_function == DistanceFunction.COSINE_DISTANCE:
149+
return "<=>"
150+
if distance_function == DistanceFunction.COSINE_SIMILARITY:
151+
return "<=>"
152+
if distance_function == DistanceFunction.DOT_PROD:
153+
return "<#>"
154+
if distance_function == DistanceFunction.EUCLIDEAN_DISTANCE:
155+
return "<->"
156+
if distance_function == DistanceFunction.MANHATTAN:
157+
return "<+>"
158+
raise ValueError(f"Unsupported distance function: {distance_function}")
159+
160+
124161
async def ensure_open(connection_pool: AsyncConnectionPool) -> AsyncConnectionPool:
125162
"""Ensure the connection pool is open.
126163

‎python/tests/integration/memory/vector_stores/postgres/test_postgres_int.py

+35-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Copyright (c) Microsoft. All rights reserved.
22

33
import uuid
4-
from collections.abc import AsyncGenerator
4+
from collections.abc import AsyncGenerator, Sequence
55
from contextlib import asynccontextmanager
66
from typing import Annotated, Any
77

@@ -11,6 +11,7 @@
1111
from pydantic import BaseModel
1212

1313
from semantic_kernel.connectors.memory.postgres import PostgresSettings, PostgresStore
14+
from semantic_kernel.connectors.memory.postgres.postgres_collection import PostgresCollection
1415
from semantic_kernel.data import (
1516
DistanceFunction,
1617
IndexKind,
@@ -20,6 +21,7 @@
2021
VectorStoreRecordVectorField,
2122
vectorstoremodel,
2223
)
24+
from semantic_kernel.data.vector_search.vector_search_options import VectorSearchOptions
2325
from semantic_kernel.exceptions.memory_connector_exceptions import (
2426
MemoryConnectorConnectionException,
2527
MemoryConnectorInitializationError,
@@ -49,13 +51,13 @@
4951
class SimpleDataModel(BaseModel):
5052
id: Annotated[int, VectorStoreRecordKeyField()]
5153
embedding: Annotated[
52-
list[float],
54+
list[float] | None,
5355
VectorStoreRecordVectorField(
5456
index_kind=IndexKind.HNSW,
5557
dimensions=3,
5658
distance_function=DistanceFunction.COSINE_SIMILARITY,
5759
),
58-
]
60+
] = None
5961
data: Annotated[
6062
dict[str, Any],
6163
VectorStoreRecordDataField(has_embedding=True, embedding_property_name="embedding", property_type="JSONB"),
@@ -97,7 +99,9 @@ async def vector_store() -> AsyncGenerator[PostgresStore, None]:
9799

98100

99101
@asynccontextmanager
100-
async def create_simple_collection(vector_store: PostgresStore):
102+
async def create_simple_collection(
103+
vector_store: PostgresStore,
104+
) -> AsyncGenerator[PostgresCollection[int, SimpleDataModel], None]:
101105
"""Returns a collection with a unique name that is deleted after the context.
102106
103107
This can be moved to use a fixture with scope=function and loop_scope=session
@@ -107,6 +111,7 @@ async def create_simple_collection(vector_store: PostgresStore):
107111
suffix = str(uuid.uuid4()).replace("-", "")[:8]
108112
collection_id = f"test_collection_{suffix}"
109113
collection = vector_store.get_collection(collection_id, SimpleDataModel)
114+
assert isinstance(collection, PostgresCollection)
110115
await collection.create_collection()
111116
try:
112117
yield collection
@@ -213,6 +218,7 @@ async def test_upsert_get_and_delete_batch(vector_store: PostgresStore):
213218
# this should return only the two existing records.
214219
result = await simple_collection.get_batch([1, 2, 3])
215220
assert result is not None
221+
assert isinstance(result, Sequence)
216222
assert len(result) == 2
217223
assert result[0] is not None
218224
assert result[0].id == record1.id
@@ -226,3 +232,28 @@ async def test_upsert_get_and_delete_batch(vector_store: PostgresStore):
226232
await simple_collection.delete_batch([1, 2])
227233
result_after_delete = await simple_collection.get_batch([1, 2])
228234
assert result_after_delete is None
235+
236+
237+
async def test_search(vector_store: PostgresStore):
238+
async with create_simple_collection(vector_store) as simple_collection:
239+
records = [
240+
SimpleDataModel(id=1, embedding=[1.0, 0.0, 0.0], data={"key": "value1"}),
241+
SimpleDataModel(id=2, embedding=[0.8, 0.2, 0.0], data={"key": "value2"}),
242+
SimpleDataModel(id=3, embedding=[0.6, 0.0, 0.4], data={"key": "value3"}),
243+
SimpleDataModel(id=4, embedding=[1.0, 1.0, 0.0], data={"key": "value4"}),
244+
SimpleDataModel(id=5, embedding=[0.0, 1.0, 1.0], data={"key": "value5"}),
245+
SimpleDataModel(id=6, embedding=[1.0, 0.0, 1.0], data={"key": "value6"}),
246+
]
247+
248+
await simple_collection.upsert_batch(records)
249+
250+
try:
251+
search_results = await simple_collection.vectorized_search(
252+
[1.0, 0.0, 0.0], options=VectorSearchOptions(top=3, include_total_count=True)
253+
)
254+
assert search_results is not None
255+
assert search_results.total_count == 3
256+
assert {result.record.id async for result in search_results.results} == {1, 2, 3}
257+
258+
finally:
259+
await simple_collection.delete_batch([r.id for r in records])

‎python/tests/unit/connectors/memory/postgres/test_postgres_store.py

+134-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Annotated, Any
66
from unittest.mock import AsyncMock, MagicMock, Mock, patch
77

8+
import pytest
89
import pytest_asyncio
910
from psycopg import AsyncConnection, AsyncCursor
1011
from psycopg_pool import AsyncConnectionPool
@@ -13,6 +14,8 @@
1314
from semantic_kernel.connectors.ai.open_ai.prompt_execution_settings.open_ai_prompt_execution_settings import (
1415
OpenAIEmbeddingPromptExecutionSettings,
1516
)
17+
from semantic_kernel.connectors.memory.postgres.constants import DISTANCE_COLUMN_NAME
18+
from semantic_kernel.connectors.memory.postgres.postgres_collection import PostgresCollection
1619
from semantic_kernel.connectors.memory.postgres.postgres_settings import PostgresSettings
1720
from semantic_kernel.connectors.memory.postgres.postgres_store import PostgresStore
1821
from semantic_kernel.data.const import DistanceFunction, IndexKind
@@ -22,6 +25,7 @@
2225
VectorStoreRecordKeyField,
2326
VectorStoreRecordVectorField,
2427
)
28+
from semantic_kernel.data.vector_search.vector_search_options import VectorSearchOptions
2529

2630

2731
@fixture(scope="function")
@@ -76,6 +80,9 @@ class SimpleDataModel:
7680
]
7781

7882

83+
# region VectorStore Tests
84+
85+
7986
async def test_vector_store_defaults(vector_store: PostgresStore) -> None:
8087
assert vector_store.connection_pool is not None
8188
async with vector_store.connection_pool.connection() as conn:
@@ -236,7 +243,130 @@ async def test_get_records(vector_store: PostgresStore, mock_cursor: Mock) -> No
236243
assert records[2].data == {"key": "value3"}
237244

238245

239-
# Test settings
246+
# endregion
247+
248+
# region Vector Search tests
249+
250+
251+
@pytest.mark.parametrize(
252+
"distance_function, operator, subquery_distance, include_vectors, include_total_count",
253+
[
254+
(DistanceFunction.COSINE_SIMILARITY, "<=>", f'1 - subquery."{DISTANCE_COLUMN_NAME}"', False, False),
255+
(DistanceFunction.COSINE_DISTANCE, "<=>", None, False, False),
256+
(DistanceFunction.DOT_PROD, "<#>", f'-1 * subquery."{DISTANCE_COLUMN_NAME}"', True, False),
257+
(DistanceFunction.EUCLIDEAN_DISTANCE, "<->", None, False, True),
258+
(DistanceFunction.MANHATTAN, "<+>", None, True, True),
259+
],
260+
)
261+
async def test_vector_search(
262+
vector_store: PostgresStore,
263+
mock_cursor: Mock,
264+
distance_function: DistanceFunction,
265+
operator: str,
266+
subquery_distance: str | None,
267+
include_vectors: bool,
268+
include_total_count: bool,
269+
) -> None:
270+
@vectorstoremodel
271+
@dataclass
272+
class SimpleDataModel:
273+
id: Annotated[int, VectorStoreRecordKeyField()]
274+
embedding: Annotated[
275+
list[float],
276+
VectorStoreRecordVectorField(
277+
embedding_settings={"embedding": OpenAIEmbeddingPromptExecutionSettings(dimensions=1536)},
278+
index_kind=IndexKind.HNSW,
279+
dimensions=1536,
280+
distance_function=distance_function,
281+
property_type="float",
282+
),
283+
]
284+
data: Annotated[
285+
dict[str, Any],
286+
VectorStoreRecordDataField(has_embedding=True, embedding_property_name="embedding", property_type="JSONB"),
287+
]
288+
289+
collection = vector_store.get_collection("test_collection", SimpleDataModel)
290+
assert isinstance(collection, PostgresCollection)
291+
292+
search_results = await collection.vectorized_search(
293+
[1.0, 2.0, 3.0],
294+
options=VectorSearchOptions(
295+
top=10, skip=5, include_vectors=include_vectors, include_total_count=include_total_count
296+
),
297+
)
298+
if include_total_count:
299+
# Including total count issues query directly
300+
assert mock_cursor.execute.call_count == 1
301+
else:
302+
# Total count is not included, query is issued when iterating over results
303+
assert mock_cursor.execute.call_count == 0
304+
async for _ in search_results.results:
305+
pass
306+
assert mock_cursor.execute.call_count == 1
307+
308+
execute_args, _ = mock_cursor.execute.call_args
309+
310+
assert (search_results.total_count is not None) == include_total_count
311+
312+
statement = execute_args[0]
313+
statement_str = statement.as_string()
314+
315+
expected_columns = '"id", "data"'
316+
if include_vectors:
317+
expected_columns = '"id", "embedding", "data"'
318+
319+
expected_statement = (
320+
f'SELECT {expected_columns}, "embedding" {operator} %s as "{DISTANCE_COLUMN_NAME}" '
321+
'FROM "public"."test_collection" '
322+
f'ORDER BY "{DISTANCE_COLUMN_NAME}" LIMIT 10 OFFSET 5'
323+
)
324+
325+
if subquery_distance:
326+
expected_statement = (
327+
f'SELECT subquery.*, {subquery_distance} AS "{DISTANCE_COLUMN_NAME}" FROM ('
328+
+ expected_statement
329+
+ ") AS subquery"
330+
)
331+
332+
assert statement_str == expected_statement
333+
334+
335+
async def test_model_post_init_conflicting_distance_column_name(vector_store: PostgresStore) -> None:
336+
@vectorstoremodel
337+
@dataclass
338+
class ConflictingDataModel:
339+
id: Annotated[int, VectorStoreRecordKeyField()]
340+
sk_pg_distance: Annotated[
341+
float, VectorStoreRecordDataField()
342+
] # Note: test depends on value of DISTANCE_COLUMN_NAME constant
343+
344+
embedding: Annotated[
345+
list[float],
346+
VectorStoreRecordVectorField(
347+
embedding_settings={"embedding": OpenAIEmbeddingPromptExecutionSettings(dimensions=1536)},
348+
index_kind=IndexKind.HNSW,
349+
dimensions=1536,
350+
distance_function=DistanceFunction.COSINE_SIMILARITY,
351+
property_type="float",
352+
),
353+
]
354+
data: Annotated[
355+
dict[str, Any],
356+
VectorStoreRecordDataField(has_embedding=True, embedding_property_name="embedding", property_type="JSONB"),
357+
]
358+
359+
collection = vector_store.get_collection("test_collection", ConflictingDataModel)
360+
assert isinstance(collection, PostgresCollection)
361+
362+
# Ensure that the distance column name has been changed to avoid conflict
363+
assert collection._distance_column_name != DISTANCE_COLUMN_NAME
364+
assert collection._distance_column_name.startswith(f"{DISTANCE_COLUMN_NAME}_")
365+
366+
367+
# endregion
368+
369+
# region Settings tests
240370

241371

242372
def test_settings_connection_string(monkeypatch) -> None:
@@ -290,3 +420,6 @@ def test_settings_env_vars(monkeypatch) -> None:
290420
assert conn_info["dbname"] == "dbname"
291421
assert conn_info["user"] == "user"
292422
assert conn_info["password"] == "password"
423+
424+
425+
# endregion

0 commit comments

Comments
 (0)
Please sign in to comment.