Skip to content

compile the QAT trained model in with TensorRT #3622

Open
@pribadihcr

Description

@pribadihcr

Bug Description

Run the following notebook:

https://github.com/pytorch/TensorRT/blob/main/notebooks/qat-ptq-workflow.ipynb

when execute this section:

quant_nn.TensorQuantizer.use_fb_fake_quant = True
with torch.no_grad():
    data = iter(val_dataloader)
    images, _ = data.next()
    jit_model = torch.jit.trace(q_model, images.to("cuda"))
    torch.jit.save(jit_model, "mobilenetv2_qat.jit.pt")

Got the following error:

   torch.jit.save(jit_model, "mobilenetv2_qat.jit.pt")
  File "anaconda3/lib/python3.11/site-packages/torch/jit/_serialization.py", line 84, in save
    m.save(f, _extra_files=_extra_files)
  File "anaconda3/lib/python3.11/site-packages/torch/jit/_script.py", line 754, in save
    return self._c.save(str(f), **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: 
Could not export Python function call 'FakeTensorQuantFunction'. Remove calls to Python functions before export. Did you forget to add @script or @script_method annotation? If this is a nn.ModuleList, add it to __constants__:

I solve the issue with:

def disable_fake_quant(model):
    for mod in model.modules():
        if isinstance(mod, quant_nn.TensorQuantizer):
            mod.disable()

call disable_fake_quant function

disable_fake_quant(q_model)
quant_nn.TensorQuantizer.use_fb_fake_quant = True
with torch.no_grad():
    data = iter(val_dataloader)
    images, _ = data.next()
    jit_model = torch.jit.trace(q_model, images.to("cuda"))
    torch.jit.save(jit_model, "mobilenetv2_qat.jit.pt")

Now I get the following error:

ERROR: [Torch-TensorRT TorchScript Conversion Context] - IBuilder::buildSerializedNetwork: Error Code 4: Internal Error (Calibration failure occurred with no scaling factors detected. This could be due to no int8 calibrator or insufficient custom scales for network layers. Please see int8 sample to setup calibration correctly.)
Traceback (most recent call last):
  trt_mod = torch_tensorrt.compile(qat_model, **compile_spec)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "anaconda3/lib/python3.11/site-packages/torch_tensorrt/_compile.py", line 212, in compile
    compiled_ts_module: torch.jit.ScriptModule = torchscript_compile(
                                                 ^^^^^^^^^^^^^^^^^^^^
  File "anaconda3/lib/python3.11/site-packages/torch_tensorrt/ts/_compiler.py", line 154, in compile
    compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: [Error thrown at core/conversion/conversionctx/ConversionCtx.cpp:169] Building serialized network failed in TensorRT

environments:
torch_tensorrt == 2.5.0+cu118

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions