Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CIR] Add support for complex cast operations #758

Merged
merged 1 commit into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,22 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
return create<mlir::cir::ConstantOp>(loc, attr.getType(), attr);
}

// Creates constant null value for integral type ty.
mlir::cir::ConstantOp getNullValue(mlir::Type ty, mlir::Location loc) {
return create<mlir::cir::ConstantOp>(loc, ty, getZeroInitAttr(ty));
}

mlir::cir::ConstantOp getBool(bool state, mlir::Location loc) {
return create<mlir::cir::ConstantOp>(loc, getBoolTy(),
getCIRBoolAttr(state));
}
mlir::cir::ConstantOp getFalse(mlir::Location loc) {
return getBool(false, loc);
}
mlir::cir::ConstantOp getTrue(mlir::Location loc) {
return getBool(true, loc);
}

mlir::cir::BoolType getBoolTy() {
return ::mlir::cir::BoolType::get(getContext());
}
Expand Down Expand Up @@ -110,12 +126,16 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
return mlir::cir::FPAttr::getZero(fltType);
if (auto fltType = mlir::dyn_cast<mlir::cir::DoubleType>(ty))
return mlir::cir::FPAttr::getZero(fltType);
if (auto fltType = mlir::dyn_cast<mlir::cir::FP16Type>(ty))
return mlir::cir::FPAttr::getZero(fltType);
if (auto fltType = mlir::dyn_cast<mlir::cir::BF16Type>(ty))
return mlir::cir::FPAttr::getZero(fltType);
if (auto complexType = mlir::dyn_cast<mlir::cir::ComplexType>(ty))
return getZeroAttr(complexType);
if (auto arrTy = mlir::dyn_cast<mlir::cir::ArrayType>(ty))
return getZeroAttr(arrTy);
if (auto ptrTy = mlir::dyn_cast<mlir::cir::PointerType>(ty))
return getConstPtrAttr(ptrTy, 0);
return getConstNullPtrAttr(ptrTy);
if (auto structTy = mlir::dyn_cast<mlir::cir::StructType>(ty))
return getZeroAttr(structTy);
if (mlir::isa<mlir::cir::BoolType>(ty)) {
Expand Down Expand Up @@ -548,6 +568,11 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
getContext(), mlir::cast<mlir::cir::PointerType>(t), val);
}

mlir::TypedAttr getConstNullPtrAttr(mlir::Type t) {
assert(mlir::isa<mlir::cir::PointerType>(t) && "expected cir.ptr");
return getConstPtrAttr(t, 0);
}

// Creates constant nullptr for pointer type ty.
mlir::cir::ConstantOp getNullPtr(mlir::Type ty, mlir::Location loc) {
assert(!MissingFeatures::targetCodeGenInfoGetNullPointer());
Expand Down
28 changes: 27 additions & 1 deletion clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,18 @@ def CK_BooleanToIntegral : I32EnumAttrCase<"bool_to_int", 11>;
def CK_IntegralToFloat : I32EnumAttrCase<"int_to_float", 12>;
def CK_BooleanToFloat : I32EnumAttrCase<"bool_to_float", 13>;
def CK_AddressSpaceConversion : I32EnumAttrCase<"address_space", 14>;
def CK_FloatToComplex : I32EnumAttrCase<"float_to_complex", 15>;
def CK_IntegralToComplex : I32EnumAttrCase<"int_to_complex", 16>;
def CK_FloatComplexToReal : I32EnumAttrCase<"float_complex_to_real", 17>;
def CK_IntegralComplexToReal : I32EnumAttrCase<"int_complex_to_real", 18>;
def CK_FloatComplexToBoolean : I32EnumAttrCase<"float_complex_to_bool", 19>;
def CK_IntegralComplexToBoolean : I32EnumAttrCase<"int_complex_to_bool", 20>;
def CK_FloatComplexCast : I32EnumAttrCase<"float_complex", 21>;
def CK_FloatComplexToIntegralComplex
: I32EnumAttrCase<"float_complex_to_int_complex", 22>;
def CK_IntegralComplexCast : I32EnumAttrCase<"int_complex", 23>;
def CK_IntegralComplexToFloatComplex
: I32EnumAttrCase<"int_complex_to_float_complex", 24>;

def CastKind : I32EnumAttr<
"CastKind",
Expand All @@ -79,7 +91,11 @@ def CastKind : I32EnumAttr<
CK_BitCast, CK_FloatingCast, CK_PtrToBoolean, CK_FloatToIntegral,
CK_IntegralToPointer, CK_PointerToIntegral, CK_FloatToBoolean,
CK_BooleanToIntegral, CK_IntegralToFloat, CK_BooleanToFloat,
CK_AddressSpaceConversion]> {
CK_AddressSpaceConversion, CK_FloatToComplex, CK_IntegralToComplex,
CK_FloatComplexToReal, CK_IntegralComplexToReal, CK_FloatComplexToBoolean,
CK_IntegralComplexToBoolean, CK_FloatComplexCast,
CK_FloatComplexToIntegralComplex, CK_IntegralComplexCast,
CK_IntegralComplexToFloatComplex]> {
let cppNamespace = "::mlir::cir";
}

Expand All @@ -104,6 +120,16 @@ def CastOp : CIR_Op<"cast",
- `bool_to_int`
- `bool_to_float`
- `address_space`
- `float_to_complex`
- `int_to_complex`
- `float_complex_to_real`
- `int_complex_to_real`
- `float_complex_to_bool`
- `int_complex_to_bool`
- `float_complex`
- `float_complex_to_int_complex`
- `int_complex`
- `int_complex_to_float_complex`

This is effectively a subset of the rules from
`llvm-project/clang/include/clang/AST/OperationKinds.def`; but note that some
Expand Down
46 changes: 0 additions & 46 deletions clang/lib/CIR/CodeGen/CIRGenBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,6 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
return mlir::cir::GlobalViewAttr::get(type, symbol, indices);
}

mlir::TypedAttr getConstNullPtrAttr(mlir::Type t) {
assert(mlir::isa<mlir::cir::PointerType>(t) && "expected cir.ptr");
return getConstPtrAttr(t, 0);
}

mlir::Attribute getString(llvm::StringRef str, mlir::Type eltTy,
unsigned size = 0) {
unsigned finalSize = size ? size : str.size();
Expand Down Expand Up @@ -246,31 +241,6 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
return mlir::cir::DataMemberAttr::get(getContext(), ty, std::nullopt);
}

mlir::TypedAttr getZeroInitAttr(mlir::Type ty) {
if (mlir::isa<mlir::cir::IntType>(ty))
return mlir::cir::IntAttr::get(ty, 0);
if (auto fltType = mlir::dyn_cast<mlir::cir::SingleType>(ty))
return mlir::cir::FPAttr::getZero(fltType);
if (auto fltType = mlir::dyn_cast<mlir::cir::DoubleType>(ty))
return mlir::cir::FPAttr::getZero(fltType);
if (auto fltType = mlir::dyn_cast<mlir::cir::FP16Type>(ty))
return mlir::cir::FPAttr::getZero(fltType);
if (auto fltType = mlir::dyn_cast<mlir::cir::BF16Type>(ty))
return mlir::cir::FPAttr::getZero(fltType);
if (auto complexType = mlir::dyn_cast<mlir::cir::ComplexType>(ty))
return getZeroAttr(complexType);
if (auto arrTy = mlir::dyn_cast<mlir::cir::ArrayType>(ty))
return getZeroAttr(arrTy);
if (auto ptrTy = mlir::dyn_cast<mlir::cir::PointerType>(ty))
return getConstNullPtrAttr(ptrTy);
if (auto structTy = mlir::dyn_cast<mlir::cir::StructType>(ty))
return getZeroAttr(structTy);
if (mlir::isa<mlir::cir::BoolType>(ty)) {
return getCIRBoolAttr(false);
}
llvm_unreachable("Zero initializer for given type is NYI");
}

// TODO(cir): Once we have CIR float types, replace this by something like a
// NullableValueInterface to allow for type-independent queries.
bool isNullValue(mlir::Attribute attr) const {
Expand Down Expand Up @@ -554,28 +524,12 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
mlir::cir::IntAttr::get(t, C));
}

mlir::cir::ConstantOp getBool(bool state, mlir::Location loc) {
return create<mlir::cir::ConstantOp>(loc, getBoolTy(),
getCIRBoolAttr(state));
}
mlir::cir::ConstantOp getFalse(mlir::Location loc) {
return getBool(false, loc);
}
mlir::cir::ConstantOp getTrue(mlir::Location loc) {
return getBool(true, loc);
}

/// Create constant nullptr for pointer-to-data-member type ty.
mlir::cir::ConstantOp getNullDataMemberPtr(mlir::cir::DataMemberType ty,
mlir::Location loc) {
return create<mlir::cir::ConstantOp>(loc, ty, getNullDataMemberAttr(ty));
}

// Creates constant null value for integral type ty.
mlir::cir::ConstantOp getNullValue(mlir::Type ty, mlir::Location loc) {
return create<mlir::cir::ConstantOp>(loc, ty, getZeroInitAttr(ty));
}

mlir::cir::ConstantOp getZero(mlir::Location loc, mlir::Type ty) {
// TODO: dispatch creation for primitive types.
assert((mlir::isa<mlir::cir::StructType>(ty) ||
Expand Down
47 changes: 38 additions & 9 deletions clang/lib/CIR/CodeGen/CIRGenExprComplex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -385,20 +385,43 @@ mlir::Value ComplexExprEmitter::buildComplexToComplexCast(mlir::Value Val,
QualType SrcType,
QualType DestType,
SourceLocation Loc) {
// Get the src/dest element type.
SrcType = SrcType->castAs<ComplexType>()->getElementType();
DestType = DestType->castAs<ComplexType>()->getElementType();
if (SrcType == DestType)
return Val;

llvm_unreachable("complex cast is NYI");
// Get the src/dest element type.
QualType SrcElemTy = SrcType->castAs<ComplexType>()->getElementType();
QualType DestElemTy = DestType->castAs<ComplexType>()->getElementType();

mlir::cir::CastKind CastOpKind;
if (SrcElemTy->isFloatingType() && DestElemTy->isFloatingType())
CastOpKind = mlir::cir::CastKind::float_complex;
else if (SrcElemTy->isFloatingType() && DestElemTy->isIntegerType())
CastOpKind = mlir::cir::CastKind::float_complex_to_int_complex;
else if (SrcElemTy->isIntegerType() && DestElemTy->isFloatingType())
CastOpKind = mlir::cir::CastKind::int_complex_to_float_complex;
else if (SrcElemTy->isIntegerType() && DestElemTy->isIntegerType())
CastOpKind = mlir::cir::CastKind::int_complex;
else
llvm_unreachable("unexpected src type or dest type");

return Builder.createCast(CGF.getLoc(Loc), CastOpKind, Val,
CGF.ConvertType(DestType));
}

mlir::Value ComplexExprEmitter::buildScalarToComplexCast(mlir::Value Val,
QualType SrcType,
QualType DestType,
SourceLocation Loc) {
llvm_unreachable("complex cast is NYI");
mlir::cir::CastKind CastOpKind;
if (SrcType->isFloatingType())
CastOpKind = mlir::cir::CastKind::float_to_complex;
else if (SrcType->isIntegerType())
CastOpKind = mlir::cir::CastKind::int_to_complex;
else
llvm_unreachable("unexpected src type");

return Builder.createCast(CGF.getLoc(Loc), CastOpKind, Val,
CGF.ConvertType(DestType));
}

mlir::Value ComplexExprEmitter::buildCast(CastKind CK, Expr *Op,
Expand Down Expand Up @@ -480,14 +503,20 @@ mlir::Value ComplexExprEmitter::buildCast(CastKind CK, Expr *Op,
llvm_unreachable("invalid cast kind for complex value");

case CK_FloatingRealToComplex:
case CK_IntegralRealToComplex:
llvm_unreachable("NYI");
case CK_IntegralRealToComplex: {
assert(!MissingFeatures::CGFPOptionsRAII());
return buildScalarToComplexCast(CGF.buildScalarExpr(Op), Op->getType(),
DestTy, Op->getExprLoc());
}

case CK_FloatingComplexCast:
case CK_FloatingComplexToIntegralComplex:
case CK_IntegralComplexCast:
case CK_IntegralComplexToFloatingComplex:
llvm_unreachable("NYI");
case CK_IntegralComplexToFloatingComplex: {
assert(!MissingFeatures::CGFPOptionsRAII());
return buildComplexToComplexCast(Visit(Op), Op->getType(), DestTy,
Op->getExprLoc());
}
}

llvm_unreachable("unknown cast resulting in complex value");
Expand Down
36 changes: 31 additions & 5 deletions clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
return CGF.buildCheckedLValue(E, TCK);
}

mlir::Value buildComplexToScalarConversion(mlir::Location Loc, mlir::Value V,
CastKind Kind, QualType DestTy);

/// Emit a value that corresponds to null for the given type.
mlir::Value buildNullValue(QualType Ty, mlir::Location loc);

Expand Down Expand Up @@ -1797,13 +1800,13 @@ mlir::Value ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
case CK_MemberPointerToBoolean:
llvm_unreachable("NYI");
case CK_FloatingComplexToReal:
llvm_unreachable("NYI");
case CK_IntegralComplexToReal:
llvm_unreachable("NYI");
case CK_FloatingComplexToBoolean:
llvm_unreachable("NYI");
case CK_IntegralComplexToBoolean:
llvm_unreachable("NYI");
case CK_IntegralComplexToBoolean: {
mlir::Value V = CGF.buildComplexExpr(E);
return buildComplexToScalarConversion(CGF.getLoc(CE->getExprLoc()), V, Kind,
DestTy);
}
case CK_ZeroToOCLOpaqueType:
llvm_unreachable("NYI");
case CK_IntToOCLSampler:
Expand Down Expand Up @@ -2161,6 +2164,29 @@ LValue ScalarExprEmitter::buildCompoundAssignLValue(
return LHSLV;
}

mlir::Value ScalarExprEmitter::buildComplexToScalarConversion(
mlir::Location Loc, mlir::Value V, CastKind Kind, QualType DestTy) {
mlir::cir::CastKind CastOpKind;
switch (Kind) {
case CK_FloatingComplexToReal:
CastOpKind = mlir::cir::CastKind::float_complex_to_real;
break;
case CK_IntegralComplexToReal:
CastOpKind = mlir::cir::CastKind::int_complex_to_real;
break;
case CK_FloatingComplexToBoolean:
CastOpKind = mlir::cir::CastKind::float_complex_to_bool;
break;
case CK_IntegralComplexToBoolean:
CastOpKind = mlir::cir::CastKind::int_complex_to_bool;
break;
default:
llvm_unreachable("invalid complex-to-scalar cast kind");
}

return Builder.createCast(Loc, CastOpKind, V, CGF.ConvertType(DestTy));
}

mlir::Value ScalarExprEmitter::buildNullValue(QualType Ty, mlir::Location loc) {
return CGF.buildFromMemory(CGF.CGM.buildNullConstant(Ty, loc), Ty);
}
Expand Down
Loading
Loading