|
3 | 3 | import time
|
4 | 4 |
|
5 | 5 | import pytest
|
| 6 | +from psycopg_pool import PoolTimeout |
6 | 7 | from pydantic import ValidationError
|
7 | 8 |
|
8 | 9 | from semantic_kernel.connectors.memory.postgres import PostgresMemoryStore
|
@@ -52,147 +53,162 @@ def test_constructor(connection_string):
|
52 | 53 | @pytest.mark.asyncio
|
53 | 54 | async def test_create_and_does_collection_exist(connection_string):
|
54 | 55 | memory = PostgresMemoryStore(connection_string, 2, 1, 5)
|
55 |
| - |
56 |
| - await memory.create_collection("test_collection") |
57 |
| - result = await memory.does_collection_exist("test_collection") |
58 |
| - assert result is not None |
| 56 | + try: |
| 57 | + await memory.create_collection("test_collection") |
| 58 | + result = await memory.does_collection_exist("test_collection") |
| 59 | + assert result is not None |
| 60 | + except PoolTimeout: |
| 61 | + pytest.skip("PoolTimeout exception raised, skipping test.") |
59 | 62 |
|
60 | 63 |
|
61 | 64 | @pytest.mark.asyncio
|
62 | 65 | async def test_get_collections(connection_string):
|
63 | 66 | memory = PostgresMemoryStore(connection_string, 2, 1, 5)
|
64 | 67 |
|
65 |
| - await memory.create_collection("test_collection") |
66 |
| - result = await memory.get_collections() |
67 |
| - assert "test_collection" in result |
| 68 | + try: |
| 69 | + await memory.create_collection("test_collection") |
| 70 | + result = await memory.get_collections() |
| 71 | + assert "test_collection" in result |
| 72 | + except PoolTimeout: |
| 73 | + pytest.skip("PoolTimeout exception raised, skipping test.") |
68 | 74 |
|
69 | 75 |
|
70 | 76 | @pytest.mark.asyncio
|
71 | 77 | async def test_delete_collection(connection_string):
|
72 | 78 | memory = PostgresMemoryStore(connection_string, 2, 1, 5)
|
| 79 | + try: |
| 80 | + await memory.create_collection("test_collection") |
73 | 81 |
|
74 |
| - await memory.create_collection("test_collection") |
75 |
| - |
76 |
| - result = await memory.get_collections() |
77 |
| - assert "test_collection" in result |
| 82 | + result = await memory.get_collections() |
| 83 | + assert "test_collection" in result |
78 | 84 |
|
79 |
| - await memory.delete_collection("test_collection") |
80 |
| - result = await memory.get_collections() |
81 |
| - assert "test_collection" not in result |
| 85 | + await memory.delete_collection("test_collection") |
| 86 | + result = await memory.get_collections() |
| 87 | + assert "test_collection" not in result |
| 88 | + except PoolTimeout: |
| 89 | + pytest.skip("PoolTimeout exception raised, skipping test.") |
82 | 90 |
|
83 | 91 |
|
84 | 92 | @pytest.mark.asyncio
|
85 | 93 | async def test_does_collection_exist(connection_string):
|
86 | 94 | memory = PostgresMemoryStore(connection_string, 2, 1, 5)
|
87 |
| - |
88 |
| - await memory.create_collection("test_collection") |
89 |
| - result = await memory.does_collection_exist("test_collection") |
90 |
| - assert result is True |
| 95 | + try: |
| 96 | + await memory.create_collection("test_collection") |
| 97 | + result = await memory.does_collection_exist("test_collection") |
| 98 | + assert result is True |
| 99 | + except PoolTimeout: |
| 100 | + pytest.skip("PoolTimeout exception raised, skipping test.") |
91 | 101 |
|
92 | 102 |
|
93 | 103 | @pytest.mark.asyncio
|
94 | 104 | async def test_upsert_and_get(connection_string, memory_record1):
|
95 | 105 | memory = PostgresMemoryStore(connection_string, 2, 1, 5)
|
96 |
| - |
97 |
| - await memory.create_collection("test_collection") |
98 |
| - await memory.upsert("test_collection", memory_record1) |
99 |
| - result = await memory.get("test_collection", memory_record1._id, with_embedding=True) |
100 |
| - assert result is not None |
101 |
| - assert result._id == memory_record1._id |
102 |
| - assert result._text == memory_record1._text |
103 |
| - assert result._timestamp == memory_record1._timestamp |
104 |
| - for i in range(len(result._embedding)): |
105 |
| - assert result._embedding[i] == memory_record1._embedding[i] |
| 106 | + try: |
| 107 | + await memory.create_collection("test_collection") |
| 108 | + await memory.upsert("test_collection", memory_record1) |
| 109 | + result = await memory.get("test_collection", memory_record1._id, with_embedding=True) |
| 110 | + assert result is not None |
| 111 | + assert result._id == memory_record1._id |
| 112 | + assert result._text == memory_record1._text |
| 113 | + assert result._timestamp == memory_record1._timestamp |
| 114 | + for i in range(len(result._embedding)): |
| 115 | + assert result._embedding[i] == memory_record1._embedding[i] |
| 116 | + except PoolTimeout: |
| 117 | + pytest.skip("PoolTimeout exception raised, skipping test.") |
106 | 118 |
|
107 | 119 |
|
108 |
| -@pytest.mark.xfail(reason="Test failing with reason couldn't: get a connection after 30.00 sec") |
109 | 120 | @pytest.mark.asyncio
|
110 | 121 | async def test_upsert_batch_and_get_batch(connection_string, memory_record1, memory_record2):
|
111 | 122 | memory = PostgresMemoryStore(connection_string, 2, 1, 5)
|
| 123 | + try: |
| 124 | + await memory.create_collection("test_collection") |
| 125 | + await memory.upsert_batch("test_collection", [memory_record1, memory_record2]) |
112 | 126 |
|
113 |
| - await memory.create_collection("test_collection") |
114 |
| - await memory.upsert_batch("test_collection", [memory_record1, memory_record2]) |
115 |
| - |
116 |
| - results = await memory.get_batch( |
117 |
| - "test_collection", |
118 |
| - [memory_record1._id, memory_record2._id], |
119 |
| - with_embeddings=True, |
120 |
| - ) |
121 |
| - |
122 |
| - assert len(results) == 2 |
123 |
| - assert results[0]._id in [memory_record1._id, memory_record2._id] |
124 |
| - assert results[1]._id in [memory_record1._id, memory_record2._id] |
| 127 | + results = await memory.get_batch( |
| 128 | + "test_collection", |
| 129 | + [memory_record1._id, memory_record2._id], |
| 130 | + with_embeddings=True, |
| 131 | + ) |
| 132 | + assert len(results) == 2 |
| 133 | + assert results[0]._id in [memory_record1._id, memory_record2._id] |
| 134 | + assert results[1]._id in [memory_record1._id, memory_record2._id] |
| 135 | + except PoolTimeout: |
| 136 | + pytest.skip("PoolTimeout exception raised, skipping test.") |
125 | 137 |
|
126 | 138 |
|
127 |
| -@pytest.mark.xfail(reason="Test failing with reason couldn't: get a connection after 30.00 sec") |
128 | 139 | @pytest.mark.asyncio
|
129 | 140 | async def test_remove(connection_string, memory_record1):
|
130 | 141 | memory = PostgresMemoryStore(connection_string, 2, 1, 5)
|
| 142 | + try: |
| 143 | + await memory.create_collection("test_collection") |
| 144 | + await memory.upsert("test_collection", memory_record1) |
131 | 145 |
|
132 |
| - await memory.create_collection("test_collection") |
133 |
| - await memory.upsert("test_collection", memory_record1) |
134 |
| - |
135 |
| - result = await memory.get("test_collection", memory_record1._id, with_embedding=True) |
136 |
| - assert result is not None |
| 146 | + result = await memory.get("test_collection", memory_record1._id, with_embedding=True) |
| 147 | + assert result is not None |
137 | 148 |
|
138 |
| - await memory.remove("test_collection", memory_record1._id) |
139 |
| - with pytest.raises(ServiceResourceNotFoundError): |
140 |
| - _ = await memory.get("test_collection", memory_record1._id, with_embedding=True) |
| 149 | + await memory.remove("test_collection", memory_record1._id) |
| 150 | + with pytest.raises(ServiceResourceNotFoundError): |
| 151 | + await memory.get("test_collection", memory_record1._id, with_embedding=True) |
| 152 | + except PoolTimeout: |
| 153 | + pytest.skip("PoolTimeout exception raised, skipping test.") |
141 | 154 |
|
142 | 155 |
|
143 |
| -@pytest.mark.xfail(reason="Test failing with reason couldn't: get a connection after 30.00 sec") |
144 | 156 | @pytest.mark.asyncio
|
145 | 157 | async def test_remove_batch(connection_string, memory_record1, memory_record2):
|
146 | 158 | memory = PostgresMemoryStore(connection_string, 2, 1, 5)
|
| 159 | + try: |
| 160 | + await memory.create_collection("test_collection") |
| 161 | + await memory.upsert_batch("test_collection", [memory_record1, memory_record2]) |
| 162 | + await memory.remove_batch("test_collection", [memory_record1._id, memory_record2._id]) |
| 163 | + with pytest.raises(ServiceResourceNotFoundError): |
| 164 | + _ = await memory.get("test_collection", memory_record1._id, with_embedding=True) |
147 | 165 |
|
148 |
| - await memory.create_collection("test_collection") |
149 |
| - await memory.upsert_batch("test_collection", [memory_record1, memory_record2]) |
150 |
| - await memory.remove_batch("test_collection", [memory_record1._id, memory_record2._id]) |
151 |
| - with pytest.raises(ServiceResourceNotFoundError): |
152 |
| - _ = await memory.get("test_collection", memory_record1._id, with_embedding=True) |
153 |
| - |
154 |
| - with pytest.raises(ServiceResourceNotFoundError): |
155 |
| - _ = await memory.get("test_collection", memory_record2._id, with_embedding=True) |
| 166 | + with pytest.raises(ServiceResourceNotFoundError): |
| 167 | + _ = await memory.get("test_collection", memory_record2._id, with_embedding=True) |
| 168 | + except PoolTimeout: |
| 169 | + pytest.skip("PoolTimeout exception raised, skipping test.") |
156 | 170 |
|
157 | 171 |
|
158 |
| -@pytest.mark.xfail(reason="Test failing with reason couldn't: get a connection after 30.00 sec") |
159 | 172 | @pytest.mark.asyncio
|
160 | 173 | async def test_get_nearest_match(connection_string, memory_record1, memory_record2):
|
161 | 174 | memory = PostgresMemoryStore(connection_string, 2, 1, 5)
|
162 |
| - |
163 |
| - await memory.create_collection("test_collection") |
164 |
| - await memory.upsert_batch("test_collection", [memory_record1, memory_record2]) |
165 |
| - test_embedding = memory_record1.embedding.copy() |
166 |
| - test_embedding[0] = test_embedding[0] + 0.01 |
167 |
| - |
168 |
| - result = await memory.get_nearest_match( |
169 |
| - "test_collection", test_embedding, min_relevance_score=0.0, with_embedding=True |
170 |
| - ) |
171 |
| - assert result is not None |
172 |
| - assert result[0]._id == memory_record1._id |
173 |
| - assert result[0]._text == memory_record1._text |
174 |
| - assert result[0]._timestamp == memory_record1._timestamp |
175 |
| - for i in range(len(result[0]._embedding)): |
176 |
| - assert result[0]._embedding[i] == memory_record1._embedding[i] |
| 175 | + try: |
| 176 | + await memory.create_collection("test_collection") |
| 177 | + await memory.upsert_batch("test_collection", [memory_record1, memory_record2]) |
| 178 | + test_embedding = memory_record1.embedding.copy() |
| 179 | + test_embedding[0] = test_embedding[0] + 0.01 |
| 180 | + |
| 181 | + result = await memory.get_nearest_match( |
| 182 | + "test_collection", test_embedding, min_relevance_score=0.0, with_embedding=True |
| 183 | + ) |
| 184 | + assert result is not None |
| 185 | + assert result[0]._id == memory_record1._id |
| 186 | + assert result[0]._text == memory_record1._text |
| 187 | + assert result[0]._timestamp == memory_record1._timestamp |
| 188 | + for i in range(len(result[0]._embedding)): |
| 189 | + assert result[0]._embedding[i] == memory_record1._embedding[i] |
| 190 | + except PoolTimeout: |
| 191 | + pytest.skip("PoolTimeout exception raised, skipping test.") |
177 | 192 |
|
178 | 193 |
|
179 | 194 | @pytest.mark.asyncio
|
180 |
| -@pytest.mark.xfail(reason="The test is failing due to a timeout.") |
181 | 195 | async def test_get_nearest_matches(connection_string, memory_record1, memory_record2, memory_record3):
|
182 | 196 | memory = PostgresMemoryStore(connection_string, 2, 1, 5)
|
183 |
| - |
184 |
| - await memory.create_collection("test_collection") |
185 |
| - await memory.upsert_batch("test_collection", [memory_record1, memory_record2, memory_record3]) |
186 |
| - test_embedding = memory_record2.embedding |
187 |
| - test_embedding[0] = test_embedding[0] + 0.025 |
188 |
| - |
189 |
| - result = await memory.get_nearest_matches( |
190 |
| - "test_collection", |
191 |
| - test_embedding, |
192 |
| - limit=2, |
193 |
| - min_relevance_score=0.0, |
194 |
| - with_embeddings=True, |
195 |
| - ) |
196 |
| - assert len(result) == 2 |
197 |
| - assert result[0][0]._id in [memory_record3._id, memory_record2._id] |
198 |
| - assert result[1][0]._id in [memory_record3._id, memory_record2._id] |
| 197 | + try: |
| 198 | + await memory.create_collection("test_collection") |
| 199 | + await memory.upsert_batch("test_collection", [memory_record1, memory_record2, memory_record3]) |
| 200 | + test_embedding = memory_record2.embedding |
| 201 | + test_embedding[0] = test_embedding[0] + 0.025 |
| 202 | + |
| 203 | + result = await memory.get_nearest_matches( |
| 204 | + "test_collection", |
| 205 | + test_embedding, |
| 206 | + limit=2, |
| 207 | + min_relevance_score=0.0, |
| 208 | + with_embeddings=True, |
| 209 | + ) |
| 210 | + assert len(result) == 2 |
| 211 | + assert result[0][0]._id in [memory_record3._id, memory_record2._id] |
| 212 | + assert result[1][0]._id in [memory_record3._id, memory_record2._id] |
| 213 | + except PoolTimeout: |
| 214 | + pytest.skip("PoolTimeout exception raised, skipping test.") |
0 commit comments