Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support bufferizable in okl #9787

Merged
merged 154 commits into from
Mar 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
154 commits
Select commit Hold shift + click to select a range
49a6c9b
kernel support cuda graph
howin98 Dec 15, 2022
3d3a05d
make more infer before computing
howin98 Dec 16, 2022
9e95230
enable cuda graph
howin98 Dec 16, 2022
5cdc924
add taging cuda graph support on mlir in a pass
howin98 Dec 22, 2022
65c9da8
enable tag make effect in kernel launch op
howin98 Dec 22, 2022
e9af789
make lit checking
howin98 Dec 22, 2022
03e294d
aggergate cuda graph support ops register
howin98 Dec 23, 2022
1dc5970
enable okl wrap with cuda graph support
howin98 Dec 26, 2022
591a7a2
remove aggregate pass
howin98 Dec 27, 2022
1dbd103
Merge branch 'master' of github.com:Oneflow-Inc/oneflow into support-…
howin98 Dec 27, 2022
843e320
Update OKLOps.td
howin98 Dec 27, 2022
45c258a
Update OneFlowPasses.td
howin98 Dec 27, 2022
f33a4bb
Update CudaGraphSupport.cpp
howin98 Dec 27, 2022
adc5d01
Delete readme.md
howin98 Dec 27, 2022
e7c0524
Delete readme.md
howin98 Dec 27, 2022
1597a37
Update test_cuda_graph_split.py
howin98 Dec 27, 2022
fbe8ebe
Update test_resnet.py
howin98 Dec 27, 2022
49eba17
auto format by CI
oneflow-ci-bot Dec 27, 2022
7b994ca
rewrite test
howin98 Dec 27, 2022
90c660a
auto format by CI
oneflow-ci-bot Dec 27, 2022
fc7bcb5
Merge branch 'master' into support-cuda-graph-in-okl
howin98 Dec 27, 2022
402d22d
Merge branch 'master' into support-cuda-graph-in-okl
howin98 Dec 28, 2022
82ea85e
remove magic string in mlir func name
howin98 Dec 28, 2022
0b7b12c
remoev cuda graph magic str
howin98 Dec 28, 2022
0c2703f
fix
howin98 Dec 28, 2022
77b073d
auto format by CI
oneflow-ci-bot Dec 28, 2022
d55fbaa
add flag in wrap op pass
howin98 Dec 29, 2022
2f62aab
auto format by CI
oneflow-ci-bot Dec 29, 2022
a8097b5
fix
howin98 Dec 29, 2022
618781c
fix
howin98 Dec 29, 2022
e5235fd
Merge branch 'master' into support-cuda-graph-in-okl
howin98 Dec 29, 2022
8c9db19
Merge branch 'master' into support-cuda-graph-in-okl
howin98 Dec 29, 2022
a9e15ed
fix
howin98 Dec 29, 2022
efcbfb5
fix
howin98 Dec 30, 2022
7bc0028
fix
howin98 Dec 30, 2022
f7fc3c0
auto format by CI
oneflow-ci-bot Dec 30, 2022
7f63f10
fix
howin98 Dec 30, 2022
2c48a27
Merge branch 'support-cuda-graph-in-okl' of github.com:Oneflow-Inc/on…
howin98 Dec 30, 2022
051faae
Merge branch 'master' into support-cuda-graph-in-okl
howin98 Dec 30, 2022
292d322
fix wrap op in cuda_graph with oenflow v.cpu
howin98 Dec 31, 2022
5fae685
fix
howin98 Dec 31, 2022
0065835
Merge branch 'master' into support-cuda-graph-in-okl
howin98 Dec 31, 2022
876fb3a
Merge branch 'master' into support-cuda-graph-in-okl
howin98 Jan 2, 2023
72e26d3
Merge branch 'master' into support-cuda-graph-in-okl
howin98 Jan 3, 2023
d46880a
move and fix
howin98 Jan 4, 2023
5aa34f1
Merge branch 'master' into support-cuda-graph-in-okl
howin98 Jan 10, 2023
f9bce13
Refactor okl dialect (#9693)
howin98 Jan 10, 2023
ec73fc8
auto format by CI
oneflow-ci-bot Jan 10, 2023
78d81a8
Merge branch 'master' into support-cuda-graph-in-okl
howin98 Jan 10, 2023
60f7371
auto format by CI
oneflow-ci-bot Jan 10, 2023
d9f1d55
Merge branch 'master' into support-cuda-graph-in-okl
howin98 Jan 11, 2023
998fd01
rewrite readme.md
howin98 Jan 11, 2023
284181d
fix
howin98 Jan 11, 2023
31dc8c3
fix
howin98 Jan 11, 2023
8618e67
fix
howin98 Jan 11, 2023
bd088fe
rewrite readme
howin98 Jan 12, 2023
8917b32
rewrite readme
howin98 Jan 12, 2023
387d25c
Merge branch 'master' into support-cuda-graph-in-okl
howin98 Jan 12, 2023
ec936b6
Support sd in okl with cuda graph (#9759)
howin98 Jan 17, 2023
e472761
Merge branch 'master' into support-cuda-graph-in-okl
howin98 Jan 17, 2023
875cd46
Dev cutlass conv cuda graph (#9767)
howin98 Jan 17, 2023
433da23
fix
howin98 Jan 17, 2023
1569248
Merge branch 'master' into support-cuda-graph-in-okl
jackalcooper Jan 17, 2023
608a34d
Merge branch 'master' into support-cuda-graph-in-okl
howin98 Jan 18, 2023
06dad71
Merge branch 'master' into support-cuda-graph-in-okl
howin98 Jan 19, 2023
f502992
fix
howin98 Jan 20, 2023
907b4b2
Merge branch 'master' into support-cuda-graph-in-okl
howin98 Jan 20, 2023
ffc8159
not works
howin98 Jan 22, 2023
48cb3cc
before opt is ok
howin98 Jan 28, 2023
2411e55
stash
howin98 Jan 28, 2023
e8b0026
done
howin98 Jan 28, 2023
33c2947
fix
howin98 Jan 28, 2023
3ff92c5
okm is ok
howin98 Jan 28, 2023
9d2de6a
okm is must now
howin98 Jan 29, 2023
372c4c3
stash
howin98 Jan 29, 2023
7675eff
stash
howin98 Jan 30, 2023
26b38c8
refactor tmp buffer to memory pool
howin98 Jan 31, 2023
47a885b
stash
howin98 Feb 1, 2023
307a5cb
Merge branch 'master' of github.com:Oneflow-Inc/oneflow into support-…
howin98 Feb 1, 2023
108af3a
Merge branch 'master' into support-bufferizable-in-okl
jackalcooper Feb 1, 2023
0013a65
rename
howin98 Feb 1, 2023
0c6f2a9
Merge branch 'support-bufferizable-in-okl' of github.com:Oneflow-Inc/…
howin98 Feb 1, 2023
c26a491
rename
howin98 Feb 1, 2023
fca3ba3
rename
howin98 Feb 1, 2023
86d911d
fix
howin98 Feb 1, 2023
d82e605
fix
howin98 Feb 1, 2023
4f92ac5
udpate guard
jackalcooper Feb 1, 2023
95f63dc
fix
howin98 Feb 1, 2023
4f95338
Merge branch 'support-bufferizable-in-okl' of github.com:Oneflow-Inc/…
howin98 Feb 1, 2023
9d6d34f
fix
howin98 Feb 2, 2023
838acf3
fix
howin98 Feb 2, 2023
0288483
done
howin98 Feb 2, 2023
5563d2c
Merge branch 'master' into support-bufferizable-in-okl
howin98 Feb 2, 2023
938ba20
fix
howin98 Feb 2, 2023
406b18f
trim no-use pass
howin98 Feb 2, 2023
3f1ea54
fix
howin98 Feb 2, 2023
110e5db
fix
howin98 Feb 2, 2023
57022a5
fix
howin98 Feb 2, 2023
1a89834
fix
howin98 Feb 2, 2023
8708c5d
fix
howin98 Feb 2, 2023
2a02fd3
Merge branch 'master' into support-bufferizable-in-okl
howin98 Feb 2, 2023
1a7baa8
fix
howin98 Feb 3, 2023
1ac00d8
fix
howin98 Feb 3, 2023
2be82e5
add ops and tests
howin98 Feb 7, 2023
885b275
Merge branch 'master' of github.com:Oneflow-Inc/oneflow into support-…
howin98 Feb 7, 2023
8ad510c
done
howin98 Feb 7, 2023
9dea3f1
support memory size first
howin98 Feb 10, 2023
8fad29d
Merge branch 'master' of github.com:Oneflow-Inc/oneflow into support-…
howin98 Feb 14, 2023
e0e9b3d
fix stride error
howin98 Feb 15, 2023
b5009c0
fix typo
jackalcooper Feb 16, 2023
406ff32
work well on stable diffusion
howin98 Feb 20, 2023
46b1f3d
Merge branch 'support-bufferizable-in-okl' of github.com:Oneflow-Inc/…
howin98 Feb 20, 2023
ad7bcbd
fix
howin98 Feb 20, 2023
6186833
rename
howin98 Feb 20, 2023
f70b722
Merge branch 'master' of github.com:Oneflow-Inc/oneflow into support-…
howin98 Feb 21, 2023
b73b877
speed up
howin98 Feb 22, 2023
ceeb31d
fix cuda graph support
howin98 Feb 22, 2023
76c486b
rm push down algo
howin98 Feb 22, 2023
0f479f6
fix
howin98 Feb 22, 2023
ca02eb9
fix
howin98 Feb 22, 2023
b529847
fix
howin98 Feb 22, 2023
6d25327
fix
howin98 Feb 22, 2023
77e269c
fix
howin98 Feb 22, 2023
fc4a52b
fix
howin98 Feb 22, 2023
69098e4
Merge branch 'master' into support-bufferizable-in-okl
howin98 Feb 22, 2023
873c637
Merge branch 'master' into support-bufferizable-in-okl
howin98 Feb 23, 2023
5b9f088
Merge branch 'master' into support-bufferizable-in-okl
howin98 Feb 24, 2023
12a065c
Merge branch 'master' into support-bufferizable-in-okl
howin98 Feb 24, 2023
dadceff
fix
howin98 Feb 24, 2023
411fa1f
Merge branch 'master' into support-bufferizable-in-okl
howin98 Feb 26, 2023
7c4755f
trim and add mlir
howin98 Feb 28, 2023
c0bd1f4
Merge branch 'support-bufferizable-in-okl' of github.com:Oneflow-Inc/…
howin98 Feb 28, 2023
7d852fe
Merge branch 'master' into support-bufferizable-in-okl
howin98 Feb 28, 2023
fb56031
Merge branch 'master' into support-bufferizable-in-okl
howin98 Mar 2, 2023
d708e77
Merge branch 'master' into support-bufferizable-in-okl
howin98 Mar 2, 2023
5045c6d
fmt
jackalcooper Mar 2, 2023
cd42322
fmt
jackalcooper Mar 2, 2023
3f3bf71
fmt
jackalcooper Mar 2, 2023
d99c4a3
Merge branch 'master' into support-bufferizable-in-okl
howin98 Mar 2, 2023
c124f7b
fix
howin98 Mar 3, 2023
c2eb941
fix
howin98 Mar 3, 2023
e7d0d0c
Merge branch 'support-bufferizable-in-okl' of github.com:Oneflow-Inc/…
howin98 Mar 3, 2023
51c7c66
fix
howin98 Mar 3, 2023
fbc6f22
Merge branch 'master' into support-bufferizable-in-okl
howin98 Mar 3, 2023
8a72530
Merge branch 'master' into support-bufferizable-in-okl
mergify[bot] Mar 3, 2023
8b72dd8
Merge branch 'master' into support-bufferizable-in-okl
mergify[bot] Mar 3, 2023
9f89381
Merge branch 'master' into support-bufferizable-in-okl
mergify[bot] Mar 3, 2023
32cbfbb
auto format by CI
oneflow-ci-bot Mar 4, 2023
b598be2
Merge branch 'master' into support-bufferizable-in-okl
howin98 Mar 4, 2023
e682d49
ifx
howin98 Mar 5, 2023
6d8e6da
Merge branch 'master' into support-bufferizable-in-okl
howin98 Mar 5, 2023
b2bfaa6
fix
howin98 Mar 5, 2023
e1a3503
Merge branch 'support-bufferizable-in-okl' of github.com:Oneflow-Inc/…
howin98 Mar 5, 2023
ea77556
Merge branch 'master' into support-bufferizable-in-okl
howin98 Mar 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 17 additions & 161 deletions oneflow/core/job/intra_job_mem_sharing_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,6 @@ namespace oneflow {

namespace {

struct MemBlockResultInfo {
size_t mem_block_size;
HashMap<RegstDescProto*, int64_t> regst_desc2offset;
};

int64_t GenDeviceUniqueId(int64_t machine_id, int64_t device_id) {
return (machine_id << 32) | device_id;
}
Expand Down Expand Up @@ -411,154 +406,11 @@ void GenRegstAllocFreeTimeLineAndRegstLifetimes(
CHECK(remain_regsts.empty());
}

// Judge whether a is suitable than b for a gap
bool SuitableThan(int64_t a, int64_t b) {
// The number have orders
// A non-negative number is always more suitable than a negative number
// If a number is non-negative, then the smaller the better
// If a number is negative, then the larger the better
// 0 > 1 > 2 > ... > 999999999 > -1 > -2 > ... > -99999999
// Now we flip the positive part to make it "the larger the better".
if (a >= 0) { a = GetMaxVal<int64_t>() - a; }
if (b >= 0) { b = GetMaxVal<int64_t>() - b; }
return a > b;
}

void MemReusedAlgorithmAllocateByOrder(
const bool compact_insert, const std::vector<RegstDescProto*>& order,
const HashMap<RegstDescProto*, size_t>& regst_desc2size,
const HashMap<RegstDescProto*, std::pair<int32_t, int32_t>>& regst2lifetime,
MemBlockResultInfo* result) {
HashMap<RegstDescProto*, int64_t>* regst_desc2offset = &(result->regst_desc2offset);
// NOTE: It is important to make the variables local.
// It took me several days to find out that using passed-in vector for size, order, and lifetime
// would double the running time. Switch HashMap to vector
int32_t total_register_num = order.size();
std::vector<int64_t> order2size(total_register_num);
std::vector<std::pair<int32_t, int32_t>> order2lifetime(total_register_num);
std::vector<int64_t> order2offset(total_register_num);
for (int32_t i = 0; i < total_register_num; i++) {
order2size[i] = regst_desc2size.at(order[i]);
order2lifetime[i] = regst2lifetime.at(order[i]);
}
size_t buffer_size = 1;
// Sort by offset
auto comp = [&order2offset](const auto& a, const auto& b) {
if (order2offset[a] != order2offset[b]) { return order2offset[a] < order2offset[b]; }
// Make sure we have a stable order even if we have the same offset for different registers
return a < b;
};
std::set<int32_t, decltype(comp)> sorted_registers(comp);
// Decide offset following the given order
for (int32_t inserting_id = 0; inserting_id < total_register_num; inserting_id++) {
const auto& inserting_lifetime = order2lifetime[inserting_id];
// At the beginning, try to insert the offset in the front of the whole memory pool.
int64_t inserting_offset = 0;
int64_t inserting_end = inserting_offset + order2size[inserting_id];
if (compact_insert) {
// Find the most suitable gap for the register
int64_t gap_head = 0;
int64_t inserting_size = order2size[inserting_id];
// difference = length of gap - length of the inserting register
int64_t diff_gap = 0, suitable_diff_gap = -1 - inserting_size;
for (const auto& curr_register : sorted_registers) {
// Ignore those non-excluded registers
if (IsLifetimeExcluded(inserting_lifetime, order2lifetime[curr_register])) {
if (gap_head < order2offset[curr_register]) {
// Find one gap
diff_gap = (order2offset[curr_register] - gap_head) - inserting_size;
// Compared with the previous suitable gap
if (SuitableThan(diff_gap, suitable_diff_gap)) {
suitable_diff_gap = diff_gap;
// We may insert the register into the gap
inserting_offset = gap_head;
}
// Update gap head
gap_head = order2offset[curr_register] + order2size[curr_register];
} else {
// No gap, update gap head
gap_head = std::max(gap_head, order2offset[curr_register] + order2size[curr_register]);
}
}
}
// Deal with the buffer_size, which may be the final gap
diff_gap = (buffer_size - gap_head) - inserting_size;
// Compared with the previous suitable gap
if (SuitableThan(diff_gap, suitable_diff_gap)) {
suitable_diff_gap = diff_gap;
// We may insert the register into the gap
inserting_offset = gap_head;
}
// If no gap large enough to contain the current register
if (suitable_diff_gap < 0) {
// Prolong the maximum memory pool size by (-suitable_diff_gap)
buffer_size -= suitable_diff_gap;
int64_t gap_end = suitable_diff_gap + inserting_size + inserting_offset;
for (auto reverse_it = sorted_registers.rbegin(); reverse_it != sorted_registers.rend();
reverse_it++) {
// All the registers with offset < gap_end maintain their position
if (order2offset[*reverse_it] < gap_end) { break; }
// All the registers with offset >= gap_end move backward
order2offset[*reverse_it] -= suitable_diff_gap;
}
}

} else {
for (const auto& curr_register : sorted_registers) {
// i: inserting register, j: current register
// x: register offset, l: register size
// If x_i + l_i <= x_j, then the inserting register would be placed at x_i
if (order2offset[curr_register] >= inserting_end) { break; }
// If i and j are excluded, and x_i + l_i > x_j,
// then we try to place i at x_j + l_j and check the following registers
if (IsLifetimeExcluded(inserting_lifetime, order2lifetime[curr_register])) {
int64_t curr_end = order2offset[curr_register] + order2size[curr_register];
// Can not set inserting offset = current end directly.
// We might have two excluded registers like this:
// register a: [100, 10000]
// register b: [500, 600]
if (inserting_offset < curr_end) {
inserting_offset = curr_end;
inserting_end = inserting_offset + order2size[inserting_id];
}
}
}
// Update total size
if (inserting_end > buffer_size) { buffer_size = inserting_end; }
}
// Either we break the loop or the loop terminated naturally, we can place i at inserting_offset
order2offset[inserting_id] = inserting_offset;
sorted_registers.insert(inserting_id);
}

result->mem_block_size = buffer_size;
// Switch vector to HashMap
for (int32_t i = 0; i < total_register_num; i++) {
(*regst_desc2offset)[order[i]] = order2offset[i];
}
}

void MemReusedMemSizeFirstAlgo(
const bool compact_insert,
const HashMap<RegstDescProto*, std::pair<int32_t, int32_t>>& regst2lifetime,
const HashMap<RegstDescProto*, size_t>& mem_reused_regst2size, MemBlockResultInfo* result) {
std::vector<RegstDescProto*> order;
order.reserve(regst2lifetime.size());
for (const auto& pair : regst2lifetime) { order.emplace_back(pair.first); }
std::sort(order.begin(), order.end(), [&](RegstDescProto* lhs, RegstDescProto* rhs) {
size_t l_value = mem_reused_regst2size.at(lhs);
size_t r_value = mem_reused_regst2size.at(rhs);
if (l_value == r_value) { return regst2lifetime.at(lhs).first < regst2lifetime.at(rhs).first; }
return l_value > r_value;
});
MemReusedAlgorithmAllocateByOrder(compact_insert, order, mem_reused_regst2size, regst2lifetime,
result);
}

void MemReusedLifetimeFirstAlgo(
const bool compact_insert,
const HashMap<RegstDescProto*, std::pair<int32_t, int32_t>>& regst2lifetime,
const HashMap<RegstDescProto*, size_t>& mem_reused_regst2size, MemBlockResultInfo* result) {
const HashMap<RegstDescProto*, size_t>& mem_reused_regst2size,
MemBlockResultInfo<RegstDescProto*>* result) {
std::vector<RegstDescProto*> order;
order.reserve(regst2lifetime.size());
for (const auto& pair : regst2lifetime) { order.emplace_back(pair.first); }
Expand All @@ -575,7 +427,8 @@ void MemReusedLifetimeFirstAlgo(
void MemReusedTimeLineAlgo(
const bool compact_insert,
const HashMap<RegstDescProto*, std::pair<int32_t, int32_t>>& regst2lifetime,
const HashMap<RegstDescProto*, size_t>& mem_reused_regst2size, MemBlockResultInfo* result) {
const HashMap<RegstDescProto*, size_t>& mem_reused_regst2size,
MemBlockResultInfo<RegstDescProto*>* result) {
std::vector<RegstDescProto*> order;
order.reserve(regst2lifetime.size());
for (const auto& pair : regst2lifetime) { order.emplace_back(pair.first); }
Expand All @@ -594,7 +447,8 @@ void MemReusedTimeLineAlgo(
void MemReusedMemVolumeFirstAlgo(
const bool compact_insert,
const HashMap<RegstDescProto*, std::pair<int32_t, int32_t>>& regst2lifetime,
const HashMap<RegstDescProto*, size_t>& mem_reused_regst2size, MemBlockResultInfo* result) {
const HashMap<RegstDescProto*, size_t>& mem_reused_regst2size,
MemBlockResultInfo<RegstDescProto*>* result) {
std::vector<RegstDescProto*> order;
order.reserve(regst2lifetime.size());
auto ComputeMemoryVolume = [&](RegstDescProto* key) {
Expand All @@ -617,7 +471,8 @@ void MemReusedMemVolumeFirstAlgo(
void SelectAlgorithmGenMemBlockOffset4Regsts(
MemAllocAlgoType algo_id, const bool compact_insert,
const HashMap<RegstDescProto*, std::pair<int32_t, int32_t>>& regst2lifetime,
const HashMap<RegstDescProto*, size_t>& mem_reused_regst2size, MemBlockResultInfo* result) {
const HashMap<RegstDescProto*, size_t>& mem_reused_regst2size,
MemBlockResultInfo<RegstDescProto*>* result) {
CHECK_EQ(result->mem_block_size, 0);
CHECK(result->regst_desc2offset.empty());

Expand Down Expand Up @@ -661,7 +516,8 @@ int64_t CountMemAllocAlgoNum() {
return alloc_algo_num * compact_insert_num;
}

void InitAlgo2Result(HashMap<std::pair<MemAllocAlgoType, bool>, MemBlockResultInfo>* algo2result) {
void InitAlgo2Result(
HashMap<std::pair<MemAllocAlgoType, bool>, MemBlockResultInfo<RegstDescProto*>>* algo2result) {
CHECK(algo2result->empty());
std::vector<bool> compact_insert_algorithms;
const MemoryCompactInsertConf& mem_compact_insert_conf =
Expand All @@ -676,16 +532,16 @@ void InitAlgo2Result(HashMap<std::pair<MemAllocAlgoType, bool>, MemBlockResultIn
// NOTE: Experiments show that memory first might be good enough for some cases.
for (auto compact_insert : compact_insert_algorithms) {
if (mem_alloc_algo_conf.use_mem_size_first_algo()) {
(*algo2result)[{kMemSizeFirstAlgo, compact_insert}] = MemBlockResultInfo();
(*algo2result)[{kMemSizeFirstAlgo, compact_insert}] = MemBlockResultInfo<RegstDescProto*>();
}
if (mem_alloc_algo_conf.use_lifetime_first_algo()) {
(*algo2result)[{kLifetimeFirstAlgo, compact_insert}] = MemBlockResultInfo();
(*algo2result)[{kLifetimeFirstAlgo, compact_insert}] = MemBlockResultInfo<RegstDescProto*>();
}
if (mem_alloc_algo_conf.use_time_line_algo()) {
(*algo2result)[{kTimeLineAlgo, compact_insert}] = MemBlockResultInfo();
(*algo2result)[{kTimeLineAlgo, compact_insert}] = MemBlockResultInfo<RegstDescProto*>();
}
if (mem_alloc_algo_conf.use_mem_volume_first_algo()) {
(*algo2result)[{kMemVolumeFirstAlgo, compact_insert}] = MemBlockResultInfo();
(*algo2result)[{kMemVolumeFirstAlgo, compact_insert}] = MemBlockResultInfo<RegstDescProto*>();
}
}
}
Expand Down Expand Up @@ -725,7 +581,7 @@ void IntraJobMemSharingUtil::InferMemBlockId4MemReusedRegst(
}

// step 2: multi-thread run several algorithm for each mem chain
HashMap<int64_t, HashMap<std::pair<MemAllocAlgoType, bool>, MemBlockResultInfo>>
HashMap<int64_t, HashMap<std::pair<MemAllocAlgoType, bool>, MemBlockResultInfo<RegstDescProto*>>>
mem_chain2algo2result;
{
int64_t work_size = mem_chain2mem_reused_regsts.size() * CountMemAllocAlgoNum();
Expand All @@ -737,7 +593,7 @@ void IntraJobMemSharingUtil::InferMemBlockId4MemReusedRegst(
for (auto& pair : mem_chain2algo2result.at(mem_chain_id)) {
MemAllocAlgoType algo_id = pair.first.first;
bool compact_insert = pair.first.second;
MemBlockResultInfo* result = &pair.second;
MemBlockResultInfo<RegstDescProto*>* result = &pair.second;
thread_pool.AddWork([algo_id, compact_insert, mem_chain_id, &mem_chain2regst2lifetime,
&mem_reused_regst2size, result, &counter]() {
SelectAlgorithmGenMemBlockOffset4Regsts(algo_id, compact_insert,
Expand All @@ -752,7 +608,7 @@ void IntraJobMemSharingUtil::InferMemBlockId4MemReusedRegst(

// step 3: choose best one for each mem chain and set offset for inplace consumer regst
for (auto& pair : mem_chain2algo2result) {
MemBlockResultInfo* best_result = nullptr;
MemBlockResultInfo<RegstDescProto*>* best_result = nullptr;
for (auto& algo_result_pair : pair.second) {
if (!best_result || algo_result_pair.second.mem_block_size < best_result->mem_block_size) {
best_result = &algo_result_pair.second;
Expand Down
Loading