Skip to content

Commit

Permalink
[NFC][MLIR][TableGen] Eliminate llvm:: for commonly used types (llv…
Browse files Browse the repository at this point in the history
…m#112456)

Eliminate `llvm::` namespace qualifier for commonly used types in MLIR
TableGen backends to reduce code clutter.
  • Loading branch information
jurahul authored Oct 18, 2024
1 parent 6e02e19 commit 659192b
Show file tree
Hide file tree
Showing 9 changed files with 193 additions and 177 deletions.
81 changes: 41 additions & 40 deletions mlir/lib/TableGen/AttrOrTypeDef.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@

using namespace mlir;
using namespace mlir::tblgen;
using llvm::DefInit;
using llvm::Init;
using llvm::ListInit;
using llvm::Record;
using llvm::RecordVal;
using llvm::StringInit;

//===----------------------------------------------------------------------===//
// AttrOrTypeBuilder
Expand All @@ -35,14 +41,13 @@ bool AttrOrTypeBuilder::hasInferredContextParameter() const {
// AttrOrTypeDef
//===----------------------------------------------------------------------===//

AttrOrTypeDef::AttrOrTypeDef(const llvm::Record *def) : def(def) {
AttrOrTypeDef::AttrOrTypeDef(const Record *def) : def(def) {
// Populate the builders.
auto *builderList =
dyn_cast_or_null<llvm::ListInit>(def->getValueInit("builders"));
const auto *builderList =
dyn_cast_or_null<ListInit>(def->getValueInit("builders"));
if (builderList && !builderList->empty()) {
for (const llvm::Init *init : builderList->getValues()) {
AttrOrTypeBuilder builder(cast<llvm::DefInit>(init)->getDef(),
def->getLoc());
for (const Init *init : builderList->getValues()) {
AttrOrTypeBuilder builder(cast<DefInit>(init)->getDef(), def->getLoc());

// Ensure that all parameters have names.
for (const AttrOrTypeBuilder::Parameter &param :
Expand All @@ -56,16 +61,16 @@ AttrOrTypeDef::AttrOrTypeDef(const llvm::Record *def) : def(def) {

// Populate the traits.
if (auto *traitList = def->getValueAsListInit("traits")) {
SmallPtrSet<const llvm::Init *, 32> traitSet;
SmallPtrSet<const Init *, 32> traitSet;
traits.reserve(traitSet.size());
llvm::unique_function<void(const llvm::ListInit *)> processTraitList =
[&](const llvm::ListInit *traitList) {
llvm::unique_function<void(const ListInit *)> processTraitList =
[&](const ListInit *traitList) {
for (auto *traitInit : *traitList) {
if (!traitSet.insert(traitInit).second)
continue;

// If this is an interface, add any bases to the trait list.
auto *traitDef = cast<llvm::DefInit>(traitInit)->getDef();
auto *traitDef = cast<DefInit>(traitInit)->getDef();
if (traitDef->isSubClassOf("Interface")) {
if (auto *bases = traitDef->getValueAsListInit("baseInterfaces"))
processTraitList(bases);
Expand Down Expand Up @@ -111,7 +116,7 @@ AttrOrTypeDef::AttrOrTypeDef(const llvm::Record *def) : def(def) {
}

Dialect AttrOrTypeDef::getDialect() const {
auto *dialect = dyn_cast<llvm::DefInit>(def->getValue("dialect")->getValue());
const auto *dialect = dyn_cast<DefInit>(def->getValue("dialect")->getValue());
return Dialect(dialect ? dialect->getDef() : nullptr);
}

Expand All @@ -126,17 +131,17 @@ StringRef AttrOrTypeDef::getCppBaseClassName() const {
}

bool AttrOrTypeDef::hasDescription() const {
const llvm::RecordVal *desc = def->getValue("description");
return desc && isa<llvm::StringInit>(desc->getValue());
const RecordVal *desc = def->getValue("description");
return desc && isa<StringInit>(desc->getValue());
}

StringRef AttrOrTypeDef::getDescription() const {
return def->getValueAsString("description");
}

bool AttrOrTypeDef::hasSummary() const {
const llvm::RecordVal *summary = def->getValue("summary");
return summary && isa<llvm::StringInit>(summary->getValue());
const RecordVal *summary = def->getValue("summary");
return summary && isa<StringInit>(summary->getValue());
}

StringRef AttrOrTypeDef::getSummary() const {
Expand Down Expand Up @@ -249,9 +254,9 @@ StringRef TypeDef::getTypeName() const {
template <typename InitT>
auto AttrOrTypeParameter::getDefValue(StringRef name) const {
std::optional<decltype(std::declval<InitT>().getValue())> result;
if (auto *param = dyn_cast<llvm::DefInit>(getDef()))
if (auto *init = param->getDef()->getValue(name))
if (auto *value = dyn_cast_or_null<InitT>(init->getValue()))
if (const auto *param = dyn_cast<DefInit>(getDef()))
if (const auto *init = param->getDef()->getValue(name))
if (const auto *value = dyn_cast_or_null<InitT>(init->getValue()))
result = value->getValue();
return result;
}
Expand All @@ -270,20 +275,20 @@ std::string AttrOrTypeParameter::getAccessorName() const {
}

std::optional<StringRef> AttrOrTypeParameter::getAllocator() const {
return getDefValue<llvm::StringInit>("allocator");
return getDefValue<StringInit>("allocator");
}

StringRef AttrOrTypeParameter::getComparator() const {
return getDefValue<llvm::StringInit>("comparator").value_or("$_lhs == $_rhs");
return getDefValue<StringInit>("comparator").value_or("$_lhs == $_rhs");
}

StringRef AttrOrTypeParameter::getCppType() const {
if (auto *stringType = dyn_cast<llvm::StringInit>(getDef()))
if (auto *stringType = dyn_cast<StringInit>(getDef()))
return stringType->getValue();
auto cppType = getDefValue<llvm::StringInit>("cppType");
auto cppType = getDefValue<StringInit>("cppType");
if (cppType)
return *cppType;
if (auto *init = dyn_cast<llvm::DefInit>(getDef()))
if (const auto *init = dyn_cast<DefInit>(getDef()))
llvm::PrintFatalError(
init->getDef()->getLoc(),
Twine("Missing `cppType` field in Attribute/Type parameter: ") +
Expand All @@ -295,52 +300,48 @@ StringRef AttrOrTypeParameter::getCppType() const {
}

StringRef AttrOrTypeParameter::getCppAccessorType() const {
return getDefValue<llvm::StringInit>("cppAccessorType")
.value_or(getCppType());
return getDefValue<StringInit>("cppAccessorType").value_or(getCppType());
}

StringRef AttrOrTypeParameter::getCppStorageType() const {
return getDefValue<llvm::StringInit>("cppStorageType").value_or(getCppType());
return getDefValue<StringInit>("cppStorageType").value_or(getCppType());
}

StringRef AttrOrTypeParameter::getConvertFromStorage() const {
return getDefValue<llvm::StringInit>("convertFromStorage").value_or("$_self");
return getDefValue<StringInit>("convertFromStorage").value_or("$_self");
}

std::optional<StringRef> AttrOrTypeParameter::getParser() const {
return getDefValue<llvm::StringInit>("parser");
return getDefValue<StringInit>("parser");
}

std::optional<StringRef> AttrOrTypeParameter::getPrinter() const {
return getDefValue<llvm::StringInit>("printer");
return getDefValue<StringInit>("printer");
}

std::optional<StringRef> AttrOrTypeParameter::getSummary() const {
return getDefValue<llvm::StringInit>("summary");
return getDefValue<StringInit>("summary");
}

StringRef AttrOrTypeParameter::getSyntax() const {
if (auto *stringType = dyn_cast<llvm::StringInit>(getDef()))
if (auto *stringType = dyn_cast<StringInit>(getDef()))
return stringType->getValue();
return getDefValue<llvm::StringInit>("syntax").value_or(getCppType());
return getDefValue<StringInit>("syntax").value_or(getCppType());
}

bool AttrOrTypeParameter::isOptional() const {
return getDefaultValue().has_value();
}

std::optional<StringRef> AttrOrTypeParameter::getDefaultValue() const {
std::optional<StringRef> result =
getDefValue<llvm::StringInit>("defaultValue");
std::optional<StringRef> result = getDefValue<StringInit>("defaultValue");
return result && !result->empty() ? result : std::nullopt;
}

const llvm::Init *AttrOrTypeParameter::getDef() const {
return def->getArg(index);
}
const Init *AttrOrTypeParameter::getDef() const { return def->getArg(index); }

std::optional<Constraint> AttrOrTypeParameter::getConstraint() const {
if (auto *param = dyn_cast<llvm::DefInit>(getDef()))
if (const auto *param = dyn_cast<DefInit>(getDef()))
if (param->getDef()->isSubClassOf("Constraint"))
return Constraint(param->getDef());
return std::nullopt;
Expand All @@ -351,8 +352,8 @@ std::optional<Constraint> AttrOrTypeParameter::getConstraint() const {
//===----------------------------------------------------------------------===//

bool AttributeSelfTypeParameter::classof(const AttrOrTypeParameter *param) {
const llvm::Init *paramDef = param->getDef();
if (auto *paramDefInit = dyn_cast<llvm::DefInit>(paramDef))
const Init *paramDef = param->getDef();
if (const auto *paramDefInit = dyn_cast<DefInit>(paramDef))
return paramDefInit->getDef()->isSubClassOf("AttributeSelfTypeParameter");
return false;
}
25 changes: 12 additions & 13 deletions mlir/lib/TableGen/Attribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ StringRef Attribute::getReturnType() const {
// Return the type constraint corresponding to the type of this attribute, or
// std::nullopt if this is not a TypedAttr.
std::optional<Type> Attribute::getValueType() const {
if (auto *defInit = dyn_cast<llvm::DefInit>(def->getValueInit("valueType")))
if (const auto *defInit = dyn_cast<DefInit>(def->getValueInit("valueType")))
return Type(defInit->getDef());
return std::nullopt;
}
Expand All @@ -92,8 +92,7 @@ StringRef Attribute::getConstBuilderTemplate() const {
}

Attribute Attribute::getBaseAttr() const {
if (const auto *defInit =
llvm::dyn_cast<llvm::DefInit>(def->getValueInit("baseAttr"))) {
if (const auto *defInit = dyn_cast<DefInit>(def->getValueInit("baseAttr"))) {
return Attribute(defInit).getBaseAttr();
}
return *this;
Expand Down Expand Up @@ -132,7 +131,7 @@ Dialect Attribute::getDialect() const {
return Dialect(nullptr);
}

const llvm::Record &Attribute::getDef() const { return *def; }
const Record &Attribute::getDef() const { return *def; }

ConstantAttr::ConstantAttr(const DefInit *init) : def(init->getDef()) {
assert(def->isSubClassOf("ConstantAttr") &&
Expand All @@ -147,12 +146,12 @@ StringRef ConstantAttr::getConstantValue() const {
return def->getValueAsString("value");
}

EnumAttrCase::EnumAttrCase(const llvm::Record *record) : Attribute(record) {
EnumAttrCase::EnumAttrCase(const Record *record) : Attribute(record) {
assert(isSubClassOf("EnumAttrCaseInfo") &&
"must be subclass of TableGen 'EnumAttrInfo' class");
}

EnumAttrCase::EnumAttrCase(const llvm::DefInit *init)
EnumAttrCase::EnumAttrCase(const DefInit *init)
: EnumAttrCase(init->getDef()) {}

StringRef EnumAttrCase::getSymbol() const {
Expand All @@ -163,16 +162,16 @@ StringRef EnumAttrCase::getStr() const { return def->getValueAsString("str"); }

int64_t EnumAttrCase::getValue() const { return def->getValueAsInt("value"); }

const llvm::Record &EnumAttrCase::getDef() const { return *def; }
const Record &EnumAttrCase::getDef() const { return *def; }

EnumAttr::EnumAttr(const llvm::Record *record) : Attribute(record) {
EnumAttr::EnumAttr(const Record *record) : Attribute(record) {
assert(isSubClassOf("EnumAttrInfo") &&
"must be subclass of TableGen 'EnumAttr' class");
}

EnumAttr::EnumAttr(const llvm::Record &record) : Attribute(&record) {}
EnumAttr::EnumAttr(const Record &record) : Attribute(&record) {}

EnumAttr::EnumAttr(const llvm::DefInit *init) : EnumAttr(init->getDef()) {}
EnumAttr::EnumAttr(const DefInit *init) : EnumAttr(init->getDef()) {}

bool EnumAttr::classof(const Attribute *attr) {
return attr->isSubClassOf("EnumAttrInfo");
Expand Down Expand Up @@ -218,8 +217,8 @@ std::vector<EnumAttrCase> EnumAttr::getAllCases() const {
std::vector<EnumAttrCase> cases;
cases.reserve(inits->size());

for (const llvm::Init *init : *inits) {
cases.emplace_back(cast<llvm::DefInit>(init));
for (const Init *init : *inits) {
cases.emplace_back(cast<DefInit>(init));
}

return cases;
Expand All @@ -229,7 +228,7 @@ bool EnumAttr::genSpecializedAttr() const {
return def->getValueAsBit("genSpecializedAttr");
}

const llvm::Record *EnumAttr::getBaseAttrClass() const {
const Record *EnumAttr::getBaseAttrClass() const {
return def->getValueAsDef("baseAttrClass");
}

Expand Down
24 changes: 14 additions & 10 deletions mlir/lib/TableGen/Builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,21 @@

using namespace mlir;
using namespace mlir::tblgen;
using llvm::DagInit;
using llvm::DefInit;
using llvm::Init;
using llvm::Record;
using llvm::StringInit;

//===----------------------------------------------------------------------===//
// Builder::Parameter
//===----------------------------------------------------------------------===//

/// Return a string containing the C++ type of this parameter.
StringRef Builder::Parameter::getCppType() const {
if (const auto *stringInit = dyn_cast<llvm::StringInit>(def))
if (const auto *stringInit = dyn_cast<StringInit>(def))
return stringInit->getValue();
const llvm::Record *record = cast<llvm::DefInit>(def)->getDef();
const Record *record = cast<DefInit>(def)->getDef();
// Inlining the first part of `Record::getValueAsString` to give better
// error messages.
const llvm::RecordVal *type = record->getValue("type");
Expand All @@ -35,9 +40,9 @@ StringRef Builder::Parameter::getCppType() const {
/// Return an optional string containing the default value to use for this
/// parameter.
std::optional<StringRef> Builder::Parameter::getDefaultValue() const {
if (isa<llvm::StringInit>(def))
if (isa<StringInit>(def))
return std::nullopt;
const llvm::Record *record = cast<llvm::DefInit>(def)->getDef();
const Record *record = cast<DefInit>(def)->getDef();
std::optional<StringRef> value =
record->getValueAsOptionalString("defaultValue");
return value && !value->empty() ? value : std::nullopt;
Expand All @@ -47,18 +52,17 @@ std::optional<StringRef> Builder::Parameter::getDefaultValue() const {
// Builder
//===----------------------------------------------------------------------===//

Builder::Builder(const llvm::Record *record, ArrayRef<SMLoc> loc)
: def(record) {
Builder::Builder(const Record *record, ArrayRef<SMLoc> loc) : def(record) {
// Initialize the parameters of the builder.
const llvm::DagInit *dag = def->getValueAsDag("dagParams");
auto *defInit = dyn_cast<llvm::DefInit>(dag->getOperator());
const DagInit *dag = def->getValueAsDag("dagParams");
auto *defInit = dyn_cast<DefInit>(dag->getOperator());
if (!defInit || defInit->getDef()->getName() != "ins")
PrintFatalError(def->getLoc(), "expected 'ins' in builders");

bool seenDefaultValue = false;
for (unsigned i = 0, e = dag->getNumArgs(); i < e; ++i) {
const llvm::StringInit *paramName = dag->getArgName(i);
const llvm::Init *paramValue = dag->getArg(i);
const StringInit *paramName = dag->getArgName(i);
const Init *paramValue = dag->getArg(i);
Parameter param(paramName ? paramName->getValue()
: std::optional<StringRef>(),
paramValue);
Expand Down
Loading

0 comments on commit 659192b

Please sign in to comment.