Skip to content

Commit ee760e0

Browse files
committed
Make emitted egal code more loopy
The strategy here is to look at (data, padding) pairs and RLE them into loops, so that repeated adjacent patterns use a loop rather than getting unrolled. On the test case from #54109, this makes compilation essentially instant, while also being faster at runtime (turns out LLVM spends a massive amount of time AND the answer is bad). There's some obvious further enhancements possible here: 1. The `memcmp` constant is small. LLVM has a pass to inline these with better code. However, we don't have it turned on. We should consider vendoring it, though we may want to add some shorcutting to it to avoid having it iterate through each function. 2. This only does one level of sequence matching. It could be recursed to turn things into nested loops. However, this solves the immediate issue, so hopefully it's a useful start. Fixes #54109.
1 parent 7ba1b33 commit ee760e0

File tree

8 files changed

+246
-23
lines changed

8 files changed

+246
-23
lines changed

base/reflection.jl

+17-2
Original file line numberDiff line numberDiff line change
@@ -488,8 +488,8 @@ gc_alignment(T::Type) = gc_alignment(Core.sizeof(T))
488488
Base.datatype_haspadding(dt::DataType) -> Bool
489489
490490
Return whether the fields of instances of this type are packed in memory,
491-
with no intervening padding bits (defined as bits whose value does not uniquely
492-
impact the egal test when applied to the struct fields).
491+
with no intervening padding bits (defined as bits whose value does not impact
492+
the semantic value of the instance itself).
493493
Can be called on any `isconcretetype`.
494494
"""
495495
function datatype_haspadding(dt::DataType)
@@ -499,6 +499,21 @@ function datatype_haspadding(dt::DataType)
499499
return flags & 1 == 1
500500
end
501501

502+
"""
503+
Base.datatype_isbitsegal(dt::DataType) -> Bool
504+
505+
Return whether egality of the (non-padding bits of the) in-memory representation
506+
of an instance of this type implies semantic egality of the instance itself.
507+
This may not be the case if the type contains to other values whose egality is
508+
independent of their identity (e.g. immutable structs, some types, etc.).
509+
"""
510+
function datatype_isbitsegal(dt::DataType)
511+
@_foldable_meta
512+
dt.layout == C_NULL && throw(UndefRefError())
513+
flags = unsafe_load(convert(Ptr{DataTypeLayout}, dt.layout)).flags
514+
return (flags & (1<<5)) != 0
515+
end
516+
502517
"""
503518
Base.datatype_nfields(dt::DataType) -> UInt32
504519

src/builtins.c

+3-3
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ static int NOINLINE compare_fields(const jl_value_t *a, const jl_value_t *b, jl_
115115
continue; // skip this field (it is #undef)
116116
}
117117
}
118-
if (!ft->layout->flags.haspadding) {
118+
if (!ft->layout->flags.haspadding && ft->layout->flags.isbitsegal) {
119119
if (!bits_equal(ao, bo, ft->layout->size))
120120
return 0;
121121
}
@@ -284,7 +284,7 @@ inline int jl_egal__bits(const jl_value_t *a JL_MAYBE_UNROOTED, const jl_value_t
284284
if (sz == 0)
285285
return 1;
286286
size_t nf = jl_datatype_nfields(dt);
287-
if (nf == 0 || !dt->layout->flags.haspadding)
287+
if (nf == 0 || (!dt->layout->flags.haspadding && dt->layout->flags.isbitsegal))
288288
return bits_equal(a, b, sz);
289289
return compare_fields(a, b, dt);
290290
}
@@ -394,7 +394,7 @@ static uintptr_t immut_id_(jl_datatype_t *dt, jl_value_t *v, uintptr_t h) JL_NOT
394394
if (sz == 0)
395395
return ~h;
396396
size_t f, nf = jl_datatype_nfields(dt);
397-
if (nf == 0 || (!dt->layout->flags.haspadding && dt->layout->npointers == 0)) {
397+
if (nf == 0 || (!dt->layout->flags.haspadding && dt->layout->flags.isbitsegal && dt->layout->npointers == 0)) {
398398
// operate element-wise if there are unused bits inside,
399399
// otherwise just take the whole data block at once
400400
// a few select pointers (notably symbol) also have special hash values

src/cgutils.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -2200,7 +2200,8 @@ static jl_cgval_t typed_store(jl_codectx_t &ctx,
22002200
}
22012201
else if (!isboxed) {
22022202
assert(jl_is_concrete_type(jltype));
2203-
needloop = ((jl_datatype_t*)jltype)->layout->flags.haspadding;
2203+
needloop = ((jl_datatype_t*)jltype)->layout->flags.haspadding ||
2204+
!((jl_datatype_t*)jltype)->layout->flags.isbitsegal;
22042205
Value *SameType = emit_isa(ctx, cmp, jltype, Twine()).first;
22052206
if (SameType != ConstantInt::getTrue(ctx.builder.getContext())) {
22062207
BasicBlock *SkipBB = BasicBlock::Create(ctx.builder.getContext(), "skip_xchg", ctx.f);

src/codegen.cpp

+136-1
Original file line numberDiff line numberDiff line change
@@ -3358,6 +3358,58 @@ static Value *emit_bitsunion_compare(jl_codectx_t &ctx, const jl_cgval_t &arg1,
33583358
return phi;
33593359
}
33603360

3361+
struct egal_desc {
3362+
size_t offset;
3363+
size_t nrepeats;
3364+
size_t data_bytes;
3365+
size_t padding_bytes;
3366+
};
3367+
3368+
template <typename callback>
3369+
static size_t emit_masked_bits_compare(callback &emit_desc, jl_datatype_t *aty, egal_desc &current_desc)
3370+
{
3371+
// Memcmp, but with masked padding
3372+
size_t data_bytes = 0;
3373+
size_t padding_bytes = 0;
3374+
size_t nfields = jl_datatype_nfields(aty);
3375+
size_t total_size = jl_datatype_size(aty);
3376+
for (size_t i = 0; i < nfields; ++i) {
3377+
size_t offset = jl_field_offset(aty, i);
3378+
size_t fend = i == nfields - 1 ? total_size : jl_field_offset(aty, i + 1);
3379+
size_t fsz = jl_field_size(aty, i);
3380+
jl_datatype_t *fty = (jl_datatype_t*)jl_field_type(aty, i);
3381+
if (jl_field_isptr(aty, i) || !fty->layout->flags.haspadding) {
3382+
// The field has no internal padding
3383+
data_bytes += fsz;
3384+
if (offset + fsz == fend) {
3385+
// The field has no padding after. Merge this into the current
3386+
// comparison range and go to next field.
3387+
} else {
3388+
padding_bytes = fend - offset - fsz;
3389+
// Found padding. Either merge this into the current comparison
3390+
// range, or emit the old one and start a new one.
3391+
if (current_desc.data_bytes == data_bytes &&
3392+
current_desc.padding_bytes == padding_bytes) {
3393+
// Same as the previous range, just note that down, so we
3394+
// emit this as a loop.
3395+
current_desc.nrepeats += 1;
3396+
} else {
3397+
if (current_desc.nrepeats != 0)
3398+
emit_desc(current_desc);
3399+
current_desc.nrepeats = 1;
3400+
current_desc.data_bytes = data_bytes;
3401+
current_desc.padding_bytes = padding_bytes;
3402+
}
3403+
data_bytes = 0;
3404+
}
3405+
} else {
3406+
// The field may have internal padding. Recurse this.
3407+
data_bytes += emit_masked_bits_compare(emit_desc, fty, current_desc);
3408+
}
3409+
}
3410+
return data_bytes;
3411+
}
3412+
33613413
static Value *emit_bits_compare(jl_codectx_t &ctx, jl_cgval_t arg1, jl_cgval_t arg2)
33623414
{
33633415
++EmittedBitsCompares;
@@ -3396,7 +3448,7 @@ static Value *emit_bits_compare(jl_codectx_t &ctx, jl_cgval_t arg1, jl_cgval_t a
33963448
if (at->isAggregateType()) { // Struct or Array
33973449
jl_datatype_t *sty = (jl_datatype_t*)arg1.typ;
33983450
size_t sz = jl_datatype_size(sty);
3399-
if (sz > 512 && !sty->layout->flags.haspadding) {
3451+
if (sz > 512 && !sty->layout->flags.haspadding && sty->layout->flags.isbitsegal) {
34003452
Value *varg1 = arg1.ispointer() ? data_pointer(ctx, arg1) :
34013453
value_to_pointer(ctx, arg1).V;
34023454
Value *varg2 = arg2.ispointer() ? data_pointer(ctx, arg2) :
@@ -3433,6 +3485,89 @@ static Value *emit_bits_compare(jl_codectx_t &ctx, jl_cgval_t arg1, jl_cgval_t a
34333485
}
34343486
return ctx.builder.CreateICmpEQ(answer, ConstantInt::get(getInt32Ty(ctx.builder.getContext()), 0));
34353487
}
3488+
else if (sz > 512 && jl_struct_try_layout(sty) && sty->layout->flags.isbitsegal) {
3489+
Type *TInt8 = getInt8Ty(ctx.builder.getContext());
3490+
Type *TpInt8 = getInt8PtrTy(ctx.builder.getContext());
3491+
Type *TInt1 = getInt1Ty(ctx.builder.getContext());
3492+
Value *varg1 = arg1.ispointer() ? data_pointer(ctx, arg1) :
3493+
value_to_pointer(ctx, arg1).V;
3494+
Value *varg2 = arg2.ispointer() ? data_pointer(ctx, arg2) :
3495+
value_to_pointer(ctx, arg2).V;
3496+
varg1 = emit_pointer_from_objref(ctx, varg1);
3497+
varg2 = emit_pointer_from_objref(ctx, varg2);
3498+
varg1 = emit_bitcast(ctx, varg1, TpInt8);
3499+
varg2 = emit_bitcast(ctx, varg2, TpInt8);
3500+
3501+
Value *answer = nullptr;
3502+
auto emit_desc = [&](egal_desc desc) {
3503+
Value *ptr1 = varg1;
3504+
Value *ptr2 = varg2;
3505+
if (desc.offset != 0) {
3506+
ptr1 = ctx.builder.CreateConstInBoundsGEP1_32(TInt8, ptr1, desc.offset);
3507+
ptr2 = ctx.builder.CreateConstInBoundsGEP1_32(TInt8, ptr2, desc.offset);
3508+
}
3509+
3510+
Value *new_ptr1 = ptr1;
3511+
Value *endptr1 = nullptr;
3512+
BasicBlock *postBB = nullptr;
3513+
BasicBlock *loopBB = nullptr;
3514+
PHINode *answerphi = nullptr;
3515+
if (desc.nrepeats != 1) {
3516+
// Set up loop
3517+
endptr1 = ctx.builder.CreateConstInBoundsGEP1_32(TInt8, ptr1, desc.nrepeats * (desc.data_bytes + desc.padding_bytes));;
3518+
3519+
BasicBlock *currBB = ctx.builder.GetInsertBlock();
3520+
loopBB = BasicBlock::Create(ctx.builder.getContext(), "egal_loop", ctx.f);
3521+
postBB = BasicBlock::Create(ctx.builder.getContext(), "post", ctx.f);
3522+
ctx.builder.CreateBr(loopBB);
3523+
3524+
ctx.builder.SetInsertPoint(loopBB);
3525+
answerphi = ctx.builder.CreatePHI(TInt1, 2);
3526+
answerphi->addIncoming(answer ? answer : ConstantInt::get(TInt1, 1), currBB);
3527+
answer = answerphi;
3528+
3529+
PHINode *itr1 = ctx.builder.CreatePHI(ptr1->getType(), 2);
3530+
PHINode *itr2 = ctx.builder.CreatePHI(ptr2->getType(), 2);
3531+
3532+
new_ptr1 = ctx.builder.CreateConstInBoundsGEP1_32(TInt8, itr1, desc.data_bytes + desc.padding_bytes);
3533+
itr1->addIncoming(ptr1, currBB);
3534+
itr1->addIncoming(new_ptr1, loopBB);
3535+
3536+
Value *new_ptr2 = ctx.builder.CreateConstInBoundsGEP1_32(TInt8, itr2, desc.data_bytes + desc.padding_bytes);
3537+
itr2->addIncoming(ptr2, currBB);
3538+
itr2->addIncoming(new_ptr2, loopBB);
3539+
3540+
ptr1 = itr1;
3541+
ptr2 = itr2;
3542+
}
3543+
3544+
// Emit memcmp. TODO: LLVM has a pass to expand this for additional
3545+
// performance.
3546+
Value *this_answer = ctx.builder.CreateCall(prepare_call(memcmp_func),
3547+
{ ptr1,
3548+
ptr2,
3549+
ConstantInt::get(ctx.types().T_size, desc.data_bytes) });
3550+
this_answer = ctx.builder.CreateICmpEQ(this_answer, ConstantInt::get(getInt32Ty(ctx.builder.getContext()), 0));
3551+
answer = answer ? ctx.builder.CreateAnd(answer, this_answer) : this_answer;
3552+
if (endptr1) {
3553+
answerphi->addIncoming(answer, loopBB);
3554+
Value *loopend = ctx.builder.CreateICmpEQ(new_ptr1, endptr1);
3555+
ctx.builder.CreateCondBr(loopend, postBB, loopBB);
3556+
ctx.builder.SetInsertPoint(postBB);
3557+
}
3558+
};
3559+
egal_desc current_desc = {0};
3560+
size_t trailing_data_bytes = emit_masked_bits_compare(emit_desc, sty, current_desc);
3561+
assert(current_desc.nrepeats != 0);
3562+
emit_desc(current_desc);
3563+
if (trailing_data_bytes != 0) {
3564+
current_desc.nrepeats = 1;
3565+
current_desc.data_bytes = trailing_data_bytes;
3566+
current_desc.padding_bytes = 0;
3567+
emit_desc(current_desc);
3568+
}
3569+
return answer;
3570+
}
34363571
else {
34373572
jl_svec_t *types = sty->types;
34383573
Value *answer = ConstantInt::get(getInt1Ty(ctx.builder.getContext()), 1);

src/datatype.c

+27-13
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ static jl_datatype_layout_t *jl_get_layout(uint32_t sz,
180180
uint32_t npointers,
181181
uint32_t alignment,
182182
int haspadding,
183+
int isbitsegal,
183184
int arrayelem,
184185
jl_fielddesc32_t desc[],
185186
uint32_t pointers[]) JL_NOTSAFEPOINT
@@ -226,6 +227,7 @@ static jl_datatype_layout_t *jl_get_layout(uint32_t sz,
226227
flddesc->nfields = nfields;
227228
flddesc->alignment = alignment;
228229
flddesc->flags.haspadding = haspadding;
230+
flddesc->flags.isbitsegal = isbitsegal;
229231
flddesc->flags.fielddesc_type = fielddesc_type;
230232
flddesc->flags.arrayelem_isboxed = arrayelem == 1;
231233
flddesc->flags.arrayelem_isunion = arrayelem == 2;
@@ -504,6 +506,7 @@ void jl_get_genericmemory_layout(jl_datatype_t *st)
504506
int isunboxed = jl_islayout_inline(eltype, &elsz, &al) && (kind != (jl_value_t*)jl_atomic_sym || jl_is_datatype(eltype));
505507
int isunion = isunboxed && jl_is_uniontype(eltype);
506508
int haspadding = 1; // we may want to eventually actually compute this more precisely
509+
int isbitsegal = 0;
507510
int nfields = 0; // aka jl_is_layout_opaque
508511
int npointers = 1;
509512
int zi;
@@ -562,7 +565,7 @@ void jl_get_genericmemory_layout(jl_datatype_t *st)
562565
else
563566
arrayelem = 0;
564567
assert(!st->layout);
565-
st->layout = jl_get_layout(elsz, nfields, npointers, al, haspadding, arrayelem, NULL, pointers);
568+
st->layout = jl_get_layout(elsz, nfields, npointers, al, haspadding, isbitsegal, arrayelem, NULL, pointers);
566569
st->zeroinit = zi;
567570
//st->has_concrete_subtype = 1;
568571
//st->isbitstype = 0;
@@ -621,18 +624,17 @@ void jl_compute_field_offsets(jl_datatype_t *st)
621624
// if we have no fields, we can trivially skip the rest
622625
if (st == jl_symbol_type || st == jl_string_type) {
623626
// opaque layout - heap-allocated blob
624-
static const jl_datatype_layout_t opaque_byte_layout = {0, 0, 1, -1, 1, {0}};
627+
static const jl_datatype_layout_t opaque_byte_layout = {0, 0, 1, -1, 1, { .haspadding = 0, .fielddesc_type=0, .isbitsegal=1, .arrayelem_isboxed=0, .arrayelem_isunion=0 }};
625628
st->layout = &opaque_byte_layout;
626629
return;
627630
}
628631
else if (st == jl_simplevector_type || st == jl_module_type) {
629-
static const jl_datatype_layout_t opaque_ptr_layout = {0, 0, 1, -1, sizeof(void*), {0}};
632+
static const jl_datatype_layout_t opaque_ptr_layout = {0, 0, 1, -1, sizeof(void*), { .haspadding = 0, .fielddesc_type=0, .isbitsegal=1, .arrayelem_isboxed=0, .arrayelem_isunion=0 }};
630633
st->layout = &opaque_ptr_layout;
631634
return;
632635
}
633636
else {
634-
// reuse the same layout for all singletons
635-
static const jl_datatype_layout_t singleton_layout = {0, 0, 0, -1, 1, {0}};
637+
static const jl_datatype_layout_t singleton_layout = {0, 0, 0, -1, 1, { .haspadding = 0, .fielddesc_type=0, .isbitsegal=1, .arrayelem_isboxed=0, .arrayelem_isunion=0 }};
636638
st->layout = &singleton_layout;
637639
}
638640
}
@@ -673,6 +675,7 @@ void jl_compute_field_offsets(jl_datatype_t *st)
673675
size_t alignm = 1;
674676
int zeroinit = 0;
675677
int haspadding = 0;
678+
int isbitsegal = 1;
676679
int homogeneous = 1;
677680
int needlock = 0;
678681
uint32_t npointers = 0;
@@ -687,19 +690,30 @@ void jl_compute_field_offsets(jl_datatype_t *st)
687690
throw_ovf(should_malloc, desc, st, fsz);
688691
desc[i].isptr = 0;
689692
if (jl_is_uniontype(fld)) {
690-
haspadding = 1;
691693
fsz += 1; // selector byte
692694
zeroinit = 1;
695+
// TODO: Some unions could be bits comparable.
696+
isbitsegal = 0;
693697
}
694698
else {
695699
uint32_t fld_npointers = ((jl_datatype_t*)fld)->layout->npointers;
696700
if (((jl_datatype_t*)fld)->layout->flags.haspadding)
697701
haspadding = 1;
702+
if (!((jl_datatype_t*)fld)->layout->flags.isbitsegal)
703+
isbitsegal = 0;
698704
if (i >= nfields - st->name->n_uninitialized && fld_npointers &&
699705
fld_npointers * sizeof(void*) != fsz) {
700-
// field may be undef (may be uninitialized and contains pointer),
701-
// and contains non-pointer fields of non-zero sizes.
702-
haspadding = 1;
706+
// For field types that contain pointers, we allow inlinealloc
707+
// as long as the field type itself is always fully initialized.
708+
// In such a case, we use the first pointer in the inlined field
709+
// as the #undef marker (if it is zero, we treat the whole inline
710+
// struct as #undef). However, we do not zero-initialize the whole
711+
// struct, so the non-pointer parts of the inline allocation may
712+
// be arbitrary, but still need to compare egal (because all #undef)
713+
// representations are egal. Because of this, we cannot bitscompare
714+
// them.
715+
// TODO: Consider zero-initializing the whole struct.
716+
isbitsegal = 0;
703717
}
704718
if (!zeroinit)
705719
zeroinit = ((jl_datatype_t*)fld)->zeroinit;
@@ -715,8 +729,7 @@ void jl_compute_field_offsets(jl_datatype_t *st)
715729
zeroinit = 1;
716730
npointers++;
717731
if (!jl_pointer_egal(fld)) {
718-
// this somewhat poorly named flag says whether some of the bits can be non-unique
719-
haspadding = 1;
732+
isbitsegal = 0;
720733
}
721734
}
722735
if (isatomic && fsz > MAX_ATOMIC_SIZE)
@@ -777,7 +790,7 @@ void jl_compute_field_offsets(jl_datatype_t *st)
777790
}
778791
}
779792
assert(ptr_i == npointers);
780-
st->layout = jl_get_layout(sz, nfields, npointers, alignm, haspadding, 0, desc, pointers);
793+
st->layout = jl_get_layout(sz, nfields, npointers, alignm, haspadding, isbitsegal, 0, desc, pointers);
781794
if (should_malloc) {
782795
free(desc);
783796
if (npointers)
@@ -931,7 +944,7 @@ JL_DLLEXPORT jl_datatype_t *jl_new_primitivetype(jl_value_t *name, jl_module_t *
931944
bt->ismutationfree = 1;
932945
bt->isidentityfree = 1;
933946
bt->isbitstype = (parameters == jl_emptysvec);
934-
bt->layout = jl_get_layout(nbytes, 0, 0, alignm, 0, 0, NULL, NULL);
947+
bt->layout = jl_get_layout(nbytes, 0, 0, alignm, 0, 1, 0, NULL, NULL);
935948
bt->instance = NULL;
936949
return bt;
937950
}
@@ -954,6 +967,7 @@ JL_DLLEXPORT jl_datatype_t * jl_new_foreign_type(jl_sym_t *name,
954967
layout->alignment = sizeof(void *);
955968
layout->npointers = haspointers;
956969
layout->flags.haspadding = 1;
970+
layout->flags.isbitsegal = 0;
957971
layout->flags.fielddesc_type = 3;
958972
layout->flags.padding = 0;
959973
layout->flags.arrayelem_isboxed = 0;

src/julia.h

+4-1
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,10 @@ typedef struct {
574574
// metadata bit only for GenericMemory eltype layout
575575
uint16_t arrayelem_isboxed : 1;
576576
uint16_t arrayelem_isunion : 1;
577-
uint16_t padding : 11;
577+
// If set, this type's egality can be determined entirely by comparing
578+
// the non-padding bits of this datatype.
579+
uint16_t isbitsegal : 1;
580+
uint16_t padding : 10;
578581
} flags;
579582
// union {
580583
// jl_fielddesc8_t field8[nfields];

0 commit comments

Comments
 (0)