|
20 | 20 | quantize_static_fx,
|
21 | 21 | quantize_dynamic_fx,
|
22 | 22 | prepare_qat_fx,
|
| 23 | + register_observed_custom_module_mapping, |
| 24 | + register_quantized_custom_module_mapping, |
23 | 25 | )
|
24 | 26 |
|
25 | 27 | from torch.quantization import (
|
@@ -482,6 +484,140 @@ def forward(self, x):
|
482 | 484 | # Verify that loaded state dict produces same results.
|
483 | 485 | self.assertEqual(quant(x), quant_2(x))
|
484 | 486 |
|
| 487 | + @skipIfNoFBGEMM |
| 488 | + def test_custom_module_class(self): |
| 489 | + class CustomModule(torch.nn.Module): |
| 490 | + def __init__(self): |
| 491 | + super().__init__() |
| 492 | + self.conv = torch.nn.Conv2d(1, 1, 1) |
| 493 | + |
| 494 | + def forward(self, x): |
| 495 | + return self.conv(x) |
| 496 | + |
| 497 | + class ObservedCustomModule(torch.nn.Module): |
| 498 | + def __init__(self, conv): |
| 499 | + super().__init__() |
| 500 | + self.conv = conv |
| 501 | + |
| 502 | + def forward(self, x): |
| 503 | + return self.conv(x) |
| 504 | + |
| 505 | + @classmethod |
| 506 | + def from_float(cls, float_module): |
| 507 | + assert hasattr(float_module, 'qconfig') |
| 508 | + observed = cls(float_module.conv) |
| 509 | + observed.qconfig = float_module.qconfig |
| 510 | + return observed |
| 511 | + |
| 512 | + class QuantizedCustomModule(torch.nn.Module): |
| 513 | + def __init__(self, conv): |
| 514 | + super().__init__() |
| 515 | + self.conv = conv |
| 516 | + |
| 517 | + def forward(self, x): |
| 518 | + return self.conv(x) |
| 519 | + |
| 520 | + @classmethod |
| 521 | + def from_observed(cls, observed_module): |
| 522 | + assert hasattr(observed_module, 'qconfig') |
| 523 | + assert hasattr(observed_module, 'activation_post_process') |
| 524 | + observed_module.conv.activation_post_process = \ |
| 525 | + observed_module.activation_post_process |
| 526 | + quantized = cls(nnq.Conv2d.from_float(observed_module.conv)) |
| 527 | + return quantized |
| 528 | + |
| 529 | + class DynamicallyQuantizedCustomModule(torch.nn.Module): |
| 530 | + def __init__(self, conv): |
| 531 | + super().__init__() |
| 532 | + self.conv = conv |
| 533 | + |
| 534 | + def forward(self, x): |
| 535 | + return self.conv(x) |
| 536 | + |
| 537 | + @classmethod |
| 538 | + def from_observed(cls, observed_module): |
| 539 | + assert hasattr(observed_module, 'qconfig') |
| 540 | + assert hasattr(observed_module, 'activation_post_process') |
| 541 | + quantized = cls(nnqd.Conv2d.from_float(observed_module.conv)) |
| 542 | + return quantized |
| 543 | + |
| 544 | + class M(torch.nn.Module): |
| 545 | + def __init__(self): |
| 546 | + super().__init__() |
| 547 | + self.conv = torch.nn.Conv2d(1, 1, 1) |
| 548 | + self.custom = CustomModule() |
| 549 | + |
| 550 | + def forward(self, x): |
| 551 | + x = self.conv(x) |
| 552 | + x = self.custom(x) |
| 553 | + return x |
| 554 | + |
| 555 | + class RefM(torch.nn.Module): |
| 556 | + def __init__(self): |
| 557 | + super().__init__() |
| 558 | + self.conv1 = torch.nn.Conv2d(1, 1, 1) |
| 559 | + self.conv2 = torch.nn.Conv2d(1, 1, 1) |
| 560 | + |
| 561 | + def forward(self, x): |
| 562 | + x = self.conv1(x) |
| 563 | + x = self.conv2(x) |
| 564 | + return x |
| 565 | + |
| 566 | + data = torch.randn(1, 1, 1, 1) |
| 567 | + # instantiate M and RefM and align the parameters |
| 568 | + original_m = M() |
| 569 | + original_ref_m = RefM() |
| 570 | + original_ref_m.conv1.weight = torch.nn.Parameter(original_m.conv.weight.detach()) |
| 571 | + original_ref_m.conv1.bias = torch.nn.Parameter(original_m.conv.bias.detach()) |
| 572 | + original_ref_m.conv2.weight = torch.nn.Parameter(original_m.custom.conv.weight.detach()) |
| 573 | + original_ref_m.conv2.bias = torch.nn.Parameter(original_m.custom.conv.bias.detach()) |
| 574 | + |
| 575 | + from torch.fx.symbolic_trace import Tracer |
| 576 | + |
| 577 | + # define a custom tracer to not trace through the custom module |
| 578 | + |
| 579 | + class CustomTracer(Tracer): |
| 580 | + def is_leaf_module(self, m, module_qualified_name): |
| 581 | + return (m.__module__.startswith('torch.nn') and |
| 582 | + not isinstance(m, torch.nn.Sequential)) or \ |
| 583 | + isinstance(m, CustomModule) |
| 584 | + |
| 585 | + # TODO: add other quant types after mixed mode support |
| 586 | + for quant_type in [QuantType.STATIC]: |
| 587 | + # register observed and quantized custom module classes |
| 588 | + register_observed_custom_module_mapping(CustomModule, ObservedCustomModule) |
| 589 | + register_quantized_custom_module_mapping(CustomModule, QuantizedCustomModule) |
| 590 | + |
| 591 | + m = CustomTracer().trace(original_m).eval() |
| 592 | + qconfig_dict = {'': default_qconfig} |
| 593 | + # check prepared model |
| 594 | + m = prepare_static_fx(m, qconfig_dict) |
| 595 | + # calibration |
| 596 | + m(data) |
| 597 | + # all activation observers are inserted in the top level module |
| 598 | + count_check = { |
| 599 | + ns.call_module(torch.quantization.MinMaxObserver): 3 |
| 600 | + } |
| 601 | + self.checkGraphModuleNodes(m, expected_node_occurrence=count_check) |
| 602 | + |
| 603 | + # check converted/quantized model |
| 604 | + m = convert_static_fx(m) |
| 605 | + count_check = { |
| 606 | + ns.call_function(torch.quantize_per_tensor) : 1, |
| 607 | + ns.call_module(nnq.Conv2d) : 1, |
| 608 | + ns.call_method('dequantize') : 1, |
| 609 | + } |
| 610 | + self.checkGraphModuleNodes(m, expected_node_occurrence=count_check) |
| 611 | + res = m(data) |
| 612 | + |
| 613 | + # quantize the reference model |
| 614 | + ref_m = symbolic_trace(original_ref_m).eval() |
| 615 | + ref_m = prepare_fx(ref_m, qconfig_dict) |
| 616 | + ref_m(data) |
| 617 | + ref_m = convert_fx(ref_m) |
| 618 | + ref_res = ref_m(data) |
| 619 | + self.assertEqual(res, ref_res) |
| 620 | + |
485 | 621 | class TestQuantizeFxOps(QuantizationTestCase):
|
486 | 622 | """Unit tests for individual ops
|
487 | 623 | """
|
|
0 commit comments