@@ -16,14 +16,16 @@ class LintCode(Enum):
16
16
def optimize_for_mobile (
17
17
script_module ,
18
18
optimization_blocklist : Set [MobileOptimizerType ] = None ,
19
- preserved_methods : List [AnyStr ] = None ):
19
+ preserved_methods : List [AnyStr ] = None ,
20
+ backend : str = 'CPU' ):
20
21
"""
21
22
Args:
22
23
script_module: An instance of torch script module with type of ScriptModule.
23
24
optimization_blocklist: A set with type of MobileOptimizerType. When set is not passed,
24
25
optimization method will run all the optimizer pass; otherwise, optimizer
25
26
method will run the optimization pass that is not included inside optimization_blocklist.
26
- perserved_methods: A list of methods that needed to be preserved when freeze_module pass is invoked.
27
+ perserved_methods: A list of methods that needed to be preserved when freeze_module pass is invoked
28
+ backend: Device type to use for running the result model ('CPU'(default) or 'Vulkan').
27
29
Returns:
28
30
A new optimized torch script module
29
31
"""
@@ -37,7 +39,13 @@ def optimize_for_mobile(
37
39
if preserved_methods is None :
38
40
preserved_methods = []
39
41
40
- optimized_cpp_module = torch ._C ._jit_pass_optimize_for_mobile (script_module ._c , optimization_blocklist , preserved_methods )
42
+ if backend == 'CPU' :
43
+ optimized_cpp_module = torch ._C ._jit_pass_optimize_for_mobile (script_module ._c , optimization_blocklist , preserved_methods )
44
+ elif backend == 'Vulkan' :
45
+ optimized_cpp_module = torch ._C ._jit_pass_vulkan_optimize_for_mobile (script_module ._c , preserved_methods )
46
+ else :
47
+ raise TypeError ("Unknown backend, must be one of 'CPU', 'Vulkan'" )
48
+
41
49
return torch .jit ._recursive .wrap_cpp_module (optimized_cpp_module )
42
50
43
51
0 commit comments