Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set our own stream position from the current sequence value on startup #17309

Merged
merged 9 commits into from
Jun 17, 2024
1 change: 1 addition & 0 deletions changelog.d/17309.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
When rolling back to a previous Synapse version and then forwards again to this release, don't require server operators to manually run SQL.
23 changes: 20 additions & 3 deletions synapse/storage/util/id_generators.py
Original file line number Diff line number Diff line change
@@ -276,9 +276,6 @@ def __init__(
# no active writes in progress.
self._max_position_of_local_instance = self._max_seen_allocated_stream_id

# This goes and fills out the above state from the database.
self._load_current_ids(db_conn, tables)

self._sequence_gen = build_sequence_generator(
db_conn=db_conn,
database_engine=db.engine,
@@ -303,6 +300,13 @@ def __init__(
positive=positive,
)

# This goes and fills out the above state from the database.
# This may read on the PostgreSQL sequence, and
# SequenceGenerator.check_consistency might have fixed up the sequence, which
# means the SequenceGenerator needs to be setup before we read the value from
# the sequence.
self._load_current_ids(db_conn, tables, sequence_name)

self._max_seen_allocated_stream_id = max(
self._current_positions.values(), default=1
)
@@ -327,6 +331,7 @@ def _load_current_ids(
self,
db_conn: LoggingDatabaseConnection,
tables: List[Tuple[str, str, str]],
sequence_name: str,
) -> None:
cur = db_conn.cursor(txn_name="_load_current_ids")

@@ -360,6 +365,18 @@ def _load_current_ids(
if instance in self._writers
}

# If we're a writer, we can assume we're at the end of the stream
# Usually, we would get that from the stream_positions, but in some cases,
# like if we rolled back Synapse, the stream_positions table might not be up to
# date. If we're using Postgres for the sequences, we can just use the current
# sequence value as our own position.
if self._instance_name in self._writers:
if isinstance(self._db.engine, PostgresEngine):
cur.execute(f"SELECT last_value FROM {sequence_name}")
row = cur.fetchone()
assert row is not None
self._current_positions[self._instance_name] = row[0]

# We set the `_persisted_upto_position` to be the minimum of all current
# positions. If empty we use the max stream ID from the DB table.
min_stream_id = min(self._current_positions.values(), default=None)
301 changes: 126 additions & 175 deletions tests/storage/test_id_generators.py
Original file line number Diff line number Diff line change
@@ -18,7 +18,7 @@
# [This file includes modifications made by New Vector Limited]
#
#
from typing import List, Optional
from typing import Dict, List, Optional

from twisted.test.proto_helpers import MemoryReactor

@@ -42,9 +42,13 @@


class MultiWriterIdGeneratorBase(HomeserverTestCase):
positive: bool = True
tables: List[str] = ["foobar"]

def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.db_pool: DatabasePool = self.store.db_pool
self.instances: Dict[str, MultiWriterIdGenerator] = {}

self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))

@@ -57,18 +61,22 @@ def _setup_db(self, txn: LoggingTransaction) -> None:
if USE_POSTGRES_FOR_TESTS:
txn.execute("CREATE SEQUENCE foobar_seq")

txn.execute(
"""
CREATE TABLE foobar (
stream_id BIGINT NOT NULL,
instance_name TEXT NOT NULL,
data TEXT
);
"""
)
for table in self.tables:
txn.execute(
"""
CREATE TABLE %s (
stream_id BIGINT NOT NULL,
instance_name TEXT NOT NULL,
data TEXT
);
"""
% (table,)
)

def _create_id_generator(
self, instance_name: str = "master", writers: Optional[List[str]] = None
self,
instance_name: str = "master",
writers: Optional[List[str]] = None,
) -> MultiWriterIdGenerator:
def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator:
return MultiWriterIdGenerator(
@@ -77,36 +85,93 @@ def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator:
notifier=self.hs.get_replication_notifier(),
stream_name="test_stream",
instance_name=instance_name,
tables=[("foobar", "instance_name", "stream_id")],
tables=[(table, "instance_name", "stream_id") for table in self.tables],
sequence_name="foobar_seq",
writers=writers or ["master"],
positive=self.positive,
)

return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
self.instances[instance_name] = self.get_success_or_raise(
self.db_pool.runWithConnection(_create)
)
return self.instances[instance_name]

def _replicate(self, instance_name: str) -> None:
"""Similate a replication event for the given instance."""

writer = self.instances[instance_name]
token = writer.get_current_token_for_writer(instance_name)
for generator in self.instances.values():
if writer != generator:
generator.advance(instance_name, token)

def _replicate_all(self) -> None:
"""Similate a replication event for all instances."""

def _insert_rows(self, instance_name: str, number: int) -> None:
for instance_name in self.instances:
self._replicate(instance_name)

def _insert_row(
self, instance_name: str, stream_id: int, table: Optional[str] = None
) -> None:
"""Insert one row as the given instance with given stream_id."""

if table is None:
table = self.tables[0]

factor = 1 if self.positive else -1

def _insert(txn: LoggingTransaction) -> None:
txn.execute(
"INSERT INTO %s VALUES (?, ?)" % (table,),
(
stream_id,
instance_name,
),
)
txn.execute(
"""
INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
""",
(instance_name, stream_id * factor, stream_id * factor),
)

self.get_success(self.db_pool.runInteraction("_insert_row", _insert))

def _insert_rows(
self,
instance_name: str,
number: int,
table: Optional[str] = None,
update_stream_table: bool = True,
) -> None:
"""Insert N rows as the given instance, inserting with stream IDs pulled
from the postgres sequence.
"""

if table is None:
table = self.tables[0]

factor = 1 if self.positive else -1

def _insert(txn: LoggingTransaction) -> None:
for _ in range(number):
next_val = self.seq_gen.get_next_id_txn(txn)
txn.execute(
"INSERT INTO foobar (stream_id, instance_name) VALUES (?, ?)",
(
next_val,
instance_name,
),
"INSERT INTO %s (stream_id, instance_name) VALUES (?, ?)"
% (table,),
(next_val, instance_name),
)

txn.execute(
"""
INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
""",
(instance_name, next_val, next_val),
)
if update_stream_table:
txn.execute(
"""
INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
""",
(instance_name, next_val * factor, next_val * factor),
)

self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))

@@ -353,7 +418,9 @@ def test_get_persisted_upto_position_get_next(self) -> None:

id_gen = self._create_id_generator("first", writers=["first", "second"])

self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
# When the writer is created, it assumes its own position is the current head of
# the sequence
self.assertEqual(id_gen.get_positions(), {"first": 5, "second": 5})

self.assertEqual(id_gen.get_persisted_upto_position(), 5)

@@ -375,11 +442,13 @@ def test_multi_instance(self) -> None:
correctly.
"""
self._insert_rows("first", 3)
self._insert_rows("second", 4)

first_id_gen = self._create_id_generator("first", writers=["first", "second"])

self._insert_rows("second", 4)
second_id_gen = self._create_id_generator("second", writers=["first", "second"])

self._replicate_all()

self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
@@ -398,6 +467,9 @@ async def _get_next_async() -> None:
self.assertEqual(
first_id_gen.get_positions(), {"first": 3, "second": 7}
)
self.assertEqual(
second_id_gen.get_positions(), {"first": 3, "second": 7}
)
self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)

self.get_success(_get_next_async())
@@ -432,18 +504,20 @@ def test_multi_instance_empty_row(self) -> None:
"""
# Insert some rows for two out of three of the ID gens.
self._insert_rows("first", 3)
self._insert_rows("second", 4)

first_id_gen = self._create_id_generator(
"first", writers=["first", "second", "third"]
)

self._insert_rows("second", 4)
second_id_gen = self._create_id_generator(
"second", writers=["first", "second", "third"]
)
third_id_gen = self._create_id_generator(
"third", writers=["first", "second", "third"]
)

self._replicate_all()

self.assertEqual(
first_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7}
)
@@ -546,11 +620,13 @@ async def _get_next_async() -> None:

def test_minimal_local_token(self) -> None:
self._insert_rows("first", 3)
self._insert_rows("second", 4)

first_id_gen = self._create_id_generator("first", writers=["first", "second"])

self._insert_rows("second", 4)
second_id_gen = self._create_id_generator("second", writers=["first", "second"])

self._replicate_all()

self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
self.assertEqual(first_id_gen.get_minimal_local_current_token(), 3)

@@ -562,15 +638,17 @@ def test_current_token_gap(self) -> None:
token when there are no writes.
"""
self._insert_rows("first", 3)
self._insert_rows("second", 4)

first_id_gen = self._create_id_generator(
"first", writers=["first", "second", "third"]
)

self._insert_rows("second", 4)
second_id_gen = self._create_id_generator(
"second", writers=["first", "second", "third"]
)

self._replicate_all()

self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7)
self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
self.assertEqual(second_id_gen.get_current_token(), 7)
@@ -609,68 +687,13 @@ async def _get_next_async() -> None:
self.assertEqual(second_id_gen.get_current_token(), 7)


class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
class BackwardsMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
"""Tests MultiWriterIdGenerator that produce *negative* stream IDs."""

if not USE_POSTGRES_FOR_TESTS:
skip = "Requires Postgres"

def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.db_pool: DatabasePool = self.store.db_pool

self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))

def _setup_db(self, txn: LoggingTransaction) -> None:
txn.execute("CREATE SEQUENCE foobar_seq")
txn.execute(
"""
CREATE TABLE foobar (
stream_id BIGINT NOT NULL,
instance_name TEXT NOT NULL,
data TEXT
);
"""
)

def _create_id_generator(
self, instance_name: str = "master", writers: Optional[List[str]] = None
) -> MultiWriterIdGenerator:
def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator:
return MultiWriterIdGenerator(
conn,
self.db_pool,
notifier=self.hs.get_replication_notifier(),
stream_name="test_stream",
instance_name=instance_name,
tables=[("foobar", "instance_name", "stream_id")],
sequence_name="foobar_seq",
writers=writers or ["master"],
positive=False,
)

return self.get_success(self.db_pool.runWithConnection(_create))

def _insert_row(self, instance_name: str, stream_id: int) -> None:
"""Insert one row as the given instance with given stream_id."""

def _insert(txn: LoggingTransaction) -> None:
txn.execute(
"INSERT INTO foobar VALUES (?, ?)",
(
stream_id,
instance_name,
),
)
txn.execute(
"""
INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
""",
(instance_name, -stream_id, -stream_id),
)

self.get_success(self.db_pool.runInteraction("_insert_row", _insert))
positive = False

def test_single_instance(self) -> None:
"""Test that reads and writes from a single process are handled
@@ -716,7 +739,7 @@ def test_multiple_instance(self) -> None:
async def _get_next_async() -> None:
async with id_gen_1.get_next() as stream_id:
self._insert_row("first", stream_id)
id_gen_2.advance("first", stream_id)
self._replicate("first")

self.get_success(_get_next_async())

@@ -728,7 +751,7 @@ async def _get_next_async() -> None:
async def _get_next_async2() -> None:
async with id_gen_2.get_next() as stream_id:
self._insert_row("second", stream_id)
id_gen_1.advance("second", stream_id)
self._replicate("second")

self.get_success(_get_next_async2())

@@ -738,98 +761,26 @@ async def _get_next_async2() -> None:
self.assertEqual(id_gen_2.get_persisted_upto_position(), -2)


class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
class MultiTableMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
if not USE_POSTGRES_FOR_TESTS:
skip = "Requires Postgres"

def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.db_pool: DatabasePool = self.store.db_pool

self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))

def _setup_db(self, txn: LoggingTransaction) -> None:
txn.execute("CREATE SEQUENCE foobar_seq")
txn.execute(
"""
CREATE TABLE foobar1 (
stream_id BIGINT NOT NULL,
instance_name TEXT NOT NULL,
data TEXT
);
"""
)

txn.execute(
"""
CREATE TABLE foobar2 (
stream_id BIGINT NOT NULL,
instance_name TEXT NOT NULL,
data TEXT
);
"""
)

def _create_id_generator(
self, instance_name: str = "master", writers: Optional[List[str]] = None
) -> MultiWriterIdGenerator:
def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator:
return MultiWriterIdGenerator(
conn,
self.db_pool,
notifier=self.hs.get_replication_notifier(),
stream_name="test_stream",
instance_name=instance_name,
tables=[
("foobar1", "instance_name", "stream_id"),
("foobar2", "instance_name", "stream_id"),
],
sequence_name="foobar_seq",
writers=writers or ["master"],
)

return self.get_success_or_raise(self.db_pool.runWithConnection(_create))

def _insert_rows(
self,
table: str,
instance_name: str,
number: int,
update_stream_table: bool = True,
) -> None:
"""Insert N rows as the given instance, inserting with stream IDs pulled
from the postgres sequence.
"""

def _insert(txn: LoggingTransaction) -> None:
for _ in range(number):
txn.execute(
"INSERT INTO %s VALUES (nextval('foobar_seq'), ?)" % (table,),
(instance_name,),
)
if update_stream_table:
txn.execute(
"""
INSERT INTO stream_positions VALUES ('test_stream', ?, lastval())
ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = lastval()
""",
(instance_name,),
)

self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
tables = ["foobar1", "foobar2"]

def test_load_existing_stream(self) -> None:
"""Test creating ID gens with multiple tables that have rows from after
the position in `stream_positions` table.
"""
self._insert_rows("foobar1", "first", 3)
self._insert_rows("foobar2", "second", 3)
self._insert_rows("foobar2", "second", 1, update_stream_table=False)

self._insert_rows("first", 3, table="foobar1")
first_id_gen = self._create_id_generator("first", writers=["first", "second"])

self._insert_rows("second", 3, table="foobar2")
self._insert_rows("second", 1, table="foobar2", update_stream_table=False)
second_id_gen = self._create_id_generator("second", writers=["first", "second"])

self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 6})
self._replicate_all()

self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
Loading