Skip to content

Commit e9941a5

Browse files
IvanKobzarevfacebook-github-bot
authored andcommittedSep 19, 2020
[vulkan][py] torch.utils.optimize_for_vulkan (pytorch#44903)
Summary: Pull Request resolved: pytorch#44903 Test Plan: Imported from OSS Reviewed By: kimishpatel Differential Revision: D23766039 Pulled By: IvanKobzarev fbshipit-source-id: dbdf484ee7d3a7719aab105efba51b92ebc51568
1 parent 572f7e0 commit e9941a5

File tree

5 files changed

+26
-9
lines changed

5 files changed

+26
-9
lines changed
 

‎torch/_C/__init__.pyi.in

+2
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,8 @@ def _jit_get_operation(op_name: str) -> Callable: ...
156156
def _jit_pass_optimize_for_mobile(module: 'torch.jit.ScriptModule',
157157
optimization_blocklist: Set[MobileOptimizerType],
158158
preserved_methods: List[AnyStr]) -> 'torch.jit.ScriptModule': ...
159+
def _jit_pass_vulkan_optimize_for_mobile(module: 'torch.jit.ScriptModule',
160+
preserved_methods: List[AnyStr]) -> 'torch.jit.ScriptModule': ...
159161
def _jit_pass_inline(Graph) -> None: ...
160162
def _jit_get_schemas_for_operator(name :str) -> List[FunctionSchema]: ...
161163
def _jit_can_fuse_on_cpu() -> _bool: ...

‎torch/csrc/jit/passes/vulkan_rewrite.cpp

+7-3
Original file line numberDiff line numberDiff line change
@@ -159,12 +159,14 @@ void vulkanFoldPrePackingOps(script::Module& m) {
159159
PrePackingOpsFolder(m, filter_fn, "prepack_folding");
160160
}
161161

162-
script::Module vulkanOptimizeForMobile(const script::Module& m) {
162+
script::Module vulkanOptimizeForMobile(
163+
const script::Module& m,
164+
const std::vector<std::string>& preserved_methods) {
163165
auto cloned_module = m.clone();
164166
cloned_module.eval();
165167
cloned_module = FoldConvBatchNorm(cloned_module);
166168
vulkanInsertPrePackedOps(cloned_module);
167-
cloned_module = freeze_module(cloned_module);
169+
cloned_module = freeze_module(cloned_module, preserved_methods);
168170
vulkanFusePrePackedConvWithClamp(cloned_module);
169171
vulkanFoldPrePackingOps(cloned_module);
170172
removeDropout(cloned_module);
@@ -193,7 +195,9 @@ void vulkanFoldPrePackingOps(script::Module& m) {
193195
"Vulkan is not enabled. Please build with USE_VULKAN=1");
194196
}
195197

196-
script::Module vulkanOptimizeForMobile(const script::Module& module) {
198+
script::Module vulkanOptimizeForMobile(
199+
const script::Module& module,
200+
const std::vector<std::string>& preserved_methods) {
197201
TORCH_INTERNAL_ASSERT(
198202
"Mobile optimizaiton only available with Vulkan at the moment. "
199203
"Vulkan is not enabled. Please build with USE_VULKAN=1");

‎torch/csrc/jit/passes/vulkan_rewrite.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ TORCH_API void vulkanInsertPrePackedOps(std::shared_ptr<Graph>& graph);
99
TORCH_API void vulkanInsertPrePackedOps(script::Module& module);
1010
TORCH_API void vulkanFusePrePackedConvWithClamp(script::Module& module);
1111
TORCH_API void vulkanFoldPrePackingOps(script::Module& module);
12-
TORCH_API script::Module vulkanOptimizeForMobile(const script::Module& module);
12+
TORCH_API script::Module vulkanOptimizeForMobile(
13+
const script::Module& module,
14+
const std::vector<std::string>& preserved_methods);
1315
} // namespace jit
1416
} // namespace torch

‎torch/csrc/jit/python/init.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -647,8 +647,9 @@ void initJITBindings(PyObject* module) {
647647
})
648648
.def(
649649
"_jit_pass_vulkan_optimize_for_mobile",
650-
[](script::Module& module) {
651-
return vulkanOptimizeForMobile(module);
650+
[](script::Module& module,
651+
std::vector<std::string>& preserved_methods) {
652+
return vulkanOptimizeForMobile(module, preserved_methods);
652653
})
653654
.def(
654655
"_jit_pass_onnx_unpack_quantized_weights",

‎torch/utils/mobile_optimizer.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,16 @@ class LintCode(Enum):
1616
def optimize_for_mobile(
1717
script_module,
1818
optimization_blocklist: Set[MobileOptimizerType] = None,
19-
preserved_methods: List[AnyStr] = None):
19+
preserved_methods: List[AnyStr] = None,
20+
backend: str = 'CPU'):
2021
"""
2122
Args:
2223
script_module: An instance of torch script module with type of ScriptModule.
2324
optimization_blocklist: A set with type of MobileOptimizerType. When set is not passed,
2425
optimization method will run all the optimizer pass; otherwise, optimizer
2526
method will run the optimization pass that is not included inside optimization_blocklist.
26-
perserved_methods: A list of methods that needed to be preserved when freeze_module pass is invoked.
27+
perserved_methods: A list of methods that needed to be preserved when freeze_module pass is invoked
28+
backend: Device type to use for running the result model ('CPU'(default) or 'Vulkan').
2729
Returns:
2830
A new optimized torch script module
2931
"""
@@ -37,7 +39,13 @@ def optimize_for_mobile(
3739
if preserved_methods is None:
3840
preserved_methods = []
3941

40-
optimized_cpp_module = torch._C._jit_pass_optimize_for_mobile(script_module._c, optimization_blocklist, preserved_methods)
42+
if backend == 'CPU':
43+
optimized_cpp_module = torch._C._jit_pass_optimize_for_mobile(script_module._c, optimization_blocklist, preserved_methods)
44+
elif backend == 'Vulkan':
45+
optimized_cpp_module = torch._C._jit_pass_vulkan_optimize_for_mobile(script_module._c, preserved_methods)
46+
else:
47+
raise TypeError("Unknown backend, must be one of 'CPU', 'Vulkan'")
48+
4149
return torch.jit._recursive.wrap_cpp_module(optimized_cpp_module)
4250

4351

0 commit comments

Comments
 (0)