Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit d882fbc

Browse files
Update type hints for Cursor to match PEP 249. (#9299)
1 parent 5a9cdaa commit d882fbc

File tree

5 files changed

+47
-17
lines changed

5 files changed

+47
-17
lines changed

changelog.d/9299.misc

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Update the `Cursor` type hints to better match PEP 249.

synapse/storage/database.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,8 @@ def close(self) -> None:
158158
def commit(self) -> None:
159159
self.conn.commit()
160160

161-
def rollback(self, *args, **kwargs) -> None:
162-
self.conn.rollback(*args, **kwargs)
161+
def rollback(self) -> None:
162+
self.conn.rollback()
163163

164164
def __enter__(self) -> "Connection":
165165
self.conn.__enter__()
@@ -244,12 +244,15 @@ def call_on_exception(
244244
assert self.exception_callbacks is not None
245245
self.exception_callbacks.append((callback, args, kwargs))
246246

247+
def fetchone(self) -> Optional[Tuple]:
248+
return self.txn.fetchone()
249+
250+
def fetchmany(self, size: Optional[int] = None) -> List[Tuple]:
251+
return self.txn.fetchmany(size=size)
252+
247253
def fetchall(self) -> List[Tuple]:
248254
return self.txn.fetchall()
249255

250-
def fetchone(self) -> Tuple:
251-
return self.txn.fetchone()
252-
253256
def __iter__(self) -> Iterator[Tuple]:
254257
return self.txn.__iter__()
255258

@@ -754,6 +757,7 @@ def cursor_to_dict(cursor: Cursor) -> List[Dict[str, Any]]:
754757
Returns:
755758
A list of dicts where the key is the column header.
756759
"""
760+
assert cursor.description is not None, "cursor.description was None"
757761
col_headers = [intern(str(column[0])) for column in cursor.description]
758762
results = [dict(zip(col_headers, row)) for row in cursor]
759763
return results

synapse/storage/prepare_database.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -619,9 +619,9 @@ def _get_or_create_schema_state(
619619

620620
txn.execute("SELECT version, upgraded FROM schema_version")
621621
row = txn.fetchone()
622-
current_version = int(row[0]) if row else None
623622

624-
if current_version:
623+
if row is not None:
624+
current_version = int(row[0])
625625
txn.execute(
626626
"SELECT file FROM applied_schema_deltas WHERE version >= ?",
627627
(current_version,),

synapse/storage/types.py

+29-8
Original file line numberDiff line numberDiff line change
@@ -12,31 +12,52 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
from typing import Any, Iterable, Iterator, List, Optional, Tuple
15+
from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Union
1616

1717
from typing_extensions import Protocol
1818

1919
"""
2020
Some very basic protocol definitions for the DB-API2 classes specified in PEP-249
2121
"""
2222

23+
_Parameters = Union[Sequence[Any], Mapping[str, Any]]
24+
2325

2426
class Cursor(Protocol):
25-
def execute(self, sql: str, parameters: Iterable[Any] = ...) -> Any:
27+
def execute(self, sql: str, parameters: _Parameters = ...) -> Any:
2628
...
2729

28-
def executemany(self, sql: str, parameters: Iterable[Iterable[Any]]) -> Any:
30+
def executemany(self, sql: str, parameters: Sequence[_Parameters]) -> Any:
2931
...
3032

31-
def fetchall(self) -> List[Tuple]:
33+
def fetchone(self) -> Optional[Tuple]:
34+
...
35+
36+
def fetchmany(self, size: Optional[int] = ...) -> List[Tuple]:
3237
...
3338

34-
def fetchone(self) -> Tuple:
39+
def fetchall(self) -> List[Tuple]:
3540
...
3641

3742
@property
38-
def description(self) -> Any:
39-
return None
43+
def description(
44+
self,
45+
) -> Optional[
46+
Sequence[
47+
# Note that this is an approximate typing based on sqlite3 and other
48+
# drivers, and may not be entirely accurate.
49+
Tuple[
50+
str,
51+
Optional[Any],
52+
Optional[int],
53+
Optional[int],
54+
Optional[int],
55+
Optional[int],
56+
Optional[int],
57+
]
58+
]
59+
]:
60+
...
4061

4162
@property
4263
def rowcount(self) -> int:
@@ -59,7 +80,7 @@ def close(self) -> None:
5980
def commit(self) -> None:
6081
...
6182

62-
def rollback(self, *args, **kwargs) -> None:
83+
def rollback(self) -> None:
6384
...
6485

6586
def __enter__(self) -> "Connection":

synapse/storage/util/sequence.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,9 @@ def __init__(self, sequence_name: str):
106106

107107
def get_next_id_txn(self, txn: Cursor) -> int:
108108
txn.execute("SELECT nextval(?)", (self._sequence_name,))
109-
return txn.fetchone()[0]
109+
fetch_res = txn.fetchone()
110+
assert fetch_res is not None
111+
return fetch_res[0]
110112

111113
def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
112114
txn.execute(
@@ -147,7 +149,9 @@ def check_consistency(
147149
txn.execute(
148150
"SELECT last_value, is_called FROM %(seq)s" % {"seq": self._sequence_name}
149151
)
150-
last_value, is_called = txn.fetchone()
152+
fetch_res = txn.fetchone()
153+
assert fetch_res is not None
154+
last_value, is_called = fetch_res
151155

152156
# If we have an associated stream check the stream_positions table.
153157
max_in_stream_positions = None

0 commit comments

Comments
 (0)