Skip to content

Commit 5ff3403

Browse files
seven-milelanza
authored andcommitted
[CIR][Dialect] Make addrspace in pointer types to model LangAS (llvm#692)
This PR implements the solution B as discussed in llvm#682. * Use the syntax `cir.ptr<T>` `cir.ptr<T, addrspace(target<0>)` `cir.ptr<T, addrspace(opencl_private)>` * Add a new `AddressSpaceAttr`, which is used as the new type of addrspace parameter in `PointerType` * `AddressSpaceAttr` itself takes one single `int64_t $value` as the parameter * TableGen templates to generate the conversion between `clang::LangAS -> int64_t $value <-> text-form CIR`
1 parent 2c3b8d9 commit 5ff3403

17 files changed

+341
-57
lines changed

clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h

+10-6
Original file line numberDiff line numberDiff line change
@@ -75,14 +75,18 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
7575
return mlir::cir::IntType::get(getContext(), N, true);
7676
}
7777

78-
mlir::cir::PointerType getPointerTo(mlir::Type ty,
79-
unsigned addressSpace = 0) {
80-
assert(!addressSpace && "address space is NYI");
81-
return mlir::cir::PointerType::get(getContext(), ty);
78+
mlir::cir::PointerType
79+
getPointerTo(mlir::Type ty, clang::LangAS langAS = clang::LangAS::Default) {
80+
mlir::cir::AddressSpaceAttr addrSpaceAttr;
81+
if (langAS != clang::LangAS::Default)
82+
addrSpaceAttr = mlir::cir::AddressSpaceAttr::get(getContext(), langAS);
83+
84+
return mlir::cir::PointerType::get(getContext(), ty, addrSpaceAttr);
8285
}
8386

84-
mlir::cir::PointerType getVoidPtrTy(unsigned addressSpace = 0) {
85-
return getPointerTo(::mlir::cir::VoidType::get(getContext()), addressSpace);
87+
mlir::cir::PointerType
88+
getVoidPtrTy(clang::LangAS langAS = clang::LangAS::Default) {
89+
return getPointerTo(::mlir::cir::VoidType::get(getContext()), langAS);
8690
}
8791

8892
mlir::Value createLoad(mlir::Location loc, mlir::Value ptr,

clang/include/clang/CIR/Dialect/IR/CIRAttrs.td

+137
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,143 @@ def DynamicCastInfoAttr
617617
}];
618618
}
619619

620+
//===----------------------------------------------------------------------===//
621+
// AddressSpaceAttr
622+
//===----------------------------------------------------------------------===//
623+
624+
// TODO: other CIR AS cases
625+
def AS_Target : I32EnumAttrCase<"target", 21>;
626+
627+
def AddressSpaceAttr : CIR_Attr<"AddressSpace", "addrspace"> {
628+
629+
let summary = "Address space attribute for pointer types";
630+
let description = [{
631+
The address space attribute models `clang::LangAS` rather than the LLVM
632+
address space, which means it's not yet converted by the address space map
633+
to carry target-specific semantics.
634+
635+
The representation is one-to-one except for `LangAS::Default`, which
636+
corresponds to a null attribute instead.
637+
}];
638+
639+
let parameters = (ins "int32_t":$value);
640+
641+
let assemblyFormat = [{
642+
`<` $value `>`
643+
}];
644+
645+
let builders = [
646+
AttrBuilder<(ins "clang::LangAS":$langAS), [{
647+
assert(langAS != clang::LangAS::Default &&
648+
"Default address space is encoded as null attribute");
649+
return $_get($_ctxt, getValueFromLangAS(langAS).value());
650+
}]>
651+
];
652+
653+
let cppNamespace = "::mlir::cir";
654+
655+
// The following codes implement these conversions:
656+
// clang::LangAS -> int32_t <-> text-form CIR
657+
658+
// CIR_PointerType manipulates the parse- and stringify- methods to provide
659+
// simplified assembly format `custom<PointerAddrSpace>`.
660+
661+
list<I32EnumAttrCase> langASCases = [
662+
// TODO: includes all non-target CIR AS cases here
663+
];
664+
665+
I32EnumAttrCase targetASCase = AS_Target;
666+
667+
let extraClassDeclaration = [{
668+
static constexpr char kTargetKeyword[] = "}]#targetASCase.symbol#[{";
669+
static constexpr int32_t kFirstTargetASValue = }]#targetASCase.value#[{;
670+
671+
bool isLang() const;
672+
bool isTarget() const;
673+
unsigned getTargetValue() const;
674+
675+
static std::optional<int32_t> parseValueFromString(llvm::StringRef s);
676+
static std::optional<int32_t> getValueFromLangAS(clang::LangAS v);
677+
static std::optional<llvm::StringRef> stringifyValue(int32_t v);
678+
}];
679+
680+
let extraClassDefinition = [{
681+
bool $cppClass::isLang() const {
682+
return !isTarget();
683+
}
684+
685+
bool $cppClass::isTarget() const {
686+
return getValue() >= kFirstTargetASValue;
687+
}
688+
689+
unsigned $cppClass::getTargetValue() const {
690+
assert(isTarget() && "Not a target address space");
691+
return getValue() - kFirstTargetASValue;
692+
}
693+
694+
std::optional<int32_t>
695+
$cppClass::parseValueFromString(llvm::StringRef str) {
696+
return llvm::StringSwitch<::std::optional<int32_t>>(str)
697+
}]
698+
#
699+
!interleave(
700+
!foreach(case, langASCases,
701+
".Case(\""#case.symbol# "\", "#case.value # ")\n"
702+
),
703+
"\n"
704+
)
705+
#
706+
[{
707+
// Target address spaces are not parsed here
708+
.Default(std::nullopt);
709+
}
710+
711+
std::optional<llvm::StringRef>
712+
$cppClass::stringifyValue(int32_t value) {
713+
switch (value) {
714+
}]
715+
#
716+
!interleave(
717+
!foreach(case, langASCases,
718+
"case "#case.value
719+
# ": return \""#case.symbol # "\";" ),
720+
"\n"
721+
)
722+
#
723+
[{
724+
default:
725+
// Target address spaces are not processed here
726+
return std::nullopt;
727+
}
728+
}
729+
730+
std::optional<int32_t>
731+
$cppClass::getValueFromLangAS(clang::LangAS langAS) {
732+
assert((langAS == clang::LangAS::Default ||
733+
clang::isTargetAddressSpace(langAS)) &&
734+
"Language-specific address spaces are not supported");
735+
switch (langAS) {
736+
}]
737+
#
738+
!interleave(
739+
!foreach(case, langASCases,
740+
"case clang::LangAS::"#case.symbol
741+
# [{: llvm_unreachable("Not Yet Supported");}] ),
742+
"\n"
743+
)
744+
#
745+
[{
746+
case clang::LangAS::Default:
747+
// Default address space should be encoded as a null attribute.
748+
return std::nullopt;
749+
default:
750+
// Target address space offset arithmetics
751+
return clang::toTargetAddressSpace(langAS) + kFirstTargetASValue;
752+
}
753+
}
754+
}];
755+
}
756+
620757
//===----------------------------------------------------------------------===//
621758
// AST Wrappers
622759
//===----------------------------------------------------------------------===//

clang/include/clang/CIR/Dialect/IR/CIRTypes.h

+2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020

2121
#include "clang/CIR/Interfaces/ASTAttrInterfaces.h"
2222

23+
#include "clang/CIR/Dialect/IR/CIROpsEnums.h"
24+
2325
//===----------------------------------------------------------------------===//
2426
// CIR StructType
2527
//

clang/include/clang/CIR/Dialect/IR/CIRTypes.td

+18-8
Original file line numberDiff line numberDiff line change
@@ -208,24 +208,34 @@ def CIR_PointerType : CIR_Type<"Pointer", "ptr",
208208
`CIR.ptr` is a type returned by any op generating a pointer in C++.
209209
}];
210210

211-
let parameters = (ins "mlir::Type":$pointee,
212-
DefaultValuedParameter<"unsigned", "0">:$addrSpace);
211+
let parameters = (ins
212+
"mlir::Type":$pointee,
213+
// FIXME(cir): Currently unable to directly use AddressSpaceAttr because of
214+
// cyclic dep. Workaround with the top type and verifier.
215+
OptionalParameter<"mlir::Attribute">:$addrSpace
216+
);
213217

214218
let builders = [
215219
TypeBuilderWithInferredContext<(ins
216-
"mlir::Type":$pointee, CArg<"unsigned", "0">:$addrSpace), [{
217-
return Base::get(pointee.getContext(), pointee, addrSpace);
220+
"mlir::Type":$pointee,
221+
CArg<"mlir::Attribute", "{}">:$addrSpace), [{
222+
return $_get(pointee.getContext(), pointee, addrSpace);
218223
}]>,
219224
TypeBuilder<(ins
220-
"mlir::Type":$pointee, CArg<"unsigned", "0">:$addrSpace), [{
221-
return Base::get($_ctxt, pointee, addrSpace);
222-
}]>,
225+
"mlir::Type":$pointee,
226+
CArg<"mlir::Attribute", "{}">:$addrSpace), [{
227+
return $_get($_ctxt, pointee, addrSpace);
228+
}]>
223229
];
224230

225231
let assemblyFormat = [{
226-
`<` $pointee ( `,` `addrspace` `(` $addrSpace^ `)` )? `>`
232+
`<` $pointee ( `,` `addrspace` `(`
233+
custom<PointerAddrSpace>($addrSpace)^
234+
`)` )? `>`
227235
}];
228236

237+
let genVerifyDecl = 1;
238+
229239
let skipDefaultBuilders = 1;
230240

231241
let extraClassDeclaration = [{

clang/include/clang/CIR/MissingFeatures.h

+1
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ struct MissingFeatures {
156156
static bool constantFoldsToSimpleInteger() { return false; }
157157
static bool checkFunctionCallABI() { return false; }
158158
static bool zeroInitializer() { return false; }
159+
static bool targetLoweringInfoAddressSpaceMap() { return false; }
159160
static bool targetCodeGenInfoIsProtoCallVariadic() { return false; }
160161
static bool targetCodeGenInfoGetNullPointer() { return false; }
161162
static bool operandBundles() { return false; }

clang/lib/CIR/CodeGen/CIRGenTypes.cpp

+2-6
Original file line numberDiff line numberDiff line change
@@ -605,9 +605,7 @@ mlir::Type CIRGenTypes::ConvertType(QualType T) {
605605
const ReferenceType *RTy = cast<ReferenceType>(Ty);
606606
QualType ETy = RTy->getPointeeType();
607607
auto PointeeType = convertTypeForMem(ETy);
608-
ResultType = ::mlir::cir::PointerType::get(
609-
Builder.getContext(), PointeeType,
610-
Context.getTargetAddressSpace(ETy.getAddressSpace()));
608+
ResultType = Builder.getPointerTo(PointeeType, ETy.getAddressSpace());
611609
assert(ResultType && "Cannot get pointer type?");
612610
break;
613611
}
@@ -622,9 +620,7 @@ mlir::Type CIRGenTypes::ConvertType(QualType T) {
622620
// if (PointeeType->isVoidTy())
623621
// PointeeType = Builder.getI8Type();
624622

625-
ResultType = ::mlir::cir::PointerType::get(
626-
Builder.getContext(), PointeeType,
627-
Context.getTargetAddressSpace(ETy.getAddressSpace()));
623+
ResultType = Builder.getPointerTo(PointeeType, ETy.getAddressSpace());
628624
assert(ResultType && "Cannot get pointer type?");
629625
break;
630626
}

clang/lib/CIR/Dialect/IR/CIRTypes.cpp

+64
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ parseFuncTypeArgs(mlir::AsmParser &p, llvm::SmallVector<mlir::Type> &params,
4545
static void printFuncTypeArgs(mlir::AsmPrinter &p,
4646
mlir::ArrayRef<mlir::Type> params, bool isVarArg);
4747

48+
static mlir::ParseResult parsePointerAddrSpace(mlir::AsmParser &p,
49+
mlir::Attribute &addrSpaceAttr);
50+
static void printPointerAddrSpace(mlir::AsmPrinter &p,
51+
mlir::Attribute addrSpaceAttr);
52+
4853
//===----------------------------------------------------------------------===//
4954
// Get autogenerated stuff
5055
//===----------------------------------------------------------------------===//
@@ -872,6 +877,65 @@ llvm::ArrayRef<mlir::Type> FuncType::getReturnTypes() const {
872877

873878
bool FuncType::isVoid() const { return mlir::isa<VoidType>(getReturnType()); }
874879

880+
//===----------------------------------------------------------------------===//
881+
// PointerType Definitions
882+
//===----------------------------------------------------------------------===//
883+
884+
mlir::LogicalResult
885+
PointerType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
886+
mlir::Type pointee, mlir::Attribute addrSpace) {
887+
if (addrSpace && !mlir::isa<mlir::cir::AddressSpaceAttr>(addrSpace)) {
888+
emitError() << "unexpected addrspace attribute type";
889+
return mlir::failure();
890+
}
891+
return mlir::success();
892+
}
893+
894+
mlir::ParseResult parsePointerAddrSpace(mlir::AsmParser &p,
895+
mlir::Attribute &addrSpaceAttr) {
896+
using mlir::cir::AddressSpaceAttr;
897+
auto attrLoc = p.getCurrentLocation();
898+
899+
llvm::StringRef addrSpaceKind;
900+
if (mlir::failed(p.parseOptionalKeyword(&addrSpaceKind))) {
901+
p.emitError(attrLoc, "expected keyword for addrspace kind");
902+
return mlir::failure();
903+
}
904+
905+
if (addrSpaceKind == AddressSpaceAttr::kTargetKeyword) {
906+
int64_t targetValue = -1;
907+
if (p.parseLess() || p.parseInteger(targetValue) || p.parseGreater()) {
908+
return mlir::failure();
909+
}
910+
addrSpaceAttr = AddressSpaceAttr::get(
911+
p.getContext(), AddressSpaceAttr::kFirstTargetASValue + targetValue);
912+
} else {
913+
std::optional<int64_t> value =
914+
AddressSpaceAttr::parseValueFromString(addrSpaceKind);
915+
// not target AS, must be wrong keyword if no value
916+
if (!value.has_value()) {
917+
p.emitError(attrLoc, "invalid addrspace kind keyword: " + addrSpaceKind);
918+
return mlir::failure();
919+
}
920+
921+
addrSpaceAttr = AddressSpaceAttr::get(p.getContext(), *value);
922+
}
923+
924+
return mlir::success();
925+
}
926+
927+
void printPointerAddrSpace(mlir::AsmPrinter &p,
928+
mlir::Attribute rawAddrSpaceAttr) {
929+
using mlir::cir::AddressSpaceAttr;
930+
auto addrSpaceAttr = mlir::cast<AddressSpaceAttr>(rawAddrSpaceAttr);
931+
if (addrSpaceAttr.isTarget()) {
932+
p << AddressSpaceAttr::kTargetKeyword << "<"
933+
<< addrSpaceAttr.getTargetValue() << ">";
934+
} else {
935+
p << AddressSpaceAttr::stringifyValue(addrSpaceAttr.getValue());
936+
}
937+
}
938+
875939
//===----------------------------------------------------------------------===//
876940
// CIR Dialect
877941
//===----------------------------------------------------------------------===//

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

+13-1
Original file line numberDiff line numberDiff line change
@@ -3337,8 +3337,20 @@ void prepareTypeConverter(mlir::LLVMTypeConverter &converter,
33373337
mlir::DataLayout &dataLayout) {
33383338
converter.addConversion([&](mlir::cir::PointerType type) -> mlir::Type {
33393339
// Drop pointee type since LLVM dialect only allows opaque pointers.
3340+
3341+
auto addrSpace =
3342+
mlir::cast_if_present<mlir::cir::AddressSpaceAttr>(type.getAddrSpace());
3343+
// null addrspace attribute indicates the default addrspace
3344+
if (!addrSpace)
3345+
return mlir::LLVM::LLVMPointerType::get(type.getContext());
3346+
3347+
// TODO(cir): Query the target-specific address space map to lower other ASs
3348+
// like `opencl_private`.
3349+
assert(!MissingFeatures::targetLoweringInfoAddressSpaceMap());
3350+
assert(addrSpace.isTarget() && "NYI");
3351+
33403352
return mlir::LLVM::LLVMPointerType::get(type.getContext(),
3341-
type.getAddrSpace());
3353+
addrSpace.getTargetValue());
33423354
});
33433355
converter.addConversion([&](mlir::cir::DataMemberType type) -> mlir::Type {
33443356
return mlir::IntegerType::get(type.getContext(),

clang/test/CIR/CodeGen/OpenCL/addrspace-alloca.cl

+7-4
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,25 @@
33
// RUN: %clang_cc1 -cl-std=CL3.0 -O0 -fclangir -emit-llvm -triple spirv64-unknown-unknown %s -o %t.ll
44
// RUN: FileCheck --input-file=%t.ll %s --check-prefix=LLVM
55

6-
// CIR: cir.func @func(%arg0: !cir.ptr<!s32i, addrspace(3)>
6+
// Lowering of language-specific AS not supported
7+
// XFAIL: *
8+
9+
// CIR: cir.func @func(%arg0: !cir.ptr<!s32i, addrspace(target<3>)>
710
// LLVM: @func(ptr addrspace(3)
811
kernel void func(local int *p) {
9-
// CIR-NEXT: %[[#ALLOCA_P:]] = cir.alloca !cir.ptr<!s32i, addrspace(3)>, !cir.ptr<!cir.ptr<!s32i, addrspace(3)>>, ["p", init] {alignment = 8 : i64}
12+
// CIR-NEXT: %[[#ALLOCA_P:]] = cir.alloca !cir.ptr<!s32i, addrspace(target<3>)>, !cir.ptr<!cir.ptr<!s32i, addrspace(target<3>)>>, ["p", init] {alignment = 8 : i64}
1013
// LLVM-NEXT: %[[#ALLOCA_P:]] = alloca ptr addrspace(3), i64 1, align 8
1114

1215
int x;
1316
// CIR-NEXT: %[[#ALLOCA_X:]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["x"] {alignment = 4 : i64}
1417
// LLVM-NEXT: %[[#ALLOCA_X:]] = alloca i32, i64 1, align 4
1518

1619
global char *b;
17-
// CIR-NEXT: %[[#ALLOCA_B:]] = cir.alloca !cir.ptr<!s8i, addrspace(1)>, !cir.ptr<!cir.ptr<!s8i, addrspace(1)>>, ["b"] {alignment = 8 : i64}
20+
// CIR-NEXT: %[[#ALLOCA_B:]] = cir.alloca !cir.ptr<!s8i, addrspace(target<1>)>, !cir.ptr<!cir.ptr<!s8i, addrspace(target<1>)>>, ["b"] {alignment = 8 : i64}
1821
// LLVM-NEXT: %[[#ALLOCA_B:]] = alloca ptr addrspace(1), i64 1, align 8
1922

2023
// Store of the argument `p`
21-
// CIR-NEXT: cir.store %arg0, %[[#ALLOCA_P]] : !cir.ptr<!s32i, addrspace(3)>, !cir.ptr<!cir.ptr<!s32i, addrspace(3)>>
24+
// CIR-NEXT: cir.store %arg0, %[[#ALLOCA_P]] : !cir.ptr<!s32i, addrspace(target<3>)>, !cir.ptr<!cir.ptr<!s32i, addrspace(target<3>)>>
2225
// LLVM-NEXT: store ptr addrspace(3) %{{[0-9]+}}, ptr %[[#ALLOCA_P]], align 8
2326

2427
return;

clang/test/CIR/CodeGen/OpenCL/spirv-target.cl

+3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
// RUN: %clang_cc1 -cl-std=CL3.0 -fclangir -emit-llvm -triple spirv64-unknown-unknown %s -o %t_64.ll
55
// RUN: FileCheck --input-file=%t_64.ll %s --check-prefix=LLVM-SPIRV64
66

7+
// Lowering of language-specific AS not supported
8+
// XFAIL: *
9+
710
// CIR-SPIRV64: cir.triple = "spirv64-unknown-unknown"
811
// LLVM-SPIRV64: target triple = "spirv64-unknown-unknown"
912

0 commit comments

Comments
 (0)