Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nccl ops correction changes #3387

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open

Nccl ops correction changes #3387

wants to merge 12 commits into from

Conversation

apbose
Copy link
Collaborator

@apbose apbose commented Feb 10, 2025

No description provided.

@apbose apbose requested a review from narendasan February 10, 2025 19:16
@github-actions github-actions bot added component: lowering Issues re: The lowering / preprocessing passes component: conversion Issues re: Conversion stage component: api [Python] Issues re: Python API component: runtime component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: torch_compile labels Feb 10, 2025
Comment on lines 68 to 76
# transpose key deleted since not desirable to lower it to permute
to_delete = {
key
for key in settings_aot_autograd["decompositions"]
if "detach" in key._name or "transpose" in key._name
}

for key in to_delete:
del settings_aot_autograd["decompositions"][key]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have a remove_detach lowering pass. Can that help here ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is not helping here, because the graph explicitly does not have detach ops to remove the nodes. Instead it encounters this in https://github.com/pytorch/pytorch/blob/main/torch/_decomp/decompositions.py#L2153. This might be due to the %hook_result_3 = call_function[target=torch._dynamo.variables.tensor.prim_to_local](args = (%outputs_3,), kwargs = {}) where it gets the DTensor to local tensor and might need to detach the distributed operation.

if aten.detach in settings_aot_autograd["decompositions"]:
del settings_aot_autograd["decompositions"][aten.detach]
# transpose key deleted since not desirable to lower it to permute
to_delete = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this apply to all cases not just NCCL?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean in the non distributed example? I am not sure about that answer, I added this for the llama3 example since I was issues in the model lowering and it was generating graph breaks at the wrong part, leading to complex input error. It can be added to all cases in case if we want to not lower transpose to permute.

Copy link
Collaborator Author

@apbose apbose Mar 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding the discussion

  1. detach: remove_detach does not help since the graph explicitly does not have detach ops to remove the nodes. Instead it encounters this in https://github.com/pytorch/pytorch/blob/main/torch/_decomp/decompositions.py#L2153. This might be due to the %hook_result_3 = call_function[target=torch._dynamo.variables.tensor.prim_to_local](args = (%outputs_3,), kwargs = {}) where it moves the DTensor to local tensor and needs to detach the distributed operation. This is in tensor_parallel_simple_example.py

  2. transpose: transpose is more for tackling the tensor_parallel_llama3.py. The broad modification I did to handle the complex nos, are:
    a. Modifying the placeholder node shape and type
    b. Modifying the inputs to the reshape and slice ops with complex inputs
    c. Replace the complex tensorrt mul
    I see that if I decompose transpose to permute, the graph in gpu_0, has output of complex tensor mul as complex64 or complex 128 which goes as input to acc_* graph causing complex input error. Transpose being in the graph helps in it handling the complex input in gpu_0 graph partition only.

Regarding the discussion, would removal of transpose from decomposition affect the result- I would think no, since this is not removal of op like detach, but instead it is just that we do not lower it to permute. But you could provide me more insights if not

  1. Also if we would want it to be model specific and not apply to all models, I think it can be the next step to include it either in the UI or something like we do in the non distributed models, including in the torch_disabled_decomposition dictionary which applies to all model. Specifying UI, since we are talking about model specific disabled decomposition. As of now since this part of code applies to only distributed model, it should be good to go.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, so the code needs to be restructured to make it clear that is is not the main codepath.

settings, engine_cache = parse_dynamo_kwargs(kwargs)
    if settings.use_aot_joint_export:
        return _pretraced_backend(gm, sample_inputs, settings, engine_cache)
    logger.debug("Wrapping the backend with aot_autograd\n")
    _pretraced_backend_autograd = functools.partial(
        _pretraced_backend, settings=settings, engine_cache=engine_cache
    )
    settings_aot_autograd = {}
    settings_aot_autograd["decompositions"] = get_decompositions(
        settings.enable_experimental_decompositions
    )
    # This is added since detach lowering leads to alias nodes
    # Error - View operation returned a tensor that is the same as the input base tensor
    # torch nop_decompositions in torch/_decomp/decompositions.py
    # transpose key deleted since not desirable to lower it to permute
    to_delete = {...

Its not immediately obvious that most models will run through

if settings.use_aot_joint_export:
        return _pretraced_backend(gm, sample_inputs, settings, engine_cache)

And I am still not sure if such a broad change should be made even in the case of MGMN. How would a user/we know know that this change is needed?

Copy link
Collaborator Author

@apbose apbose Mar 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing that out, I restructured the code.
As for the case of exclusion of decompositions for MGMN, if we want it to be model specific, could this be the next thing in the UI to make it model specific.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Now the new bool flag (use_distributed_mode_trace) is better than previous approach for code restructuring since previous one was not conveying general path (pretraced_backend) clearly.
  • These changes should not be model specific ideally. Infact, I was wondering if we really need the above flag. Because it might be not straightforward to users when to actually use it (as you mentioned in the comments). If we only support all_gather_into_tensor and reduce_scatter_tensor, we should detect if these ops exist in the graph pre-tracing and then redirect it to aot_autograd accordingly.
  • There maybe other distributed variants of these ops which are not supported but users might expect them to be supported once use_distributed_mode_trace is enabled. It's broad.

Copy link
Collaborator Author

@apbose apbose Mar 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the detailed analysis and bringing these points up!

  1. Yes the new bool flag (use_distributed_mode_trace) is better
  2. When you say the changes should not be model specific, are you referring to the specific decompositions to be excluded or change of the tracing from aot_export_joint_simple?
  • Assuming the latter- change of the tracing from aot_export_joint_simple. Right now for distributed case using the torch.distributed libraries I have encountered these two ops all_gather_into_tensor and reduce_scatter_tensor which lead to DTensors, requiring this tracing mode as directed by Pytorch folks. I am yet to encounter the other ops, need to get into the task of running more models with this code once I am able to wrap this up. From user perspective, if the user is using distributed libs leading to these ops and consequently these distributed tensors, then this flag would be required. Thats what my idea is right now.
    That being said we can discuss further on it if you see the point in changing the trace if these ops are present. That would be a design change and we could work on it in the PR with the C++ changes.
  • As for the former-specific decompositions to be excluded, handling complex inputs is just an added feature which is required if we want these distributed models to run end to end. TensorRT does not support it. It is an added complication to support depending on graph of model if the user wants to run the model end to end. For this we require these decompositions to be excluded, which can be made more generic with time as we proceed with running more models with this setup. Anyways with the development of MLIR TRT kernels to handle rotary embedding, this would not be required.
  1. Yes I do get this point. Other distributed variants of these ops are maybe untested. But I would say this flag tells the user if distributed mode trace should be enabled or not, rather than saying if other variant ops will work right out of box. We can discuss more on this.

As I mentioned earlier, we could maybe merge this PR for nccl ops support and discuss these points for further improvement.

@github-actions github-actions bot added component: tests Issues re: Tests component: converters Issues re: Specific op converters labels Feb 28, 2025
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/distributed/test_nccl_ops.py	2025-02-28 04:08:09.204507+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/distributed/test_nccl_ops.py	2025-02-28 04:08:34.494240+00:00
@@ -23,11 +23,11 @@

from conversion.harness import DispatchTestCase


class TestGatherNcclOpsConverter(DispatchTestCase):
-    @parameterized.expand([(8)])
+    @parameterized.expand([8])
    def test_nccl_ops(self, linear_layer_dim):
        class DistributedGatherModel(nn.Module):
            def __init__(self, input_dim):
                super().__init__()
                self.fc = torch.nn.Linear(input_dim, input_dim)
@@ -48,11 +48,11 @@
            inputs,
            use_dynamo_tracer=True,
            fuse_distributed_ops=True,
        )

-    @parameterized.expand([(8)])
+    @parameterized.expand([8])
    def test_nccl_ops_scatter(self, linear_layer_dim):

        class DistributedReduceScatterModel(nn.Module):
            def __init__(self, input_dim):
                super().__init__()

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/distributed/test_nccl_ops.py	2025-02-28 19:22:33.272875+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/distributed/test_nccl_ops.py	2025-02-28 19:23:00.029705+00:00
@@ -15,11 +15,11 @@

from conversion.harness import DispatchTestCase


class TestGatherNcclOpsConverter(DispatchTestCase):
-    @parameterized.expand([(8)])
+    @parameterized.expand([8])
    def test_nccl_ops(self, linear_layer_dim):
        class DistributedGatherModel(nn.Module):
            def __init__(self, input_dim):
                super().__init__()
                self.fc = torch.nn.Linear(input_dim, input_dim)
@@ -40,11 +40,11 @@
            inputs,
            use_dynamo_tracer=True,
            fuse_distributed_ops=True,
        )

-    @parameterized.expand([(8)])
+    @parameterized.expand([8])
    def test_nccl_ops_scatter(self, linear_layer_dim):

        class DistributedReduceScatterModel(nn.Module):
            def __init__(self, input_dim):
                super().__init__()

@apbose apbose force-pushed the nccl_ops_additional branch from 091c83f to 67b970e Compare February 28, 2025 19:24
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/distributed/test_nccl_ops.py	2025-02-28 19:24:59.719676+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/distributed/test_nccl_ops.py	2025-02-28 19:25:27.751500+00:00
@@ -15,11 +15,11 @@

from conversion.harness import DispatchTestCase


class TestGatherNcclOpsConverter(DispatchTestCase):
-    @parameterized.expand([(8)])
+    @parameterized.expand([8])
    def test_nccl_ops(self, linear_layer_dim):
        class DistributedGatherModel(nn.Module):
            def __init__(self, input_dim):
                super().__init__()
                self.fc = torch.nn.Linear(input_dim, input_dim)
@@ -40,11 +40,11 @@
            inputs,
            use_dynamo_tracer=True,
            fuse_distributed_ops=True,
        )

-    @parameterized.expand([(8)])
+    @parameterized.expand([8])
    def test_nccl_ops_scatter(self, linear_layer_dim):

        class DistributedReduceScatterModel(nn.Module):
            def __init__(self, input_dim):
                super().__init__()

@@ -364,6 +364,15 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
(i.contiguous() if isinstance(i, torch.Tensor) else torch.tensor(i).cuda())
for i in inputs
]

for i, contiguous_input in enumerate(contiguous_inputs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the C++ API not need these changes?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not clear on this aspect. Could you please let me know what would be required as part of this. I am running distributed python example and did not encounter the requirement for this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean try running the example using the C++ runtime, Id expect that it doesnt handle complex numerics correctly

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is showing this

[rank0]:  File "/root/.pyenv/versions/3.11.11/lib/python3.11/site-packages/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py", line 304, in forward
[rank0]:    outputs: List[torch.Tensor] = torch.ops.tensorrt.execute_engine(
[rank0]:                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:  File "/root/.pyenv/versions/3.11.11/lib/python3.11/site-packages/torch/_ops.py", line 1158, in __call__
[rank0]:    return self._op(*args, **(kwargs or {}))
[rank0]:           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: [Error thrown at core/runtime/execute_engine.cpp:111] Expected inputs[i].dtype() == expected_type to be true but got false
[rank0]: Expected input tensors to have type Float, found type c10::complex<float>

for C++ runtime.
Can we please record these as future To-dos for this feature, and merge this PR so that the main feature is merged to the main?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uh, I would not consider this feature shipable without supporting the C++ runtime, but it can be in a different PR

Copy link
Collaborator Author

@apbose apbose Mar 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I agree I can do it in the next PR. As such this change in code is to enable running llama3 distributed end to end on torchTRT for supporting complex inputs. At present it runs on python runtime but then additional changes are needed for C++ runtime.
The implementation of NCCL ops plugin which is the main feature and its basic functioning should be working.

os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = str(port)
os.environ["TRTLLM_PLUGINS_PATH"] = (
find_repo_root() + "/lib/libnvinfer_plugin_tensorrt_llm.so"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assume a tmp path instead of some arbitrary location in the source tree

@@ -351,6 +351,7 @@ def generate_graph(
enable_passes: bool,
propagate_shapes: bool = False,
settings: CompilationSettings = CompilationSettings(),
fuse_distributed_ops: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this setting in the harness? Since it only applies to distributed, you should 1. have a test for the pass specifically, 2. have a test for converting the custom op without a pass and 3. have a test that applies the pass then converts the custom op

Copy link
Collaborator Author

@apbose apbose Mar 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had the same doubt actually.

  1. I can look for the test for pass
  2. Without a pass, it is not possible to encounter the custom op.
gathered_tensor = torch.ops._c10d_functional.all_gather_into_tensor(x, world_size, group_name)
gathered_tensor = torch.ops._c10d_functional.wait_tensor(gathered_tensor)

together form the nccl_ops code, and for that the pass is required

  1. That is what is done right now. Maybe applying the pass in harness is not the right place. Where would you suggest it to be? Since we want to test the converters only, can't harness have a specific option saying distributed=True?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mpirun -n 1 --allow-run-as-root python test_distributed_simple_example.py ideally covers all of it

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually just putting enable_passes=True, works in the converter test. Earlier the constant folding was giving an error, but that is resolved.


OS="$(uname -s)"
ARCH="$(uname -m)"
PYTHON_VERSION="$(python3 -c 'import sys; print(f"cp{sys.version_info.major}{sys.version_info.minor}")')"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does it matter what python version the system is using? Can we just use the same .so for all python versions?

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/distributed/test_nccl_ops.py	2025-03-03 16:25:48.020598+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/distributed/test_nccl_ops.py	2025-03-03 16:26:14.671617+00:00
@@ -15,11 +15,11 @@

from conversion.harness import DispatchTestCase


class TestGatherNcclOpsConverter(DispatchTestCase):
-    @parameterized.expand([(8)])
+    @parameterized.expand([8])
    def test_nccl_ops(self, linear_layer_dim):
        class DistributedGatherModel(nn.Module):
            def __init__(self, input_dim):
                super().__init__()
                self.fc = torch.nn.Linear(input_dim, input_dim)
@@ -40,11 +40,11 @@
            inputs,
            use_dynamo_tracer=True,
            fuse_distributed_ops=True,
        )

-    @parameterized.expand([(8)])
+    @parameterized.expand([8])
    def test_nccl_ops_scatter(self, linear_layer_dim):

        class DistributedReduceScatterModel(nn.Module):
            def __init__(self, input_dim):
                super().__init__()

@apbose
Copy link
Collaborator Author

apbose commented Mar 3, 2025

I am not sure why lint is showing this

Skipping .ipynb files as Jupyter dependencies are not installed.
You can fix this by running ``pip install "black[jupyter]"``
would reformat /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/distributed/test_nccl_ops.py

Oh no! 💥 💔 💥
1 file would be reformatted, 595 files would be left unchanged.
Skipping .ipynb files as Jupyter dependencies are not installed.
You can fix this by running ``pip install "black[jupyter]"``
would reformat /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/distributed/test_nccl_ops.py

All done! ✨ 🍰 ✨
1 file would be reformatted, 595 files would be left unchanged.

I see no error in local pre-commit

@apbose apbose force-pushed the nccl_ops_additional branch from 0f6966f to ca8478a Compare March 3, 2025 17:06
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/distributed/test_nccl_ops.py	2025-03-03 17:06:44.889577+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/distributed/test_nccl_ops.py	2025-03-03 17:07:10.979519+00:00
@@ -15,11 +15,11 @@

from conversion.harness import DispatchTestCase


class TestGatherNcclOpsConverter(DispatchTestCase):
-    @parameterized.expand([(8)])
+    @parameterized.expand([8])
    def test_nccl_ops(self, linear_layer_dim):
        class DistributedGatherModel(nn.Module):
            def __init__(self, input_dim):
                super().__init__()
                self.fc = torch.nn.Linear(input_dim, input_dim)
@@ -40,11 +40,11 @@
            inputs,
            use_dynamo_tracer=True,
            fuse_distributed_ops=True,
        )

-    @parameterized.expand([(8)])
+    @parameterized.expand([8])
    def test_nccl_ops_scatter(self, linear_layer_dim):

        class DistributedReduceScatterModel(nn.Module):
            def __init__(self, input_dim):
                super().__init__()

@apbose apbose force-pushed the nccl_ops_additional branch from ca8478a to 4b79bfb Compare March 3, 2025 18:14
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/distributed/test_nccl_ops.py	2025-03-03 18:15:05.376989+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/distributed/test_nccl_ops.py	2025-03-03 18:15:31.109367+00:00
@@ -15,11 +15,11 @@

from conversion.harness import DispatchTestCase


class TestGatherNcclOpsConverter(DispatchTestCase):
-    @parameterized.expand([(8)])
+    @parameterized.expand([8])
    def test_nccl_ops(self, linear_layer_dim):
        class DistributedGatherModel(nn.Module):
            def __init__(self, input_dim):
                super().__init__()
                self.fc = torch.nn.Linear(input_dim, input_dim)
@@ -40,11 +40,11 @@
            inputs,
            use_dynamo_tracer=True,
            enable_passes=True,
        )

-    @parameterized.expand([(8)])
+    @parameterized.expand([8])
    def test_nccl_ops_scatter(self, linear_layer_dim):

        class DistributedReduceScatterModel(nn.Module):
            def __init__(self, input_dim):
                super().__init__()

@apbose apbose force-pushed the nccl_ops_additional branch from 946a15f to abe373a Compare March 4, 2025 17:54
@apbose apbose force-pushed the nccl_ops_additional branch from abe373a to 3adafaf Compare March 4, 2025 19:13
@@ -92,7 +92,7 @@ class CompilationSettings:
enable_weight_streaming (bool): Enable weight streaming.
enable_cross_compile_for_windows (bool): By default this is False means TensorRT engines can only be executed on the same platform where they were built.
True will enable cross-platform compatibility which allows the engine to be built on Linux and run on Windows
use_aot_joint_export (bool): Use aot_export_joint_simple, else wrap backend with AOT_autograd, required for distributed tensors
USE_DISTRIBUTED_MODE_TRACE (bool): Use aot_export_joint_simple, else wrap backend with AOT_autograd, required for distributed tensors
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should not be all caps?

Comment on lines +58 to +59
"Wrapping the backend with aot_autograd for Distributed examples\n"
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider changing the message to - "Using aot_autograd to trace the graph. Enable this only if the model includes distributed operations"

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure about only,

  1. have a warning for when users should use this trace mode but arent
  2. have an Info level message when this mode is being used Using AOTAutograd tracer for model lowering or something to that extent

@@ -121,7 +131,7 @@ def _pretraced_backend(
)

# Invoke AOTAutograd to translate operators to aten
if settings.use_aot_joint_export:
if not settings.use_distributed_mode_trace:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this check now ? since we return the aot_autograd block before ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we would be needing this since while wrapping it in aot_autograd , we need to pass the pretraced_backend still to the fw_compiler arg in aot_autograd.
aot_print_fn = aot_function(fn, fw_compiler=compiler_fn, bw_compiler=compiler_fn) But then we won't want it to do aot_joint_export tracing there

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: lowering Issues re: The lowering / preprocessing passes component: tests Issues re: Tests component: torch_compile
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants