Skip to content

Commit 4bbb6ad

Browse files
nickggfacebook-github-bot
authored andcommittedSep 21, 2020
[NNC] fix SyncThreads insertion and reenable CudaSharedMem test (pytorch#44909)
Summary: A previous fix for masking Cuda dimensions (pytorch#44733) changed the behaviour of inserting thread synchronization barriers in the Cuda CodeGen, causing the CudaSharedMemReduce_1 to be flaky and ultimately disabled. The issue is working out where these barriers must be inserted - solving this optimally is very hard, and I think not possible without dependency analysis we don't have, so I've changed our logic to be quite pessimistic. We'll insert barriers before and after any blocks that have thread dimensions masked (even between blocks that have no data dependencies). This should be correct, but it's an area we could improve performance. To address this somewhat I've added a simplifier pass that removes obviously unnecessary syncThreads. To avoid this test being flaky again, I've added a check against the generated code to ensure there is a syncThread in the right place. Also fixed a couple of non-functional but clarity issues in the generated code: fixed the missing newline after Stores in the CudaPrinter, and prevented the PrioritizeLoad mutator from pulling out loads contained within simple Let statements (such as those produced by the Registerizer). Pull Request resolved: pytorch#44909 Reviewed By: agolynski Differential Revision: D23800565 Pulled By: nickgg fbshipit-source-id: bddef1f40d8d461da965685f01d00b468d8a2c2f
1 parent e2f49c8 commit 4bbb6ad

File tree

8 files changed

+187
-31
lines changed

8 files changed

+187
-31
lines changed
 

‎test/cpp/tensorexpr/test_cuda.cpp

+19
Original file line numberDiff line numberDiff line change
@@ -762,6 +762,25 @@ void testCudaSharedMemReduce_1() {
762762

763763
// TODO: check the generated code for correctness.
764764
CudaCodeGen cuda_cg(loop_k1, a, b);
765+
766+
std::ostringstream oss;
767+
oss << *cuda_cg.stmt();
768+
769+
// Check the c write is not masked, but the d write is.
770+
const std::string& verification_pattern =
771+
R"IR(
772+
# CHECK: c_ = 0
773+
# CHECK: for (int m = 0; m < 128
774+
# CHECK: c_ = c_ +
775+
# CHECK: __syncthreads();
776+
# CHECK: if (threadIdx.x<1
777+
# CHECK: b[blockIdx.x] =
778+
# CHECK: __syncthreads();
779+
# CHECK: atomicAdd(&b[blockIdx.x], c_)
780+
)IR";
781+
782+
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
783+
765784
PaddedBuffer<float> a_v(1, M, N, "a_v");
766785
PaddedBuffer<float> b_v(1, "b_v");
767786
PaddedBuffer<float> b_ref(1, "b_ref");

‎test/cpp/tensorexpr/test_simplify.cpp

+94-5
Original file line numberDiff line numberDiff line change
@@ -3687,11 +3687,9 @@ void testSimplifyFuseConditions() {
36873687
Store::make(a, {1}, j, mask),
36883688
nullptr),
36893689
});
3690-
36913690
Stmt* simplified = IRSimplifier::simplify(body);
36923691
IS_NODE_WITH_NAME(Block, simplified, block);
36933692
ASSERT_EQ(block->nstmts(), 3);
3694-
36953693
auto it = block->begin();
36963694
it++;
36973695
IS_NODE_WITH_NAME(Cond, *it, cond);
@@ -3720,7 +3718,6 @@ void testSimplifyFuseConditions() {
37203718
Store::make(a, {1}, j, mask),
37213719
nullptr),
37223720
});
3723-
37243721
Stmt* simplified = IRSimplifier::simplify(body);
37253722
IS_NODE_WITH_NAME(Block, simplified, block);
37263723
ASSERT_EQ(block->nstmts(), 1);
@@ -3751,7 +3748,6 @@ void testSimplifyFuseConditions() {
37513748
Store::make(a, {1}, j, mask),
37523749
nullptr),
37533750
});
3754-
37553751
Stmt* simplified = IRSimplifier::simplify(body);
37563752
IS_NODE_WITH_NAME(Block, simplified, block);
37573753
ASSERT_EQ(block->nstmts(), 3);
@@ -3786,7 +3782,6 @@ void testSimplifyFuseConditions() {
37863782
CompareSelectOperation::kLT),
37873783
Store::make(a, {1}, i, mask),
37883784
nullptr)});
3789-
37903785
Stmt* simplified = IRSimplifier::simplify(body);
37913786
IS_NODE_WITH_NAME(Block, simplified, block);
37923787
ASSERT_EQ(block->nstmts(), 1);
@@ -3861,5 +3856,99 @@ void testSimplifyFuseConditions() {
38613856
}
38623857
}
38633858

3859+
void testSimplifySyncThreads() {
3860+
KernelScope kernel_scope;
3861+
Buffer a(BufHandle("A", {4}, kInt));
3862+
auto mask = IntImm::make(1);
3863+
VarHandle i("i", kInt);
3864+
3865+
{
3866+
// Merge two inner SyncThreads.
3867+
auto body = Block::make({Store::make(a, {0}, 1, 1),
3868+
new SyncThreads(),
3869+
new SyncThreads(),
3870+
Store::make(a, {1}, 0, 1)});
3871+
Stmt* simplified = IRSimplifier::simplify(body);
3872+
IS_NODE_WITH_NAME(Block, simplified, block);
3873+
ASSERT_EQ(block->nstmts(), 3);
3874+
auto it = block->begin();
3875+
IS_NODE(Store, *it++);
3876+
IS_NODE(SyncThreads, *it++);
3877+
IS_NODE(Store, *it++);
3878+
}
3879+
3880+
{
3881+
// Eliminate outer SyncThreads.
3882+
auto body = Block::make(
3883+
{new SyncThreads(), Store::make(a, {1}, 0, 1), new SyncThreads()});
3884+
3885+
Stmt* simplified = IRSimplifier::simplify(body);
3886+
IS_NODE_WITH_NAME(Block, simplified, block);
3887+
ASSERT_EQ(block->nstmts(), 1);
3888+
auto it = block->begin();
3889+
IS_NODE(Store, *it);
3890+
}
3891+
3892+
{
3893+
// Merge many inner SyncThreads.
3894+
auto body = Block::make({Store::make(a, {0}, 1, 1),
3895+
new SyncThreads(),
3896+
new SyncThreads(),
3897+
new SyncThreads(),
3898+
new SyncThreads(),
3899+
new SyncThreads(),
3900+
Store::make(a, {1}, 0, 1)});
3901+
3902+
Stmt* simplified = IRSimplifier::simplify(body);
3903+
IS_NODE_WITH_NAME(Block, simplified, block);
3904+
ASSERT_EQ(block->nstmts(), 3);
3905+
auto it = block->begin();
3906+
IS_NODE(Store, *it++);
3907+
IS_NODE(SyncThreads, *it++);
3908+
IS_NODE(Store, *it++);
3909+
}
3910+
3911+
{
3912+
// Merge multiple outer SyncThreads.
3913+
auto body = Block::make({new SyncThreads(),
3914+
new SyncThreads(),
3915+
Store::make(a, {1}, 0, 1),
3916+
new SyncThreads(),
3917+
new SyncThreads(),
3918+
new SyncThreads(),
3919+
new SyncThreads()});
3920+
3921+
Stmt* simplified = IRSimplifier::simplify(body);
3922+
IS_NODE_WITH_NAME(Block, simplified, block);
3923+
ASSERT_EQ(block->nstmts(), 1);
3924+
auto it = block->begin();
3925+
IS_NODE(Store, *it);
3926+
}
3927+
3928+
{
3929+
// Merge multiple sections;
3930+
auto body = Block::make({Store::make(a, {0}, 1, 1),
3931+
new SyncThreads(),
3932+
new SyncThreads(),
3933+
Store::make(a, {1}, 0, 1),
3934+
Store::make(a, {2}, 0, 1),
3935+
new SyncThreads(),
3936+
new SyncThreads(),
3937+
new SyncThreads(),
3938+
Store::make(a, {3}, 0, 1)});
3939+
3940+
Stmt* simplified = IRSimplifier::simplify(body);
3941+
IS_NODE_WITH_NAME(Block, simplified, block);
3942+
ASSERT_EQ(block->nstmts(), 6);
3943+
auto it = block->begin();
3944+
IS_NODE(Store, *it++);
3945+
IS_NODE(SyncThreads, *it++);
3946+
IS_NODE(Store, *it++);
3947+
IS_NODE(Store, *it++);
3948+
IS_NODE(SyncThreads, *it++);
3949+
IS_NODE(Store, *it++);
3950+
}
3951+
}
3952+
38643953
} // namespace jit
38653954
} // namespace torch

‎test/cpp/tensorexpr/tests.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ namespace jit {
215215
_(DontSimplifyRand) \
216216
_(SimplifyReorderForCond) \
217217
_(SimplifyFuseConditions) \
218+
_(SimplifySyncThreads) \
218219
_(RegisterizerSimple) \
219220
_(RegisterizerLoop) \
220221
_(RegisterizerLoopFixedLoad) \
@@ -434,6 +435,7 @@ namespace jit {
434435
_(CudaOneBlockMultiThreadGlobalReduce1) \
435436
_(CudaNoThreadIdxWrite_1) \
436437
_(CudaLocalMemReduce_1) \
438+
_(CudaSharedMemReduce_1) \
437439
_(CudaTestRand01) \
438440
_(CudaSigmoid) \
439441
_(CudaHalfCast) \
@@ -449,7 +451,6 @@ namespace jit {
449451
_(CudaMaskInnerLoopOneBlock) \
450452
_(CudaMaskMultiDimMultiAxis) \
451453
_(CudaMaskMultiDimMultiLevel)
452-
// _(CudaSharedMemReduce_1)
453454

454455
#define DECLARE_TENSOREXPR_TEST(name) void test##name();
455456
TH_FORALL_TENSOREXPR_TESTS(DECLARE_TENSOREXPR_TEST)

‎torch/csrc/jit/tensorexpr/cuda_codegen.cpp

+21-19
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,7 @@ void CudaPrinter::visit(const Store* v) {
436436
os() << *v->base_handle() << "[" << *v->flat_index() << "] = ";
437437
}
438438
os() << *v->value() << ";";
439+
os() << std::endl;
439440
}
440441

441442
void CudaPrinter::visit(const AtomicAdd* v) {
@@ -505,6 +506,9 @@ class PrioritizeLoad : public IRMutator {
505506
if (nested_if_then_else_ > 0) {
506507
return IRMutator::mutate(v);
507508
}
509+
if (nested_let_) {
510+
return IRMutator::mutate(v);
511+
}
508512
if (thread_local_bufs_.count(v->base_handle()) > 0) {
509513
return IRMutator::mutate(v);
510514
}
@@ -566,6 +570,13 @@ class PrioritizeLoad : public IRMutator {
566570
return s;
567571
}
568572

573+
Stmt* mutate(const Let* v) override {
574+
nested_let_ = true;
575+
Stmt* s = IRMutator::mutate(v);
576+
nested_let_ = false;
577+
return s;
578+
}
579+
569580
Stmt* mutate(const Block* v) override {
570581
bool any_change = false;
571582

@@ -631,8 +642,9 @@ class PrioritizeLoad : public IRMutator {
631642
// v = false_v;
632643
// }
633644
// int v2 = v + 2;
634-
int nested_if_then_else_ = 0;
645+
int nested_if_then_else_{0};
635646
const Store* nested_store_{nullptr};
647+
bool nested_let_{false};
636648
std::unordered_set<const Var*> thread_local_bufs_;
637649
};
638650

@@ -703,13 +715,6 @@ Stmt* GPUMetaVarRewriter::mutate(const For* v) {
703715
IRSimplifier::simplify(new Max(old_reach, v->stop(), true));
704716
}
705717

706-
// If a thread dimension has changed, insert a syncThreads in the enclosing
707-
// Block.
708-
if (last_thread_dim_ && !exprEquals(last_thread_dim_, v->stop())) {
709-
need_sync_ = true;
710-
}
711-
last_thread_dim_ = v->stop();
712-
713718
const Var* metaVar = gpu_thread_vars_[gpu_thread_index];
714719
body = Substitute(Stmt::clone(body), {{v->var(), metaVar}});
715720
}
@@ -750,16 +755,6 @@ Stmt* GPUMetaVarRewriter::mutate(const Block* v) {
750755
stmt_new = Stmt::clone(stmt_new);
751756
}
752757

753-
if (need_sync_) {
754-
// sync is special, we never want to mask it and it is never part of
755-
// another segment.
756-
pushAndReset(false);
757-
current.stmts().push_back(new SyncThreads());
758-
pushAndReset(true);
759-
760-
need_sync_ = false;
761-
}
762-
763758
// Likewise, Allocate and Free should never be masked.
764759
if (dynamic_cast<Allocate*>(stmt) || dynamic_cast<Free*>(stmt)) {
765760
pushAndReset(false);
@@ -796,8 +791,8 @@ Stmt* GPUMetaVarRewriter::mutate(const Block* v) {
796791
}
797792

798793
std::vector<Stmt*> stmts;
799-
int rsqi = 0;
800794
for (auto& segment : innerSegments) {
795+
bool need_sync = false;
801796
// We never mask loops, they'll mask their contents.
802797
if (!segment.mask()) {
803798
TORCH_INTERNAL_ASSERT(segment.stmts().size() == 1);
@@ -812,6 +807,7 @@ Stmt* GPUMetaVarRewriter::mutate(const Block* v) {
812807
auto& thread_extents = cuda_analysis_->gpu_thread_extents();
813808
for (size_t i = 0; i < gpu_thread_vars_.size(); ++i) {
814809
if (!exprEquals(current_thread_reach_[i], thread_extents[i])) {
810+
need_sync = true;
815811
// Mask it against the current dimensions.
816812
inner = new Cond(
817813
new CompareSelect(
@@ -836,7 +832,13 @@ Stmt* GPUMetaVarRewriter::mutate(const Block* v) {
836832
}
837833
}
838834

835+
if (need_sync) {
836+
stmts.push_back(new SyncThreads());
837+
}
839838
stmts.push_back(inner);
839+
if (need_sync) {
840+
stmts.push_back(new SyncThreads());
841+
}
840842
}
841843

842844
return new Block(stmts);

‎torch/csrc/jit/tensorexpr/cuda_codegen.h

-2
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,6 @@ class GPUMetaVarRewriter : public IRMutator {
138138
std::vector<const Expr*> current_block_reach_;
139139
std::vector<const Expr*> current_thread_reach_;
140140

141-
bool need_sync_ = false;
142-
const Expr* last_thread_dim_ = nullptr;
143141
const CudaAnalysis* cuda_analysis_;
144142
};
145143

‎torch/csrc/jit/tensorexpr/ir_simplifier.cpp

+48-2
Original file line numberDiff line numberDiff line change
@@ -1930,7 +1930,6 @@ Block* TermExpander::fuseConditions(Block* v) {
19301930
// erase, which shortens the list.
19311931
stmts.pop_back();
19321932
stmts.push_back(prev_cond);
1933-
19341933
did_anything = true;
19351934
}
19361935

@@ -1948,6 +1947,51 @@ Block* TermExpander::fuseConditions(Block* v) {
19481947
return new Block(stmts);
19491948
}
19501949

1950+
Stmt* TermExpander::fuseSyncThreads(Block* block) {
1951+
// only really first if highest level Block.
1952+
bool first = block->get_parent() == nullptr;
1953+
SyncThreads* last = nullptr;
1954+
std::vector<Stmt*> stmts;
1955+
bool did_anything = false;
1956+
1957+
for (auto* s : *block) {
1958+
SyncThreads* sync = dynamic_cast<SyncThreads*>(s);
1959+
if (!sync) {
1960+
first = false;
1961+
last = nullptr;
1962+
stmts.push_back(s);
1963+
continue;
1964+
}
1965+
1966+
if (first || last) {
1967+
did_anything = true;
1968+
continue;
1969+
}
1970+
1971+
last = sync;
1972+
first = false;
1973+
stmts.push_back(s);
1974+
}
1975+
1976+
if (last) {
1977+
stmts.pop_back();
1978+
did_anything = true;
1979+
}
1980+
1981+
if (!did_anything) {
1982+
return block;
1983+
}
1984+
1985+
// clean up parents.
1986+
for (auto* s : stmts) {
1987+
if (s->get_parent() == block) {
1988+
block->remove_stmt(s);
1989+
}
1990+
}
1991+
1992+
return new Block({stmts});
1993+
}
1994+
19511995
Stmt* TermExpander::mutate(const Block* v) {
19521996
Stmt* new_stmt = IRSimplifierBase::mutate(v);
19531997
Block* new_block = dynamic_cast<Block*>(new_stmt);
@@ -1956,7 +2000,9 @@ Stmt* TermExpander::mutate(const Block* v) {
19562000
}
19572001

19582002
// fuseConditions will return the original block if it cannot fuse.
1959-
return fuseConditions(new_block);
2003+
new_block = fuseConditions(new_block);
2004+
/// fuseSyncThreads too.
2005+
return fuseSyncThreads(new_block);
19602006
}
19612007

19622008
bool exprEquals(const Expr* A, const Expr* B) {

‎torch/csrc/jit/tensorexpr/ir_simplifier.h

+1
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,7 @@ class TORCH_API TermExpander : public IRSimplifierBase {
570570

571571
// Override to enable condition fusing.
572572
Block* fuseConditions(Block* v);
573+
Stmt* fuseSyncThreads(Block* block);
573574
Stmt* mutate(const Block* v) override;
574575
};
575576

‎torch/csrc/jit/tensorexpr/stmt.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -676,7 +676,7 @@ class TORCH_API For : public StmtNode<For> {
676676
// This node could only shows up as an internal with GPU backends.
677677
// TODO: move to this an internal IR.
678678
// TODO: make IR nodes extensible.
679-
class AtomicAdd : public StmtNode<AtomicAdd> {
679+
class TORCH_API AtomicAdd : public StmtNode<AtomicAdd> {
680680
public:
681681
AtomicAdd(
682682
const Buf* buf,
@@ -711,7 +711,7 @@ class AtomicAdd : public StmtNode<AtomicAdd> {
711711
const Expr* value_;
712712
};
713713

714-
class SyncThreads : public StmtNode<SyncThreads> {
714+
class TORCH_API SyncThreads : public StmtNode<SyncThreads> {
715715
public:
716716
SyncThreads() {}
717717
};

0 commit comments

Comments
 (0)
Please sign in to comment.