17
17
import os
18
18
import shutil
19
19
from io import BytesIO
20
- from typing import IO , TYPE_CHECKING , Dict , List , Optional , Set , Tuple
20
+ from typing import IO , TYPE_CHECKING , Any , Dict , List , Optional , Set , Tuple
21
21
22
22
from matrix_common .types .mxc_uri import MXCUri
23
23
32
32
NotFoundError ,
33
33
RequestSendFailed ,
34
34
SynapseError ,
35
+ cs_error ,
35
36
)
36
37
from synapse .config .repository import ThumbnailRequirement
38
+ from synapse .http .server import respond_with_json
37
39
from synapse .http .site import SynapseRequest
38
40
from synapse .logging .context import defer_to_thread
39
41
from synapse .media ._base import (
@@ -301,8 +303,62 @@ async def create_content(
301
303
302
304
return MXCUri (self .server_name , media_id )
303
305
306
+ def respond_not_yet_uploaded (self , request : SynapseRequest ) -> None :
307
+ respond_with_json (
308
+ request ,
309
+ 404 ,
310
+ cs_error ("Media has not been uploaded yet" , code = Codes .NOT_YET_UPLOADED ),
311
+ send_cors = True ,
312
+ )
313
+
314
+ async def get_local_media_info (
315
+ self , request : SynapseRequest , media_id : str , max_timeout_ms : int
316
+ ) -> Optional [Dict [str , Any ]]:
317
+ """Gets the info dictionary for given local media ID. If the media has
318
+ not been uploaded yet, this function will wait up to ``max_timeout_ms``
319
+ milliseconds for the media to be uploaded.
320
+ Args:
321
+ request: The incoming request.
322
+ media_id: The media ID of the content. (This is the same as
323
+ the file_id for local content.)
324
+ max_timeout_ms: the maximum number of milliseconds to wait for the
325
+ media to be uploaded.
326
+ Returns:
327
+ Either the info dictionary for the given local media ID or
328
+ ``None``. If ``None``, then no further processing is necessary as
329
+ this function will send the necessary JSON response.
330
+ """
331
+ wait_until = self .clock .time_msec () + max_timeout_ms
332
+ while True :
333
+ # Get the info for the media
334
+ media_info = await self .store .get_local_media (media_id )
335
+ if not media_info :
336
+ respond_404 (request )
337
+ return None
338
+
339
+ if media_info ["quarantined_by" ]:
340
+ logger .info ("Media is quarantined" )
341
+ respond_404 (request )
342
+ return None
343
+
344
+ # The file has been uploaded, so stop looping
345
+ if media_info .get ("media_length" ) is not None :
346
+ return media_info
347
+
348
+ if self .clock .time_msec () >= wait_until :
349
+ break
350
+
351
+ await self .clock .sleep (0.5 )
352
+
353
+ self .respond_not_yet_uploaded (request )
354
+ return None
355
+
304
356
async def get_local_media (
305
- self , request : SynapseRequest , media_id : str , name : Optional [str ]
357
+ self ,
358
+ request : SynapseRequest ,
359
+ media_id : str ,
360
+ name : Optional [str ],
361
+ max_timeout_ms : int ,
306
362
) -> None :
307
363
"""Responds to requests for local media, if exists, or returns 404.
308
364
@@ -312,13 +368,14 @@ async def get_local_media(
312
368
the file_id for local content.)
313
369
name: Optional name that, if specified, will be used as
314
370
the filename in the Content-Disposition header of the response.
371
+ max_timeout_ms: the maximum number of milliseconds to wait for the
372
+ media to be uploaded.
315
373
316
374
Returns:
317
375
Resolves once a response has successfully been written to request
318
376
"""
319
- media_info = await self .store .get_local_media (media_id )
320
- if not media_info or media_info ["quarantined_by" ]:
321
- respond_404 (request )
377
+ media_info = await self .get_local_media_info (request , media_id , max_timeout_ms )
378
+ if not media_info :
322
379
return
323
380
324
381
self .mark_recently_accessed (None , media_id )
@@ -343,6 +400,7 @@ async def get_remote_media(
343
400
server_name : str ,
344
401
media_id : str ,
345
402
name : Optional [str ],
403
+ max_timeout_ms : int ,
346
404
) -> None :
347
405
"""Respond to requests for remote media.
348
406
@@ -352,6 +410,8 @@ async def get_remote_media(
352
410
media_id: The media ID of the content (as defined by the remote server).
353
411
name: Optional name that, if specified, will be used as
354
412
the filename in the Content-Disposition header of the response.
413
+ max_timeout_ms: the maximum number of milliseconds to wait for the
414
+ media to be uploaded.
355
415
356
416
Returns:
357
417
Resolves once a response has successfully been written to request
@@ -377,27 +437,31 @@ async def get_remote_media(
377
437
key = (server_name , media_id )
378
438
async with self .remote_media_linearizer .queue (key ):
379
439
responder , media_info = await self ._get_remote_media_impl (
380
- server_name , media_id
440
+ server_name , media_id , max_timeout_ms
381
441
)
382
442
383
443
# We deliberately stream the file outside the lock
384
- if responder :
444
+ if responder and media_info :
385
445
media_type = media_info ["media_type" ]
386
446
media_length = media_info ["media_length" ]
387
447
upload_name = name if name else media_info ["upload_name" ]
388
448
await respond_with_responder (
389
449
request , responder , media_type , media_length , upload_name
390
450
)
391
451
else :
392
- respond_404 (request )
452
+ self . respond_not_yet_uploaded (request )
393
453
394
- async def get_remote_media_info (self , server_name : str , media_id : str ) -> dict :
454
+ async def get_remote_media_info (
455
+ self , server_name : str , media_id : str , max_timeout_ms : int
456
+ ) -> dict :
395
457
"""Gets the media info associated with the remote file, downloading
396
458
if necessary.
397
459
398
460
Args:
399
461
server_name: Remote server_name where the media originated.
400
462
media_id: The media ID of the content (as defined by the remote server).
463
+ max_timeout_ms: the maximum number of milliseconds to wait for the
464
+ media to be uploaded.
401
465
402
466
Returns:
403
467
The media info of the file
@@ -413,7 +477,7 @@ async def get_remote_media_info(self, server_name: str, media_id: str) -> dict:
413
477
key = (server_name , media_id )
414
478
async with self .remote_media_linearizer .queue (key ):
415
479
responder , media_info = await self ._get_remote_media_impl (
416
- server_name , media_id
480
+ server_name , media_id , max_timeout_ms
417
481
)
418
482
419
483
# Ensure we actually use the responder so that it releases resources
@@ -424,7 +488,7 @@ async def get_remote_media_info(self, server_name: str, media_id: str) -> dict:
424
488
return media_info
425
489
426
490
async def _get_remote_media_impl (
427
- self , server_name : str , media_id : str
491
+ self , server_name : str , media_id : str , max_timeout_ms : int
428
492
) -> Tuple [Optional [Responder ], dict ]:
429
493
"""Looks for media in local cache, if not there then attempt to
430
494
download from remote server.
@@ -433,6 +497,8 @@ async def _get_remote_media_impl(
433
497
server_name: Remote server_name where the media originated.
434
498
media_id: The media ID of the content (as defined by the
435
499
remote server).
500
+ max_timeout_ms: the maximum number of milliseconds to wait for the
501
+ media to be uploaded.
436
502
437
503
Returns:
438
504
A tuple of responder and the media info of the file.
@@ -463,8 +529,7 @@ async def _get_remote_media_impl(
463
529
464
530
try :
465
531
media_info = await self ._download_remote_file (
466
- server_name ,
467
- media_id ,
532
+ server_name , media_id , max_timeout_ms
468
533
)
469
534
except SynapseError :
470
535
raise
@@ -497,6 +562,7 @@ async def _download_remote_file(
497
562
self ,
498
563
server_name : str ,
499
564
media_id : str ,
565
+ max_timeout_ms : int ,
500
566
) -> dict :
501
567
"""Attempt to download the remote file from the given server name,
502
568
using the given file_id as the local id.
@@ -506,7 +572,8 @@ async def _download_remote_file(
506
572
media_id: The media ID of the content (as defined by the
507
573
remote server). This is different than the file_id, which is
508
574
locally generated.
509
- file_id: Local file ID
575
+ max_timeout_ms: the maximum number of milliseconds to wait for the
576
+ media to be uploaded.
510
577
511
578
Returns:
512
579
The media info of the file.
@@ -530,7 +597,8 @@ async def _download_remote_file(
530
597
# tell the remote server to 404 if it doesn't
531
598
# recognise the server_name, to make sure we don't
532
599
# end up with a routing loop.
533
- "allow_remote" : "false"
600
+ "allow_remote" : "false" ,
601
+ "timeout_ms" : str (max_timeout_ms ),
534
602
},
535
603
)
536
604
except RequestSendFailed as e :
0 commit comments