Skip to content

Commit 8e06d6a

Browse files
Lancernlanza
authored andcommittedMar 14, 2025
[CIR] Add select operation (#796)
This PR adds a new `cir.select` operation. This operation won't be generated directly by CIRGen but it is useful during further CIR to CIR transformations. This PR addresses #785 .
1 parent 029a126 commit 8e06d6a

File tree

7 files changed

+216
-15
lines changed

7 files changed

+216
-15
lines changed
 

‎clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h

+18
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,24 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
274274
return createBinop(lhs, mlir::cir::BinOpKind::Mul, val);
275275
}
276276

277+
mlir::Value createSelect(mlir::Location loc, mlir::Value condition,
278+
mlir::Value trueValue, mlir::Value falseValue) {
279+
assert(trueValue.getType() == falseValue.getType() &&
280+
"trueValue and falseValue should have the same type");
281+
return create<mlir::cir::SelectOp>(loc, trueValue.getType(), condition,
282+
trueValue, falseValue);
283+
}
284+
285+
mlir::Value createLogicalAnd(mlir::Location loc, mlir::Value lhs,
286+
mlir::Value rhs) {
287+
return createSelect(loc, lhs, rhs, getBool(false, loc));
288+
}
289+
290+
mlir::Value createLogicalOr(mlir::Location loc, mlir::Value lhs,
291+
mlir::Value rhs) {
292+
return createSelect(loc, lhs, getBool(true, loc), rhs);
293+
}
294+
277295
mlir::Value createComplexCreate(mlir::Location loc, mlir::Value real,
278296
mlir::Value imag) {
279297
auto resultComplexTy =

‎clang/include/clang/CIR/Dialect/IR/CIROps.td

+40
Original file line numberDiff line numberDiff line change
@@ -761,6 +761,46 @@ def TernaryOp : CIR_Op<"ternary",
761761
}];
762762
}
763763

764+
//===----------------------------------------------------------------------===//
765+
// SelectOp
766+
//===----------------------------------------------------------------------===//
767+
768+
def SelectOp : CIR_Op<"select", [Pure,
769+
AllTypesMatch<["true_value", "false_value", "result"]>]> {
770+
let summary = "Yield one of two values based on a boolean value";
771+
let description = [{
772+
The `cir.select` operation takes three operands. The first operand
773+
`condition` is a boolean value of type `!cir.bool`. The second and the third
774+
operand can be of any CIR types, but their types must be the same. If the
775+
first operand is `true`, the operation yields its second operand. Otherwise,
776+
the operation yields its third operand.
777+
778+
Example:
779+
780+
```mlir
781+
%0 = cir.const #cir.bool<true> : !cir.bool
782+
%1 = cir.const #cir.int<42> : !s32i
783+
%2 = cir.const #cir.int<72> : !s32i
784+
%3 = cir.select if %0 then %1 else %2 : (!cir.bool, !s32i, !s32i) -> !s32i
785+
```
786+
}];
787+
788+
let arguments = (ins CIR_BoolType:$condition, CIR_AnyType:$true_value,
789+
CIR_AnyType:$false_value);
790+
let results = (outs CIR_AnyType:$result);
791+
792+
let assemblyFormat = [{
793+
`if` $condition `then` $true_value `else` $false_value
794+
`:` `(`
795+
qualified(type($condition)) `,`
796+
qualified(type($true_value)) `,`
797+
qualified(type($false_value))
798+
`)` `->` qualified(type($result)) attr-dict
799+
}];
800+
801+
let hasFolder = 1;
802+
}
803+
764804
//===----------------------------------------------------------------------===//
765805
// ConditionOp
766806
//===----------------------------------------------------------------------===//

‎clang/lib/CIR/Dialect/IR/CIRDialect.cpp

+13
Original file line numberDiff line numberDiff line change
@@ -1382,6 +1382,19 @@ void TernaryOp::build(OpBuilder &builder, OperationState &result, Value cond,
13821382
result.addTypes(TypeRange{yield.getOperandTypes().front()});
13831383
}
13841384

1385+
//===----------------------------------------------------------------------===//
1386+
// SelectOp
1387+
//===----------------------------------------------------------------------===//
1388+
1389+
OpFoldResult SelectOp::fold(FoldAdaptor adaptor) {
1390+
auto condition = adaptor.getCondition();
1391+
if (!condition)
1392+
return nullptr;
1393+
1394+
auto conditionValue = mlir::cast<mlir::cir::BoolAttr>(condition).getValue();
1395+
return conditionValue ? getTrueValue() : getFalseValue();
1396+
}
1397+
13851398
//===----------------------------------------------------------------------===//
13861399
// BrOp
13871400
//===----------------------------------------------------------------------===//

‎clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ void CIRSimplifyPass::runOnOperation() {
146146
getOperation()->walk([&](Operation *op) {
147147
// CastOp here is to perform a manual `fold` in
148148
// applyOpPatternsAndFold
149-
if (isa<BrOp, BrCondOp, ScopeOp, SwitchOp, CastOp, TryOp, UnaryOp,
149+
if (isa<BrOp, BrCondOp, ScopeOp, SwitchOp, CastOp, TryOp, UnaryOp, SelectOp,
150150
ComplexCreateOp, ComplexRealOp, ComplexImagOp>(op))
151151
ops.push_back(op);
152152
});

‎clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

+68-14
Original file line numberDiff line numberDiff line change
@@ -2988,6 +2988,60 @@ class CIRRotateOpLowering
29882988
}
29892989
};
29902990

2991+
class CIRSelectOpLowering
2992+
: public mlir::OpConversionPattern<mlir::cir::SelectOp> {
2993+
public:
2994+
using OpConversionPattern<mlir::cir::SelectOp>::OpConversionPattern;
2995+
2996+
mlir::LogicalResult
2997+
matchAndRewrite(mlir::cir::SelectOp op, OpAdaptor adaptor,
2998+
mlir::ConversionPatternRewriter &rewriter) const override {
2999+
auto getConstantBool = [](mlir::Value value) -> std::optional<bool> {
3000+
auto definingOp = mlir::dyn_cast_if_present<mlir::cir::ConstantOp>(
3001+
value.getDefiningOp());
3002+
if (!definingOp)
3003+
return std::nullopt;
3004+
3005+
auto constValue =
3006+
mlir::dyn_cast<mlir::cir::BoolAttr>(definingOp.getValue());
3007+
if (!constValue)
3008+
return std::nullopt;
3009+
3010+
return constValue.getValue();
3011+
};
3012+
3013+
// Two special cases in the LLVMIR codegen of select op:
3014+
// - select %0, %1, false => and %0, %1
3015+
// - select %0, true, %1 => or %0, %1
3016+
auto trueValue = op.getTrueValue();
3017+
auto falseValue = op.getFalseValue();
3018+
if (mlir::isa<mlir::cir::BoolType>(trueValue.getType())) {
3019+
if (std::optional<bool> falseValueBool = getConstantBool(falseValue);
3020+
falseValueBool.has_value() && !*falseValueBool) {
3021+
// select %0, %1, false => and %0, %1
3022+
rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(
3023+
op, adaptor.getCondition(), adaptor.getTrueValue());
3024+
return mlir::success();
3025+
}
3026+
if (std::optional<bool> trueValueBool = getConstantBool(trueValue);
3027+
trueValueBool.has_value() && *trueValueBool) {
3028+
// select %0, true, %1 => or %0, %1
3029+
rewriter.replaceOpWithNewOp<mlir::LLVM::OrOp>(
3030+
op, adaptor.getCondition(), adaptor.getFalseValue());
3031+
return mlir::success();
3032+
}
3033+
}
3034+
3035+
auto llvmCondition = rewriter.create<mlir::LLVM::TruncOp>(
3036+
op.getLoc(), mlir::IntegerType::get(op->getContext(), 1),
3037+
adaptor.getCondition());
3038+
rewriter.replaceOpWithNewOp<mlir::LLVM::SelectOp>(
3039+
op, llvmCondition, adaptor.getTrueValue(), adaptor.getFalseValue());
3040+
3041+
return mlir::success();
3042+
}
3043+
};
3044+
29913045
class CIRBrOpLowering : public mlir::OpConversionPattern<mlir::cir::BrOp> {
29923046
public:
29933047
using OpConversionPattern<mlir::cir::BrOp>::OpConversionPattern;
@@ -3836,20 +3890,20 @@ void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns,
38363890
patterns.add<CIRReturnLowering>(patterns.getContext());
38373891
patterns.add<CIRAllocaLowering>(converter, dataLayout, patterns.getContext());
38383892
patterns.add<
3839-
CIRCmpOpLowering, CIRBitClrsbOpLowering, CIRBitClzOpLowering,
3840-
CIRBitCtzOpLowering, CIRBitFfsOpLowering, CIRBitParityOpLowering,
3841-
CIRBitPopcountOpLowering, CIRAtomicCmpXchgLowering, CIRAtomicXchgLowering,
3842-
CIRAtomicFetchLowering, CIRByteswapOpLowering, CIRRotateOpLowering,
3843-
CIRBrCondOpLowering, CIRPtrStrideOpLowering, CIRCallLowering,
3844-
CIRTryCallLowering, CIREhInflightOpLowering, CIRUnaryOpLowering,
3845-
CIRBinOpLowering, CIRBinOpOverflowOpLowering, CIRShiftOpLowering,
3846-
CIRLoadLowering, CIRConstantLowering, CIRStoreLowering, CIRFuncLowering,
3847-
CIRCastOpLowering, CIRGlobalOpLowering, CIRGetGlobalOpLowering,
3848-
CIRComplexCreateOpLowering, CIRComplexRealOpLowering,
3849-
CIRComplexImagOpLowering, CIRComplexRealPtrOpLowering,
3850-
CIRComplexImagPtrOpLowering, CIRVAStartLowering, CIRVAEndLowering,
3851-
CIRVACopyLowering, CIRVAArgLowering, CIRBrOpLowering,
3852-
CIRGetMemberOpLowering, CIRGetRuntimeMemberOpLowering,
3893+
CIRCmpOpLowering, CIRSelectOpLowering, CIRBitClrsbOpLowering,
3894+
CIRBitClzOpLowering, CIRBitCtzOpLowering, CIRBitFfsOpLowering,
3895+
CIRBitParityOpLowering, CIRBitPopcountOpLowering,
3896+
CIRAtomicCmpXchgLowering, CIRAtomicXchgLowering, CIRAtomicFetchLowering,
3897+
CIRByteswapOpLowering, CIRRotateOpLowering, CIRBrCondOpLowering,
3898+
CIRPtrStrideOpLowering, CIRCallLowering, CIRTryCallLowering,
3899+
CIREhInflightOpLowering, CIRUnaryOpLowering, CIRBinOpLowering,
3900+
CIRBinOpOverflowOpLowering, CIRShiftOpLowering, CIRLoadLowering,
3901+
CIRConstantLowering, CIRStoreLowering, CIRFuncLowering, CIRCastOpLowering,
3902+
CIRGlobalOpLowering, CIRGetGlobalOpLowering, CIRComplexCreateOpLowering,
3903+
CIRComplexRealOpLowering, CIRComplexImagOpLowering,
3904+
CIRComplexRealPtrOpLowering, CIRComplexImagPtrOpLowering,
3905+
CIRVAStartLowering, CIRVAEndLowering, CIRVACopyLowering, CIRVAArgLowering,
3906+
CIRBrOpLowering, CIRGetMemberOpLowering, CIRGetRuntimeMemberOpLowering,
38533907
CIRSwitchFlatOpLowering, CIRPtrDiffOpLowering, CIRCopyOpLowering,
38543908
CIRMemCpyOpLowering, CIRFAbsOpLowering, CIRExpectOpLowering,
38553909
CIRVTableAddrPointOpLowering, CIRVectorCreateLowering,

‎clang/test/CIR/Lowering/select.cir

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// RUN: cir-translate -cir-to-llvmir -o %t.ll %s
2+
// RUN: FileCheck --input-file=%t.ll -check-prefix=LLVM %s
3+
4+
!s32i = !cir.int<s, 32>
5+
6+
module {
7+
cir.func @select_int(%arg0 : !cir.bool, %arg1 : !s32i, %arg2 : !s32i) -> !s32i {
8+
%0 = cir.select if %arg0 then %arg1 else %arg2 : (!cir.bool, !s32i, !s32i) -> !s32i
9+
cir.return %0 : !s32i
10+
}
11+
12+
// LLVM: define i32 @select_int(i8 %[[#COND:]], i32 %[[#TV:]], i32 %[[#FV:]])
13+
// LLVM-NEXT: %[[#CONDF:]] = trunc i8 %[[#COND]] to i1
14+
// LLVM-NEXT: %[[#RES:]] = select i1 %[[#CONDF]], i32 %[[#TV]], i32 %[[#FV]]
15+
// LLVM-NEXT: ret i32 %[[#RES]]
16+
// LLVM-NEXT: }
17+
18+
cir.func @select_bool(%arg0 : !cir.bool, %arg1 : !cir.bool, %arg2 : !cir.bool) -> !cir.bool {
19+
%0 = cir.select if %arg0 then %arg1 else %arg2 : (!cir.bool, !cir.bool, !cir.bool) -> !cir.bool
20+
cir.return %0 : !cir.bool
21+
}
22+
23+
// LLVM: define i8 @select_bool(i8 %[[#COND:]], i8 %[[#TV:]], i8 %[[#FV:]])
24+
// LLVM-NEXT: %[[#CONDF:]] = trunc i8 %[[#COND]] to i1
25+
// LLVM-NEXT: %[[#RES:]] = select i1 %[[#CONDF]], i8 %[[#TV]], i8 %[[#FV]]
26+
// LLVM-NEXT: ret i8 %[[#RES]]
27+
// LLVM-NEXT: }
28+
29+
cir.func @logical_and(%arg0 : !cir.bool, %arg1 : !cir.bool) -> !cir.bool {
30+
%0 = cir.const #cir.bool<false> : !cir.bool
31+
%1 = cir.select if %arg0 then %arg1 else %0 : (!cir.bool, !cir.bool, !cir.bool) -> !cir.bool
32+
cir.return %1 : !cir.bool
33+
}
34+
35+
// LLVM: define i8 @logical_and(i8 %[[#ARG0:]], i8 %[[#ARG1:]])
36+
// LLVM-NEXT: %[[#RES:]] = and i8 %[[#ARG0]], %[[#ARG1]]
37+
// LLVM-NEXT: ret i8 %[[#RES]]
38+
// LLVM-NEXT: }
39+
40+
cir.func @logical_or(%arg0 : !cir.bool, %arg1 : !cir.bool) -> !cir.bool {
41+
%0 = cir.const #cir.bool<true> : !cir.bool
42+
%1 = cir.select if %arg0 then %0 else %arg1 : (!cir.bool, !cir.bool, !cir.bool) -> !cir.bool
43+
cir.return %1 : !cir.bool
44+
}
45+
46+
// LLVM: define i8 @logical_or(i8 %[[#ARG0:]], i8 %[[#ARG1:]])
47+
// LLVM-NEXT: %[[#RES:]] = or i8 %[[#ARG0]], %[[#ARG1]]
48+
// LLVM-NEXT: ret i8 %[[#RES]]
49+
// LLVM-NEXT: }
50+
}

‎clang/test/CIR/Transforms/select.cir

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// RUN: cir-opt --canonicalize -o %t.cir %s
2+
// RUN: FileCheck --input-file=%t.cir %s
3+
4+
!s32i = !cir.int<s, 32>
5+
6+
module {
7+
cir.func @fold_true(%arg0 : !s32i, %arg1 : !s32i) -> !s32i {
8+
%0 = cir.const #cir.bool<true> : !cir.bool
9+
%1 = cir.select if %0 then %arg0 else %arg1 : (!cir.bool, !s32i, !s32i) -> !s32i
10+
cir.return %1 : !s32i
11+
}
12+
13+
// CHECK: cir.func @fold_true(%[[ARG0:.+]]: !s32i, %[[ARG1:.+]]: !s32i) -> !s32i {
14+
// CHECK-NEXT: cir.return %[[ARG0]] : !s32i
15+
// CHECK-NEXT: }
16+
17+
cir.func @fold_false(%arg0 : !s32i, %arg1 : !s32i) -> !s32i {
18+
%0 = cir.const #cir.bool<false> : !cir.bool
19+
%1 = cir.select if %0 then %arg0 else %arg1 : (!cir.bool, !s32i, !s32i) -> !s32i
20+
cir.return %1 : !s32i
21+
}
22+
23+
// CHECK: cir.func @fold_false(%[[ARG0:.+]]: !s32i, %[[ARG1:.+]]: !s32i) -> !s32i {
24+
// CHECK-NEXT: cir.return %[[ARG1]] : !s32i
25+
// CHECK-NEXT: }
26+
}

0 commit comments

Comments
 (0)
Please sign in to comment.