diff --git a/python/ray/_private/ray_perf.py b/python/ray/_private/ray_perf.py index 330527957d675..d1c1e7e1abf6c 100644 --- a/python/ray/_private/ray_perf.py +++ b/python/ray/_private/ray_perf.py @@ -9,6 +9,7 @@ import ray import ray.experimental.channel as ray_channel +from ray.dag import InputNode, MultiOutputNode logger = logging.getLogger(__name__) @@ -369,6 +370,58 @@ def read(self, chans): for reader in readers: ray.kill(reader) + # Tests for compiled DAGs. + + def _exec(dag): + output_channel = dag.execute(b"x") + output_channel.begin_read() + output_channel.end_read() + + def _exec_multi_output(dag): + output_channels = dag.execute(b"x") + for output_channel in output_channels: + output_channel.begin_read() + for output_channel in output_channels: + output_channel.end_read() + + @ray.remote + class Actor: + def echo(self, x): + return x + + a = Actor.remote() + with InputNode() as inp: + dag = a.echo.bind(inp) + + results += timeit("single-actor DAG calls", lambda: ray.get(dag.execute(b"x"))) + dag = dag.experimental_compile() + results += timeit("compiled single-actor DAG calls", lambda: _exec(dag)) + + del a + n_cpu = multiprocessing.cpu_count() // 2 + actors = [Actor.remote() for _ in range(n_cpu)] + with InputNode() as inp: + dag = MultiOutputNode([a.echo.bind(inp) for a in actors]) + results += timeit( + "scatter-gather DAG calls, n={n_cpu} actors", lambda: ray.get(dag.execute(b"x")) + ) + dag = dag.experimental_compile() + results += timeit( + f"compiled scatter-gather DAG calls, n={n_cpu} actors", + lambda: _exec_multi_output(dag), + ) + + actors = [Actor.remote() for _ in range(n_cpu)] + with InputNode() as inp: + dag = inp + for a in actors: + dag = a.echo.bind(dag) + results += timeit( + f"chain DAG calls, n={n_cpu} actors", lambda: ray.get(dag.execute(b"x")) + ) + dag = dag.experimental_compile() + results += timeit(f"compiled chain DAG calls, n={n_cpu} actors", lambda: _exec(dag)) + ray.shutdown() ############################ diff --git a/python/ray/dag/BUILD b/python/ray/dag/BUILD index bbb0fb48564b0..4cb5e5634a13b 100644 --- a/python/ray/dag/BUILD +++ b/python/ray/dag/BUILD @@ -68,3 +68,11 @@ py_test( tags = ["exclusive", "team:core", "ray_dag_tests"], deps = [":dag_lib"], ) + +py_test( + name = "test_accelerated_dag", + size = "medium", + srcs = dag_tests_srcs, + tags = ["exclusive", "team:core", "ray_dag_tests"], + deps = [":dag_lib"], +) diff --git a/python/ray/dag/class_node.py b/python/ray/dag/class_node.py index e4a74f83c7495..64f4615e1aada 100644 --- a/python/ray/dag/class_node.py +++ b/python/ray/dag/class_node.py @@ -202,3 +202,12 @@ def __str__(self) -> str: def get_method_name(self) -> str: return self._method_name + + def _get_remote_method(self, method_name): + method_body = getattr(self._parent_class_node, method_name) + return method_body + + def _get_actor_handle(self) -> Optional["ray.actor.ActorHandle"]: + if not isinstance(self._parent_class_node, ray.actor.ActorHandle): + return None + return self._parent_class_node diff --git a/python/ray/dag/compiled_dag_node.py b/python/ray/dag/compiled_dag_node.py new file mode 100644 index 0000000000000..4e86b5686f531 --- /dev/null +++ b/python/ray/dag/compiled_dag_node.py @@ -0,0 +1,416 @@ +import logging +from typing import Any, Dict, List, Tuple, Union, Optional + +from collections import defaultdict + +import ray +import ray.experimental.channel as ray_channel +from ray.util.annotations import DeveloperAPI + + +MAX_BUFFER_SIZE = int(100 * 1e6) # 100MB + +ChannelType = "ray.experimental.channel.Channel" + +logger = logging.getLogger(__name__) + + +@DeveloperAPI +def do_allocate_channel( + self, buffer_size_bytes: int, num_readers: int = 1 +) -> ChannelType: + """Generic actor method to allocate an output channel. + + Args: + buffer_size_bytes: The maximum size of messages in the channel. + num_readers: The number of readers per message. + + Returns: + The allocated channel. + """ + self._output_channel = ray_channel.Channel(buffer_size_bytes, num_readers) + return self._output_channel + + +@DeveloperAPI +def do_exec_compiled_task( + self, + inputs: List[Union[Any, ChannelType]], + actor_method_name: str, +) -> None: + """Generic actor method to begin executing a compiled DAG. This runs an + infinite loop to repeatedly read input channel(s), execute the given + method, and write output channel(s). It only exits if the actor dies or an + exception is thrown. + + Args: + inputs: The arguments to the task. Arguments that are not Channels will + get passed through to the actor method. If the argument is a channel, + it will be replaced by the value read from the channel before the + method execute. + actor_method_name: The name of the actual actor method to execute in + the loop. + """ + try: + method = getattr(self, actor_method_name) + + resolved_inputs = [] + input_channel_idxs = [] + # Add placeholders for input channels. + for idx, inp in enumerate(inputs): + if isinstance(inp, ray_channel.Channel): + input_channel_idxs.append((idx, inp)) + resolved_inputs.append(None) + else: + resolved_inputs.append(inp) + + while True: + for idx, channel in input_channel_idxs: + resolved_inputs[idx] = channel.begin_read() + + output_val = method(*resolved_inputs) + + self._output_channel.write(output_val) + for _, channel in input_channel_idxs: + channel.end_read() + + except Exception as e: + logging.warn(f"Compiled DAG task aborted with exception: {e}") + raise + + +@DeveloperAPI +class CompiledTask: + """Wraps the normal Ray DAGNode with some metadata.""" + + def __init__(self, idx: int, dag_node: "ray.dag.DAGNode"): + """ + Args: + idx: A unique index into the original DAG. + dag_node: The original DAG node created by the user. + """ + self.idx = idx + self.dag_node = dag_node + + self.downstream_node_idxs = set() + self.output_channel = None + + @property + def args(self) -> Tuple[Any]: + return self.dag_node.get_args() + + @property + def num_readers(self) -> int: + return len(self.downstream_node_idxs) + + def __str__(self) -> str: + return f""" +Node: {self.dag_node} +Arguments: {self.args} +Output: {self.output_channel} +""" + + +@DeveloperAPI +class CompiledDAG: + """Experimental class for accelerated execution. + + This class should not be called directly. Instead, create + a ray.dag and call experimental_compile(). + + See REP https://github.com/ray-project/enhancements/pull/48 for more + information. + """ + + def __init__(self, buffer_size_bytes: Optional[int]): + self._buffer_size_bytes: Optional[int] = buffer_size_bytes + if self._buffer_size_bytes is None: + self._buffer_size_bytes = MAX_BUFFER_SIZE + if not isinstance(self._buffer_size_bytes, int) or self._buffer_size_bytes <= 0: + raise ValueError( + "`buffer_size_bytes` must be a positive integer, found " + f"{self._buffer_size_bytes}" + ) + + # idx -> CompiledTask. + self.idx_to_task: Dict[int, "CompiledTask"] = {} + # DAGNode -> idx. + self.dag_node_to_idx: Dict["ray.dag.DAGNode", int] = {} + # idx counter. + self.counter: int = 0 + + # Attributes that are set during preprocessing. + # Preprocessing identifies the input node and output node. + self.input_task_idx: Optional[int] = None + self.output_task_idx: Optional[int] = None + self.has_single_output: bool = False + self.actor_task_count: Dict["ray._raylet.ActorID", int] = defaultdict(int) + + # Cached attributes that are set during compilation. + self.dag_input_channel: Optional[ChannelType] = None + self.dag_output_channels: Optional[ChannelType] = None + # ObjectRef for each worker's task. The task is an infinite loop that + # repeatedly executes the method specified in the DAG. + self.worker_task_refs: List["ray.ObjectRef"] = [] + + def _add_node(self, node: "ray.dag.DAGNode") -> None: + idx = self.counter + self.idx_to_task[idx] = CompiledTask(idx, node) + self.dag_node_to_idx[node] = idx + self.counter += 1 + + def _preprocess(self) -> None: + """Before compiling, preprocess the DAG to build an index from task to + upstream and downstream tasks, and to set the input and output node(s) + of the DAG. + + This function is idempotent. + """ + from ray.dag import ( + DAGNode, + ClassMethodNode, + FunctionNode, + InputAttributeNode, + InputNode, + MultiOutputNode, + ) + + self.input_task_idx, self.output_task_idx = None, None + self.actor_task_count.clear() + + # For each task node, set its upstream and downstream task nodes. + for idx, task in self.idx_to_task.items(): + dag_node = task.dag_node + if not ( + isinstance(dag_node, InputNode) + or isinstance(dag_node, MultiOutputNode) + or isinstance(dag_node, ClassMethodNode) + ): + if isinstance(dag_node, InputAttributeNode): + # TODO(swang): Support multi args. + raise NotImplementedError( + "Compiled DAGs currently do not support kwargs or " + "multiple args for InputNode" + ) + elif isinstance(dag_node, FunctionNode): + # TODO(swang): Support non-actor tasks. + raise NotImplementedError( + "Compiled DAGs currently only support actor method nodes" + ) + else: + raise ValueError( + f"Found unsupported node of type {type(task.dag_node)}" + ) + + if isinstance(dag_node, ClassMethodNode): + actor_handle = dag_node._get_actor_handle() + if actor_handle is None: + raise ValueError( + "Compiled DAGs can only bind methods to an actor " + "that is already created with Actor.remote()" + ) + self.actor_task_count[actor_handle._actor_id] += 1 + + for arg in task.args: + if isinstance(arg, DAGNode): + arg_idx = self.dag_node_to_idx[arg] + self.idx_to_task[arg_idx].downstream_node_idxs.add(idx) + + for actor_id, task_count in self.actor_task_count.items(): + if task_count > 1: + raise NotImplementedError( + "Compiled DAGs can contain at most one task per actor handle. " + f"Actor with ID {actor_id} appears {task_count}x." + ) + + # Find the input node to the DAG. + for idx, task in self.idx_to_task.items(): + if isinstance(task.dag_node, InputNode): + assert self.input_task_idx is None, "more than one InputNode found" + self.input_task_idx = idx + # TODO: Support no-input DAGs (use an empty object to signal). + if self.input_task_idx is None: + raise NotImplementedError( + "Compiled DAGs currently require exactly one InputNode" + ) + + # Find the (multi-)output node to the DAG. + for idx, task in self.idx_to_task.items(): + if len(task.downstream_node_idxs) == 0: + assert self.output_task_idx is None, "More than one output node found" + self.output_task_idx = idx + + assert self.output_task_idx is not None + output_node = self.idx_to_task[self.output_task_idx].dag_node + # Add an MultiOutputNode to the end of the DAG if it's not already there. + if not isinstance(output_node, MultiOutputNode): + self.has_single_output = True + output_node = MultiOutputNode([output_node]) + self._add_node(output_node) + self.output_task_idx = self.dag_node_to_idx[output_node] + # Preprocess one more time so that we have the right output node + # now. + self._preprocess() + + def _get_or_compile( + self, + ) -> Tuple[ChannelType, Union[ChannelType, List[ChannelType]]]: + """Compile an execution path. This allocates channels for adjacent + tasks to send/receive values. An infinite task is submitted to each + actor in the DAG that repeatedly receives from input channel(s) and + sends to output channel(s). + + This function is idempotent and will cache the previously allocated + channels. + + Returns: + A tuple of (input channel, output channel(s)). The input channel + that should be used by the caller to submit a DAG execution. The + output channel(s) should be read by the caller to get the DAG + output. + """ + from ray.dag import DAGNode, InputNode, MultiOutputNode, ClassMethodNode + + if self.input_task_idx is None: + self._preprocess() + + if self.dag_input_channel is not None: + assert self.dag_output_channels is not None + # Driver should ray.put on input, ray.get/release on output + return ( + self.dag_input_channel, + self.dag_output_channels, + ) + + queue = [self.input_task_idx] + visited = set() + # Create output buffers + while queue: + cur_idx = queue.pop(0) + if cur_idx in visited: + continue + visited.add(cur_idx) + + task = self.idx_to_task[cur_idx] + # Create an output buffer on the actor. + assert task.output_channel is None + if isinstance(task.dag_node, ClassMethodNode): + fn = task.dag_node._get_remote_method("__ray_call__") + task.output_channel = ray.get( + fn.remote( + do_allocate_channel, + buffer_size_bytes=self._buffer_size_bytes, + num_readers=task.num_readers, + ) + ) + elif isinstance(task.dag_node, InputNode): + task.output_channel = ray_channel.Channel( + buffer_size_bytes=self._buffer_size_bytes, + num_readers=task.num_readers, + ) + else: + assert isinstance(task.dag_node, MultiOutputNode) + + for idx in task.downstream_node_idxs: + queue.append(idx) + + for node_idx, task in self.idx_to_task.items(): + if node_idx == self.input_task_idx: + # We don't need to assign an actual task for the input node. + continue + + if node_idx == self.output_task_idx: + # We don't need to assign an actual task for the input node. + continue + + resolved_args = [] + has_at_least_one_channel_input = False + for arg in task.args: + if isinstance(arg, DAGNode): + arg_idx = self.dag_node_to_idx[arg] + arg_channel = self.idx_to_task[arg_idx].output_channel + assert arg_channel is not None + resolved_args.append(arg_channel) + has_at_least_one_channel_input = True + else: + resolved_args.append(arg) + # TODO: Support no-input DAGs (use an empty object to signal). + if not has_at_least_one_channel_input: + raise ValueError( + "Compiled DAGs require each task to take a " + "ray.dag.InputNode or at least one other DAGNode as an " + "input" + ) + + # Assign the task with the correct input and output buffers. + worker_fn = task.dag_node._get_remote_method("__ray_call__") + self.worker_task_refs.append( + worker_fn.remote( + do_exec_compiled_task, + resolved_args, + task.dag_node.get_method_name(), + ) + ) + + self.dag_input_channel = self.idx_to_task[self.input_task_idx].output_channel + + self.dag_output_channels = [] + for output in self.idx_to_task[self.output_task_idx].args: + assert isinstance(output, DAGNode) + output_idx = self.dag_node_to_idx[output] + self.dag_output_channels.append(self.idx_to_task[output_idx].output_channel) + + assert self.dag_input_channel + assert self.dag_output_channels + assert [ + output_channel is not None for output_channel in self.dag_output_channels + ] + # If no MultiOutputNode was specified during the DAG creation, there is only + # one output. Return a single output channel instead of a list of + # channels. + if self.has_single_output: + assert len(self.dag_output_channels) == 1 + self.dag_output_channels = self.dag_output_channels[0] + + # Driver should ray.put on input, ray.get/release on output + return (self.dag_input_channel, self.dag_output_channels) + + def execute( + self, + *args, + **kwargs, + ) -> Union[ChannelType, List[ChannelType]]: + """Execute this DAG using the compiled execution path. + + Args: + args: Args to the InputNode. + kwargs: Kwargs to the InputNode. Not supported yet. + + Returns: + A list of Channels that can be used to read the DAG result. + """ + # These errors should already be caught during compilation, but just in + # case. + if len(args) != 1: + raise NotImplementedError("Compiled DAGs support exactly one InputNode arg") + if len(kwargs) != 0: + raise NotImplementedError("Compiled DAGs do not support kwargs") + + input_channel, output_channels = self._get_or_compile() + input_channel.write(args[0]) + return output_channels + + +@DeveloperAPI +def build_compiled_dag_from_ray_dag( + dag: "ray.dag.DAGNode", buffer_size_bytes: Optional[int] +) -> "CompiledDAG": + compiled_dag = CompiledDAG(buffer_size_bytes) + + def _build_compiled_dag(node): + compiled_dag._add_node(node) + return node + + dag.apply_recursive(_build_compiled_dag) + compiled_dag._get_or_compile() + return compiled_dag diff --git a/python/ray/dag/dag_node.py b/python/ray/dag/dag_node.py index aef2f31e4f63e..5acfbd02f92cf 100644 --- a/python/ray/dag/dag_node.py +++ b/python/ray/dag/dag_node.py @@ -16,6 +16,8 @@ import uuid import asyncio +from ray.dag.compiled_dag_node import build_compiled_dag_from_ray_dag + T = TypeVar("T") @@ -103,6 +105,20 @@ async def get_object_refs_from_last_execute(self) -> Dict[str, Any]: def clear_cache(self): self.cache_from_last_execute = {} + def experimental_compile( + self, buffer_size_bytes: Optional[int] = None + ) -> "ray.dag.CompiledDAG": + """Compile an accelerated execution path for this DAG. + + Args: + buffer_size_bytes: The maximum size of messages that can be passed + between tasks in the DAG. + + Returns: + A compiled DAG. + """ + return build_compiled_dag_from_ray_dag(self, buffer_size_bytes) + def execute( self, *args, _ray_cache_refs: bool = False, **kwargs ) -> Union[ray.ObjectRef, "ray.actor.ActorHandle"]: diff --git a/python/ray/dag/tests/test_accelerated_dag.py b/python/ray/dag/tests/test_accelerated_dag.py new file mode 100644 index 0000000000000..b293ee8940267 --- /dev/null +++ b/python/ray/dag/tests/test_accelerated_dag.py @@ -0,0 +1,183 @@ +# coding: utf-8 +import logging +import os +import sys + +import pytest + +import ray +import ray.cluster_utils +from ray.dag import InputNode, MultiOutputNode +from ray.tests.conftest import * # noqa +from ray._private.test_utils import wait_for_condition + + +logger = logging.getLogger(__name__) + +if sys.platform != "linux": + pytest.skip("Skipping, requires Linux.", allow_module_level=True) + + +@ray.remote +class Actor: + def __init__(self, init_value): + print("__init__ PID", os.getpid()) + self.i = init_value + + def inc(self, x): + self.i += x + return self.i + + def append_to(self, lst): + lst.append(self.i) + return lst + + def inc_two(self, x, y): + self.i += x + self.i += y + return self.i + + +def test_basic(ray_start_regular): + a = Actor.remote(0) + with InputNode() as i: + dag = a.inc.bind(i) + + compiled_dag = dag.experimental_compile() + + for i in range(3): + output_channel = compiled_dag.execute(1) + # TODO(swang): Replace with fake ObjectRef. + result = output_channel.begin_read() + assert result == i + 1 + output_channel.end_read() + + +def test_regular_args(ray_start_regular): + # Test passing regular args to .bind in addition to DAGNode args. + a = Actor.remote(0) + with InputNode() as i: + dag = a.inc_two.bind(2, i) + + compiled_dag = dag.experimental_compile() + + for i in range(3): + output_channel = compiled_dag.execute(1) + # TODO(swang): Replace with fake ObjectRef. + result = output_channel.begin_read() + assert result == (i + 1) * 3 + output_channel.end_read() + + +@pytest.mark.parametrize("num_actors", [1, 4]) +def test_scatter_gather_dag(ray_start_regular, num_actors): + actors = [Actor.remote(0) for _ in range(num_actors)] + with InputNode() as i: + out = [a.inc.bind(i) for a in actors] + dag = MultiOutputNode(out) + + compiled_dag = dag.experimental_compile() + + for i in range(3): + output_channels = compiled_dag.execute(1) + # TODO(swang): Replace with fake ObjectRef. + results = [chan.begin_read() for chan in output_channels] + assert results == [i + 1] * num_actors + for chan in output_channels: + chan.end_read() + + +@pytest.mark.parametrize("num_actors", [1, 4]) +def test_chain_dag(ray_start_regular, num_actors): + actors = [Actor.remote(i) for i in range(num_actors)] + with InputNode() as inp: + dag = inp + for a in actors: + dag = a.append_to.bind(dag) + + compiled_dag = dag.experimental_compile() + + for i in range(3): + output_channel = compiled_dag.execute([]) + # TODO(swang): Replace with fake ObjectRef. + result = output_channel.begin_read() + assert result == list(range(num_actors)) + output_channel.end_read() + + +def test_dag_exception(ray_start_regular, capsys): + a = Actor.remote(0) + with InputNode() as inp: + dag = a.inc.bind(inp) + + compiled_dag = dag.experimental_compile() + compiled_dag.execute("hello") + wait_for_condition( + lambda: "Compiled DAG task aborted with exception" in capsys.readouterr().err + ) + + +def test_dag_errors(ray_start_regular): + a = Actor.remote(0) + dag = a.inc.bind(1) + with pytest.raises( + NotImplementedError, + match="Compiled DAGs currently require exactly one InputNode", + ): + dag.experimental_compile() + + a2 = Actor.remote(0) + with InputNode() as inp: + dag = MultiOutputNode([a.inc.bind(inp), a2.inc.bind(1)]) + with pytest.raises( + ValueError, + match="Compiled DAGs require each task to take a ray.dag.InputNode or " + "at least one other DAGNode as an input", + ): + dag.experimental_compile() + + @ray.remote + def f(x): + return x + + with InputNode() as inp: + dag = f.bind(inp) + with pytest.raises( + NotImplementedError, + match="Compiled DAGs currently only support actor method nodes", + ): + dag.experimental_compile() + + with InputNode() as inp: + dag = a.inc.bind(inp) + dag = a.inc.bind(dag) + with pytest.raises( + NotImplementedError, + match="Compiled DAGs can contain at most one task per actor handle.", + ): + dag.experimental_compile() + + with InputNode() as inp: + dag = a.inc_two.bind(inp[0], inp[1]) + with pytest.raises( + NotImplementedError, + match="Compiled DAGs currently do not support kwargs or multiple args " + "for InputNode", + ): + dag.experimental_compile() + + with InputNode() as inp: + dag = a.inc_two.bind(inp.x, inp.y) + with pytest.raises( + NotImplementedError, + match="Compiled DAGs currently do not support kwargs or multiple args " + "for InputNode", + ): + dag.experimental_compile() + + +if __name__ == "__main__": + if os.environ.get("PARALLEL_CI"): + sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__])) + else: + sys.exit(pytest.main(["-sv", __file__])) diff --git a/python/ray/experimental/channel.py b/python/ray/experimental/channel.py index e8ef9ad085f79..d8443f6a9db17 100644 --- a/python/ray/experimental/channel.py +++ b/python/ray/experimental/channel.py @@ -12,7 +12,7 @@ def _create_channel_ref( - buffer_size: int, + buffer_size_bytes: int, ) -> "ray.ObjectRef": """ Create a channel that can be read and written by co-located Ray processes. @@ -21,7 +21,7 @@ def _create_channel_ref( read the previous value. Only the channel creator may write to the channel. Args: - buffer_size: The number of bytes to allocate for the object data and + buffer_size_bytes: The number of bytes to allocate for the object data and metadata. Writes to the channel must produce serialized data and metadata less than or equal to this value. Returns: @@ -30,7 +30,7 @@ def _create_channel_ref( worker = ray._private.worker.global_worker worker.check_connected() - value = b"0" * buffer_size + value = b"0" * buffer_size_bytes try: object_ref = worker.put_object( @@ -52,7 +52,12 @@ class Channel: ray.wait. """ - def __init__(self, buffer_size: Optional[int] = None, num_readers: int = 1): + def __init__( + self, + buffer_size_bytes: Optional[int] = None, + num_readers: int = 1, + _base_ref: Optional["ray.ObjectRef"] = None, + ): """ Create a channel that can be read and written by co-located Ray processes. @@ -60,16 +65,25 @@ def __init__(self, buffer_size: Optional[int] = None, num_readers: int = 1): so the writer will block until reader(s) have read the previous value. Args: - buffer_size: The number of bytes to allocate for the object data and + buffer_size_bytes: The number of bytes to allocate for the object data and metadata. Writes to the channel must produce serialized data and metadata less than or equal to this value. Returns: Channel: A wrapper around ray.ObjectRef. """ - if buffer_size is None: - self._base_ref = None + if buffer_size_bytes is None: + if _base_ref is None: + raise ValueError( + "One of `buffer_size_bytes` or `_base_ref` must be provided" + ) + self._base_ref = _base_ref else: - self._base_ref = _create_channel_ref(buffer_size) + if not isinstance(buffer_size_bytes, int): + raise ValueError("buffer_size_bytes must be an integer") + self._base_ref = _create_channel_ref(buffer_size_bytes) + + if not isinstance(num_readers, int): + raise ValueError("num_readers must be an integer") self._num_readers = num_readers self._worker = ray._private.worker.global_worker @@ -77,9 +91,7 @@ def __init__(self, buffer_size: Optional[int] = None, num_readers: int = 1): @staticmethod def _from_base_ref(base_ref: "ray.ObjectRef", num_readers: int) -> "Channel": - chan = Channel(num_readers=num_readers) - chan._base_ref = base_ref - return chan + return Channel(num_readers=num_readers, _base_ref=base_ref) def __reduce__(self): return self._from_base_ref, (self._base_ref, self._num_readers)