Skip to content

Commit 80e1a10

Browse files
authored
[CIR][LLVMLowering] Add LLVM lowering for complex operations (#723)
This PR adds LLVM lowering for the following operations related to complex numbers: - `cir.complex.create`, - `cir.complex.real_ptr`, and - `cir.complex.imag_ptr`. The LLVM IR generated for `cir.complex.create` is a bit ugly since it includes the `insertvalue` instruction, which typically is not generated in upstream CodeGen. Later we may need further CIR canonicalization passes to try folding `cir.complex.create`.
1 parent 58d5f0b commit 80e1a10

File tree

2 files changed

+142
-3
lines changed

2 files changed

+142
-3
lines changed

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

+86-3
Original file line numberDiff line numberDiff line change
@@ -1633,6 +1633,80 @@ class CIRGetGlobalOpLowering
16331633
}
16341634
};
16351635

1636+
class CIRComplexCreateOpLowering
1637+
: public mlir::OpConversionPattern<mlir::cir::ComplexCreateOp> {
1638+
public:
1639+
using OpConversionPattern<mlir::cir::ComplexCreateOp>::OpConversionPattern;
1640+
1641+
mlir::LogicalResult
1642+
matchAndRewrite(mlir::cir::ComplexCreateOp op, OpAdaptor adaptor,
1643+
mlir::ConversionPatternRewriter &rewriter) const override {
1644+
auto complexLLVMTy =
1645+
getTypeConverter()->convertType(op.getResult().getType());
1646+
auto initialComplex =
1647+
rewriter.create<mlir::LLVM::UndefOp>(op->getLoc(), complexLLVMTy);
1648+
1649+
int64_t position[1]{0};
1650+
auto realComplex = rewriter.create<mlir::LLVM::InsertValueOp>(
1651+
op->getLoc(), initialComplex, adaptor.getReal(), position);
1652+
1653+
position[0] = 1;
1654+
auto complex = rewriter.create<mlir::LLVM::InsertValueOp>(
1655+
op->getLoc(), realComplex, adaptor.getImag(), position);
1656+
1657+
rewriter.replaceOp(op, complex);
1658+
return mlir::success();
1659+
}
1660+
};
1661+
1662+
class CIRComplexRealPtrOPLowering
1663+
: public mlir::OpConversionPattern<mlir::cir::ComplexRealPtrOp> {
1664+
public:
1665+
using OpConversionPattern<mlir::cir::ComplexRealPtrOp>::OpConversionPattern;
1666+
1667+
mlir::LogicalResult
1668+
matchAndRewrite(mlir::cir::ComplexRealPtrOp op, OpAdaptor adaptor,
1669+
mlir::ConversionPatternRewriter &rewriter) const override {
1670+
auto operandTy =
1671+
mlir::cast<mlir::cir::PointerType>(op.getOperand().getType());
1672+
auto resultLLVMTy =
1673+
getTypeConverter()->convertType(op.getResult().getType());
1674+
auto elementLLVMTy =
1675+
getTypeConverter()->convertType(operandTy.getPointee());
1676+
1677+
mlir::LLVM::GEPArg gepIndices[2]{{0}, {0}};
1678+
rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(
1679+
op, resultLLVMTy, elementLLVMTy, adaptor.getOperand(), gepIndices,
1680+
/*inbounds=*/true);
1681+
1682+
return mlir::success();
1683+
}
1684+
};
1685+
1686+
class CIRComplexImagPtrOpLowering
1687+
: public mlir::OpConversionPattern<mlir::cir::ComplexImagPtrOp> {
1688+
public:
1689+
using OpConversionPattern<mlir::cir::ComplexImagPtrOp>::OpConversionPattern;
1690+
1691+
mlir::LogicalResult
1692+
matchAndRewrite(mlir::cir::ComplexImagPtrOp op, OpAdaptor adaptor,
1693+
mlir::ConversionPatternRewriter &rewriter) const override {
1694+
auto operandTy =
1695+
mlir::cast<mlir::cir::PointerType>(op.getOperand().getType());
1696+
auto resultLLVMTy =
1697+
getTypeConverter()->convertType(op.getResult().getType());
1698+
auto elementLLVMTy =
1699+
getTypeConverter()->convertType(operandTy.getPointee());
1700+
1701+
mlir::LLVM::GEPArg gepIndices[2]{{0}, {1}};
1702+
rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(
1703+
op, resultLLVMTy, elementLLVMTy, adaptor.getOperand(), gepIndices,
1704+
/*inbounds=*/true);
1705+
1706+
return mlir::success();
1707+
}
1708+
};
1709+
16361710
class CIRSwitchFlatOpLowering
16371711
: public mlir::OpConversionPattern<mlir::cir::SwitchFlatOp> {
16381712
public:
@@ -3365,9 +3439,10 @@ void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns,
33653439
CIRUnaryOpLowering, CIRBinOpLowering, CIRBinOpOverflowOpLowering,
33663440
CIRShiftOpLowering, CIRLoadLowering, CIRConstantLowering,
33673441
CIRStoreLowering, CIRAllocaLowering, CIRFuncLowering, CIRCastOpLowering,
3368-
CIRGlobalOpLowering, CIRGetGlobalOpLowering, CIRVAStartLowering,
3369-
CIRVAEndLowering, CIRVACopyLowering, CIRVAArgLowering, CIRBrOpLowering,
3370-
CIRGetMemberOpLowering, CIRGetRuntimeMemberOpLowering,
3442+
CIRGlobalOpLowering, CIRGetGlobalOpLowering, CIRComplexCreateOpLowering,
3443+
CIRComplexRealPtrOPLowering, CIRComplexImagPtrOpLowering,
3444+
CIRVAStartLowering, CIRVAEndLowering, CIRVACopyLowering, CIRVAArgLowering,
3445+
CIRBrOpLowering, CIRGetMemberOpLowering, CIRGetRuntimeMemberOpLowering,
33713446
CIRSwitchFlatOpLowering, CIRPtrDiffOpLowering, CIRCopyOpLowering,
33723447
CIRMemCpyOpLowering, CIRFAbsOpLowering, CIRExpectOpLowering,
33733448
CIRVTableAddrPointOpLowering, CIRVectorCreateLowering,
@@ -3444,6 +3519,14 @@ void prepareTypeConverter(mlir::LLVMTypeConverter &converter,
34443519
converter.addConversion([&](mlir::cir::BF16Type type) -> mlir::Type {
34453520
return mlir::FloatType::getBF16(type.getContext());
34463521
});
3522+
converter.addConversion([&](mlir::cir::ComplexType type) -> mlir::Type {
3523+
// A complex type is lowered to an LLVM struct that contains the real and
3524+
// imaginary part as data fields.
3525+
mlir::Type elementTy = converter.convertType(type.getElementTy());
3526+
mlir::Type structFields[2] = {elementTy, elementTy};
3527+
return mlir::LLVM::LLVMStructType::getLiteral(type.getContext(),
3528+
structFields);
3529+
});
34473530
converter.addConversion([&](mlir::cir::FuncType type) -> mlir::Type {
34483531
auto result = converter.convertType(type.getReturnType());
34493532
llvm::SmallVector<mlir::Type> arguments;

clang/test/CIR/CodeGen/complex.c

+56
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
// RUN: FileCheck --input-file=%t.cir --check-prefixes=C,CHECK %s
33
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -x c++ -fclangir -emit-cir -o %t.cir %s
44
// RUN: FileCheck --input-file=%t.cir --check-prefixes=CPP,CHECK %s
5+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-llvm -o %t.ll %s
6+
// RUN: FileCheck --input-file=%t.ll --check-prefixes=LLVM %s
57

68
double _Complex c, c2;
79
int _Complex ci, ci2;
@@ -24,6 +26,10 @@ void list_init() {
2426
// CHECK-NEXT: %{{.+}} = cir.complex.create %[[#REAL]], %[[#IMAG]] : !s32i -> !cir.complex<!s32i>
2527
// CHECK: }
2628

29+
// LLVM: define void @list_init()
30+
// LLVM: store { double, double } { double 1.000000e+00, double 2.000000e+00 }, ptr %{{.+}}, align 8
31+
// LLVM: }
32+
2733
void list_init_2(double r, double i) {
2834
double _Complex c1 = {r, i};
2935
}
@@ -36,6 +42,12 @@ void list_init_2(double r, double i) {
3642
// CHECK-NEXT: cir.store %[[#C]], %{{.+}} : !cir.complex<!cir.double>, !cir.ptr<!cir.complex<!cir.double>>
3743
// CHECK: }
3844

45+
// LLVM: define void @list_init_2(double %{{.+}}, double %{{.+}})
46+
// LLVM: %[[#A:]] = insertvalue { double, double } undef, double %{{.+}}, 0
47+
// LLVM-NEXT: %[[#B:]] = insertvalue { double, double } %[[#A]], double %{{.+}}, 1
48+
// LLVM-NEXT: store { double, double } %[[#B]], ptr %5, align 8
49+
// LLVM: }
50+
3951
void imag_literal() {
4052
c = 3.0i;
4153
ci = 3i;
@@ -51,6 +63,11 @@ void imag_literal() {
5163
// CHECK-NEXT: %{{.+}} = cir.complex.create %[[#REAL]], %[[#IMAG]] : !s32i -> !cir.complex<!s32i>
5264
// CHECK: }
5365

66+
// LLVM: define void @imag_literal()
67+
// LLVM: store { double, double } { double 0.000000e+00, double 3.000000e+00 }, ptr @c, align 8
68+
// LLVM: store { i32, i32 } { i32 0, i32 3 }, ptr @ci, align 4
69+
// LLVM: }
70+
5471
void load_store() {
5572
c = c2;
5673
ci = ci2;
@@ -68,6 +85,13 @@ void load_store() {
6885
// CHECK-NEXT: cir.store %[[#CI2]], %[[#CI_PTR]] : !cir.complex<!s32i>, !cir.ptr<!cir.complex<!s32i>>
6986
// CHECK: }
7087

88+
// LLVM: define void @load_store()
89+
// LLVM: %[[#A:]] = load { double, double }, ptr @c2, align 8
90+
// LLVM-NEXT: store { double, double } %[[#A]], ptr @c, align 8
91+
// LLVM-NEXT: %[[#B:]] = load { i32, i32 }, ptr @ci2, align 4
92+
// LLVM-NEXT: store { i32, i32 } %[[#B]], ptr @ci, align 4
93+
// LLVM: }
94+
7195
void load_store_volatile() {
7296
vc = vc2;
7397
vci = vci2;
@@ -85,6 +109,13 @@ void load_store_volatile() {
85109
// CHECK-NEXT: cir.store volatile %[[#VCI2]], %[[#VCI_PTR]] : !cir.complex<!s32i>, !cir.ptr<!cir.complex<!s32i>>
86110
// CHECK: }
87111

112+
// LLVM: define void @load_store_volatile()
113+
// LLVM: %[[#A:]] = load volatile { double, double }, ptr @vc2, align 8
114+
// LLVM-NEXT: store volatile { double, double } %[[#A]], ptr @vc, align 8
115+
// LLVM-NEXT: %[[#B:]] = load volatile { i32, i32 }, ptr @vci2, align 4
116+
// LLVM-NEXT: store volatile { i32, i32 } %[[#B]], ptr @vci, align 4
117+
// LLVM: }
118+
88119
void real_ptr() {
89120
double *r1 = &__real__ c;
90121
int *r2 = &__real__ ci;
@@ -98,6 +129,11 @@ void real_ptr() {
98129
// CHECK-NEXT: %{{.+}} = cir.complex.real_ptr %[[#CI_PTR]] : !cir.ptr<!cir.complex<!s32i>> -> !cir.ptr<!s32i>
99130
// CHECK: }
100131

132+
// LLVM: define void @real_ptr()
133+
// LLVM: store ptr @c, ptr %{{.+}}, align 8
134+
// LLVM-NEXT: store ptr @ci, ptr %{{.+}}, align 8
135+
// LLVM: }
136+
101137
void real_ptr_local() {
102138
double _Complex c1 = {1.0, 2.0};
103139
double *r3 = &__real__ c1;
@@ -109,6 +145,11 @@ void real_ptr_local() {
109145
// CHECK: %{{.+}} = cir.complex.real_ptr %[[#C]] : !cir.ptr<!cir.complex<!cir.double>> -> !cir.ptr<!cir.double>
110146
// CHECK: }
111147

148+
// LLVM: define void @real_ptr_local()
149+
// LLVM: store { double, double } { double 1.000000e+00, double 2.000000e+00 }, ptr %{{.+}}, align 8
150+
// LLVM-NEXT: %{{.+}} = getelementptr inbounds { double, double }, ptr %{{.+}}, i32 0, i32 0
151+
// LLVM: }
152+
112153
void extract_real() {
113154
double r1 = __real__ c;
114155
int r2 = __real__ ci;
@@ -124,6 +165,11 @@ void extract_real() {
124165
// CHECK-NEXT: %{{.+}} = cir.load %[[#REAL_PTR]] : !cir.ptr<!s32i>, !s32i
125166
// CHECK: }
126167

168+
// LLVM: define void @extract_real()
169+
// LLVM: %{{.+}} = load double, ptr @c, align 8
170+
// LLVM: %{{.+}} = load i32, ptr @ci, align 4
171+
// LLVM: }
172+
127173
void imag_ptr() {
128174
double *i1 = &__imag__ c;
129175
int *i2 = &__imag__ ci;
@@ -137,6 +183,11 @@ void imag_ptr() {
137183
// CHECK-NEXT: %{{.+}} = cir.complex.imag_ptr %[[#CI_PTR]] : !cir.ptr<!cir.complex<!s32i>> -> !cir.ptr<!s32i>
138184
// CHECK: }
139185

186+
// LLVM: define void @imag_ptr()
187+
// LLVM: store ptr getelementptr inbounds ({ double, double }, ptr @c, i32 0, i32 1), ptr %{{.+}}, align 8
188+
// LLVM: store ptr getelementptr inbounds ({ i32, i32 }, ptr @ci, i32 0, i32 1), ptr %{{.+}}, align 8
189+
// LLVM: }
190+
140191
void extract_imag() {
141192
double i1 = __imag__ c;
142193
int i2 = __imag__ ci;
@@ -151,3 +202,8 @@ void extract_imag() {
151202
// CHECK-NEXT: %[[#IMAG_PTR:]] = cir.complex.imag_ptr %[[#CI_PTR]] : !cir.ptr<!cir.complex<!s32i>> -> !cir.ptr<!s32i>
152203
// CHECK-NEXT: %{{.+}} = cir.load %[[#IMAG_PTR]] : !cir.ptr<!s32i>, !s32i
153204
// CHECK: }
205+
206+
// LLVM: define void @extract_imag()
207+
// LLVM: %{{.+}} = load double, ptr getelementptr inbounds ({ double, double }, ptr @c, i32 0, i32 1), align 8
208+
// LLVM: %{{.+}} = load i32, ptr getelementptr inbounds ({ i32, i32 }, ptr @ci, i32 0, i32 1), align 4
209+
// LLVM: }

0 commit comments

Comments
 (0)