Skip to content

Commit d1a1161

Browse files
bwastifacebook-github-bot
authored andcommittedSep 25, 2020
[static runtime] Add _out variants and reuse memory (pytorch#44128)
Summary: Pull Request resolved: pytorch#44128 Test Plan: Imported from OSS Reviewed By: hlu1 Differential Revision: D23604304 Pulled By: bwasti fbshipit-source-id: 06a23cb75700a0fc733069071843b7b498e7b9e9
1 parent d1d9017 commit d1a1161

File tree

8 files changed

+192
-65
lines changed

8 files changed

+192
-65
lines changed
 
+5-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1-
list(APPEND STATIC_RUNTIME_BENCHMARK_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/deep_wide_pt_bench.cc)
21
list(APPEND STATIC_RUNTIME_BENCHMARK_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/deep_wide_pt.cc)
2+
list(APPEND STATIC_RUNTIME_BENCHMARK_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/deep_wide_pt_bench.cc)
33
set(STATIC_RUNTIME_BENCHMARK_SRCS ${STATIC_RUNTIME_BENCHMARK_SRCS} PARENT_SCOPE)
4+
5+
list(APPEND STATIC_RUNTIME_TEST_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/deep_wide_pt.cc)
6+
list(APPEND STATIC_RUNTIME_TEST_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/test_static_runtime.cc)
7+
set(STATIC_RUNTIME_TEST_SRCS ${STATIC_RUNTIME_TEST_SRCS} PARENT_SCOPE)

‎caffe2/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -1247,7 +1247,9 @@ endif()
12471247
if(BUILD_STATIC_RUNTIME_BENCHMARK)
12481248
add_subdirectory(${TORCH_ROOT}/benchmarks/static_runtime ${PROJECT_BINARY_DIR}/bin)
12491249
add_executable(static_runtime_bench "${STATIC_RUNTIME_BENCHMARK_SRCS}")
1250+
add_executable(static_runtime_test "${STATIC_RUNTIME_TEST_SRCS}")
12501251
target_link_libraries(static_runtime_bench torch_library benchmark)
1252+
target_link_libraries(static_runtime_test torch_library gtest_main)
12511253
endif()
12521254

12531255
if(BUILD_MOBILE_BENCHMARK)

‎test/test_static_runtime.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,8 @@ def test_multihead_attention_layer(self):
106106
DROPOUT = 0.1
107107
device = torch.device("cpu")
108108
attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device)
109-
src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device)
109+
with torch.no_grad():
110+
src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device)
110111
src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device)
111112

112113
attention.eval()
@@ -129,17 +130,19 @@ def test_mlp(self):
129130
bot_l_acc = StaticRuntime(bot_l)
130131
top_l = create_mlp(ln_top, sigmoid_top)
131132
top_l_acc = StaticRuntime(top_l)
132-
bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512])
133-
top_inp = torch.randn(2048, 100) # torch.Size([2048, 100])
133+
with torch.no_grad():
134+
bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512])
135+
top_inp = torch.randn(2048, 100) # torch.Size([2048, 100])
134136
ref_bot = bot_l(bot_inp)
135137
acc_bot = bot_l_acc(bot_inp)[0]
136138
torch.testing.assert_allclose(acc_bot, ref_bot)
137139
ref_top = top_l(top_inp)
138140
acc_top = top_l_acc(top_inp)[0]
139141
torch.testing.assert_allclose(acc_top, ref_top)
140142
for _ in range(5):
141-
bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512])
142-
top_inp = torch.randn(2048, 100) # torch.Size([2048, 100])
143+
with torch.no_grad():
144+
bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512])
145+
top_inp = torch.randn(2048, 100) # torch.Size([2048, 100])
143146
ref_bot = bot_l(bot_inp)
144147
acc_bot = bot_l_acc(bot_inp)[0]
145148
torch.testing.assert_allclose(acc_bot, ref_bot)

‎tools/build_variables.bzl

+1
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ core_sources_full = [
219219
"torch/csrc/jit/runtime/profiling_record.cpp",
220220
"torch/csrc/jit/runtime/symbolic_script.cpp",
221221
"torch/csrc/jit/runtime/static/impl.cpp",
222+
"torch/csrc/jit/runtime/static/ops.cpp",
222223
"torch/csrc/jit/serialization/import.cpp",
223224
"torch/csrc/jit/serialization/import_export_helpers.cpp",
224225
"torch/csrc/jit/serialization/import_source.cpp",

‎torch/csrc/jit/runtime/static/impl.cpp

+6-57
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <torch/csrc/jit/passes/freeze_module.h>
55
#include <torch/csrc/jit/passes/remove_mutation.h>
66
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
7+
#include <torch/csrc/jit/runtime/static/ops.h>
78
#include <torch/csrc/jit/runtime/vararg_functions.h>
89

910
namespace torch {
@@ -12,48 +13,6 @@ namespace jit {
1213
using c10::DispatchKey;
1314
using c10::RegisterOperators;
1415

15-
static auto reg =
16-
RegisterOperators()
17-
.op("static::add(Tensor a, Tensor b) -> Tensor",
18-
RegisterOperators::options().kernel(
19-
DispatchKey::CPU,
20-
[](at::Tensor a, at::Tensor b) -> at::Tensor { return a + b; }))
21-
.op("static::mul.a(Tensor a, Tensor b) -> Tensor",
22-
RegisterOperators::options().kernel(
23-
DispatchKey::CPU,
24-
[](at::Tensor a, at::Tensor b) -> at::Tensor { return a * b; }))
25-
.op("static::mul.b(Tensor a, int b) -> Tensor",
26-
RegisterOperators::options().kernel(
27-
DispatchKey::CPU,
28-
[](at::Tensor a, int64_t b) -> at::Tensor { return a * b; }));
29-
30-
#define SUPPORTED_OPS(F) \
31-
F(aten::__getitem__) \
32-
F(aten::add) \
33-
F(aten::addmm) \
34-
F(aten::bmm) \
35-
F(aten::cat) \
36-
F(aten::clamp) \
37-
F(aten::contiguous) \
38-
F(aten::div) \
39-
F(aten::flatten) \
40-
F(aten::index_put_) \
41-
F(aten::isnan) \
42-
F(aten::matmul) \
43-
F(aten::mul) \
44-
F(aten::permute) \
45-
F(aten::relu) \
46-
F(aten::sigmoid) \
47-
F(aten::size) \
48-
F(aten::softmax) \
49-
F(aten::t) \
50-
F(aten::to) \
51-
F(aten::transpose) \
52-
F(aten::view) \
53-
F(prim::Constant) \
54-
F(prim::ListConstruct) \
55-
F(prim::TupleConstruct)
56-
5716
StaticRuntime::StaticRuntime(const torch::jit::Module& m)
5817
: module_(m.copy()), graph_(nullptr) {
5918
module_.eval();
@@ -84,19 +43,6 @@ StaticRuntime::StaticRuntime(const torch::jit::Module& m)
8443
}
8544
}
8645

87-
SubgraphRewriter sr;
88-
sr.RegisterRewritePattern(
89-
R"IR(
90-
graph(%x, %w, %s):
91-
%r = aten::add(%x, %w, %s)
92-
return (%r))IR",
93-
R"IR(
94-
graph(%x, %w, %s):
95-
%y = static::add(%x, %w)
96-
%r = static::mul(%y, %s)
97-
return (%r))IR");
98-
sr.runOnGraph(graph_);
99-
10046
// remove unused input 0 from graph
10147
if (graph_->inputs().at(0)->type()->is_module()) {
10248
if (!graph_->inputs().at(0)->hasUses()) {
@@ -157,10 +103,13 @@ ProcessedNode::ProcessedNode(Node* node) : node_(node) {
157103
CHECK(op.hasOperation());
158104
op_ = op.getOperation(node);
159105
}
106+
if (canRunOutOfPlace(node)) {
107+
fn_ = getOutOfPlaceOperation(node);
108+
}
160109
}
161110

162111
void ProcessedNode::run(StaticRuntime::ConstantMap& workspace) const {
163-
if (use_stack_) {
112+
if (!fn_) {
164113
std::vector<IValue> stack;
165114
const size_t size = node_->inputs().size();
166115
stack.reserve(size);
@@ -201,7 +150,7 @@ void ProcessedNode::run(StaticRuntime::ConstantMap& workspace) const {
201150
workspace[node_->outputs()[i]] = stack[i];
202151
}
203152
} else {
204-
TORCH_CHECK(0, "Non-stack execution not yet implemented");
153+
(*fn_)(workspace);
205154
}
206155
}
207156

‎torch/csrc/jit/runtime/static/impl.h

+1-2
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,7 @@ class ProcessedNode {
5353
private:
5454
Node* node_;
5555
c10::optional<Operation> op_;
56-
// if false, we have an optimized version
57-
bool use_stack_ = true;
56+
c10::optional<std::function<void(StaticRuntime::ConstantMap&)>> fn_;
5857
};
5958

6059
} // namespace jit

‎torch/csrc/jit/runtime/static/ops.cpp

+128
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
#include <torch/csrc/jit/runtime/static/ops.h>
2+
#include <torch/csrc/jit/ir/ir.h>
3+
4+
namespace torch {
5+
namespace jit {
6+
7+
bool canRunOutOfPlace(Node* n) {
8+
auto str = std::string(n->kind().toQualString());
9+
if ((str == "aten::add") || (str == "aten::mul") || (str == "aten::addmm") ||
10+
(str == "aten::bmm") || (str == "aten::sigmoid") ||
11+
(str == "aten::cat")) {
12+
return true;
13+
}
14+
return false;
15+
}
16+
17+
std::function<void(StaticRuntime::ConstantMap&)> getOutOfPlaceOperation(
18+
Node* n) {
19+
auto create_empty_from = [](const at::Tensor& t) {
20+
return at::empty({0}, t.options());
21+
};
22+
23+
if (n->kind() == c10::Symbol::fromQualString("aten::add")) {
24+
auto out = n->outputs().at(0);
25+
auto in0 = n->inputs().at(0);
26+
auto in1 = n->inputs().at(1);
27+
auto in2 = n->inputs().at(2);
28+
return [=](StaticRuntime::ConstantMap& ws) {
29+
auto in0_t = ws.at(in0).toTensor();
30+
auto in1_t = ws.at(in1).toTensor();
31+
auto in2_s = ws.at(in2).toScalar();
32+
if (!ws.count(out)) {
33+
ws.emplace(out, create_empty_from(in0_t));
34+
}
35+
auto out_t = ws.at(out).toTensor();
36+
at::native::add_out(out_t, in0_t, in1_t, in2_s);
37+
};
38+
} else if (n->kind() == c10::Symbol::fromQualString("aten::mul")) {
39+
auto out = n->outputs().at(0);
40+
auto in0 = n->inputs().at(0);
41+
auto in1 = n->inputs().at(1);
42+
return [=](StaticRuntime::ConstantMap& ws) {
43+
auto in0_t = ws.at(in0).toTensor();
44+
auto in1_t = ws.at(in1).toTensor();
45+
if (!ws.count(out)) {
46+
ws.emplace(out, create_empty_from(in0_t));
47+
}
48+
auto out_t = ws.at(out).toTensor();
49+
at::native::mul_out(out_t, in0_t, in1_t);
50+
};
51+
} else if (n->kind() == c10::Symbol::fromQualString("aten::addmm")) {
52+
auto out = n->outputs().at(0);
53+
auto in0 = n->inputs().at(0);
54+
auto in1 = n->inputs().at(1);
55+
auto in2 = n->inputs().at(2);
56+
auto in3 = n->inputs().at(3);
57+
auto in4 = n->inputs().at(4);
58+
return [=](StaticRuntime::ConstantMap& ws) {
59+
auto in0_t = ws.at(in0).toTensor();
60+
auto in1_t = ws.at(in1).toTensor();
61+
auto in2_t = ws.at(in2).toTensor();
62+
auto in3_s = ws.at(in3).toScalar();
63+
auto in4_s = ws.at(in3).toScalar();
64+
if (!ws.count(out)) {
65+
ws.emplace(out, create_empty_from(in0_t));
66+
}
67+
auto out_t = ws.at(out).toTensor();
68+
at::native::addmm_cpu_out(out_t, in0_t, in1_t, in2_t, in3_s, in4_s);
69+
};
70+
} else if (n->kind() == c10::Symbol::fromQualString("aten::clamp")) {
71+
auto out = n->outputs().at(0);
72+
auto in0 = n->inputs().at(0);
73+
auto in1 = n->inputs().at(1);
74+
auto in2 = n->inputs().at(2);
75+
return [=](StaticRuntime::ConstantMap& ws) {
76+
auto in0_t = ws.at(in0).toTensor();
77+
auto in1_s = ws.at(in1).toScalar();
78+
auto in2_s = ws.at(in2).toScalar();
79+
if (!ws.count(out)) {
80+
ws.emplace(out, create_empty_from(in0_t));
81+
}
82+
auto out_t = ws.at(out).toTensor();
83+
at::native::clamp_out(out_t, in0_t, in1_s, in2_s);
84+
};
85+
} else if (n->kind() == c10::Symbol::fromQualString("aten::bmm")) {
86+
auto out = n->outputs().at(0);
87+
auto in0 = n->inputs().at(0);
88+
auto in1 = n->inputs().at(1);
89+
return [=](StaticRuntime::ConstantMap& ws) {
90+
auto in0_t = ws.at(in0).toTensor();
91+
auto in1_t = ws.at(in1).toTensor();
92+
if (!ws.count(out)) {
93+
ws.emplace(out, create_empty_from(in0_t));
94+
}
95+
auto out_t = ws.at(out).toTensor();
96+
at::native::bmm_out_cpu(out_t, in0_t, in1_t);
97+
};
98+
} else if (n->kind() == c10::Symbol::fromQualString("aten::cat")) {
99+
auto out = n->outputs().at(0);
100+
auto in0 = n->inputs().at(0);
101+
auto in1 = n->inputs().at(1);
102+
return [=](StaticRuntime::ConstantMap& ws) {
103+
auto in0_tl = ws.at(in0).toTensorVector();
104+
auto in1_i = ws.at(in1).toInt();
105+
if (!ws.count(out)) {
106+
ws.emplace(out, create_empty_from(in0_tl[0]));
107+
}
108+
auto out_t = ws.at(out).toTensor();
109+
at::native::cat_out(out_t, in0_tl, in1_i);
110+
};
111+
} else if (n->kind() == c10::Symbol::fromQualString("aten::sigmoid")) {
112+
auto out = n->outputs().at(0);
113+
auto in0 = n->inputs().at(0);
114+
return [=](StaticRuntime::ConstantMap& ws) {
115+
auto in0_t = ws.at(in0).toTensor();
116+
if (!ws.count(out)) {
117+
ws.emplace(out, create_empty_from(in0_t));
118+
}
119+
auto out_t = ws.at(out).toTensor();
120+
at::native::sigmoid_out(out_t, in0_t);
121+
};
122+
}
123+
124+
return [](StaticRuntime::ConstantMap&) { TORCH_CHECK(0); };
125+
}
126+
127+
} // namespace jit
128+
} // namespace torch

‎torch/csrc/jit/runtime/static/ops.h

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
#pragma once
2+
3+
#include <torch/csrc/jit/ir/ir.h>
4+
#include <torch/csrc/jit/runtime/static/impl.h>
5+
6+
namespace torch {
7+
namespace jit {
8+
9+
bool canRunOutOfPlace(Node* n);
10+
std::function<void(StaticRuntime::ConstantMap&)> getOutOfPlaceOperation(
11+
Node* n);
12+
13+
#define SUPPORTED_OPS(F) \
14+
F(aten::__getitem__) \
15+
F(aten::add) \
16+
F(aten::addmm) \
17+
F(aten::bmm) \
18+
F(aten::cat) \
19+
F(aten::clamp) \
20+
F(aten::contiguous) \
21+
F(aten::div) \
22+
F(aten::flatten) \
23+
F(aten::index_put_) \
24+
F(aten::isnan) \
25+
F(aten::matmul) \
26+
F(aten::mul) \
27+
F(aten::permute) \
28+
F(aten::relu) \
29+
F(aten::sigmoid) \
30+
F(aten::size) \
31+
F(aten::softmax) \
32+
F(aten::t) \
33+
F(aten::to) \
34+
F(aten::transpose) \
35+
F(aten::view) \
36+
F(prim::Constant) \
37+
F(prim::ListConstruct) \
38+
F(prim::TupleConstruct)
39+
40+
} // namespace jit
41+
} // namespace torch

0 commit comments

Comments
 (0)
Please sign in to comment.