1
+ import inspect
1
2
import logging
2
3
from copy import deepcopy
3
4
from enum import Enum , auto
@@ -41,6 +42,10 @@ def get_state(self) -> RefitFlag:
41
42
return self ._state
42
43
43
44
45
+ class DynamicShapeOutOfRangeException (Exception ):
46
+ pass
47
+
48
+
44
49
class MutableTorchTensorRTModule (object ):
45
50
"""
46
51
Initialize a MutableTorchTensorRTModule to seamlessly manipulate it like a regular PyTorch module.
@@ -65,7 +70,7 @@ def __init__(
65
70
Union [torch .dtype , dtype ]
66
71
] = _defaults .ENABLED_PRECISIONS ,
67
72
engine_capability : EngineCapability = _defaults .ENGINE_CAPABILITY ,
68
- immutable_weights : bool = _defaults . IMMUTABLE_WEIGHTS ,
73
+ immutable_weights : bool = False ,
69
74
debug : bool = _defaults .DEBUG ,
70
75
num_avg_timing_iters : int = _defaults .NUM_AVG_TIMING_ITERS ,
71
76
workspace_size : int = _defaults .WORKSPACE_SIZE ,
@@ -189,6 +194,9 @@ def __init__(
189
194
"hardware_compatible" : hardware_compatible ,
190
195
"timing_cache_path" : timing_cache_path ,
191
196
}
197
+ self .arg_dynamic_shapes : Optional [tuple [Any ]] = None
198
+ self .kwarg_dynamic_shapes : Optional [dict [Any , Any ]] = None
199
+ self .total_dynamic_shape : Optional [dict [Any , Any ]] = None
192
200
193
201
self .settings = CompilationSettings (** compilation_options )
194
202
self .run_info : Optional [tuple [Any , ...]] = None
@@ -203,6 +211,27 @@ def __init__(
203
211
)
204
212
self .init_finished = True
205
213
214
+ def set_dynamic_shape_hint (
215
+ self ,
216
+ args_dynamic_shape : tuple [dict [Any , Any ]],
217
+ kwargs_dynamic_shape : dict [str , Any ],
218
+ ) -> None :
219
+ assert isinstance (
220
+ args_dynamic_shape , tuple
221
+ ), "args dynamic shape has to be a tuple"
222
+ assert isinstance (
223
+ kwargs_dynamic_shape , dict
224
+ ), "args dynamic shape has to be a dictionary"
225
+ self .kwarg_dynamic_shapes = kwargs_dynamic_shape
226
+ self .arg_dynamic_shapes = args_dynamic_shape
227
+ self .total_dynamic_shape = self .kwarg_dynamic_shapes .copy ()
228
+ signature = list (
229
+ inspect .signature (self .original_model .forward ).parameters .keys ()
230
+ )
231
+ for i , arg in enumerate (self .arg_dynamic_shapes ):
232
+ self .total_dynamic_shape [signature [i ]] = arg
233
+ self .refit_state .set_state (RefitFlag .NEEDS_RECOMPILE )
234
+
206
235
def store_state_dict_metadata (self ) -> None :
207
236
for k , v in self .original_model .state_dict ().items ():
208
237
self .state_dict_metadata [k ] = v .shape
@@ -295,6 +324,7 @@ def compile(self) -> None:
295
324
self .original_model ,
296
325
self .arg_inputs ,
297
326
kwargs = self .kwarg_inputs ,
327
+ dynamic_shapes = self .total_dynamic_shape ,
298
328
)
299
329
self .gm = dynamo_compile (
300
330
self .exp_program ,
@@ -306,14 +336,26 @@ def compile(self) -> None:
306
336
torch .cuda .empty_cache ()
307
337
308
338
def _validate_inputs (self , * args : Any , ** kwargs : Any ) -> None :
309
- if (
310
- not self .arg_inputs
311
- or not MutableTorchTensorRTModule .check_inputs_equal (self .arg_inputs , args )
312
- or not MutableTorchTensorRTModule .check_inputs_equal (
313
- self .kwarg_inputs , kwargs
314
- )
315
- ):
339
+ try :
340
+ if (
341
+ not self .arg_inputs
342
+ or not MutableTorchTensorRTModule .check_inputs_equal (
343
+ self .arg_inputs , args , dynamic_shapes = self .arg_dynamic_shapes
344
+ )
345
+ or not MutableTorchTensorRTModule .check_inputs_equal (
346
+ self .kwarg_inputs , kwargs , dynamic_shapes = self .kwarg_dynamic_shapes
347
+ )
348
+ ):
349
+ logger .info ("Input change detected." )
350
+ self .refit_state .set_state (RefitFlag .NEEDS_RECOMPILE )
351
+ self .store_inputs (args , kwargs )
352
+ except DynamicShapeOutOfRangeException as e :
316
353
logger .info ("Input change detected." )
354
+ logger .warning (e )
355
+ logger .warning ("Recompiling the engine with static shape" )
356
+ self .arg_dynamic_shapes = None
357
+ self .kwarg_dynamic_shapes = None
358
+ self .total_dynamic_shape = None
317
359
self .refit_state .set_state (RefitFlag .NEEDS_RECOMPILE )
318
360
self .store_inputs (args , kwargs )
319
361
@@ -436,33 +478,66 @@ def __setattr__(self, name: str, value: Any) -> None:
436
478
def check_inputs_equal (
437
479
input1 : Any ,
438
480
input2 : Any ,
481
+ dynamic_shapes : Any = None ,
439
482
) -> bool :
440
- # TODO: Add support for dynamic shape
483
+
441
484
if isinstance (input1 , (tuple , list )):
442
485
if len (input1 ) != len (input2 ):
443
486
return False
444
- for a , b in zip (input1 , input2 ):
487
+ for ( i , a ) , b in zip (enumerate ( input1 ) , input2 ):
445
488
if type (a ) != type (b ):
446
489
return False
447
- if isinstance (a , torch .Tensor ) and a .shape != b .shape :
448
- return False
449
- elif isinstance (a , bool ) and a != b :
490
+ if isinstance (a , bool ) and a != b :
450
491
return False
492
+ elif isinstance (a , torch .Tensor ) and a .shape != b .shape :
493
+ if dynamic_shapes is None :
494
+ return False
495
+ else :
496
+ tensor_dynamic_shape = dynamic_shapes [i ]
497
+ if not MutableTorchTensorRTModule .check_tensor_shapes_with_dynamic_shapes (
498
+ a , b , tensor_dynamic_shape
499
+ ):
500
+ return False
451
501
452
502
elif isinstance (input1 , dict ):
453
503
if input1 .keys () != input2 .keys ():
454
504
return False
455
- for a , b in zip (input1 .values (), input2 .values ()):
456
- if type (a ) != type (b ):
505
+ for ( ka , va ), vb in zip (input1 .items (), input2 .values ()):
506
+ if type (va ) != type (vb ):
457
507
return False
458
- if isinstance (a , torch .Tensor ) and a .shape != b .shape :
459
- return False
460
- elif isinstance (a , bool ) and a != b :
508
+ if isinstance (va , bool ) and va != vb :
461
509
return False
510
+ elif isinstance (va , torch .Tensor ) and va .shape != vb .shape :
511
+ if dynamic_shapes is None :
512
+ return False
513
+ else :
514
+ tensor_dynamic_shape = dynamic_shapes [ka ]
515
+ if not MutableTorchTensorRTModule .check_tensor_shapes_with_dynamic_shapes (
516
+ va , vb , tensor_dynamic_shape
517
+ ):
518
+ return False
462
519
elif isinstance (
463
- a , (list , tuple , dict )
464
- ) and not MutableTorchTensorRTModule .check_inputs_equal (a , b ):
520
+ va , (list , tuple , dict )
521
+ ) and not MutableTorchTensorRTModule .check_inputs_equal (
522
+ va , vb , dynamic_shapes [ka ] if dynamic_shapes else None
523
+ ):
524
+ return False
525
+ return True
526
+
527
+ @staticmethod
528
+ def check_tensor_shapes_with_dynamic_shapes (
529
+ t1 : torch .tensor , t2 : torch .tensor , dynamic_shape : dict [int , Any ]
530
+ ) -> bool :
531
+ for (i , axis_0 ), axis_1 in zip (enumerate (t1 .shape ), t2 .shape ):
532
+ if axis_0 != axis_1 :
533
+ if i not in dynamic_shape :
465
534
return False
535
+ dyn = dynamic_shape [i ]
536
+ if axis_1 > dyn .max or axis_1 < dyn .min :
537
+ raise DynamicShapeOutOfRangeException (
538
+ f"The input size ({ axis_1 } ) of dimension ({ i } ) is not in dynamic shape range [{ dyn .max } , { dyn .max } ]!"
539
+ )
540
+
466
541
return True
467
542
468
543
@staticmethod
0 commit comments