From 335e5efd5f759d4029dc4a423eae26a90c591193 Mon Sep 17 00:00:00 2001 From: Maxwell Muoto <41130755+max-muoto@users.noreply.github.com> Date: Sun, 23 Jun 2024 14:28:10 -0500 Subject: [PATCH 01/12] feat: Improve/fix sqlite aggregration protocols --- stdlib/sqlite3/dbapi2.pyi | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/stdlib/sqlite3/dbapi2.pyi b/stdlib/sqlite3/dbapi2.pyi index 3cb4b93e88fe..5a5c221a3ce7 100644 --- a/stdlib/sqlite3/dbapi2.pyi +++ b/stdlib/sqlite3/dbapi2.pyi @@ -4,13 +4,14 @@ from _typeshed import ReadableBuffer, StrOrBytesPath, SupportsLenAndGetItem, Unu from collections.abc import Callable, Generator, Iterable, Iterator, Mapping from datetime import date, datetime, time from types import TracebackType -from typing import Any, Literal, Protocol, SupportsIndex, TypeVar, final, overload -from typing_extensions import Self, TypeAlias +from typing import Any, Literal, Protocol, SupportsIndex, final, overload +from typing_extensions import Self, TypeAlias, TypeVar _T = TypeVar("_T") _ConnectionT = TypeVar("_ConnectionT", bound=Connection) _CursorT = TypeVar("_CursorT", bound=Cursor) _SqliteData: TypeAlias = str | ReadableBuffer | int | float | None +_SQLType = TypeVar("_SQLType", bound=_SqliteData, default=_SqliteData, covariant=True) # Data that is passed through adapters can be of any type accepted by an adapter. _AdaptedInputData: TypeAlias = _SqliteData | Any # The Mapping must really be a dict, but making it invariant is too annoying. @@ -312,27 +313,27 @@ else: def register_adapter(type: type[_T], caster: _Adapter[_T], /) -> None: ... def register_converter(name: str, converter: _Converter, /) -> None: ... -class _AggregateProtocol(Protocol): - def step(self, value: int, /) -> object: ... - def finalize(self) -> int: ... +class _AggregateProtocol(Protocol[_SQLType]): + def step(self, *args: Any) -> object: ... + def finalize(self) -> _SQLType: ... -class _SingleParamWindowAggregateClass(Protocol): +class _SingleParamWindowAggregateClass(Protocol[_SQLType]): def step(self, param: Any, /) -> object: ... def inverse(self, param: Any, /) -> object: ... - def value(self) -> _SqliteData: ... - def finalize(self) -> _SqliteData: ... + def value(self) -> _SQLType: ... + def finalize(self) -> _SQLType: ... -class _AnyParamWindowAggregateClass(Protocol): +class _AnyParamWindowAggregateClass(Protocol[_SQLType]): def step(self, *args: Any) -> object: ... def inverse(self, *args: Any) -> object: ... - def value(self) -> _SqliteData: ... - def finalize(self) -> _SqliteData: ... + def value(self) -> _SQLType: ... + def finalize(self) -> _SQLType: ... -class _WindowAggregateClass(Protocol): +class _WindowAggregateClass(Protocol[_SQLType]): step: Callable[..., object] inverse: Callable[..., object] - def value(self) -> _SqliteData: ... - def finalize(self) -> _SqliteData: ... + def value(self) -> _SQLType: ... + def finalize(self) -> _SQLType: ... class Connection: @property From c76c6438565799858647906321fc901cb8af5db4 Mon Sep 17 00:00:00 2001 From: Maxwell Muoto <41130755+max-muoto@users.noreply.github.com> Date: Sun, 23 Jun 2024 14:34:05 -0500 Subject: [PATCH 02/12] Remove default --- stdlib/sqlite3/dbapi2.pyi | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/stdlib/sqlite3/dbapi2.pyi b/stdlib/sqlite3/dbapi2.pyi index 5a5c221a3ce7..af715c5d41d3 100644 --- a/stdlib/sqlite3/dbapi2.pyi +++ b/stdlib/sqlite3/dbapi2.pyi @@ -4,14 +4,14 @@ from _typeshed import ReadableBuffer, StrOrBytesPath, SupportsLenAndGetItem, Unu from collections.abc import Callable, Generator, Iterable, Iterator, Mapping from datetime import date, datetime, time from types import TracebackType -from typing import Any, Literal, Protocol, SupportsIndex, final, overload -from typing_extensions import Self, TypeAlias, TypeVar +from typing import Any, Literal, Protocol, SupportsIndex, TypeVar, final, overload +from typing_extensions import Self, TypeAlias _T = TypeVar("_T") _ConnectionT = TypeVar("_ConnectionT", bound=Connection) _CursorT = TypeVar("_CursorT", bound=Cursor) _SqliteData: TypeAlias = str | ReadableBuffer | int | float | None -_SQLType = TypeVar("_SQLType", bound=_SqliteData, default=_SqliteData, covariant=True) +_SQLType = TypeVar("_SQLType", bound=_SqliteData, covariant=True) # Data that is passed through adapters can be of any type accepted by an adapter. _AdaptedInputData: TypeAlias = _SqliteData | Any # The Mapping must really be a dict, but making it invariant is too annoying. From 694f7b41551c83d0762f26bdb0d05475083999d6 Mon Sep 17 00:00:00 2001 From: Maxwell Muoto <41130755+max-muoto@users.noreply.github.com> Date: Sun, 23 Jun 2024 14:35:16 -0500 Subject: [PATCH 03/12] Add default back --- stdlib/sqlite3/dbapi2.pyi | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/stdlib/sqlite3/dbapi2.pyi b/stdlib/sqlite3/dbapi2.pyi index af715c5d41d3..5a5c221a3ce7 100644 --- a/stdlib/sqlite3/dbapi2.pyi +++ b/stdlib/sqlite3/dbapi2.pyi @@ -4,14 +4,14 @@ from _typeshed import ReadableBuffer, StrOrBytesPath, SupportsLenAndGetItem, Unu from collections.abc import Callable, Generator, Iterable, Iterator, Mapping from datetime import date, datetime, time from types import TracebackType -from typing import Any, Literal, Protocol, SupportsIndex, TypeVar, final, overload -from typing_extensions import Self, TypeAlias +from typing import Any, Literal, Protocol, SupportsIndex, final, overload +from typing_extensions import Self, TypeAlias, TypeVar _T = TypeVar("_T") _ConnectionT = TypeVar("_ConnectionT", bound=Connection) _CursorT = TypeVar("_CursorT", bound=Cursor) _SqliteData: TypeAlias = str | ReadableBuffer | int | float | None -_SQLType = TypeVar("_SQLType", bound=_SqliteData, covariant=True) +_SQLType = TypeVar("_SQLType", bound=_SqliteData, default=_SqliteData, covariant=True) # Data that is passed through adapters can be of any type accepted by an adapter. _AdaptedInputData: TypeAlias = _SqliteData | Any # The Mapping must really be a dict, but making it invariant is too annoying. From 6b9e3a11e45833f6e0f64e797ba3720c46e8d510 Mon Sep 17 00:00:00 2001 From: Maxwell Muoto <41130755+max-muoto@users.noreply.github.com> Date: Sun, 23 Jun 2024 16:02:12 -0500 Subject: [PATCH 04/12] Add test --- .../test_cases/sqlite3/check_aggregations.py | 74 +++++++++++++++++++ .../check_connection.py} | 0 stdlib/sqlite3/dbapi2.pyi | 42 +++++------ 3 files changed, 95 insertions(+), 21 deletions(-) create mode 100644 stdlib/@tests/test_cases/sqlite3/check_aggregations.py rename stdlib/@tests/test_cases/{check_sqlite3.py => sqlite3/check_connection.py} (100%) diff --git a/stdlib/@tests/test_cases/sqlite3/check_aggregations.py b/stdlib/@tests/test_cases/sqlite3/check_aggregations.py new file mode 100644 index 000000000000..c5ee0308f3eb --- /dev/null +++ b/stdlib/@tests/test_cases/sqlite3/check_aggregations.py @@ -0,0 +1,74 @@ +import sqlite3 + +class WindowSumInt: + def __init__(self) -> None: + self.count = 0 + + def step(self, param: int) -> None: + self.count += param + + def value(self) -> int: + return self.count + + def inverse(self, param: int) -> None: + self.count -= param + + def finalize(self) -> int: + return self.count + + +con = sqlite3.connect(":memory:") +cur = con.execute("CREATE TABLE test(x, y)") +values = [ + ("a", 4), + ("b", 5), + ("c", 3), + ("d", 8), + ("e", 1), +] +cur.executemany("INSERT INTO test VALUES(?, ?)", values) +con.create_window_function("sumint", 1, WindowSumInt) +con.create_aggregate("sumint", 1, WindowSumInt) +cur.execute(""" + SELECT x, sumint(y) OVER ( + ORDER BY x ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING + ) AS sum_y + FROM test ORDER BY x +""") +con.close() + + +def create_window_function() -> WindowSumInt: + return WindowSumInt() + + +# A callable should work as well. +con.create_window_function("sumint", 1, create_window_function) +con.create_aggregate("sumint", 1, create_window_function) + +# With num_args set to 1, the callable should not be called with more than one. + +class WindowSumIntMultiArgs: + def __init__(self) -> None: + self.count = 0 + + def step(self, arg_1: int, arg_2: int) -> None: + self.count += arg_1 + arg_2 + + def value(self) -> int: + return self.count + + def inverse(self, arg_1: int, arg_2: int) -> None: + self.count -= arg_1 + arg_2 + + def finalize(self) -> int: + return self.count + + +# This should fail because the callable is called with more than one argument. +con.create_window_function("sumint", 1, WindowSumIntMultiArgs) # type: ignore +con.create_aggregate("sumint", 1, WindowSumIntMultiArgs) # type: ignore + +# With num_args set to -1, this should work. +con.create_window_function("sumint", 2, WindowSumIntMultiArgs) +con.create_aggregate("sumint", 2, WindowSumIntMultiArgs) diff --git a/stdlib/@tests/test_cases/check_sqlite3.py b/stdlib/@tests/test_cases/sqlite3/check_connection.py similarity index 100% rename from stdlib/@tests/test_cases/check_sqlite3.py rename to stdlib/@tests/test_cases/sqlite3/check_connection.py diff --git a/stdlib/sqlite3/dbapi2.pyi b/stdlib/sqlite3/dbapi2.pyi index 5a5c221a3ce7..c9286e6b5076 100644 --- a/stdlib/sqlite3/dbapi2.pyi +++ b/stdlib/sqlite3/dbapi2.pyi @@ -4,14 +4,14 @@ from _typeshed import ReadableBuffer, StrOrBytesPath, SupportsLenAndGetItem, Unu from collections.abc import Callable, Generator, Iterable, Iterator, Mapping from datetime import date, datetime, time from types import TracebackType -from typing import Any, Literal, Protocol, SupportsIndex, final, overload -from typing_extensions import Self, TypeAlias, TypeVar +from typing import Any, Literal, Protocol, SupportsIndex, final, overload, TypeVar +from typing_extensions import Self, TypeAlias _T = TypeVar("_T") _ConnectionT = TypeVar("_ConnectionT", bound=Connection) _CursorT = TypeVar("_CursorT", bound=Cursor) _SqliteData: TypeAlias = str | ReadableBuffer | int | float | None -_SQLType = TypeVar("_SQLType", bound=_SqliteData, default=_SqliteData, covariant=True) +_SQLType = TypeVar("_SQLType", bound=_SqliteData) # Data that is passed through adapters can be of any type accepted by an adapter. _AdaptedInputData: TypeAlias = _SqliteData | Any # The Mapping must really be a dict, but making it invariant is too annoying. @@ -313,27 +313,26 @@ else: def register_adapter(type: type[_T], caster: _Adapter[_T], /) -> None: ... def register_converter(name: str, converter: _Converter, /) -> None: ... -class _AggregateProtocol(Protocol[_SQLType]): - def step(self, *args: Any) -> object: ... +class _SingleParamAggregateProtocol(Protocol[_SQLType]): + def step(self, param: _SQLType, /) -> object: ... + def finalize(self) -> _SQLType: ... + +class _AnyParamAggregateProtocol(Protocol[_SQLType]): + def step(self, *args: _SQLType) -> object: ... def finalize(self) -> _SQLType: ... class _SingleParamWindowAggregateClass(Protocol[_SQLType]): - def step(self, param: Any, /) -> object: ... - def inverse(self, param: Any, /) -> object: ... + def step(self, param: _SQLType, /) -> object: ... + def inverse(self, param: _SQLType, /) -> object: ... def value(self) -> _SQLType: ... def finalize(self) -> _SQLType: ... class _AnyParamWindowAggregateClass(Protocol[_SQLType]): - def step(self, *args: Any) -> object: ... - def inverse(self, *args: Any) -> object: ... + def step(self, *args: _SQLType) -> object: ... + def inverse(self, *args: _SQLType) -> object: ... def value(self) -> _SQLType: ... def finalize(self) -> _SQLType: ... -class _WindowAggregateClass(Protocol[_SQLType]): - step: Callable[..., object] - inverse: Callable[..., object] - def value(self) -> _SQLType: ... - def finalize(self) -> _SQLType: ... class Connection: @property @@ -399,22 +398,23 @@ class Connection: def blobopen(self, table: str, column: str, row: int, /, *, readonly: bool = False, name: str = "main") -> Blob: ... def commit(self) -> None: ... - def create_aggregate(self, name: str, n_arg: int, aggregate_class: Callable[[], _AggregateProtocol]) -> None: ... + + @overload + def create_aggregate(self, name: str, n_arg: Literal[1], aggregate_class: Callable[[], _SingleParamAggregateProtocol[_SQLType]]) -> None: ... + @overload + def create_aggregate(self, name: str, n_arg: int, aggregate_class: Callable[[], _AnyParamAggregateProtocol[_SQLType]]) -> None: ... + if sys.version_info >= (3, 11): # num_params determines how many params will be passed to the aggregate class. We provide an overload # for the case where num_params = 1, which is expected to be the common case. @overload def create_window_function( - self, name: str, num_params: Literal[1], aggregate_class: Callable[[], _SingleParamWindowAggregateClass] | None, / + self, name: str, num_params: Literal[1], aggregate_class: Callable[[], _SingleParamWindowAggregateClass[_SQLType]] | None, / ) -> None: ... # And for num_params = -1, which means the aggregate must accept any number of parameters. @overload def create_window_function( - self, name: str, num_params: Literal[-1], aggregate_class: Callable[[], _AnyParamWindowAggregateClass] | None, / - ) -> None: ... - @overload - def create_window_function( - self, name: str, num_params: int, aggregate_class: Callable[[], _WindowAggregateClass] | None, / + self, name: str, num_params: int, aggregate_class: Callable[[], _AnyParamWindowAggregateClass[_SQLType]] | None, / ) -> None: ... def create_collation(self, name: str, callback: Callable[[str, str], int | SupportsIndex] | None, /) -> None: ... From d52702a394bf28dab8e8b51317dd6a22e005b0d9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 23 Jun 2024 21:03:59 +0000 Subject: [PATCH 05/12] [pre-commit.ci] auto fixes from pre-commit.com hooks --- .../test_cases/sqlite3/check_aggregations.py | 20 +++++++++---------- stdlib/sqlite3/dbapi2.pyi | 20 ++++++++++++------- 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/stdlib/@tests/test_cases/sqlite3/check_aggregations.py b/stdlib/@tests/test_cases/sqlite3/check_aggregations.py index c5ee0308f3eb..6247e0eaeb4f 100644 --- a/stdlib/@tests/test_cases/sqlite3/check_aggregations.py +++ b/stdlib/@tests/test_cases/sqlite3/check_aggregations.py @@ -1,5 +1,6 @@ import sqlite3 + class WindowSumInt: def __init__(self) -> None: self.count = 0 @@ -19,22 +20,18 @@ def finalize(self) -> int: con = sqlite3.connect(":memory:") cur = con.execute("CREATE TABLE test(x, y)") -values = [ - ("a", 4), - ("b", 5), - ("c", 3), - ("d", 8), - ("e", 1), -] +values = [("a", 4), ("b", 5), ("c", 3), ("d", 8), ("e", 1)] cur.executemany("INSERT INTO test VALUES(?, ?)", values) con.create_window_function("sumint", 1, WindowSumInt) con.create_aggregate("sumint", 1, WindowSumInt) -cur.execute(""" +cur.execute( + """ SELECT x, sumint(y) OVER ( ORDER BY x ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING ) AS sum_y FROM test ORDER BY x -""") +""" +) con.close() @@ -48,6 +45,7 @@ def create_window_function() -> WindowSumInt: # With num_args set to 1, the callable should not be called with more than one. + class WindowSumIntMultiArgs: def __init__(self) -> None: self.count = 0 @@ -66,8 +64,8 @@ def finalize(self) -> int: # This should fail because the callable is called with more than one argument. -con.create_window_function("sumint", 1, WindowSumIntMultiArgs) # type: ignore -con.create_aggregate("sumint", 1, WindowSumIntMultiArgs) # type: ignore +con.create_window_function("sumint", 1, WindowSumIntMultiArgs) # type: ignore +con.create_aggregate("sumint", 1, WindowSumIntMultiArgs) # type: ignore # With num_args set to -1, this should work. con.create_window_function("sumint", 2, WindowSumIntMultiArgs) diff --git a/stdlib/sqlite3/dbapi2.pyi b/stdlib/sqlite3/dbapi2.pyi index c9286e6b5076..0fd7740804dd 100644 --- a/stdlib/sqlite3/dbapi2.pyi +++ b/stdlib/sqlite3/dbapi2.pyi @@ -4,7 +4,7 @@ from _typeshed import ReadableBuffer, StrOrBytesPath, SupportsLenAndGetItem, Unu from collections.abc import Callable, Generator, Iterable, Iterator, Mapping from datetime import date, datetime, time from types import TracebackType -from typing import Any, Literal, Protocol, SupportsIndex, final, overload, TypeVar +from typing import Any, Literal, Protocol, SupportsIndex, TypeVar, final, overload from typing_extensions import Self, TypeAlias _T = TypeVar("_T") @@ -333,7 +333,6 @@ class _AnyParamWindowAggregateClass(Protocol[_SQLType]): def value(self) -> _SQLType: ... def finalize(self) -> _SQLType: ... - class Connection: @property def DataError(self) -> type[sqlite3.DataError]: ... @@ -398,18 +397,25 @@ class Connection: def blobopen(self, table: str, column: str, row: int, /, *, readonly: bool = False, name: str = "main") -> Blob: ... def commit(self) -> None: ... - @overload - def create_aggregate(self, name: str, n_arg: Literal[1], aggregate_class: Callable[[], _SingleParamAggregateProtocol[_SQLType]]) -> None: ... + def create_aggregate( + self, name: str, n_arg: Literal[1], aggregate_class: Callable[[], _SingleParamAggregateProtocol[_SQLType]] + ) -> None: ... @overload - def create_aggregate(self, name: str, n_arg: int, aggregate_class: Callable[[], _AnyParamAggregateProtocol[_SQLType]]) -> None: ... - + def create_aggregate( + self, name: str, n_arg: int, aggregate_class: Callable[[], _AnyParamAggregateProtocol[_SQLType]] + ) -> None: ... + if sys.version_info >= (3, 11): # num_params determines how many params will be passed to the aggregate class. We provide an overload # for the case where num_params = 1, which is expected to be the common case. @overload def create_window_function( - self, name: str, num_params: Literal[1], aggregate_class: Callable[[], _SingleParamWindowAggregateClass[_SQLType]] | None, / + self, + name: str, + num_params: Literal[1], + aggregate_class: Callable[[], _SingleParamWindowAggregateClass[_SQLType]] | None, + /, ) -> None: ... # And for num_params = -1, which means the aggregate must accept any number of parameters. @overload From 5377741f3a1d63639de5a9e5503e0c99d0c5ff7d Mon Sep 17 00:00:00 2001 From: Maxwell Muoto <41130755+max-muoto@users.noreply.github.com> Date: Sun, 23 Jun 2024 16:12:14 -0500 Subject: [PATCH 06/12] Fix test --- .../test_cases/sqlite3/check_aggregations.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/stdlib/@tests/test_cases/sqlite3/check_aggregations.py b/stdlib/@tests/test_cases/sqlite3/check_aggregations.py index 6247e0eaeb4f..601776e8bee4 100644 --- a/stdlib/@tests/test_cases/sqlite3/check_aggregations.py +++ b/stdlib/@tests/test_cases/sqlite3/check_aggregations.py @@ -50,23 +50,20 @@ class WindowSumIntMultiArgs: def __init__(self) -> None: self.count = 0 - def step(self, arg_1: int, arg_2: int) -> None: - self.count += arg_1 + arg_2 + def step(self, *args: int) -> None: + self.count += sum(args) def value(self) -> int: return self.count - def inverse(self, arg_1: int, arg_2: int) -> None: - self.count -= arg_1 + arg_2 + def inverse(self, *args: int) -> None: + self.count -= sum(args) def finalize(self) -> int: return self.count +con.create_window_function("sumint", 1, WindowSumIntMultiArgs) +con.create_aggregate("sumint", 1, WindowSumIntMultiArgs) -# This should fail because the callable is called with more than one argument. -con.create_window_function("sumint", 1, WindowSumIntMultiArgs) # type: ignore -con.create_aggregate("sumint", 1, WindowSumIntMultiArgs) # type: ignore - -# With num_args set to -1, this should work. con.create_window_function("sumint", 2, WindowSumIntMultiArgs) con.create_aggregate("sumint", 2, WindowSumIntMultiArgs) From 6169219f5ecabcf04470cf3115ee12bfc0f5244e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 23 Jun 2024 21:13:51 +0000 Subject: [PATCH 07/12] [pre-commit.ci] auto fixes from pre-commit.com hooks --- stdlib/@tests/test_cases/sqlite3/check_aggregations.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/stdlib/@tests/test_cases/sqlite3/check_aggregations.py b/stdlib/@tests/test_cases/sqlite3/check_aggregations.py index 601776e8bee4..0e331383fe91 100644 --- a/stdlib/@tests/test_cases/sqlite3/check_aggregations.py +++ b/stdlib/@tests/test_cases/sqlite3/check_aggregations.py @@ -62,7 +62,8 @@ def inverse(self, *args: int) -> None: def finalize(self) -> int: return self.count -con.create_window_function("sumint", 1, WindowSumIntMultiArgs) + +con.create_window_function("sumint", 1, WindowSumIntMultiArgs) con.create_aggregate("sumint", 1, WindowSumIntMultiArgs) con.create_window_function("sumint", 2, WindowSumIntMultiArgs) From 58fc21a99aed778825177b85653b7489de7dfd1f Mon Sep 17 00:00:00 2001 From: Maxwell Muoto <41130755+max-muoto@users.noreply.github.com> Date: Sun, 23 Jun 2024 16:15:21 -0500 Subject: [PATCH 08/12] Finalize tests --- .../test_cases/sqlite3/check_aggregations.py | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/stdlib/@tests/test_cases/sqlite3/check_aggregations.py b/stdlib/@tests/test_cases/sqlite3/check_aggregations.py index 0e331383fe91..a0d3bf377010 100644 --- a/stdlib/@tests/test_cases/sqlite3/check_aggregations.py +++ b/stdlib/@tests/test_cases/sqlite3/check_aggregations.py @@ -68,3 +68,42 @@ def finalize(self) -> int: con.create_window_function("sumint", 2, WindowSumIntMultiArgs) con.create_aggregate("sumint", 2, WindowSumIntMultiArgs) + + +class WindowSumIntMismatchedArgs: + def __init__(self) -> None: + self.count = 0 + + def step(self, *args: str) -> None: + self.count += 34 + + def value(self) -> int: + return self.count + + def inverse(self, *args: int) -> None: + self.count -= 34 + + def finalize(self) -> str: + return str(self.count) + + +# Since the types for `inverse`, `step`, `finalize`, and `value` are not compatible, the following should fail. +con.create_window_function("sumint", 1, WindowSumIntMismatchedArgs) # type: ignore +con.create_window_function("sumint", 2, WindowSumIntMismatchedArgs) # type: ignore + + + +class AggMismatchedArgs: + def __init__(self) -> None: + self.count = 0 + + def step(self, *args: str) -> None: + self.count += 34 + + def finalize(self) -> int: + return self.count + + +# Since the types for `step` and `finalize` are not compatible, the following should fail. +con.create_aggregate("sumint", 1, AggMismatchedArgs) # type: ignore +con.create_aggregate("sumint", 2, AggMismatchedArgs) # type: ignore From 4e5b605138fb0db8bd6abb5222c2ac9377cb2f69 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 23 Jun 2024 21:17:08 +0000 Subject: [PATCH 09/12] [pre-commit.ci] auto fixes from pre-commit.com hooks --- stdlib/@tests/test_cases/sqlite3/check_aggregations.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/stdlib/@tests/test_cases/sqlite3/check_aggregations.py b/stdlib/@tests/test_cases/sqlite3/check_aggregations.py index a0d3bf377010..67cd08b35ca6 100644 --- a/stdlib/@tests/test_cases/sqlite3/check_aggregations.py +++ b/stdlib/@tests/test_cases/sqlite3/check_aggregations.py @@ -85,14 +85,13 @@ def inverse(self, *args: int) -> None: def finalize(self) -> str: return str(self.count) - + # Since the types for `inverse`, `step`, `finalize`, and `value` are not compatible, the following should fail. con.create_window_function("sumint", 1, WindowSumIntMismatchedArgs) # type: ignore con.create_window_function("sumint", 2, WindowSumIntMismatchedArgs) # type: ignore - class AggMismatchedArgs: def __init__(self) -> None: self.count = 0 @@ -102,7 +101,7 @@ def step(self, *args: str) -> None: def finalize(self) -> int: return self.count - + # Since the types for `step` and `finalize` are not compatible, the following should fail. con.create_aggregate("sumint", 1, AggMismatchedArgs) # type: ignore From 7f4149075b01659ab5beedbfe2a73752a7e36a42 Mon Sep 17 00:00:00 2001 From: Maxwell Muoto <41130755+max-muoto@users.noreply.github.com> Date: Sun, 23 Jun 2024 16:22:00 -0500 Subject: [PATCH 10/12] Tweak --- stdlib/@tests/test_cases/sqlite3/check_aggregations.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/stdlib/@tests/test_cases/sqlite3/check_aggregations.py b/stdlib/@tests/test_cases/sqlite3/check_aggregations.py index 67cd08b35ca6..fe3b58927416 100644 --- a/stdlib/@tests/test_cases/sqlite3/check_aggregations.py +++ b/stdlib/@tests/test_cases/sqlite3/check_aggregations.py @@ -35,13 +35,13 @@ def finalize(self) -> int: con.close() -def create_window_function() -> WindowSumInt: +def _create_window_function() -> WindowSumInt: return WindowSumInt() # A callable should work as well. -con.create_window_function("sumint", 1, create_window_function) -con.create_aggregate("sumint", 1, create_window_function) +con.create_window_function("sumint", 1, _create_window_function) +con.create_aggregate("sumint", 1, _create_window_function) # With num_args set to 1, the callable should not be called with more than one. From c1ca8212266a5489d7c7a99d8438d82f9aae203e Mon Sep 17 00:00:00 2001 From: Maxwell Muoto <41130755+max-muoto@users.noreply.github.com> Date: Sun, 23 Jun 2024 16:23:30 -0500 Subject: [PATCH 11/12] Fix tests --- .../test_cases/sqlite3/check_aggregations.py | 24 ++++++++++++------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/stdlib/@tests/test_cases/sqlite3/check_aggregations.py b/stdlib/@tests/test_cases/sqlite3/check_aggregations.py index fe3b58927416..6e60474a727c 100644 --- a/stdlib/@tests/test_cases/sqlite3/check_aggregations.py +++ b/stdlib/@tests/test_cases/sqlite3/check_aggregations.py @@ -1,5 +1,5 @@ import sqlite3 - +import sys class WindowSumInt: def __init__(self) -> None: @@ -22,7 +22,10 @@ def finalize(self) -> int: cur = con.execute("CREATE TABLE test(x, y)") values = [("a", 4), ("b", 5), ("c", 3), ("d", 8), ("e", 1)] cur.executemany("INSERT INTO test VALUES(?, ?)", values) -con.create_window_function("sumint", 1, WindowSumInt) + +if sys.version_info >= (3, 11): + con.create_window_function("sumint", 1, WindowSumInt) + con.create_aggregate("sumint", 1, WindowSumInt) cur.execute( """ @@ -40,8 +43,9 @@ def _create_window_function() -> WindowSumInt: # A callable should work as well. -con.create_window_function("sumint", 1, _create_window_function) -con.create_aggregate("sumint", 1, _create_window_function) +if sys.version_info >= (3, 11): + con.create_window_function("sumint", 1, _create_window_function) + con.create_aggregate("sumint", 1, _create_window_function) # With num_args set to 1, the callable should not be called with more than one. @@ -63,10 +67,11 @@ def finalize(self) -> int: return self.count -con.create_window_function("sumint", 1, WindowSumIntMultiArgs) -con.create_aggregate("sumint", 1, WindowSumIntMultiArgs) +if sys.version_info >= (3, 11): + con.create_window_function("sumint", 1, WindowSumIntMultiArgs) + con.create_window_function("sumint", 2, WindowSumIntMultiArgs) -con.create_window_function("sumint", 2, WindowSumIntMultiArgs) +con.create_aggregate("sumint", 1, WindowSumIntMultiArgs) con.create_aggregate("sumint", 2, WindowSumIntMultiArgs) @@ -88,8 +93,9 @@ def finalize(self) -> str: # Since the types for `inverse`, `step`, `finalize`, and `value` are not compatible, the following should fail. -con.create_window_function("sumint", 1, WindowSumIntMismatchedArgs) # type: ignore -con.create_window_function("sumint", 2, WindowSumIntMismatchedArgs) # type: ignore +if sys.version_info >= (3, 11): + con.create_window_function("sumint", 1, WindowSumIntMismatchedArgs) # type: ignore + con.create_window_function("sumint", 2, WindowSumIntMismatchedArgs) # type: ignore class AggMismatchedArgs: From 8bea1e035026aad10f98e22b2c4929cbfdc7c181 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 23 Jun 2024 21:25:16 +0000 Subject: [PATCH 12/12] [pre-commit.ci] auto fixes from pre-commit.com hooks --- stdlib/@tests/test_cases/sqlite3/check_aggregations.py | 1 + 1 file changed, 1 insertion(+) diff --git a/stdlib/@tests/test_cases/sqlite3/check_aggregations.py b/stdlib/@tests/test_cases/sqlite3/check_aggregations.py index 6e60474a727c..8eebe71a39e0 100644 --- a/stdlib/@tests/test_cases/sqlite3/check_aggregations.py +++ b/stdlib/@tests/test_cases/sqlite3/check_aggregations.py @@ -1,6 +1,7 @@ import sqlite3 import sys + class WindowSumInt: def __init__(self) -> None: self.count = 0