diff --git a/compiler/passes/convert-typed-uast.cpp b/compiler/passes/convert-typed-uast.cpp index 3bcccea919f7..9cccd04a9f1d 100644 --- a/compiler/passes/convert-typed-uast.cpp +++ b/compiler/passes/convert-typed-uast.cpp @@ -1098,9 +1098,7 @@ void TConverter::createMainFunctions() { // TODO: add converter or QualifiedType methods to more // easily construct a QualifiedType for common values like param false. - auto falseQt = types::QualifiedType(types::QualifiedType::PARAM, - types::BoolType::get(context), - types::BoolParam::get(context, false)); + auto falseQt = types::QualifiedType::makeParamBool(context, false); auto ci = resolution::CallInfo(UniqueString::get(context, "_endCountAlloc"), /* calledType */ types::QualifiedType(), /* isMethodCall */ false, diff --git a/frontend/include/chpl/types/DomainType.h b/frontend/include/chpl/types/DomainType.h index 738ad7d8560f..49e006f517fe 100644 --- a/frontend/include/chpl/types/DomainType.h +++ b/frontend/include/chpl/types/DomainType.h @@ -99,6 +99,7 @@ class DomainType final : public CompositeType { /** Return an associative domain type */ static const DomainType* getAssociativeType(Context* context, + const QualifiedType& instance, const QualifiedType& idxType, const QualifiedType& parSafe); diff --git a/frontend/include/chpl/types/QualifiedType.h b/frontend/include/chpl/types/QualifiedType.h index ccdaca35dc6d..bba647775571 100644 --- a/frontend/include/chpl/types/QualifiedType.h +++ b/frontend/include/chpl/types/QualifiedType.h @@ -71,6 +71,12 @@ class QualifiedType final { static const char* kindToString(Kind k); + // Convenience functions to construct param types + static QualifiedType makeParamBool(Context* context, bool b); + static QualifiedType makeParamInt(Context* context, int64_t i); + static QualifiedType makeParamString(Context* context, UniqueString s); + static QualifiedType makeParamString(Context* context, std::string s); + private: Kind kind_ = UNKNOWN; const Type* type_ = nullptr; diff --git a/frontend/include/chpl/uast/prim-ops-list.h b/frontend/include/chpl/uast/prim-ops-list.h index 0111b01105c2..009ae064e9e0 100644 --- a/frontend/include/chpl/uast/prim-ops-list.h +++ b/frontend/include/chpl/uast/prim-ops-list.h @@ -139,8 +139,6 @@ PRIMITIVE_R(QUERY, "query") PRIMITIVE_R(QUERY_PARAM_FIELD, "query param field") PRIMITIVE_R(QUERY_TYPE_FIELD, "query type field") -PRIMITIVE_R(STATIC_DOMAIN_TYPE, "static domain type") - PRIMITIVE_R(STATIC_FUNCTION_VAR, "static function var") PRIMITIVE_R(STATIC_FUNCTION_VAR_VALIDATE_TYPE, "static function validate type") PRIMITIVE_R(STATIC_FUNCTION_VAR_WRAPPER, "static function var wrapper") diff --git a/frontend/lib/resolution/InitResolver.cpp b/frontend/lib/resolution/InitResolver.cpp index 89a95538ccd1..5da6f7ea6918 100644 --- a/frontend/lib/resolution/InitResolver.cpp +++ b/frontend/lib/resolution/InitResolver.cpp @@ -295,9 +295,9 @@ static const DomainType* domainTypeFromSubsHelper( if (auto instanceBct = instanceCt->basicClassType()) { // Get BaseRectangularDom parent subs for rectangular domain info if (auto baseDom = instanceBct->parentClassType()) { - auto& rf = fieldsForTypeDecl(context, baseDom, - DefaultsPolicy::IGNORE_DEFAULTS); if (baseDom->id().symbolPath() == "ChapelDistribution.BaseRectangularDom") { + auto& rf = fieldsForTypeDecl(context, baseDom, + DefaultsPolicy::IGNORE_DEFAULTS); CHPL_ASSERT(rf.numFields() == 3); QualifiedType rank; QualifiedType idxType; @@ -315,7 +315,24 @@ static const DomainType* domainTypeFromSubsHelper( return DomainType::getRectangularType(context, instanceQt, rank, idxType, strides); } else if (baseDom->id().symbolPath() == "ChapelDistribution.BaseAssociativeDom") { - // TODO: support associative domains + // Currently the relevant associative domain fields are defined + // on all the children of BaseAssociativeDom, so get information + // from there. + auto& rf = fieldsForTypeDecl(context, instanceBct, + DefaultsPolicy::IGNORE_DEFAULTS); + CHPL_ASSERT(rf.numFields() >= 2); + QualifiedType idxType; + QualifiedType parSafe; + for (int i = 0; i < rf.numFields(); i++) { + if (rf.fieldName(i) == "idxType") { + idxType = rf.fieldType(i); + } else if (rf.fieldName(i) == "parSafe") { + parSafe = rf.fieldType(i); + } + } + + return DomainType::getAssociativeType(context, instanceQt, idxType, + parSafe); } else if (baseDom->id().symbolPath() == "ChapelDistribution.BaseSparseDom") { // TODO: support sparse domains } else { diff --git a/frontend/lib/resolution/Resolver.cpp b/frontend/lib/resolution/Resolver.cpp index 5c5b8d89f834..021b5df38b64 100644 --- a/frontend/lib/resolution/Resolver.cpp +++ b/frontend/lib/resolution/Resolver.cpp @@ -1056,9 +1056,7 @@ void Resolver::resolveTypeQueries(const AstNode* formalTypeExpr, if (isNonStarVarArg) { varArgTypeQueryError(context, call->actual(0), resolvedWidth); } else { - auto p = IntParam::get(context, pt->bitwidth()); - auto it = IntType::get(context, 0); - auto qt = QualifiedType(QualifiedType::PARAM, it, p); + auto qt = QualifiedType::makeParamInt(context, pt->bitwidth()); resolvedWidth.setType(qt); } } @@ -1218,10 +1216,8 @@ void Resolver::resolveTypeQueriesFromFormalType(const VarLikeDecl* formal, // args...?n if (auto countQuery = varargs->count()) { - auto intType = IntType::get(context, 0); - auto val = IntParam::get(context, tuple->numElements()); ResolvedExpression& result = byPostorder.byAst(countQuery); - result.setType(QualifiedType(QualifiedType::PARAM, intType, val)); + result.setType(QualifiedType::makeParamInt(context, tuple->numElements())); } if (auto typeExpr = formal->typeExpression()) { @@ -2453,8 +2449,7 @@ bool Resolver::resolveSpecialPrimitiveCall(const Call* call) { resultBool = false; } - result = QualifiedType(QualifiedType::PARAM, BoolType::get(context), - BoolParam::get(context, resultBool)); + result = QualifiedType::makeParamBool(context, resultBool); } byPostorder.byAst(primCall).setType(result); @@ -2748,9 +2743,7 @@ QualifiedType Resolver::typeForId(const ID& id, bool localGenericToUnknown) { if (field && field->name() == USTR("size")) { // Tuples don't store a 'size' in their substitutions map, so // manually take care of things here. - auto intType = IntType::get(context, 0); - auto val = IntParam::get(context, tup->numElements()); - return QualifiedType(QualifiedType::PARAM, intType, val); + return QualifiedType::makeParamInt(context, tup->numElements()); } } @@ -4174,10 +4167,8 @@ void Resolver::exit(const uast::Domain* decl) { // Add definedConst actual if appropriate if (decl->usedCurlyBraces()) { - actuals.emplace_back( - QualifiedType(QualifiedType::PARAM, BoolType::get(context), - BoolParam::get(context, true)), - UniqueString()); + actuals.emplace_back(QualifiedType::makeParamBool(context, true)), + UniqueString(); } auto ci = CallInfo(/* name */ UniqueString::get(context, domainBuilderProc), @@ -4235,9 +4226,7 @@ types::QualifiedType Resolver::typeForBooleanOp(const uast::OpCall* op) { // preserve param-ness // this case is only hit when the result is false (for &&) // or when the result is true (for ||), so return !isAnd. - return QualifiedType(QualifiedType::PARAM, - BoolType::get(context), - BoolParam::get(context, !isAnd)); + return QualifiedType::makeParamBool(context, !isAnd); } else { // otherwise just return a Bool value return QualifiedType(QualifiedType::CONST_VAR, diff --git a/frontend/lib/resolution/VarScopeVisitor.cpp b/frontend/lib/resolution/VarScopeVisitor.cpp index 0fe8d139dedb..625c4f600f94 100644 --- a/frontend/lib/resolution/VarScopeVisitor.cpp +++ b/frontend/lib/resolution/VarScopeVisitor.cpp @@ -386,6 +386,20 @@ void VarScopeVisitor::exitAst(const uast::AstNode* ast) { inAstStack.pop_back(); } +bool VarScopeVisitor::enter(const TupleDecl* ast, RV& rv) { + enterAst(ast); + enterScope(ast, rv); + + // TODO: handle tuple decls + return false; +} +void VarScopeVisitor::exit(const TupleDecl* ast, RV& rv) { + exitScope(ast, rv); + exitAst(ast); + + return; +} + bool VarScopeVisitor::enter(const NamedDecl* ast, RV& rv) { if (ast->id().isSymbolDefiningScope()) { diff --git a/frontend/lib/resolution/VarScopeVisitor.h b/frontend/lib/resolution/VarScopeVisitor.h index 25c48310316e..0bfabe2482dd 100644 --- a/frontend/lib/resolution/VarScopeVisitor.h +++ b/frontend/lib/resolution/VarScopeVisitor.h @@ -213,6 +213,9 @@ class VarScopeVisitor { bool enter(const NamedDecl* ast, RV& rv); void exit(const NamedDecl* ast, RV& rv); + bool enter(const TupleDecl* ast, RV& rv); + void exit(const TupleDecl* ast, RV& rv); + bool enter(const OpCall* ast, RV& rv); void exit(const OpCall* ast, RV& rv); diff --git a/frontend/lib/resolution/call-init-deinit.cpp b/frontend/lib/resolution/call-init-deinit.cpp index f1fc81c12ab8..c49e11b40f6f 100644 --- a/frontend/lib/resolution/call-init-deinit.cpp +++ b/frontend/lib/resolution/call-init-deinit.cpp @@ -788,8 +788,8 @@ void CallInitDeinit::handleDeclaration(const VarLikeDecl* ast, RV& rv) { VarFrame* frame = currentFrame(); // check for use of deinited variables in type or init exprs - if (auto init = ast->initExpression()) { - processMentions(init, rv); + if (auto type = ast->typeExpression()) { + processMentions(type, rv); } if (auto init = ast->initExpression()) { processMentions(init, rv); diff --git a/frontend/lib/resolution/prims.cpp b/frontend/lib/resolution/prims.cpp index 764e31aad742..29b84583fe6a 100644 --- a/frontend/lib/resolution/prims.cpp +++ b/frontend/lib/resolution/prims.cpp @@ -101,26 +101,6 @@ static bool toParamStringActual(const QualifiedType& type, UniqueString& into) { return paramStringBytesHelper(type, into, true); } -static QualifiedType makeParamBool(Context* context, bool b) { - return { QualifiedType::PARAM, BoolType::get(context), - BoolParam::get(context, b) }; -} - -static QualifiedType makeParamInt(Context* context, int64_t i) { - return { QualifiedType::PARAM, IntType::get(context, 0), - IntParam::get(context, i) }; -} - -static QualifiedType makeParamString(Context* context, UniqueString s) { - return { QualifiedType::PARAM, RecordType::getStringType(context), - StringParam::get(context, s) }; -} - -static QualifiedType makeParamString(Context* context, std::string s) { - auto ustr = UniqueString::get(context, s); - return makeParamString(context, ustr); -} - static QualifiedType primIsBound(Context* context, const CallInfo& ci) { auto type = QualifiedType(); if (ci.numActuals() != 2) return type; @@ -140,7 +120,7 @@ static QualifiedType primIsBound(Context* context, const CallInfo& ci) { // will only return true if the field's type is concrete. auto isBound = fields->fieldType(i).genericity() == Type::Genericity::CONCRETE; - type = makeParamBool(context, isBound); + type = QualifiedType::makeParamBool(context, isBound); break; } } @@ -154,9 +134,7 @@ static QualifiedType primNumFields(Context* context, const CallInfo& ci) { auto firstActual = ci.actual(0).type(); if (auto fields = toCompositeTypeActualFields(context, firstActual)) { int64_t numFields = fields->numFields(); - type = QualifiedType(QualifiedType::PARAM, - IntType::get(context, 64), - IntParam::get(context, numFields)); + type = QualifiedType::makeParamInt(context, numFields); } return type; } @@ -174,9 +152,7 @@ static QualifiedType primFieldNumToName(Context* context, const CallInfo& ci) { if (fieldNum > fields->numFields() || fieldNum < 1) return type; auto fieldName = fields->fieldName(fieldNum - 1); - type = QualifiedType(QualifiedType::PARAM, - RecordType::getStringType(context), - StringParam::get(context, fieldName)); + type = QualifiedType::makeParamString(context, fieldName); } return type; } @@ -204,9 +180,7 @@ static QualifiedType primFieldNameToNum(Context* context, const CallInfo& ci) { } if (!foundField) return type; - type = QualifiedType(QualifiedType::PARAM, - IntType::get(context, 64), - IntParam::get(context, field)); + type = QualifiedType::makeParamInt(context, field); } return type; } @@ -283,9 +257,7 @@ static QualifiedType primCallResolves(ResolutionContext* rc, } } - return QualifiedType(QualifiedType::PARAM, - BoolType::get(context), - BoolParam::get(context, callAndFnResolved)); + return QualifiedType::makeParamBool(context, callAndFnResolved); } static QualifiedType primImplementsInterface(Context* context, @@ -312,7 +284,7 @@ static QualifiedType primImplementsInterface(Context* context, findMatchingImplementationPoint(&rc, instantiatedIft, inScopes); if (witness) { - return makeParamInt(context, 0); + return QualifiedType::makeParamInt(context, 0); } // try automatically satisfy the interface if it's in the standard modules. @@ -323,26 +295,7 @@ static QualifiedType primImplementsInterface(Context* context, witness = runResult.result(); } - return makeParamInt(context, witness ? 1 : 2); -} - -static QualifiedType computeDomainType(Context* context, const CallInfo& ci) { - if (ci.numActuals() == 3) { - auto type = DomainType::getRectangularType(context, - QualifiedType(), - ci.actual(0).type(), - ci.actual(1).type(), - ci.actual(2).type()); - return QualifiedType(QualifiedType::TYPE, type); - } else if (ci.numActuals() == 2) { - auto type = DomainType::getAssociativeType(context, - ci.actual(0).type(), - ci.actual(1).type()); - return QualifiedType(QualifiedType::TYPE, type); - } else { - CHPL_ASSERT(false && "unhandled domain type?"); - } - return QualifiedType(); + return QualifiedType::makeParamInt(context, witness ? 1 : 2); } static QualifiedType primAddrOf(Context* context, const CallInfo& ci) { @@ -555,9 +508,7 @@ static QualifiedType primGatherTests(Context* context, const CallInfo& ci) { QUERY_STORE_INPUT_RESULT(gatheredTestsQuery, context, finder.fns); auto numFoundFns = (int) gatheredTestsQuery(context).size(); - return QualifiedType(QualifiedType::PARAM, - IntType::get(context, 0), - IntParam::get(context, numFoundFns)); + return QualifiedType::makeParamInt(context, numFoundFns); } static QualifiedType primIsTuple(Context* context, @@ -567,9 +518,7 @@ static QualifiedType primIsTuple(Context* context, if (actualType.kind() != QualifiedType::TYPE) return QualifiedType(); bool isTupleType = actualType.type() && actualType.type()->isTupleType(); - return QualifiedType(QualifiedType::PARAM, - BoolType::get(context), - BoolParam::get(context, isTupleType)); + return QualifiedType::makeParamBool(context, isTupleType); } static QualifiedType primCast(Context* context, @@ -754,8 +703,7 @@ static QualifiedType primFamilyIsSubtype(Context* context, } } - return QualifiedType(QualifiedType::PARAM, BoolType::get(context), - BoolParam::get(context, result)); + return QualifiedType::makeParamBool(context, result); } static QualifiedType primToNilableClass(Context* context, @@ -883,8 +831,7 @@ static QualifiedType primFamilyCopyableAssignable(Context* context, const bool isCopyableOrAssignable = info.isFromConst() || (info.isFromRef() && isFromRefOk); - return QualifiedType(QualifiedType::PARAM, BoolType::get(context), - BoolParam::get(context, isCopyableOrAssignable)); + return QualifiedType::makeParamBool(context, isCopyableOrAssignable); } // TODO: What should be done for 'MAYBE_GENERIC', if anything? @@ -915,7 +862,7 @@ static QualifiedType primIsGenericType(Context* context, const CallInfo& ci) { // Both cases are considered 'generic' for this primitive. eval = (g == Type::GENERIC || g == Type::GENERIC_WITH_DEFAULTS); } - return makeParamBool(context, eval); + return QualifiedType::makeParamBool(context, eval); } @@ -932,7 +879,7 @@ static QualifiedType primIsClassType(Context* context, const CallInfo& ci) { const bool isDdata = t->hasPragma(context, PRAGMA_DATA_CLASS); eval = isClassLike && !isExtern && !isDdata; } - return makeParamBool(context, eval); + return QualifiedType::makeParamBool(context, eval); } template @@ -941,7 +888,7 @@ actualTypeHasProperty(Context* context, const CallInfo& ci, F&& f) { if (ci.numActuals() < 1) return QualifiedType(); bool eval = false; if (auto t = ci.actual(0).type().type()) eval = f(t); - return makeParamBool(context, eval); + return QualifiedType::makeParamBool(context, eval); } static QualifiedType @@ -968,7 +915,7 @@ static QualifiedType primIsRecordType(Context* context, const CallInfo& ci) { static QualifiedType primIsFcfType(Context* context, const CallInfo& ci) { CHPL_UNIMPL("PRIM_IS_FCF_TYPE"); - return makeParamBool(context, false); + return QualifiedType::makeParamBool(context, false); } static QualifiedType primIsUnionType(Context* context, const CallInfo& ci) { @@ -990,7 +937,7 @@ primIsExternUnionType(Context* context, const CallInfo& ci) { auto qt1 = primIsExternType(context, ci); auto qt2 = primIsUnionType(context, ci); const bool eval = qt1.isParamTrue() && qt2.isParamTrue(); - return makeParamBool(context, eval); + return QualifiedType::makeParamBool(context, eval); } static QualifiedType @@ -1037,7 +984,7 @@ primIsCoercible(Context* context, const CallInfo& ci) { bool eval = canPass.passes() && (canPass.instantiates() || canPass.converts()) && !canPass.promotes(); - return makeParamBool(context, eval); + return QualifiedType::makeParamBool(context, eval); } static std::string typeToString(Context* context, const Type* t); @@ -1117,7 +1064,7 @@ primTypeToString(Context* context, const CallInfo& ci) { if (auto t = ci.actual(0).type().type()) { eval = typeToString(context, t); } - return makeParamString(context, eval); + return QualifiedType::makeParamString(context, eval); } static QualifiedType @@ -1131,7 +1078,7 @@ primSimpleTypeName(Context* context, const CallInfo& ci) { } eval = typeToString(context, root); } - return makeParamString(context, eval); + return QualifiedType::makeParamString(context, eval); } CallResolutionResult resolvePrimCall(ResolutionContext* rc, @@ -1319,9 +1266,7 @@ CallResolutionResult resolvePrimCall(ResolutionContext* rc, if (auto tt = t->toTupleType()) result = tt->isStarTuple(); - type = QualifiedType(QualifiedType::PARAM, - BoolType::get(context), - BoolParam::get(context, result)); + type = QualifiedType::makeParamBool(context, result); } break; @@ -1418,9 +1363,7 @@ CallResolutionResult resolvePrimCall(ResolutionContext* rc, if (toParamStringActual(actualType, sParam)|| toParamBytesActual(actualType, sParam)) { const size_t s = sParam.length(); - type = QualifiedType(QualifiedType::PARAM, - IntType::get(context, 0), - IntParam::get(context, s)); + type = QualifiedType::makeParamInt(context, s); break; } else if (actualType.type()->isStringType() || actualType.type()->isBytesType() || @@ -1445,7 +1388,7 @@ CallResolutionResult resolvePrimCall(ResolutionContext* rc, auto lstr = lhs.param()->toStringParam()->value(); auto rstr = rhs.param()->toStringParam()->value(); auto concat = UniqueString::getConcat(context, lstr.c_str(), rstr.c_str()); - type = QualifiedType(QualifiedType::PARAM, lhs.type(), StringParam::get(context, concat)); + type = QualifiedType::makeParamString(context, concat); } } break; @@ -1462,9 +1405,7 @@ CallResolutionResult resolvePrimCall(ResolutionContext* rc, case PRIM_EQUAL: if (ci.actual(0).type().isType() && ci.actual(1).type().isType()) { bool isEqual = ci.actual(0).type().type() == ci.actual(1).type().type(); - type = QualifiedType(QualifiedType::PARAM, - BoolType::get(context), - BoolParam::get(context, isEqual)); + type = QualifiedType::makeParamBool(context, isEqual); break; } case PRIM_IS_WIDE_PTR: @@ -1631,19 +1572,13 @@ CallResolutionResult resolvePrimCall(ResolutionContext* rc, VoidType::get(context)); break; - case PRIM_STATIC_DOMAIN_TYPE: - type = computeDomainType(context, ci); - break; - case PRIM_GET_COMPILER_VAR: { auto chplenv = context->getChplEnv(); auto varName = ci.actual(0).type().param()->toStringParam()->value().str(); auto it = chplenv->find(varName); auto ret = (it != chplenv->end()) ? it->second : ""; - auto st = CompositeType::getStringType(context); - auto sp = StringParam::get(context, UniqueString::get(context, ret)); - type = QualifiedType(QualifiedType::PARAM, st, sp); + type = QualifiedType::makeParamString(context, ret); } break; /* primitives that return real parts from a complex */ @@ -1820,18 +1755,15 @@ CallResolutionResult resolvePrimCall(ResolutionContext* rc, break; case PRIM_VERSION_MAJOR: - type = QualifiedType(QualifiedType::PARAM, IntType::get(context, 0), - IntParam::get(context, getMajorVersion())); + type = QualifiedType::makeParamInt(context, getMajorVersion()); break; case PRIM_VERSION_MINOR: - type = QualifiedType(QualifiedType::PARAM, IntType::get(context, 0), - IntParam::get(context, getMinorVersion())); + type = QualifiedType::makeParamInt(context, getMinorVersion()); break; case PRIM_VERSION_UPDATE: - type = QualifiedType(QualifiedType::PARAM, IntType::get(context, 0), - IntParam::get(context, getUpdateVersion())); + type = QualifiedType::makeParamInt(context, getUpdateVersion()); break; case PRIM_VERSION_SHA: { @@ -1840,8 +1772,7 @@ CallResolutionResult resolvePrimCall(ResolutionContext* rc, versionHash = UniqueString::get(context, getCommitHash()); } - type = QualifiedType(QualifiedType::PARAM, RecordType::getStringType(context), - StringParam::get(context, versionHash)); + type = QualifiedType::makeParamString(context, versionHash); break; } diff --git a/frontend/lib/resolution/resolution-queries.cpp b/frontend/lib/resolution/resolution-queries.cpp index af9f6836196e..b4d4190da8c2 100644 --- a/frontend/lib/resolution/resolution-queries.cpp +++ b/frontend/lib/resolution/resolution-queries.cpp @@ -1804,9 +1804,7 @@ computeNumericValuesOfEnumElements(Context* context, ID node) { Resolver res = Resolver::createForEnumElements(context, enumNode, byPostorder); // The constant 'one' for adding - auto one = QualifiedType(QualifiedType::PARAM, - IntType::get(context, 0), - IntParam::get(context, 1)); + auto one = QualifiedType::makeParamInt(context, 1); // A type to track what kind of signedness a value needs. enum RequiredSignedness { @@ -1944,9 +1942,7 @@ computeNumericValuesOfEnumElements(Context* context, ID node) { UintType::get(context, 0), UintParam::get(context, (uint64_t) *signedValue)); } else { - resultType = QualifiedType(QualifiedType::PARAM, - IntType::get(context, 0), - IntParam::get(context, *signedValue)); + resultType = QualifiedType::makeParamInt(context, *signedValue); } } @@ -2828,9 +2824,14 @@ helpResolveFunction(ResolutionContext* rc, const TypedFnSignature* sig, // same function twice when working with inferred 'out' formals) sig = sig->inferredFrom(); - if (!sig->isInitializer() && sig->needsInstantiation()) { - CHPL_ASSERT(false && "Should only be called on concrete or fully " - "instantiated functions"); + // Signature should be concrete by now, except in the case of an initializer + // or type constructor in which case we may still have generic formals. + // For example, range(?) will reach this point. + if (!sig->isInitializer() && !sig->untyped()->isTypeConstructor() && + sig->needsInstantiation()) { + CHPL_ASSERT(false && + "Should only be called on concrete or fully " + "instantiated functions"); return nullptr; } @@ -3964,10 +3965,7 @@ static bool resolveFnCallSpecial(Context* context, if (srcQtEnumType && dstTy->isStringType()) { std::ostringstream oss; srcQt.param()->stringify(oss, chpl::StringifyKind::CHPL_SYNTAX); - auto ustr = UniqueString::get(context, oss.str()); - exprTypeOut = QualifiedType(QualifiedType::PARAM, - RecordType::getStringType(context), - StringParam::get(context, ustr)); + exprTypeOut = QualifiedType::makeParamString(context, oss.str()); return true; } @@ -3992,10 +3990,7 @@ static bool resolveFnCallSpecial(Context* context, // handle casting a type name to a string std::ostringstream oss; srcTy->stringify(oss, chpl::StringifyKind::CHPL_SYNTAX); - auto ustr = UniqueString::get(context, oss.str()); - exprTypeOut = QualifiedType(QualifiedType::PARAM, - RecordType::getStringType(context), - StringParam::get(context, ustr)); + exprTypeOut = QualifiedType::makeParamString(context, oss.str()); return true; } else if (srcTy->isClassType() && dstTy->isClassType()) { // cast (borrowed class) : unmanaged @@ -4030,13 +4025,9 @@ static bool resolveFnCallSpecial(Context* context, } if ((ci.name() == USTR("==") || ci.name() == USTR("!="))) { - if (ci.numActuals() == 2 || ci.hasQuestionArg()) { + if (ci.numActuals() == 2) { auto lhs = ci.actual(0).type(); - - // support comparisons with '?' - auto rhs = ci.hasQuestionArg() ? - QualifiedType(QualifiedType::TYPE, AnyType::get(context)) : - ci.actual(1).type(); + auto rhs = ci.actual(1).type(); bool bothType = lhs.kind() == QualifiedType::TYPE && rhs.kind() == QualifiedType::TYPE; @@ -4045,8 +4036,26 @@ static bool resolveFnCallSpecial(Context* context, if (bothType || bothParam) { bool result = lhs == rhs; result = ci.name() == USTR("==") ? result : !result; - exprTypeOut = QualifiedType(QualifiedType::PARAM, BoolType::get(context), - BoolParam::get(context, result)); + exprTypeOut = QualifiedType::makeParamBool(context, result); + return true; + } + } else if (ci.numActuals() == 1 && ci.hasQuestionArg()) { + // support type and param comparisons with '?' + // TODO: will likely need adjustment once we are able to compare a + // partially-instantiated type's fields with '?' + auto arg = ci.actual(0).type(); + bool result = false; + bool haveResult = true; + if (arg.isType()) { + result = arg.type()->isAnyType(); + } else if (arg.isParam()) { + result = arg.param() == nullptr; + } else { + haveResult = false; + } + result = ci.name() == USTR("==") ? result : !result; + if (haveResult) { + exprTypeOut = QualifiedType::makeParamBool(context, result); return true; } } @@ -4076,8 +4085,7 @@ static bool resolveFnCallSpecial(Context* context, } auto got = canPassScalar(context, ci.actual(0).type(), ci.actual(1).type()); bool result = got.passes(); - exprTypeOut = QualifiedType(QualifiedType::PARAM, BoolType::get(context), - BoolParam::get(context, result)); + exprTypeOut = QualifiedType::makeParamBool(context, result); return true; } @@ -6082,9 +6090,7 @@ QualifiedType paramTypeFromValue(Context* context, T value); template <> QualifiedType paramTypeFromValue(Context* context, bool value) { - return QualifiedType(QualifiedType::PARAM, - BoolType::get(context), - BoolParam::get(context, value)); + return QualifiedType::makeParamBool(context, value); } const std::unordered_map& diff --git a/frontend/lib/resolution/return-type-inference.cpp b/frontend/lib/resolution/return-type-inference.cpp index b5446ff6d864..27d89664f253 100644 --- a/frontend/lib/resolution/return-type-inference.cpp +++ b/frontend/lib/resolution/return-type-inference.cpp @@ -1205,8 +1205,7 @@ static bool helpComputeCompilerGeneratedReturnType(Context* context, auto ast = parsing::idToAst(context, enumType->id())->toEnum(); CHPL_ASSERT(ast); int numElts = ast->numElements(); - result = QualifiedType(QualifiedType::PARAM, IntType::get(context, 0), - IntParam::get(context, numElts)); + result = QualifiedType::makeParamInt(context, numElts); return true; } CHPL_ASSERT(false && "unhandled compiler-generated enum method"); @@ -1299,8 +1298,7 @@ static bool helpComputeReturnType(ResolutionContext* rc, } else if (untyped->isMethod() && sig->formalType(0).type()->isTupleType() && untyped->name() == "size") { auto tup = sig->formalType(0).type()->toTupleType(); - result = QualifiedType(QualifiedType::PARAM, IntType::get(context, 0), - IntParam::get(context, tup->numElements())); + result = QualifiedType::makeParamInt(context, tup->numElements()); return true; // if method call and the receiver points to a composite type definition, diff --git a/frontend/lib/types/DomainType.cpp b/frontend/lib/types/DomainType.cpp index e1165a2bb833..4fa01772aef5 100644 --- a/frontend/lib/types/DomainType.cpp +++ b/frontend/lib/types/DomainType.cpp @@ -118,17 +118,35 @@ DomainType::getRectangularType(Context* context, const DomainType* DomainType::getAssociativeType(Context* context, + const QualifiedType& instance, const QualifiedType& idxType, const QualifiedType& parSafe) { + auto genericDomain = getGenericDomainType(context); + SubstitutionsMap subs; - // TODO: assert validity of sub types subs.emplace(ID(UniqueString(), 0, 0), idxType); + CHPL_ASSERT(idxType.isType()); subs.emplace(ID(UniqueString(), 1, 0), parSafe); + CHPL_ASSERT(parSafe.isParam() && parSafe.param() && + parSafe.param()->isBoolParam()); + + // Add substitution for _instance field + auto& rf = fieldsForTypeDecl(context, genericDomain, + resolution::DefaultsPolicy::IGNORE_DEFAULTS, + /* syntaxOnly */ true); + ID instanceFieldId; + for (int i = 0; i < rf.numFields(); i++) { + if (rf.fieldName(i) == USTR("_instance")) { + instanceFieldId = rf.fieldDeclId(i); + break; + } + } + subs.emplace(instanceFieldId, instance); + auto name = UniqueString::get(context, "_domain"); auto id = getDomainID(context); - auto instantiatedFrom = getGenericDomainType(context); - return getDomainType(context, id, name, instantiatedFrom, subs, - DomainType::Kind::Associative).get(); + return getDomainType(context, id, name, /* instantiatedFrom */ genericDomain, + subs, DomainType::Kind::Associative).get(); } const QualifiedType& DomainType::getDefaultDistType(Context* context) { diff --git a/frontend/lib/types/QualifiedType.cpp b/frontend/lib/types/QualifiedType.cpp index 4762f9206d25..d064c1d4d57d 100644 --- a/frontend/lib/types/QualifiedType.cpp +++ b/frontend/lib/types/QualifiedType.cpp @@ -20,6 +20,7 @@ #include "chpl/types/QualifiedType.h" #include "chpl/resolution/resolution-queries.h" +#include "chpl/types/all-types.h" #include "chpl/types/Param.h" #include "chpl/types/Type.h" #include "chpl/types/TupleType.h" @@ -52,6 +53,25 @@ bool QualifiedType::isParamKnownTuple() const { return false; } +QualifiedType QualifiedType::makeParamBool(Context* context, bool b) { + return {QualifiedType::PARAM, BoolType::get(context), + BoolParam::get(context, b)}; +} + +QualifiedType QualifiedType::makeParamInt(Context* context, int64_t i) { + return {QualifiedType::PARAM, IntType::get(context, 0), + IntParam::get(context, i)}; +} + +QualifiedType QualifiedType::makeParamString(Context* context, UniqueString s) { + return {QualifiedType::PARAM, RecordType::getStringType(context), + StringParam::get(context, s)}; +} + +QualifiedType QualifiedType::makeParamString(Context* context, std::string s) { + return makeParamString(context, UniqueString::get(context, s)); +} + bool QualifiedType::needsSplitInitTypeInfo(Context* context) const { return (isParam() && !hasParamPtr()) || isUnknownKindOrType() || diff --git a/frontend/test/resolution/testDomains.cpp b/frontend/test/resolution/testDomains.cpp index d37e62030fd3..c301df66c4fe 100644 --- a/frontend/test/resolution/testDomains.cpp +++ b/frontend/test/resolution/testDomains.cpp @@ -41,6 +41,8 @@ static void testRectangular(Context* context, int rank, std::string idxType, std::string strides) { + printf("Testing: %s\n", domainType.c_str()); + context->advanceToNextRevision(false); setupModuleSearchPaths(context, false, false, {}, {}); ErrorGuard guard(context); @@ -105,11 +107,10 @@ module M { assert(aa.action() == AssociatedAction::RUNTIME_TYPE); QualifiedType fullIndexType = findVarType(m, rr, "fullIndex"); - (void)fullIndexType; auto rankVarTy = findVarType(m, rr, "r"); assert(rankVarTy == dType->rank()); - assert(rankVarTy.param()->toIntParam()->value() == rank); + ensureParamInt(rankVarTy, rank); auto idxTypeVarTy = findVarType(m, rr, "i"); assert(idxTypeVarTy == dType->idxType()); @@ -119,11 +120,11 @@ module M { assert(stridesVarTy == dType->strides()); assert(stridesVarTy.param()->toEnumParam()->value().str == strides); - assert(findVarType(m, rr, "rk").param()->toBoolParam()->value() == true); + ensureParamBool(findVarType(m, rr, "rk"), true); - assert(findVarType(m, rr, "ak").param()->toBoolParam()->value() == false); + ensureParamBool(findVarType(m, rr, "ak"), false); - assert(findVarType(m, rr, "p").type() == IntType::get(context, 0)); + assert(findVarType(m, rr, "p").type()->isIntType()); assert(findVarType(m, rr, "z").type() == fullIndexType.type()); @@ -154,109 +155,155 @@ module M { } assert(guard.realizeErrors() == 0); +} + +static void testDomainLiteral(Context* context, + std::string domainLiteral, + DomainType::Kind domainKind) { + printf("Testing: %s\n", domainLiteral.c_str()); + + context->advanceToNextRevision(false); + setupModuleSearchPaths(context, false, false, {}, {}); + ErrorGuard guard(context); - printf("Success: %s\n", domainType.c_str()); + std::string program = +R"""( +module M { + var d = )""" + domainLiteral + R"""(; + + type i = d.idxType; + param rk = d.isRectangular(); + param ak = d.isAssociative(); } +)"""; -// static void testAssociative(Context* context, -// std::string domainType, -// std::string idxType, -// bool parSafe) { -// context->advanceToNextRevision(false); -// setupModuleSearchPaths(context, false, false, {}, {}); -// ErrorGuard guard(context); + auto path = UniqueString::get(context, "input.chpl"); + setFileText(context, path, std::move(program)); -// std::string program = -// R"""( -// module M { -// var d : )""" + domainType + R"""(; -// type ig = )""" + idxType + R"""(; + const ModuleVec& vec = parseToplevel(context, path); + const Module* m = vec[0]; -// type i = d.idxType; -// param s = d.parSafe; -// param rk = d.isRectangular(); -// param ak = d.isAssociative(); + const ResolutionResultByPostorderID& rr = resolveModule(context, m->id()); -// var p = d.pid(); + const Variable* d = m->stmt(0)->toVariable(); + assert(d); + assert(d->name() == "d"); -// for loopI in d { -// var z = loopI; -// } + QualifiedType dQt = rr.byAst(d).type(); + assert(dQt.type()); + auto dType = dQt.type()->toDomainType(); + assert(dType); -// proc generic(arg: domain) { -// type GT = arg.type; -// return 42; -// } + assert(findVarType(m, rr, "i") == dType->idxType()); -// proc concrete(arg: )""" + domainType + R"""() { -// type CT = arg.type; -// return 42; -// } + assert(dType->kind() == domainKind); + bool isRectangular = domainKind == DomainType::Kind::Rectangular; + ensureParamBool(findVarType(m, rr, "rk"), isRectangular); + ensureParamBool(findVarType(m, rr, "ak"), !isRectangular); -// var g_ret = generic(d); -// var c_ret = concrete(d); -// } -// )"""; -// // TODO: generic checks + assert(guard.realizeErrors() == 0); +} -// auto path = UniqueString::get(context, "input.chpl"); -// setFileText(context, path, std::move(program)); +static void testAssociative(Context* context, + std::string domainType, + std::string idxType, + bool parSafe) { + printf("Testing: %s\n", domainType.c_str()); -// const ModuleVec& vec = parseToplevel(context, path); -// const Module* m = vec[1]; + context->advanceToNextRevision(false); + setupModuleSearchPaths(context, false, false, {}, {}); + ErrorGuard guard(context); -// const ResolutionResultByPostorderID& rr = resolveModule(context, m->id()); + std::string program = +R"""( +module M { + var d : )""" + domainType + R"""(; + type ig = )""" + idxType + R"""(; -// QualifiedType dType = findVarType(m, rr, "d"); -// assert(dType.type()->isDomainType()); + type i = d.idxType; + param s = d.parSafe; + param rk = d.isRectangular(); + param ak = d.isAssociative(); -// auto fullIndexType = findVarType(m, rr, "i"); -// assert(findVarType(m, rr, "ig") == fullIndexType); + var p = d.pid; -// assert(findVarType(m, rr, "s").param()->toBoolParam()->value() == parSafe); + for loopI in d { + var z = loopI; + } -// assert(findVarType(m, rr, "rk").param()->toBoolParam()->value() == false); + proc generic(arg: domain) { + type GT = arg.type; + return 42; + } -// assert(findVarType(m, rr, "ak").param()->toBoolParam()->value() == true); + proc concrete(arg: )""" + domainType + R"""() { + type CT = arg.type; + return 42; + } + + var g_ret = generic(d); + var c_ret = concrete(d); +} +)"""; + + auto path = UniqueString::get(context, "input.chpl"); + setFileText(context, path, std::move(program)); + + const ModuleVec& vec = parseToplevel(context, path); + const Module* m = vec[0]; -// assert(findVarType(m, rr, "p").type() == IntType::get(context, 0)); + const ResolutionResultByPostorderID& rr = resolveModule(context, m->id()); -// assert(findVarType(m, rr, "z").type() == fullIndexType.type()); + QualifiedType dType = findVarType(m, rr, "d"); + assert(dType.type()->isDomainType()); -// { -// const Variable* g_ret = findOnlyNamed(m, "g_ret")->toVariable(); -// auto res = rr.byAst(g_ret); -// assert(res.type().type()->isIntType()); + auto fullIndexType = findVarType(m, rr, "i"); + assert(findVarType(m, rr, "ig") == fullIndexType); -// auto call = resolveOnlyCandidate(context, rr.byAst(g_ret->initExpression())); -// // Generic function, should have been instantiated -// assert(call->signature()->instantiatedFrom() != nullptr); + ensureParamBool(findVarType(m, rr, "s"), parSafe); -// const Variable* GT = findOnlyNamed(m, "GT")->toVariable(); -// assert(call->byAst(GT).type().type() == dType.type()); -// } + ensureParamBool(findVarType(m, rr, "rk"), false); -// { -// const Variable* c_ret = findOnlyNamed(m, "c_ret")->toVariable(); -// auto res = rr.byAst(c_ret); -// assert(res.type().type()->isIntType()); + ensureParamBool(findVarType(m, rr, "ak"), true); -// auto call = resolveOnlyCandidate(context, rr.byAst(c_ret->initExpression())); -// // Concrete function, should not be instantiated -// assert(call->signature()->instantiatedFrom() == nullptr); + assert(findVarType(m, rr, "p").type()->isIntType()); -// const Variable* CT = findOnlyNamed(m, "CT")->toVariable(); -// assert(call->byAst(CT).type().type() == dType.type()); -// } + assert(findVarType(m, rr, "z").type() == fullIndexType.type()); -// assert(guard.errors().size() == 0); + { + const Variable* g_ret = findOnlyNamed(m, "g_ret")->toVariable(); + auto res = rr.byAst(g_ret); + assert(res.type().type()->isIntType()); -// printf("Success: %s\n", domainType.c_str()); -// } + auto call = resolveOnlyCandidate(context, rr.byAst(g_ret->initExpression())); + // Generic function, should have been instantiated + assert(call->signature()->instantiatedFrom() != nullptr); + + const Variable* GT = findOnlyNamed(m, "GT")->toVariable(); + assert(call->byAst(GT).type().type() == dType.type()); + } + { + const Variable* c_ret = findOnlyNamed(m, "c_ret")->toVariable(); + auto res = rr.byAst(c_ret); + assert(res.type().type()->isIntType()); + + auto call = resolveOnlyCandidate(context, rr.byAst(c_ret->initExpression())); + // Concrete function, should not be instantiated + assert(call->signature()->instantiatedFrom() == nullptr); + + const Variable* CT = findOnlyNamed(m, "CT")->toVariable(); + assert(call->byAst(CT).type().type() == dType.type()); + } + + assert(guard.errors().size() == 0); +} + +// Ensure that we can't, e.g., pass a domain(1) to a domain(2) static void testBadPass(Context* context, std::string argType, std::string actualType) { - // Ensure that we can't, e.g., pass a domain(1) to a domain(2) + printf("Testing: cannot pass %s to %s\n", actualType.c_str(), argType.c_str()); + context->advanceToNextRevision(false); setupModuleSearchPaths(context, false, false, {}, {}); ErrorGuard guard(context); @@ -288,8 +335,6 @@ module M { auto& e = guard.errors()[0]; assert(e->type() == chpl::NoMatchingCandidates); - printf("Success: cannot pass %s to %s\n", actualType.c_str(), argType.c_str()); - // 'clear' rather than 'realize' to simplify test output guard.clearErrors(); } @@ -297,6 +342,9 @@ module M { static void testIndex(Context* context, std::string domainType, std::string expectedType) { + printf("Testing: index(%s) == %s\n", domainType.c_str(), + expectedType.c_str()); + context->advanceToNextRevision(false); setupModuleSearchPaths(context, false, false, {}, {}); ErrorGuard guard(context); @@ -324,12 +372,9 @@ module M { assert(!findVarType(m, rr, "t").isUnknownOrErroneous()); assert(!findVarType(m, rr, "i").isUnknownOrErroneous()); - assert(findVarType(m, rr, "equal").isParamTrue()); - - // assert(guard.realizeErrors() == 0); + ensureParamBool(findVarType(m, rr, "equal"), true); - printf("Success: index(%s) == %s\n", domainType.c_str(), - expectedType.c_str()); + assert(guard.realizeErrors() == 0); } static void testBadDomainHelper(std::string domainType, Context* context, @@ -365,6 +410,9 @@ module M { // Ensure we gracefully error for bad domain type expressions, with or without // the standard modules available. static void testBadDomain(Context* contextWithStd, std::string domainType) { + printf("Testing: cannot resolve %s\n", + domainType.c_str()); + // With standard modules { contextWithStd->advanceToNextRevision(false); @@ -373,9 +421,6 @@ static void testBadDomain(Context* contextWithStd, std::string domainType) { testBadDomainHelper(domainType, contextWithStd, guard); } - - printf("Success: cannot resolve %s\n", - domainType.c_str()); } int main() { @@ -389,31 +434,27 @@ int main() { testRectangular(context, "domain(2, int(8))", 2, "int(8)", "one"); testRectangular(context, "domain(3, int(16), strideKind.negOne)", 3, "int(16)", "negOne"); testRectangular(context, "domain(strides=strideKind.negative, idxType=int, rank=1)", 1, "int", "negative"); - context->collectGarbage(); - // TODO: re-enable associative - // testAssociative(context, "domain(int)", "int", true); - // testAssociative(context, "domain(int, false)", "int", false); - // testAssociative(context, "domain(string)", "string", true); - // context->collectGarbage(); + testDomainLiteral(context, "{1..10}", DomainType::Kind::Rectangular); + testDomainLiteral(context, "{1..10, 1..10}", DomainType::Kind::Rectangular); + + testAssociative(context, "domain(int)", "int", true); + testAssociative(context, "domain(int, false)", "int", false); + testAssociative(context, "domain(string)", "string", true); testBadPass(context, "domain(1)", "domain(2)"); testBadPass(context, "domain(1, int(16))", "domain(1, int(8))"); testBadPass(context, "domain(1, int(8))", "domain(1, int(16))"); testBadPass(context, "domain(1, strides=strideKind.negOne)", "domain(1, strides=strideKind.one)"); - // TODO: re-enable associative badPass - // testBadPass(context, "domain(int)", "domain(string)"); - // testBadPass(context, "domain(1)", "domain(int)"); - context->collectGarbage(); + testBadPass(context, "domain(int)", "domain(string)"); + testBadPass(context, "domain(1)", "domain(int)"); testIndex(context, "domain(1)", "int"); testIndex(context, "domain(2)", "2*int"); testIndex(context, "domain(1, bool)", "bool"); testIndex(context, "domain(2, bool)", "2*bool"); - // TODO: re-enable associative indexes - // testIndex(context, "domain(int)", "int"); - // testIndex(context, "domain(string)", "string"); - context->collectGarbage(); + testIndex(context, "domain(int)", "int"); + testIndex(context, "domain(string)", "string"); testBadDomain(context, "domain()"); testBadDomain(context, "domain(1, 2, 3, 4)");