diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 170e9558..642629af 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -15,7 +15,7 @@ jobs: - uses: "actions/checkout@v3" - uses: "actions/setup-python@v4" with: - python-version: 3.7 + python-version: 3.8 - name: "Install dependencies" run: "scripts/install" - name: "Build package & docs" diff --git a/.github/workflows/test-suite.yml b/.github/workflows/test-suite.yml index bc271a65..f85ca99a 100644 --- a/.github/workflows/test-suite.yml +++ b/.github/workflows/test-suite.yml @@ -14,7 +14,7 @@ jobs: strategy: matrix: - python-version: ["3.7", "3.8", "3.9", "3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] services: mysql: diff --git a/CHANGELOG.md b/CHANGELOG.md index 4816bc16..e8ef0174 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,123 +4,225 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). +## 0.9.0 (February 23th, 2024) + +### Changed + +* Drop support for Python 3.7 and add support for Python 3.12 ([#583][#583]) +* Add support for SQLAlchemy 2+ ([#540][#540]) +* Allow SSL string parameters in PostgresSQL URL ([#575][#575]) and ([#576][#576]) + +[#583]: https://github.com/encode/databases/pull/583 +[#540]: https://github.com/encode/databases/pull/540 +[#575]: https://github.com/encode/databases/pull/575 +[#576]: https://github.com/encode/databases/pull/576 + +## 0.8.0 (August 28th, 2023) + +### Added + +* Allow SQLite query parameters and support cached databases ([#561][#561]) +* Support for unix socket for aiomysql and asyncmy ([#551][#551]) + +[#551]: https://github.com/encode/databases/pull/551 +[#561]: https://github.com/encode/databases/pull/546 + +### Changed + +* Change isolation connections and transactions during concurrent usage ([#546][#546]) +* Bump requests from 2.28.1 to 2.31.0 ([#562][#562]) +* Bump starlette from 0.20.4 to 0.27.0 ([#560][#560]) +* Bump up asyncmy version to fix `No module named 'asyncmy.connection'` ([#553][#553]) +* Bump wheel from 0.37.1 to 0.38.1 ([#524][#524]) + +[#546]: https://github.com/encode/databases/pull/546 +[#562]: https://github.com/encode/databases/pull/562 +[#560]: https://github.com/encode/databases/pull/560 +[#553]: https://github.com/encode/databases/pull/553 +[#524]: https://github.com/encode/databases/pull/524 + +### Fixed + +* Fix the type-hints using more standard mode ([#526][#526]) + +[#526]: https://github.com/encode/databases/pull/526 + ## 0.7.0 (Dec 18th, 2022) ### Fixed -* Fixed breaking changes in SQLAlchemy cursor; supports `>=1.4.42,<1.5` (#513). -* Wrapped types in `typing.Optional` where applicable (#510). +* Fixed breaking changes in SQLAlchemy cursor; supports `>=1.4.42,<1.5` ([#513][#513]) +* Wrapped types in `typing.Optional` where applicable ([#510][#510]) + +[#513]: https://github.com/encode/databases/pull/513 +[#510]: https://github.com/encode/databases/pull/510 ## 0.6.2 (Nov 7th, 2022) ### Changed -* Pinned SQLAlchemy `<=1.4.41` to avoid breaking changes (#520). +* Pinned SQLAlchemy `<=1.4.41` to avoid breaking changes ([#520][#520]) + +[#520]: https://github.com/encode/databases/pull/520 ## 0.6.1 (Aug 9th, 2022) ### Fixed -* Improve typing for `Transaction` (#493) -* Allow string indexing into Record (#501) +* Improve typing for `Transaction` ([#493][#493]) +* Allow string indexing into Record ([#501][#501]) + +[#493]: https://github.com/encode/databases/pull/493 +[#501]: https://github.com/encode/databases/pull/501 ## 0.6.0 (May 29th, 2022) -* Dropped Python 3.6 support (#458) +* Dropped Python 3.6 support ([#458][#458]) + +[#458]: https://github.com/encode/databases/pull/458 ### Added -* Add _mapping property to the result set interface (#447 ) -* Add contributing docs (#453 ) +* Add \_mapping property to the result set interface ([#447][#447]) +* Add contributing docs ([#453][#453]) + +[#447]: https://github.com/encode/databases/pull/447 +[#453]: https://github.com/encode/databases/pull/453 ### Fixed -* Fix query result named access (#448) -* Fix connections getting into a bad state when a task is cancelled (#457) -* Revert #328 parallel transactions (#472) -* Change extra installations to specific drivers (#436) +* Fix query result named access ([#448][#448]) +* Fix connections getting into a bad state when a task is cancelled ([#457][#457]) +* Revert #328 parallel transactions ([#472][#472]) +* Change extra installations to specific drivers ([#436][#436]) + +[#448]: https://github.com/encode/databases/pull/448 +[#457]: https://github.com/encode/databases/pull/457 +[#472]: https://github.com/encode/databases/pull/472 +[#436]: https://github.com/encode/databases/pull/436 ## 0.5.4 (January 14th, 2022) ### Added -* Support for Unix domain in connections (#423) -* Added `asyncmy` MySQL driver (#382) +* Support for Unix domain in connections ([#423][#423]) +* Added `asyncmy` MySQL driver ([#382][#382]) + +[#423]: https://github.com/encode/databases/pull/423 +[#382]: https://github.com/encode/databases/pull/382 ### Fixed -* Fix SQLite fetch queries with multiple parameters (#435) -* Changed `Record` type to `Sequence` (#408) +* Fix SQLite fetch queries with multiple parameters ([#435][#435]) +* Changed `Record` type to `Sequence` ([#408][#408]) + +[#435]: https://github.com/encode/databases/pull/435 +[#408]: https://github.com/encode/databases/pull/408 ## 0.5.3 (October 10th, 2021) ### Added -* Support `dialect+driver` for default database drivers like `postgresql+asyncpg` (#396) +* Support `dialect+driver` for default database drivers like `postgresql+asyncpg` ([#396][#396]) + +[#396]: https://github.com/encode/databases/pull/396 ### Fixed -* Documentation of low-level transaction (#390) +* Documentation of low-level transaction ([#390][#390]) + +[#390]: https://github.com/encode/databases/pull/390 ## 0.5.2 (September 10th, 2021) ### Fixed -* Reset counter for failed connections (#385) -* Avoid dangling task-local connections after Database.disconnect() (#211) +* Reset counter for failed connections ([#385][#385]) +* Avoid dangling task-local connections after Database.disconnect() ([#211][#211]) + +[#385]: https://github.com/encode/databases/pull/385 +[#211]: https://github.com/encode/databases/pull/211 ## 0.5.1 (September 2nd, 2021) ### Added -* Make database `connect` and `disconnect` calls idempotent (#379) +* Make database `connect` and `disconnect` calls idempotent ([#379][#379]) + +[#379]: https://github.com/encode/databases/pull/379 ### Fixed -* Fix `in_` and `notin_` queries in SQLAlchemy 1.4 (#378) +* Fix `in_` and `notin_` queries in SQLAlchemy 1.4 ([#378][#378]) + +[#378]: https://github.com/encode/databases/pull/378 ## 0.5.0 (August 26th, 2021) ### Added -* Support SQLAlchemy 1.4 (#299) + +* Support SQLAlchemy 1.4 ([#299][#299]) + +[#299]: https://github.com/encode/databases/pull/299 ### Fixed -* Fix concurrent transactions (#328) +* Fix concurrent transactions ([#328][#328]) + +[#328]: https://github.com/encode/databases/pull/328 ## 0.4.3 (March 26th, 2021) ### Fixed -* Pin SQLAlchemy to <1.4 (#314) +* Pin SQLAlchemy to <1.4 ([#314][#314]) + +[#314]: https://github.com/encode/databases/pull/314 ## 0.4.2 (March 14th, 2021) ### Fixed -* Fix memory leak with asyncpg for SQLAlchemy generic functions (#273) +* Fix memory leak with asyncpg for SQLAlchemy generic functions ([#273][#273]) + +[#273]: https://github.com/encode/databases/pull/273 ## 0.4.1 (November 16th, 2020) ### Fixed -* Remove package dependency on the synchronous DB drivers (#256) +* Remove package dependency on the synchronous DB drivers ([#256][#256]) + +[#256]: https://github.com/encode/databases/pull/256 ## 0.4.0 (October 20th, 2020) ### Added -* Use backend native fetch_val() implementation when available (#132) -* Replace psycopg2-binary with psycopg2 (#204) -* Speed up PostgresConnection fetch() and iterate() (#193) -* Access asyncpg Record field by key on raw query (#207) -* Allow setting min_size and max_size in postgres DSN (#210) -* Add option pool_recycle in postgres DSN (#233) -* Allow extra transaction options (#242) +* Use backend native fetch_val() implementation when available ([#132][#132]) +* Replace psycopg2-binary with psycopg2 ([#204][#204]) +* Speed up PostgresConnection fetch() and iterate() ([#193][#193]) +* Access asyncpg Record field by key on raw query ([#207][#207]) +* Allow setting min_size and max_size in postgres DSN ([#210][#210]) +* Add option pool_recycle in postgres DSN ([#233][#233]) +* Allow extra transaction options ([#242][#242]) + +[#132]: https://github.com/encode/databases/pull/132 +[#204]: https://github.com/encode/databases/pull/204 +[#193]: https://github.com/encode/databases/pull/193 +[#207]: https://github.com/encode/databases/pull/207 +[#210]: https://github.com/encode/databases/pull/210 +[#233]: https://github.com/encode/databases/pull/233 +[#242]: https://github.com/encode/databases/pull/242 ### Fixed -* Fix type hinting for sqlite backend (#227) -* Fix SQLAlchemy DDL statements (#226) -* Make fetch_val call fetch_one for type conversion (#246) -* Unquote username and password in DatabaseURL (#248) +* Fix type hinting for sqlite backend ([#227][#227]) +* Fix SQLAlchemy DDL statements ([#226][#226]) +* Make fetch_val call fetch_one for type conversion ([#246][#246]) +* Unquote username and password in DatabaseURL ([#248][#248]) + +[#227]: https://github.com/encode/databases/pull/227 +[#226]: https://github.com/encode/databases/pull/226 +[#246]: https://github.com/encode/databases/pull/246 +[#248]: https://github.com/encode/databases/pull/248 diff --git a/README.md b/README.md index ba16a104..f40cd173 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ Databases is suitable for integrating against any async Web framework, such as [ **Documentation**: [https://www.encode.io/databases/](https://www.encode.io/databases/) -**Requirements**: Python 3.7+ +**Requirements**: Python 3.8+ --- @@ -85,7 +85,7 @@ values = [ ] await database.execute_many(query=query, values=values) -# Run a database query. +# Run a database query. query = "SELECT * FROM HighScores" rows = await database.fetch_all(query=query) print('High Scores:', rows) diff --git a/databases/__init__.py b/databases/__init__.py index cfb75242..e7390984 100644 --- a/databases/__init__.py +++ b/databases/__init__.py @@ -1,4 +1,4 @@ from databases.core import Database, DatabaseURL -__version__ = "0.7.0" +__version__ = "0.9.0" __all__ = ["Database", "DatabaseURL"] diff --git a/databases/backends/aiopg.py b/databases/backends/aiopg.py index 8668b2b9..0b4d95a3 100644 --- a/databases/backends/aiopg.py +++ b/databases/backends/aiopg.py @@ -5,19 +5,20 @@ import uuid import aiopg -from aiopg.sa.engine import APGCompiler_psycopg2 -from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2 from sqlalchemy.engine.cursor import CursorResultMetaData from sqlalchemy.engine.interfaces import Dialect, ExecutionContext from sqlalchemy.engine.row import Row from sqlalchemy.sql import ClauseElement from sqlalchemy.sql.ddl import DDLElement -from databases.core import DatabaseURL +from databases.backends.common.records import Record, Row, create_column_maps +from databases.backends.compilers.psycopg import PGCompiler_psycopg +from databases.backends.dialects.psycopg import PGDialect_psycopg +from databases.core import LOG_EXTRA, DatabaseURL from databases.interfaces import ( ConnectionBackend, DatabaseBackend, - Record, + Record as RecordInterface, TransactionBackend, ) @@ -34,10 +35,10 @@ def __init__( self._pool: typing.Union[aiopg.Pool, None] = None def _get_dialect(self) -> Dialect: - dialect = PGDialect_psycopg2( + dialect = PGDialect_psycopg( json_serializer=json.dumps, json_deserializer=lambda x: x ) - dialect.statement_compiler = APGCompiler_psycopg2 + dialect.statement_compiler = PGCompiler_psycopg dialect.implicit_returning = True dialect.supports_native_enum = True dialect.supports_smallserial = True # 9.2+ @@ -117,30 +118,35 @@ async def release(self) -> None: await self._database._pool.release(self._connection) self._connection = None - async def fetch_all(self, query: ClauseElement) -> typing.List[Record]: + async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]: assert self._connection is not None, "Connection is not acquired" - query_str, args, context = self._compile(query) + query_str, args, result_columns, context = self._compile(query) + column_maps = create_column_maps(result_columns) + dialect = self._dialect + cursor = await self._connection.cursor() try: await cursor.execute(query_str, args) rows = await cursor.fetchall() metadata = CursorResultMetaData(context, cursor.description) - return [ + rows = [ Row( metadata, metadata._processors, metadata._keymap, - Row._default_key_style, row, ) for row in rows ] + return [Record(row, result_columns, dialect, column_maps) for row in rows] finally: cursor.close() - async def fetch_one(self, query: ClauseElement) -> typing.Optional[Record]: + async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]: assert self._connection is not None, "Connection is not acquired" - query_str, args, context = self._compile(query) + query_str, args, result_columns, context = self._compile(query) + column_maps = create_column_maps(result_columns) + dialect = self._dialect cursor = await self._connection.cursor() try: await cursor.execute(query_str, args) @@ -148,19 +154,19 @@ async def fetch_one(self, query: ClauseElement) -> typing.Optional[Record]: if row is None: return None metadata = CursorResultMetaData(context, cursor.description) - return Row( + row = Row( metadata, metadata._processors, metadata._keymap, - Row._default_key_style, row, ) + return Record(row, result_columns, dialect, column_maps) finally: cursor.close() async def execute(self, query: ClauseElement) -> typing.Any: assert self._connection is not None, "Connection is not acquired" - query_str, args, context = self._compile(query) + query_str, args, _, _ = self._compile(query) cursor = await self._connection.cursor() try: await cursor.execute(query_str, args) @@ -173,7 +179,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None: cursor = await self._connection.cursor() try: for single_query in queries: - single_query, args, context = self._compile(single_query) + single_query, args, _, _ = self._compile(single_query) await cursor.execute(single_query, args) finally: cursor.close() @@ -182,36 +188,37 @@ async def iterate( self, query: ClauseElement ) -> typing.AsyncGenerator[typing.Any, None]: assert self._connection is not None, "Connection is not acquired" - query_str, args, context = self._compile(query) + query_str, args, result_columns, context = self._compile(query) + column_maps = create_column_maps(result_columns) + dialect = self._dialect cursor = await self._connection.cursor() try: await cursor.execute(query_str, args) metadata = CursorResultMetaData(context, cursor.description) async for row in cursor: - yield Row( + record = Row( metadata, metadata._processors, metadata._keymap, - Row._default_key_style, row, ) + yield Record(record, result_columns, dialect, column_maps) finally: cursor.close() def transaction(self) -> TransactionBackend: return AiopgTransaction(self) - def _compile( - self, query: ClauseElement - ) -> typing.Tuple[str, dict, CompilationContext]: + def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]: compiled = query.compile( dialect=self._dialect, compile_kwargs={"render_postcompile": True} ) - execution_context = self._dialect.execution_ctx_cls() execution_context.dialect = self._dialect if not isinstance(query, DDLElement): + compiled_params = sorted(compiled.params.items()) + args = compiled.construct_params() for key, val in args.items(): if key in compiled._bind_processors: @@ -224,11 +231,23 @@ def _compile( compiled._ad_hoc_textual, compiled._loose_column_name_matching, ) + + mapping = { + key: "$" + str(i) for i, (key, _) in enumerate(compiled_params, start=1) + } + compiled_query = compiled.string % mapping + result_map = compiled._result_columns + else: args = {} + result_map = None + compiled_query = compiled.string - logger.debug("Query: %s\nArgs: %s", compiled.string, args) - return compiled.string, args, CompilationContext(execution_context) + query_message = compiled_query.replace(" \n", " ").replace("\n", " ") + logger.debug( + "Query: %s Args: %s", query_message, repr(tuple(args)), extra=LOG_EXTRA + ) + return compiled.string, args, result_map, CompilationContext(execution_context) @property def raw_connection(self) -> aiopg.connection.Connection: diff --git a/databases/backends/asyncmy.py b/databases/backends/asyncmy.py index 749e5afe..040a4346 100644 --- a/databases/backends/asyncmy.py +++ b/databases/backends/asyncmy.py @@ -7,15 +7,15 @@ from sqlalchemy.dialects.mysql import pymysql from sqlalchemy.engine.cursor import CursorResultMetaData from sqlalchemy.engine.interfaces import Dialect, ExecutionContext -from sqlalchemy.engine.row import Row from sqlalchemy.sql import ClauseElement from sqlalchemy.sql.ddl import DDLElement +from databases.backends.common.records import Record, Row, create_column_maps from databases.core import LOG_EXTRA, DatabaseURL from databases.interfaces import ( ConnectionBackend, DatabaseBackend, - Record, + Record as RecordInterface, TransactionBackend, ) @@ -40,6 +40,7 @@ def _get_connection_kwargs(self) -> dict: max_size = url_options.get("max_size") pool_recycle = url_options.get("pool_recycle") ssl = url_options.get("ssl") + unix_socket = url_options.get("unix_socket") if min_size is not None: kwargs["minsize"] = int(min_size) @@ -49,6 +50,8 @@ def _get_connection_kwargs(self) -> dict: kwargs["pool_recycle"] = int(pool_recycle) if ssl is not None: kwargs["ssl"] = {"true": True, "false": False}[ssl.lower()] + if unix_socket is not None: + kwargs["unix_socket"] = unix_socket for key, value in self._options.items(): # Coerce 'min_size' and 'max_size' for consistency. @@ -105,30 +108,37 @@ async def release(self) -> None: await self._database._pool.release(self._connection) self._connection = None - async def fetch_all(self, query: ClauseElement) -> typing.List[Record]: + async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]: assert self._connection is not None, "Connection is not acquired" - query_str, args, context = self._compile(query) + query_str, args, result_columns, context = self._compile(query) + column_maps = create_column_maps(result_columns) + dialect = self._dialect + async with self._connection.cursor() as cursor: try: await cursor.execute(query_str, args) rows = await cursor.fetchall() metadata = CursorResultMetaData(context, cursor.description) - return [ + rows = [ Row( metadata, metadata._processors, metadata._keymap, - Row._default_key_style, row, ) for row in rows ] + return [ + Record(row, result_columns, dialect, column_maps) for row in rows + ] finally: await cursor.close() - async def fetch_one(self, query: ClauseElement) -> typing.Optional[Record]: + async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]: assert self._connection is not None, "Connection is not acquired" - query_str, args, context = self._compile(query) + query_str, args, result_columns, context = self._compile(query) + column_maps = create_column_maps(result_columns) + dialect = self._dialect async with self._connection.cursor() as cursor: try: await cursor.execute(query_str, args) @@ -136,19 +146,19 @@ async def fetch_one(self, query: ClauseElement) -> typing.Optional[Record]: if row is None: return None metadata = CursorResultMetaData(context, cursor.description) - return Row( + row = Row( metadata, metadata._processors, metadata._keymap, - Row._default_key_style, row, ) + return Record(row, result_columns, dialect, column_maps) finally: await cursor.close() async def execute(self, query: ClauseElement) -> typing.Any: assert self._connection is not None, "Connection is not acquired" - query_str, args, context = self._compile(query) + query_str, args, _, _ = self._compile(query) async with self._connection.cursor() as cursor: try: await cursor.execute(query_str, args) @@ -163,7 +173,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None: async with self._connection.cursor() as cursor: try: for single_query in queries: - single_query, args, context = self._compile(single_query) + single_query, args, _, _ = self._compile(single_query) await cursor.execute(single_query, args) finally: await cursor.close() @@ -172,36 +182,37 @@ async def iterate( self, query: ClauseElement ) -> typing.AsyncGenerator[typing.Any, None]: assert self._connection is not None, "Connection is not acquired" - query_str, args, context = self._compile(query) + query_str, args, result_columns, context = self._compile(query) + column_maps = create_column_maps(result_columns) + dialect = self._dialect async with self._connection.cursor() as cursor: try: await cursor.execute(query_str, args) metadata = CursorResultMetaData(context, cursor.description) async for row in cursor: - yield Row( + record = Row( metadata, metadata._processors, metadata._keymap, - Row._default_key_style, row, ) + yield Record(record, result_columns, dialect, column_maps) finally: await cursor.close() def transaction(self) -> TransactionBackend: return AsyncMyTransaction(self) - def _compile( - self, query: ClauseElement - ) -> typing.Tuple[str, dict, CompilationContext]: + def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]: compiled = query.compile( dialect=self._dialect, compile_kwargs={"render_postcompile": True} ) - execution_context = self._dialect.execution_ctx_cls() execution_context.dialect = self._dialect if not isinstance(query, DDLElement): + compiled_params = sorted(compiled.params.items()) + args = compiled.construct_params() for key, val in args.items(): if key in compiled._bind_processors: @@ -214,12 +225,23 @@ def _compile( compiled._ad_hoc_textual, compiled._loose_column_name_matching, ) + + mapping = { + key: "$" + str(i) for i, (key, _) in enumerate(compiled_params, start=1) + } + compiled_query = compiled.string % mapping + result_map = compiled._result_columns + else: args = {} + result_map = None + compiled_query = compiled.string - query_message = compiled.string.replace(" \n", " ").replace("\n", " ") - logger.debug("Query: %s Args: %s", query_message, repr(args), extra=LOG_EXTRA) - return compiled.string, args, CompilationContext(execution_context) + query_message = compiled_query.replace(" \n", " ").replace("\n", " ") + logger.debug( + "Query: %s Args: %s", query_message, repr(tuple(args)), extra=LOG_EXTRA + ) + return compiled.string, args, result_map, CompilationContext(execution_context) @property def raw_connection(self) -> asyncmy.connection.Connection: diff --git a/databases/backends/common/__init__.py b/databases/backends/common/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/databases/backends/common/records.py b/databases/backends/common/records.py new file mode 100644 index 00000000..e963af50 --- /dev/null +++ b/databases/backends/common/records.py @@ -0,0 +1,136 @@ +import enum +import typing +from datetime import date, datetime, time + +from sqlalchemy.engine.interfaces import Dialect +from sqlalchemy.engine.row import Row as SQLRow +from sqlalchemy.sql.compiler import _CompileLabel +from sqlalchemy.sql.schema import Column +from sqlalchemy.sql.sqltypes import JSON +from sqlalchemy.types import TypeEngine + +from databases.interfaces import Record as RecordInterface + +DIALECT_EXCLUDE = {"postgresql"} + + +class Record(RecordInterface): + __slots__ = ( + "_row", + "_result_columns", + "_dialect", + "_column_map", + "_column_map_int", + "_column_map_full", + ) + + def __init__( + self, + row: typing.Any, + result_columns: tuple, + dialect: Dialect, + column_maps: typing.Tuple[ + typing.Mapping[typing.Any, typing.Tuple[int, TypeEngine]], + typing.Mapping[int, typing.Tuple[int, TypeEngine]], + typing.Mapping[str, typing.Tuple[int, TypeEngine]], + ], + ) -> None: + self._row = row + self._result_columns = result_columns + self._dialect = dialect + self._column_map, self._column_map_int, self._column_map_full = column_maps + + @property + def _mapping(self) -> typing.Mapping: + return self._row + + def keys(self) -> typing.KeysView: + return self._mapping.keys() + + def values(self) -> typing.ValuesView: + return self._mapping.values() + + def __getitem__(self, key: typing.Any) -> typing.Any: + if len(self._column_map) == 0: + return self._row[key] + elif isinstance(key, Column): + idx, datatype = self._column_map_full[str(key)] + elif isinstance(key, int): + idx, datatype = self._column_map_int[key] + else: + idx, datatype = self._column_map[key] + + raw = self._row[idx] + processor = datatype._cached_result_processor(self._dialect, None) + + if self._dialect.name in DIALECT_EXCLUDE: + if processor is not None and isinstance(raw, (int, str, float)): + return processor(raw) + + return raw + + def __iter__(self) -> typing.Iterator: + return iter(self._row.keys()) + + def __len__(self) -> int: + return len(self._row) + + def __getattr__(self, name: str) -> typing.Any: + try: + return self.__getitem__(name) + except KeyError as e: + raise AttributeError(e.args[0]) from e + + +class Row(SQLRow): + def __getitem__(self, key: typing.Any) -> typing.Any: + """ + An instance of a Row in SQLAlchemy allows the access + to the Row._fields as tuple and the Row._mapping for + the values. + """ + if isinstance(key, int): + return super().__getitem__(key) + + idx = self._key_to_index[key][0] + return super().__getitem__(idx) + + def keys(self): + return self._mapping.keys() + + def values(self): + return self._mapping.values() + + +def create_column_maps( + result_columns: typing.Any, +) -> typing.Tuple[ + typing.Mapping[typing.Any, typing.Tuple[int, TypeEngine]], + typing.Mapping[int, typing.Tuple[int, TypeEngine]], + typing.Mapping[str, typing.Tuple[int, TypeEngine]], +]: + """ + Generate column -> datatype mappings from the column definitions. + + These mappings are used throughout PostgresConnection methods + to initialize Record-s. The underlying DB driver does not do type + conversion for us so we have wrap the returned asyncpg.Record-s. + + :return: Three mappings from different ways to address a column to \ + corresponding column indexes and datatypes: \ + 1. by column identifier; \ + 2. by column index; \ + 3. by column name in Column sqlalchemy objects. + """ + column_map, column_map_int, column_map_full = {}, {}, {} + for idx, (column_name, _, column, datatype) in enumerate(result_columns): + column_map[column_name] = (idx, datatype) + column_map_int[idx] = (idx, datatype) + + # Added in SQLA 2.0 and _CompileLabels do not have _annotations + # When this happens, the mapping is on the second position + if isinstance(column[0], _CompileLabel): + column_map_full[str(column[2])] = (idx, datatype) + else: + column_map_full[str(column[0])] = (idx, datatype) + return column_map, column_map_int, column_map_full diff --git a/databases/backends/compilers/__init__.py b/databases/backends/compilers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/databases/backends/compilers/psycopg.py b/databases/backends/compilers/psycopg.py new file mode 100644 index 00000000..654c22a1 --- /dev/null +++ b/databases/backends/compilers/psycopg.py @@ -0,0 +1,17 @@ +from sqlalchemy.dialects.postgresql.psycopg import PGCompiler_psycopg + + +class APGCompiler_psycopg2(PGCompiler_psycopg): + def construct_params(self, *args, **kwargs): + pd = super().construct_params(*args, **kwargs) + + for column in self.prefetch: + pd[column.key] = self._exec_default(column.default) + + return pd + + def _exec_default(self, default): + if default.is_callable: + return default.arg(self.dialect) + else: + return default.arg diff --git a/databases/backends/dialects/__init__.py b/databases/backends/dialects/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/databases/backends/dialects/psycopg.py b/databases/backends/dialects/psycopg.py new file mode 100644 index 00000000..07bd1880 --- /dev/null +++ b/databases/backends/dialects/psycopg.py @@ -0,0 +1,46 @@ +""" +All the unique changes for the databases package +with the custom Numeric as the deprecated pypostgresql +for backwards compatibility and to make sure the +package can go to SQLAlchemy 2.0+. +""" + +import typing + +from sqlalchemy import types, util +from sqlalchemy.dialects.postgresql.base import PGDialect, PGExecutionContext +from sqlalchemy.engine import processors +from sqlalchemy.types import Float, Numeric + + +class PGExecutionContext_psycopg(PGExecutionContext): + ... + + +class PGNumeric(Numeric): + def bind_processor( + self, dialect: typing.Any + ) -> typing.Union[str, None]: # pragma: no cover + return processors.to_str + + def result_processor( + self, dialect: typing.Any, coltype: typing.Any + ) -> typing.Union[float, None]: # pragma: no cover + if self.asdecimal: + return None + else: + return processors.to_float + + +class PGDialect_psycopg(PGDialect): + colspecs = util.update_copy( + PGDialect.colspecs, + { + types.Numeric: PGNumeric, + types.Float: Float, + }, + ) + execution_ctx_cls = PGExecutionContext_psycopg + + +dialect = PGDialect_psycopg diff --git a/databases/backends/mysql.py b/databases/backends/mysql.py index 6b86042f..792f3685 100644 --- a/databases/backends/mysql.py +++ b/databases/backends/mysql.py @@ -7,15 +7,15 @@ from sqlalchemy.dialects.mysql import pymysql from sqlalchemy.engine.cursor import CursorResultMetaData from sqlalchemy.engine.interfaces import Dialect, ExecutionContext -from sqlalchemy.engine.row import Row from sqlalchemy.sql import ClauseElement from sqlalchemy.sql.ddl import DDLElement +from databases.backends.common.records import Record, Row, create_column_maps from databases.core import LOG_EXTRA, DatabaseURL from databases.interfaces import ( ConnectionBackend, DatabaseBackend, - Record, + Record as RecordInterface, TransactionBackend, ) @@ -40,6 +40,7 @@ def _get_connection_kwargs(self) -> dict: max_size = url_options.get("max_size") pool_recycle = url_options.get("pool_recycle") ssl = url_options.get("ssl") + unix_socket = url_options.get("unix_socket") if min_size is not None: kwargs["minsize"] = int(min_size) @@ -49,6 +50,8 @@ def _get_connection_kwargs(self) -> dict: kwargs["pool_recycle"] = int(pool_recycle) if ssl is not None: kwargs["ssl"] = {"true": True, "false": False}[ssl.lower()] + if unix_socket is not None: + kwargs["unix_socket"] = unix_socket for key, value in self._options.items(): # Coerce 'min_size' and 'max_size' for consistency. @@ -105,30 +108,34 @@ async def release(self) -> None: await self._database._pool.release(self._connection) self._connection = None - async def fetch_all(self, query: ClauseElement) -> typing.List[Record]: + async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]: assert self._connection is not None, "Connection is not acquired" - query_str, args, context = self._compile(query) + query_str, args, result_columns, context = self._compile(query) + column_maps = create_column_maps(result_columns) + dialect = self._dialect cursor = await self._connection.cursor() try: await cursor.execute(query_str, args) rows = await cursor.fetchall() metadata = CursorResultMetaData(context, cursor.description) - return [ + rows = [ Row( metadata, metadata._processors, metadata._keymap, - Row._default_key_style, row, ) for row in rows ] + return [Record(row, result_columns, dialect, column_maps) for row in rows] finally: await cursor.close() - async def fetch_one(self, query: ClauseElement) -> typing.Optional[Record]: + async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]: assert self._connection is not None, "Connection is not acquired" - query_str, args, context = self._compile(query) + query_str, args, result_columns, context = self._compile(query) + column_maps = create_column_maps(result_columns) + dialect = self._dialect cursor = await self._connection.cursor() try: await cursor.execute(query_str, args) @@ -136,19 +143,19 @@ async def fetch_one(self, query: ClauseElement) -> typing.Optional[Record]: if row is None: return None metadata = CursorResultMetaData(context, cursor.description) - return Row( + row = Row( metadata, metadata._processors, metadata._keymap, - Row._default_key_style, row, ) + return Record(row, result_columns, dialect, column_maps) finally: await cursor.close() async def execute(self, query: ClauseElement) -> typing.Any: assert self._connection is not None, "Connection is not acquired" - query_str, args, context = self._compile(query) + query_str, args, _, _ = self._compile(query) cursor = await self._connection.cursor() try: await cursor.execute(query_str, args) @@ -163,7 +170,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None: cursor = await self._connection.cursor() try: for single_query in queries: - single_query, args, context = self._compile(single_query) + single_query, args, _, _ = self._compile(single_query) await cursor.execute(single_query, args) finally: await cursor.close() @@ -172,36 +179,37 @@ async def iterate( self, query: ClauseElement ) -> typing.AsyncGenerator[typing.Any, None]: assert self._connection is not None, "Connection is not acquired" - query_str, args, context = self._compile(query) + query_str, args, result_columns, context = self._compile(query) + column_maps = create_column_maps(result_columns) + dialect = self._dialect cursor = await self._connection.cursor() try: await cursor.execute(query_str, args) metadata = CursorResultMetaData(context, cursor.description) async for row in cursor: - yield Row( + record = Row( metadata, metadata._processors, metadata._keymap, - Row._default_key_style, row, ) + yield Record(record, result_columns, dialect, column_maps) finally: await cursor.close() def transaction(self) -> TransactionBackend: return MySQLTransaction(self) - def _compile( - self, query: ClauseElement - ) -> typing.Tuple[str, dict, CompilationContext]: + def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]: compiled = query.compile( dialect=self._dialect, compile_kwargs={"render_postcompile": True} ) - execution_context = self._dialect.execution_ctx_cls() execution_context.dialect = self._dialect if not isinstance(query, DDLElement): + compiled_params = sorted(compiled.params.items()) + args = compiled.construct_params() for key, val in args.items(): if key in compiled._bind_processors: @@ -214,12 +222,23 @@ def _compile( compiled._ad_hoc_textual, compiled._loose_column_name_matching, ) + + mapping = { + key: "$" + str(i) for i, (key, _) in enumerate(compiled_params, start=1) + } + compiled_query = compiled.string % mapping + result_map = compiled._result_columns + else: args = {} + result_map = None + compiled_query = compiled.string - query_message = compiled.string.replace(" \n", " ").replace("\n", " ") - logger.debug("Query: %s Args: %s", query_message, repr(args), extra=LOG_EXTRA) - return compiled.string, args, CompilationContext(execution_context) + query_message = compiled_query.replace(" \n", " ").replace("\n", " ") + logger.debug( + "Query: %s Args: %s", query_message, repr(tuple(args)), extra=LOG_EXTRA + ) + return compiled.string, args, result_map, CompilationContext(execution_context) @property def raw_connection(self) -> aiomysql.connection.Connection: diff --git a/databases/backends/postgres.py b/databases/backends/postgres.py index e30c12d7..c42688e1 100644 --- a/databases/backends/postgres.py +++ b/databases/backends/postgres.py @@ -2,13 +2,12 @@ import typing import asyncpg -from sqlalchemy.dialects.postgresql import pypostgresql from sqlalchemy.engine.interfaces import Dialect from sqlalchemy.sql import ClauseElement from sqlalchemy.sql.ddl import DDLElement -from sqlalchemy.sql.schema import Column -from sqlalchemy.types import TypeEngine +from databases.backends.common.records import Record, create_column_maps +from databases.backends.dialects.psycopg import dialect as psycopg_dialect from databases.core import LOG_EXTRA, DatabaseURL from databases.interfaces import ( ConnectionBackend, @@ -30,7 +29,7 @@ def __init__( self._pool = None def _get_dialect(self) -> Dialect: - dialect = pypostgresql.dialect(paramstyle="pyformat") + dialect = psycopg_dialect(paramstyle="pyformat") dialect.implicit_returning = True dialect.supports_native_enum = True @@ -55,7 +54,8 @@ def _get_connection_kwargs(self) -> dict: if max_size is not None: kwargs["max_size"] = int(max_size) if ssl is not None: - kwargs["ssl"] = {"true": True, "false": False}[ssl.lower()] + ssl = ssl.lower() + kwargs["ssl"] = {"true": True, "false": False}.get(ssl, ssl) kwargs.update(self._options) @@ -82,82 +82,6 @@ def connection(self) -> "PostgresConnection": return PostgresConnection(self, self._dialect) -class Record(RecordInterface): - __slots__ = ( - "_row", - "_result_columns", - "_dialect", - "_column_map", - "_column_map_int", - "_column_map_full", - ) - - def __init__( - self, - row: asyncpg.Record, - result_columns: tuple, - dialect: Dialect, - column_maps: typing.Tuple[ - typing.Mapping[typing.Any, typing.Tuple[int, TypeEngine]], - typing.Mapping[int, typing.Tuple[int, TypeEngine]], - typing.Mapping[str, typing.Tuple[int, TypeEngine]], - ], - ) -> None: - self._row = row - self._result_columns = result_columns - self._dialect = dialect - self._column_map, self._column_map_int, self._column_map_full = column_maps - - @property - def _mapping(self) -> typing.Mapping: - return self._row - - def keys(self) -> typing.KeysView: - import warnings - - warnings.warn( - "The `Row.keys()` method is deprecated to mimic SQLAlchemy behaviour, " - "use `Row._mapping.keys()` instead.", - DeprecationWarning, - ) - return self._mapping.keys() - - def values(self) -> typing.ValuesView: - import warnings - - warnings.warn( - "The `Row.values()` method is deprecated to mimic SQLAlchemy behaviour, " - "use `Row._mapping.values()` instead.", - DeprecationWarning, - ) - return self._mapping.values() - - def __getitem__(self, key: typing.Any) -> typing.Any: - if len(self._column_map) == 0: # raw query - return self._row[key] - elif isinstance(key, Column): - idx, datatype = self._column_map_full[str(key)] - elif isinstance(key, int): - idx, datatype = self._column_map_int[key] - else: - idx, datatype = self._column_map[key] - raw = self._row[idx] - processor = datatype._cached_result_processor(self._dialect, None) - - if processor is not None: - return processor(raw) - return raw - - def __iter__(self) -> typing.Iterator: - return iter(self._row.keys()) - - def __len__(self) -> int: - return len(self._row) - - def __getattr__(self, name: str) -> typing.Any: - return self._mapping.get(name) - - class PostgresConnection(ConnectionBackend): def __init__(self, database: PostgresBackend, dialect: Dialect): self._database = database @@ -180,7 +104,7 @@ async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]: query_str, args, result_columns = self._compile(query) rows = await self._connection.fetch(query_str, *args) dialect = self._dialect - column_maps = self._create_column_maps(result_columns) + column_maps = create_column_maps(result_columns) return [Record(row, result_columns, dialect, column_maps) for row in rows] async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]: @@ -193,7 +117,7 @@ async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterfa row, result_columns, self._dialect, - self._create_column_maps(result_columns), + create_column_maps(result_columns), ) async def fetch_val( @@ -213,7 +137,7 @@ async def fetch_val( async def execute(self, query: ClauseElement) -> typing.Any: assert self._connection is not None, "Connection is not acquired" - query_str, args, result_columns = self._compile(query) + query_str, args, _ = self._compile(query) return await self._connection.fetchval(query_str, *args) async def execute_many(self, queries: typing.List[ClauseElement]) -> None: @@ -222,7 +146,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None: # loop through multiple executes here, which should all end up # using the same prepared statement. for single_query in queries: - single_query, args, result_columns = self._compile(single_query) + single_query, args, _ = self._compile(single_query) await self._connection.execute(single_query, *args) async def iterate( @@ -230,7 +154,7 @@ async def iterate( ) -> typing.AsyncGenerator[typing.Any, None]: assert self._connection is not None, "Connection is not acquired" query_str, args, result_columns = self._compile(query) - column_maps = self._create_column_maps(result_columns) + column_maps = create_column_maps(result_columns) async for row in self._connection.cursor(query_str, *args): yield Record(row, result_columns, self._dialect, column_maps) @@ -255,7 +179,6 @@ def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]: processors[key](val) if key in processors else val for key, val in compiled_params ] - result_map = compiled._result_columns else: compiled_query = compiled.string @@ -268,34 +191,6 @@ def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]: ) return compiled_query, args, result_map - @staticmethod - def _create_column_maps( - result_columns: tuple, - ) -> typing.Tuple[ - typing.Mapping[typing.Any, typing.Tuple[int, TypeEngine]], - typing.Mapping[int, typing.Tuple[int, TypeEngine]], - typing.Mapping[str, typing.Tuple[int, TypeEngine]], - ]: - """ - Generate column -> datatype mappings from the column definitions. - - These mappings are used throughout PostgresConnection methods - to initialize Record-s. The underlying DB driver does not do type - conversion for us so we have wrap the returned asyncpg.Record-s. - - :return: Three mappings from different ways to address a column to \ - corresponding column indexes and datatypes: \ - 1. by column identifier; \ - 2. by column index; \ - 3. by column name in Column sqlalchemy objects. - """ - column_map, column_map_int, column_map_full = {}, {}, {} - for idx, (column_name, _, column, datatype) in enumerate(result_columns): - column_map[column_name] = (idx, datatype) - column_map_int[idx] = (idx, datatype) - column_map_full[str(column[0])] = (idx, datatype) - return column_map, column_map_int, column_map_full - @property def raw_connection(self) -> asyncpg.connection.Connection: assert self._connection is not None, "Connection is not acquired" diff --git a/databases/backends/sqlite.py b/databases/backends/sqlite.py index 19464627..16e17e9e 100644 --- a/databases/backends/sqlite.py +++ b/databases/backends/sqlite.py @@ -1,22 +1,19 @@ import logging +import sqlite3 import typing import uuid +from urllib.parse import urlencode import aiosqlite from sqlalchemy.dialects.sqlite import pysqlite from sqlalchemy.engine.cursor import CursorResultMetaData from sqlalchemy.engine.interfaces import Dialect, ExecutionContext -from sqlalchemy.engine.row import Row from sqlalchemy.sql import ClauseElement from sqlalchemy.sql.ddl import DDLElement +from databases.backends.common.records import Record, Row, create_column_maps from databases.core import LOG_EXTRA, DatabaseURL -from databases.interfaces import ( - ConnectionBackend, - DatabaseBackend, - Record, - TransactionBackend, -) +from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend logger = logging.getLogger("databases") @@ -33,19 +30,12 @@ def __init__( self._pool = SQLitePool(self._database_url, **self._options) async def connect(self) -> None: - pass - # assert self._pool is None, "DatabaseBackend is already running" - # self._pool = await aiomysql.create_pool( - # host=self._database_url.hostname, - # port=self._database_url.port or 3306, - # user=self._database_url.username or getpass.getuser(), - # password=self._database_url.password, - # db=self._database_url.database, - # autocommit=True, - # ) + ... async def disconnect(self) -> None: - pass + # if it extsis, remove reference to connection to cached in-memory database on disconnect + if self._pool._memref: + self._pool._memref = None # assert self._pool is not None, "DatabaseBackend is not running" # self._pool.close() # await self._pool.wait_closed() @@ -57,12 +47,20 @@ def connection(self) -> "SQLiteConnection": class SQLitePool: def __init__(self, url: DatabaseURL, **options: typing.Any) -> None: - self._url = url + self._database = url.database + self._memref = None + # add query params to database connection string + if url.options: + self._database += "?" + urlencode(url.options) self._options = options + if url.options and "cache" in url.options: + # reference to a connection to the cached in-memory database must be held to keep it from being deleted + self._memref = sqlite3.connect(self._database, **self._options) + async def acquire(self) -> aiosqlite.Connection: connection = aiosqlite.connect( - database=self._url.database, isolation_level=None, **self._options + database=self._database, isolation_level=None, **self._options ) await connection.__aenter__() return connection @@ -93,42 +91,46 @@ async def release(self) -> None: async def fetch_all(self, query: ClauseElement) -> typing.List[Record]: assert self._connection is not None, "Connection is not acquired" - query_str, args, context = self._compile(query) + query_str, args, result_columns, context = self._compile(query) + column_maps = create_column_maps(result_columns) + dialect = self._dialect async with self._connection.execute(query_str, args) as cursor: rows = await cursor.fetchall() metadata = CursorResultMetaData(context, cursor.description) - return [ + rows = [ Row( metadata, metadata._processors, metadata._keymap, - Row._default_key_style, row, ) for row in rows ] + return [Record(row, result_columns, dialect, column_maps) for row in rows] async def fetch_one(self, query: ClauseElement) -> typing.Optional[Record]: assert self._connection is not None, "Connection is not acquired" - query_str, args, context = self._compile(query) + query_str, args, result_columns, context = self._compile(query) + column_maps = create_column_maps(result_columns) + dialect = self._dialect async with self._connection.execute(query_str, args) as cursor: row = await cursor.fetchone() if row is None: return None metadata = CursorResultMetaData(context, cursor.description) - return Row( + row = Row( metadata, metadata._processors, metadata._keymap, - Row._default_key_style, row, ) + return Record(row, result_columns, dialect, column_maps) async def execute(self, query: ClauseElement) -> typing.Any: assert self._connection is not None, "Connection is not acquired" - query_str, args, context = self._compile(query) + query_str, args, result_columns, context = self._compile(query) async with self._connection.cursor() as cursor: await cursor.execute(query_str, args) if cursor.lastrowid == 0: @@ -144,34 +146,37 @@ async def iterate( self, query: ClauseElement ) -> typing.AsyncGenerator[typing.Any, None]: assert self._connection is not None, "Connection is not acquired" - query_str, args, context = self._compile(query) + query_str, args, result_columns, context = self._compile(query) + column_maps = create_column_maps(result_columns) + dialect = self._dialect + async with self._connection.execute(query_str, args) as cursor: metadata = CursorResultMetaData(context, cursor.description) async for row in cursor: - yield Row( + record = Row( metadata, metadata._processors, metadata._keymap, - Row._default_key_style, row, ) + yield Record(record, result_columns, dialect, column_maps) def transaction(self) -> TransactionBackend: return SQLiteTransaction(self) - def _compile( - self, query: ClauseElement - ) -> typing.Tuple[str, list, CompilationContext]: + def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]: compiled = query.compile( dialect=self._dialect, compile_kwargs={"render_postcompile": True} ) - execution_context = self._dialect.execution_ctx_cls() execution_context.dialect = self._dialect args = [] + result_map = None if not isinstance(query, DDLElement): + compiled_params = sorted(compiled.params.items()) + params = compiled.construct_params() for key in compiled.positiontup: raw_val = params[key] @@ -189,11 +194,20 @@ def _compile( compiled._loose_column_name_matching, ) - query_message = compiled.string.replace(" \n", " ").replace("\n", " ") + mapping = { + key: "$" + str(i) for i, (key, _) in enumerate(compiled_params, start=1) + } + compiled_query = compiled.string % mapping + result_map = compiled._result_columns + + else: + compiled_query = compiled.string + + query_message = compiled_query.replace(" \n", " ").replace("\n", " ") logger.debug( "Query: %s Args: %s", query_message, repr(tuple(args)), extra=LOG_EXTRA ) - return compiled.string, args, CompilationContext(execution_context) + return compiled.string, args, result_map, CompilationContext(execution_context) @property def raw_connection(self) -> aiosqlite.core.Connection: diff --git a/databases/core.py b/databases/core.py index 8394ab5c..d55dd3c8 100644 --- a/databases/core.py +++ b/databases/core.py @@ -3,6 +3,7 @@ import functools import logging import typing +import weakref from contextvars import ContextVar from types import TracebackType from urllib.parse import SplitResult, parse_qsl, unquote, urlsplit @@ -11,7 +12,7 @@ from sqlalchemy.sql import ClauseElement from databases.importer import import_from_string -from databases.interfaces import DatabaseBackend, Record +from databases.interfaces import DatabaseBackend, Record, TransactionBackend try: # pragma: no cover import click @@ -35,6 +36,11 @@ logger = logging.getLogger("databases") +_ACTIVE_TRANSACTIONS: ContextVar[ + typing.Optional["weakref.WeakKeyDictionary['Transaction', 'TransactionBackend']"] +] = ContextVar("databases:active_transactions", default=None) + + class Database: SUPPORTED_BACKENDS = { "postgresql": "databases.backends.postgres:PostgresBackend", @@ -45,6 +51,8 @@ class Database: "sqlite": "databases.backends.sqlite:SQLiteBackend", } + _connection_map: "weakref.WeakKeyDictionary[asyncio.Task, 'Connection']" + def __init__( self, url: typing.Union[str, "DatabaseURL"], @@ -55,6 +63,7 @@ def __init__( self.url = DatabaseURL(url) self.options = options self.is_connected = False + self._connection_map = weakref.WeakKeyDictionary() self._force_rollback = force_rollback @@ -63,14 +72,35 @@ def __init__( assert issubclass(backend_cls, DatabaseBackend) self._backend = backend_cls(self.url, **self.options) - # Connections are stored as task-local state. - self._connection_context: ContextVar = ContextVar("connection_context") - # When `force_rollback=True` is used, we use a single global # connection, within a transaction that always rolls back. self._global_connection: typing.Optional[Connection] = None self._global_transaction: typing.Optional[Transaction] = None + @property + def _current_task(self) -> asyncio.Task: + task = asyncio.current_task() + if not task: + raise RuntimeError("No currently active asyncio.Task found") + return task + + @property + def _connection(self) -> typing.Optional["Connection"]: + return self._connection_map.get(self._current_task) + + @_connection.setter + def _connection( + self, connection: typing.Optional["Connection"] + ) -> typing.Optional["Connection"]: + task = self._current_task + + if connection is None: + self._connection_map.pop(task, None) + else: + self._connection_map[task] = connection + + return self._connection + async def connect(self) -> None: """ Establish the connection pool. @@ -89,7 +119,7 @@ async def connect(self) -> None: assert self._global_connection is None assert self._global_transaction is None - self._global_connection = Connection(self._backend) + self._global_connection = Connection(self, self._backend) self._global_transaction = self._global_connection.transaction( force_rollback=True ) @@ -113,7 +143,7 @@ async def disconnect(self) -> None: self._global_transaction = None self._global_connection = None else: - self._connection_context = ContextVar("connection_context") + self._connection = None await self._backend.disconnect() logger.info( @@ -187,12 +217,10 @@ def connection(self) -> "Connection": if self._global_connection is not None: return self._global_connection - try: - return self._connection_context.get() - except LookupError: - connection = Connection(self._backend) - self._connection_context.set(connection) - return connection + if not self._connection: + self._connection = Connection(self, self._backend) + + return self._connection def transaction( self, *, force_rollback: bool = False, **kwargs: typing.Any @@ -215,7 +243,8 @@ def _get_backend(self) -> str: class Connection: - def __init__(self, backend: DatabaseBackend) -> None: + def __init__(self, database: Database, backend: DatabaseBackend) -> None: + self._database = database self._backend = backend self._connection_lock = asyncio.Lock() @@ -249,6 +278,7 @@ async def __aexit__( self._connection_counter -= 1 if self._connection_counter == 0: await self._connection.release() + self._database._connection = None async def fetch_all( self, @@ -326,7 +356,7 @@ def _build_query( return query.bindparams(**values) if values is not None else query elif values: - return query.values(**values) + return query.values(**values) # type: ignore return query @@ -345,6 +375,37 @@ def __init__( self._force_rollback = force_rollback self._extra_options = kwargs + @property + def _connection(self) -> "Connection": + # Returns the same connection if called multiple times + return self._connection_callable() + + @property + def _transaction(self) -> typing.Optional["TransactionBackend"]: + transactions = _ACTIVE_TRANSACTIONS.get() + if transactions is None: + return None + + return transactions.get(self, None) + + @_transaction.setter + def _transaction( + self, transaction: typing.Optional["TransactionBackend"] + ) -> typing.Optional["TransactionBackend"]: + transactions = _ACTIVE_TRANSACTIONS.get() + if transactions is None: + transactions = weakref.WeakKeyDictionary() + else: + transactions = transactions.copy() + + if transaction is None: + transactions.pop(self, None) + else: + transactions[self] = transaction + + _ACTIVE_TRANSACTIONS.set(transactions) + return transactions.get(self, None) + async def __aenter__(self) -> "Transaction": """ Called when entering `async with database.transaction()` @@ -385,7 +446,6 @@ async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: return wrapper # type: ignore async def start(self) -> "Transaction": - self._connection = self._connection_callable() self._transaction = self._connection._connection.transaction() async with self._connection._transaction_lock: @@ -401,15 +461,19 @@ async def commit(self) -> None: async with self._connection._transaction_lock: assert self._connection._transaction_stack[-1] is self self._connection._transaction_stack.pop() + assert self._transaction is not None await self._transaction.commit() await self._connection.__aexit__() + self._transaction = None async def rollback(self) -> None: async with self._connection._transaction_lock: assert self._connection._transaction_stack[-1] is self self._connection._transaction_stack.pop() + assert self._transaction is not None await self._transaction.rollback() await self._connection.__aexit__() + self._transaction = None class _EmptyNetloc(str): diff --git a/docs/connections_and_transactions.md b/docs/connections_and_transactions.md index aa45537d..2b8f6b68 100644 --- a/docs/connections_and_transactions.md +++ b/docs/connections_and_transactions.md @@ -7,14 +7,14 @@ that transparently handles the use of either transactions or savepoints. ## Connecting and disconnecting -You can control the database connect/disconnect, by using it as a async context manager. +You can control the database connection pool with an async context manager: ```python async with Database(DATABASE_URL) as database: ... ``` -Or by using explicit connection and disconnection: +Or by using the explicit `.connect()` and `.disconnect()` methods: ```python database = Database(DATABASE_URL) @@ -23,6 +23,8 @@ await database.connect() await database.disconnect() ``` +Connections within this connection pool are acquired for each new `asyncio.Task`. + If you're integrating against a web framework, then you'll probably want to hook into framework startup or shutdown events. For example, with [Starlette][starlette] you would use the following: @@ -67,6 +69,7 @@ A transaction can be acquired from the database connection pool: async with database.transaction(): ... ``` + It can also be acquired from a specific database connection: ```python @@ -95,8 +98,51 @@ async def create_users(request): ... ``` -Transaction blocks are managed as task-local state. Nested transactions -are fully supported, and are implemented using database savepoints. +Transaction state is tied to the connection used in the currently executing asynchronous task. +If you would like to influence an active transaction from another task, the connection must be +shared. This state is _inherited_ by tasks that share the same connection: + +```python +async def add_excitement(connnection: databases.core.Connection, id: int): + await connection.execute( + "UPDATE notes SET text = CONCAT(text, '!!!') WHERE id = :id", + {"id": id} + ) + + +async with Database(database_url) as database: + async with database.transaction(): + # This note won't exist until the transaction closes... + await database.execute( + "INSERT INTO notes(id, text) values (1, 'databases is cool')" + ) + # ...but child tasks can use this connection now! + await asyncio.create_task(add_excitement(database.connection(), id=1)) + + await database.fetch_val("SELECT text FROM notes WHERE id=1") + # ^ returns: "databases is cool!!!" +``` + +Nested transactions are fully supported, and are implemented using database savepoints: + +```python +async with databases.Database(database_url) as db: + async with db.transaction() as outer: + # Do something in the outer transaction + ... + + # Suppress to prevent influence on the outer transaction + with contextlib.suppress(ValueError): + async with db.transaction(): + # Do something in the inner transaction + ... + + raise ValueError('Abort the inner transaction') + + # Observe the results of the outer transaction, + # without effects from the inner transaction. + await db.fetch_all('SELECT * FROM ...') +``` Transaction isolation-level can be specified if the driver backend supports that: diff --git a/docs/index.md b/docs/index.md index b18de817..7c3cebf2 100644 --- a/docs/index.md +++ b/docs/index.md @@ -17,7 +17,7 @@ expression language, and provides support for PostgreSQL, MySQL, and SQLite. Databases is suitable for integrating against any async Web framework, such as [Starlette][starlette], [Sanic][sanic], [Responder][responder], [Quart][quart], [aiohttp][aiohttp], [Tornado][tornado], or [FastAPI][fastapi]. -**Requirements**: Python 3.7+ +**Requirements**: Python 3.8+ --- @@ -83,7 +83,7 @@ values = [ ] await database.execute_many(query=query, values=values) -# Run a database query. +# Run a database query. query = "SELECT * FROM HighScores" rows = await database.fetch_all(query=query) print('High Scores:', rows) diff --git a/requirements.txt b/requirements.txt index 0699d3cc..8b05a46e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,25 +1,26 @@ -e . # Async database drivers -asyncmy==0.2.5 -aiomysql==0.1.1 -aiopg==1.3.4 -aiosqlite==0.17.0 -asyncpg==0.26.0 +asyncmy==0.2.9 +aiomysql==0.2.0 +aiopg==1.4.0 +aiosqlite==0.20.0 +asyncpg==0.29.0 # Sync database drivers for standard tooling around setup/teardown/migrations. -psycopg2-binary==2.9.3 -pymysql==1.0.2 +psycopg==3.1.18 +pymysql==1.1.0 # Testing autoflake==1.4 black==22.6.0 +httpx==0.24.1 isort==5.10.1 mypy==0.971 pytest==7.1.2 pytest-cov==3.0.0 -starlette==0.20.4 -requests==2.28.1 +starlette==0.36.2 +requests==2.31.0 # Documentation mkdocs==1.3.1 @@ -29,3 +30,4 @@ mkautodoc==0.1.0 # Packaging twine==4.0.1 wheel==0.38.1 +setuptools==69.0.3 diff --git a/scripts/clean b/scripts/clean index f01cc831..d7388629 100755 --- a/scripts/clean +++ b/scripts/clean @@ -9,6 +9,12 @@ fi if [ -d 'databases.egg-info' ] ; then rm -r databases.egg-info fi +if [ -d '.mypy_cache' ] ; then + rm -r .mypy_cache +fi +if [ -d '.pytest_cache' ] ; then + rm -r .pytest_cache +fi find databases -type f -name "*.py[co]" -delete find databases -type d -name __pycache__ -delete diff --git a/setup.cfg b/setup.cfg index da1831fd..b4182c83 100644 --- a/setup.cfg +++ b/setup.cfg @@ -2,6 +2,11 @@ disallow_untyped_defs = True ignore_missing_imports = True no_implicit_optional = True +disallow_any_generics = false +disallow_untyped_decorators = true +implicit_reexport = true +disallow_incomplete_defs = true +exclude = databases/backends [tool:isort] profile = black diff --git a/setup.py b/setup.py index 3725cab9..41c0c584 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,7 @@ def get_packages(package): setup( name="databases", version=get_version("databases"), - python_requires=">=3.7", + python_requires=">=3.8", url="https://github.com/encode/databases", license="BSD", description="Async database support for Python.", @@ -47,7 +47,7 @@ def get_packages(package): author_email="tom@tomchristie.com", packages=get_packages("databases"), package_data={"databases": ["py.typed"]}, - install_requires=["sqlalchemy>=1.4.42,<1.5"], + install_requires=["sqlalchemy>=2.0.7"], extras_require={ "postgresql": ["asyncpg"], "asyncpg": ["asyncpg"], @@ -66,10 +66,11 @@ def get_packages(package): "Operating System :: OS Independent", "Topic :: Internet :: WWW/HTTP", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3 :: Only", ], zip_safe=False, diff --git a/tests/test_connection_options.py b/tests/test_connection_options.py index e6fe6849..81ce2ac7 100644 --- a/tests/test_connection_options.py +++ b/tests/test_connection_options.py @@ -46,12 +46,24 @@ def test_postgres_ssl(): assert kwargs == {"ssl": True} +def test_postgres_ssl_verify_full(): + backend = PostgresBackend("postgres://localhost/database?ssl=verify-full") + kwargs = backend._get_connection_kwargs() + assert kwargs == {"ssl": "verify-full"} + + def test_postgres_explicit_ssl(): backend = PostgresBackend("postgres://localhost/database", ssl=True) kwargs = backend._get_connection_kwargs() assert kwargs == {"ssl": True} +def test_postgres_explicit_ssl_verify_full(): + backend = PostgresBackend("postgres://localhost/database", ssl="verify-full") + kwargs = backend._get_connection_kwargs() + assert kwargs == {"ssl": "verify-full"} + + def test_postgres_no_extra_options(): backend = PostgresBackend("postgres://localhost/database") kwargs = backend._get_connection_kwargs() @@ -77,6 +89,15 @@ def test_mysql_pool_size(): assert kwargs == {"minsize": 1, "maxsize": 20} +@pytest.mark.skipif(sys.version_info >= (3, 10), reason="requires python3.9 or lower") +def test_mysql_unix_socket(): + backend = MySQLBackend( + "mysql+aiomysql://username:password@/testsuite?unix_socket=/tmp/mysqld/mysqld.sock" + ) + kwargs = backend._get_connection_kwargs() + assert kwargs == {"unix_socket": "/tmp/mysqld/mysqld.sock"} + + @pytest.mark.skipif(sys.version_info >= (3, 10), reason="requires python3.9 or lower") def test_mysql_explicit_pool_size(): backend = MySQLBackend("mysql://localhost/database", min_size=1, max_size=20) @@ -114,6 +135,15 @@ def test_asyncmy_pool_size(): assert kwargs == {"minsize": 1, "maxsize": 20} +@pytest.mark.skipif(sys.version_info < (3, 7), reason="requires python3.7 or higher") +def test_asyncmy_unix_socket(): + backend = AsyncMyBackend( + "mysql+asyncmy://username:password@/testsuite?unix_socket=/tmp/mysqld/mysqld.sock" + ) + kwargs = backend._get_connection_kwargs() + assert kwargs == {"unix_socket": "/tmp/mysqld/mysqld.sock"} + + @pytest.mark.skipif(sys.version_info < (3, 7), reason="requires python3.7 or higher") def test_asyncmy_explicit_pool_size(): backend = AsyncMyBackend("mysql://localhost/database", min_size=1, max_size=20) diff --git a/tests/test_database_url.py b/tests/test_database_url.py index 9eea4fa6..7aa15926 100644 --- a/tests/test_database_url.py +++ b/tests/test_database_url.py @@ -69,6 +69,11 @@ def test_database_url_options(): u = DatabaseURL("postgresql://localhost/mydatabase?pool_size=20&ssl=true") assert u.options == {"pool_size": "20", "ssl": "true"} + u = DatabaseURL( + "mysql+asyncmy://username:password@/testsuite?unix_socket=/tmp/mysqld/mysqld.sock" + ) + assert u.options == {"unix_socket": "/tmp/mysqld/mysqld.sock"} + def test_replace_database_url_components(): u = DatabaseURL("postgresql://localhost/mydatabase") diff --git a/tests/test_databases.py b/tests/test_databases.py index a7545e31..d9d9e4d6 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -1,9 +1,13 @@ import asyncio import datetime import decimal +import enum import functools +import gc +import itertools import os -import re +import sqlite3 +from typing import MutableMapping from unittest.mock import MagicMock, patch import pytest @@ -52,6 +56,47 @@ def process_result_value(self, value, dialect): sqlalchemy.Column("published", sqlalchemy.DateTime), ) +# Used to test Date +events = sqlalchemy.Table( + "events", + metadata, + sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True), + sqlalchemy.Column("date", sqlalchemy.Date), +) + + +# Used to test Time +daily_schedule = sqlalchemy.Table( + "daily_schedule", + metadata, + sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True), + sqlalchemy.Column("time", sqlalchemy.Time), +) + + +class TshirtSize(enum.Enum): + SMALL = "SMALL" + MEDIUM = "MEDIUM" + LARGE = "LARGE" + XL = "XL" + + +class TshirtColor(enum.Enum): + BLUE = 0 + GREEN = 1 + YELLOW = 2 + RED = 3 + + +# Used to test Enum +tshirt_size = sqlalchemy.Table( + "tshirt_size", + metadata, + sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True), + sqlalchemy.Column("size", sqlalchemy.Enum(TshirtSize)), + sqlalchemy.Column("color", sqlalchemy.Enum(TshirtColor)), +) + # Used to test JSON session = sqlalchemy.Table( "session", @@ -111,6 +156,9 @@ def create_test_database(): engine = sqlalchemy.create_engine(url) metadata.drop_all(engine) + # Run garbage collection to ensure any in-memory databases are dropped + gc.collect() + def async_adapter(wrapped_func): """ @@ -167,24 +215,24 @@ async def test_queries(database_url): assert result["completed"] == True # fetch_val() - query = sqlalchemy.sql.select([notes.c.text]) + query = sqlalchemy.sql.select(*[notes.c.text]) result = await database.fetch_val(query=query) assert result == "example1" # fetch_val() with no rows - query = sqlalchemy.sql.select([notes.c.text]).where( + query = sqlalchemy.sql.select(*[notes.c.text]).where( notes.c.text == "impossible" ) result = await database.fetch_val(query=query) assert result is None # fetch_val() with a different column - query = sqlalchemy.sql.select([notes.c.id, notes.c.text]) + query = sqlalchemy.sql.select(*[notes.c.id, notes.c.text]) result = await database.fetch_val(query=query, column=1) assert result == "example1" # row access (needed to maintain test coverage for Record.__getitem__ in postgres backend) - query = sqlalchemy.sql.select([notes.c.text]) + query = sqlalchemy.sql.select(*[notes.c.text]) result = await database.fetch_one(query=query) assert result["text"] == "example1" assert result[0] == "example1" @@ -244,6 +292,7 @@ async def test_queries_raw(database_url): query = "SELECT completed FROM notes WHERE text = :text" result = await database.fetch_val(query=query, values={"text": "example1"}) assert result == True + query = "SELECT * FROM notes WHERE text = :text" result = await database.fetch_val( query=query, values={"text": "example1"}, column="completed" @@ -354,7 +403,7 @@ async def test_results_support_column_reference(database_url): await database.execute(query, values) # fetch_all() - query = sqlalchemy.select([articles, custom_date]) + query = sqlalchemy.select(*[articles, custom_date]) results = await database.fetch_all(query=query) assert len(results) == 1 assert results[0][articles.c.title] == "Hello, world Article" @@ -477,6 +526,254 @@ async def test_transaction_commit(database_url): assert len(results) == 1 +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_transaction_context_child_task_inheritance(database_url): + """ + Ensure that transactions are inherited by child tasks. + """ + async with Database(database_url) as database: + + async def check_transaction(transaction, active_transaction): + # Should have inherited the same transaction backend from the parent task + assert transaction._transaction is active_transaction + + async with database.transaction() as transaction: + await asyncio.create_task( + check_transaction(transaction, transaction._transaction) + ) + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_transaction_context_child_task_inheritance_example(database_url): + """ + Ensure that child tasks may influence inherited transactions. + """ + # This is an practical example of the above test. + async with Database(database_url) as database: + async with database.transaction(): + # Create a note + await database.execute( + notes.insert().values(id=1, text="setup", completed=True) + ) + + # Change the note from the same task + await database.execute( + notes.update().where(notes.c.id == 1).values(text="prior") + ) + + # Confirm the change + result = await database.fetch_one(notes.select().where(notes.c.id == 1)) + assert result.text == "prior" + + async def run_update_from_child_task(connection): + # Change the note from a child task + await connection.execute( + notes.update().where(notes.c.id == 1).values(text="test") + ) + + await asyncio.create_task(run_update_from_child_task(database.connection())) + + # Confirm the child's change + result = await database.fetch_one(notes.select().where(notes.c.id == 1)) + assert result.text == "test" + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_transaction_context_sibling_task_isolation(database_url): + """ + Ensure that transactions are isolated between sibling tasks. + """ + start = asyncio.Event() + end = asyncio.Event() + + async with Database(database_url) as database: + + async def check_transaction(transaction): + await start.wait() + # Parent task is now in a transaction, we should not + # see its transaction backend since this task was + # _started_ in a context where no transaction was active. + assert transaction._transaction is None + end.set() + + transaction = database.transaction() + assert transaction._transaction is None + task = asyncio.create_task(check_transaction(transaction)) + + async with transaction: + start.set() + assert transaction._transaction is not None + await end.wait() + + # Cleanup for "Task not awaited" warning + await task + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_transaction_context_sibling_task_isolation_example(database_url): + """ + Ensure that transactions are running in sibling tasks are isolated from eachother. + """ + # This is an practical example of the above test. + setup = asyncio.Event() + done = asyncio.Event() + + async def tx1(connection): + async with connection.transaction(): + await db.execute( + notes.insert(), values={"id": 1, "text": "tx1", "completed": False} + ) + setup.set() + await done.wait() + + async def tx2(connection): + async with connection.transaction(): + await setup.wait() + result = await db.fetch_all(notes.select()) + assert result == [], result + done.set() + + async with Database(database_url) as db: + await asyncio.gather(tx1(db), tx2(db)) + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_connection_cleanup_contextmanager(database_url): + """ + Ensure that task connections are not persisted unecessarily. + """ + + ready = asyncio.Event() + done = asyncio.Event() + + async def check_child_connection(database: Database): + async with database.connection(): + ready.set() + await done.wait() + + async with Database(database_url) as database: + # Should have a connection in this task + # .connect is lazy, it doesn't create a Connection, but .connection does + connection = database.connection() + assert isinstance(database._connection_map, MutableMapping) + assert database._connection_map.get(asyncio.current_task()) is connection + + # Create a child task and see if it registers a connection + task = asyncio.create_task(check_child_connection(database)) + await ready.wait() + assert database._connection_map.get(task) is not None + assert database._connection_map.get(task) is not connection + + # Let the child task finish, and see if it cleaned up + done.set() + await task + # This is normal exit logic cleanup, the WeakKeyDictionary + # shouldn't have cleaned up yet since the task is still referenced + assert task not in database._connection_map + + # Context manager closes, all open connections are removed + assert isinstance(database._connection_map, MutableMapping) + assert len(database._connection_map) == 0 + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_connection_cleanup_garbagecollector(database_url): + """ + Ensure that connections for tasks are not persisted unecessarily, even + if exit handlers are not called. + """ + database = Database(database_url) + await database.connect() + + created = asyncio.Event() + + async def check_child_connection(database: Database): + # neither .disconnect nor .__aexit__ are called before deleting this task + database.connection() + created.set() + + task = asyncio.create_task(check_child_connection(database)) + await created.wait() + assert task in database._connection_map + await task + del task + gc.collect() + + # Should not have a connection for the task anymore + assert len(database._connection_map) == 0 + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_transaction_context_cleanup_contextmanager(database_url): + """ + Ensure that contextvar transactions are not persisted unecessarily. + """ + from databases.core import _ACTIVE_TRANSACTIONS + + assert _ACTIVE_TRANSACTIONS.get() is None + + async with Database(database_url) as database: + async with database.transaction() as transaction: + open_transactions = _ACTIVE_TRANSACTIONS.get() + assert isinstance(open_transactions, MutableMapping) + assert open_transactions.get(transaction) is transaction._transaction + + # Context manager closes, open_transactions is cleaned up + open_transactions = _ACTIVE_TRANSACTIONS.get() + assert isinstance(open_transactions, MutableMapping) + assert open_transactions.get(transaction, None) is None + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_transaction_context_cleanup_garbagecollector(database_url): + """ + Ensure that contextvar transactions are not persisted unecessarily, even + if exit handlers are not called. + + This test should be an XFAIL, but cannot be due to the way that is hangs + during teardown. + """ + from databases.core import _ACTIVE_TRANSACTIONS + + assert _ACTIVE_TRANSACTIONS.get() is None + + async with Database(database_url) as database: + transaction = database.transaction() + await transaction.start() + + # Should be tracking the transaction + open_transactions = _ACTIVE_TRANSACTIONS.get() + assert isinstance(open_transactions, MutableMapping) + assert open_transactions.get(transaction) is transaction._transaction + + # neither .commit, .rollback, nor .__aexit__ are called + del transaction + gc.collect() + + # TODO(zevisert,review): Could skip instead of using the logic below + # A strong reference to the transaction is kept alive by the connection's + # ._transaction_stack, so it is still be tracked at this point. + assert len(open_transactions) == 1 + + # If that were magically cleared, the transaction would be cleaned up, + # but as it stands this always causes a hang during teardown at + # `Database(...).disconnect()` if the transaction is not closed. + transaction = database.connection()._transaction_stack[-1] + await transaction.rollback() + del transaction + + # Now with the transaction rolled-back, it should be cleaned up. + assert len(open_transactions) == 0 + + @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter async def test_transaction_commit_serializable(database_url): @@ -498,6 +795,7 @@ def insert_independently(): query = notes.insert().values(text="example1", completed=True) conn.execute(query) + conn.close() def delete_independently(): engine = sqlalchemy.create_engine(str(database_url)) @@ -505,6 +803,7 @@ def delete_independently(): query = notes.delete() conn.execute(query) + conn.close() async with Database(database_url) as database: async with database.transaction(force_rollback=True, isolation="serializable"): @@ -609,17 +908,44 @@ async def insert_data(raise_exception): with pytest.raises(RuntimeError): await insert_data(raise_exception=True) - query = notes.select() - results = await database.fetch_all(query=query) + results = await database.fetch_all(query=notes.select()) assert len(results) == 0 await insert_data(raise_exception=False) - query = notes.select() - results = await database.fetch_all(query=query) + results = await database.fetch_all(query=notes.select()) assert len(results) == 1 +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_transaction_decorator_concurrent(database_url): + """ + Ensure that @database.transaction() can be called concurrently. + """ + + database = Database(database_url) + + @database.transaction() + async def insert_data(): + await database.execute( + query=notes.insert().values(text="example", completed=True) + ) + + async with database: + await asyncio.gather( + insert_data(), + insert_data(), + insert_data(), + insert_data(), + insert_data(), + insert_data(), + ) + + results = await database.fetch_all(query=notes.select()) + assert len(results) == 6 + + @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter async def test_datetime_field(database_url): @@ -644,6 +970,52 @@ async def test_datetime_field(database_url): assert results[0]["published"] == now +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_date_field(database_url): + """ + Test Date columns, to ensure records are coerced to/from proper Python types. + """ + + async with Database(database_url) as database: + async with database.transaction(force_rollback=True): + now = datetime.date.today() + + # execute() + query = events.insert() + values = {"date": now} + await database.execute(query, values) + + # fetch_all() + query = events.select() + results = await database.fetch_all(query=query) + assert len(results) == 1 + assert results[0]["date"] == now + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_time_field(database_url): + """ + Test Time columns, to ensure records are coerced to/from proper Python types. + """ + + async with Database(database_url) as database: + async with database.transaction(force_rollback=True): + now = datetime.datetime.now().time().replace(microsecond=0) + + # execute() + query = daily_schedule.insert() + values = {"time": now} + await database.execute(query, values) + + # fetch_all() + query = daily_schedule.select() + results = await database.fetch_all(query=query) + assert len(results) == 1 + assert results[0]["time"] == now + + @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter async def test_decimal_field(database_url): @@ -673,7 +1045,32 @@ async def test_decimal_field(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter -async def test_json_field(database_url): +async def test_enum_field(database_url): + """ + Test enum columns, to ensure correct cross-database support. + """ + + async with Database(database_url) as database: + async with database.transaction(force_rollback=True): + # execute() + size = TshirtSize.SMALL + color = TshirtColor.GREEN + values = {"size": size, "color": color} + query = tshirt_size.insert() + await database.execute(query, values) + + # fetch_all() + query = tshirt_size.select() + results = await database.fetch_all(query=query) + + assert len(results) == 1 + assert results[0]["size"] == size + assert results[0]["color"] == color + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_json_dict_field(database_url): """ Test JSON columns, to ensure correct cross-database support. """ @@ -689,10 +1086,34 @@ async def test_json_field(database_url): # fetch_all() query = session.select() results = await database.fetch_all(query=query) + assert len(results) == 1 assert results[0]["data"] == {"text": "hello", "boolean": True, "int": 1} +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_json_list_field(database_url): + """ + Test JSON columns, to ensure correct cross-database support. + """ + + async with Database(database_url) as database: + async with database.transaction(force_rollback=True): + # execute() + data = ["lemon", "raspberry", "lime", "pumice"] + values = {"data": data} + query = session.insert() + await database.execute(query, values) + + # fetch_all() + query = session.select() + results = await database.fetch_all(query=query) + + assert len(results) == 1 + assert results[0]["data"] == ["lemon", "raspberry", "lime", "pumice"] + + @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter async def test_custom_field(database_url): @@ -789,15 +1210,16 @@ async def test_connect_and_disconnect(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter -async def test_connection_context(database_url): - """ - Test connection contexts are task-local. - """ +async def test_connection_context_same_task(database_url): async with Database(database_url) as database: async with database.connection() as connection_1: async with database.connection() as connection_2: assert connection_1 is connection_2 + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_connection_context_multiple_sibling_tasks(database_url): async with Database(database_url) as database: connection_1 = None connection_2 = None @@ -817,9 +1239,8 @@ async def get_connection_2(): connection_2 = connection await test_complete.wait() - loop = asyncio.get_event_loop() - task_1 = loop.create_task(get_connection_1()) - task_2 = loop.create_task(get_connection_2()) + task_1 = asyncio.create_task(get_connection_1()) + task_2 = asyncio.create_task(get_connection_2()) while connection_1 is None or connection_2 is None: await asyncio.sleep(0.000001) assert connection_1 is not connection_2 @@ -828,6 +1249,61 @@ async def get_connection_2(): await task_2 +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_connection_context_multiple_tasks(database_url): + async with Database(database_url) as database: + parent_connection = database.connection() + connection_1 = None + connection_2 = None + task_1_ready = asyncio.Event() + task_2_ready = asyncio.Event() + test_complete = asyncio.Event() + + async def get_connection_1(): + nonlocal connection_1 + + async with database.connection() as connection: + connection_1 = connection + task_1_ready.set() + await test_complete.wait() + + async def get_connection_2(): + nonlocal connection_2 + + async with database.connection() as connection: + connection_2 = connection + task_2_ready.set() + await test_complete.wait() + + task_1 = asyncio.create_task(get_connection_1()) + task_2 = asyncio.create_task(get_connection_2()) + await task_1_ready.wait() + await task_2_ready.wait() + + assert connection_1 is not parent_connection + assert connection_2 is not parent_connection + assert connection_1 is not connection_2 + + test_complete.set() + await task_1 + await task_2 + + +@pytest.mark.parametrize( + "database_url1,database_url2", + ( + pytest.param(db1, db2, id=f"{db1} | {db2}") + for (db1, db2) in itertools.combinations(DATABASE_URLS, 2) + ), +) +@async_adapter +async def test_connection_context_multiple_databases(database_url1, database_url2): + async with Database(database_url1) as database1: + async with Database(database_url2) as database2: + assert database1.connection() is not database2.connection() + + @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter async def test_connection_context_with_raw_connection(database_url): @@ -961,16 +1437,59 @@ async def test_database_url_interface(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter async def test_concurrent_access_on_single_connection(database_url): - database_url = DatabaseURL(database_url) - if database_url.dialect != "postgresql": - pytest.skip("Test requires `pg_sleep()`") - async with Database(database_url, force_rollback=True) as database: async def db_lookup(): - await database.fetch_one("SELECT pg_sleep(1)") + await database.fetch_one("SELECT 1 AS value") + + await asyncio.gather( + db_lookup(), + db_lookup(), + ) + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_concurrent_transactions_on_single_connection(database_url: str): + async with Database(database_url) as database: + + @database.transaction() + async def db_lookup(): + await database.fetch_one(query="SELECT 1 AS value") + + await asyncio.gather( + db_lookup(), + db_lookup(), + ) + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_concurrent_tasks_on_single_connection(database_url: str): + async with Database(database_url) as database: + + async def db_lookup(): + await database.fetch_one(query="SELECT 1 AS value") - await asyncio.gather(db_lookup(), db_lookup()) + await asyncio.gather( + asyncio.create_task(db_lookup()), + asyncio.create_task(db_lookup()), + ) + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_concurrent_task_transactions_on_single_connection(database_url: str): + async with Database(database_url) as database: + + @database.transaction() + async def db_lookup(): + await database.fetch_one(query="SELECT 1 AS value") + + await asyncio.gather( + asyncio.create_task(db_lookup()), + asyncio.create_task(db_lookup()), + ) @pytest.mark.parametrize("database_url", DATABASE_URLS) @@ -1075,52 +1594,6 @@ async def test_column_names(database_url, select_query): assert results[0]["completed"] == True -@pytest.mark.parametrize("database_url", DATABASE_URLS) -@async_adapter -async def test_posgres_interface(database_url): - """ - Since SQLAlchemy 1.4, `Row.values()` is removed and `Row.keys()` is deprecated. - Custom postgres interface mimics more or less this behaviour by deprecating those - two methods - """ - database_url = DatabaseURL(database_url) - - if database_url.scheme not in ["postgresql", "postgresql+asyncpg"]: - pytest.skip("Test is only for asyncpg") - - async with Database(database_url) as database: - async with database.transaction(force_rollback=True): - query = notes.insert() - values = {"text": "example1", "completed": True} - await database.execute(query, values) - - query = notes.select() - result = await database.fetch_one(query=query) - - with pytest.warns( - DeprecationWarning, - match=re.escape( - "The `Row.keys()` method is deprecated to mimic SQLAlchemy behaviour, " - "use `Row._mapping.keys()` instead." - ), - ): - assert ( - list(result.keys()) - == [k for k in result] - == ["id", "text", "completed"] - ) - - with pytest.warns( - DeprecationWarning, - match=re.escape( - "The `Row.values()` method is deprecated to mimic SQLAlchemy behaviour, " - "use `Row._mapping.values()` instead." - ), - ): - # avoid checking `id` at index 0 since it may change depending on the launched tests - assert list(result.values())[1:] == ["example1", True] - - @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter async def test_postcompile_queries(database_url): @@ -1153,6 +1626,82 @@ async def test_result_named_access(database_url): assert result.completed is True +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_mapping_property_interface(database_url): + """ + Test that all connections implement interface with `_mapping` property + """ + async with Database(database_url) as database: + query = notes.select() + single_result = await database.fetch_one(query=query) + assert single_result._mapping["text"] == "example1" + assert single_result._mapping["completed"] is True + + list_result = await database.fetch_all(query=query) + assert list_result[0]._mapping["text"] == "example1" + assert list_result[0]._mapping["completed"] is True + + +@async_adapter +async def test_should_not_maintain_ref_when_no_cache_param(): + async with Database( + "sqlite:///file::memory:", + uri=True, + ) as database: + query = sqlalchemy.schema.CreateTable(notes) + await database.execute(query) + + query = notes.insert() + values = {"text": "example1", "completed": True} + with pytest.raises(sqlite3.OperationalError): + await database.execute(query, values) + + +@async_adapter +async def test_should_maintain_ref_when_cache_param(): + async with Database( + "sqlite:///file::memory:?cache=shared", + uri=True, + ) as database: + query = sqlalchemy.schema.CreateTable(notes) + await database.execute(query) + + query = notes.insert() + values = {"text": "example1", "completed": True} + await database.execute(query, values) + + query = notes.select().where(notes.c.text == "example1") + result = await database.fetch_one(query=query) + assert result.text == "example1" + assert result.completed is True + + +@async_adapter +async def test_should_remove_ref_on_disconnect(): + async with Database( + "sqlite:///file::memory:?cache=shared", + uri=True, + ) as database: + query = sqlalchemy.schema.CreateTable(notes) + await database.execute(query) + + query = notes.insert() + values = {"text": "example1", "completed": True} + await database.execute(query, values) + + # Run garbage collection to reset the database if we dropped the reference + gc.collect() + + async with Database( + "sqlite:///file::memory:?cache=shared", + uri=True, + ) as database: + query = notes.select() + with pytest.raises(sqlite3.OperationalError): + await database.fetch_all(query=query) + + @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter async def test_mapping_property_interface(database_url):