@@ -2987,6 +2987,60 @@ class CIRRotateOpLowering
2987
2987
}
2988
2988
};
2989
2989
2990
+ class CIRSelectOpLowering
2991
+ : public mlir::OpConversionPattern<mlir::cir::SelectOp> {
2992
+ public:
2993
+ using OpConversionPattern<mlir::cir::SelectOp>::OpConversionPattern;
2994
+
2995
+ mlir::LogicalResult
2996
+ matchAndRewrite (mlir::cir::SelectOp op, OpAdaptor adaptor,
2997
+ mlir::ConversionPatternRewriter &rewriter) const override {
2998
+ auto getConstantBool = [](mlir::Value value) -> std::optional<bool > {
2999
+ auto definingOp = mlir::dyn_cast_if_present<mlir::cir::ConstantOp>(
3000
+ value.getDefiningOp ());
3001
+ if (!definingOp)
3002
+ return std::nullopt;
3003
+
3004
+ auto constValue =
3005
+ mlir::dyn_cast<mlir::cir::BoolAttr>(definingOp.getValue ());
3006
+ if (!constValue)
3007
+ return std::nullopt;
3008
+
3009
+ return constValue.getValue ();
3010
+ };
3011
+
3012
+ // Two special cases in the LLVMIR codegen of select op:
3013
+ // - select %0, %1, false => and %0, %1
3014
+ // - select %0, true, %1 => or %0, %1
3015
+ auto trueValue = op.getTrueValue ();
3016
+ auto falseValue = op.getFalseValue ();
3017
+ if (mlir::isa<mlir::cir::BoolType>(trueValue.getType ())) {
3018
+ if (std::optional<bool > falseValueBool = getConstantBool (falseValue);
3019
+ falseValueBool.has_value () && !*falseValueBool) {
3020
+ // select %0, %1, false => and %0, %1
3021
+ rewriter.replaceOpWithNewOp <mlir::LLVM::AndOp>(
3022
+ op, adaptor.getCondition (), adaptor.getTrueValue ());
3023
+ return mlir::success ();
3024
+ }
3025
+ if (std::optional<bool > trueValueBool = getConstantBool (trueValue);
3026
+ trueValueBool.has_value () && *trueValueBool) {
3027
+ // select %0, true, %1 => or %0, %1
3028
+ rewriter.replaceOpWithNewOp <mlir::LLVM::OrOp>(
3029
+ op, adaptor.getCondition (), adaptor.getFalseValue ());
3030
+ return mlir::success ();
3031
+ }
3032
+ }
3033
+
3034
+ auto llvmCondition = rewriter.create <mlir::LLVM::TruncOp>(
3035
+ op.getLoc (), mlir::IntegerType::get (op->getContext (), 1 ),
3036
+ adaptor.getCondition ());
3037
+ rewriter.replaceOpWithNewOp <mlir::LLVM::SelectOp>(
3038
+ op, llvmCondition, adaptor.getTrueValue (), adaptor.getFalseValue ());
3039
+
3040
+ return mlir::success ();
3041
+ }
3042
+ };
3043
+
2990
3044
class CIRBrOpLowering : public mlir ::OpConversionPattern<mlir::cir::BrOp> {
2991
3045
public:
2992
3046
using OpConversionPattern<mlir::cir::BrOp>::OpConversionPattern;
@@ -3835,20 +3889,20 @@ void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns,
3835
3889
patterns.add <CIRReturnLowering>(patterns.getContext ());
3836
3890
patterns.add <CIRAllocaLowering>(converter, dataLayout, patterns.getContext ());
3837
3891
patterns.add <
3838
- CIRCmpOpLowering, CIRBitClrsbOpLowering, CIRBitClzOpLowering ,
3839
- CIRBitCtzOpLowering, CIRBitFfsOpLowering, CIRBitParityOpLowering ,
3840
- CIRBitPopcountOpLowering, CIRAtomicCmpXchgLowering, CIRAtomicXchgLowering ,
3841
- CIRAtomicFetchLowering, CIRByteswapOpLowering, CIRRotateOpLowering ,
3842
- CIRBrCondOpLowering, CIRPtrStrideOpLowering, CIRCallLowering ,
3843
- CIRTryCallLowering, CIREhInflightOpLowering, CIRUnaryOpLowering ,
3844
- CIRBinOpLowering, CIRBinOpOverflowOpLowering, CIRShiftOpLowering ,
3845
- CIRLoadLowering, CIRConstantLowering, CIRStoreLowering, CIRFuncLowering ,
3846
- CIRCastOpLowering, CIRGlobalOpLowering, CIRGetGlobalOpLowering ,
3847
- CIRComplexCreateOpLowering, CIRComplexRealOpLowering ,
3848
- CIRComplexImagOpLowering, CIRComplexRealPtrOpLowering ,
3849
- CIRComplexImagPtrOpLowering, CIRVAStartLowering, CIRVAEndLowering ,
3850
- CIRVACopyLowering, CIRVAArgLowering, CIRBrOpLowering ,
3851
- CIRGetMemberOpLowering, CIRGetRuntimeMemberOpLowering,
3892
+ CIRCmpOpLowering, CIRSelectOpLowering, CIRBitClrsbOpLowering ,
3893
+ CIRBitClzOpLowering, CIRBitCtzOpLowering, CIRBitFfsOpLowering ,
3894
+ CIRBitParityOpLowering, CIRBitPopcountOpLowering ,
3895
+ CIRAtomicCmpXchgLowering, CIRAtomicXchgLowering, CIRAtomicFetchLowering ,
3896
+ CIRByteswapOpLowering, CIRRotateOpLowering, CIRBrCondOpLowering ,
3897
+ CIRPtrStrideOpLowering, CIRCallLowering, CIRTryCallLowering ,
3898
+ CIREhInflightOpLowering, CIRUnaryOpLowering, CIRBinOpLowering ,
3899
+ CIRBinOpOverflowOpLowering, CIRShiftOpLowering, CIRLoadLowering ,
3900
+ CIRConstantLowering, CIRStoreLowering, CIRFuncLowering, CIRCastOpLowering ,
3901
+ CIRGlobalOpLowering, CIRGetGlobalOpLowering, CIRComplexCreateOpLowering ,
3902
+ CIRComplexRealOpLowering, CIRComplexImagOpLowering ,
3903
+ CIRComplexRealPtrOpLowering, CIRComplexImagPtrOpLowering ,
3904
+ CIRVAStartLowering, CIRVAEndLowering, CIRVACopyLowering, CIRVAArgLowering ,
3905
+ CIRBrOpLowering, CIRGetMemberOpLowering, CIRGetRuntimeMemberOpLowering,
3852
3906
CIRSwitchFlatOpLowering, CIRPtrDiffOpLowering, CIRCopyOpLowering,
3853
3907
CIRMemCpyOpLowering, CIRFAbsOpLowering, CIRExpectOpLowering,
3854
3908
CIRVTableAddrPointOpLowering, CIRVectorCreateLowering,
0 commit comments