|
9 | 9 | //===----------------------------------------------------------------------===//
|
10 | 10 |
|
11 | 11 | #include "src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp"
|
12 |
| - |
13 | 12 | #include "mlir/Dialect/Traits.h"
|
| 13 | +#include "mlir/IR/Threading.h" |
14 | 14 | #include "llvm/ADT/STLExtras.h"
|
15 | 15 |
|
16 | 16 | #include "src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp"
|
@@ -849,6 +849,8 @@ ElementsAttr ElementsAttrBuilder::reduce(ElementsAttr elms,
|
849 | 849 | if (axes.empty())
|
850 | 850 | return elms;
|
851 | 851 |
|
| 852 | + Type elementType = elms.getElementType(); |
| 853 | + MLIRContext *ctx = elementType.getContext(); |
852 | 854 | SmallVector<unsigned, 4> sortedAxes(axes);
|
853 | 855 | std::sort(sortedAxes.begin(), sortedAxes.end());
|
854 | 856 | assert(
|
@@ -885,22 +887,74 @@ ElementsAttr ElementsAttrBuilder::reduce(ElementsAttr elms,
|
885 | 887 |
|
886 | 888 | ShapedType reducedType = type.clone(reducedShape);
|
887 | 889 | return fromWideNums(reducedType, [&](MutableArrayRef<WideNum> dstNums) {
|
888 |
| - // Traverse and populate each element d in dstNums. |
889 |
| - for (auto &idxoffs : StridesRange<1>(reducedShape, {reducedStrides})) { |
890 |
| - WideNum &d = dstNums[idxoffs.flattenedIndex]; |
891 |
| - int64_t srcPos = idxoffs[0]; |
892 |
| - // Traverse all the elements that reduce together into d. |
893 |
| - // srcNums elements may be repeated if there are zeros in axesStrides. |
894 |
| - StridesRange<1> axesRange(axesShape, {axesStrides}); |
895 |
| - auto axesIter = axesRange.begin(); |
896 |
| - auto axesEnd = axesRange.end(); |
897 |
| - assert(axesIter->at(0) == 0 && "initial src offset must be zero"); |
898 |
| - d = srcNums.get()[srcPos]; |
899 |
| - while (++axesIter != axesEnd) { |
900 |
| - int64_t srcOffset = axesIter->at(0); |
901 |
| - d = reducer(d, srcNums.get()[srcPos + srcOffset]); |
| 890 | + StridesRange<1> sRange(reducedShape, {reducedStrides}); |
| 891 | + StridesRange<1> axesRange(axesShape, {axesStrides}); |
| 892 | + SmallVector<std::pair<int64_t, uint64_t>, 4> batch; |
| 893 | + for (auto &idxoffs : sRange) |
| 894 | + batch.emplace_back(std::make_pair(idxoffs.flattenedIndex, idxoffs[0])); |
| 895 | + |
| 896 | + auto fetchBatch = [&](size_t threadNumber, bool parallel) { |
| 897 | + // retrun all data without spliting for sequential execution. |
| 898 | + if (!parallel) |
| 899 | + return llvm::make_range(batch.begin(), batch.end()); |
| 900 | + // Each thread fetches the same batch size. The leftovers are set in the |
| 901 | + // threads with small thread number. |
| 902 | + size_t tileSize = floor(batch.size() / ctx->getNumThreads()); |
| 903 | + size_t leftovers = batch.size() % ctx->getNumThreads(); |
| 904 | + int beginOffset; |
| 905 | + if (threadNumber < leftovers) { |
| 906 | + // for the first few threads, it is as if the block size is larger by 1. |
| 907 | + tileSize++; |
| 908 | + beginOffset = threadNumber * tileSize; |
| 909 | + } else { |
| 910 | + // for the last threads, its as we shift the start by leftovers. |
| 911 | + beginOffset = threadNumber * tileSize + leftovers; |
902 | 912 | }
|
903 |
| - } |
| 913 | + int endOffset = beginOffset + tileSize; |
| 914 | + return llvm::make_range( |
| 915 | + batch.begin() + beginOffset, batch.begin() + endOffset); |
| 916 | + }; |
| 917 | + |
| 918 | + auto work = [&](size_t threadNumber, bool parallel = true) { |
| 919 | + auto tile = fetchBatch(threadNumber, parallel); |
| 920 | + // Traverse and populate each element d in dstNums. |
| 921 | + for (auto b : tile) { |
| 922 | + WideNum &d = dstNums[b.first]; |
| 923 | + int64_t srcPos = b.second; |
| 924 | + // Traverse all the elements that reduce together into d. |
| 925 | + // srcNums elements may be repeated if there are zeros in axesStrides. |
| 926 | + auto axesIter = axesRange.begin(); |
| 927 | + auto axesEnd = axesRange.end(); |
| 928 | + assert(axesIter->at(0) == 0 && "initial src offset must be zero"); |
| 929 | + d = srcNums.get()[srcPos]; |
| 930 | + while (++axesIter != axesEnd) { |
| 931 | + int64_t srcOffset = axesIter->at(0); |
| 932 | + d = reducer(d, srcNums.get()[srcPos + srcOffset]); |
| 933 | + } |
| 934 | + } |
| 935 | + }; |
| 936 | + // Using 'parallelFor()' introduces large overhead. Followings are actual |
| 937 | + // measurement results on IBM z16 to decide the 'minCount'. We measured |
| 938 | + // 'onnx.ReduceSum()' in 'test/mlir/onnx/onnx_constprop_parallel.mlir' using |
| 939 | + // several input size. From these results, we decided to use 2000 as the |
| 940 | + // 'minCount'. |
| 941 | + // |
| 942 | + // inputCounts|Sequential | Parallel with 2 threads |
| 943 | + // | (work()) | (parallelFor()) |
| 944 | + // | (msec) | (msec) |
| 945 | + // -------------------------------------------------- |
| 946 | + // 400 | 0.065 | 0.153 |
| 947 | + // 800 | 0.115 | 0.164 |
| 948 | + // 1200 | 0.175 | 0.201 |
| 949 | + // 1600 | 0.226 | 0.228 |
| 950 | + // 2000 | 0.282 | 0.258 |
| 951 | + // 2400 | 0.336 | 0.284 |
| 952 | + constexpr size_t minCount = 2000; |
| 953 | + size_t inputCount = batch.size() * axesRange.size(); |
| 954 | + if (inputCount < minCount) |
| 955 | + work(0, /*parallel*/ false); |
| 956 | + else |
| 957 | + parallelFor(ctx, 0, ctx->getNumThreads(), work); |
904 | 958 | });
|
905 | 959 | }
|
906 | 960 |
|
|
0 commit comments