14
14
# limitations under the License.
15
15
16
16
import contextlib
17
+ import heapq
17
18
import threading
18
19
from collections import deque
19
- from typing import Dict , Set
20
+ from typing import Dict , List , Set
20
21
21
22
from typing_extensions import Deque
22
23
@@ -210,6 +211,23 @@ def __init__(
210
211
# should be less than the minimum of this set (if not empty).
211
212
self ._unfinished_ids = set () # type: Set[int]
212
213
214
+ # We track the max position where we know everything before has been
215
+ # persisted. This is done by a) looking at the min across all instances
216
+ # and b) noting that if we have seen a run of persisted positions
217
+ # without gaps (e.g. 5, 6, 7) then we can skip forward (e.g. to 7).
218
+ #
219
+ # Note: There is no guarentee that the IDs generated by the sequence
220
+ # will be gapless; gaps can form when e.g. a transaction was rolled
221
+ # back. This means that sometimes we won't be able to skip forward the
222
+ # position even though everything has been persisted. However, since
223
+ # gaps should be relatively rare it's still worth doing the book keeping
224
+ # that allows us to skip forwards when there are gapless runs of
225
+ # positions.
226
+ self ._persisted_upto_position = (
227
+ min (self ._current_positions .values ()) if self ._current_positions else 0
228
+ )
229
+ self ._known_persisted_positions = [] # type: List[int]
230
+
213
231
self ._sequence_gen = PostgresSequenceGenerator (sequence_name )
214
232
215
233
def _load_current_ids (
@@ -234,9 +252,12 @@ def _load_current_ids(
234
252
235
253
return current_positions
236
254
237
- def _load_next_id_txn (self , txn ):
255
+ def _load_next_id_txn (self , txn ) -> int :
238
256
return self ._sequence_gen .get_next_id_txn (txn )
239
257
258
+ def _load_next_mult_id_txn (self , txn , n : int ) -> List [int ]:
259
+ return self ._sequence_gen .get_next_mult_txn (txn , n )
260
+
240
261
async def get_next (self ):
241
262
"""
242
263
Usage:
@@ -262,6 +283,34 @@ def manager():
262
283
263
284
return manager ()
264
285
286
+ async def get_next_mult (self , n : int ):
287
+ """
288
+ Usage:
289
+ with await stream_id_gen.get_next_mult(5) as stream_ids:
290
+ # ... persist events ...
291
+ """
292
+ next_ids = await self ._db .runInteraction (
293
+ "_load_next_mult_id" , self ._load_next_mult_id_txn , n
294
+ )
295
+
296
+ # Assert the fetched ID is actually greater than any ID we've already
297
+ # seen. If not, then the sequence and table have got out of sync
298
+ # somehow.
299
+ assert max (self .get_positions ().values (), default = 0 ) < min (next_ids )
300
+
301
+ with self ._lock :
302
+ self ._unfinished_ids .update (next_ids )
303
+
304
+ @contextlib .contextmanager
305
+ def manager ():
306
+ try :
307
+ yield next_ids
308
+ finally :
309
+ for i in next_ids :
310
+ self ._mark_id_as_finished (i )
311
+
312
+ return manager ()
313
+
265
314
def get_next_txn (self , txn : LoggingTransaction ):
266
315
"""
267
316
Usage:
@@ -326,3 +375,53 @@ def advance(self, instance_name: str, new_id: int):
326
375
self ._current_positions [instance_name ] = max (
327
376
new_id , self ._current_positions .get (instance_name , 0 )
328
377
)
378
+
379
+ self ._add_persisted_position (new_id )
380
+
381
+ def get_persisted_upto_position (self ) -> int :
382
+ """Get the max position where all previous positions have been
383
+ persisted.
384
+
385
+ Note: In the worst case scenario this will be equal to the minimum
386
+ position across writers. This means that the returned position here can
387
+ lag if one writer doesn't write very often.
388
+ """
389
+
390
+ with self ._lock :
391
+ return self ._persisted_upto_position
392
+
393
+ def _add_persisted_position (self , new_id : int ):
394
+ """Record that we have persisted a position.
395
+
396
+ This is used to keep the `_current_positions` up to date.
397
+ """
398
+
399
+ # We require that the lock is locked by caller
400
+ assert self ._lock .locked ()
401
+
402
+ heapq .heappush (self ._known_persisted_positions , new_id )
403
+
404
+ # We move the current min position up if the minimum current positions
405
+ # of all instances is higher (since by definition all positions less
406
+ # that that have been persisted).
407
+ min_curr = min (self ._current_positions .values ())
408
+ self ._persisted_upto_position = max (min_curr , self ._persisted_upto_position )
409
+
410
+ # We now iterate through the seen positions, discarding those that are
411
+ # less than the current min positions, and incrementing the min position
412
+ # if its exactly one greater.
413
+ #
414
+ # This is also where we discard items from `_known_persisted_positions`
415
+ # (to ensure the list doesn't infinitely grow).
416
+ while self ._known_persisted_positions :
417
+ if self ._known_persisted_positions [0 ] <= self ._persisted_upto_position :
418
+ heapq .heappop (self ._known_persisted_positions )
419
+ elif (
420
+ self ._known_persisted_positions [0 ] == self ._persisted_upto_position + 1
421
+ ):
422
+ heapq .heappop (self ._known_persisted_positions )
423
+ self ._persisted_upto_position += 1
424
+ else :
425
+ # There was a gap in seen positions, so there is nothing more to
426
+ # do.
427
+ break
0 commit comments