Skip to content

Commit 092402d

Browse files
committed
fix race condition to avoid downloading the same artifact in multiple threads and trying to store it in the same location of the artifact cache
1 parent 6349005 commit 092402d

File tree

3 files changed

+61
-12
lines changed

3 files changed

+61
-12
lines changed

src/poetry/utils/cache.py

+16-6
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
import json
66
import logging
77
import shutil
8+
import threading
89
import time
910

11+
from collections import defaultdict
1012
from pathlib import Path
1113
from typing import TYPE_CHECKING
1214
from typing import Any
@@ -188,6 +190,9 @@ def _deserialize(self, data_raw: bytes) -> CacheItem[T]:
188190
class ArtifactCache:
189191
def __init__(self, *, cache_dir: Path) -> None:
190192
self._cache_dir = cache_dir
193+
self._archive_locks: defaultdict[Path, threading.Lock] = defaultdict(
194+
threading.Lock
195+
)
191196

192197
def get_cache_directory_for_link(self, link: Link) -> Path:
193198
key_parts = {"url": link.url_without_fragment}
@@ -253,13 +258,18 @@ def get_cached_archive_for_link(
253258
cache_dir, strict=strict, filename=link.filename, env=env
254259
)
255260
if cached_archive is None and strict and download_func is not None:
256-
cache_dir.mkdir(parents=True, exist_ok=True)
257261
cached_archive = cache_dir / link.filename
258-
try:
259-
download_func(link.url, cached_archive)
260-
except BaseException:
261-
cached_archive.unlink(missing_ok=True)
262-
raise
262+
with self._archive_locks[cached_archive]:
263+
# Check again if the archive exists (under the lock) to avoid
264+
# duplicate downloads because it may have already been downloaded
265+
# by another thread in the meantime
266+
if not cached_archive.exists():
267+
cache_dir.mkdir(parents=True, exist_ok=True)
268+
try:
269+
download_func(link.url, cached_archive)
270+
except BaseException:
271+
cached_archive.unlink(missing_ok=True)
272+
raise
263273

264274
return cached_archive
265275

tests/installation/test_executor.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -582,14 +582,16 @@ def test_executor_should_delete_incomplete_downloads(
582582
pool: RepositoryPool,
583583
mock_file_downloads: None,
584584
env: MockEnv,
585-
fixture_dir: FixtureDirGetter,
586585
) -> None:
587-
fixture = fixture_dir("distributions") / "demo-0.1.0-py2.py3-none-any.whl"
588-
destination_fixture = tmp_path / "tomlkit-0.5.3-py2.py3-none-any.whl"
589-
shutil.copyfile(str(fixture), str(destination_fixture))
586+
cached_archive = tmp_path / "tomlkit-0.5.3-py2.py3-none-any.whl"
587+
588+
def download_fail(*_: Any) -> None:
589+
cached_archive.touch() # broken archive
590+
raise Exception("Download error")
591+
590592
mocker.patch(
591593
"poetry.installation.executor.Executor._download_archive",
592-
side_effect=Exception("Download error"),
594+
side_effect=download_fail,
593595
)
594596
mocker.patch(
595597
"poetry.utils.cache.ArtifactCache._get_cached_archive",
@@ -607,7 +609,7 @@ def test_executor_should_delete_incomplete_downloads(
607609
with pytest.raises(Exception, match="Download error"):
608610
executor._download(Install(Package("tomlkit", "0.5.3")))
609611

610-
assert not destination_fixture.exists()
612+
assert not cached_archive.exists()
611613

612614

613615
def verify_installed_distribution(

tests/utils/test_cache.py

+37
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from __future__ import annotations
22

3+
import concurrent.futures
34
import shutil
5+
import traceback
46

57
from pathlib import Path
68
from typing import TYPE_CHECKING
@@ -322,6 +324,41 @@ def test_get_found_cached_archive_for_link(
322324
assert Path(cached) == archive
323325

324326

327+
def test_get_cached_archive_for_link_no_race_condition(
328+
tmp_path: Path, mocker: MockerFixture
329+
) -> None:
330+
cache = ArtifactCache(cache_dir=tmp_path)
331+
link = Link("https://files.python-poetry.org/demo-0.1.0.tar.gz")
332+
333+
def replace_file(_: str, dest: Path) -> None:
334+
dest.unlink(missing_ok=True)
335+
# write some data (so it takes a while) to provoke possible race conditions
336+
dest.write_text("a" * 2**20)
337+
338+
download_mock = mocker.Mock(side_effect=replace_file)
339+
340+
with concurrent.futures.ThreadPoolExecutor() as executor:
341+
tasks = []
342+
for _ in range(4):
343+
tasks.append(
344+
executor.submit(
345+
cache.get_cached_archive_for_link, # type: ignore[arg-type]
346+
link,
347+
strict=True,
348+
download_func=download_mock,
349+
)
350+
)
351+
concurrent.futures.wait(tasks)
352+
results = set()
353+
for task in tasks:
354+
try:
355+
results.add(task.result())
356+
except Exception:
357+
pytest.fail(traceback.format_exc())
358+
assert results == {cache.get_cache_directory_for_link(link) / link.filename}
359+
download_mock.assert_called_once()
360+
361+
325362
def test_get_cached_archive_for_git() -> None:
326363
"""Smoke test that checks that no assertion is raised."""
327364
cache = ArtifactCache(cache_dir=Path())

0 commit comments

Comments
 (0)