Skip to content

File tree

7 files changed

+216
-15
lines changed

7 files changed

+216
-15
lines changed
 

‎clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h

+18
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,24 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
274274
return createBinop(lhs, mlir::cir::BinOpKind::Mul, val);
275275
}
276276

277+
mlir::Value createSelect(mlir::Location loc, mlir::Value condition,
278+
mlir::Value trueValue, mlir::Value falseValue) {
279+
assert(trueValue.getType() == falseValue.getType() &&
280+
"trueValue and falseValue should have the same type");
281+
return create<mlir::cir::SelectOp>(loc, trueValue.getType(), condition,
282+
trueValue, falseValue);
283+
}
284+
285+
mlir::Value createLogicalAnd(mlir::Location loc, mlir::Value lhs,
286+
mlir::Value rhs) {
287+
return createSelect(loc, lhs, rhs, getBool(false, loc));
288+
}
289+
290+
mlir::Value createLogicalOr(mlir::Location loc, mlir::Value lhs,
291+
mlir::Value rhs) {
292+
return createSelect(loc, lhs, getBool(true, loc), rhs);
293+
}
294+
277295
mlir::Value createComplexCreate(mlir::Location loc, mlir::Value real,
278296
mlir::Value imag) {
279297
auto resultComplexTy =

‎clang/include/clang/CIR/Dialect/IR/CIROps.td

+40
Original file line numberDiff line numberDiff line change
@@ -761,6 +761,46 @@ def TernaryOp : CIR_Op<"ternary",
761761
}];
762762
}
763763

764+
//===----------------------------------------------------------------------===//
765+
// SelectOp
766+
//===----------------------------------------------------------------------===//
767+
768+
def SelectOp : CIR_Op<"select", [Pure,
769+
AllTypesMatch<["true_value", "false_value", "result"]>]> {
770+
let summary = "Yield one of two values based on a boolean value";
771+
let description = [{
772+
The `cir.select` operation takes three operands. The first operand
773+
`condition` is a boolean value of type `!cir.bool`. The second and the third
774+
operand can be of any CIR types, but their types must be the same. If the
775+
first operand is `true`, the operation yields its second operand. Otherwise,
776+
the operation yields its third operand.
777+
778+
Example:
779+
780+
```mlir
781+
%0 = cir.const #cir.bool<true> : !cir.bool
782+
%1 = cir.const #cir.int<42> : !s32i
783+
%2 = cir.const #cir.int<72> : !s32i
784+
%3 = cir.select if %0 then %1 else %2 : (!cir.bool, !s32i, !s32i) -> !s32i
785+
```
786+
}];
787+
788+
let arguments = (ins CIR_BoolType:$condition, CIR_AnyType:$true_value,
789+
CIR_AnyType:$false_value);
790+
let results = (outs CIR_AnyType:$result);
791+
792+
let assemblyFormat = [{
793+
`if` $condition `then` $true_value `else` $false_value
794+
`:` `(`
795+
qualified(type($condition)) `,`
796+
qualified(type($true_value)) `,`
797+
qualified(type($false_value))
798+
`)` `->` qualified(type($result)) attr-dict
799+
}];
800+
801+
let hasFolder = 1;
802+
}
803+
764804
//===----------------------------------------------------------------------===//
765805
// ConditionOp
766806
//===----------------------------------------------------------------------===//

‎clang/lib/CIR/Dialect/IR/CIRDialect.cpp

+13
Original file line numberDiff line numberDiff line change
@@ -1382,6 +1382,19 @@ void TernaryOp::build(OpBuilder &builder, OperationState &result, Value cond,
13821382
result.addTypes(TypeRange{yield.getOperandTypes().front()});
13831383
}
13841384

1385+
//===----------------------------------------------------------------------===//
1386+
// SelectOp
1387+
//===----------------------------------------------------------------------===//
1388+
1389+
OpFoldResult SelectOp::fold(FoldAdaptor adaptor) {
1390+
auto condition = adaptor.getCondition();
1391+
if (!condition)
1392+
return nullptr;
1393+
1394+
auto conditionValue = mlir::cast<mlir::cir::BoolAttr>(condition).getValue();
1395+
return conditionValue ? getTrueValue() : getFalseValue();
1396+
}
1397+
13851398
//===----------------------------------------------------------------------===//
13861399
// BrOp
13871400
//===----------------------------------------------------------------------===//

‎clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ void CIRSimplifyPass::runOnOperation() {
146146
getOperation()->walk([&](Operation *op) {
147147
// CastOp here is to perform a manual `fold` in
148148
// applyOpPatternsAndFold
149-
if (isa<BrOp, BrCondOp, ScopeOp, SwitchOp, CastOp, TryOp, UnaryOp,
149+
if (isa<BrOp, BrCondOp, ScopeOp, SwitchOp, CastOp, TryOp, UnaryOp, SelectOp,
150150
ComplexCreateOp, ComplexRealOp, ComplexImagOp>(op))
151151
ops.push_back(op);
152152
});

‎clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

+68-14
Original file line numberDiff line numberDiff line change
@@ -2987,6 +2987,60 @@ class CIRRotateOpLowering
29872987
}
29882988
};
29892989

2990+
class CIRSelectOpLowering
2991+
: public mlir::OpConversionPattern<mlir::cir::SelectOp> {
2992+
public:
2993+
using OpConversionPattern<mlir::cir::SelectOp>::OpConversionPattern;
2994+
2995+
mlir::LogicalResult
2996+
matchAndRewrite(mlir::cir::SelectOp op, OpAdaptor adaptor,
2997+
mlir::ConversionPatternRewriter &rewriter) const override {
2998+
auto getConstantBool = [](mlir::Value value) -> std::optional<bool> {
2999+
auto definingOp = mlir::dyn_cast_if_present<mlir::cir::ConstantOp>(
3000+
value.getDefiningOp());
3001+
if (!definingOp)
3002+
return std::nullopt;
3003+
3004+
auto constValue =
3005+
mlir::dyn_cast<mlir::cir::BoolAttr>(definingOp.getValue());
3006+
if (!constValue)
3007+
return std::nullopt;
3008+
3009+
return constValue.getValue();
3010+
};
3011+
3012+
// Two special cases in the LLVMIR codegen of select op:
3013+
// - select %0, %1, false => and %0, %1
3014+
// - select %0, true, %1 => or %0, %1
3015+
auto trueValue = op.getTrueValue();
3016+
auto falseValue = op.getFalseValue();
3017+
if (mlir::isa<mlir::cir::BoolType>(trueValue.getType())) {
3018+
if (std::optional<bool> falseValueBool = getConstantBool(falseValue);
3019+
falseValueBool.has_value() && !*falseValueBool) {
3020+
// select %0, %1, false => and %0, %1
3021+
rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(
3022+
op, adaptor.getCondition(), adaptor.getTrueValue());
3023+
return mlir::success();
3024+
}
3025+
if (std::optional<bool> trueValueBool = getConstantBool(trueValue);
3026+
trueValueBool.has_value() && *trueValueBool) {
3027+
// select %0, true, %1 => or %0, %1
3028+
rewriter.replaceOpWithNewOp<mlir::LLVM::OrOp>(
3029+
op, adaptor.getCondition(), adaptor.getFalseValue());
3030+
return mlir::success();
3031+
}
3032+
}
3033+
3034+
auto llvmCondition = rewriter.create<mlir::LLVM::TruncOp>(
3035+
op.getLoc(), mlir::IntegerType::get(op->getContext(), 1),
3036+
adaptor.getCondition());
3037+
rewriter.replaceOpWithNewOp<mlir::LLVM::SelectOp>(
3038+
op, llvmCondition, adaptor.getTrueValue(), adaptor.getFalseValue());
3039+
3040+
return mlir::success();
3041+
}
3042+
};
3043+
29903044
class CIRBrOpLowering : public mlir::OpConversionPattern<mlir::cir::BrOp> {
29913045
public:
29923046
using OpConversionPattern<mlir::cir::BrOp>::OpConversionPattern;
@@ -3835,20 +3889,20 @@ void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns,
38353889
patterns.add<CIRReturnLowering>(patterns.getContext());
38363890
patterns.add<CIRAllocaLowering>(converter, dataLayout, patterns.getContext());
38373891
patterns.add<
3838-
CIRCmpOpLowering, CIRBitClrsbOpLowering, CIRBitClzOpLowering,
3839-
CIRBitCtzOpLowering, CIRBitFfsOpLowering, CIRBitParityOpLowering,
3840-
CIRBitPopcountOpLowering, CIRAtomicCmpXchgLowering, CIRAtomicXchgLowering,
3841-
CIRAtomicFetchLowering, CIRByteswapOpLowering, CIRRotateOpLowering,
3842-
CIRBrCondOpLowering, CIRPtrStrideOpLowering, CIRCallLowering,
3843-
CIRTryCallLowering, CIREhInflightOpLowering, CIRUnaryOpLowering,
3844-
CIRBinOpLowering, CIRBinOpOverflowOpLowering, CIRShiftOpLowering,
3845-
CIRLoadLowering, CIRConstantLowering, CIRStoreLowering, CIRFuncLowering,
3846-
CIRCastOpLowering, CIRGlobalOpLowering, CIRGetGlobalOpLowering,
3847-
CIRComplexCreateOpLowering, CIRComplexRealOpLowering,
3848-
CIRComplexImagOpLowering, CIRComplexRealPtrOpLowering,
3849-
CIRComplexImagPtrOpLowering, CIRVAStartLowering, CIRVAEndLowering,
3850-
CIRVACopyLowering, CIRVAArgLowering, CIRBrOpLowering,
3851-
CIRGetMemberOpLowering, CIRGetRuntimeMemberOpLowering,
3892+
CIRCmpOpLowering, CIRSelectOpLowering, CIRBitClrsbOpLowering,
3893+
CIRBitClzOpLowering, CIRBitCtzOpLowering, CIRBitFfsOpLowering,
3894+
CIRBitParityOpLowering, CIRBitPopcountOpLowering,
3895+
CIRAtomicCmpXchgLowering, CIRAtomicXchgLowering, CIRAtomicFetchLowering,
3896+
CIRByteswapOpLowering, CIRRotateOpLowering, CIRBrCondOpLowering,
3897+
CIRPtrStrideOpLowering, CIRCallLowering, CIRTryCallLowering,
3898+
CIREhInflightOpLowering, CIRUnaryOpLowering, CIRBinOpLowering,
3899+
CIRBinOpOverflowOpLowering, CIRShiftOpLowering, CIRLoadLowering,
3900+
CIRConstantLowering, CIRStoreLowering, CIRFuncLowering, CIRCastOpLowering,
3901+
CIRGlobalOpLowering, CIRGetGlobalOpLowering, CIRComplexCreateOpLowering,
3902+
CIRComplexRealOpLowering, CIRComplexImagOpLowering,
3903+
CIRComplexRealPtrOpLowering, CIRComplexImagPtrOpLowering,
3904+
CIRVAStartLowering, CIRVAEndLowering, CIRVACopyLowering, CIRVAArgLowering,
3905+
CIRBrOpLowering, CIRGetMemberOpLowering, CIRGetRuntimeMemberOpLowering,
38523906
CIRSwitchFlatOpLowering, CIRPtrDiffOpLowering, CIRCopyOpLowering,
38533907
CIRMemCpyOpLowering, CIRFAbsOpLowering, CIRExpectOpLowering,
38543908
CIRVTableAddrPointOpLowering, CIRVectorCreateLowering,

‎clang/test/CIR/Lowering/select.cir

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// RUN: cir-translate -cir-to-llvmir -o %t.ll %s
2+
// RUN: FileCheck --input-file=%t.ll -check-prefix=LLVM %s
3+
4+
!s32i = !cir.int<s, 32>
5+
6+
module {
7+
cir.func @select_int(%arg0 : !cir.bool, %arg1 : !s32i, %arg2 : !s32i) -> !s32i {
8+
%0 = cir.select if %arg0 then %arg1 else %arg2 : (!cir.bool, !s32i, !s32i) -> !s32i
9+
cir.return %0 : !s32i
10+
}
11+
12+
// LLVM: define i32 @select_int(i8 %[[#COND:]], i32 %[[#TV:]], i32 %[[#FV:]])
13+
// LLVM-NEXT: %[[#CONDF:]] = trunc i8 %[[#COND]] to i1
14+
// LLVM-NEXT: %[[#RES:]] = select i1 %[[#CONDF]], i32 %[[#TV]], i32 %[[#FV]]
15+
// LLVM-NEXT: ret i32 %[[#RES]]
16+
// LLVM-NEXT: }
17+
18+
cir.func @select_bool(%arg0 : !cir.bool, %arg1 : !cir.bool, %arg2 : !cir.bool) -> !cir.bool {
19+
%0 = cir.select if %arg0 then %arg1 else %arg2 : (!cir.bool, !cir.bool, !cir.bool) -> !cir.bool
20+
cir.return %0 : !cir.bool
21+
}
22+
23+
// LLVM: define i8 @select_bool(i8 %[[#COND:]], i8 %[[#TV:]], i8 %[[#FV:]])
24+
// LLVM-NEXT: %[[#CONDF:]] = trunc i8 %[[#COND]] to i1
25+
// LLVM-NEXT: %[[#RES:]] = select i1 %[[#CONDF]], i8 %[[#TV]], i8 %[[#FV]]
26+
// LLVM-NEXT: ret i8 %[[#RES]]
27+
// LLVM-NEXT: }
28+
29+
cir.func @logical_and(%arg0 : !cir.bool, %arg1 : !cir.bool) -> !cir.bool {
30+
%0 = cir.const #cir.bool<false> : !cir.bool
31+
%1 = cir.select if %arg0 then %arg1 else %0 : (!cir.bool, !cir.bool, !cir.bool) -> !cir.bool
32+
cir.return %1 : !cir.bool
33+
}
34+
35+
// LLVM: define i8 @logical_and(i8 %[[#ARG0:]], i8 %[[#ARG1:]])
36+
// LLVM-NEXT: %[[#RES:]] = and i8 %[[#ARG0]], %[[#ARG1]]
37+
// LLVM-NEXT: ret i8 %[[#RES]]
38+
// LLVM-NEXT: }
39+
40+
cir.func @logical_or(%arg0 : !cir.bool, %arg1 : !cir.bool) -> !cir.bool {
41+
%0 = cir.const #cir.bool<true> : !cir.bool
42+
%1 = cir.select if %arg0 then %0 else %arg1 : (!cir.bool, !cir.bool, !cir.bool) -> !cir.bool
43+
cir.return %1 : !cir.bool
44+
}
45+
46+
// LLVM: define i8 @logical_or(i8 %[[#ARG0:]], i8 %[[#ARG1:]])
47+
// LLVM-NEXT: %[[#RES:]] = or i8 %[[#ARG0]], %[[#ARG1]]
48+
// LLVM-NEXT: ret i8 %[[#RES]]
49+
// LLVM-NEXT: }
50+
}

‎clang/test/CIR/Transforms/select.cir

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// RUN: cir-opt --canonicalize -o %t.cir %s
2+
// RUN: FileCheck --input-file=%t.cir %s
3+
4+
!s32i = !cir.int<s, 32>
5+
6+
module {
7+
cir.func @fold_true(%arg0 : !s32i, %arg1 : !s32i) -> !s32i {
8+
%0 = cir.const #cir.bool<true> : !cir.bool
9+
%1 = cir.select if %0 then %arg0 else %arg1 : (!cir.bool, !s32i, !s32i) -> !s32i
10+
cir.return %1 : !s32i
11+
}
12+
13+
// CHECK: cir.func @fold_true(%[[ARG0:.+]]: !s32i, %[[ARG1:.+]]: !s32i) -> !s32i {
14+
// CHECK-NEXT: cir.return %[[ARG0]] : !s32i
15+
// CHECK-NEXT: }
16+
17+
cir.func @fold_false(%arg0 : !s32i, %arg1 : !s32i) -> !s32i {
18+
%0 = cir.const #cir.bool<false> : !cir.bool
19+
%1 = cir.select if %0 then %arg0 else %arg1 : (!cir.bool, !s32i, !s32i) -> !s32i
20+
cir.return %1 : !s32i
21+
}
22+
23+
// CHECK: cir.func @fold_false(%[[ARG0:.+]]: !s32i, %[[ARG1:.+]]: !s32i) -> !s32i {
24+
// CHECK-NEXT: cir.return %[[ARG1]] : !s32i
25+
// CHECK-NEXT: }
26+
}

0 commit comments

Comments
 (0)
Please sign in to comment.