From bde3e295de56a8cf71de790a9fbfdf80fd025525 Mon Sep 17 00:00:00 2001 From: Adam Binford Date: Sun, 19 Jan 2025 15:44:51 -0500 Subject: [PATCH] Add get to CLI --- Cargo.lock | 1 + Cargo.toml | 1 + python/Cargo.toml | 1 + python/conftest.py | 13 +++- python/hdfs_native/__init__.py | 12 +++ python/hdfs_native/_internal.pyi | 6 ++ python/hdfs_native/cli.py | 121 ++++++++++++++++++++++++++++++- python/src/lib.rs | 40 +++++++++- python/tests/test_cli.py | 59 ++++++++++++--- python/tests/test_integration.py | 13 +++- rust/Cargo.toml | 2 +- 11 files changed, 249 insertions(+), 20 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8235ca6..5e6f9bd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -810,6 +810,7 @@ version = "0.11.0" dependencies = [ "bytes", "env_logger", + "futures", "hdfs-native", "log", "pyo3", diff --git a/Cargo.toml b/Cargo.toml index b44de33..a88b274 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,7 @@ resolver = "2" [workspace.dependencies] bytes = "1" env_logger = "0.11" +futures = "0.3" log = "0.4" thiserror = "2" tokio = "1" diff --git a/python/Cargo.toml b/python/Cargo.toml index 18c542a..36e7645 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -27,6 +27,7 @@ name = "hdfs_native._internal" [dependencies] bytes = { workspace = true } env_logger = { workspace = true } +futures = { workspace = true } hdfs-native = { path = "../rust" } log = { workspace = true } pyo3 = { version = "0.23", features = ["extension-module", "abi3", "abi3-py39"] } diff --git a/python/conftest.py b/python/conftest.py index bc2362f..5f4699b 100644 --- a/python/conftest.py +++ b/python/conftest.py @@ -43,9 +43,16 @@ def minidfs(): child.kill() -@pytest.fixture(scope="module") -def client(minidfs: str) -> Client: - return Client(minidfs) +@pytest.fixture +def client(minidfs: str): + client = Client(minidfs) + + try: + yield client + finally: + statuses = list(client.list_status("/")) + for status in statuses: + client.delete(status.path, True) @pytest.fixture(scope="module") diff --git a/python/hdfs_native/__init__.py b/python/hdfs_native/__init__.py index b9d16ba..102998a 100644 --- a/python/hdfs_native/__init__.py +++ b/python/hdfs_native/__init__.py @@ -39,6 +39,9 @@ def __init__(self, inner: "RawFileReader"): def __len__(self) -> int: return self.inner.file_length() + def __iter__(self) -> Iterator[bytes]: + return self.read_range_stream(0, len(self)) + def __enter__(self): # Don't need to do anything special here return self @@ -82,6 +85,15 @@ def read_range(self, offset: int, len: int) -> bytes: """Read `len` bytes from the file starting at `offset`. Doesn't affect the position in the file""" return self.inner.read_range(offset, len) + def read_range_stream(self, offset: int, len: int) -> Iterator[bytes]: + """ + Read `len` bytes from the file starting at `offset` as an iterator of bytes. Doesn't affect + the position in the file. + + This is the most efficient way to iteratively read a file. + """ + return self.inner.read_range_stream(offset, len) + def close(self) -> None: pass diff --git a/python/hdfs_native/_internal.pyi b/python/hdfs_native/_internal.pyi index 49174dd..e82ae7d 100644 --- a/python/hdfs_native/_internal.pyi +++ b/python/hdfs_native/_internal.pyi @@ -80,6 +80,12 @@ class RawFileReader: def read_range(self, offset: int, len: int) -> bytes: """Read `len` bytes from the file starting at `offset`. Doesn't affect the position in the file""" + def read_range_stream(self, offset: int, len: int) -> Iterator[bytes]: + """ + Read `len` bytes from the file starting at `offset` as an iterator of bytes. Doesn't affect + the position in the file. + """ + class RawFileWriter: def write(self, buf: Buffer) -> int: """Writes `buf` to the file""" diff --git a/python/hdfs_native/cli.py b/python/hdfs_native/cli.py index d7158ef..3fb602c 100644 --- a/python/hdfs_native/cli.py +++ b/python/hdfs_native/cli.py @@ -1,9 +1,11 @@ import functools import os import re +import shutil import sys from argparse import ArgumentParser, Namespace -from typing import List, Optional, Sequence +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import List, Optional, Sequence, Tuple from urllib.parse import urlparse from hdfs_native import Client @@ -50,6 +52,42 @@ def _glob_path(client: Client, glob: str) -> List[str]: return [glob] +def _download_file( + client: Client, + remote_src: str, + local_dst: str, + force: bool = False, + preserve: bool = False, +) -> None: + if not force and os.path.exists(local_dst): + raise FileExistsError(f"{local_dst} already exists, use --force to overwrite") + + with client.read(remote_src) as remote_file: + with open(local_dst, "wb") as local_file: + for chunk in remote_file.read_range_stream(0, len(remote_file)): + local_file.write(chunk) + + if preserve: + status = client.get_file_info(remote_src) + os.utime( + local_dst, + (status.access_time / 1000, status.modification_time / 1000), + ) + os.chmod(local_dst, status.permission) + + +def _upload_file( + client: Client, + local_src: str, + remote_dst: str, + force: bool = False, + preserve: bool = False, +) -> None: + with open(local_src, "rb") as local_file: + with client.create(remote_dst) as remote_file: + shutil.copyfileobj(local_file, remote_file) + + def cat(args: Namespace): for src in args.src: client = _client_for_url(src) @@ -99,6 +137,48 @@ def chown(args: Namespace): client.set_owner(path, owner, group) +def get(args: Namespace): + paths: List[Tuple[Client, str]] = [] + + for url in args.src: + client = _client_for_url(url) + for path in _glob_path(client, _path_for_url(url)): + paths.append((client, path)) + + dst_is_dir = os.path.isdir(args.localdst) + + if len(paths) > 1 and not dst_is_dir: + raise ValueError("Destination must be directory when copying multiple files") + elif not dst_is_dir: + _download_file( + paths[0][0], + paths[0][1], + args.localdst, + force=args.force, + preserve=args.preserve, + ) + else: + with ThreadPoolExecutor(args.threads) as executor: + futures = [] + for client, path in paths: + filename = os.path.basename(path) + + futures.append( + executor.submit( + _download_file, + client, + path, + os.path.join(args.localdst, filename), + force=args.force, + preserve=args.preserve, + ) + ) + + # Iterate to raise any exceptions thrown + for f in as_completed(futures): + f.result() + + def mkdir(args: Namespace): create_parent = args.parent @@ -193,6 +273,45 @@ def main(in_args: Optional[Sequence[str]] = None): chown_parser.add_argument("path", nargs="+", help="File pattern to modify") chown_parser.set_defaults(func=chown) + get_parser = subparsers.add_parser( + "get", + aliases=["copyToLocal"], + help="Copy files to a local destination", + description="""Copy files matching a pattern to a local destination. + When copying multiple files, the destination must be a directory""", + ) + get_parser.add_argument( + "-p", + "--preserve", + action="store_true", + default=False, + help="Preserve timestamps and the mode", + ) + get_parser.add_argument( + "-f", + "--force", + action="store_true", + default=False, + help="Overwrite the destination if it already exists", + ) + get_parser.add_argument( + "-t", + "--threads", + type=int, + help="Number of threads to use", + default=1, + ) + get_parser.add_argument( + "src", + nargs="+", + help="Source patterns to copy", + ) + get_parser.add_argument( + "localdst", + help="Local destination to write to", + ) + get_parser.set_defaults(func=get) + mkdir_parser = subparsers.add_parser( "mkdir", help="Create a directory", diff --git a/python/src/lib.rs b/python/src/lib.rs index 7d948ac..6d09da5 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -1,6 +1,6 @@ use std::borrow::Cow; use std::collections::HashMap; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use ::hdfs_native::file::{FileReader, FileWriter}; use ::hdfs_native::WriteOptions; @@ -9,6 +9,8 @@ use ::hdfs_native::{ Client, }; use bytes::Bytes; +use futures::stream::BoxStream; +use futures::StreamExt; use hdfs_native::acl::{AclEntry, AclStatus}; use hdfs_native::client::ContentSummary; use pyo3::{exceptions::PyRuntimeError, prelude::*}; @@ -220,6 +222,32 @@ impl PyAclEntry { } } +#[pyclass(name = "FileReadStream")] +struct PyFileReadStream { + inner: Arc>>>, + rt: Arc, +} + +#[pymethods] +impl PyFileReadStream { + pub fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } + + fn __next__(slf: PyRefMut<'_, Self>) -> PyHdfsResult>> { + let inner = Arc::clone(&slf.inner); + let rt = Arc::clone(&slf.rt); + if let Some(result) = slf + .py() + .allow_threads(|| rt.block_on(inner.lock().unwrap().next())) + { + Ok(Some(Cow::from(result?.to_vec()))) + } else { + Ok(None) + } + } +} + #[pyclass] struct RawFileReader { inner: FileReader, @@ -258,6 +286,16 @@ impl RawFileReader { .to_vec(), )) } + + pub fn read_range_stream(&self, offset: usize, len: usize) -> PyFileReadStream { + let stream = Arc::new(Mutex::new( + self.inner.read_range_stream(offset, len).boxed(), + )); + PyFileReadStream { + inner: stream, + rt: Arc::clone(&self.rt), + } + } } #[pyclass(get_all, set_all, name = "WriteOptions")] diff --git a/python/tests/test_cli.py b/python/tests/test_cli.py index 6de04a9..597e038 100644 --- a/python/tests/test_cli.py +++ b/python/tests/test_cli.py @@ -1,5 +1,8 @@ import contextlib import io +import os +import stat +from tempfile import TemporaryDirectory import pytest @@ -27,9 +30,6 @@ def test_cat(client: Client): with pytest.raises(FileNotFoundError): cli_main(["cat", "/nonexistent"]) - client.delete("/testfile") - client.delete("/testfile2") - def test_chmod(client: Client): with pytest.raises(FileNotFoundError): @@ -52,8 +52,6 @@ def test_chmod(client: Client): with pytest.raises(ValueError): cli_main(["chmod", "2778", "/testfile"]) - client.delete("/testfile") - client.mkdirs("/testdir") client.create("/testdir/testfile").close() original_permission = client.get_file_info("/testdir/testfile").permission @@ -66,8 +64,6 @@ def test_chmod(client: Client): assert client.get_file_info("/testdir").permission == 0o700 assert client.get_file_info("/testdir/testfile").permission == 0o700 - client.delete("/testdir", True) - def test_chown(client: Client): with pytest.raises(FileNotFoundError): @@ -92,8 +88,6 @@ def test_chown(client: Client): assert status.owner == "newuser" assert status.group == "newgroup" - client.delete("/testfile") - client.mkdirs("/testdir") client.create("/testdir/testfile").close() file_status = client.get_file_info("/testdir/testfile") @@ -111,7 +105,50 @@ def test_chown(client: Client): assert status.owner == "testuser" assert status.group == "testgroup" - client.delete("/testdir", True) + +def test_get(client: Client): + data = b"0123456789" + + with pytest.raises(FileNotFoundError): + cli_main(["get", "/testfile", "testfile"]) + + with client.create("/testfile") as file: + file.write(data) + + status = client.get_file_info("/testfile") + + with TemporaryDirectory() as tmp_dir: + cli_main(["get", "/testfile", os.path.join(tmp_dir, "localfile")]) + with open(os.path.join(tmp_dir, "localfile"), "rb") as file: + assert file.read() == data + + cli_main(["get", "/testfile", tmp_dir]) + with open(os.path.join(tmp_dir, "testfile"), "rb") as file: + assert file.read() == data + + with pytest.raises(FileExistsError): + cli_main(["get", "/testfile", tmp_dir]) + + cli_main(["get", "-f", "-p", "/testfile", tmp_dir]) + st = os.stat(os.path.join(tmp_dir, "testfile")) + assert stat.S_IMODE(st.st_mode) == status.permission + assert int(st.st_atime * 1000) == status.access_time + assert int(st.st_mtime * 1000) == status.modification_time + + with client.create("/testfile2") as file: + file.write(data) + + with pytest.raises(ValueError): + cli_main(["get", "/testfile", "/testfile2", "notadir"]) + + with TemporaryDirectory() as tmp_dir: + cli_main(["get", "/testfile", "/testfile2", tmp_dir]) + + with open(os.path.join(tmp_dir, "testfile"), "rb") as file: + assert file.read() == data + + with open(os.path.join(tmp_dir, "testfile2"), "rb") as file: + assert file.read() == data def test_mkdir(client: Client): @@ -124,8 +161,6 @@ def test_mkdir(client: Client): cli_main(["mkdir", "-p", "/testdir/nested/dir"]) assert client.get_file_info("/testdir/nested/dir").isdir - client.delete("/testdir", True) - def test_mv(client: Client): client.create("/testfile").close() diff --git a/python/tests/test_integration.py b/python/tests/test_integration.py index 3291e52..71b9e1a 100644 --- a/python/tests/test_integration.py +++ b/python/tests/test_integration.py @@ -49,8 +49,17 @@ def test_integration(client: Client): with client.read("/testfile") as file: data = io.BytesIO(file.read()) - for i in range(0, 33 * 1024 * 1024): - assert data.read(4) == i.to_bytes(4, "big") + for i in range(0, 33 * 1024 * 1024): + assert data.read(4) == i.to_bytes(4, "big") + + data = io.BytesIO() + for chunk in file: + data.write(chunk) + + data.seek(0) + + for i in range(0, 33 * 1024 * 1024): + assert data.read(4) == i.to_bytes(4, "big") with client.read("/testfile") as file: # Skip first two ints diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 0d9b1a1..e320c83 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -22,7 +22,7 @@ crc = "3.2" ctr = "0.9" des = "0.8" dns-lookup = "2" -futures = "0.3" +futures = { workspace = true } g2p = "1" hex = "0.4" hmac = "0.12"