@@ -196,12 +196,11 @@ def __init__(
196
196
}
197
197
self .arg_dynamic_shapes : Optional [tuple [Any ]] = None
198
198
self .kwarg_dynamic_shapes : Optional [dict [Any , Any ]] = None
199
- self .total_dynamic_shape : Optional [dict [Any , Any ]] = None
200
199
201
200
self .settings = CompilationSettings (** compilation_options )
202
201
self .run_info : Optional [tuple [Any , ...]] = None
203
202
self .state_dict_metadata : dict [str , torch .Size ] = {}
204
- self .store_state_dict_metadata ()
203
+ self ._store_state_dict_metadata ()
205
204
206
205
cls = self .__class__
207
206
self .__class__ = type (
@@ -211,11 +210,31 @@ def __init__(
211
210
)
212
211
self .init_finished = True
213
212
214
- def set_dynamic_shape_hint (
213
+ def set_expected_dynamic_shape_range (
215
214
self ,
216
215
args_dynamic_shape : tuple [dict [Any , Any ]],
217
216
kwargs_dynamic_shape : dict [str , Any ],
218
217
) -> None :
218
+ """
219
+ Set the dynamic shape range. The shape hint should EXACTLY follow arg_inputs and kwarg_inputs passed to the forward function
220
+ and should not omit any entries. If the dynamic shape is not required for the input, an empty dictionary should be given
221
+ as the shape hint for that input.
222
+
223
+ Example:
224
+ def forward(a, b, c=0, d=0):
225
+ pass
226
+
227
+ seq_len = torch.export.Dim("seq_len", min=1, max=10)
228
+ args_dynamic_shape = ({0: seq_len}, {}) # b does not have a dynamic shape
229
+ kwargs_dynamic_shape = {'c': {0, seq_len}, 'd': {}} # d does not have a dynamic shape
230
+ # Later when you call the function
231
+ forward(*(a, b), **{c:..., d:...})
232
+
233
+
234
+ Arguments:
235
+ args_dynamic_shape (tuple[dict[Any, Any]]): Dynamic shape hint for the arg_inputs,
236
+ kwargs_dynamic_shape: (dict[str, Any]): Dynamic shape hint for the kwarg_inputs
237
+ """
219
238
assert isinstance (
220
239
args_dynamic_shape , tuple
221
240
), "args dynamic shape has to be a tuple"
@@ -224,19 +243,31 @@ def set_dynamic_shape_hint(
224
243
), "args dynamic shape has to be a dictionary"
225
244
self .kwarg_dynamic_shapes = kwargs_dynamic_shape
226
245
self .arg_dynamic_shapes = args_dynamic_shape
227
- self .total_dynamic_shape = self .kwarg_dynamic_shapes .copy ()
228
- signature = list (
229
- inspect .signature (self .original_model .forward ).parameters .keys ()
230
- )
231
- for i , arg in enumerate (self .arg_dynamic_shapes ):
232
- self .total_dynamic_shape [signature [i ]] = arg
233
- self .refit_state .set_state (RefitFlag .NEEDS_RECOMPILE )
234
246
235
247
# Clear cached inputs
236
248
self .arg_inputs = tuple ()
237
249
self .kwarg_inputs = {}
238
250
239
- def store_state_dict_metadata (self ) -> None :
251
+ self .refit_state .set_state (RefitFlag .NEEDS_RECOMPILE )
252
+
253
+ def _get_total_dynamic_shapes (self ) -> dict [str , Any ] | None :
254
+ if not self .arg_dynamic_shapes and not self .kwarg_dynamic_shapes :
255
+ return None
256
+ total_dynamic_shape = {}
257
+ if self .arg_dynamic_shapes :
258
+ signature = list (
259
+ inspect .signature (self .original_model .forward ).parameters .keys ()
260
+ )
261
+ for i , arg in enumerate (self .arg_dynamic_shapes ):
262
+ total_dynamic_shape [signature [i ]] = arg
263
+
264
+ if self .kwarg_dynamic_shapes :
265
+ for kwargs , kwargs_dynamic_shape in self .kwarg_dynamic_shapes .items ():
266
+ total_dynamic_shape [kwargs ] = kwargs_dynamic_shape
267
+
268
+ return total_dynamic_shape
269
+
270
+ def _store_state_dict_metadata (self ) -> None :
240
271
for k , v in self .original_model .state_dict ().items ():
241
272
self .state_dict_metadata [k ] = v .shape
242
273
@@ -328,7 +359,7 @@ def compile(self) -> None:
328
359
self .original_model ,
329
360
self .arg_inputs ,
330
361
kwargs = self .kwarg_inputs ,
331
- dynamic_shapes = self .total_dynamic_shape ,
362
+ dynamic_shapes = self ._get_total_dynamic_shapes () ,
332
363
)
333
364
self .gm = dynamo_compile (
334
365
self .exp_program ,
@@ -340,40 +371,75 @@ def compile(self) -> None:
340
371
torch .cuda .empty_cache ()
341
372
342
373
def _validate_inputs (self , * args : Any , ** kwargs : Any ) -> None :
374
+
375
+ if not self .arg_inputs :
376
+ logger .info ("First time compilation initiated. This may take some time." )
377
+ self .refit_state .set_state (RefitFlag .NEEDS_RECOMPILE )
378
+ self ._store_inputs (args , kwargs )
379
+ if self .arg_dynamic_shapes or self .kwarg_dynamic_shapes :
380
+ if not self ._validates_dynamic_hints ():
381
+ logger .warning (
382
+ "Invalid dynamic shape hint. Compiling module for the provided input shapes (static)"
383
+ )
384
+ self .arg_dynamic_shapes = None
385
+ self .kwarg_dynamic_shapes = None
386
+ return
387
+
388
+ # If input does not equal or does not fall into dynamic shape range, recompile the engine
343
389
try :
344
- if (
345
- not self .arg_inputs
346
- or not MutableTorchTensorRTModule .check_inputs_equal (
347
- self .arg_inputs , args , dynamic_shapes = self .arg_dynamic_shapes
348
- )
349
- or not MutableTorchTensorRTModule .check_inputs_equal (
350
- self .kwarg_inputs , kwargs , dynamic_shapes = self .kwarg_dynamic_shapes
351
- )
390
+ if not MutableTorchTensorRTModule ._check_inputs_shape (
391
+ self .arg_inputs , args , dynamic_shapes = self .arg_dynamic_shapes
392
+ ) or not MutableTorchTensorRTModule ._check_inputs_shape (
393
+ self .kwarg_inputs , kwargs , dynamic_shapes = self .kwarg_dynamic_shapes
352
394
):
353
395
logger .info ("Input change detected." )
354
396
self .refit_state .set_state (RefitFlag .NEEDS_RECOMPILE )
355
- self .store_inputs (args , kwargs )
397
+ self ._store_inputs (args , kwargs )
356
398
except DynamicShapeOutOfRangeException as e :
357
399
logger .info ("Input change detected." )
358
400
logger .warning (e )
359
- logger .warning ("Recompiling the engine with static shape" )
401
+ logger .warning (
402
+ "Provided inputs are outside the set expected shape range, recompiling module for the provided input shapes (static)"
403
+ )
360
404
self .arg_dynamic_shapes = None
361
405
self .kwarg_dynamic_shapes = None
362
- self .total_dynamic_shape = None
363
406
self .refit_state .set_state (RefitFlag .NEEDS_RECOMPILE )
364
- self .store_inputs (args , kwargs )
407
+ self ._store_inputs (args , kwargs )
365
408
366
- def store_inputs (self , arg_inputs : Any , kwarg_inputs : Any ) -> None :
409
+ def _validates_dynamic_hints (self ) -> bool :
410
+ if self .arg_dynamic_shapes is None :
411
+ if self .arg_inputs :
412
+ logger .warning ("arg_dynamic_shape is not provided!" )
413
+ else :
414
+ if len (self .arg_dynamic_shapes ) != len (self .arg_inputs ):
415
+ logger .warning (
416
+ f"Warning: The length of arg_inputs is { len (self .arg_inputs )} but the length of arg_dynamic_shape is { len (self .arg_dynamic_shapes )} !"
417
+ )
418
+ return False
419
+
420
+ if self .kwarg_dynamic_shapes is None :
421
+ if self .kwarg_inputs :
422
+ logger .warning ("kwarg_dynamic_shape is not provided!" )
423
+ else :
424
+ if self .kwarg_dynamic_shapes .keys () != self .kwarg_inputs .keys ():
425
+ logger .warning (
426
+ f"kwarg_inputs has { list (self .kwarg_inputs .keys ())} but kwarg_dynamic_shape has { list (self .kwarg_dynamic_shapes .keys ())} !"
427
+ )
428
+ return False
429
+
430
+ return True
431
+
432
+ def _store_inputs (self , arg_inputs : Any , kwarg_inputs : Any ) -> None :
367
433
self .arg_inputs = arg_inputs
368
434
self .kwarg_inputs = kwarg_inputs
369
435
370
436
@staticmethod
371
- def process_kwarg_inputs (inputs : Any ) -> Any :
437
+ def _process_kwarg_inputs (inputs : Any ) -> Any :
372
438
# Process kwarg inputs to be acceptable for Torch-TensorRT
373
439
if isinstance (inputs , dict ):
374
440
# None should be excluded. AOT compile also does not allow dynamic control flow, bool is also excluded.
375
441
return {
376
- k : MutableTorchTensorRTModule .process_kwarg_inputs (v )
442
+ k : MutableTorchTensorRTModule ._process_kwarg_inputs (v )
377
443
for k , v in inputs .items ()
378
444
if (v is not None and not isinstance (v , bool ))
379
445
}
@@ -384,7 +450,10 @@ def process_kwarg_inputs(inputs: Any) -> Any:
384
450
elif isinstance (inputs , (list , tuple )):
385
451
if None not in inputs :
386
452
return type (inputs )(
387
- [MutableTorchTensorRTModule .process_kwarg_inputs (v ) for v in inputs ]
453
+ [
454
+ MutableTorchTensorRTModule ._process_kwarg_inputs (v )
455
+ for v in inputs
456
+ ]
388
457
)
389
458
390
459
raise ValueError (
@@ -394,7 +463,7 @@ def process_kwarg_inputs(inputs: Any) -> Any:
394
463
395
464
def forward (self , * args : Any , ** kwargs : Any ) -> Any :
396
465
# Step 1: Check whether the input shape has changed
397
- kwargs = MutableTorchTensorRTModule .process_kwarg_inputs (kwargs )
466
+ kwargs = MutableTorchTensorRTModule ._process_kwarg_inputs (kwargs )
398
467
self ._validate_inputs (* args , ** kwargs )
399
468
400
469
# Step 2: If the flag is unknown, it could be a recompile or refit.
@@ -406,7 +475,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
406
475
if self .refit_state .get_state () == RefitFlag .NEEDS_RECOMPILE :
407
476
logger .info ("(Re)Compiling the engine..." )
408
477
self .compile ()
409
- self .store_state_dict_metadata ()
478
+ self ._store_state_dict_metadata ()
410
479
self .refit_state .set_state (RefitFlag .LIVE )
411
480
412
481
elif self .refit_state .get_state () == RefitFlag .NEEDS_REFIT :
@@ -417,7 +486,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
417
486
logger .error (e )
418
487
logger .error ("Model refit failed. Recompiling the graph module." )
419
488
self .compile ()
420
- self .store_state_dict_metadata ()
489
+ self ._store_state_dict_metadata ()
421
490
self .refit_state .set_state (RefitFlag .LIVE )
422
491
423
492
result = self .gm (* args , ** kwargs )
@@ -427,7 +496,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
427
496
428
497
def to (self , device : str ) -> None :
429
498
logger .warning ("Original PyTorch model is moved. CPU offload may failed." )
430
- self .orignial_model .to (device )
499
+ self .original_model .to (device )
431
500
432
501
def __deepcopy__ (self , memo : Any ) -> Any :
433
502
cls = self .__class__
@@ -479,7 +548,7 @@ def __setattr__(self, name: str, value: Any) -> None:
479
548
object .__setattr__ (self , name , value )
480
549
481
550
@staticmethod
482
- def check_inputs_equal (
551
+ def _check_inputs_shape (
483
552
input1 : Any ,
484
553
input2 : Any ,
485
554
dynamic_shapes : Any = None ,
@@ -495,10 +564,13 @@ def check_inputs_equal(
495
564
return False
496
565
elif isinstance (a , torch .Tensor ) and a .shape != b .shape :
497
566
if dynamic_shapes is None :
567
+ logger .warning (
568
+ "Dynamic shape is not properly set but the input shape is changed!"
569
+ )
498
570
return False
499
571
else :
500
572
tensor_dynamic_shape = dynamic_shapes [i ]
501
- if not MutableTorchTensorRTModule .check_tensor_shapes_with_dynamic_shapes (
573
+ if not MutableTorchTensorRTModule ._check_tensor_shapes_with_dynamic_shapes (
502
574
a , b , tensor_dynamic_shape
503
575
):
504
576
return False
@@ -513,28 +585,34 @@ def check_inputs_equal(
513
585
return False
514
586
elif isinstance (va , torch .Tensor ) and va .shape != vb .shape :
515
587
if dynamic_shapes is None :
588
+ logger .warning (
589
+ "Dynamic shape is not properly set but the input shape is changed!"
590
+ )
516
591
return False
517
592
else :
518
593
tensor_dynamic_shape = dynamic_shapes [ka ]
519
- if not MutableTorchTensorRTModule .check_tensor_shapes_with_dynamic_shapes (
594
+ if not MutableTorchTensorRTModule ._check_tensor_shapes_with_dynamic_shapes (
520
595
va , vb , tensor_dynamic_shape
521
596
):
522
597
return False
523
598
elif isinstance (
524
599
va , (list , tuple , dict )
525
- ) and not MutableTorchTensorRTModule .check_inputs_equal (
600
+ ) and not MutableTorchTensorRTModule ._check_inputs_shape (
526
601
va , vb , dynamic_shapes [ka ] if dynamic_shapes else None
527
602
):
528
603
return False
529
604
return True
530
605
531
606
@staticmethod
532
- def check_tensor_shapes_with_dynamic_shapes (
607
+ def _check_tensor_shapes_with_dynamic_shapes (
533
608
t1 : torch .tensor , t2 : torch .tensor , dynamic_shape : dict [int , Any ]
534
609
) -> bool :
535
610
for (i , axis_0 ), axis_1 in zip (enumerate (t1 .shape ), t2 .shape ):
536
611
if axis_0 != axis_1 :
537
612
if i not in dynamic_shape :
613
+ logger .warning (
614
+ "Dynamic shape does not include the axis on which input changes!"
615
+ )
538
616
return False
539
617
dyn = dynamic_shape [i ]
540
618
if axis_1 > dyn .max or axis_1 < dyn .min :
0 commit comments