@@ -2988,6 +2988,60 @@ class CIRRotateOpLowering
2988
2988
}
2989
2989
};
2990
2990
2991
+ class CIRSelectOpLowering
2992
+ : public mlir::OpConversionPattern<mlir::cir::SelectOp> {
2993
+ public:
2994
+ using OpConversionPattern<mlir::cir::SelectOp>::OpConversionPattern;
2995
+
2996
+ mlir::LogicalResult
2997
+ matchAndRewrite (mlir::cir::SelectOp op, OpAdaptor adaptor,
2998
+ mlir::ConversionPatternRewriter &rewriter) const override {
2999
+ auto getConstantBool = [](mlir::Value value) -> std::optional<bool > {
3000
+ auto definingOp = mlir::dyn_cast_if_present<mlir::cir::ConstantOp>(
3001
+ value.getDefiningOp ());
3002
+ if (!definingOp)
3003
+ return std::nullopt;
3004
+
3005
+ auto constValue =
3006
+ mlir::dyn_cast<mlir::cir::BoolAttr>(definingOp.getValue ());
3007
+ if (!constValue)
3008
+ return std::nullopt;
3009
+
3010
+ return constValue.getValue ();
3011
+ };
3012
+
3013
+ // Two special cases in the LLVMIR codegen of select op:
3014
+ // - select %0, %1, false => and %0, %1
3015
+ // - select %0, true, %1 => or %0, %1
3016
+ auto trueValue = op.getTrueValue ();
3017
+ auto falseValue = op.getFalseValue ();
3018
+ if (mlir::isa<mlir::cir::BoolType>(trueValue.getType ())) {
3019
+ if (std::optional<bool > falseValueBool = getConstantBool (falseValue);
3020
+ falseValueBool.has_value () && !*falseValueBool) {
3021
+ // select %0, %1, false => and %0, %1
3022
+ rewriter.replaceOpWithNewOp <mlir::LLVM::AndOp>(
3023
+ op, adaptor.getCondition (), adaptor.getTrueValue ());
3024
+ return mlir::success ();
3025
+ }
3026
+ if (std::optional<bool > trueValueBool = getConstantBool (trueValue);
3027
+ trueValueBool.has_value () && *trueValueBool) {
3028
+ // select %0, true, %1 => or %0, %1
3029
+ rewriter.replaceOpWithNewOp <mlir::LLVM::OrOp>(
3030
+ op, adaptor.getCondition (), adaptor.getFalseValue ());
3031
+ return mlir::success ();
3032
+ }
3033
+ }
3034
+
3035
+ auto llvmCondition = rewriter.create <mlir::LLVM::TruncOp>(
3036
+ op.getLoc (), mlir::IntegerType::get (op->getContext (), 1 ),
3037
+ adaptor.getCondition ());
3038
+ rewriter.replaceOpWithNewOp <mlir::LLVM::SelectOp>(
3039
+ op, llvmCondition, adaptor.getTrueValue (), adaptor.getFalseValue ());
3040
+
3041
+ return mlir::success ();
3042
+ }
3043
+ };
3044
+
2991
3045
class CIRBrOpLowering : public mlir ::OpConversionPattern<mlir::cir::BrOp> {
2992
3046
public:
2993
3047
using OpConversionPattern<mlir::cir::BrOp>::OpConversionPattern;
@@ -3836,20 +3890,20 @@ void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns,
3836
3890
patterns.add <CIRReturnLowering>(patterns.getContext ());
3837
3891
patterns.add <CIRAllocaLowering>(converter, dataLayout, patterns.getContext ());
3838
3892
patterns.add <
3839
- CIRCmpOpLowering, CIRBitClrsbOpLowering, CIRBitClzOpLowering ,
3840
- CIRBitCtzOpLowering, CIRBitFfsOpLowering, CIRBitParityOpLowering ,
3841
- CIRBitPopcountOpLowering, CIRAtomicCmpXchgLowering, CIRAtomicXchgLowering ,
3842
- CIRAtomicFetchLowering, CIRByteswapOpLowering, CIRRotateOpLowering ,
3843
- CIRBrCondOpLowering, CIRPtrStrideOpLowering, CIRCallLowering ,
3844
- CIRTryCallLowering, CIREhInflightOpLowering, CIRUnaryOpLowering ,
3845
- CIRBinOpLowering, CIRBinOpOverflowOpLowering, CIRShiftOpLowering ,
3846
- CIRLoadLowering, CIRConstantLowering, CIRStoreLowering, CIRFuncLowering ,
3847
- CIRCastOpLowering, CIRGlobalOpLowering, CIRGetGlobalOpLowering ,
3848
- CIRComplexCreateOpLowering, CIRComplexRealOpLowering ,
3849
- CIRComplexImagOpLowering, CIRComplexRealPtrOpLowering ,
3850
- CIRComplexImagPtrOpLowering, CIRVAStartLowering, CIRVAEndLowering ,
3851
- CIRVACopyLowering, CIRVAArgLowering, CIRBrOpLowering ,
3852
- CIRGetMemberOpLowering, CIRGetRuntimeMemberOpLowering,
3893
+ CIRCmpOpLowering, CIRSelectOpLowering, CIRBitClrsbOpLowering ,
3894
+ CIRBitClzOpLowering, CIRBitCtzOpLowering, CIRBitFfsOpLowering ,
3895
+ CIRBitParityOpLowering, CIRBitPopcountOpLowering ,
3896
+ CIRAtomicCmpXchgLowering, CIRAtomicXchgLowering, CIRAtomicFetchLowering ,
3897
+ CIRByteswapOpLowering, CIRRotateOpLowering, CIRBrCondOpLowering ,
3898
+ CIRPtrStrideOpLowering, CIRCallLowering, CIRTryCallLowering ,
3899
+ CIREhInflightOpLowering, CIRUnaryOpLowering, CIRBinOpLowering ,
3900
+ CIRBinOpOverflowOpLowering, CIRShiftOpLowering, CIRLoadLowering ,
3901
+ CIRConstantLowering, CIRStoreLowering, CIRFuncLowering, CIRCastOpLowering ,
3902
+ CIRGlobalOpLowering, CIRGetGlobalOpLowering, CIRComplexCreateOpLowering ,
3903
+ CIRComplexRealOpLowering, CIRComplexImagOpLowering ,
3904
+ CIRComplexRealPtrOpLowering, CIRComplexImagPtrOpLowering ,
3905
+ CIRVAStartLowering, CIRVAEndLowering, CIRVACopyLowering, CIRVAArgLowering ,
3906
+ CIRBrOpLowering, CIRGetMemberOpLowering, CIRGetRuntimeMemberOpLowering,
3853
3907
CIRSwitchFlatOpLowering, CIRPtrDiffOpLowering, CIRCopyOpLowering,
3854
3908
CIRMemCpyOpLowering, CIRFAbsOpLowering, CIRExpectOpLowering,
3855
3909
CIRVTableAddrPointOpLowering, CIRVectorCreateLowering,
0 commit comments