Skip to content

Commit ab75f99

Browse files
authored
Parallelization of ConstProp compilation (#3042)
To accelerate compilation time, this PR parallelizes compilation of ConstProp using `parallelFor()`. This mainly improves constant propagation for reduction computation. Run sequentially without applying this parallelization when input tensor is small to avoid parallelization overhead. --------- Signed-off-by: Haruki Imai <[email protected]>
1 parent d8de38c commit ab75f99

File tree

3 files changed

+141
-20
lines changed

3 files changed

+141
-20
lines changed

src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp

+70-16
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
//===----------------------------------------------------------------------===//
1010

1111
#include "src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp"
12-
1312
#include "mlir/Dialect/Traits.h"
13+
#include "mlir/IR/Threading.h"
1414
#include "llvm/ADT/STLExtras.h"
1515

1616
#include "src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp"
@@ -849,6 +849,8 @@ ElementsAttr ElementsAttrBuilder::reduce(ElementsAttr elms,
849849
if (axes.empty())
850850
return elms;
851851

852+
Type elementType = elms.getElementType();
853+
MLIRContext *ctx = elementType.getContext();
852854
SmallVector<unsigned, 4> sortedAxes(axes);
853855
std::sort(sortedAxes.begin(), sortedAxes.end());
854856
assert(
@@ -885,22 +887,74 @@ ElementsAttr ElementsAttrBuilder::reduce(ElementsAttr elms,
885887

886888
ShapedType reducedType = type.clone(reducedShape);
887889
return fromWideNums(reducedType, [&](MutableArrayRef<WideNum> dstNums) {
888-
// Traverse and populate each element d in dstNums.
889-
for (auto &idxoffs : StridesRange<1>(reducedShape, {reducedStrides})) {
890-
WideNum &d = dstNums[idxoffs.flattenedIndex];
891-
int64_t srcPos = idxoffs[0];
892-
// Traverse all the elements that reduce together into d.
893-
// srcNums elements may be repeated if there are zeros in axesStrides.
894-
StridesRange<1> axesRange(axesShape, {axesStrides});
895-
auto axesIter = axesRange.begin();
896-
auto axesEnd = axesRange.end();
897-
assert(axesIter->at(0) == 0 && "initial src offset must be zero");
898-
d = srcNums.get()[srcPos];
899-
while (++axesIter != axesEnd) {
900-
int64_t srcOffset = axesIter->at(0);
901-
d = reducer(d, srcNums.get()[srcPos + srcOffset]);
890+
StridesRange<1> sRange(reducedShape, {reducedStrides});
891+
StridesRange<1> axesRange(axesShape, {axesStrides});
892+
SmallVector<std::pair<int64_t, uint64_t>, 4> batch;
893+
for (auto &idxoffs : sRange)
894+
batch.emplace_back(std::make_pair(idxoffs.flattenedIndex, idxoffs[0]));
895+
896+
auto fetchBatch = [&](size_t threadNumber, bool parallel) {
897+
// retrun all data without spliting for sequential execution.
898+
if (!parallel)
899+
return llvm::make_range(batch.begin(), batch.end());
900+
// Each thread fetches the same batch size. The leftovers are set in the
901+
// threads with small thread number.
902+
size_t tileSize = floor(batch.size() / ctx->getNumThreads());
903+
size_t leftovers = batch.size() % ctx->getNumThreads();
904+
int beginOffset;
905+
if (threadNumber < leftovers) {
906+
// for the first few threads, it is as if the block size is larger by 1.
907+
tileSize++;
908+
beginOffset = threadNumber * tileSize;
909+
} else {
910+
// for the last threads, its as we shift the start by leftovers.
911+
beginOffset = threadNumber * tileSize + leftovers;
902912
}
903-
}
913+
int endOffset = beginOffset + tileSize;
914+
return llvm::make_range(
915+
batch.begin() + beginOffset, batch.begin() + endOffset);
916+
};
917+
918+
auto work = [&](size_t threadNumber, bool parallel = true) {
919+
auto tile = fetchBatch(threadNumber, parallel);
920+
// Traverse and populate each element d in dstNums.
921+
for (auto b : tile) {
922+
WideNum &d = dstNums[b.first];
923+
int64_t srcPos = b.second;
924+
// Traverse all the elements that reduce together into d.
925+
// srcNums elements may be repeated if there are zeros in axesStrides.
926+
auto axesIter = axesRange.begin();
927+
auto axesEnd = axesRange.end();
928+
assert(axesIter->at(0) == 0 && "initial src offset must be zero");
929+
d = srcNums.get()[srcPos];
930+
while (++axesIter != axesEnd) {
931+
int64_t srcOffset = axesIter->at(0);
932+
d = reducer(d, srcNums.get()[srcPos + srcOffset]);
933+
}
934+
}
935+
};
936+
// Using 'parallelFor()' introduces large overhead. Followings are actual
937+
// measurement results on IBM z16 to decide the 'minCount'. We measured
938+
// 'onnx.ReduceSum()' in 'test/mlir/onnx/onnx_constprop_parallel.mlir' using
939+
// several input size. From these results, we decided to use 2000 as the
940+
// 'minCount'.
941+
//
942+
// inputCounts|Sequential | Parallel with 2 threads
943+
// | (work()) | (parallelFor())
944+
// | (msec) | (msec)
945+
// --------------------------------------------------
946+
// 400 | 0.065 | 0.153
947+
// 800 | 0.115 | 0.164
948+
// 1200 | 0.175 | 0.201
949+
// 1600 | 0.226 | 0.228
950+
// 2000 | 0.282 | 0.258
951+
// 2400 | 0.336 | 0.284
952+
constexpr size_t minCount = 2000;
953+
size_t inputCount = batch.size() * axesRange.size();
954+
if (inputCount < minCount)
955+
work(0, /*parallel*/ false);
956+
else
957+
parallelFor(ctx, 0, ctx->getNumThreads(), work);
904958
});
905959
}
906960

src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp

+41-4
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#ifndef ONNX_MLIR_ELEM_ATTR_BUILDER_H
1212
#define ONNX_MLIR_ELEM_ATTR_BUILDER_H
13+
#include "mlir/IR/Threading.h"
1314

1415
#include "src/Dialect/ONNX/ElementsAttr/BType.hpp"
1516
#include "src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp"
@@ -244,10 +245,46 @@ class ElementsAttrBuilder {
244245
// Constructs a transformer that changes every element to the result of
245246
// applying the given function to the element.
246247
template <typename Function = WideNum (*)(WideNum)>
247-
static inline Transformer functionTransformer(Function fun) {
248-
return [fun = std::move(fun)](llvm::MutableArrayRef<WideNum> data) -> void {
249-
for (WideNum &n : data)
250-
n = fun(n);
248+
inline Transformer functionTransformer(Function fun) {
249+
mlir::MLIRContext *ctx = disposablePool.getContext();
250+
return [fun = std::move(fun), ctx](
251+
llvm::MutableArrayRef<WideNum> data) -> void {
252+
auto fetchBatch = [&](size_t threadNumber, bool parallel) {
253+
// retrun all data without spliting for sequential execution.
254+
if (!parallel)
255+
return llvm::make_range(data.begin(), data.end());
256+
// Each thread fetches the same data size. The leftovers are set in the
257+
// threads with small thread number.
258+
size_t tileSize = floor(data.size() / ctx->getNumThreads());
259+
size_t leftovers = data.size() % ctx->getNumThreads();
260+
int beginOffset;
261+
if (threadNumber < leftovers) {
262+
// for the first few threads, it is as if the block size is larger
263+
// by 1.
264+
tileSize++;
265+
beginOffset = threadNumber * tileSize;
266+
} else {
267+
// for the last threads, its as we shift the start by leftovers.
268+
beginOffset = threadNumber * tileSize + leftovers;
269+
}
270+
int endOffset = beginOffset + tileSize;
271+
return llvm::make_range(
272+
data.begin() + beginOffset, data.begin() + endOffset);
273+
};
274+
275+
auto work = [&](size_t threadNumber, bool parallel = true) {
276+
auto tile = fetchBatch(threadNumber, parallel);
277+
for (WideNum &n : tile)
278+
n = fun(n);
279+
};
280+
// Using 'parallelFor()' introduces large overhead.
281+
// To avoid this overhead, call work() directry if input size is less than
282+
// `minCount`.
283+
constexpr size_t minCount = 1000;
284+
if (data.size() < minCount)
285+
work(0, /*parallel*/ false);
286+
else
287+
parallelFor(ctx, 0, ctx->getNumThreads(), work);
251288
};
252289
}
253290

0 commit comments

Comments
 (0)