Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 62d11b2

Browse files
committedMar 2, 2023
Revert "Revert "[SCEV] Add SCEVType to represent vscale.""
Relanding after fixing Polly related build error. This reverts commit 7b26dca.
1 parent c396073 commit 62d11b2

File tree

12 files changed

+97
-27
lines changed

12 files changed

+97
-27
lines changed
 

‎llvm/include/llvm/Analysis/ScalarEvolution.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,7 @@ class ScalarEvolution {
566566
const SCEV *getLosslessPtrToIntExpr(const SCEV *Op, unsigned Depth = 0);
567567
const SCEV *getPtrToIntExpr(const SCEV *Op, Type *Ty);
568568
const SCEV *getTruncateExpr(const SCEV *Op, Type *Ty, unsigned Depth = 0);
569+
const SCEV *getVScale(Type *Ty);
569570
const SCEV *getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth = 0);
570571
const SCEV *getZeroExtendExprImpl(const SCEV *Op, Type *Ty,
571572
unsigned Depth = 0);

‎llvm/include/llvm/Analysis/ScalarEvolutionDivision.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ struct SCEVDivision : public SCEVVisitor<SCEVDivision, void> {
4848

4949
void visitConstant(const SCEVConstant *Numerator);
5050

51+
void visitVScale(const SCEVVScale *Numerator);
52+
5153
void visitAddRecExpr(const SCEVAddRecExpr *Numerator);
5254

5355
void visitAddExpr(const SCEVAddExpr *Numerator);

‎llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ enum SCEVTypes : unsigned short {
3939
// These should be ordered in terms of increasing complexity to make the
4040
// folders simpler.
4141
scConstant,
42+
scVScale,
4243
scTruncate,
4344
scZeroExtend,
4445
scSignExtend,
@@ -75,6 +76,23 @@ class SCEVConstant : public SCEV {
7576
static bool classof(const SCEV *S) { return S->getSCEVType() == scConstant; }
7677
};
7778

79+
/// This class represents the value of vscale, as used when defining the length
80+
/// of a scalable vector or returned by the llvm.vscale() intrinsic.
81+
class SCEVVScale : public SCEV {
82+
friend class ScalarEvolution;
83+
84+
SCEVVScale(const FoldingSetNodeIDRef ID, Type *ty)
85+
: SCEV(ID, scVScale, 0), Ty(ty) {}
86+
87+
Type *Ty;
88+
89+
public:
90+
Type *getType() const { return Ty; }
91+
92+
/// Methods for support type inquiry through isa, cast, and dyn_cast:
93+
static bool classof(const SCEV *S) { return S->getSCEVType() == scVScale; }
94+
};
95+
7896
inline unsigned short computeExpressionSize(ArrayRef<const SCEV *> Args) {
7997
APInt Size(16, 1);
8098
for (const auto *Arg : Args)
@@ -579,9 +597,6 @@ class SCEVUnknown final : public SCEV, private CallbackVH {
579597
public:
580598
Value *getValue() const { return getValPtr(); }
581599

582-
/// Check whether this represents vscale.
583-
bool isVScale() const;
584-
585600
Type *getType() const { return getValPtr()->getType(); }
586601

587602
/// Methods for support type inquiry through isa, cast, and dyn_cast:
@@ -595,6 +610,8 @@ template <typename SC, typename RetVal = void> struct SCEVVisitor {
595610
switch (S->getSCEVType()) {
596611
case scConstant:
597612
return ((SC *)this)->visitConstant((const SCEVConstant *)S);
613+
case scVScale:
614+
return ((SC *)this)->visitVScale((const SCEVVScale *)S);
598615
case scPtrToInt:
599616
return ((SC *)this)->visitPtrToIntExpr((const SCEVPtrToIntExpr *)S);
600617
case scTruncate:
@@ -662,6 +679,7 @@ template <typename SV> class SCEVTraversal {
662679

663680
switch (S->getSCEVType()) {
664681
case scConstant:
682+
case scVScale:
665683
case scUnknown:
666684
continue;
667685
case scPtrToInt:
@@ -751,6 +769,8 @@ class SCEVRewriteVisitor : public SCEVVisitor<SC, const SCEV *> {
751769

752770
const SCEV *visitConstant(const SCEVConstant *Constant) { return Constant; }
753771

772+
const SCEV *visitVScale(const SCEVVScale *VScale) { return VScale; }
773+
754774
const SCEV *visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) {
755775
const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
756776
return Operand == Expr->getOperand()

‎llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,8 @@ class SCEVExpander : public SCEVVisitor<SCEVExpander, Value *> {
457457

458458
Value *visitConstant(const SCEVConstant *S) { return S->getValue(); }
459459

460+
Value *visitVScale(const SCEVVScale *S);
461+
460462
Value *visitPtrToIntExpr(const SCEVPtrToIntExpr *S);
461463

462464
Value *visitTruncateExpr(const SCEVTruncateExpr *S);

‎llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 43 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,9 @@ void SCEV::print(raw_ostream &OS) const {
271271
case scConstant:
272272
cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
273273
return;
274+
case scVScale:
275+
OS << "vscale";
276+
return;
274277
case scPtrToInt: {
275278
const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(this);
276279
const SCEV *Op = PtrToInt->getOperand();
@@ -366,17 +369,9 @@ void SCEV::print(raw_ostream &OS) const {
366369
OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
367370
return;
368371
}
369-
case scUnknown: {
370-
const SCEVUnknown *U = cast<SCEVUnknown>(this);
371-
if (U->isVScale()) {
372-
OS << "vscale";
373-
return;
374-
}
375-
376-
// Otherwise just print it normally.
377-
U->getValue()->printAsOperand(OS, false);
372+
case scUnknown:
373+
cast<SCEVUnknown>(this)->getValue()->printAsOperand(OS, false);
378374
return;
379-
}
380375
case scCouldNotCompute:
381376
OS << "***COULDNOTCOMPUTE***";
382377
return;
@@ -388,6 +383,8 @@ Type *SCEV::getType() const {
388383
switch (getSCEVType()) {
389384
case scConstant:
390385
return cast<SCEVConstant>(this)->getType();
386+
case scVScale:
387+
return cast<SCEVVScale>(this)->getType();
391388
case scPtrToInt:
392389
case scTruncate:
393390
case scZeroExtend:
@@ -419,6 +416,7 @@ Type *SCEV::getType() const {
419416
ArrayRef<const SCEV *> SCEV::operands() const {
420417
switch (getSCEVType()) {
421418
case scConstant:
419+
case scVScale:
422420
case scUnknown:
423421
return {};
424422
case scPtrToInt:
@@ -501,6 +499,18 @@ ScalarEvolution::getConstant(Type *Ty, uint64_t V, bool isSigned) {
501499
return getConstant(ConstantInt::get(ITy, V, isSigned));
502500
}
503501

502+
const SCEV *ScalarEvolution::getVScale(Type *Ty) {
503+
FoldingSetNodeID ID;
504+
ID.AddInteger(scVScale);
505+
ID.AddPointer(Ty);
506+
void *IP = nullptr;
507+
if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
508+
return S;
509+
SCEV *S = new (SCEVAllocator) SCEVVScale(ID.Intern(SCEVAllocator), Ty);
510+
UniqueSCEVs.InsertNode(S, IP);
511+
return S;
512+
}
513+
504514
SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy,
505515
const SCEV *op, Type *ty)
506516
: SCEV(ID, SCEVTy, computeExpressionSize(op)), Op(op), Ty(ty) {}
@@ -560,10 +570,6 @@ void SCEVUnknown::allUsesReplacedWith(Value *New) {
560570
setValPtr(New);
561571
}
562572

563-
bool SCEVUnknown::isVScale() const {
564-
return match(getValue(), m_VScale());
565-
}
566-
567573
//===----------------------------------------------------------------------===//
568574
// SCEV Utilities
569575
//===----------------------------------------------------------------------===//
@@ -714,6 +720,12 @@ CompareSCEVComplexity(EquivalenceClasses<const SCEV *> &EqCacheSCEV,
714720
return LA.ult(RA) ? -1 : 1;
715721
}
716722

723+
case scVScale: {
724+
const auto *LTy = cast<IntegerType>(cast<SCEVVScale>(LHS)->getType());
725+
const auto *RTy = cast<IntegerType>(cast<SCEVVScale>(RHS)->getType());
726+
return LTy->getBitWidth() - RTy->getBitWidth();
727+
}
728+
717729
case scAddRecExpr: {
718730
const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS);
719731
const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS);
@@ -4015,6 +4027,8 @@ class SCEVSequentialMinMaxDeduplicatingVisitor final
40154027

40164028
RetVal visitConstant(const SCEVConstant *Constant) { return Constant; }
40174029

4030+
RetVal visitVScale(const SCEVVScale *VScale) { return VScale; }
4031+
40184032
RetVal visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { return Expr; }
40194033

40204034
RetVal visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; }
@@ -4061,6 +4075,7 @@ class SCEVSequentialMinMaxDeduplicatingVisitor final
40614075
static bool scevUnconditionallyPropagatesPoisonFromOperands(SCEVTypes Kind) {
40624076
switch (Kind) {
40634077
case scConstant:
4078+
case scVScale:
40644079
case scTruncate:
40654080
case scZeroExtend:
40664081
case scSignExtend:
@@ -4104,6 +4119,7 @@ static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) {
41044119
if (!scevUnconditionallyPropagatesPoisonFromOperands(S->getSCEVType())) {
41054120
switch (S->getSCEVType()) {
41064121
case scConstant:
4122+
case scVScale:
41074123
case scTruncate:
41084124
case scZeroExtend:
41094125
case scSignExtend:
@@ -4315,15 +4331,8 @@ const SCEV *ScalarEvolution::getUMinExpr(SmallVectorImpl<const SCEV *> &Ops,
43154331
const SCEV *
43164332
ScalarEvolution::getSizeOfExpr(Type *IntTy, TypeSize Size) {
43174333
const SCEV *Res = getConstant(IntTy, Size.getKnownMinValue());
4318-
if (Size.isScalable()) {
4319-
// TODO: Why is there no ConstantExpr::getVScale()?
4320-
Type *SrcElemTy = ScalableVectorType::get(Type::getInt8Ty(getContext()), 1);
4321-
Constant *NullPtr = Constant::getNullValue(SrcElemTy->getPointerTo());
4322-
Constant *One = ConstantInt::get(IntTy, 1);
4323-
Constant *GEP = ConstantExpr::getGetElementPtr(SrcElemTy, NullPtr, One);
4324-
Constant *VScale = ConstantExpr::getPtrToInt(GEP, IntTy);
4325-
Res = getMulExpr(Res, getUnknown(VScale));
4326-
}
4334+
if (Size.isScalable())
4335+
Res = getMulExpr(Res, getVScale(IntTy));
43274336
return Res;
43284337
}
43294338

@@ -5887,6 +5896,7 @@ static bool IsAvailableOnEntry(const Loop *L, DominatorTree &DT, const SCEV *S,
58875896
bool follow(const SCEV *S) {
58885897
switch (S->getSCEVType()) {
58895898
case scConstant:
5899+
case scVScale:
58905900
case scPtrToInt:
58915901
case scTruncate:
58925902
case scZeroExtend:
@@ -6274,6 +6284,8 @@ uint32_t ScalarEvolution::GetMinTrailingZerosImpl(const SCEV *S) {
62746284
switch (S->getSCEVType()) {
62756285
case scConstant:
62766286
return cast<SCEVConstant>(S)->getAPInt().countr_zero();
6287+
case scVScale:
6288+
return 0;
62776289
case scTruncate: {
62786290
const SCEVTruncateExpr *T = cast<SCEVTruncateExpr>(S);
62796291
return std::min(GetMinTrailingZeros(T->getOperand()),
@@ -6504,6 +6516,7 @@ ScalarEvolution::getRangeRefIter(const SCEV *S,
65046516
break;
65056517
[[fallthrough]];
65066518
case scConstant:
6519+
case scVScale:
65076520
case scTruncate:
65086521
case scZeroExtend:
65096522
case scSignExtend:
@@ -6607,6 +6620,8 @@ const ConstantRange &ScalarEvolution::getRangeRef(
66076620
switch (S->getSCEVType()) {
66086621
case scConstant:
66096622
llvm_unreachable("Already handled above.");
6623+
case scVScale:
6624+
return setRange(S, SignHint, std::move(ConservativeResult));
66106625
case scTruncate: {
66116626
const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(S);
66126627
ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint, Depth + 1);
@@ -9711,6 +9726,7 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) {
97119726
switch (V->getSCEVType()) {
97129727
case scCouldNotCompute:
97139728
case scAddRecExpr:
9729+
case scVScale:
97149730
return nullptr;
97159731
case scConstant:
97169732
return cast<SCEVConstant>(V)->getValue();
@@ -9794,6 +9810,7 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) {
97949810
const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
97959811
switch (V->getSCEVType()) {
97969812
case scConstant:
9813+
case scVScale:
97979814
return V;
97989815
case scAddRecExpr: {
97999816
// If this is a loop recurrence for a loop that does not contain L, then we
@@ -9892,6 +9909,7 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
98929909
case scSequentialUMinExpr:
98939910
return getSequentialMinMaxExpr(V->getSCEVType(), NewOps);
98949911
case scConstant:
9912+
case scVScale:
98959913
case scAddRecExpr:
98969914
case scUnknown:
98979915
case scCouldNotCompute:
@@ -13677,6 +13695,7 @@ ScalarEvolution::LoopDisposition
1367713695
ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
1367813696
switch (S->getSCEVType()) {
1367913697
case scConstant:
13698+
case scVScale:
1368013699
return LoopInvariant;
1368113700
case scAddRecExpr: {
1368213701
const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
@@ -13775,6 +13794,7 @@ ScalarEvolution::BlockDisposition
1377513794
ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {
1377613795
switch (S->getSCEVType()) {
1377713796
case scConstant:
13797+
case scVScale:
1377813798
return ProperlyDominatesBlock;
1377913799
case scAddRecExpr: {
1378013800
// This uses a "dominates" query instead of "properly dominates" query

‎llvm/lib/Analysis/ScalarEvolutionDivision.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,10 @@ void SCEVDivision::visitConstant(const SCEVConstant *Numerator) {
126126
}
127127
}
128128

129+
void SCEVDivision::visitVScale(const SCEVVScale *Numerator) {
130+
return cannotDivide(Numerator);
131+
}
132+
129133
void SCEVDivision::visitAddRecExpr(const SCEVAddRecExpr *Numerator) {
130134
const SCEV *StartQ, *StartR, *StepQ, *StepR;
131135
if (!Numerator->isAffine())

‎llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -976,6 +976,7 @@ static bool isHighCostExpansion(const SCEV *S,
976976
switch (S->getSCEVType()) {
977977
case scUnknown:
978978
case scConstant:
979+
case scVScale:
979980
return false;
980981
case scTruncate:
981982
return isHighCostExpansion(cast<SCEVTruncateExpr>(S)->getOperand(),
@@ -2812,9 +2813,10 @@ static bool isCompatibleIVType(Value *LVal, Value *RVal) {
28122813
/// SCEVUnknown, we simply return the rightmost SCEV operand.
28132814
static const SCEV *getExprBase(const SCEV *S) {
28142815
switch (S->getSCEVType()) {
2815-
default: // uncluding scUnknown.
2816+
default: // including scUnknown.
28162817
return S;
28172818
case scConstant:
2819+
case scVScale:
28182820
return nullptr;
28192821
case scTruncate:
28202822
return getExprBase(cast<SCEVTruncateExpr>(S)->getOperand());

‎llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,7 @@ const Loop *SCEVExpander::getRelevantLoop(const SCEV *S) {
680680

681681
switch (S->getSCEVType()) {
682682
case scConstant:
683+
case scVScale:
683684
return nullptr; // A constant has no relevant loops.
684685
case scTruncate:
685686
case scZeroExtend:
@@ -1744,6 +1745,10 @@ Value *SCEVExpander::visitSequentialUMinExpr(const SCEVSequentialUMinExpr *S) {
17441745
return expandMinMaxExpr(S, Intrinsic::umin, "umin", /*IsSequential*/true);
17451746
}
17461747

1748+
Value *SCEVExpander::visitVScale(const SCEVVScale *S) {
1749+
return Builder.CreateVScale(ConstantInt::get(S->getType(), 1));
1750+
}
1751+
17471752
Value *SCEVExpander::expandCodeForImpl(const SCEV *SH, Type *Ty,
17481753
Instruction *IP) {
17491754
setInsertPoint(IP);
@@ -2124,6 +2129,7 @@ template<typename T> static InstructionCost costAndCollectOperands(
21242129
llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
21252130
case scUnknown:
21262131
case scConstant:
2132+
case scVScale:
21272133
return 0;
21282134
case scPtrToInt:
21292135
Cost = CastCost(Instruction::PtrToInt);
@@ -2260,6 +2266,7 @@ bool SCEVExpander::isHighCostExpansionHelper(
22602266
case scCouldNotCompute:
22612267
llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
22622268
case scUnknown:
2269+
case scVScale:
22632270
// Assume to be zero-cost.
22642271
return false;
22652272
case scConstant: {

‎polly/include/polly/Support/SCEVAffinator.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ class SCEVAffinator final : public llvm::SCEVVisitor<SCEVAffinator, PWACtx> {
9999

100100
PWACtx visit(const llvm::SCEV *E);
101101
PWACtx visitConstant(const llvm::SCEVConstant *E);
102+
PWACtx visitVScale(const llvm::SCEVVScale *E);
102103
PWACtx visitPtrToIntExpr(const llvm::SCEVPtrToIntExpr *E);
103104
PWACtx visitTruncateExpr(const llvm::SCEVTruncateExpr *E);
104105
PWACtx visitZeroExtendExpr(const llvm::SCEVZeroExtendExpr *E);

‎polly/lib/Support/SCEVAffinator.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,10 @@ PWACtx SCEVAffinator::visitConstant(const SCEVConstant *Expr) {
266266
isl::manage(isl_pw_aff_from_aff(isl_aff_val_on_domain(ls, v))));
267267
}
268268

269+
PWACtx SCEVAffinator::visitVScale(const SCEVVScale *VScale) {
270+
llvm_unreachable("SCEVVScale not yet supported");
271+
}
272+
269273
PWACtx SCEVAffinator::visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) {
270274
return visit(Expr->getOperand(0));
271275
}

‎polly/lib/Support/SCEVValidator.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,12 @@ class SCEVValidator : public SCEVVisitor<SCEVValidator, ValidatorResult> {
134134
return ValidatorResult(SCEVType::INT);
135135
}
136136

137+
ValidatorResult visitVScale(const SCEVVScale *VScale) {
138+
// We do not support VScale constants.
139+
LLVM_DEBUG(dbgs() << "INVALID: VScale is not supported");
140+
return ValidatorResult(SCEVType::INVALID);
141+
}
142+
137143
ValidatorResult visitZeroExtendOrTruncateExpr(const SCEV *Expr,
138144
const SCEV *Operand) {
139145
ValidatorResult Op = visit(Operand);

0 commit comments

Comments
 (0)
Please sign in to comment.