Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Detail typing of Connection and Cursor to further match PEP 249 #9299

Merged
merged 9 commits into from
Feb 5, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/9299.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Update internal database-related types to better match PEP 249.
19 changes: 13 additions & 6 deletions synapse/storage/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ def close(self) -> None:
def commit(self) -> None:
self.conn.commit()

def rollback(self, *args, **kwargs) -> None:
self.conn.rollback(*args, **kwargs)
def rollback(self) -> None:
self.conn.rollback()

def __enter__(self) -> "Connection":
self.conn.__enter__()
Expand Down Expand Up @@ -244,12 +244,15 @@ def call_on_exception(
assert self.exception_callbacks is not None
self.exception_callbacks.append((callback, args, kwargs))

def fetchone(self) -> Optional[Tuple]:
return self.txn.fetchone()

def fetchmany(self, size: Optional[int] = None) -> List[Tuple]:
return self.txn.fetchmany(size=size)

def fetchall(self) -> List[Tuple]:
return self.txn.fetchall()

def fetchone(self) -> Tuple:
return self.txn.fetchone()

def __iter__(self) -> Iterator[Tuple]:
return self.txn.__iter__()

Expand Down Expand Up @@ -754,7 +757,11 @@ def cursor_to_dict(cursor: Cursor) -> List[Dict[str, Any]]:
Returns:
A list of dicts where the key is the column header.
"""
col_headers = [intern(str(column[0])) for column in cursor.description]
col_headers = (
[intern(str(column[0])) for column in cursor.description]
if cursor.description is not None
else []
)
results = [dict(zip(col_headers, row)) for row in cursor]
return results

Expand Down
3 changes: 2 additions & 1 deletion synapse/storage/prepare_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,9 +619,10 @@ def _get_or_create_schema_state(

txn.execute("SELECT version, upgraded FROM schema_version")
row = txn.fetchone()
current_version = int(row[0]) if row else None
current_version = int(row[0]) if row is not None and len(row) > 0 else None

if current_version:
assert row is not None
txn.execute(
"SELECT file FROM applied_schema_deltas WHERE version >= ?",
(current_version,),
Expand Down
33 changes: 26 additions & 7 deletions synapse/storage/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,49 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Iterable, Iterator, List, Optional, Tuple
from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Union

from typing_extensions import Protocol

"""
Some very basic protocol definitions for the DB-API2 classes specified in PEP-249
"""

_Parameters = Union[Sequence[Any], Mapping[str, Any]]


class Cursor(Protocol):
def execute(self, sql: str, parameters: Iterable[Any] = ...) -> Any:
def execute(self, sql: str, parameters: _Parameters = ...) -> Any:
...

def executemany(self, sql: str, parameters: Iterable[Iterable[Any]]) -> Any:
def executemany(self, sql: str, parameters: Sequence[_Parameters]) -> Any:
...

def fetchall(self) -> List[Tuple]:
def fetchone(self) -> Optional[Tuple]:
...

def fetchone(self) -> Tuple:
def fetchmany(self, size: Optional[int] = ...) -> List[Tuple]:
...

def fetchall(self) -> List[Tuple]:
...

@property
def description(self) -> Any:
def description(
self,
) -> Optional[
Sequence[
Tuple[
Any,
Any,
Optional[Any],
Optional[Any],
Optional[Any],
Optional[Any],
Optional[Any],
]
]
]:
return None

@property
Expand All @@ -59,7 +78,7 @@ def close(self) -> None:
def commit(self) -> None:
...

def rollback(self, *args, **kwargs) -> None:
def rollback(self) -> None:
...

def __enter__(self) -> "Connection":
Expand Down
13 changes: 11 additions & 2 deletions synapse/storage/util/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,11 @@ def __init__(self, sequence_name: str):

def get_next_id_txn(self, txn: Cursor) -> int:
txn.execute("SELECT nextval(?)", (self._sequence_name,))
return txn.fetchone()[0]
fetch_res = txn.fetchone()
if fetch_res is None:
# this should technically not happen, unless the connection/cursor/transaction has been corrupted
raise RuntimeError("nextval for self._sequence_name returned empty")
return fetch_res[0]

def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
txn.execute(
Expand Down Expand Up @@ -147,7 +151,12 @@ def check_consistency(
txn.execute(
"SELECT last_value, is_called FROM %(seq)s" % {"seq": self._sequence_name}
)
last_value, is_called = txn.fetchone()
fetch_res = txn.fetchone()
if fetch_res is None:
raise RuntimeError(
"Got no value from sequence"
) # should never happen, but doesn't hurt to check
last_value, is_called = fetch_res

# If we have an associated stream check the stream_positions table.
max_in_stream_positions = None
Expand Down