Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 1e40ad1

Browse files
tarsilansipunk
and
ansipunk
authoredFeb 21, 2024
🪛 Moving to SQLAlchemy 2.0 (encode#540)
* 🪛 Added support for SQLAlchemy 2.0 * Added common and dialects packages to handle the new SQLAlchemy 2.0+ * 🪲 Fix specific asyncpg oriented test --------- Co-authored-by: ansipunk <[email protected]>
1 parent c2e4c5b commit 1e40ad1

19 files changed

+394
-275
lines changed
 

‎.github/workflows/test-suite.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ jobs:
1414

1515
strategy:
1616
matrix:
17-
python-version: ["3.7", "3.8", "3.9", "3.10"]
17+
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
1818

1919
services:
2020
mysql:

‎README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ values = [
8585
]
8686
await database.execute_many(query=query, values=values)
8787

88-
# Run a database query.
88+
# Run a database query.
8989
query = "SELECT * FROM HighScores"
9090
rows = await database.fetch_all(query=query)
9191
print('High Scores:', rows)
@@ -115,4 +115,4 @@ for examples of how to start using databases together with SQLAlchemy core expre
115115
[quart]: https://gitlab.com/pgjones/quart
116116
[aiohttp]: https://github.com/aio-libs/aiohttp
117117
[tornado]: https://github.com/tornadoweb/tornado
118-
[fastapi]: https://github.com/tiangolo/fastapi
118+
[fastapi]: https://github.com/tiangolo/fastapi

‎databases/backends/aiopg.py

+44-25
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,20 @@
55
import uuid
66

77
import aiopg
8-
from aiopg.sa.engine import APGCompiler_psycopg2
9-
from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2
108
from sqlalchemy.engine.cursor import CursorResultMetaData
119
from sqlalchemy.engine.interfaces import Dialect, ExecutionContext
1210
from sqlalchemy.engine.row import Row
1311
from sqlalchemy.sql import ClauseElement
1412
from sqlalchemy.sql.ddl import DDLElement
1513

16-
from databases.core import DatabaseURL
14+
from databases.backends.common.records import Record, Row, create_column_maps
15+
from databases.backends.compilers.psycopg import PGCompiler_psycopg
16+
from databases.backends.dialects.psycopg import PGDialect_psycopg
17+
from databases.core import LOG_EXTRA, DatabaseURL
1718
from databases.interfaces import (
1819
ConnectionBackend,
1920
DatabaseBackend,
20-
Record,
21+
Record as RecordInterface,
2122
TransactionBackend,
2223
)
2324

@@ -34,10 +35,10 @@ def __init__(
3435
self._pool: typing.Union[aiopg.Pool, None] = None
3536

3637
def _get_dialect(self) -> Dialect:
37-
dialect = PGDialect_psycopg2(
38+
dialect = PGDialect_psycopg(
3839
json_serializer=json.dumps, json_deserializer=lambda x: x
3940
)
40-
dialect.statement_compiler = APGCompiler_psycopg2
41+
dialect.statement_compiler = PGCompiler_psycopg
4142
dialect.implicit_returning = True
4243
dialect.supports_native_enum = True
4344
dialect.supports_smallserial = True # 9.2+
@@ -117,50 +118,55 @@ async def release(self) -> None:
117118
await self._database._pool.release(self._connection)
118119
self._connection = None
119120

120-
async def fetch_all(self, query: ClauseElement) -> typing.List[Record]:
121+
async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]:
121122
assert self._connection is not None, "Connection is not acquired"
122-
query_str, args, context = self._compile(query)
123+
query_str, args, result_columns, context = self._compile(query)
124+
column_maps = create_column_maps(result_columns)
125+
dialect = self._dialect
126+
123127
cursor = await self._connection.cursor()
124128
try:
125129
await cursor.execute(query_str, args)
126130
rows = await cursor.fetchall()
127131
metadata = CursorResultMetaData(context, cursor.description)
128-
return [
132+
rows = [
129133
Row(
130134
metadata,
131135
metadata._processors,
132136
metadata._keymap,
133-
Row._default_key_style,
134137
row,
135138
)
136139
for row in rows
137140
]
141+
return [Record(row, result_columns, dialect, column_maps) for row in rows]
138142
finally:
139143
cursor.close()
140144

141-
async def fetch_one(self, query: ClauseElement) -> typing.Optional[Record]:
145+
async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]:
142146
assert self._connection is not None, "Connection is not acquired"
143-
query_str, args, context = self._compile(query)
147+
query_str, args, result_columns, context = self._compile(query)
148+
column_maps = create_column_maps(result_columns)
149+
dialect = self._dialect
144150
cursor = await self._connection.cursor()
145151
try:
146152
await cursor.execute(query_str, args)
147153
row = await cursor.fetchone()
148154
if row is None:
149155
return None
150156
metadata = CursorResultMetaData(context, cursor.description)
151-
return Row(
157+
row = Row(
152158
metadata,
153159
metadata._processors,
154160
metadata._keymap,
155-
Row._default_key_style,
156161
row,
157162
)
163+
return Record(row, result_columns, dialect, column_maps)
158164
finally:
159165
cursor.close()
160166

161167
async def execute(self, query: ClauseElement) -> typing.Any:
162168
assert self._connection is not None, "Connection is not acquired"
163-
query_str, args, context = self._compile(query)
169+
query_str, args, _, _ = self._compile(query)
164170
cursor = await self._connection.cursor()
165171
try:
166172
await cursor.execute(query_str, args)
@@ -173,7 +179,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None:
173179
cursor = await self._connection.cursor()
174180
try:
175181
for single_query in queries:
176-
single_query, args, context = self._compile(single_query)
182+
single_query, args, _, _ = self._compile(single_query)
177183
await cursor.execute(single_query, args)
178184
finally:
179185
cursor.close()
@@ -182,36 +188,37 @@ async def iterate(
182188
self, query: ClauseElement
183189
) -> typing.AsyncGenerator[typing.Any, None]:
184190
assert self._connection is not None, "Connection is not acquired"
185-
query_str, args, context = self._compile(query)
191+
query_str, args, result_columns, context = self._compile(query)
192+
column_maps = create_column_maps(result_columns)
193+
dialect = self._dialect
186194
cursor = await self._connection.cursor()
187195
try:
188196
await cursor.execute(query_str, args)
189197
metadata = CursorResultMetaData(context, cursor.description)
190198
async for row in cursor:
191-
yield Row(
199+
record = Row(
192200
metadata,
193201
metadata._processors,
194202
metadata._keymap,
195-
Row._default_key_style,
196203
row,
197204
)
205+
yield Record(record, result_columns, dialect, column_maps)
198206
finally:
199207
cursor.close()
200208

201209
def transaction(self) -> TransactionBackend:
202210
return AiopgTransaction(self)
203211

204-
def _compile(
205-
self, query: ClauseElement
206-
) -> typing.Tuple[str, dict, CompilationContext]:
212+
def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]:
207213
compiled = query.compile(
208214
dialect=self._dialect, compile_kwargs={"render_postcompile": True}
209215
)
210-
211216
execution_context = self._dialect.execution_ctx_cls()
212217
execution_context.dialect = self._dialect
213218

214219
if not isinstance(query, DDLElement):
220+
compiled_params = sorted(compiled.params.items())
221+
215222
args = compiled.construct_params()
216223
for key, val in args.items():
217224
if key in compiled._bind_processors:
@@ -224,11 +231,23 @@ def _compile(
224231
compiled._ad_hoc_textual,
225232
compiled._loose_column_name_matching,
226233
)
234+
235+
mapping = {
236+
key: "$" + str(i) for i, (key, _) in enumerate(compiled_params, start=1)
237+
}
238+
compiled_query = compiled.string % mapping
239+
result_map = compiled._result_columns
240+
227241
else:
228242
args = {}
243+
result_map = None
244+
compiled_query = compiled.string
229245

230-
logger.debug("Query: %s\nArgs: %s", compiled.string, args)
231-
return compiled.string, args, CompilationContext(execution_context)
246+
query_message = compiled_query.replace(" \n", " ").replace("\n", " ")
247+
logger.debug(
248+
"Query: %s Args: %s", query_message, repr(tuple(args)), extra=LOG_EXTRA
249+
)
250+
return compiled.string, args, result_map, CompilationContext(execution_context)
232251

233252
@property
234253
def raw_connection(self) -> aiopg.connection.Connection:

‎databases/backends/asyncmy.py

+41-22
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,15 @@
77
from sqlalchemy.dialects.mysql import pymysql
88
from sqlalchemy.engine.cursor import CursorResultMetaData
99
from sqlalchemy.engine.interfaces import Dialect, ExecutionContext
10-
from sqlalchemy.engine.row import Row
1110
from sqlalchemy.sql import ClauseElement
1211
from sqlalchemy.sql.ddl import DDLElement
1312

13+
from databases.backends.common.records import Record, Row, create_column_maps
1414
from databases.core import LOG_EXTRA, DatabaseURL
1515
from databases.interfaces import (
1616
ConnectionBackend,
1717
DatabaseBackend,
18-
Record,
18+
Record as RecordInterface,
1919
TransactionBackend,
2020
)
2121

@@ -108,50 +108,57 @@ async def release(self) -> None:
108108
await self._database._pool.release(self._connection)
109109
self._connection = None
110110

111-
async def fetch_all(self, query: ClauseElement) -> typing.List[Record]:
111+
async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]:
112112
assert self._connection is not None, "Connection is not acquired"
113-
query_str, args, context = self._compile(query)
113+
query_str, args, result_columns, context = self._compile(query)
114+
column_maps = create_column_maps(result_columns)
115+
dialect = self._dialect
116+
114117
async with self._connection.cursor() as cursor:
115118
try:
116119
await cursor.execute(query_str, args)
117120
rows = await cursor.fetchall()
118121
metadata = CursorResultMetaData(context, cursor.description)
119-
return [
122+
rows = [
120123
Row(
121124
metadata,
122125
metadata._processors,
123126
metadata._keymap,
124-
Row._default_key_style,
125127
row,
126128
)
127129
for row in rows
128130
]
131+
return [
132+
Record(row, result_columns, dialect, column_maps) for row in rows
133+
]
129134
finally:
130135
await cursor.close()
131136

132-
async def fetch_one(self, query: ClauseElement) -> typing.Optional[Record]:
137+
async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]:
133138
assert self._connection is not None, "Connection is not acquired"
134-
query_str, args, context = self._compile(query)
139+
query_str, args, result_columns, context = self._compile(query)
140+
column_maps = create_column_maps(result_columns)
141+
dialect = self._dialect
135142
async with self._connection.cursor() as cursor:
136143
try:
137144
await cursor.execute(query_str, args)
138145
row = await cursor.fetchone()
139146
if row is None:
140147
return None
141148
metadata = CursorResultMetaData(context, cursor.description)
142-
return Row(
149+
row = Row(
143150
metadata,
144151
metadata._processors,
145152
metadata._keymap,
146-
Row._default_key_style,
147153
row,
148154
)
155+
return Record(row, result_columns, dialect, column_maps)
149156
finally:
150157
await cursor.close()
151158

152159
async def execute(self, query: ClauseElement) -> typing.Any:
153160
assert self._connection is not None, "Connection is not acquired"
154-
query_str, args, context = self._compile(query)
161+
query_str, args, _, _ = self._compile(query)
155162
async with self._connection.cursor() as cursor:
156163
try:
157164
await cursor.execute(query_str, args)
@@ -166,7 +173,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None:
166173
async with self._connection.cursor() as cursor:
167174
try:
168175
for single_query in queries:
169-
single_query, args, context = self._compile(single_query)
176+
single_query, args, _, _ = self._compile(single_query)
170177
await cursor.execute(single_query, args)
171178
finally:
172179
await cursor.close()
@@ -175,36 +182,37 @@ async def iterate(
175182
self, query: ClauseElement
176183
) -> typing.AsyncGenerator[typing.Any, None]:
177184
assert self._connection is not None, "Connection is not acquired"
178-
query_str, args, context = self._compile(query)
185+
query_str, args, result_columns, context = self._compile(query)
186+
column_maps = create_column_maps(result_columns)
187+
dialect = self._dialect
179188
async with self._connection.cursor() as cursor:
180189
try:
181190
await cursor.execute(query_str, args)
182191
metadata = CursorResultMetaData(context, cursor.description)
183192
async for row in cursor:
184-
yield Row(
193+
record = Row(
185194
metadata,
186195
metadata._processors,
187196
metadata._keymap,
188-
Row._default_key_style,
189197
row,
190198
)
199+
yield Record(record, result_columns, dialect, column_maps)
191200
finally:
192201
await cursor.close()
193202

194203
def transaction(self) -> TransactionBackend:
195204
return AsyncMyTransaction(self)
196205

197-
def _compile(
198-
self, query: ClauseElement
199-
) -> typing.Tuple[str, dict, CompilationContext]:
206+
def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]:
200207
compiled = query.compile(
201208
dialect=self._dialect, compile_kwargs={"render_postcompile": True}
202209
)
203-
204210
execution_context = self._dialect.execution_ctx_cls()
205211
execution_context.dialect = self._dialect
206212

207213
if not isinstance(query, DDLElement):
214+
compiled_params = sorted(compiled.params.items())
215+
208216
args = compiled.construct_params()
209217
for key, val in args.items():
210218
if key in compiled._bind_processors:
@@ -217,12 +225,23 @@ def _compile(
217225
compiled._ad_hoc_textual,
218226
compiled._loose_column_name_matching,
219227
)
228+
229+
mapping = {
230+
key: "$" + str(i) for i, (key, _) in enumerate(compiled_params, start=1)
231+
}
232+
compiled_query = compiled.string % mapping
233+
result_map = compiled._result_columns
234+
220235
else:
221236
args = {}
237+
result_map = None
238+
compiled_query = compiled.string
222239

223-
query_message = compiled.string.replace(" \n", " ").replace("\n", " ")
224-
logger.debug("Query: %s Args: %s", query_message, repr(args), extra=LOG_EXTRA)
225-
return compiled.string, args, CompilationContext(execution_context)
240+
query_message = compiled_query.replace(" \n", " ").replace("\n", " ")
241+
logger.debug(
242+
"Query: %s Args: %s", query_message, repr(tuple(args)), extra=LOG_EXTRA
243+
)
244+
return compiled.string, args, result_map, CompilationContext(execution_context)
226245

227246
@property
228247
def raw_connection(self) -> asyncmy.connection.Connection:

‎databases/backends/common/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)
Please sign in to comment.