Skip to content

Commit b5a5ea1

Browse files
committed
Make emitted egal code more loopy
The strategy here is to look at (data, padding) pairs and RLE them into loops, so that repeated adjacent patterns use a loop rather than getting unrolled. On the test case from #54109, this makes compilation essentially instant, while also being faster at runtime (turns out LLVM spends a massive amount of time AND the answer is bad). There's some obvious further enhancements possible here: 1. The `memcmp` constant is small. LLVM has a pass to inline these with better code. However, we don't have it turned on. We should consider vendoring it, though we may want to add some shorcutting to it to avoid having it iterate through each function. 2. This only does one level of sequence matching. It could be recursed to turn things into nested loops. However, this solves the immediate issue, so hopefully it's a useful start. Fixes #54109.
1 parent 7ba1b33 commit b5a5ea1

File tree

4 files changed

+214
-11
lines changed

4 files changed

+214
-11
lines changed

src/codegen.cpp

+136-1
Original file line numberDiff line numberDiff line change
@@ -3358,6 +3358,58 @@ static Value *emit_bitsunion_compare(jl_codectx_t &ctx, const jl_cgval_t &arg1,
33583358
return phi;
33593359
}
33603360

3361+
struct egal_desc {
3362+
size_t offset;
3363+
size_t nrepeats;
3364+
size_t data_bytes;
3365+
size_t padding_bytes;
3366+
};
3367+
3368+
template <typename callback>
3369+
static size_t emit_masked_bits_compare(callback &emit_desc, jl_datatype_t *aty, egal_desc &current_desc)
3370+
{
3371+
// Memcmp, but with masked padding
3372+
size_t data_bytes = 0;
3373+
size_t padding_bytes = 0;
3374+
size_t nfields = jl_datatype_nfields(aty);
3375+
size_t total_size = jl_datatype_size(aty);
3376+
for (size_t i = 0; i < nfields; ++i) {
3377+
size_t offset = jl_field_offset(aty, i);
3378+
size_t fend = i == nfields - 1 ? total_size : jl_field_offset(aty, i + 1);
3379+
size_t fsz = jl_field_size(aty, i);
3380+
jl_datatype_t *fty = (jl_datatype_t*)jl_field_type(aty, i);
3381+
if (jl_field_isptr(aty, i) || !fty->layout->flags.haspadding) {
3382+
// The field has no internal padding
3383+
data_bytes += fsz;
3384+
if (offset + fsz == fend) {
3385+
// The field has no padding after. Merge this into the current
3386+
// comparison range and go to next field.
3387+
} else {
3388+
padding_bytes = fend - offset - fsz;
3389+
// Found padding. Either merge this into the current comparison
3390+
// range, or emit the old one and start a new one.
3391+
if (current_desc.data_bytes == data_bytes &&
3392+
current_desc.padding_bytes == padding_bytes) {
3393+
// Same as the previous range, just note that down, so we
3394+
// emit this as a loop.
3395+
current_desc.nrepeats += 1;
3396+
} else {
3397+
if (current_desc.nrepeats != 0)
3398+
emit_desc(current_desc);
3399+
current_desc.nrepeats = 1;
3400+
current_desc.data_bytes = data_bytes;
3401+
current_desc.padding_bytes = padding_bytes;
3402+
}
3403+
data_bytes = 0;
3404+
}
3405+
} else {
3406+
// The field may have internal padding. Recurse this.
3407+
data_bytes += emit_masked_bits_compare(emit_desc, fty, current_desc);
3408+
}
3409+
}
3410+
return data_bytes;
3411+
}
3412+
33613413
static Value *emit_bits_compare(jl_codectx_t &ctx, jl_cgval_t arg1, jl_cgval_t arg2)
33623414
{
33633415
++EmittedBitsCompares;
@@ -3396,7 +3448,7 @@ static Value *emit_bits_compare(jl_codectx_t &ctx, jl_cgval_t arg1, jl_cgval_t a
33963448
if (at->isAggregateType()) { // Struct or Array
33973449
jl_datatype_t *sty = (jl_datatype_t*)arg1.typ;
33983450
size_t sz = jl_datatype_size(sty);
3399-
if (sz > 512 && !sty->layout->flags.haspadding) {
3451+
if (sz > 512 && !sty->layout->flags.haspadding && sty->layout->flags.isbitsegal) {
34003452
Value *varg1 = arg1.ispointer() ? data_pointer(ctx, arg1) :
34013453
value_to_pointer(ctx, arg1).V;
34023454
Value *varg2 = arg2.ispointer() ? data_pointer(ctx, arg2) :
@@ -3433,6 +3485,89 @@ static Value *emit_bits_compare(jl_codectx_t &ctx, jl_cgval_t arg1, jl_cgval_t a
34333485
}
34343486
return ctx.builder.CreateICmpEQ(answer, ConstantInt::get(getInt32Ty(ctx.builder.getContext()), 0));
34353487
}
3488+
else if (sz > 512 && jl_struct_try_layout(sty) && sty->layout->flags.isbitsegal) {
3489+
Type *TInt8 = getInt8Ty(ctx.builder.getContext());
3490+
Type *TpInt8 = getInt8PtrTy(ctx.builder.getContext());
3491+
Type *TInt1 = getInt1Ty(ctx.builder.getContext());
3492+
Value *varg1 = arg1.ispointer() ? data_pointer(ctx, arg1) :
3493+
value_to_pointer(ctx, arg1).V;
3494+
Value *varg2 = arg2.ispointer() ? data_pointer(ctx, arg2) :
3495+
value_to_pointer(ctx, arg2).V;
3496+
varg1 = emit_pointer_from_objref(ctx, varg1);
3497+
varg2 = emit_pointer_from_objref(ctx, varg2);
3498+
varg1 = emit_bitcast(ctx, varg1, TpInt8);
3499+
varg2 = emit_bitcast(ctx, varg2, TpInt8);
3500+
3501+
Value *answer = nullptr;
3502+
auto emit_desc = [&](egal_desc desc) {
3503+
Value *ptr1 = varg1;
3504+
Value *ptr2 = varg2;
3505+
if (desc.offset != 0) {
3506+
ptr1 = ctx.builder.CreateConstInBoundsGEP1_32(TInt8, ptr1, desc.offset);
3507+
ptr2 = ctx.builder.CreateConstInBoundsGEP1_32(TInt8, ptr2, desc.offset);
3508+
}
3509+
3510+
Value *new_ptr1 = ptr1;
3511+
Value *endptr1 = nullptr;
3512+
BasicBlock *postBB = nullptr;
3513+
BasicBlock *loopBB = nullptr;
3514+
PHINode *answerphi = nullptr;
3515+
if (desc.nrepeats != 1) {
3516+
// Set up loop
3517+
endptr1 = ctx.builder.CreateConstInBoundsGEP1_32(TInt8, ptr1, desc.nrepeats * (desc.data_bytes + desc.padding_bytes));;
3518+
3519+
BasicBlock *currBB = ctx.builder.GetInsertBlock();
3520+
loopBB = BasicBlock::Create(ctx.builder.getContext(), "egal_loop", ctx.f);
3521+
postBB = BasicBlock::Create(ctx.builder.getContext(), "post", ctx.f);
3522+
ctx.builder.CreateBr(loopBB);
3523+
3524+
ctx.builder.SetInsertPoint(loopBB);
3525+
answerphi = ctx.builder.CreatePHI(TInt1, 2);
3526+
answerphi->addIncoming(answer ? answer : ConstantInt::get(TInt1, 1), currBB);
3527+
answer = answerphi;
3528+
3529+
PHINode *itr1 = ctx.builder.CreatePHI(ptr1->getType(), 2);
3530+
PHINode *itr2 = ctx.builder.CreatePHI(ptr2->getType(), 2);
3531+
3532+
new_ptr1 = ctx.builder.CreateConstInBoundsGEP1_32(TInt8, itr1, desc.data_bytes + desc.padding_bytes);
3533+
itr1->addIncoming(ptr1, currBB);
3534+
itr1->addIncoming(new_ptr1, loopBB);
3535+
3536+
Value *new_ptr2 = ctx.builder.CreateConstInBoundsGEP1_32(TInt8, itr2, desc.data_bytes + desc.padding_bytes);
3537+
itr2->addIncoming(ptr2, currBB);
3538+
itr2->addIncoming(new_ptr2, loopBB);
3539+
3540+
ptr1 = itr1;
3541+
ptr2 = itr2;
3542+
}
3543+
3544+
// Emit memcmp. TODO: LLVM has a pass to expand this for additional
3545+
// performance.
3546+
Value *this_answer = ctx.builder.CreateCall(prepare_call(memcmp_func),
3547+
{ ptr1,
3548+
ptr2,
3549+
ConstantInt::get(ctx.types().T_size, desc.data_bytes) });
3550+
this_answer = ctx.builder.CreateICmpEQ(this_answer, ConstantInt::get(getInt32Ty(ctx.builder.getContext()), 0));
3551+
answer = answer ? ctx.builder.CreateAnd(answer, this_answer) : this_answer;
3552+
if (endptr1) {
3553+
answerphi->addIncoming(answer, loopBB);
3554+
Value *loopend = ctx.builder.CreateICmpEQ(new_ptr1, endptr1);
3555+
ctx.builder.CreateCondBr(loopend, postBB, loopBB);
3556+
ctx.builder.SetInsertPoint(postBB);
3557+
}
3558+
};
3559+
egal_desc current_desc = {0};
3560+
size_t trailing_data_bytes = emit_masked_bits_compare(emit_desc, sty, current_desc);
3561+
assert(current_desc.nrepeats != 0);
3562+
emit_desc(current_desc);
3563+
if (trailing_data_bytes != 0) {
3564+
current_desc.nrepeats = 1;
3565+
current_desc.data_bytes = trailing_data_bytes;
3566+
current_desc.padding_bytes = 0;
3567+
emit_desc(current_desc);
3568+
}
3569+
return answer;
3570+
}
34363571
else {
34373572
jl_svec_t *types = sty->types;
34383573
Value *answer = ConstantInt::get(getInt1Ty(ctx.builder.getContext()), 1);

src/datatype.c

+23-9
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ static jl_datatype_layout_t *jl_get_layout(uint32_t sz,
180180
uint32_t npointers,
181181
uint32_t alignment,
182182
int haspadding,
183+
int isbitsegal,
183184
int arrayelem,
184185
jl_fielddesc32_t desc[],
185186
uint32_t pointers[]) JL_NOTSAFEPOINT
@@ -226,6 +227,7 @@ static jl_datatype_layout_t *jl_get_layout(uint32_t sz,
226227
flddesc->nfields = nfields;
227228
flddesc->alignment = alignment;
228229
flddesc->flags.haspadding = haspadding;
230+
flddesc->flags.isbitscomparable = isbitscomparable;
229231
flddesc->flags.fielddesc_type = fielddesc_type;
230232
flddesc->flags.arrayelem_isboxed = arrayelem == 1;
231233
flddesc->flags.arrayelem_isunion = arrayelem == 2;
@@ -504,6 +506,7 @@ void jl_get_genericmemory_layout(jl_datatype_t *st)
504506
int isunboxed = jl_islayout_inline(eltype, &elsz, &al) && (kind != (jl_value_t*)jl_atomic_sym || jl_is_datatype(eltype));
505507
int isunion = isunboxed && jl_is_uniontype(eltype);
506508
int haspadding = 1; // we may want to eventually actually compute this more precisely
509+
int isbitsegal = 0;
507510
int nfields = 0; // aka jl_is_layout_opaque
508511
int npointers = 1;
509512
int zi;
@@ -562,7 +565,7 @@ void jl_get_genericmemory_layout(jl_datatype_t *st)
562565
else
563566
arrayelem = 0;
564567
assert(!st->layout);
565-
st->layout = jl_get_layout(elsz, nfields, npointers, al, haspadding, arrayelem, NULL, pointers);
568+
st->layout = jl_get_layout(elsz, nfields, npointers, al, haspadding, isbitsegal, arrayelem, NULL, pointers);
566569
st->zeroinit = zi;
567570
//st->has_concrete_subtype = 1;
568571
//st->isbitstype = 0;
@@ -673,6 +676,7 @@ void jl_compute_field_offsets(jl_datatype_t *st)
673676
size_t alignm = 1;
674677
int zeroinit = 0;
675678
int haspadding = 0;
679+
int isbitsegal = 1;
676680
int homogeneous = 1;
677681
int needlock = 0;
678682
uint32_t npointers = 0;
@@ -687,19 +691,30 @@ void jl_compute_field_offsets(jl_datatype_t *st)
687691
throw_ovf(should_malloc, desc, st, fsz);
688692
desc[i].isptr = 0;
689693
if (jl_is_uniontype(fld)) {
690-
haspadding = 1;
691694
fsz += 1; // selector byte
692695
zeroinit = 1;
696+
// TODO: Some unions could be bits comparable.
697+
isbitsegal = 0;
693698
}
694699
else {
695700
uint32_t fld_npointers = ((jl_datatype_t*)fld)->layout->npointers;
696701
if (((jl_datatype_t*)fld)->layout->flags.haspadding)
697702
haspadding = 1;
703+
if (!((jl_datatype_t*)fld)->layout->flags.isbitsegal)
704+
isbitsegal = 0;
698705
if (i >= nfields - st->name->n_uninitialized && fld_npointers &&
699706
fld_npointers * sizeof(void*) != fsz) {
700-
// field may be undef (may be uninitialized and contains pointer),
701-
// and contains non-pointer fields of non-zero sizes.
702-
haspadding = 1;
707+
// For field types that contain pointers, we allow inlinealloc
708+
// as long as the field type itself is always fully initialized.
709+
// In such a case, we use the first pointer in the inlined field
710+
// as the #undef marker (if it is zero, we treat the whole inline
711+
// struct as #undef). However, we do not zero-initialize the whole
712+
// struct, so the non-pointer parts of the inline allocation may
713+
// be arbitrary, but still need to compare egal (because all #undef)
714+
// representations are egal. Because of this, we cannot bitscompare
715+
// them.
716+
// TODO: Consider zero-initializing the whole struct.
717+
isbitsegal = 0;
703718
}
704719
if (!zeroinit)
705720
zeroinit = ((jl_datatype_t*)fld)->zeroinit;
@@ -715,8 +730,7 @@ void jl_compute_field_offsets(jl_datatype_t *st)
715730
zeroinit = 1;
716731
npointers++;
717732
if (!jl_pointer_egal(fld)) {
718-
// this somewhat poorly named flag says whether some of the bits can be non-unique
719-
haspadding = 1;
733+
isbitsegal = 0;
720734
}
721735
}
722736
if (isatomic && fsz > MAX_ATOMIC_SIZE)
@@ -777,7 +791,7 @@ void jl_compute_field_offsets(jl_datatype_t *st)
777791
}
778792
}
779793
assert(ptr_i == npointers);
780-
st->layout = jl_get_layout(sz, nfields, npointers, alignm, haspadding, 0, desc, pointers);
794+
st->layout = jl_get_layout(sz, nfields, npointers, alignm, haspadding, isbitsegal, 0, desc, pointers);
781795
if (should_malloc) {
782796
free(desc);
783797
if (npointers)
@@ -931,7 +945,7 @@ JL_DLLEXPORT jl_datatype_t *jl_new_primitivetype(jl_value_t *name, jl_module_t *
931945
bt->ismutationfree = 1;
932946
bt->isidentityfree = 1;
933947
bt->isbitstype = (parameters == jl_emptysvec);
934-
bt->layout = jl_get_layout(nbytes, 0, 0, alignm, 0, 0, NULL, NULL);
948+
bt->layout = jl_get_layout(nbytes, 0, 0, alignm, 0, 1, 0, NULL, NULL);
935949
bt->instance = NULL;
936950
return bt;
937951
}

src/julia.h

+4-1
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,10 @@ typedef struct {
574574
// metadata bit only for GenericMemory eltype layout
575575
uint16_t arrayelem_isboxed : 1;
576576
uint16_t arrayelem_isunion : 1;
577-
uint16_t padding : 11;
577+
// If set, this type's egality can be determined entirely by comparing
578+
// the non-padding bits of this datatype.
579+
uint16_t isbitsegal : 1;
580+
uint16_t padding : 10;
578581
} flags;
579582
// union {
580583
// jl_fielddesc8_t field8[nfields];

test/compiler/codegen.jl

+51
Original file line numberDiff line numberDiff line change
@@ -873,3 +873,54 @@ if Sys.ARCH === :x86_64
873873
end
874874
end
875875
end
876+
877+
# #54109 - Excessive LLVM time for egal
878+
struct DefaultOr54109{T}
879+
x::T
880+
default::Bool
881+
end
882+
883+
@eval struct Torture1_54109
884+
$((Expr(:(::), Symbol("x$i"), DefaultOr54109{Float64}) for i = 1:897)...)
885+
end
886+
Torture1_54109() = Torture1_54109((DefaultOr54109(1.0, false) for i = 1:897)...)
887+
888+
@eval struct Torture2_54109
889+
$((Expr(:(::), Symbol("x$i"), DefaultOr54109{Float64}) for i = 1:400)...)
890+
$((Expr(:(::), Symbol("x$(i+400)"), DefaultOr54109{Int16}) for i = 1:400)...)
891+
end
892+
Torture2_54109() = Torture2_54109((DefaultOr54109(1.0, false) for i = 1:400)..., (DefaultOr54109(Int16(1), false) for i = 1:400)...)
893+
894+
@noinline egal_any54109(x, @nospecialize(y::Any)) = x === Base.compilerbarrier(:type, y)
895+
896+
let ir1 = get_llvm(egal_any54109, Tuple{Torture1_54109, Any}),
897+
ir2 = get_llvm(egal_any54109, Tuple{Torture2_54109, Any})
898+
899+
# We can't really do timing on CI, so instead, let's look at the length of
900+
# the optimized IR. The original version had tens of thousands of lines and
901+
# was slower, so just check here that we only have < 500 lines. If somebody,
902+
# implements a better comparison that's larger than that, just re-benchmark
903+
# this and adjust the threshold.
904+
905+
@test count(==('\n'), ir1) < 500
906+
@test count(==('\n'), ir2) < 500
907+
end
908+
909+
## Regression test for egal of a struct of this size without padding, but with
910+
## non-bitsegal, to make sure that it doesn't accidentally go down the accelerated
911+
## path.
912+
@eval struct BigStructAnyInt
913+
$((Expr(:(::), Symbol("x$i"), Tuple{Any, Int}) for i = 1:33)...)
914+
end
915+
BigStructAnyInt() = BigStructAnyInt(((Union{Base.inferencebarrier(Float64), Int}, i) for i = 1:33)...)
916+
@test egal_any54109(BigStructAnyInt(), BigStructAnyInt())
917+
918+
## For completeness, also test correctness, since we don't have a lot of
919+
## large-struct tests.
920+
921+
# The two allocations of the same struct will likely have different padding,
922+
# we want to make sure we find them egal anyway - a naive memcmp would
923+
# accidentally look at it.
924+
@test egal_any54109(Torture1_54109(), Torture1_54109())
925+
@test egal_any54109(Torture2_54109(), Torture2_54109())
926+
@test !egal_any54109(Torture1_54109(), Torture1_54109((DefaultOr54109(2.0, false) for i = 1:897)...))

0 commit comments

Comments
 (0)