diff --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp index e72ca155bcf765..9e8f789d71b5ea 100644 --- a/mlir/lib/TableGen/AttrOrTypeDef.cpp +++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp @@ -17,6 +17,12 @@ using namespace mlir; using namespace mlir::tblgen; +using llvm::DefInit; +using llvm::Init; +using llvm::ListInit; +using llvm::Record; +using llvm::RecordVal; +using llvm::StringInit; //===----------------------------------------------------------------------===// // AttrOrTypeBuilder @@ -35,14 +41,13 @@ bool AttrOrTypeBuilder::hasInferredContextParameter() const { // AttrOrTypeDef //===----------------------------------------------------------------------===// -AttrOrTypeDef::AttrOrTypeDef(const llvm::Record *def) : def(def) { +AttrOrTypeDef::AttrOrTypeDef(const Record *def) : def(def) { // Populate the builders. - auto *builderList = - dyn_cast_or_null(def->getValueInit("builders")); + const auto *builderList = + dyn_cast_or_null(def->getValueInit("builders")); if (builderList && !builderList->empty()) { - for (const llvm::Init *init : builderList->getValues()) { - AttrOrTypeBuilder builder(cast(init)->getDef(), - def->getLoc()); + for (const Init *init : builderList->getValues()) { + AttrOrTypeBuilder builder(cast(init)->getDef(), def->getLoc()); // Ensure that all parameters have names. for (const AttrOrTypeBuilder::Parameter ¶m : @@ -56,16 +61,16 @@ AttrOrTypeDef::AttrOrTypeDef(const llvm::Record *def) : def(def) { // Populate the traits. if (auto *traitList = def->getValueAsListInit("traits")) { - SmallPtrSet traitSet; + SmallPtrSet traitSet; traits.reserve(traitSet.size()); - llvm::unique_function processTraitList = - [&](const llvm::ListInit *traitList) { + llvm::unique_function processTraitList = + [&](const ListInit *traitList) { for (auto *traitInit : *traitList) { if (!traitSet.insert(traitInit).second) continue; // If this is an interface, add any bases to the trait list. - auto *traitDef = cast(traitInit)->getDef(); + auto *traitDef = cast(traitInit)->getDef(); if (traitDef->isSubClassOf("Interface")) { if (auto *bases = traitDef->getValueAsListInit("baseInterfaces")) processTraitList(bases); @@ -111,7 +116,7 @@ AttrOrTypeDef::AttrOrTypeDef(const llvm::Record *def) : def(def) { } Dialect AttrOrTypeDef::getDialect() const { - auto *dialect = dyn_cast(def->getValue("dialect")->getValue()); + const auto *dialect = dyn_cast(def->getValue("dialect")->getValue()); return Dialect(dialect ? dialect->getDef() : nullptr); } @@ -126,8 +131,8 @@ StringRef AttrOrTypeDef::getCppBaseClassName() const { } bool AttrOrTypeDef::hasDescription() const { - const llvm::RecordVal *desc = def->getValue("description"); - return desc && isa(desc->getValue()); + const RecordVal *desc = def->getValue("description"); + return desc && isa(desc->getValue()); } StringRef AttrOrTypeDef::getDescription() const { @@ -135,8 +140,8 @@ StringRef AttrOrTypeDef::getDescription() const { } bool AttrOrTypeDef::hasSummary() const { - const llvm::RecordVal *summary = def->getValue("summary"); - return summary && isa(summary->getValue()); + const RecordVal *summary = def->getValue("summary"); + return summary && isa(summary->getValue()); } StringRef AttrOrTypeDef::getSummary() const { @@ -249,9 +254,9 @@ StringRef TypeDef::getTypeName() const { template auto AttrOrTypeParameter::getDefValue(StringRef name) const { std::optional().getValue())> result; - if (auto *param = dyn_cast(getDef())) - if (auto *init = param->getDef()->getValue(name)) - if (auto *value = dyn_cast_or_null(init->getValue())) + if (const auto *param = dyn_cast(getDef())) + if (const auto *init = param->getDef()->getValue(name)) + if (const auto *value = dyn_cast_or_null(init->getValue())) result = value->getValue(); return result; } @@ -270,20 +275,20 @@ std::string AttrOrTypeParameter::getAccessorName() const { } std::optional AttrOrTypeParameter::getAllocator() const { - return getDefValue("allocator"); + return getDefValue("allocator"); } StringRef AttrOrTypeParameter::getComparator() const { - return getDefValue("comparator").value_or("$_lhs == $_rhs"); + return getDefValue("comparator").value_or("$_lhs == $_rhs"); } StringRef AttrOrTypeParameter::getCppType() const { - if (auto *stringType = dyn_cast(getDef())) + if (auto *stringType = dyn_cast(getDef())) return stringType->getValue(); - auto cppType = getDefValue("cppType"); + auto cppType = getDefValue("cppType"); if (cppType) return *cppType; - if (auto *init = dyn_cast(getDef())) + if (const auto *init = dyn_cast(getDef())) llvm::PrintFatalError( init->getDef()->getLoc(), Twine("Missing `cppType` field in Attribute/Type parameter: ") + @@ -295,34 +300,33 @@ StringRef AttrOrTypeParameter::getCppType() const { } StringRef AttrOrTypeParameter::getCppAccessorType() const { - return getDefValue("cppAccessorType") - .value_or(getCppType()); + return getDefValue("cppAccessorType").value_or(getCppType()); } StringRef AttrOrTypeParameter::getCppStorageType() const { - return getDefValue("cppStorageType").value_or(getCppType()); + return getDefValue("cppStorageType").value_or(getCppType()); } StringRef AttrOrTypeParameter::getConvertFromStorage() const { - return getDefValue("convertFromStorage").value_or("$_self"); + return getDefValue("convertFromStorage").value_or("$_self"); } std::optional AttrOrTypeParameter::getParser() const { - return getDefValue("parser"); + return getDefValue("parser"); } std::optional AttrOrTypeParameter::getPrinter() const { - return getDefValue("printer"); + return getDefValue("printer"); } std::optional AttrOrTypeParameter::getSummary() const { - return getDefValue("summary"); + return getDefValue("summary"); } StringRef AttrOrTypeParameter::getSyntax() const { - if (auto *stringType = dyn_cast(getDef())) + if (auto *stringType = dyn_cast(getDef())) return stringType->getValue(); - return getDefValue("syntax").value_or(getCppType()); + return getDefValue("syntax").value_or(getCppType()); } bool AttrOrTypeParameter::isOptional() const { @@ -330,17 +334,14 @@ bool AttrOrTypeParameter::isOptional() const { } std::optional AttrOrTypeParameter::getDefaultValue() const { - std::optional result = - getDefValue("defaultValue"); + std::optional result = getDefValue("defaultValue"); return result && !result->empty() ? result : std::nullopt; } -const llvm::Init *AttrOrTypeParameter::getDef() const { - return def->getArg(index); -} +const Init *AttrOrTypeParameter::getDef() const { return def->getArg(index); } std::optional AttrOrTypeParameter::getConstraint() const { - if (auto *param = dyn_cast(getDef())) + if (const auto *param = dyn_cast(getDef())) if (param->getDef()->isSubClassOf("Constraint")) return Constraint(param->getDef()); return std::nullopt; @@ -351,8 +352,8 @@ std::optional AttrOrTypeParameter::getConstraint() const { //===----------------------------------------------------------------------===// bool AttributeSelfTypeParameter::classof(const AttrOrTypeParameter *param) { - const llvm::Init *paramDef = param->getDef(); - if (auto *paramDefInit = dyn_cast(paramDef)) + const Init *paramDef = param->getDef(); + if (const auto *paramDefInit = dyn_cast(paramDef)) return paramDefInit->getDef()->isSubClassOf("AttributeSelfTypeParameter"); return false; } diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp index 887553bca66102..f9fc58a40f334c 100644 --- a/mlir/lib/TableGen/Attribute.cpp +++ b/mlir/lib/TableGen/Attribute.cpp @@ -71,7 +71,7 @@ StringRef Attribute::getReturnType() const { // Return the type constraint corresponding to the type of this attribute, or // std::nullopt if this is not a TypedAttr. std::optional Attribute::getValueType() const { - if (auto *defInit = dyn_cast(def->getValueInit("valueType"))) + if (const auto *defInit = dyn_cast(def->getValueInit("valueType"))) return Type(defInit->getDef()); return std::nullopt; } @@ -92,8 +92,7 @@ StringRef Attribute::getConstBuilderTemplate() const { } Attribute Attribute::getBaseAttr() const { - if (const auto *defInit = - llvm::dyn_cast(def->getValueInit("baseAttr"))) { + if (const auto *defInit = dyn_cast(def->getValueInit("baseAttr"))) { return Attribute(defInit).getBaseAttr(); } return *this; @@ -132,7 +131,7 @@ Dialect Attribute::getDialect() const { return Dialect(nullptr); } -const llvm::Record &Attribute::getDef() const { return *def; } +const Record &Attribute::getDef() const { return *def; } ConstantAttr::ConstantAttr(const DefInit *init) : def(init->getDef()) { assert(def->isSubClassOf("ConstantAttr") && @@ -147,12 +146,12 @@ StringRef ConstantAttr::getConstantValue() const { return def->getValueAsString("value"); } -EnumAttrCase::EnumAttrCase(const llvm::Record *record) : Attribute(record) { +EnumAttrCase::EnumAttrCase(const Record *record) : Attribute(record) { assert(isSubClassOf("EnumAttrCaseInfo") && "must be subclass of TableGen 'EnumAttrInfo' class"); } -EnumAttrCase::EnumAttrCase(const llvm::DefInit *init) +EnumAttrCase::EnumAttrCase(const DefInit *init) : EnumAttrCase(init->getDef()) {} StringRef EnumAttrCase::getSymbol() const { @@ -163,16 +162,16 @@ StringRef EnumAttrCase::getStr() const { return def->getValueAsString("str"); } int64_t EnumAttrCase::getValue() const { return def->getValueAsInt("value"); } -const llvm::Record &EnumAttrCase::getDef() const { return *def; } +const Record &EnumAttrCase::getDef() const { return *def; } -EnumAttr::EnumAttr(const llvm::Record *record) : Attribute(record) { +EnumAttr::EnumAttr(const Record *record) : Attribute(record) { assert(isSubClassOf("EnumAttrInfo") && "must be subclass of TableGen 'EnumAttr' class"); } -EnumAttr::EnumAttr(const llvm::Record &record) : Attribute(&record) {} +EnumAttr::EnumAttr(const Record &record) : Attribute(&record) {} -EnumAttr::EnumAttr(const llvm::DefInit *init) : EnumAttr(init->getDef()) {} +EnumAttr::EnumAttr(const DefInit *init) : EnumAttr(init->getDef()) {} bool EnumAttr::classof(const Attribute *attr) { return attr->isSubClassOf("EnumAttrInfo"); @@ -218,8 +217,8 @@ std::vector EnumAttr::getAllCases() const { std::vector cases; cases.reserve(inits->size()); - for (const llvm::Init *init : *inits) { - cases.emplace_back(cast(init)); + for (const Init *init : *inits) { + cases.emplace_back(cast(init)); } return cases; @@ -229,7 +228,7 @@ bool EnumAttr::genSpecializedAttr() const { return def->getValueAsBit("genSpecializedAttr"); } -const llvm::Record *EnumAttr::getBaseAttrClass() const { +const Record *EnumAttr::getBaseAttrClass() const { return def->getValueAsDef("baseAttrClass"); } diff --git a/mlir/lib/TableGen/Builder.cpp b/mlir/lib/TableGen/Builder.cpp index 044765c726019d..a94e1cca5fc59e 100644 --- a/mlir/lib/TableGen/Builder.cpp +++ b/mlir/lib/TableGen/Builder.cpp @@ -12,6 +12,11 @@ using namespace mlir; using namespace mlir::tblgen; +using llvm::DagInit; +using llvm::DefInit; +using llvm::Init; +using llvm::Record; +using llvm::StringInit; //===----------------------------------------------------------------------===// // Builder::Parameter @@ -19,9 +24,9 @@ using namespace mlir::tblgen; /// Return a string containing the C++ type of this parameter. StringRef Builder::Parameter::getCppType() const { - if (const auto *stringInit = dyn_cast(def)) + if (const auto *stringInit = dyn_cast(def)) return stringInit->getValue(); - const llvm::Record *record = cast(def)->getDef(); + const Record *record = cast(def)->getDef(); // Inlining the first part of `Record::getValueAsString` to give better // error messages. const llvm::RecordVal *type = record->getValue("type"); @@ -35,9 +40,9 @@ StringRef Builder::Parameter::getCppType() const { /// Return an optional string containing the default value to use for this /// parameter. std::optional Builder::Parameter::getDefaultValue() const { - if (isa(def)) + if (isa(def)) return std::nullopt; - const llvm::Record *record = cast(def)->getDef(); + const Record *record = cast(def)->getDef(); std::optional value = record->getValueAsOptionalString("defaultValue"); return value && !value->empty() ? value : std::nullopt; @@ -47,18 +52,17 @@ std::optional Builder::Parameter::getDefaultValue() const { // Builder //===----------------------------------------------------------------------===// -Builder::Builder(const llvm::Record *record, ArrayRef loc) - : def(record) { +Builder::Builder(const Record *record, ArrayRef loc) : def(record) { // Initialize the parameters of the builder. - const llvm::DagInit *dag = def->getValueAsDag("dagParams"); - auto *defInit = dyn_cast(dag->getOperator()); + const DagInit *dag = def->getValueAsDag("dagParams"); + auto *defInit = dyn_cast(dag->getOperator()); if (!defInit || defInit->getDef()->getName() != "ins") PrintFatalError(def->getLoc(), "expected 'ins' in builders"); bool seenDefaultValue = false; for (unsigned i = 0, e = dag->getNumArgs(); i < e; ++i) { - const llvm::StringInit *paramName = dag->getArgName(i); - const llvm::Init *paramValue = dag->getArg(i); + const StringInit *paramName = dag->getArgName(i); + const Init *paramValue = dag->getArg(i); Parameter param(paramName ? paramName->getValue() : std::optional(), paramValue); diff --git a/mlir/lib/TableGen/CodeGenHelpers.cpp b/mlir/lib/TableGen/CodeGenHelpers.cpp index 2f13887aa0bbeb..747af1ce5a4d3d 100644 --- a/mlir/lib/TableGen/CodeGenHelpers.cpp +++ b/mlir/lib/TableGen/CodeGenHelpers.cpp @@ -24,32 +24,32 @@ using namespace mlir::tblgen; /// Generate a unique label based on the current file name to prevent name /// collisions if multiple generated files are included at once. -static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records, +static std::string getUniqueOutputLabel(const RecordKeeper &records, StringRef tag) { // Use the input file name when generating a unique name. std::string inputFilename = records.getInputFilename(); // Drop all but the base filename. - StringRef nameRef = llvm::sys::path::filename(inputFilename); + StringRef nameRef = sys::path::filename(inputFilename); nameRef.consume_back(".td"); // Sanitize any invalid characters. std::string uniqueName(tag); for (char c : nameRef) { - if (llvm::isAlnum(c) || c == '_') + if (isAlnum(c) || c == '_') uniqueName.push_back(c); else - uniqueName.append(llvm::utohexstr((unsigned char)c)); + uniqueName.append(utohexstr((unsigned char)c)); } return uniqueName; } StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter( - raw_ostream &os, const llvm::RecordKeeper &records, StringRef tag) + raw_ostream &os, const RecordKeeper &records, StringRef tag) : os(os), uniqueOutputLabel(getUniqueOutputLabel(records, tag)) {} void StaticVerifierFunctionEmitter::emitOpConstraints( - ArrayRef opDefs) { + ArrayRef opDefs) { NamespaceEmitter namespaceEmitter(os, Operator(*opDefs[0]).getCppNamespace()); emitTypeConstraints(); emitAttrConstraints(); @@ -58,7 +58,7 @@ void StaticVerifierFunctionEmitter::emitOpConstraints( } void StaticVerifierFunctionEmitter::emitPatternConstraints( - const llvm::ArrayRef constraints) { + const ArrayRef constraints) { collectPatternConstraints(constraints); emitPatternConstraints(); } @@ -298,7 +298,7 @@ void StaticVerifierFunctionEmitter::collectOpConstraints( } void StaticVerifierFunctionEmitter::collectPatternConstraints( - const llvm::ArrayRef constraints) { + const ArrayRef constraints) { for (auto &leaf : constraints) { assert(leaf.isOperandMatcher() || leaf.isAttrMatcher()); collectConstraint( @@ -313,7 +313,7 @@ void StaticVerifierFunctionEmitter::collectPatternConstraints( std::string mlir::tblgen::escapeString(StringRef value) { std::string ret; - llvm::raw_string_ostream os(ret); + raw_string_ostream os(ret); os.write_escaped(value); return ret; } diff --git a/mlir/lib/TableGen/Interfaces.cpp b/mlir/lib/TableGen/Interfaces.cpp index 4a6709a43d0a8f..dc9a74c4e8a90a 100644 --- a/mlir/lib/TableGen/Interfaces.cpp +++ b/mlir/lib/TableGen/Interfaces.cpp @@ -16,17 +16,22 @@ using namespace mlir; using namespace mlir::tblgen; +using llvm::DagInit; +using llvm::DefInit; +using llvm::Init; +using llvm::ListInit; +using llvm::Record; +using llvm::StringInit; //===----------------------------------------------------------------------===// // InterfaceMethod //===----------------------------------------------------------------------===// -InterfaceMethod::InterfaceMethod(const llvm::Record *def) : def(def) { - const llvm::DagInit *args = def->getValueAsDag("arguments"); +InterfaceMethod::InterfaceMethod(const Record *def) : def(def) { + const DagInit *args = def->getValueAsDag("arguments"); for (unsigned i = 0, e = args->getNumArgs(); i != e; ++i) { - arguments.push_back( - {llvm::cast(args->getArg(i))->getValue(), - args->getArgNameStr(i)}); + arguments.push_back({cast(args->getArg(i))->getValue(), + args->getArgNameStr(i)}); } } @@ -72,18 +77,17 @@ bool InterfaceMethod::arg_empty() const { return arguments.empty(); } // Interface //===----------------------------------------------------------------------===// -Interface::Interface(const llvm::Record *def) : def(def) { +Interface::Interface(const Record *def) : def(def) { assert(def->isSubClassOf("Interface") && "must be subclass of TableGen 'Interface' class"); // Initialize the interface methods. - auto *listInit = dyn_cast(def->getValueInit("methods")); - for (const llvm::Init *init : listInit->getValues()) - methods.emplace_back(cast(init)->getDef()); + auto *listInit = dyn_cast(def->getValueInit("methods")); + for (const Init *init : listInit->getValues()) + methods.emplace_back(cast(init)->getDef()); // Initialize the interface base classes. - auto *basesInit = - dyn_cast(def->getValueInit("baseInterfaces")); + auto *basesInit = dyn_cast(def->getValueInit("baseInterfaces")); // Chained inheritance will produce duplicates in the base interface set. StringSet<> basesAdded; llvm::unique_function addBaseInterfaceFn = @@ -98,8 +102,8 @@ Interface::Interface(const llvm::Record *def) : def(def) { baseInterfaces.push_back(std::make_unique(baseInterface)); basesAdded.insert(baseInterface.getName()); }; - for (const llvm::Init *init : basesInit->getValues()) - addBaseInterfaceFn(Interface(cast(init)->getDef())); + for (const Init *init : basesInit->getValues()) + addBaseInterfaceFn(Interface(cast(init)->getDef())); } // Return the name of this interface. diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp index 86670e9f87127c..904cc6637d53ff 100644 --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -35,9 +35,12 @@ using namespace mlir::tblgen; using llvm::DagInit; using llvm::DefInit; +using llvm::Init; +using llvm::ListInit; using llvm::Record; +using llvm::StringInit; -Operator::Operator(const llvm::Record &def) +Operator::Operator(const Record &def) : dialect(def.getValueAsDef("opDialect")), def(def) { // The first `_` in the op's TableGen def name is treated as separating the // dialect prefix and the op class name. The dialect prefix will be ignored if @@ -179,7 +182,7 @@ StringRef Operator::getExtraClassDefinition() const { return def.getValueAsString(attr); } -const llvm::Record &Operator::getDef() const { return def; } +const Record &Operator::getDef() const { return def; } bool Operator::skipDefaultBuilders() const { return def.getValueAsBit("skipDefaultBuilders"); @@ -429,7 +432,7 @@ void Operator::populateTypeInferenceInfo( // Use `AllTypesMatch` and `TypesMatchWith` operation traits to build the // result type inference graph. for (const Trait &trait : traits) { - const llvm::Record &def = trait.getDef(); + const Record &def = trait.getDef(); // If the infer type op interface was manually added, then treat it as // intention that the op needs special handling. @@ -614,9 +617,8 @@ void Operator::populateOpStructure() { def.getLoc(), "unsupported attribute modelling, only single class expected"); } - attributes.push_back( - {cast(val.getNameInit())->getValue(), - Attribute(cast(val.getValue()))}); + attributes.push_back({cast(val.getNameInit())->getValue(), + Attribute(cast(val.getValue()))}); } } @@ -701,7 +703,7 @@ void Operator::populateOpStructure() { // tablegen is easy, making them unique less so, so dedupe here. if (auto *traitList = def.getValueAsListInit("traits")) { // This is uniquing based on pointers of the trait. - SmallPtrSet traitSet; + SmallPtrSet traitSet; traits.reserve(traitSet.size()); // The declaration order of traits imply the verification order of traits. @@ -721,8 +723,8 @@ void Operator::populateOpStructure() { " to precede it in traits list"); }; - std::function insert; - insert = [&](const llvm::ListInit *traitList) { + std::function insert; + insert = [&](const ListInit *traitList) { for (auto *traitInit : *traitList) { auto *def = cast(traitInit)->getDef(); if (def->isSubClassOf("TraitList")) { @@ -777,11 +779,10 @@ void Operator::populateOpStructure() { } // Populate the builders. - auto *builderList = - dyn_cast_or_null(def.getValueInit("builders")); + auto *builderList = dyn_cast_or_null(def.getValueInit("builders")); if (builderList && !builderList->empty()) { - for (const llvm::Init *init : builderList->getValues()) - builders.emplace_back(cast(init)->getDef(), def.getLoc()); + for (const Init *init : builderList->getValues()) + builders.emplace_back(cast(init)->getDef(), def.getLoc()); } else if (skipDefaultBuilders()) { PrintFatalError( def.getLoc(), @@ -814,13 +815,12 @@ StringRef Operator::getSummary() const { bool Operator::hasAssemblyFormat() const { auto *valueInit = def.getValueInit("assemblyFormat"); - return isa(valueInit); + return isa(valueInit); } StringRef Operator::getAssemblyFormat() const { - return TypeSwitch( - def.getValueInit("assemblyFormat")) - .Case([&](auto *init) { return init->getValue(); }); + return TypeSwitch(def.getValueInit("assemblyFormat")) + .Case([&](auto *init) { return init->getValue(); }); } void Operator::print(llvm::raw_ostream &os) const { @@ -833,9 +833,9 @@ void Operator::print(llvm::raw_ostream &os) const { } } -auto Operator::VariableDecoratorIterator::unwrap(const llvm::Init *init) +auto Operator::VariableDecoratorIterator::unwrap(const Init *init) -> VariableDecorator { - return VariableDecorator(cast(init)->getDef()); + return VariableDecorator(cast(init)->getDef()); } auto Operator::getArgToOperandOrAttribute(int index) const diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp index bee20354387fd6..ffa0c067b02858 100644 --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -26,7 +26,12 @@ using namespace mlir; using namespace tblgen; +using llvm::DagInit; +using llvm::dbgs; +using llvm::DefInit; using llvm::formatv; +using llvm::IntInit; +using llvm::Record; //===----------------------------------------------------------------------===// // DagLeaf @@ -61,31 +66,31 @@ bool DagLeaf::isStringAttr() const { return isa(def); } Constraint DagLeaf::getAsConstraint() const { assert((isOperandMatcher() || isAttrMatcher()) && "the DAG leaf must be operand or attribute"); - return Constraint(cast(def)->getDef()); + return Constraint(cast(def)->getDef()); } ConstantAttr DagLeaf::getAsConstantAttr() const { assert(isConstantAttr() && "the DAG leaf must be constant attribute"); - return ConstantAttr(cast(def)); + return ConstantAttr(cast(def)); } EnumAttrCase DagLeaf::getAsEnumAttrCase() const { assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case"); - return EnumAttrCase(cast(def)); + return EnumAttrCase(cast(def)); } std::string DagLeaf::getConditionTemplate() const { return getAsConstraint().getConditionTemplate(); } -llvm::StringRef DagLeaf::getNativeCodeTemplate() const { +StringRef DagLeaf::getNativeCodeTemplate() const { assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); - return cast(def)->getDef()->getValueAsString("expression"); + return cast(def)->getDef()->getValueAsString("expression"); } int DagLeaf::getNumReturnsOfNativeCode() const { assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); - return cast(def)->getDef()->getValueAsInt("numReturns"); + return cast(def)->getDef()->getValueAsInt("numReturns"); } std::string DagLeaf::getStringAttr() const { @@ -93,7 +98,7 @@ std::string DagLeaf::getStringAttr() const { return def->getAsUnquotedString(); } bool DagLeaf::isSubClassOf(StringRef superclass) const { - if (auto *defInit = dyn_cast_or_null(def)) + if (auto *defInit = dyn_cast_or_null(def)) return defInit->getDef()->isSubClassOf(superclass); return false; } @@ -108,7 +113,7 @@ void DagLeaf::print(raw_ostream &os) const { //===----------------------------------------------------------------------===// bool DagNode::isNativeCodeCall() const { - if (auto *defInit = dyn_cast_or_null(node->getOperator())) + if (auto *defInit = dyn_cast_or_null(node->getOperator())) return defInit->getDef()->isSubClassOf("NativeCodeCall"); return false; } @@ -119,25 +124,24 @@ bool DagNode::isOperation() const { !isVariadic(); } -llvm::StringRef DagNode::getNativeCodeTemplate() const { +StringRef DagNode::getNativeCodeTemplate() const { assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); - return cast(node->getOperator()) + return cast(node->getOperator()) ->getDef() ->getValueAsString("expression"); } int DagNode::getNumReturnsOfNativeCode() const { assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); - return cast(node->getOperator()) + return cast(node->getOperator()) ->getDef() ->getValueAsInt("numReturns"); } -llvm::StringRef DagNode::getSymbol() const { return node->getNameStr(); } +StringRef DagNode::getSymbol() const { return node->getNameStr(); } Operator &DagNode::getDialectOp(RecordOperatorMap *mapper) const { - const llvm::Record *opDef = - cast(node->getOperator())->getDef(); + const Record *opDef = cast(node->getOperator())->getDef(); auto [it, inserted] = mapper->try_emplace(opDef); if (inserted) it->second = std::make_unique(opDef); @@ -158,11 +162,11 @@ int DagNode::getNumOps() const { int DagNode::getNumArgs() const { return node->getNumArgs(); } bool DagNode::isNestedDagArg(unsigned index) const { - return isa(node->getArg(index)); + return isa(node->getArg(index)); } DagNode DagNode::getArgAsNestedDag(unsigned index) const { - return DagNode(dyn_cast_or_null(node->getArg(index))); + return DagNode(dyn_cast_or_null(node->getArg(index))); } DagLeaf DagNode::getArgAsLeaf(unsigned index) const { @@ -175,27 +179,27 @@ StringRef DagNode::getArgName(unsigned index) const { } bool DagNode::isReplaceWithValue() const { - auto *dagOpDef = cast(node->getOperator())->getDef(); + auto *dagOpDef = cast(node->getOperator())->getDef(); return dagOpDef->getName() == "replaceWithValue"; } bool DagNode::isLocationDirective() const { - auto *dagOpDef = cast(node->getOperator())->getDef(); + auto *dagOpDef = cast(node->getOperator())->getDef(); return dagOpDef->getName() == "location"; } bool DagNode::isReturnTypeDirective() const { - auto *dagOpDef = cast(node->getOperator())->getDef(); + auto *dagOpDef = cast(node->getOperator())->getDef(); return dagOpDef->getName() == "returnType"; } bool DagNode::isEither() const { - auto *dagOpDef = cast(node->getOperator())->getDef(); + auto *dagOpDef = cast(node->getOperator())->getDef(); return dagOpDef->getName() == "either"; } bool DagNode::isVariadic() const { - auto *dagOpDef = cast(node->getOperator())->getDef(); + auto *dagOpDef = cast(node->getOperator())->getDef(); return dagOpDef->getName() == "variadic"; } @@ -246,7 +250,7 @@ std::string SymbolInfoMap::SymbolInfo::getVarName(StringRef name) const { } std::string SymbolInfoMap::SymbolInfo::getVarTypeStr(StringRef name) const { - LLVM_DEBUG(llvm::dbgs() << "getVarTypeStr for '" << name << "': "); + LLVM_DEBUG(dbgs() << "getVarTypeStr for '" << name << "': "); switch (kind) { case Kind::Attr: { if (op) @@ -277,26 +281,26 @@ std::string SymbolInfoMap::SymbolInfo::getVarTypeStr(StringRef name) const { } std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const { - LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': "); + LLVM_DEBUG(dbgs() << "getVarDecl for '" << name << "': "); std::string varInit = kind == Kind::Operand ? "(op0->getOperands())" : ""; return std::string( formatv("{0} {1}{2};\n", getVarTypeStr(name), getVarName(name), varInit)); } std::string SymbolInfoMap::SymbolInfo::getArgDecl(StringRef name) const { - LLVM_DEBUG(llvm::dbgs() << "getArgDecl for '" << name << "': "); + LLVM_DEBUG(dbgs() << "getArgDecl for '" << name << "': "); return std::string( formatv("{0} &{1}", getVarTypeStr(name), getVarName(name))); } std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse( StringRef name, int index, const char *fmt, const char *separator) const { - LLVM_DEBUG(llvm::dbgs() << "getValueAndRangeUse for '" << name << "': "); + LLVM_DEBUG(dbgs() << "getValueAndRangeUse for '" << name << "': "); switch (kind) { case Kind::Attr: { assert(index < 0); auto repl = formatv(fmt, name); - LLVM_DEBUG(llvm::dbgs() << repl << " (Attr)\n"); + LLVM_DEBUG(dbgs() << repl << " (Attr)\n"); return std::string(repl); } case Kind::Operand: { @@ -307,11 +311,11 @@ std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse( // the value itself. if (operand->isVariableLength() && !getVariadicSubIndex().has_value()) { auto repl = formatv(fmt, name); - LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicOperand)\n"); + LLVM_DEBUG(dbgs() << repl << " (VariadicOperand)\n"); return std::string(repl); } auto repl = formatv(fmt, formatv("(*{0}.begin())", name)); - LLVM_DEBUG(llvm::dbgs() << repl << " (SingleOperand)\n"); + LLVM_DEBUG(dbgs() << repl << " (SingleOperand)\n"); return std::string(repl); } case Kind::Result: { @@ -323,14 +327,14 @@ std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse( if (!op->getResult(index).isVariadic()) v = std::string(formatv("(*{0}.begin())", v)); auto repl = formatv(fmt, v); - LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n"); + LLVM_DEBUG(dbgs() << repl << " (SingleResult)\n"); return std::string(repl); } // If this op has no result at all but still we bind a symbol to it, it // means we want to capture the op itself. if (op->getNumResults() == 0) { - LLVM_DEBUG(llvm::dbgs() << name << " (Op)\n"); + LLVM_DEBUG(dbgs() << name << " (Op)\n"); return formatv(fmt, name); } @@ -347,14 +351,14 @@ std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse( values.push_back(std::string(formatv(fmt, v))); } auto repl = llvm::join(values, separator); - LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n"); + LLVM_DEBUG(dbgs() << repl << " (VariadicResult)\n"); return repl; } case Kind::Value: { assert(index < 0); assert(op == nullptr); auto repl = formatv(fmt, name); - LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n"); + LLVM_DEBUG(dbgs() << repl << " (Value)\n"); return std::string(repl); } case Kind::MultipleValues: { @@ -363,13 +367,13 @@ std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse( if (index >= 0) { std::string repl = formatv(fmt, std::string(formatv("{0}[{1}]", name, index))); - LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n"); + LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n"); return repl; } // If it doesn't specify certain element, unpack them all. auto repl = formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name))); - LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n"); + LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n"); return std::string(repl); } } @@ -378,19 +382,19 @@ std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse( std::string SymbolInfoMap::SymbolInfo::getAllRangeUse( StringRef name, int index, const char *fmt, const char *separator) const { - LLVM_DEBUG(llvm::dbgs() << "getAllRangeUse for '" << name << "': "); + LLVM_DEBUG(dbgs() << "getAllRangeUse for '" << name << "': "); switch (kind) { case Kind::Attr: case Kind::Operand: { assert(index < 0 && "only allowed for symbol bound to result"); auto repl = formatv(fmt, name); - LLVM_DEBUG(llvm::dbgs() << repl << " (Operand/Attr)\n"); + LLVM_DEBUG(dbgs() << repl << " (Operand/Attr)\n"); return std::string(repl); } case Kind::Result: { if (index >= 0) { auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index)); - LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n"); + LLVM_DEBUG(dbgs() << repl << " (SingleResult)\n"); return std::string(repl); } @@ -404,14 +408,14 @@ std::string SymbolInfoMap::SymbolInfo::getAllRangeUse( formatv(fmt, formatv("{0}.getODSResults({1})", name, i)))); } auto repl = llvm::join(values, separator); - LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n"); + LLVM_DEBUG(dbgs() << repl << " (VariadicResult)\n"); return repl; } case Kind::Value: { assert(index < 0 && "only allowed for symbol bound to result"); assert(op == nullptr); auto repl = formatv(fmt, formatv("{{{0}}", name)); - LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n"); + LLVM_DEBUG(dbgs() << repl << " (Value)\n"); return std::string(repl); } case Kind::MultipleValues: { @@ -420,12 +424,12 @@ std::string SymbolInfoMap::SymbolInfo::getAllRangeUse( if (index >= 0) { std::string repl = formatv(fmt, std::string(formatv("{0}[{1}]", name, index))); - LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n"); + LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n"); return repl; } auto repl = formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name))); - LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n"); + LLVM_DEBUG(dbgs() << repl << " (MultipleValues)\n"); return std::string(repl); } } @@ -614,7 +618,7 @@ void SymbolInfoMap::assignUniqueAlternativeNames() { // Pattern //==----------------------------------------------------------------------===// -Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper) +Pattern::Pattern(const Record *def, RecordOperatorMap *mapper) : def(*def), recordOpMap(mapper) {} DagNode Pattern::getSourcePattern() const { @@ -628,26 +632,26 @@ int Pattern::getNumResultPatterns() const { DagNode Pattern::getResultPattern(unsigned index) const { auto *results = def.getValueAsListInit("resultPatterns"); - return DagNode(cast(results->getElement(index))); + return DagNode(cast(results->getElement(index))); } void Pattern::collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap) { - LLVM_DEBUG(llvm::dbgs() << "start collecting source pattern bound symbols\n"); + LLVM_DEBUG(dbgs() << "start collecting source pattern bound symbols\n"); collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true); - LLVM_DEBUG(llvm::dbgs() << "done collecting source pattern bound symbols\n"); + LLVM_DEBUG(dbgs() << "done collecting source pattern bound symbols\n"); - LLVM_DEBUG(llvm::dbgs() << "start assigning alternative names for symbols\n"); + LLVM_DEBUG(dbgs() << "start assigning alternative names for symbols\n"); infoMap.assignUniqueAlternativeNames(); - LLVM_DEBUG(llvm::dbgs() << "done assigning alternative names for symbols\n"); + LLVM_DEBUG(dbgs() << "done assigning alternative names for symbols\n"); } void Pattern::collectResultPatternBoundSymbols(SymbolInfoMap &infoMap) { - LLVM_DEBUG(llvm::dbgs() << "start collecting result pattern bound symbols\n"); + LLVM_DEBUG(dbgs() << "start collecting result pattern bound symbols\n"); for (int i = 0, e = getNumResultPatterns(); i < e; ++i) { auto pattern = getResultPattern(i); collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false); } - LLVM_DEBUG(llvm::dbgs() << "done collecting result pattern bound symbols\n"); + LLVM_DEBUG(dbgs() << "done collecting result pattern bound symbols\n"); } const Operator &Pattern::getSourceRootOp() { @@ -664,7 +668,7 @@ std::vector Pattern::getConstraints() const { ret.reserve(listInit->size()); for (auto *it : *listInit) { - auto *dagInit = dyn_cast(it); + auto *dagInit = dyn_cast(it); if (!dagInit) PrintFatalError(&def, "all elements in Pattern multi-entity " "constraints should be DAG nodes"); @@ -680,7 +684,7 @@ std::vector Pattern::getConstraints() const { entities.emplace_back(argName->getValue()); } - ret.emplace_back(cast(dagInit->getOperator())->getDef(), + ret.emplace_back(cast(dagInit->getOperator())->getDef(), dagInit->getNameStr(), std::move(entities)); } return ret; @@ -693,19 +697,19 @@ int Pattern::getNumSupplementalPatterns() const { DagNode Pattern::getSupplementalPattern(unsigned index) const { auto *results = def.getValueAsListInit("supplementalPatterns"); - return DagNode(cast(results->getElement(index))); + return DagNode(cast(results->getElement(index))); } int Pattern::getBenefit() const { // The initial benefit value is a heuristic with number of ops in the source // pattern. int initBenefit = getSourcePattern().getNumOps(); - const llvm::DagInit *delta = def.getValueAsDag("benefitDelta"); - if (delta->getNumArgs() != 1 || !isa(delta->getArg(0))) { + const DagInit *delta = def.getValueAsDag("benefitDelta"); + if (delta->getNumArgs() != 1 || !isa(delta->getArg(0))) { PrintFatalError(&def, "The 'addBenefit' takes and only takes one integer value"); } - return initBenefit + dyn_cast(delta->getArg(0))->getValue(); + return initBenefit + dyn_cast(delta->getArg(0))->getValue(); } std::vector Pattern::getLocation() const { @@ -736,8 +740,8 @@ void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, if (tree.isNativeCodeCall()) { if (!treeName.empty()) { if (!isSrcPattern) { - LLVM_DEBUG(llvm::dbgs() << "found symbol bound to NativeCodeCall: " - << treeName << '\n'); + LLVM_DEBUG(dbgs() << "found symbol bound to NativeCodeCall: " + << treeName << '\n'); verifyBind( infoMap.bindValues(treeName, tree.getNumReturnsOfNativeCode()), treeName); @@ -820,8 +824,8 @@ void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, // The name attached to the DAG node's operator is for representing the // results generated from this op. It should be remembered as bound results. if (!treeName.empty()) { - LLVM_DEBUG(llvm::dbgs() - << "found symbol bound to op result: " << treeName << '\n'); + LLVM_DEBUG(dbgs() << "found symbol bound to op result: " << treeName + << '\n'); verifyBind(infoMap.bindOpResult(treeName, op), treeName); } @@ -896,8 +900,8 @@ void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, auto treeArgName = tree.getArgName(i); // `$_` is a special symbol meaning ignore the current argument. if (!treeArgName.empty() && treeArgName != "_") { - LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: " - << treeArgName << '\n'); + LLVM_DEBUG(dbgs() << "found symbol bound to op argument: " + << treeArgName << '\n'); verifyBind(infoMap.bindOpArgument(tree, treeArgName, op, opArgIdx), treeArgName); } diff --git a/mlir/lib/TableGen/Predicate.cpp b/mlir/lib/TableGen/Predicate.cpp index 0e38dab8491c07..f71dd0bd35f86c 100644 --- a/mlir/lib/TableGen/Predicate.cpp +++ b/mlir/lib/TableGen/Predicate.cpp @@ -20,15 +20,18 @@ using namespace mlir; using namespace tblgen; +using llvm::Init; +using llvm::Record; +using llvm::SpecificBumpPtrAllocator; // Construct a Predicate from a record. -Pred::Pred(const llvm::Record *record) : def(record) { +Pred::Pred(const Record *record) : def(record) { assert(def->isSubClassOf("Pred") && "must be a subclass of TableGen 'Pred' class"); } // Construct a Predicate from an initializer. -Pred::Pred(const llvm::Init *init) { +Pred::Pred(const Init *init) { if (const auto *defInit = dyn_cast_or_null(init)) def = defInit->getDef(); } @@ -48,12 +51,12 @@ bool Pred::isCombined() const { ArrayRef Pred::getLoc() const { return def->getLoc(); } -CPred::CPred(const llvm::Record *record) : Pred(record) { +CPred::CPred(const Record *record) : Pred(record) { assert(def->isSubClassOf("CPred") && "must be a subclass of Tablegen 'CPred' class"); } -CPred::CPred(const llvm::Init *init) : Pred(init) { +CPred::CPred(const Init *init) : Pred(init) { assert((!def || def->isSubClassOf("CPred")) && "must be a subclass of Tablegen 'CPred' class"); } @@ -64,22 +67,22 @@ std::string CPred::getConditionImpl() const { return std::string(def->getValueAsString("predExpr")); } -CombinedPred::CombinedPred(const llvm::Record *record) : Pred(record) { +CombinedPred::CombinedPred(const Record *record) : Pred(record) { assert(def->isSubClassOf("CombinedPred") && "must be a subclass of Tablegen 'CombinedPred' class"); } -CombinedPred::CombinedPred(const llvm::Init *init) : Pred(init) { +CombinedPred::CombinedPred(const Init *init) : Pred(init) { assert((!def || def->isSubClassOf("CombinedPred")) && "must be a subclass of Tablegen 'CombinedPred' class"); } -const llvm::Record *CombinedPred::getCombinerDef() const { +const Record *CombinedPred::getCombinerDef() const { assert(def->getValue("kind") && "CombinedPred must have a value 'kind'"); return def->getValueAsDef("kind"); } -std::vector CombinedPred::getChildren() const { +std::vector CombinedPred::getChildren() const { assert(def->getValue("children") && "CombinedPred must have a value 'children'"); return def->getValueAsListOfDefs("children"); @@ -156,7 +159,7 @@ static void performSubstitutions(std::string &str, // All nodes are created within "allocator". static PredNode * buildPredicateTree(const Pred &root, - llvm::SpecificBumpPtrAllocator &allocator, + SpecificBumpPtrAllocator &allocator, ArrayRef substitutions) { auto *rootNode = allocator.Allocate(); new (rootNode) PredNode; @@ -351,7 +354,7 @@ static std::string getCombinedCondition(const PredNode &root) { } std::string CombinedPred::getConditionImpl() const { - llvm::SpecificBumpPtrAllocator allocator; + SpecificBumpPtrAllocator allocator; auto *predicateTree = buildPredicateTree(*this, allocator, {}); predicateTree = propagateGroundTruth(predicateTree, diff --git a/mlir/lib/TableGen/Type.cpp b/mlir/lib/TableGen/Type.cpp index c3b813ec598d0a..4f74056947abe1 100644 --- a/mlir/lib/TableGen/Type.cpp +++ b/mlir/lib/TableGen/Type.cpp @@ -18,6 +18,7 @@ using namespace mlir; using namespace mlir::tblgen; +using llvm::Record; TypeConstraint::TypeConstraint(const llvm::DefInit *init) : TypeConstraint(init->getDef()) {} @@ -42,7 +43,7 @@ StringRef TypeConstraint::getVariadicOfVariadicSegmentSizeAttr() const { // Returns the builder call for this constraint if this is a buildable type, // returns std::nullopt otherwise. std::optional TypeConstraint::getBuilderCall() const { - const llvm::Record *baseType = def; + const Record *baseType = def; if (isVariableLength()) baseType = baseType->getValueAsDef("baseType"); @@ -64,7 +65,7 @@ StringRef TypeConstraint::getCppType() const { return def->getValueAsString("cppType"); } -Type::Type(const llvm::Record *record) : TypeConstraint(record) {} +Type::Type(const Record *record) : TypeConstraint(record) {} Dialect Type::getDialect() const { return Dialect(def->getValueAsDef("dialect"));