17
17
import os
18
18
import shutil
19
19
from io import BytesIO
20
- from typing import IO , TYPE_CHECKING , Dict , List , Optional , Set , Tuple
20
+ from typing import IO , TYPE_CHECKING , Any , Dict , List , Optional , Set , Tuple
21
21
22
22
from matrix_common .types .mxc_uri import MXCUri
23
23
32
32
NotFoundError ,
33
33
RequestSendFailed ,
34
34
SynapseError ,
35
+ cs_error ,
35
36
)
36
37
from synapse .config .repository import ThumbnailRequirement
38
+ from synapse .http .server import respond_with_json
37
39
from synapse .http .site import SynapseRequest
38
40
from synapse .logging .context import defer_to_thread
39
41
from synapse .media ._base import (
@@ -300,8 +302,62 @@ async def create_content(
300
302
301
303
return MXCUri (self .server_name , media_id )
302
304
305
+ def respond_not_yet_uploaded (self , request : SynapseRequest ) -> None :
306
+ respond_with_json (
307
+ request ,
308
+ 404 ,
309
+ cs_error ("Media has not been uploaded yet" , code = Codes .NOT_YET_UPLOADED ),
310
+ send_cors = True ,
311
+ )
312
+
313
+ async def get_local_media_info (
314
+ self , request : SynapseRequest , media_id : str , max_timeout_ms : int
315
+ ) -> Optional [Dict [str , Any ]]:
316
+ """Gets the info dictionary for given local media ID. If the media has
317
+ not been uploaded yet, this function will wait up to ``max_timeout_ms``
318
+ milliseconds for the media to be uploaded.
319
+ Args:
320
+ request: The incoming request.
321
+ media_id: The media ID of the content. (This is the same as
322
+ the file_id for local content.)
323
+ max_timeout_ms: the maximum number of milliseconds to wait for the
324
+ media to be uploaded.
325
+ Returns:
326
+ Either the info dictionary for the given local media ID or
327
+ ``None``. If ``None``, then no further processing is necessary as
328
+ this function will send the necessary JSON response.
329
+ """
330
+ wait_until = self .clock .time_msec () + max_timeout_ms
331
+ while True :
332
+ # Get the info for the media
333
+ media_info = await self .store .get_local_media (media_id )
334
+ if not media_info :
335
+ respond_404 (request )
336
+ return None
337
+
338
+ if media_info ["quarantined_by" ]:
339
+ logger .info ("Media is quarantined" )
340
+ respond_404 (request )
341
+ return None
342
+
343
+ # The file has been uploaded, so stop looping
344
+ if media_info .get ("media_length" ) is not None :
345
+ return media_info
346
+
347
+ if self .clock .time_msec () >= wait_until :
348
+ break
349
+
350
+ await self .clock .sleep (0.5 )
351
+
352
+ self .respond_not_yet_uploaded (request )
353
+ return None
354
+
303
355
async def get_local_media (
304
- self , request : SynapseRequest , media_id : str , name : Optional [str ]
356
+ self ,
357
+ request : SynapseRequest ,
358
+ media_id : str ,
359
+ name : Optional [str ],
360
+ max_timeout_ms : int ,
305
361
) -> None :
306
362
"""Responds to requests for local media, if exists, or returns 404.
307
363
@@ -311,13 +367,14 @@ async def get_local_media(
311
367
the file_id for local content.)
312
368
name: Optional name that, if specified, will be used as
313
369
the filename in the Content-Disposition header of the response.
370
+ max_timeout_ms: the maximum number of milliseconds to wait for the
371
+ media to be uploaded.
314
372
315
373
Returns:
316
374
Resolves once a response has successfully been written to request
317
375
"""
318
- media_info = await self .store .get_local_media (media_id )
319
- if not media_info or media_info ["quarantined_by" ]:
320
- respond_404 (request )
376
+ media_info = await self .get_local_media_info (request , media_id , max_timeout_ms )
377
+ if not media_info :
321
378
return
322
379
323
380
self .mark_recently_accessed (None , media_id )
@@ -342,6 +399,7 @@ async def get_remote_media(
342
399
server_name : str ,
343
400
media_id : str ,
344
401
name : Optional [str ],
402
+ max_timeout_ms : int ,
345
403
) -> None :
346
404
"""Respond to requests for remote media.
347
405
@@ -351,6 +409,8 @@ async def get_remote_media(
351
409
media_id: The media ID of the content (as defined by the remote server).
352
410
name: Optional name that, if specified, will be used as
353
411
the filename in the Content-Disposition header of the response.
412
+ max_timeout_ms: the maximum number of milliseconds to wait for the
413
+ media to be uploaded.
354
414
355
415
Returns:
356
416
Resolves once a response has successfully been written to request
@@ -368,27 +428,31 @@ async def get_remote_media(
368
428
key = (server_name , media_id )
369
429
async with self .remote_media_linearizer .queue (key ):
370
430
responder , media_info = await self ._get_remote_media_impl (
371
- server_name , media_id
431
+ server_name , media_id , max_timeout_ms
372
432
)
373
433
374
434
# We deliberately stream the file outside the lock
375
- if responder :
435
+ if responder and media_info :
376
436
media_type = media_info ["media_type" ]
377
437
media_length = media_info ["media_length" ]
378
438
upload_name = name if name else media_info ["upload_name" ]
379
439
await respond_with_responder (
380
440
request , responder , media_type , media_length , upload_name
381
441
)
382
442
else :
383
- respond_404 (request )
443
+ self . respond_not_yet_uploaded (request )
384
444
385
- async def get_remote_media_info (self , server_name : str , media_id : str ) -> dict :
445
+ async def get_remote_media_info (
446
+ self , server_name : str , media_id : str , max_timeout_ms : int
447
+ ) -> dict :
386
448
"""Gets the media info associated with the remote file, downloading
387
449
if necessary.
388
450
389
451
Args:
390
452
server_name: Remote server_name where the media originated.
391
453
media_id: The media ID of the content (as defined by the remote server).
454
+ max_timeout_ms: the maximum number of milliseconds to wait for the
455
+ media to be uploaded.
392
456
393
457
Returns:
394
458
The media info of the file
@@ -404,7 +468,7 @@ async def get_remote_media_info(self, server_name: str, media_id: str) -> dict:
404
468
key = (server_name , media_id )
405
469
async with self .remote_media_linearizer .queue (key ):
406
470
responder , media_info = await self ._get_remote_media_impl (
407
- server_name , media_id
471
+ server_name , media_id , max_timeout_ms
408
472
)
409
473
410
474
# Ensure we actually use the responder so that it releases resources
@@ -415,7 +479,7 @@ async def get_remote_media_info(self, server_name: str, media_id: str) -> dict:
415
479
return media_info
416
480
417
481
async def _get_remote_media_impl (
418
- self , server_name : str , media_id : str
482
+ self , server_name : str , media_id : str , max_timeout_ms : int
419
483
) -> Tuple [Optional [Responder ], dict ]:
420
484
"""Looks for media in local cache, if not there then attempt to
421
485
download from remote server.
@@ -424,6 +488,8 @@ async def _get_remote_media_impl(
424
488
server_name: Remote server_name where the media originated.
425
489
media_id: The media ID of the content (as defined by the
426
490
remote server).
491
+ max_timeout_ms: the maximum number of milliseconds to wait for the
492
+ media to be uploaded.
427
493
428
494
Returns:
429
495
A tuple of responder and the media info of the file.
@@ -454,8 +520,7 @@ async def _get_remote_media_impl(
454
520
455
521
try :
456
522
media_info = await self ._download_remote_file (
457
- server_name ,
458
- media_id ,
523
+ server_name , media_id , max_timeout_ms
459
524
)
460
525
except SynapseError :
461
526
raise
@@ -488,6 +553,7 @@ async def _download_remote_file(
488
553
self ,
489
554
server_name : str ,
490
555
media_id : str ,
556
+ max_timeout_ms : int ,
491
557
) -> dict :
492
558
"""Attempt to download the remote file from the given server name,
493
559
using the given file_id as the local id.
@@ -497,7 +563,8 @@ async def _download_remote_file(
497
563
media_id: The media ID of the content (as defined by the
498
564
remote server). This is different than the file_id, which is
499
565
locally generated.
500
- file_id: Local file ID
566
+ max_timeout_ms: the maximum number of milliseconds to wait for the
567
+ media to be uploaded.
501
568
502
569
Returns:
503
570
The media info of the file.
@@ -521,7 +588,8 @@ async def _download_remote_file(
521
588
# tell the remote server to 404 if it doesn't
522
589
# recognise the server_name, to make sure we don't
523
590
# end up with a routing loop.
524
- "allow_remote" : "false"
591
+ "allow_remote" : "false" ,
592
+ "timeout_ms" : str (max_timeout_ms ),
525
593
},
526
594
)
527
595
except RequestSendFailed as e :
0 commit comments