Skip to content

Commit 8e0714a

Browse files
Ilia Cherniavskiifacebook-github-bot
Ilia Cherniavskii
authored andcommittedJul 1, 2020
[rfc] Reduce number of coin flips in RecordFunction (pytorch#40758)
Summary: Pull Request resolved: pytorch#40758 Currently we flip a coin for each sampled callback each time we run RecordFunction, this PR is an attempt to skip most of the coin flips (for the low-probability observers) and keep the distribution close to the original one Test Plan: CI and record_function_benchmark ``` (python_venv) iliacher@devgpu151:~/local/pytorch (reduce_coin_flops)$ ./build/bin/record_function_benchmark Warmup time: 30108 us. Time per iteration (1x1): 1496.78 us. Time per iteration (16x16): 2142.46 us. Pure RecordFunction runtime of 10000000 iterations 687929 us, number of callback invocations: 978 (python_venv) iliacher@devgpu151:~/local/pytorch (reduce_coin_flops)$ ./build/bin/record_function_benchmark Warmup time: 19051 us. Time per iteration (1x1): 1581.89 us. Time per iteration (16x16): 2195.67 us. Pure RecordFunction runtime of 10000000 iterations 682402 us, number of callback invocations: 1023 (python_venv) iliacher@devgpu151:~/local/pytorch (reduce_coin_flops)$ ./build/bin/record_function_benchmark Warmup time: 18715 us. Time per iteration (1x1): 1566.11 us. Time per iteration (16x16): 2131.17 us. Pure RecordFunction runtime of 10000000 iterations 693571 us, number of callback invocations: 963 (python_venv) iliacher@devgpu151:~/local/pytorch (reduce_coin_flops)$ (python_venv) iliacher@devgpu151:~/local/pytorch (reduce_coin_flops)$ ./build/bin/record_function_benchmark Warmup time: 18814 us. Time per iteration (1x1): 1536.2 us. Time per iteration (16x16): 1985.82 us. Pure RecordFunction runtime of 10000000 iterations 944959 us, number of callback invocations: 1015 (python_venv) iliacher@devgpu151:~/local/pytorch (reduce_coin_flops)$ ./build/bin/record_function_benchmark Warmup time: 18278 us. Time per iteration (1x1): 1526.32 us. Time per iteration (16x16): 2093.77 us. Pure RecordFunction runtime of 10000000 iterations 985307 us, number of callback invocations: 1013 (python_venv) iliacher@devgpu151:~/local/pytorch (reduce_coin_flops)$ ./build/bin/record_function_benchmark Warmup time: 18545 us. Time per iteration (1x1): 1524.65 us. Time per iteration (16x16): 2080 us. Pure RecordFunction runtime of 10000000 iterations 952835 us, number of callback invocations: 1048 ``` Reviewed By: dzhulgakov Differential Revision: D22320879 Pulled By: ilia-cher fbshipit-source-id: 2193f07d2f7625814fe7bc3cc85ba4092fe036bc
1 parent 179dbd4 commit 8e0714a

File tree

3 files changed

+84
-32
lines changed

3 files changed

+84
-32
lines changed
 

‎aten/src/ATen/record_function.cpp

+45-3
Original file line numberDiff line numberDiff line change
@@ -212,16 +212,58 @@ inline CallbackManager& manager() {
212212
return _manager;
213213
}
214214

215-
} // namespace
215+
// Low probability constant
216+
const double kLowProb = 0.001;
217+
thread_local int tries_left_ = 0;
216218

217-
/* static */
218-
double RecordFunctionCallback::sample_zero_one() {
219+
int sample_geometric() {
220+
static thread_local auto gen =
221+
std::make_unique<std::mt19937>(std::random_device()());
222+
std::geometric_distribution<int> dist(kLowProb);
223+
return dist(*gen);
224+
}
225+
226+
double sample_zero_one() {
219227
static thread_local auto gen =
220228
std::make_unique<std::mt19937>(std::random_device()());
221229
std::uniform_real_distribution<double> dist(0.0, 1.0);
222230
return dist(*gen);
223231
}
224232

233+
} // namespace
234+
235+
bool RecordFunctionCallback::shouldRun(RecordScope scope) const {
236+
// first check whether this callback is interested in
237+
// the given scope type
238+
if (!checkScope(scope)) {
239+
return false;
240+
}
241+
// if we have registered should_run_ function, use it
242+
if (should_run_) {
243+
return should_run_(*this);
244+
}
245+
// otherwise potentially do the uniform sampling
246+
if (sampling_prob_ != 1.0) {
247+
// model the low probability events as events happening
248+
// with prob. kLowProb followed by another sampling with
249+
// prob. (sampling_prob_ / kLowProb), then replace the coin
250+
// flip for kLowProb with a thread local number of tries tries_left_
251+
// sampled from the geometric distribution
252+
if (sampling_prob_ < kLowProb) {
253+
if (tries_left_ == 0) {
254+
tries_left_ = sample_geometric();
255+
return (sample_zero_one() < sampling_prob_ / kLowProb);
256+
} else {
257+
--tries_left_;
258+
return false;
259+
}
260+
} else {
261+
return (sample_zero_one() < sampling_prob_);
262+
}
263+
}
264+
return true;
265+
}
266+
225267
RecordFunctionCallbacks _getTLSCallbacks() {
226268
return sorted_tls_callbacks_;
227269
}

‎aten/src/ATen/record_function.h

+2-19
Original file line numberDiff line numberDiff line change
@@ -303,23 +303,8 @@ class TORCH_API RecordFunctionCallback {
303303
return end_;
304304
}
305305

306-
// whether this callbacks should run in the given scope
307-
inline bool shouldRun(RecordScope scope) const {
308-
// first check whether this callback is interested in
309-
// the given scope type
310-
if (!checkScope(scope)) {
311-
return false;
312-
}
313-
// if we have registered should_run_ function, use it
314-
if (should_run_) {
315-
return should_run_(*this);
316-
}
317-
// otherwise potentially do the uniform sampling
318-
if (sampling_prob_ != 1.0) {
319-
return (sample_zero_one() < sampling_prob_);
320-
}
321-
return true;
322-
}
306+
// whether the callbacks should run in the given scope
307+
bool shouldRun(RecordScope scope) const;
323308

324309
private:
325310
std::function<void(const RecordFunction&)> start_;
@@ -329,8 +314,6 @@ class TORCH_API RecordFunctionCallback {
329314
bool needs_ids_ = false;
330315
double sampling_prob_ = 1.0;
331316
std::array<bool, static_cast<size_t>(RecordScope::NUM_SCOPES)> scopes_ = {};
332-
333-
static double sample_zero_one();
334317
};
335318

336319
// Using macro to minimize inputs copies,

‎binaries/record_function_benchmark.cc

+37-10
Original file line numberDiff line numberDiff line change
@@ -8,32 +8,31 @@
88
#include <ctime>
99

1010
C10_DEFINE_int(iter, 100, "Number of iterations");
11-
C10_DEFINE_int(warmup_iter, 10, "Number of warmup iterations")
11+
C10_DEFINE_int(warmup_iter, 10, "Number of warmup iterations");
12+
C10_DEFINE_int(rec_fn_iter, 10e6,
13+
"Number of iterations for the pure RecordFunction benchmark");
1214

1315
namespace {
1416
const int kInnerIter = 100;
1517
const int kNumSampledCb = 2;
1618
const int kTensorSize = 16;
1719
const int kSmallTensorSize = 1;
1820
const float kSampingProb = 0.1;
19-
}
2021

22+
const float kLowSamplingProb = 0.0001;
23+
}
2124

22-
void setupCallbacks() {
25+
void setupBenchmarkCallbacks() {
2326
// non-sampled callback
2427
at::addGlobalCallback(at::RecordFunctionCallback(
25-
[&](const at::RecordFunction& fn) {
26-
return true;
27-
},
28+
[&](const at::RecordFunction& fn) {},
2829
[](const at::RecordFunction&) {})
2930
.needsInputs(true));
3031

3132
// sampled
3233
for (auto idx = 0; idx < kNumSampledCb; ++idx) {
3334
at::addGlobalCallback(at::RecordFunctionCallback(
34-
[](const at::RecordFunction& fn) {
35-
return true;
36-
},
35+
[](const at::RecordFunction& fn) {},
3736
[](const at::RecordFunction&) {})
3837
.needsInputs(true)
3938
.samplingProb(kSampingProb)
@@ -61,7 +60,8 @@ int main(int argc, char** argv) {
6160
return -1;
6261
}
6362

64-
setupCallbacks();
63+
at::enableRecordFunction();
64+
setupBenchmarkCallbacks();
6565

6666
auto duration = runBench(kSmallTensorSize, FLAGS_warmup_iter);
6767
std::cout << "Warmup time: " << duration << " us." << std::endl;
@@ -76,5 +76,32 @@ int main(int argc, char** argv) {
7676
<< " us." << std::endl;
7777
}
7878

79+
at::clearCallbacks();
80+
81+
int cb_count = 0;
82+
at::addGlobalCallback(at::RecordFunctionCallback(
83+
[&](const at::RecordFunction& fn) {
84+
++cb_count;
85+
},
86+
[](const at::RecordFunction&) {})
87+
.needsInputs(true)
88+
.samplingProb(kLowSamplingProb)
89+
);
90+
91+
typedef std::chrono::high_resolution_clock clock;
92+
typedef std::chrono::microseconds us;
93+
std::chrono::time_point<clock> start_time = clock::now();
94+
for (auto n = 0; n < FLAGS_rec_fn_iter; ++n) {
95+
RECORD_USER_SCOPE("test");
96+
}
97+
duration = static_cast<float>(
98+
std::chrono::duration_cast<us>(clock::now() - start_time).count());
99+
std::cout << "Pure RecordFunction runtime of " << FLAGS_rec_fn_iter
100+
<< " iterations " << duration
101+
<< " us, number of callback invocations: " << cb_count
102+
<< ", expected number: ~" << (int)(FLAGS_rec_fn_iter * kLowSamplingProb)
103+
<< " invocations" << std::endl;
104+
105+
at::clearCallbacks();
79106
return 0;
80107
}

0 commit comments

Comments
 (0)
Please sign in to comment.