Skip to content

Commit

Permalink
Add support for immutable strings.
Browse files Browse the repository at this point in the history
We want a way to store strings in a global variable by passing in a
StringRef to an op, but we don't want to generate a setter for it, since
we currently don't inject the builder into the setter. So, add an
immutable string type based on the existing isImmutable option for
attributes.
  • Loading branch information
Thomas Symalla authored and tsymalla-AMD committed May 23, 2024
1 parent ae1b86b commit ed4b46e
Show file tree
Hide file tree
Showing 7 changed files with 156 additions and 21 deletions.
10 changes: 10 additions & 0 deletions example/ExampleDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -317,3 +317,13 @@ def ImmutableOp : Op<ExampleDialect, "immutable.op", [WillReturn]> {
Make an argument immutable
}];
}

def StringAttrOp : Op<ExampleDialect, "string.attr.op", [WillReturn]> {
let results = (outs);
let arguments = (ins ImmutableStringAttr:$val);

let summary = "demonstrate an argument that takes in a StringRef";
let description = [{
The argument should not have a setter method
}];
}
6 changes: 6 additions & 0 deletions example/ExampleMain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ void createFunctionExample(Module &module, const Twine &name) {
moreVarArgs.push_back(b.getInt32(4));
b.create<xd::InstNameConflictVarargsOp>(moreVarArgs, "four.varargs");

b.create<xd::StringAttrOp>("Hello world!");

b.CreateRetVoid();
}

Expand Down Expand Up @@ -242,6 +244,10 @@ template <bool rpot> const Visitor<VisitorContainer> &getExampleVisitor() {
for (Value *arg : op.getArgs())
out << " " << *arg << '\n';
});
b.add<xd::StringAttrOp>(
[](raw_ostream &out, xd::StringAttrOp &op) {
out << "visiting StringAttrOp: " << op.getVal() << '\n';
});
b.add<ReturnInst>([](raw_ostream &out, ReturnInst &ret) {
out << "visiting ReturnInst: " << ret << '\n';
});
Expand Down
9 changes: 9 additions & 0 deletions include/llvm-dialects/Dialect/Dialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,15 @@ def : AttrLlvmType<AttrI16, I16>;
def : AttrLlvmType<AttrI32, I32>;
def : AttrLlvmType<AttrI64, I64>;

def ImmutableStringAttr : Attr<"::llvm::StringRef"> {
let toLlvmValue = [{ $_builder.CreateGlobalString($0) }];
let fromLlvmValue = [{ ::llvm::cast<::llvm::ConstantDataArray>(::llvm::cast<::llvm::GlobalVariable>($0)->getInitializer())->getAsString() }];
let isImmutable = true;
}

// Global string variables are essentially pointers in addrspace(0).
def : AttrLlvmType<ImmutableStringAttr, Ptr>;

// ============================================================================
/// More general attributes
// ============================================================================
Expand Down
120 changes: 101 additions & 19 deletions test/example/generated/ExampleDialect.cpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,11 @@ namespace xd {
state.setError();
});

builder.add<StringAttrOp>([](::llvm_dialects::VerifierState &state, StringAttrOp &op) {
if (!op.verifier(state.out()))
state.setError();
});

builder.add<WriteOp>([](::llvm_dialects::VerifierState &state, WriteOp &op) {
if (!op.verifier(state.out()))
state.setError();
Expand All @@ -154,21 +159,21 @@ namespace xd {
::llvm::AttrBuilder attrBuilder{context};
attrBuilder.addAttribute(::llvm::Attribute::NoUnwind);
attrBuilder.addAttribute(::llvm::Attribute::WillReturn);
attrBuilder.addMemoryAttr(::llvm::MemoryEffects(::llvm::ModRefInfo::Ref));
attrBuilder.addMemoryAttr(::llvm::MemoryEffects::none());
m_attributeLists[0] = ::llvm::AttributeList::get(context, ::llvm::AttributeList::FunctionIndex, attrBuilder);
}
{
::llvm::AttrBuilder attrBuilder{context};
attrBuilder.addAttribute(::llvm::Attribute::NoUnwind);
attrBuilder.addAttribute(::llvm::Attribute::WillReturn);
attrBuilder.addMemoryAttr(::llvm::MemoryEffects(::llvm::MemoryEffects::Location::InaccessibleMem, ::llvm::ModRefInfo::Mod));
attrBuilder.addMemoryAttr(::llvm::MemoryEffects(::llvm::ModRefInfo::Ref));
m_attributeLists[1] = ::llvm::AttributeList::get(context, ::llvm::AttributeList::FunctionIndex, attrBuilder);
}
{
::llvm::AttrBuilder attrBuilder{context};
attrBuilder.addAttribute(::llvm::Attribute::NoUnwind);
attrBuilder.addAttribute(::llvm::Attribute::WillReturn);
attrBuilder.addMemoryAttr(::llvm::MemoryEffects::none());
attrBuilder.addMemoryAttr(::llvm::MemoryEffects(::llvm::MemoryEffects::Location::InaccessibleMem, ::llvm::ModRefInfo::Mod));
m_attributeLists[2] = ::llvm::AttributeList::get(context, ::llvm::AttributeList::FunctionIndex, attrBuilder);
}
{
Expand Down Expand Up @@ -329,7 +334,7 @@ return true;


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(2);
= ExampleDialect::get(context).getAttributeList(0);
auto fnType = ::llvm::FunctionType::get(::llvm::IntegerType::get(context, 32), {
lhs->getType(),
rhs->getType(),
Expand Down Expand Up @@ -451,7 +456,7 @@ uint32_t const extra = getExtra();


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(2);
= ExampleDialect::get(context).getAttributeList(0);

std::string mangledName =
::llvm_dialects::getMangledName(s_name, {lhs->getType()});
Expand Down Expand Up @@ -546,7 +551,7 @@ rhs


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(2);
= ExampleDialect::get(context).getAttributeList(0);

std::string mangledName =
::llvm_dialects::getMangledName(s_name, {::llvm::cast<XdVectorType>(vector->getType())->getElementType()});
Expand Down Expand Up @@ -650,7 +655,7 @@ index


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(2);
= ExampleDialect::get(context).getAttributeList(0);

std::string mangledName =
::llvm_dialects::getMangledName(s_name, {resultType});
Expand Down Expand Up @@ -820,7 +825,7 @@ source


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(2);
= ExampleDialect::get(context).getAttributeList(0);
auto fnType = ::llvm::FunctionType::get(XdHandleType::get(context), {
}, false);

Expand Down Expand Up @@ -882,7 +887,7 @@ source


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(2);
= ExampleDialect::get(context).getAttributeList(0);

std::string mangledName =
::llvm_dialects::getMangledName(s_name, {resultType});
Expand Down Expand Up @@ -980,7 +985,7 @@ source


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(2);
= ExampleDialect::get(context).getAttributeList(0);

std::string mangledName =
::llvm_dialects::getMangledName(s_name, {resultType});
Expand Down Expand Up @@ -1113,7 +1118,7 @@ source
(void)context;

using ::llvm_dialects::printable;

if (arg_size() != 1) {
errs << " wrong number of arguments: " << arg_size()
<< ", expected 1\n";
Expand Down Expand Up @@ -1147,7 +1152,7 @@ source


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(2);
= ExampleDialect::get(context).getAttributeList(0);

std::string mangledName =
::llvm_dialects::getMangledName(s_name, {vector->getType()});
Expand Down Expand Up @@ -1607,7 +1612,7 @@ instName


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(1);
= ExampleDialect::get(context).getAttributeList(2);
auto fnType = ::llvm::FunctionType::get(::llvm::Type::getVoidTy(context), true);

auto fn = module.getOrInsertFunction(s_name, fnType, attrs);
Expand Down Expand Up @@ -1670,7 +1675,7 @@ data


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(2);
= ExampleDialect::get(context).getAttributeList(0);
auto fnType = ::llvm::FunctionType::get(::llvm::IntegerType::get(context, 64), true);

auto fn = module.getOrInsertFunction(s_name, fnType, attrs);
Expand Down Expand Up @@ -1744,7 +1749,7 @@ data


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(0);
= ExampleDialect::get(context).getAttributeList(1);

std::string mangledName =
::llvm_dialects::getMangledName(s_name, {initial->getType()});
Expand Down Expand Up @@ -1836,7 +1841,7 @@ initial


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(0);
= ExampleDialect::get(context).getAttributeList(1);

std::string mangledName =
::llvm_dialects::getMangledName(s_name, {initial->getType()});
Expand Down Expand Up @@ -1928,7 +1933,7 @@ initial


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(0);
= ExampleDialect::get(context).getAttributeList(1);

std::string mangledName =
::llvm_dialects::getMangledName(s_name, {initial->getType()});
Expand Down Expand Up @@ -2011,6 +2016,75 @@ initial



const ::llvm::StringLiteral StringAttrOp::s_name{"xd.string.attr.op"};

StringAttrOp* StringAttrOp::create(llvm_dialects::Builder& b, ::llvm::StringRef val, const llvm::Twine &instName) {
::llvm::LLVMContext& context = b.getContext();
(void)context;
::llvm::Module& module = *b.GetInsertBlock()->getModule();


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(4);
auto fnType = ::llvm::FunctionType::get(::llvm::Type::getVoidTy(context), {
::llvm::PointerType::get(::llvm::Type::getInt8Ty(context), 0),
}, false);

auto fn = module.getOrInsertFunction(s_name, fnType, attrs);
::llvm::SmallString<32> newName;
for (unsigned i = 0; !::llvm::isa<::llvm::Function>(fn.getCallee()) ||
::llvm::cast<::llvm::Function>(fn.getCallee())->getFunctionType() != fn.getFunctionType(); i++) {
// If a function with the same name but a different types already exists,
// we get a bitcast of a function or a function with the wrong type.
// Try new names until we get one with the correct type.
newName = "";
::llvm::raw_svector_ostream newNameStream(newName);
newNameStream << s_name << "_" << i;
fn = module.getOrInsertFunction(newNameStream.str(), fnType, attrs);
}
assert(::llvm::isa<::llvm::Function>(fn.getCallee()));
assert(fn.getFunctionType() == fnType);
assert(::llvm::cast<::llvm::Function>(fn.getCallee())->getFunctionType() == fn.getFunctionType());


::llvm::SmallVector<::llvm::Value*, 1> args = {
b.CreateGlobalString(val)
};

return ::llvm::cast<StringAttrOp>(b.CreateCall(fn, args, instName));
}


bool StringAttrOp::verifier(::llvm::raw_ostream &errs) {
::llvm::LLVMContext &context = getModule()->getContext();
(void)context;

using ::llvm_dialects::printable;

if (arg_size() != 1) {
errs << " wrong number of arguments: " << arg_size()
<< ", expected 1\n";
return false;
}

if (getArgOperand(0)->getType() != ::llvm::PointerType::get(::llvm::Type::getInt8Ty(context), 0)) {
errs << " argument 0 (val) has type: "
<< *getArgOperand(0)->getType() << '\n';
errs << " expected: " << *::llvm::PointerType::get(::llvm::Type::getInt8Ty(context), 0) << '\n';
return false;
}
::llvm::StringRef const val = getVal();
(void)val;
return true;
}


::llvm::StringRef StringAttrOp::getVal() {
return ::llvm::cast<::llvm::ConstantDataArray>(::llvm::cast<::llvm::GlobalVariable>(getArgOperand(0))->getInitializer())->getAsString() ;
}



const ::llvm::StringLiteral WriteOp::s_name{"xd.write"};

WriteOp* WriteOp::create(llvm_dialects::Builder& b, ::llvm::Value * data, const llvm::Twine &instName) {
Expand All @@ -2020,7 +2094,7 @@ initial


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(1);
= ExampleDialect::get(context).getAttributeList(2);
auto fnType = ::llvm::FunctionType::get(::llvm::Type::getVoidTy(context), true);

auto fn = module.getOrInsertFunction(s_name, fnType, attrs);
Expand Down Expand Up @@ -2083,7 +2157,7 @@ data


const ::llvm::AttributeList attrs
= ExampleDialect::get(context).getAttributeList(1);
= ExampleDialect::get(context).getAttributeList(2);
auto fnType = ::llvm::FunctionType::get(::llvm::Type::getVoidTy(context), true);

auto fn = module.getOrInsertFunction(s_name, fnType, attrs);
Expand Down Expand Up @@ -2297,6 +2371,14 @@ data
}


template <>
const ::llvm_dialects::OpDescription &
::llvm_dialects::OpDescription::get<xd::StringAttrOp>() {
static const ::llvm_dialects::OpDescription desc{false, "xd.string.attr.op"};
return desc;
}


template <>
const ::llvm_dialects::OpDescription &
::llvm_dialects::OpDescription::get<xd::WriteOp>() {
Expand Down
20 changes: 20 additions & 0 deletions test/example/generated/ExampleDialect.h.inc
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,26 @@ bool verifier(::llvm::raw_ostream &errs);
::llvm::Value * getResult();


};

class StringAttrOp : public ::llvm::CallInst {
static const ::llvm::StringLiteral s_name; //{"xd.string.attr.op"};

public:
static bool classof(const ::llvm::CallInst* i) {
return ::llvm_dialects::detail::isSimpleOperation(i, s_name);
}
static bool classof(const ::llvm::Value* v) {
return ::llvm::isa<::llvm::CallInst>(v) &&
classof(::llvm::cast<::llvm::CallInst>(v));
}
static StringAttrOp* create(::llvm_dialects::Builder& b, ::llvm::StringRef val, const llvm::Twine &instName = "");

bool verifier(::llvm::raw_ostream &errs);

::llvm::StringRef getVal();


};

class WriteOp : public ::llvm::CallInst {
Expand Down
6 changes: 5 additions & 1 deletion test/example/test-builder.test
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --tool llvm-dialects-example --include-generated-funcs
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --tool llvm-dialects-example --include-generated-funcs --check-globals
; NOTE: stdin isn't used by the example program, but the redirect makes the UTC tool happy.
; RUN: llvm-dialects-example - | FileCheck --check-prefixes=CHECK %s

;.
; CHECK: @[[GLOB0:[0-9]+]] = private unnamed_addr constant [13 x i8] c"Hello world!\00", align 1
;.
; CHECK-LABEL: @example(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[TMP0:%.*]] = call i32 @xd.read__i32()
Expand Down Expand Up @@ -42,5 +45,6 @@
; CHECK-NEXT: [[TWO_VARARGS:%.*]] = call i32 (...) @xd.inst.name.conflict.varargs(ptr [[P1]], i8 [[P2]])
; CHECK-NEXT: [[THREE_VARARGS:%.*]] = call i32 (...) @xd.inst.name.conflict.varargs(ptr [[P1]], i8 [[P2]], i32 3)
; CHECK-NEXT: [[FOUR_VARARGS:%.*]] = call i32 (...) @xd.inst.name.conflict.varargs(ptr [[P1]], i8 [[P2]], i32 3, i32 4)
; CHECK-NEXT: call void @xd.string.attr.op(ptr @[[GLOB0:[0-9]+]])
; CHECK-NEXT: ret void
;
6 changes: 5 additions & 1 deletion test/example/visitor-basic.ll
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
; DEFAULT-NEXT: %v2 =
; DEFAULT-NEXT: %q =
; DEFAULT-NEXT: visiting umin (set): %vm = call i32 @llvm.umin.i32(i32 %v1, i32 %q)
; DEFAULT-NEXT: visiting StringAttrOp: Hello world!
; DEFAULT-NEXT: visiting Ret (set): ret void
; DEFAULT-NEXT: visiting ReturnInst: ret void
; DEFAULT-NEXT: inner.counter = 1

@0 = private unnamed_addr constant [13 x i8] c"Hello world!\00", align 1

define void @test1(ptr %p) {
entry:
%v = call i32 @xd.read__i32()
Expand All @@ -36,6 +39,7 @@ entry:
call void (...) @xd.set.write(i8 %v.2)
call void (...) @xd.write.vararg(i8 %t, i32 %v2, i32 %q)
%vm = call i32 @llvm.umin.i32(i32 %v1, i32 %q)
call void @xd.string.attr.op(ptr @0)
ret void
}

Expand All @@ -46,6 +50,6 @@ declare void @xd.write(...)
declare void @xd.set.write(...)
declare void @xd.write.vararg(...)
declare i8 @xd.itrunc__i8(...)

declare void @xd.string.attr.op(ptr)
declare i32 @llvm.umax.i32(i32, i32)
declare i32 @llvm.umin.i32(i32, i32)

0 comments on commit ed4b46e

Please sign in to comment.