diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index 8a41c9e87..dcfb14b4b 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/Ast.h" // Used for some of the enumerations #include "Luau/NotNull.h" #include "Luau/Variant.h" @@ -47,6 +48,21 @@ struct InstantiationConstraint TypeId superType; }; +struct UnaryConstraint +{ + AstExprUnary::Op op; + TypeId operandType; + TypeId resultType; +}; + +struct BinaryConstraint +{ + AstExprBinary::Op op; + TypeId leftType; + TypeId rightType; + TypeId resultType; +}; + // name(namedType) = name struct NameConstraint { @@ -54,7 +70,8 @@ struct NameConstraint std::string name; }; -using ConstraintV = Variant; +using ConstraintV = Variant; using ConstraintPtr = std::unique_ptr; struct Constraint diff --git a/Analysis/include/Luau/ConstraintGraphBuilder.h b/Analysis/include/Luau/ConstraintGraphBuilder.h index 9b118691d..a49e85941 100644 --- a/Analysis/include/Luau/ConstraintGraphBuilder.h +++ b/Analysis/include/Luau/ConstraintGraphBuilder.h @@ -25,9 +25,12 @@ struct ConstraintGraphBuilder // scope pointers; the scopes themselves borrow pointers to other scopes to // define the scope hierarchy. std::vector>> scopes; + + ModuleName moduleName; SingletonTypes& singletonTypes; - TypeArena* const arena; + const NotNull arena; // The root scope of the module we're generating constraints for. + // This is null when the CGB is initially constructed. Scope2* rootScope; // A mapping of AST node to TypeId. DenseHashMap astTypes{nullptr}; @@ -39,40 +42,50 @@ struct ConstraintGraphBuilder // Type packs resolved from type annotations. Analogous to astTypePacks. DenseHashMap astResolvedTypePacks{nullptr}; - explicit ConstraintGraphBuilder(TypeArena* arena); + int recursionCount = 0; + + // It is pretty uncommon for constraint generation to itself produce errors, but it can happen. + std::vector errors; + + // Occasionally constraint generation needs to produce an ICE. + const NotNull ice; + + NotNull globalScope; + + ConstraintGraphBuilder(const ModuleName& moduleName, TypeArena* arena, NotNull ice, NotNull globalScope); /** * Fabricates a new free type belonging to a given scope. - * @param scope the scope the free type belongs to. Must not be null. + * @param scope the scope the free type belongs to. */ - TypeId freshType(Scope2* scope); + TypeId freshType(NotNull scope); /** * Fabricates a new free type pack belonging to a given scope. - * @param scope the scope the free type pack belongs to. Must not be null. + * @param scope the scope the free type pack belongs to. */ - TypePackId freshTypePack(Scope2* scope); + TypePackId freshTypePack(NotNull scope); /** * Fabricates a scope that is a child of another scope. * @param location the lexical extent of the scope in the source code. * @param parent the parent scope of the new scope. Must not be null. */ - Scope2* childScope(Location location, Scope2* parent); + NotNull childScope(Location location, NotNull parent); /** * Adds a new constraint with no dependencies to a given scope. - * @param scope the scope to add the constraint to. Must not be null. + * @param scope the scope to add the constraint to. * @param cv the constraint variant to add. */ - void addConstraint(Scope2* scope, ConstraintV cv); + void addConstraint(NotNull scope, ConstraintV cv); /** * Adds a constraint to a given scope. * @param scope the scope to add the constraint to. Must not be null. * @param c the constraint to add. */ - void addConstraint(Scope2* scope, std::unique_ptr c); + void addConstraint(NotNull scope, std::unique_ptr c); /** * The entry point to the ConstraintGraphBuilder. This will construct a set @@ -81,20 +94,22 @@ struct ConstraintGraphBuilder */ void visit(AstStatBlock* block); - void visit(Scope2* scope, AstStat* stat); - void visit(Scope2* scope, AstStatBlock* block); - void visit(Scope2* scope, AstStatLocal* local); - void visit(Scope2* scope, AstStatLocalFunction* function); - void visit(Scope2* scope, AstStatFunction* function); - void visit(Scope2* scope, AstStatReturn* ret); - void visit(Scope2* scope, AstStatAssign* assign); - void visit(Scope2* scope, AstStatIf* ifStatement); - void visit(Scope2* scope, AstStatTypeAlias* alias); + void visitBlockWithoutChildScope(NotNull scope, AstStatBlock* block); + + void visit(NotNull scope, AstStat* stat); + void visit(NotNull scope, AstStatBlock* block); + void visit(NotNull scope, AstStatLocal* local); + void visit(NotNull scope, AstStatLocalFunction* function); + void visit(NotNull scope, AstStatFunction* function); + void visit(NotNull scope, AstStatReturn* ret); + void visit(NotNull scope, AstStatAssign* assign); + void visit(NotNull scope, AstStatIf* ifStatement); + void visit(NotNull scope, AstStatTypeAlias* alias); - TypePackId checkExprList(Scope2* scope, const AstArray& exprs); + TypePackId checkExprList(NotNull scope, const AstArray& exprs); - TypePackId checkPack(Scope2* scope, AstArray exprs); - TypePackId checkPack(Scope2* scope, AstExpr* expr); + TypePackId checkPack(NotNull scope, AstArray exprs); + TypePackId checkPack(NotNull scope, AstExpr* expr); /** * Checks an expression that is expected to evaluate to one type. @@ -102,19 +117,35 @@ struct ConstraintGraphBuilder * @param expr the expression to check. * @return the type of the expression. */ - TypeId check(Scope2* scope, AstExpr* expr); - - TypeId checkExprTable(Scope2* scope, AstExprTable* expr); - TypeId check(Scope2* scope, AstExprIndexName* indexName); - - std::pair checkFunctionSignature(Scope2* parent, AstExprFunction* fn); + TypeId check(NotNull scope, AstExpr* expr); + + TypeId checkExprTable(NotNull scope, AstExprTable* expr); + TypeId check(NotNull scope, AstExprIndexName* indexName); + TypeId check(NotNull scope, AstExprIndexExpr* indexExpr); + TypeId check(NotNull scope, AstExprUnary* unary); + TypeId check(NotNull scope, AstExprBinary* binary); + + struct FunctionSignature + { + // The type of the function. + TypeId signature; + // The scope that encompasses the function's signature. May be nullptr + // if there was no need for a signature scope (the function has no + // generics). + Scope2* signatureScope; + // The scope that encompasses the function's body. Is a child scope of + // signatureScope, if present. + NotNull bodyScope; + }; + + FunctionSignature checkFunctionSignature(NotNull parent, AstExprFunction* fn); /** * Checks the body of a function expression. * @param scope the interior scope of the body of the function. * @param fn the function expression to check. */ - void checkFunctionBody(Scope2* scope, AstExprFunction* fn); + void checkFunctionBody(NotNull scope, AstExprFunction* fn); /** * Resolves a type from its AST annotation. @@ -122,7 +153,7 @@ struct ConstraintGraphBuilder * @param ty the AST annotation to resolve. * @return the type of the AST annotation. **/ - TypeId resolveType(Scope2* scope, AstType* ty); + TypeId resolveType(NotNull scope, AstType* ty); /** * Resolves a type pack from its AST annotation. @@ -130,9 +161,25 @@ struct ConstraintGraphBuilder * @param tp the AST annotation to resolve. * @return the type pack of the AST annotation. **/ - TypePackId resolveTypePack(Scope2* scope, AstTypePack* tp); + TypePackId resolveTypePack(NotNull scope, AstTypePack* tp); + + TypePackId resolveTypePack(NotNull scope, const AstTypeList& list); + + std::vector> createGenerics(NotNull scope, AstArray generics); + std::vector> createGenericPacks(NotNull scope, AstArray packs); - TypePackId resolveTypePack(Scope2* scope, const AstTypeList& list); + TypeId flattenPack(NotNull scope, Location location, TypePackId tp); + + void reportError(Location location, TypeErrorData err); + void reportCodeTooComplex(Location location); + + /** Scan the program for global definitions. + * + * ConstraintGraphBuilder needs to differentiate between globals and accesses to undefined symbols. Doing this "for + * real" in a general way is going to be pretty hard, so we are choosing not to tackle that yet. For now, we do an + * initial scan of the AST and note what globals are defined. + */ + void prepopulateGlobalScope(NotNull globalScope, AstStatBlock* program); }; /** @@ -145,6 +192,6 @@ struct ConstraintGraphBuilder * @return a list of pointers to constraints contained within the scope graph. * None of these pointers should be null. */ -std::vector> collectConstraints(Scope2* rootScope); +std::vector> collectConstraints(NotNull rootScope); } // namespace Luau diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index 4870157fb..cf88efb6e 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -25,7 +25,7 @@ struct ConstraintSolver // is important to not add elements to this vector, lest the underlying // storage that we retain pointers to be mutated underneath us. const std::vector> constraints; - Scope2* rootScope; + NotNull rootScope; // This includes every constraint that has not been fully solved. // A constraint can be both blocked and unsolved, for instance. @@ -40,7 +40,7 @@ struct ConstraintSolver ConstraintSolverLogger logger; - explicit ConstraintSolver(TypeArena* arena, Scope2* rootScope); + explicit ConstraintSolver(TypeArena* arena, NotNull rootScope); /** * Attempts to dispatch all pending constraints and reach a type solution @@ -50,11 +50,17 @@ struct ConstraintSolver bool done(); + /** Attempt to dispatch a constraint. Returns true if it was successful. + * If tryDispatch() returns false, the constraint remains in the unsolved set and will be retried later. + */ bool tryDispatch(NotNull c, bool force); + bool tryDispatch(const SubtypeConstraint& c, NotNull constraint, bool force); bool tryDispatch(const PackSubtypeConstraint& c, NotNull constraint, bool force); bool tryDispatch(const GeneralizationConstraint& c, NotNull constraint, bool force); bool tryDispatch(const InstantiationConstraint& c, NotNull constraint, bool force); + bool tryDispatch(const UnaryConstraint& c, NotNull constraint, bool force); + bool tryDispatch(const BinaryConstraint& c, NotNull constraint, bool force); bool tryDispatch(const NameConstraint& c, NotNull constraint); void block(NotNull target, NotNull constraint); @@ -115,6 +121,6 @@ struct ConstraintSolver void unblock_(BlockedConstraintId progressed); }; -void dump(Scope2* rootScope, struct ToStringOptions& opts); +void dump(NotNull rootScope, struct ToStringOptions& opts); } // namespace Luau diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index a13239609..4c81d33d2 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -369,24 +369,25 @@ struct InternalErrorReporter [[noreturn]] void ice(const std::string& message); }; -class InternalCompilerError : public std::exception { +class InternalCompilerError : public std::exception +{ public: - explicit InternalCompilerError(const std::string& message, const std::string& moduleName) - : message(message) - , moduleName(moduleName) - { - } - explicit InternalCompilerError(const std::string& message, const std::string& moduleName, const Location& location) - : message(message) - , moduleName(moduleName) - , location(location) - { - } - virtual const char* what() const throw(); - - const std::string message; - const std::string moduleName; - const std::optional location; + explicit InternalCompilerError(const std::string& message, const std::string& moduleName) + : message(message) + , moduleName(moduleName) + { + } + explicit InternalCompilerError(const std::string& message, const std::string& moduleName, const Location& location) + : message(message) + , moduleName(moduleName) + , location(location) + { + } + virtual const char* what() const throw(); + + const std::string message; + const std::string moduleName; + const std::optional location; }; } // namespace Luau diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index f4226cc1f..f0d430904 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -5,6 +5,7 @@ #include "Luau/Module.h" #include "Luau/ModuleResolver.h" #include "Luau/RequireTracer.h" +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/Variant.h" @@ -158,6 +159,8 @@ struct Frontend void registerBuiltinDefinition(const std::string& name, std::function); void applyBuiltinDefinitionToEnvironment(const std::string& environmentName, const std::string& definitionName); + NotNull getGlobalScope2(); + private: ModulePtr check(const SourceModule& sourceModule, Mode mode, const ScopePtr& environmentScope); @@ -173,6 +176,8 @@ struct Frontend std::unordered_map environments; std::unordered_map> builtinDefinitions; + std::unique_ptr globalScope2; + public: FileResolver* fileResolver; FrontendModuleResolver moduleResolver; diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index 39f8dfb74..b3105b786 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -68,7 +68,7 @@ struct Module std::shared_ptr allocator; std::shared_ptr names; - std::vector> scopes; // never empty + std::vector> scopes; // never empty std::vector>> scope2s; // never empty DenseHashMap astTypes{nullptr}; diff --git a/Analysis/include/Luau/NotNull.h b/Analysis/include/Luau/NotNull.h index f6043e9c6..714fa1437 100644 --- a/Analysis/include/Luau/NotNull.h +++ b/Analysis/include/Luau/NotNull.h @@ -26,7 +26,7 @@ namespace Luau * The explicit delete statement is permitted (but not recommended) on a * NotNull through this implicit conversion. */ -template +template struct NotNull { explicit NotNull(T* t) @@ -38,10 +38,11 @@ struct NotNull explicit NotNull(std::nullptr_t) = delete; void operator=(std::nullptr_t) = delete; - template + template NotNull(NotNull other) : ptr(other.get()) - {} + { + } operator T*() const noexcept { @@ -72,12 +73,13 @@ struct NotNull T* ptr; }; -} +} // namespace Luau namespace std { -template struct hash> +template +struct hash> { size_t operator()(const Luau::NotNull& p) const { @@ -85,4 +87,4 @@ template struct hash> } }; -} +} // namespace std diff --git a/Analysis/include/Luau/Scope.h b/Analysis/include/Luau/Scope.h index cef4b94f3..0eaecf1d8 100644 --- a/Analysis/include/Luau/Scope.h +++ b/Analysis/include/Luau/Scope.h @@ -3,6 +3,7 @@ #include "Luau/Constraint.h" #include "Luau/Location.h" +#include "Luau/NotNull.h" #include "Luau/TypeVar.h" #include @@ -71,15 +72,18 @@ struct Scope2 // is the module-level scope). Scope2* parent = nullptr; // All the children of this scope. - std::vector children; + std::vector> children; std::unordered_map bindings; // TODO: I think this can be a DenseHashMap std::unordered_map typeBindings; + std::unordered_map typePackBindings; TypePackId returnType; + std::optional varargPack; // All constraints belonging to this scope. std::vector constraints; std::optional lookup(Symbol sym); std::optional lookupTypeBinding(const Name& name); + std::optional lookupTypePackBinding(const Name& name); }; } // namespace Luau diff --git a/Analysis/include/Luau/TypeArena.h b/Analysis/include/Luau/TypeArena.h index 559c55c8c..be36f19c7 100644 --- a/Analysis/include/Luau/TypeArena.h +++ b/Analysis/include/Luau/TypeArena.h @@ -34,6 +34,12 @@ struct TypeArena TypePackId addTypePack(std::vector types); TypePackId addTypePack(TypePack pack); TypePackId addTypePack(TypePackVar pack); + + template + TypePackId addTypePack(T tp) + { + return addTypePack(TypePackVar(std::move(tp))); + } }; void freeze(TypeArena& arena); diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 28adc9d9a..455654d95 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -173,7 +173,7 @@ struct TypeChecker TypeId checkFunctionName(const ScopePtr& scope, AstExpr& funName, TypeLevel level); std::pair checkFunctionSignature(const ScopePtr& scope, int subLevel, const AstExprFunction& expr, - std::optional originalNameLoc, std::optional expectedType); + std::optional originalNameLoc, std::optional selfType, std::optional expectedType); void checkFunctionBody(const ScopePtr& scope, TypeId type, const AstExprFunction& function); void checkArgumentList( @@ -424,6 +424,8 @@ struct TypeChecker * (exported, name) to properly deal with the case where the two duplicates do not have the same export status. */ DenseHashSet, HashBoolNamePair> duplicateTypeAliases; + + std::vector> deferredQuantification; }; // Unit test hook diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index 20f4107ce..6ad6b9270 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -357,6 +357,9 @@ struct TableTypeVar std::optional boundTo; Tags tags; + + // Methods of this table that have an untyped self will use the same shared self type. + std::optional selfTy; }; // Represents a metatable attached to a table typevar. Somewhat analogous to a bound typevar. diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index d9e8d238b..3b9000cd8 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -1,6 +1,10 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/ConstraintGraphBuilder.h" +#include "Luau/RecursionCounter.h" +#include "Luau/ToString.h" + +LUAU_FASTINT(LuauCheckRecursionLimit); #include "Luau/Scope.h" @@ -9,32 +13,33 @@ namespace Luau const AstStat* getFallthrough(const AstStat* node); // TypeInfer.cpp -ConstraintGraphBuilder::ConstraintGraphBuilder(TypeArena* arena) - : singletonTypes(getSingletonTypes()) +ConstraintGraphBuilder::ConstraintGraphBuilder( + const ModuleName& moduleName, TypeArena* arena, NotNull ice, NotNull globalScope) + : moduleName(moduleName) + , singletonTypes(getSingletonTypes()) , arena(arena) , rootScope(nullptr) + , ice(ice) + , globalScope(globalScope) { LUAU_ASSERT(arena); } -TypeId ConstraintGraphBuilder::freshType(Scope2* scope) +TypeId ConstraintGraphBuilder::freshType(NotNull scope) { - LUAU_ASSERT(scope); return arena->addType(FreeTypeVar{scope}); } -TypePackId ConstraintGraphBuilder::freshTypePack(Scope2* scope) +TypePackId ConstraintGraphBuilder::freshTypePack(NotNull scope) { - LUAU_ASSERT(scope); FreeTypePack f{scope}; return arena->addTypePack(TypePackVar{std::move(f)}); } -Scope2* ConstraintGraphBuilder::childScope(Location location, Scope2* parent) +NotNull ConstraintGraphBuilder::childScope(Location location, NotNull parent) { - LUAU_ASSERT(parent); auto scope = std::make_unique(); - Scope2* borrow = scope.get(); + NotNull borrow = NotNull(scope.get()); scopes.emplace_back(location, std::move(scope)); borrow->parent = parent; @@ -44,15 +49,13 @@ Scope2* ConstraintGraphBuilder::childScope(Location location, Scope2* parent) return borrow; } -void ConstraintGraphBuilder::addConstraint(Scope2* scope, ConstraintV cv) +void ConstraintGraphBuilder::addConstraint(NotNull scope, ConstraintV cv) { - LUAU_ASSERT(scope); scope->constraints.emplace_back(new Constraint{std::move(cv)}); } -void ConstraintGraphBuilder::addConstraint(Scope2* scope, std::unique_ptr c) +void ConstraintGraphBuilder::addConstraint(NotNull scope, std::unique_ptr c) { - LUAU_ASSERT(scope); scope->constraints.emplace_back(std::move(c)); } @@ -62,7 +65,11 @@ void ConstraintGraphBuilder::visit(AstStatBlock* block) LUAU_ASSERT(rootScope == nullptr); scopes.emplace_back(block->location, std::make_unique()); rootScope = scopes.back().second.get(); - rootScope->returnType = freshTypePack(rootScope); + NotNull borrow = NotNull(rootScope); + + rootScope->returnType = freshTypePack(borrow); + + prepopulateGlobalScope(borrow, block); // TODO: We should share the global scope. rootScope->typeBindings["nil"] = singletonTypes.nilType; @@ -71,12 +78,26 @@ void ConstraintGraphBuilder::visit(AstStatBlock* block) rootScope->typeBindings["boolean"] = singletonTypes.booleanType; rootScope->typeBindings["thread"] = singletonTypes.threadType; - visit(rootScope, block); + visitBlockWithoutChildScope(borrow, block); +} + +void ConstraintGraphBuilder::visitBlockWithoutChildScope(NotNull scope, AstStatBlock* block) +{ + RecursionCounter counter{&recursionCount}; + + if (recursionCount >= FInt::LuauCheckRecursionLimit) + { + reportCodeTooComplex(block->location); + return; + } + + for (AstStat* stat : block->body) + visit(scope, stat); } -void ConstraintGraphBuilder::visit(Scope2* scope, AstStat* stat) +void ConstraintGraphBuilder::visit(NotNull scope, AstStat* stat) { - LUAU_ASSERT(scope); + RecursionLimiter limiter{&recursionCount, FInt::LuauCheckRecursionLimit}; if (auto s = stat->as()) visit(scope, s); @@ -100,10 +121,8 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStat* stat) LUAU_ASSERT(0); } -void ConstraintGraphBuilder::visit(Scope2* scope, AstStatLocal* local) +void ConstraintGraphBuilder::visit(NotNull scope, AstStatLocal* local) { - LUAU_ASSERT(scope); - std::vector varTypes; for (AstLocal* local : local->vars) @@ -148,23 +167,19 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatLocal* local) } } -void addConstraints(Constraint* constraint, Scope2* scope) +void addConstraints(Constraint* constraint, NotNull scope) { - LUAU_ASSERT(scope); - scope->constraints.reserve(scope->constraints.size() + scope->constraints.size()); for (const auto& c : scope->constraints) constraint->dependencies.push_back(NotNull{c.get()}); - for (Scope2* childScope : scope->children) + for (NotNull childScope : scope->children) addConstraints(constraint, childScope); } -void ConstraintGraphBuilder::visit(Scope2* scope, AstStatLocalFunction* function) +void ConstraintGraphBuilder::visit(NotNull scope, AstStatLocalFunction* function) { - LUAU_ASSERT(scope); - // Local // Global // Dotted path @@ -172,36 +187,31 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatLocalFunction* function TypeId functionType = nullptr; auto ty = scope->lookup(function->name); - if (ty.has_value()) - { - // TODO: This is duplicate definition of a local function. Is this allowed? - functionType = *ty; - } - else - { - functionType = arena->addType(BlockedTypeVar{}); - scope->bindings[function->name] = functionType; - } + LUAU_ASSERT(!ty.has_value()); // The parser ensures that every local function has a distinct Symbol for its name. + + functionType = arena->addType(BlockedTypeVar{}); + scope->bindings[function->name] = functionType; - auto [actualFunctionType, innerScope] = checkFunctionSignature(scope, function->func); - innerScope->bindings[function->name] = actualFunctionType; + FunctionSignature sig = checkFunctionSignature(scope, function->func); + sig.bodyScope->bindings[function->name] = sig.signature; - checkFunctionBody(innerScope, function->func); + checkFunctionBody(sig.bodyScope, function->func); - std::unique_ptr c{new Constraint{GeneralizationConstraint{functionType, actualFunctionType, innerScope}}}; - addConstraints(c.get(), innerScope); + std::unique_ptr c{ + new Constraint{GeneralizationConstraint{functionType, sig.signature, sig.signatureScope ? sig.signatureScope : sig.bodyScope}}}; + addConstraints(c.get(), sig.bodyScope); addConstraint(scope, std::move(c)); } -void ConstraintGraphBuilder::visit(Scope2* scope, AstStatFunction* function) +void ConstraintGraphBuilder::visit(NotNull scope, AstStatFunction* function) { // Name could be AstStatLocal, AstStatGlobal, AstStatIndexName. // With or without self TypeId functionType = nullptr; - auto [actualFunctionType, innerScope] = checkFunctionSignature(scope, function->func); + FunctionSignature sig = checkFunctionSignature(scope, function->func); if (AstExprLocal* localName = function->name->as()) { @@ -216,7 +226,7 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatFunction* function) functionType = arena->addType(BlockedTypeVar{}); scope->bindings[localName->local] = functionType; } - innerScope->bindings[localName->local] = actualFunctionType; + sig.bodyScope->bindings[localName->local] = sig.signature; } else if (AstExprGlobal* globalName = function->name->as()) { @@ -231,32 +241,48 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatFunction* function) functionType = arena->addType(BlockedTypeVar{}); rootScope->bindings[globalName->name] = functionType; } - innerScope->bindings[globalName->name] = actualFunctionType; + sig.bodyScope->bindings[globalName->name] = sig.signature; } else if (AstExprIndexName* indexName = function->name->as()) { - LUAU_ASSERT(0); // not yet implemented + TypeId containingTableType = check(scope, indexName->expr); + + functionType = arena->addType(BlockedTypeVar{}); + TypeId prospectiveTableType = + arena->addType(TableTypeVar{}); // TODO look into stack utilization. This is probably ok because it scales with AST depth. + NotNull prospectiveTable{getMutable(prospectiveTableType)}; + + Property& prop = prospectiveTable->props[indexName->index.value]; + prop.type = functionType; + prop.location = function->name->location; + + addConstraint(scope, SubtypeConstraint{containingTableType, prospectiveTableType}); + } + else if (AstExprError* err = function->name->as()) + { + functionType = singletonTypes.errorRecoveryType(); } - checkFunctionBody(innerScope, function->func); + LUAU_ASSERT(functionType != nullptr); - std::unique_ptr c{new Constraint{GeneralizationConstraint{functionType, actualFunctionType, innerScope}}}; - addConstraints(c.get(), innerScope); + checkFunctionBody(sig.bodyScope, function->func); + + std::unique_ptr c{ + new Constraint{GeneralizationConstraint{functionType, sig.signature, sig.signatureScope ? sig.signatureScope : sig.bodyScope}}}; + addConstraints(c.get(), sig.bodyScope); addConstraint(scope, std::move(c)); } -void ConstraintGraphBuilder::visit(Scope2* scope, AstStatReturn* ret) +void ConstraintGraphBuilder::visit(NotNull scope, AstStatReturn* ret) { - LUAU_ASSERT(scope); - TypePackId exprTypes = checkPack(scope, ret->list); addConstraint(scope, PackSubtypeConstraint{exprTypes, scope->returnType}); } -void ConstraintGraphBuilder::visit(Scope2* scope, AstStatBlock* block) +void ConstraintGraphBuilder::visit(NotNull scope, AstStatBlock* block) { - LUAU_ASSERT(scope); + NotNull innerScope = childScope(block->location, scope); // In order to enable mutually-recursive type aliases, we need to // populate the type bindings before we actually check any of the @@ -271,11 +297,10 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatBlock* block) } } - for (AstStat* stat : block->body) - visit(scope, stat); + visitBlockWithoutChildScope(innerScope, block); } -void ConstraintGraphBuilder::visit(Scope2* scope, AstStatAssign* assign) +void ConstraintGraphBuilder::visit(NotNull scope, AstStatAssign* assign) { TypePackId varPackId = checkExprList(scope, assign->vars); TypePackId valuePack = checkPack(scope, assign->values); @@ -283,21 +308,21 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatAssign* assign) addConstraint(scope, PackSubtypeConstraint{valuePack, varPackId}); } -void ConstraintGraphBuilder::visit(Scope2* scope, AstStatIf* ifStatement) +void ConstraintGraphBuilder::visit(NotNull scope, AstStatIf* ifStatement) { check(scope, ifStatement->condition); - Scope2* thenScope = childScope(ifStatement->thenbody->location, scope); + NotNull thenScope = childScope(ifStatement->thenbody->location, scope); visit(thenScope, ifStatement->thenbody); if (ifStatement->elsebody) { - Scope2* elseScope = childScope(ifStatement->elsebody->location, scope); + NotNull elseScope = childScope(ifStatement->elsebody->location, scope); visit(elseScope, ifStatement->elsebody); } } -void ConstraintGraphBuilder::visit(Scope2* scope, AstStatTypeAlias* alias) +void ConstraintGraphBuilder::visit(NotNull scope, AstStatTypeAlias* alias) { // TODO: Exported type aliases // TODO: Generic type aliases @@ -307,6 +332,10 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatTypeAlias* alias) // AST to set up typeBindings. If it's not, we've somehow skipped // this alias in that first pass. LUAU_ASSERT(it != scope->typeBindings.end()); + if (it == scope->typeBindings.end()) + { + ice->ice("Type alias does not have a pre-populated binding", alias->location); + } TypeId ty = resolveType(scope, alias->type); @@ -319,10 +348,8 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatTypeAlias* alias) addConstraint(scope, NameConstraint{ty, alias->name.value}); } -TypePackId ConstraintGraphBuilder::checkPack(Scope2* scope, AstArray exprs) +TypePackId ConstraintGraphBuilder::checkPack(NotNull scope, AstArray exprs) { - LUAU_ASSERT(scope); - if (exprs.size == 0) return arena->addTypePack({}); @@ -342,7 +369,7 @@ TypePackId ConstraintGraphBuilder::checkPack(Scope2* scope, AstArray e return arena->addTypePack(TypePack{std::move(types), last}); } -TypePackId ConstraintGraphBuilder::checkExprList(Scope2* scope, const AstArray& exprs) +TypePackId ConstraintGraphBuilder::checkExprList(NotNull scope, const AstArray& exprs) { TypePackId result = arena->addTypePack({}); TypePack* resultPack = getMutable(result); @@ -363,9 +390,15 @@ TypePackId ConstraintGraphBuilder::checkExprList(Scope2* scope, const AstArray scope, AstExpr* expr) { - LUAU_ASSERT(scope); + RecursionCounter counter{&recursionCount}; + + if (recursionCount >= FInt::LuauCheckRecursionLimit) + { + reportCodeTooComplex(expr->location); + return singletonTypes.errorRecoveryTypePack(); + } TypePackId result = nullptr; @@ -384,7 +417,7 @@ TypePackId ConstraintGraphBuilder::checkPack(Scope2* scope, AstExpr* expr) astOriginalCallTypes[call->func] = fnType; - TypeId instantiatedType = freshType(scope); + TypeId instantiatedType = arena->addType(BlockedTypeVar{}); addConstraint(scope, InstantiationConstraint{instantiatedType, fnType}); TypePackId rets = freshTypePack(scope); @@ -394,6 +427,13 @@ TypePackId ConstraintGraphBuilder::checkPack(Scope2* scope, AstExpr* expr) addConstraint(scope, SubtypeConstraint{inferredFnType, instantiatedType}); result = rets; } + else if (AstExprVarargs* varargs = expr->as()) + { + if (scope->varargPack) + result = *scope->varargPack; + else + result = singletonTypes.errorRecoveryTypePack(); + } else { TypeId t = check(scope, expr); @@ -405,9 +445,15 @@ TypePackId ConstraintGraphBuilder::checkPack(Scope2* scope, AstExpr* expr) return result; } -TypeId ConstraintGraphBuilder::check(Scope2* scope, AstExpr* expr) +TypeId ConstraintGraphBuilder::check(NotNull scope, AstExpr* expr) { - LUAU_ASSERT(scope); + RecursionCounter counter{&recursionCount}; + + if (recursionCount >= FInt::LuauCheckRecursionLimit) + { + reportCodeTooComplex(expr->location); + return singletonTypes.errorRecoveryType(); + } TypeId result = nullptr; @@ -435,37 +481,38 @@ TypeId ConstraintGraphBuilder::check(Scope2* scope, AstExpr* expr) if (ty) result = *ty; else - result = singletonTypes.errorRecoveryType(); // FIXME? Record an error at this point? - } - else if (auto a = expr->as()) - { - TypePackId packResult = checkPack(scope, expr); - if (auto f = first(packResult)) - return *f; - else if (get(packResult)) { - TypeId typeResult = freshType(scope); - TypePack onePack{{typeResult}, freshTypePack(scope)}; - TypePackId oneTypePack = arena->addTypePack(std::move(onePack)); - - addConstraint(scope, PackSubtypeConstraint{packResult, oneTypePack}); - - return typeResult; + /* prepopulateGlobalScope() has already added all global functions to the environment by this point, so any + * global that is not already in-scope is definitely an unknown symbol. + */ + reportError(g->location, UnknownSymbol{g->name.value}); + result = singletonTypes.errorRecoveryType(); // FIXME? Record an error at this point? } } + else if (expr->is()) + result = flattenPack(scope, expr->location, checkPack(scope, expr)); + else if (expr->is()) + result = flattenPack(scope, expr->location, checkPack(scope, expr)); else if (auto a = expr->as()) { - auto [fnType, functionScope] = checkFunctionSignature(scope, a); - checkFunctionBody(functionScope, a); - return fnType; + FunctionSignature sig = checkFunctionSignature(scope, a); + checkFunctionBody(sig.bodyScope, a); + return sig.signature; } else if (auto indexName = expr->as()) - { result = check(scope, indexName); - } + else if (auto indexExpr = expr->as()) + result = check(scope, indexExpr); else if (auto table = expr->as()) - { result = checkExprTable(scope, table); + else if (auto unary = expr->as()) + result = check(scope, unary); + else if (auto binary = expr->as()) + result = check(scope, binary); + else if (auto err = expr->as()) + { + // Open question: Should we traverse into this? + result = singletonTypes.errorRecoveryType(); } else { @@ -478,7 +525,7 @@ TypeId ConstraintGraphBuilder::check(Scope2* scope, AstExpr* expr) return result; } -TypeId ConstraintGraphBuilder::check(Scope2* scope, AstExprIndexName* indexName) +TypeId ConstraintGraphBuilder::check(NotNull scope, AstExprIndexName* indexName) { TypeId obj = check(scope, indexName->expr); TypeId result = freshType(scope); @@ -494,7 +541,67 @@ TypeId ConstraintGraphBuilder::check(Scope2* scope, AstExprIndexName* indexName) return result; } -TypeId ConstraintGraphBuilder::checkExprTable(Scope2* scope, AstExprTable* expr) +TypeId ConstraintGraphBuilder::check(NotNull scope, AstExprIndexExpr* indexExpr) +{ + TypeId obj = check(scope, indexExpr->expr); + TypeId indexType = check(scope, indexExpr->index); + + TypeId result = freshType(scope); + + TableIndexer indexer{indexType, result}; + TypeId tableType = arena->addType(TableTypeVar{TableTypeVar::Props{}, TableIndexer{indexType, result}, TypeLevel{}, TableState::Free}); + + addConstraint(scope, SubtypeConstraint{obj, tableType}); + + return result; +} + +TypeId ConstraintGraphBuilder::check(NotNull scope, AstExprUnary* unary) +{ + TypeId operandType = check(scope, unary->expr); + + switch (unary->op) + { + case AstExprUnary::Minus: + { + TypeId resultType = arena->addType(BlockedTypeVar{}); + addConstraint(scope, UnaryConstraint{AstExprUnary::Minus, operandType, resultType}); + return resultType; + } + default: + LUAU_ASSERT(0); + } + + LUAU_UNREACHABLE(); + return singletonTypes.errorRecoveryType(); +} + +TypeId ConstraintGraphBuilder::check(NotNull scope, AstExprBinary* binary) +{ + TypeId leftType = check(scope, binary->left); + TypeId rightType = check(scope, binary->right); + switch (binary->op) + { + case AstExprBinary::Or: + { + addConstraint(scope, SubtypeConstraint{leftType, rightType}); + return leftType; + } + case AstExprBinary::Sub: + { + TypeId resultType = arena->addType(BlockedTypeVar{}); + addConstraint(scope, BinaryConstraint{AstExprBinary::Sub, leftType, rightType, resultType}); + return resultType; + } + default: + LUAU_ASSERT(0); + } + + LUAU_ASSERT(0); + return nullptr; +} + +TypeId ConstraintGraphBuilder::checkExprTable(NotNull scope, AstExprTable* expr) { TypeId ty = arena->addType(TableTypeVar{}); TableTypeVar* ttv = getMutable(ty); @@ -515,6 +622,8 @@ TypeId ConstraintGraphBuilder::checkExprTable(Scope2* scope, AstExprTable* expr) for (const AstExprTable::Item& item : expr->items) { TypeId itemTy = check(scope, item.value); + if (get(follow(itemTy))) + return ty; if (item.key) { @@ -542,47 +651,111 @@ TypeId ConstraintGraphBuilder::checkExprTable(Scope2* scope, AstExprTable* expr) return ty; } -std::pair ConstraintGraphBuilder::checkFunctionSignature(Scope2* parent, AstExprFunction* fn) +ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionSignature(NotNull parent, AstExprFunction* fn) { - Scope2* innerScope = childScope(fn->body->location, parent); - TypePackId returnType = freshTypePack(innerScope); - innerScope->returnType = returnType; + Scope2* signatureScope = nullptr; + Scope2* bodyScope = nullptr; + TypePackId returnType = nullptr; + + std::vector genericTypes; + std::vector genericTypePacks; + + bool hasGenerics = fn->generics.size > 0 || fn->genericPacks.size > 0; + + // If we don't have any generics, we can save some memory and compute by not + // creating the signatureScope, which is only used to scope the declared + // generics properly. + if (hasGenerics) + { + NotNull signatureBorrow = childScope(fn->location, parent); + signatureScope = signatureBorrow.get(); + + // We need to assign returnType before creating bodyScope so that the + // return type gets propogated to bodyScope. + returnType = freshTypePack(signatureBorrow); + signatureScope->returnType = returnType; + + bodyScope = childScope(fn->body->location, signatureBorrow).get(); + + std::vector> genericDefinitions = createGenerics(signatureBorrow, fn->generics); + std::vector> genericPackDefinitions = createGenericPacks(signatureBorrow, fn->genericPacks); + + // We do not support default values on function generics, so we only + // care about the types involved. + for (const auto& [name, g] : genericDefinitions) + { + genericTypes.push_back(g.ty); + signatureScope->typeBindings[name] = g.ty; + } + + for (const auto& [name, g] : genericPackDefinitions) + { + genericTypePacks.push_back(g.tp); + signatureScope->typePackBindings[name] = g.tp; + } + } + else + { + NotNull bodyBorrow = childScope(fn->body->location, parent); + bodyScope = bodyBorrow.get(); + + returnType = freshTypePack(bodyBorrow); + bodyBorrow->returnType = returnType; + + // To eliminate the need to branch on hasGenerics below, we say that the + // signature scope is the body scope when there is no real signature + // scope. + signatureScope = bodyScope; + } + + NotNull bodyBorrow = NotNull(bodyScope); + NotNull signatureBorrow = NotNull(signatureScope); if (fn->returnAnnotation) { - TypePackId annotatedRetType = resolveTypePack(innerScope, *fn->returnAnnotation); - addConstraint(innerScope, PackSubtypeConstraint{returnType, annotatedRetType}); + TypePackId annotatedRetType = resolveTypePack(signatureBorrow, *fn->returnAnnotation); + addConstraint(signatureBorrow, PackSubtypeConstraint{returnType, annotatedRetType}); } std::vector argTypes; for (AstLocal* local : fn->args) { - TypeId t = freshType(innerScope); + TypeId t = freshType(signatureBorrow); argTypes.push_back(t); - innerScope->bindings[local] = t; + signatureScope->bindings[local] = t; if (local->annotation) { - TypeId argAnnotation = resolveType(innerScope, local->annotation); - addConstraint(innerScope, SubtypeConstraint{t, argAnnotation}); + TypeId argAnnotation = resolveType(signatureBorrow, local->annotation); + addConstraint(signatureBorrow, SubtypeConstraint{t, argAnnotation}); } } // TODO: Vararg annotation. + // TODO: Preserve argument names in the function's type. FunctionTypeVar actualFunction{arena->addTypePack(argTypes), returnType}; + actualFunction.hasNoGenerics = !hasGenerics; + actualFunction.generics = std::move(genericTypes); + actualFunction.genericPacks = std::move(genericTypePacks); + TypeId actualFunctionType = arena->addType(std::move(actualFunction)); LUAU_ASSERT(actualFunctionType); astTypes[fn] = actualFunctionType; - return {actualFunctionType, innerScope}; + return { + /* signature */ actualFunctionType, + // Undo the workaround we made above: if there's no signature scope, + // don't report it. + /* signatureScope */ hasGenerics ? signatureScope : nullptr, + /* bodyScope */ bodyBorrow, + }; } -void ConstraintGraphBuilder::checkFunctionBody(Scope2* scope, AstExprFunction* fn) +void ConstraintGraphBuilder::checkFunctionBody(NotNull scope, AstExprFunction* fn) { - for (AstStat* stat : fn->body->body) - visit(scope, stat); + visitBlockWithoutChildScope(scope, fn->body); // If it is possible for execution to reach the end of the function, the return type must be compatible with () @@ -593,7 +766,7 @@ void ConstraintGraphBuilder::checkFunctionBody(Scope2* scope, AstExprFunction* f } } -TypeId ConstraintGraphBuilder::resolveType(Scope2* scope, AstType* ty) +TypeId ConstraintGraphBuilder::resolveType(NotNull scope, AstType* ty) { TypeId result = nullptr; @@ -636,29 +809,73 @@ TypeId ConstraintGraphBuilder::resolveType(Scope2* scope, AstType* ty) } else if (auto fn = ty->as()) { - // TODO: Generic functions. - // TODO: Scope (though it may not be needed). // TODO: Recursion limit. - TypePackId argTypes = resolveTypePack(scope, fn->argTypes); - TypePackId returnTypes = resolveTypePack(scope, fn->returnTypes); + bool hasGenerics = fn->generics.size > 0 || fn->genericPacks.size > 0; + Scope2* signatureScope = nullptr; + + std::vector genericTypes; + std::vector genericTypePacks; - // TODO: Is this the right constructor to use? - result = arena->addType(FunctionTypeVar{argTypes, returnTypes}); + // If we don't have generics, we do not need to generate a child scope + // for the generic bindings to live on. + if (hasGenerics) + { + NotNull signatureBorrow = childScope(fn->location, scope); + signatureScope = signatureBorrow.get(); + + std::vector> genericDefinitions = createGenerics(signatureBorrow, fn->generics); + std::vector> genericPackDefinitions = createGenericPacks(signatureBorrow, fn->genericPacks); - FunctionTypeVar* ftv = getMutable(result); - ftv->argNames.reserve(fn->argNames.size); + for (const auto& [name, g] : genericDefinitions) + { + genericTypes.push_back(g.ty); + signatureBorrow->typeBindings[name] = g.ty; + } + + for (const auto& [name, g] : genericPackDefinitions) + { + genericTypePacks.push_back(g.tp); + signatureBorrow->typePackBindings[name] = g.tp; + } + } + else + { + // To eliminate the need to branch on hasGenerics below, we say that + // the signature scope is the parent scope if we don't have + // generics. + signatureScope = scope.get(); + } + + NotNull signatureBorrow(signatureScope); + + TypePackId argTypes = resolveTypePack(signatureBorrow, fn->argTypes); + TypePackId returnTypes = resolveTypePack(signatureBorrow, fn->returnTypes); + + // TODO: FunctionTypeVar needs a pointer to the scope so that we know + // how to quantify/instantiate it. + FunctionTypeVar ftv{argTypes, returnTypes}; + + // This replicates the behavior of the appropriate FunctionTypeVar + // constructors. + ftv.hasNoGenerics = !hasGenerics; + ftv.generics = std::move(genericTypes); + ftv.genericPacks = std::move(genericTypePacks); + + ftv.argNames.reserve(fn->argNames.size); for (const auto& el : fn->argNames) { if (el) { const auto& [name, location] = *el; - ftv->argNames.push_back(FunctionArgument{name.value, location}); + ftv.argNames.push_back(FunctionArgument{name.value, location}); } else { - ftv->argNames.push_back(std::nullopt); + ftv.argNames.push_back(std::nullopt); } } + + result = arena->addType(std::move(ftv)); } else if (auto tof = ty->as()) { @@ -710,7 +927,7 @@ TypeId ConstraintGraphBuilder::resolveType(Scope2* scope, AstType* ty) return result; } -TypePackId ConstraintGraphBuilder::resolveTypePack(Scope2* scope, AstTypePack* tp) +TypePackId ConstraintGraphBuilder::resolveTypePack(NotNull scope, AstTypePack* tp) { TypePackId result; if (auto expl = tp->as()) @@ -736,7 +953,7 @@ TypePackId ConstraintGraphBuilder::resolveTypePack(Scope2* scope, AstTypePack* t return result; } -TypePackId ConstraintGraphBuilder::resolveTypePack(Scope2* scope, const AstTypeList& list) +TypePackId ConstraintGraphBuilder::resolveTypePack(NotNull scope, const AstTypeList& list) { std::vector head; @@ -754,16 +971,108 @@ TypePackId ConstraintGraphBuilder::resolveTypePack(Scope2* scope, const AstTypeL return arena->addTypePack(TypePack{head, tail}); } -void collectConstraints(std::vector>& result, Scope2* scope) +std::vector> ConstraintGraphBuilder::createGenerics(NotNull scope, AstArray generics) +{ + std::vector> result; + for (const auto& generic : generics) + { + TypeId genericTy = arena->addType(GenericTypeVar{scope, generic.name.value}); + std::optional defaultTy = std::nullopt; + + if (generic.defaultValue) + defaultTy = resolveType(scope, generic.defaultValue); + + result.push_back({generic.name.value, GenericTypeDefinition{ + genericTy, + defaultTy, + }}); + } + + return result; +} + +std::vector> ConstraintGraphBuilder::createGenericPacks( + NotNull scope, AstArray generics) +{ + std::vector> result; + for (const auto& generic : generics) + { + TypePackId genericTy = arena->addTypePack(TypePackVar{GenericTypePack{scope, generic.name.value}}); + std::optional defaultTy = std::nullopt; + + if (generic.defaultValue) + defaultTy = resolveTypePack(scope, generic.defaultValue); + + result.push_back({generic.name.value, GenericTypePackDefinition{ + genericTy, + defaultTy, + }}); + } + + return result; +} + +TypeId ConstraintGraphBuilder::flattenPack(NotNull scope, Location location, TypePackId tp) +{ + if (auto f = first(tp)) + return *f; + + TypeId typeResult = freshType(scope); + TypePack onePack{{typeResult}, freshTypePack(scope)}; + TypePackId oneTypePack = arena->addTypePack(std::move(onePack)); + + addConstraint(scope, PackSubtypeConstraint{tp, oneTypePack}); + + return typeResult; +} + +void ConstraintGraphBuilder::reportError(Location location, TypeErrorData err) +{ + errors.push_back(TypeError{location, moduleName, std::move(err)}); +} + +void ConstraintGraphBuilder::reportCodeTooComplex(Location location) +{ + errors.push_back(TypeError{location, moduleName, CodeTooComplex{}}); +} + +struct GlobalPrepopulator : AstVisitor +{ + const NotNull globalScope; + const NotNull arena; + + GlobalPrepopulator(NotNull globalScope, NotNull arena) + : globalScope(globalScope) + , arena(arena) + { + } + + bool visit(AstStatFunction* function) override + { + if (AstExprGlobal* g = function->name->as()) + globalScope->bindings[g->name] = arena->addType(BlockedTypeVar{}); + + return true; + } +}; + +void ConstraintGraphBuilder::prepopulateGlobalScope(NotNull globalScope, AstStatBlock* program) +{ + GlobalPrepopulator gp{NotNull{globalScope}, arena}; + + program->visit(&gp); +} + +void collectConstraints(std::vector>& result, NotNull scope) { for (const auto& c : scope->constraints) result.push_back(NotNull{c.get()}); - for (Scope2* child : scope->children) + for (NotNull child : scope->children) collectConstraints(result, child); } -std::vector> collectConstraints(Scope2* rootScope) +std::vector> collectConstraints(NotNull rootScope) { std::vector> result; collectConstraints(result, rootScope); diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 9e3552360..077a4e282 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -13,7 +13,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false); namespace Luau { -[[maybe_unused]] static void dumpBindings(Scope2* scope, ToStringOptions& opts) +[[maybe_unused]] static void dumpBindings(NotNull scope, ToStringOptions& opts) { for (const auto& [k, v] : scope->bindings) { @@ -22,22 +22,22 @@ namespace Luau printf("\t%s : %s\n", k.c_str(), d.name.c_str()); } - for (Scope2* child : scope->children) + for (NotNull child : scope->children) dumpBindings(child, opts); } -static void dumpConstraints(Scope2* scope, ToStringOptions& opts) +static void dumpConstraints(NotNull scope, ToStringOptions& opts) { for (const ConstraintPtr& c : scope->constraints) { printf("\t%s\n", toString(*c, opts).c_str()); } - for (Scope2* child : scope->children) + for (NotNull child : scope->children) dumpConstraints(child, opts); } -void dump(Scope2* rootScope, ToStringOptions& opts) +void dump(NotNull rootScope, ToStringOptions& opts) { printf("constraints:\n"); dumpConstraints(rootScope, opts); @@ -55,7 +55,7 @@ void dump(ConstraintSolver* cs, ToStringOptions& opts) } } -ConstraintSolver::ConstraintSolver(TypeArena* arena, Scope2* rootScope) +ConstraintSolver::ConstraintSolver(TypeArena* arena, NotNull rootScope) : arena(arena) , constraints(collectConstraints(rootScope)) , rootScope(rootScope) @@ -180,6 +180,10 @@ bool ConstraintSolver::tryDispatch(NotNull constraint, bool fo success = tryDispatch(*gc, constraint, force); else if (auto ic = get(*constraint)) success = tryDispatch(*ic, constraint, force); + else if (auto uc = get(*constraint)) + success = tryDispatch(*uc, constraint, force); + else if (auto bc = get(*constraint)) + success = tryDispatch(*bc, constraint, force); else if (auto nc = get(*constraint)) success = tryDispatch(*nc, constraint); else @@ -246,12 +250,65 @@ bool ConstraintSolver::tryDispatch(const InstantiationConstraint& c, NotNull instantiated = inst.substitute(c.superType); LUAU_ASSERT(instantiated); // TODO FIXME HANDLE THIS - unify(c.subType, *instantiated); + if (isBlocked(c.subType)) + asMutable(c.subType)->ty.emplace(*instantiated); + else + unify(c.subType, *instantiated); + unblock(c.subType); return true; } +bool ConstraintSolver::tryDispatch(const UnaryConstraint& c, NotNull constraint, bool force) +{ + TypeId operandType = follow(c.operandType); + + if (isBlocked(operandType)) + return block(operandType, constraint); + + if (get(operandType)) + return block(operandType, constraint); + + LUAU_ASSERT(get(c.resultType)); + + if (isNumber(operandType) || get(operandType) || get(operandType)) + { + asMutable(c.resultType)->ty.emplace(c.operandType); + return true; + } + + LUAU_ASSERT(0); // TODO metatable handling + return false; +} + +bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNull constraint, bool force) +{ + TypeId leftType = follow(c.leftType); + TypeId rightType = follow(c.rightType); + + if (isBlocked(leftType) || isBlocked(rightType)) + { + block(leftType, constraint); + block(rightType, constraint); + return false; + } + + if (isNumber(leftType)) + { + unify(leftType, rightType); + asMutable(c.resultType)->ty.emplace(leftType); + return true; + } + + if (get(leftType) && !force) + return block(leftType, constraint); + + // TODO metatables, classes + + return true; +} + bool ConstraintSolver::tryDispatch(const NameConstraint& c, NotNull constraint) { if (isBlocked(c.namedType)) diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index 2407e3ef0..1b5275fdd 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -1,6 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/BuiltinDefinitions.h" +LUAU_FASTFLAG(LuauCheckLenMT) + namespace Luau { @@ -202,7 +204,13 @@ declare function unpack(tab: {V}, i: number?, j: number?): ...V std::string getBuiltinDefinitionSource() { - return kBuiltinDefinitionLuaSrc; + std::string result = kBuiltinDefinitionLuaSrc; + + // TODO: move this into kBuiltinDefinitionLuaSrc + if (FFlag::LuauCheckLenMT) + result += "declare function rawlen(obj: {[K]: V} | string): number\n"; + + return result; } } // namespace Luau diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 85c5dbc8a..4cfaa112a 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -787,14 +787,32 @@ const SourceModule* Frontend::getSourceModule(const ModuleName& moduleName) cons return const_cast(this)->getSourceModule(moduleName); } +NotNull Frontend::getGlobalScope2() +{ + if (!globalScope2) + { + const SingletonTypes& singletonTypes = getSingletonTypes(); + + globalScope2 = std::make_unique(); + globalScope2->typeBindings["nil"] = singletonTypes.nilType; + globalScope2->typeBindings["number"] = singletonTypes.numberType; + globalScope2->typeBindings["string"] = singletonTypes.stringType; + globalScope2->typeBindings["boolean"] = singletonTypes.booleanType; + globalScope2->typeBindings["thread"] = singletonTypes.threadType; + } + + return NotNull(globalScope2.get()); +} + ModulePtr Frontend::check(const SourceModule& sourceModule, Mode mode, const ScopePtr& environmentScope) { ModulePtr result = std::make_shared(); - ConstraintGraphBuilder cgb{&result->internalTypes}; + ConstraintGraphBuilder cgb{sourceModule.name, &result->internalTypes, NotNull(&iceHandler), getGlobalScope2()}; cgb.visit(sourceModule.root); + result->errors = std::move(cgb.errors); - ConstraintSolver cs{&result->internalTypes, cgb.rootScope}; + ConstraintSolver cs{&result->internalTypes, NotNull(cgb.rootScope)}; cs.run(); result->scope2s = std::move(cgb.scopes); diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index d36665e27..8ce7f7423 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -5,7 +5,6 @@ #include #include "Luau/Clone.h" -#include "Luau/Substitution.h" #include "Luau/Unifier.h" #include "Luau/VisitTypeVar.h" @@ -16,7 +15,6 @@ LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineTableFix, false); LUAU_FASTFLAGVARIABLE(LuauNormalizeFlagIsConservative, false); LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineEqFix, false); -LUAU_FASTFLAGVARIABLE(LuauReplaceReplacer, false); LUAU_FASTFLAG(LuauQuantifyConstrained) namespace Luau @@ -25,238 +23,33 @@ namespace Luau namespace { -struct Replacer : Substitution +struct Replacer { + TypeArena* arena; TypeId sourceType; TypeId replacedType; - DenseHashMap replacedTypes{nullptr}; - DenseHashMap replacedPacks{nullptr}; + DenseHashMap newTypes; Replacer(TypeArena* arena, TypeId sourceType, TypeId replacedType) - : Substitution(TxnLog::empty(), arena) + : arena(arena) , sourceType(sourceType) , replacedType(replacedType) + , newTypes(nullptr) { } - bool isDirty(TypeId ty) override - { - if (!sourceType) - return false; - - auto vecHasSourceType = [sourceType = sourceType](const auto& vec) { - return end(vec) != std::find(begin(vec), end(vec), sourceType); - }; - - // Walk every kind of TypeVar and find pointers to sourceType - if (auto t = get(ty)) - return false; - else if (auto t = get(ty)) - return false; - else if (auto t = get(ty)) - return false; - else if (auto t = get(ty)) - return false; - else if (auto t = get(ty)) - return vecHasSourceType(t->parts); - else if (auto t = get(ty)) - return false; - else if (auto t = get(ty)) - { - if (vecHasSourceType(t->generics)) - return true; - - return false; - } - else if (auto t = get(ty)) - { - if (t->boundTo) - return *t->boundTo == sourceType; - - for (const auto& [_name, prop] : t->props) - { - if (prop.type == sourceType) - return true; - } - - if (auto indexer = t->indexer) - { - if (indexer->indexType == sourceType || indexer->indexResultType == sourceType) - return true; - } - - if (vecHasSourceType(t->instantiatedTypeParams)) - return true; - - return false; - } - else if (auto t = get(ty)) - return t->table == sourceType || t->metatable == sourceType; - else if (auto t = get(ty)) - return false; - else if (auto t = get(ty)) - return false; - else if (auto t = get(ty)) - return vecHasSourceType(t->options); - else if (auto t = get(ty)) - return vecHasSourceType(t->parts); - else if (auto t = get(ty)) - return false; - - LUAU_ASSERT(!"Luau::Replacer::isDirty internal error: Unknown TypeVar type"); - LUAU_UNREACHABLE(); - } - - bool isDirty(TypePackId tp) override - { - if (auto it = replacedPacks.find(tp)) - return false; - - if (auto pack = get(tp)) - { - for (TypeId ty : pack->head) - { - if (ty == sourceType) - return true; - } - return false; - } - else if (auto vtp = get(tp)) - return vtp->ty == sourceType; - else - return false; - } - - TypeId clean(TypeId ty) override - { - LUAU_ASSERT(sourceType && replacedType); - - // Walk every kind of TypeVar and create a copy with sourceType replaced by replacedType - // Before returning, memoize the result for later use. - - // Helpfully, Substitution::clone() only shallow-clones the kinds of types that we care to work with. This - // function returns the identity for things like primitives. - TypeId res = clone(ty); - - if (auto t = get(res)) - LUAU_ASSERT(!"Impossible"); - else if (auto t = get(res)) - LUAU_ASSERT(!"Impossible"); - else if (auto t = get(res)) - LUAU_ASSERT(!"Impossible"); - else if (auto t = get(res)) - LUAU_ASSERT(!"Impossible"); - else if (auto t = getMutable(res)) - { - for (TypeId& part : t->parts) - { - if (part == sourceType) - part = replacedType; - } - } - else if (auto t = get(res)) - LUAU_ASSERT(!"Impossible"); - else if (auto t = getMutable(res)) - { - // The constituent typepacks are cleaned separately. We just need to walk the generics array. - for (TypeId& g : t->generics) - { - if (g == sourceType) - g = replacedType; - } - } - else if (auto t = getMutable(res)) - { - for (auto& [_key, prop] : t->props) - { - if (prop.type == sourceType) - prop.type = replacedType; - } - } - else if (auto t = getMutable(res)) - { - if (t->table == sourceType) - t->table = replacedType; - if (t->metatable == sourceType) - t->table = replacedType; - } - else if (auto t = get(res)) - LUAU_ASSERT(!"Impossible"); - else if (auto t = get(res)) - LUAU_ASSERT(!"Impossible"); - else if (auto t = getMutable(res)) - { - for (TypeId& option : t->options) - { - if (option == sourceType) - option = replacedType; - } - } - else if (auto t = getMutable(res)) - { - for (TypeId& part : t->parts) - { - if (part == sourceType) - part = replacedType; - } - } - else if (auto t = get(res)) - LUAU_ASSERT(!"Impossible"); - else - LUAU_ASSERT(!"Luau::Replacer::clean internal error: Unknown TypeVar type"); - - replacedTypes[ty] = res; - return res; - } - - TypePackId clean(TypePackId tp) override - { - TypePackId res = clone(tp); - - if (auto pack = getMutable(res)) - { - for (TypeId& type : pack->head) - { - if (type == sourceType) - type = replacedType; - } - } - else if (auto vtp = getMutable(res)) - { - if (vtp->ty == sourceType) - vtp->ty = replacedType; - } - - replacedPacks[tp] = res; - return res; - } - TypeId smartClone(TypeId t) { - if (FFlag::LuauReplaceReplacer) - { - // The new smartClone is just a memoized clone() - // TODO: Remove the Substitution base class and all other methods from this struct. - // Add DenseHashMap newTypes; - t = log->follow(t); - TypeId* res = newTypes.find(t); - if (res) - return *res; - - TypeId result = shallowClone(t, *arena, TxnLog::empty()); - newTypes[t] = result; - newTypes[result] = result; - - return result; - } - else - { - std::optional res = replace(t); - LUAU_ASSERT(res.has_value()); // TODO think about this - if (*res == t) - return clone(t); + t = follow(t); + TypeId* res = newTypes.find(t); + if (res) return *res; - } + + TypeId result = shallowClone(t, *arena, TxnLog::empty()); + newTypes[t] = result; + newTypes[result] = result; + + return result; } }; diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index 40e14c689..294c479de 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -8,6 +8,7 @@ #include "Luau/VisitTypeVar.h" LUAU_FASTFLAG(LuauAlwaysQuantify); +LUAU_FASTFLAG(DebugLuauSharedSelf) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAGVARIABLE(LuauQuantifyConstrained, false) @@ -158,24 +159,61 @@ struct Quantifier final : TypeVarOnceVisitor void quantify(TypeId ty, TypeLevel level) { - Quantifier q{level}; - q.traverse(ty); - - FunctionTypeVar* ftv = getMutable(ty); - LUAU_ASSERT(ftv); - if (FFlag::LuauAlwaysQuantify) + if (FFlag::DebugLuauSharedSelf) { - ftv->generics.insert(ftv->generics.end(), q.generics.begin(), q.generics.end()); - ftv->genericPacks.insert(ftv->genericPacks.end(), q.genericPacks.begin(), q.genericPacks.end()); + ty = follow(ty); + + if (auto ttv = getTableType(ty); ttv && ttv->selfTy) + { + Quantifier selfQ{level}; + selfQ.traverse(*ttv->selfTy); + + Quantifier q{level}; + q.traverse(ty); + + for (const auto& [_, prop] : ttv->props) + { + auto ftv = getMutable(follow(prop.type)); + if (!ftv || !ftv->hasSelf) + continue; + + if (Luau::first(ftv->argTypes) == ttv->selfTy) + { + ftv->generics.insert(ftv->generics.end(), selfQ.generics.begin(), selfQ.generics.end()); + ftv->genericPacks.insert(ftv->genericPacks.end(), selfQ.genericPacks.begin(), selfQ.genericPacks.end()); + } + } + } + else if (auto ftv = getMutable(ty)) + { + Quantifier q{level}; + q.traverse(ty); + + ftv->generics.insert(ftv->generics.end(), q.generics.begin(), q.generics.end()); + ftv->genericPacks.insert(ftv->genericPacks.end(), q.genericPacks.begin(), q.genericPacks.end()); + + if (ftv->generics.empty() && ftv->genericPacks.empty() && !q.seenMutableType && !q.seenGenericType) + ftv->hasNoGenerics = true; + } } else { - ftv->generics = q.generics; - ftv->genericPacks = q.genericPacks; - } + Quantifier q{level}; + q.traverse(ty); - if (ftv->generics.empty() && ftv->genericPacks.empty() && !q.seenMutableType && !q.seenGenericType) - ftv->hasNoGenerics = true; + FunctionTypeVar* ftv = getMutable(ty); + LUAU_ASSERT(ftv); + if (FFlag::LuauAlwaysQuantify) + { + ftv->generics.insert(ftv->generics.end(), q.generics.begin(), q.generics.end()); + ftv->genericPacks.insert(ftv->genericPacks.end(), q.genericPacks.begin(), q.genericPacks.end()); + } + else + { + ftv->generics = q.generics; + ftv->genericPacks = q.genericPacks; + } + } } void quantify(TypeId ty, Scope2* scope) @@ -206,8 +244,8 @@ struct PureQuantifier : Substitution std::vector insertedGenerics; std::vector insertedGenericPacks; - PureQuantifier(const TxnLog* log, TypeArena* arena, Scope2* scope) - : Substitution(log, arena) + PureQuantifier(TypeArena* arena, Scope2* scope) + : Substitution(TxnLog::empty(), arena) , scope(scope) { } @@ -286,7 +324,7 @@ struct PureQuantifier : Substitution TypeId quantify(TypeArena* arena, TypeId ty, Scope2* scope) { - PureQuantifier quantifier{TxnLog::empty(), arena, scope}; + PureQuantifier quantifier{arena, scope}; std::optional result = quantifier.substitute(ty); LUAU_ASSERT(result); @@ -294,8 +332,7 @@ TypeId quantify(TypeArena* arena, TypeId ty, Scope2* scope) LUAU_ASSERT(ftv); ftv->generics.insert(ftv->generics.end(), quantifier.insertedGenerics.begin(), quantifier.insertedGenerics.end()); ftv->genericPacks.insert(ftv->genericPacks.end(), quantifier.insertedGenericPacks.begin(), quantifier.insertedGenericPacks.end()); - - // TODO: Set hasNoGenerics. + ftv->hasNoGenerics = ftv->generics.empty() && ftv->genericPacks.empty(); return *result; } diff --git a/Analysis/src/Scope.cpp b/Analysis/src/Scope.cpp index 66aaee1f8..247a9dd6f 100644 --- a/Analysis/src/Scope.cpp +++ b/Analysis/src/Scope.cpp @@ -153,4 +153,19 @@ std::optional Scope2::lookupTypeBinding(const Name& name) return std::nullopt; } +std::optional Scope2::lookupTypePackBinding(const Name& name) +{ + Scope2* s = this; + while (s) + { + auto it = s->typePackBindings.find(name); + if (it != s->typePackBindings.end()) + return it->second; + + s = s->parent; + } + + return std::nullopt; +} + } // namespace Luau diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index eb7b9cd6b..fe940d5ec 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -1367,51 +1367,74 @@ std::string generateName(size_t i) return n; } -std::string toString(const Constraint& c, ToStringOptions& opts) +std::string toString(const Constraint& constraint, ToStringOptions& opts) { - if (const SubtypeConstraint* sc = Luau::get_if(&c.c)) - { - ToStringResult subStr = toStringDetailed(sc->subType, opts); - opts.nameMap = std::move(subStr.nameMap); - ToStringResult superStr = toStringDetailed(sc->superType, opts); - opts.nameMap = std::move(superStr.nameMap); - return subStr.name + " <: " + superStr.name; - } - else if (const PackSubtypeConstraint* psc = Luau::get_if(&c.c)) - { - ToStringResult subStr = toStringDetailed(psc->subPack, opts); - opts.nameMap = std::move(subStr.nameMap); - ToStringResult superStr = toStringDetailed(psc->superPack, opts); - opts.nameMap = std::move(superStr.nameMap); - return subStr.name + " <: " + superStr.name; - } - else if (const GeneralizationConstraint* gc = Luau::get_if(&c.c)) - { - ToStringResult subStr = toStringDetailed(gc->generalizedType, opts); - opts.nameMap = std::move(subStr.nameMap); - ToStringResult superStr = toStringDetailed(gc->sourceType, opts); - opts.nameMap = std::move(superStr.nameMap); - return subStr.name + " ~ gen " + superStr.name; - } - else if (const InstantiationConstraint* ic = Luau::get_if(&c.c)) - { - ToStringResult subStr = toStringDetailed(ic->subType, opts); - opts.nameMap = std::move(subStr.nameMap); - ToStringResult superStr = toStringDetailed(ic->superType, opts); - opts.nameMap = std::move(superStr.nameMap); - return subStr.name + " ~ inst " + superStr.name; - } - else if (const NameConstraint* nc = Luau::get(c)) - { - ToStringResult namedStr = toStringDetailed(nc->namedType, opts); - opts.nameMap = std::move(namedStr.nameMap); - return "@name(" + namedStr.name + ") = " + nc->name; - } - else - { - LUAU_ASSERT(false); - return ""; - } + auto go = [&opts](auto&& c) { + using T = std::decay_t; + + if constexpr (std::is_same_v) + { + ToStringResult subStr = toStringDetailed(c.subType, opts); + opts.nameMap = std::move(subStr.nameMap); + ToStringResult superStr = toStringDetailed(c.superType, opts); + opts.nameMap = std::move(superStr.nameMap); + return subStr.name + " <: " + superStr.name; + } + else if constexpr (std::is_same_v) + { + ToStringResult subStr = toStringDetailed(c.subPack, opts); + opts.nameMap = std::move(subStr.nameMap); + ToStringResult superStr = toStringDetailed(c.superPack, opts); + opts.nameMap = std::move(superStr.nameMap); + return subStr.name + " <: " + superStr.name; + } + else if constexpr (std::is_same_v) + { + ToStringResult subStr = toStringDetailed(c.generalizedType, opts); + opts.nameMap = std::move(subStr.nameMap); + ToStringResult superStr = toStringDetailed(c.sourceType, opts); + opts.nameMap = std::move(superStr.nameMap); + return subStr.name + " ~ gen " + superStr.name; + } + else if constexpr (std::is_same_v) + { + ToStringResult subStr = toStringDetailed(c.subType, opts); + opts.nameMap = std::move(subStr.nameMap); + ToStringResult superStr = toStringDetailed(c.superType, opts); + opts.nameMap = std::move(superStr.nameMap); + return subStr.name + " ~ inst " + superStr.name; + } + else if constexpr (std::is_same_v) + { + ToStringResult resultStr = toStringDetailed(c.resultType, opts); + opts.nameMap = std::move(resultStr.nameMap); + ToStringResult operandStr = toStringDetailed(c.operandType, opts); + opts.nameMap = std::move(operandStr.nameMap); + + return resultStr.name + " ~ Unary<" + toString(c.op) + ", " + operandStr.name + ">"; + } + else if constexpr (std::is_same_v) + { + ToStringResult resultStr = toStringDetailed(c.resultType); + opts.nameMap = std::move(resultStr.nameMap); + ToStringResult leftStr = toStringDetailed(c.leftType); + opts.nameMap = std::move(leftStr.nameMap); + ToStringResult rightStr = toStringDetailed(c.rightType); + opts.nameMap = std::move(rightStr.nameMap); + + return resultStr.name + " ~ Binary<" + toString(c.op) + ", " + leftStr.name + ", " + rightStr.name + ">"; + } + else if constexpr (std::is_same_v) + { + ToStringResult namedStr = toStringDetailed(c.namedType, opts); + opts.nameMap = std::move(namedStr.nameMap); + return "@name(" + namedStr.name + ") = " + c.name; + } + else + static_assert(always_false_v, "Non-exhaustive constraint switch"); + }; + + return visit(go, constraint.c); } std::string dump(const Constraint& c) diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index 63e5800fc..30e498af1 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -6,8 +6,10 @@ #include "Luau/Ast.h" #include "Luau/AstQuery.h" #include "Luau/Clone.h" +#include "Luau/Instantiation.h" #include "Luau/Normalize.h" -#include "Luau/ConstraintGraphBuilder.h" // FIXME move Scope2 into its own header +#include "Luau/TxnLog.h" +#include "Luau/TypeUtils.h" #include "Luau/Unifier.h" #include "Luau/ToString.h" @@ -19,10 +21,12 @@ struct TypeChecker2 : public AstVisitor const SourceModule* sourceModule; Module* module; InternalErrorReporter ice; // FIXME accept a pointer from Frontend + SingletonTypes& singletonTypes; TypeChecker2(const SourceModule* sourceModule, Module* module) : sourceModule(sourceModule) , module(module) + , singletonTypes(getSingletonTypes()) { } @@ -30,16 +34,30 @@ struct TypeChecker2 : public AstVisitor TypePackId lookupPack(AstExpr* expr) { + // If a type isn't in the type graph, it probably means that a recursion limit was exceeded. + // We'll just return anyType in these cases. Typechecking against any is very fast and this + // allows us not to think about this very much in the actual typechecking logic. TypePackId* tp = module->astTypePacks.find(expr); - LUAU_ASSERT(tp); - return follow(*tp); + if (tp) + return follow(*tp); + else + return singletonTypes.anyTypePack; } TypeId lookupType(AstExpr* expr) { + // If a type isn't in the type graph, it probably means that a recursion limit was exceeded. + // We'll just return anyType in these cases. Typechecking against any is very fast and this + // allows us not to think about this very much in the actual typechecking logic. TypeId* ty = module->astTypes.find(expr); - LUAU_ASSERT(ty); - return follow(*ty); + if (ty) + return follow(*ty); + + TypePackId* tp = module->astTypePacks.find(expr); + if (tp) + return flattenPack(*tp); + + return singletonTypes.anyType; } TypeId lookupAnnotation(AstType* annotation) @@ -78,7 +96,7 @@ struct TypeChecker2 : public AstVisitor bestLocation = scopeBounds; } } - else + else if (scopeBounds.begin > location.end) { // TODO: Is this sound? This relies on the fact that scopes are inserted // into the scope list in the order that they appear in the AST. @@ -147,16 +165,14 @@ struct TypeChecker2 : public AstVisitor for (size_t i = 0; i < count; ++i) { AstExpr* lhs = assign->vars.data[i]; - TypeId* lhsType = module->astTypes.find(lhs); - LUAU_ASSERT(lhsType); + TypeId lhsType = lookupType(lhs); AstExpr* rhs = assign->values.data[i]; - TypeId* rhsType = module->astTypes.find(rhs); - LUAU_ASSERT(rhsType); + TypeId rhsType = lookupType(rhs); - if (!isSubtype(*rhsType, *lhsType, ice)) + if (!isSubtype(rhsType, lhsType, ice)) { - reportError(TypeMismatch{*lhsType, *rhsType}, rhs->location); + reportError(TypeMismatch{lhsType, rhsType}, rhs->location); } } @@ -181,7 +197,7 @@ struct TypeChecker2 : public AstVisitor if (!ok) { for (const TypeError& e : u.errors) - module->errors.push_back(e); + reportError(e); } return true; @@ -189,10 +205,14 @@ struct TypeChecker2 : public AstVisitor bool visit(AstExprCall* call) override { + TypeArena arena; + Instantiation instantiation{TxnLog::empty(), &arena, TypeLevel{}}; + TypePackId expectedRetType = lookupPack(call); TypeId functionType = lookupType(call->func); + TypeId instantiatedFunctionType = instantiation.substitute(functionType).value_or(nullptr); + LUAU_ASSERT(functionType); - TypeArena arena; TypePack args; for (const auto& arg : call->args) { @@ -204,7 +224,7 @@ struct TypeChecker2 : public AstVisitor TypePackId argsTp = arena.addTypePack(args); FunctionTypeVar ftv{argsTp, expectedRetType}; TypeId expectedType = arena.addType(ftv); - if (!isSubtype(expectedType, functionType, ice)) + if (!isSubtype(expectedType, instantiatedFunctionType, ice)) { unfreeze(module->interfaceTypes); CloneState cloneState; @@ -252,16 +272,12 @@ struct TypeChecker2 : public AstVisitor // leftType must have a property called indexName->index - if (auto ttv = get(leftType)) + std::optional t = findTablePropertyRespectingMeta(module->errors, leftType, indexName->index.value, indexName->location); + if (t) { - auto it = ttv->props.find(indexName->index.value); - if (it == ttv->props.end()) + if (!isSubtype(resultType, *t, ice)) { - reportError(UnknownProperty{leftType, indexName->index.value}, indexName->location); - } - else if (!isSubtype(resultType, it->second.type, ice)) - { - reportError(TypeMismatch{resultType, it->second.type}, indexName->location); + reportError(TypeMismatch{resultType, *t}, indexName->location); } } else @@ -277,7 +293,7 @@ struct TypeChecker2 : public AstVisitor TypeId actualType = lookupType(number); TypeId numberType = getSingletonTypes().numberType; - if (!isSubtype(actualType, numberType, ice)) + if (!isSubtype(numberType, actualType, ice)) { reportError(TypeMismatch{actualType, numberType}, number->location); } @@ -290,7 +306,7 @@ struct TypeChecker2 : public AstVisitor TypeId actualType = lookupType(string); TypeId stringType = getSingletonTypes().stringType; - if (!isSubtype(actualType, stringType, ice)) + if (!isSubtype(stringType, actualType, ice)) { reportError(TypeMismatch{actualType, stringType}, string->location); } @@ -298,6 +314,41 @@ struct TypeChecker2 : public AstVisitor return true; } + /** Extract a TypeId for the first type of the provided pack. + * + * Note that this may require modifying some types. I hope this doesn't cause problems! + */ + TypeId flattenPack(TypePackId pack) + { + pack = follow(pack); + + while (auto tp = get(pack)) + { + if (tp->head.empty() && tp->tail) + pack = *tp->tail; + } + + if (auto ty = first(pack)) + return *ty; + else if (auto vtp = get(pack)) + return vtp->ty; + else if (auto ftp = get(pack)) + { + TypeId result = module->internalTypes.addType(FreeTypeVar{ftp->scope}); + TypePackId freeTail = module->internalTypes.addTypePack(FreeTypePack{ftp->scope}); + + TypePack& resultPack = asMutable(pack)->ty.emplace(); + resultPack.head.assign(1, result); + resultPack.tail = freeTail; + + return result; + } + else if (get(pack)) + return singletonTypes.errorRecoveryType(); + else + ice.ice("flattenPack got a weird pack!"); + } + bool visit(AstType* ty) override { return true; @@ -321,6 +372,11 @@ struct TypeChecker2 : public AstVisitor { module->errors.emplace_back(location, sourceModule->name, std::move(data)); } + + void reportError(TypeError e) + { + module->errors.emplace_back(std::move(e)); + } }; void check(const SourceModule& sourceModule, Module* module) diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 44635e88e..d9486a4fd 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -38,11 +38,13 @@ LUAU_FASTFLAGVARIABLE(LuauReduceUnionRecursion, false) LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. LUAU_FASTFLAG(LuauNormalizeFlagIsConservative) LUAU_FASTFLAGVARIABLE(LuauReturnTypeInferenceInNonstrict, false) +LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false); LUAU_FASTFLAGVARIABLE(LuauAlwaysQuantify, false); LUAU_FASTFLAGVARIABLE(LuauReportErrorsOnIndexerKeyMismatch, false) LUAU_FASTFLAG(LuauQuantifyConstrained) LUAU_FASTFLAGVARIABLE(LuauFalsyPredicateReturnsNilInstead, false) LUAU_FASTFLAGVARIABLE(LuauNonCopyableTypeVarFields, false) +LUAU_FASTFLAGVARIABLE(LuauCheckLenMT, false) namespace Luau { @@ -238,7 +240,7 @@ static bool isMetamethod(const Name& name) { return name == "__index" || name == "__newindex" || name == "__call" || name == "__concat" || name == "__unm" || name == "__add" || name == "__sub" || name == "__mul" || name == "__div" || name == "__mod" || name == "__pow" || name == "__tostring" || - name == "__metatable" || name == "__eq" || name == "__lt" || name == "__le" || name == "__mode"; + name == "__metatable" || name == "__eq" || name == "__lt" || name == "__le" || name == "__mode" || name == "__iter" || name == "__len"; } size_t HashBoolNamePair::operator()(const std::pair& pair) const @@ -327,10 +329,19 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo currentModule->timeout = true; } + if (FFlag::DebugLuauSharedSelf) + { + for (auto& [ty, scope] : deferredQuantification) + Luau::quantify(ty, scope->level); + deferredQuantification.clear(); + } + if (get(follow(moduleScope->returnType))) moduleScope->returnType = addTypePack(TypePack{{}, std::nullopt}); else + { moduleScope->returnType = anyify(moduleScope, moduleScope->returnType, Location{}); + } for (auto& [_, typeFun] : moduleScope->exportedTypeBindings) typeFun.type = anyify(moduleScope, typeFun.type, Location{}); @@ -537,18 +548,43 @@ void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const A } else if (auto fun = (*protoIter)->as()) { + std::optional selfType; std::optional expectedType; - if (!fun->func->self) + if (FFlag::DebugLuauSharedSelf) { if (auto name = fun->name->as()) { - TypeId exprTy = checkExpr(scope, *name->expr).type; - expectedType = getIndexTypeFromType(scope, exprTy, name->index.value, name->indexLocation, false); + TypeId baseTy = checkExpr(scope, *name->expr).type; + tablify(baseTy); + + if (!fun->func->self) + expectedType = getIndexTypeFromType(scope, baseTy, name->index.value, name->indexLocation, false); + else if (auto ttv = getMutableTableType(baseTy)) + { + if (!baseTy->persistent && ttv->state != TableState::Sealed && !ttv->selfTy) + { + ttv->selfTy = anyIfNonstrict(freshType(ttv->level)); + deferredQuantification.push_back({baseTy, scope}); + } + + selfType = ttv->selfTy; + } + } + } + else + { + if (!fun->func->self) + { + if (auto name = fun->name->as()) + { + TypeId exprTy = checkExpr(scope, *name->expr).type; + expectedType = getIndexTypeFromType(scope, exprTy, name->index.value, name->indexLocation, false); + } } } - auto pair = checkFunctionSignature(scope, subLevel, *fun->func, fun->name->location, expectedType); + auto pair = checkFunctionSignature(scope, subLevel, *fun->func, fun->name->location, selfType, expectedType); auto [funTy, funScope] = pair; functionDecls[*protoIter] = pair; @@ -560,7 +596,7 @@ void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const A } else if (auto fun = (*protoIter)->as()) { - auto pair = checkFunctionSignature(scope, subLevel, *fun->func, fun->name->location, std::nullopt); + auto pair = checkFunctionSignature(scope, subLevel, *fun->func, fun->name->location, std::nullopt, std::nullopt); auto [funTy, funScope] = pair; functionDecls[*protoIter] = pair; @@ -2076,7 +2112,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprFunction& expr, std::optional expectedType) { - auto [funTy, funScope] = checkFunctionSignature(scope, 0, expr, std::nullopt, expectedType); + auto [funTy, funScope] = checkFunctionSignature(scope, 0, expr, std::nullopt, std::nullopt, expectedType); checkFunctionBody(funScope, funTy, expr); @@ -2296,6 +2332,8 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp state.tryUnify(actualFunctionType, expectedFunctionType, /*isFunctionCall*/ true); state.log.commit(); + reportErrors(state.errors); + TypeId retType = first(retTypePack).value_or(nilType); if (!state.errors.empty()) retType = errorRecoveryType(retType); @@ -2322,6 +2360,23 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp DenseHashSet seen{nullptr}; + if (FFlag::LuauCheckLenMT && typeCouldHaveMetatable(operandType)) + { + if (auto fnt = findMetatableEntry(operandType, "__len", expr.location)) + { + TypeId actualFunctionType = instantiate(scope, *fnt, expr.location); + TypePackId arguments = addTypePack({operandType}); + TypePackId retTypePack = addTypePack({numberType}); + TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retTypePack)); + + Unifier state = mkUnifier(expr.location); + state.tryUnify(actualFunctionType, expectedFunctionType, /*isFunctionCall*/ true); + state.log.commit(); + + reportErrors(state.errors); + } + } + if (!hasLength(operandType, seen, &recursionCount)) reportError(TypeError{expr.location, NotATable{operandType}}); @@ -2530,17 +2585,15 @@ TypeId TypeChecker::checkRelationalOperation( } } - if (!matches) { reportError( expr.location, GenericError{format("Types %s and %s cannot be compared with %s because they do not have the same metatable", - toString(lhsType).c_str(), toString(rhsType).c_str(), toString(expr.op).c_str())}); + toString(lhsType).c_str(), toString(rhsType).c_str(), toString(expr.op).c_str())}); return errorRecoveryType(booleanType); } } - if (leftMetatable) { std::optional metamethod = findMetatableEntry(lhsType, metamethodName, expr.location); @@ -3139,8 +3192,8 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName, T // `(X) -> Y...`, but after typechecking the body, we cam unify `Y...` with `X` // to get type `(X) -> X`, then we quantify the free types to get the final // generic type `(a) -> a`. -std::pair TypeChecker::checkFunctionSignature( - const ScopePtr& scope, int subLevel, const AstExprFunction& expr, std::optional originalName, std::optional expectedType) +std::pair TypeChecker::checkFunctionSignature(const ScopePtr& scope, int subLevel, const AstExprFunction& expr, + std::optional originalName, std::optional selfType, std::optional expectedType) { ScopePtr funScope = childFunctionScope(scope, expr.location, subLevel); @@ -3241,12 +3294,25 @@ std::pair TypeChecker::checkFunctionSignature( funScope->returnType = retPack; - if (expr.self) + if (FFlag::DebugLuauSharedSelf) { - // TODO: generic self types: CLI-39906 - TypeId selfType = anyIfNonstrict(freshType(funScope)); - funScope->bindings[expr.self] = {selfType, expr.self->location}; - argTypes.push_back(selfType); + if (expr.self) + { + // TODO: generic self types: CLI-39906 + TypeId selfTy = anyIfNonstrict(selfType ? *selfType : freshType(funScope)); + funScope->bindings[expr.self] = {selfTy, expr.self->location}; + argTypes.push_back(selfTy); + } + } + else + { + if (expr.self) + { + // TODO: generic self types: CLI-39906 + TypeId selfType = anyIfNonstrict(freshType(funScope)); + funScope->bindings[expr.self] = {selfType, expr.self->location}; + argTypes.push_back(selfType); + } } // Prepare expected argument type iterators if we have an expected function type @@ -4457,25 +4523,43 @@ TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location { ty = follow(ty); - const FunctionTypeVar* ftv = get(ty); - - if (FFlag::LuauAlwaysQuantify) + if (FFlag::DebugLuauSharedSelf) { - if (ftv) + if (auto ftv = get(ty)) Luau::quantify(ty, scope->level); + else if (auto ttv = getTableType(ty); ttv && ttv->selfTy) + Luau::quantify(ty, scope->level); + + if (FFlag::LuauLowerBoundsCalculation) + { + auto [t, ok] = Luau::normalize(ty, currentModule, *iceHandler); + if (!ok) + reportError(location, NormalizationTooComplex{}); + return t; + } } else { - if (ftv && ftv->generics.empty() && ftv->genericPacks.empty()) - Luau::quantify(ty, scope->level); - } + const FunctionTypeVar* ftv = get(ty); - if (FFlag::LuauLowerBoundsCalculation && ftv) - { - auto [t, ok] = Luau::normalize(ty, currentModule, *iceHandler); - if (!ok) - reportError(location, NormalizationTooComplex{}); - return t; + if (FFlag::LuauAlwaysQuantify) + { + if (ftv) + Luau::quantify(ty, scope->level); + } + else + { + if (ftv && ftv->generics.empty() && ftv->genericPacks.empty()) + Luau::quantify(ty, scope->level); + } + + if (FFlag::LuauLowerBoundsCalculation && ftv) + { + auto [t, ok] = Luau::normalize(ty, currentModule, *iceHandler); + if (!ok) + reportError(location, NormalizationTooComplex{}); + return t; + } } return ty; diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 6147e118c..0792a3503 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -740,7 +740,7 @@ void Unifier::tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const I std::optional unificationTooComplex; std::optional firstFailedOption; - // T <: A & B if A <: T and B <: T + // T <: A & B if T <: A and T <: B for (TypeId type : uv->parts) { Unifier innerState = makeChildUnifier(); @@ -765,7 +765,7 @@ void Unifier::tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const I void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeVar* uv, TypeId superTy, bool cacheEnabled, bool isFunctionCall) { - // A & B <: T if T <: A or T <: B + // A & B <: T if A <: T or B <: T bool found = false; std::optional unificationTooComplex; diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 95bce3ee2..70c925559 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -5,6 +5,9 @@ #include +#include +#include + // Warning: If you are introducing new syntax, ensure that it is behind a separate // flag so that we don't break production games by reverting syntax changes. // See docs/SyntaxChanges.md for an explanation. @@ -14,6 +17,18 @@ LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) LUAU_FASTFLAGVARIABLE(LuauParserFunctionKeywordAsTypeHelp, false) LUAU_FASTFLAGVARIABLE(LuauReturnTypeTokenConfusion, false) +LUAU_FASTFLAGVARIABLE(LuauFixNamedFunctionParse, false) +LUAU_DYNAMIC_FASTFLAGVARIABLE(LuaReportParseWrongNamedType, false) + +bool lua_telemetry_parsed_named_non_function_type = false; + +LUAU_FASTFLAGVARIABLE(LuauErrorParseIntegerIssues, false) +LUAU_DYNAMIC_FASTFLAGVARIABLE(LuaReportParseIntegerIssues, false) + +bool lua_telemetry_parsed_out_of_range_bin_integer = false; +bool lua_telemetry_parsed_out_of_range_hex_integer = false; +bool lua_telemetry_parsed_double_prefix_hex_integer = false; + namespace Luau { @@ -1330,7 +1345,7 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) { incrementRecursionCounter("type annotation"); - bool monomorphic = lexer.current().type != '<'; + bool forceFunctionType = lexer.current().type == '<'; Lexeme begin = lexer.current(); @@ -1355,21 +1370,33 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) AstArray paramTypes = copy(params); + if (FFlag::LuauFixNamedFunctionParse && !names.empty()) + forceFunctionType = true; + bool returnTypeIntroducer = FFlag::LuauReturnTypeTokenConfusion ? lexer.current().type == Lexeme::SkinnyArrow || lexer.current().type == ':' : false; // Not a function at all. Just a parenthesized type. Or maybe a type pack with a single element - if (params.size() == 1 && !varargAnnotation && monomorphic && + if (params.size() == 1 && !varargAnnotation && !forceFunctionType && (FFlag::LuauReturnTypeTokenConfusion ? !returnTypeIntroducer : lexer.current().type != Lexeme::SkinnyArrow)) { + if (DFFlag::LuaReportParseWrongNamedType && !names.empty()) + lua_telemetry_parsed_named_non_function_type = true; + if (allowPack) return {{}, allocator.alloc(begin.location, AstTypeList{paramTypes, nullptr})}; else return {params[0], {}}; } - if ((FFlag::LuauReturnTypeTokenConfusion ? !returnTypeIntroducer : lexer.current().type != Lexeme::SkinnyArrow) && monomorphic && allowPack) + if ((FFlag::LuauReturnTypeTokenConfusion ? !returnTypeIntroducer : lexer.current().type != Lexeme::SkinnyArrow) && !forceFunctionType && + allowPack) + { + if (DFFlag::LuaReportParseWrongNamedType && !names.empty()) + lua_telemetry_parsed_named_non_function_type = true; + return {{}, allocator.alloc(begin.location, AstTypeList{paramTypes, varargAnnotation})}; + } AstArray> paramNames = copy(names); @@ -2010,7 +2037,63 @@ AstExpr* Parser::parseAssertionExpr() return expr; } -static bool parseNumber(double& result, const char* data) +static const char* parseInteger(double& result, const char* data, int base) +{ + char* end = nullptr; + unsigned long long value = strtoull(data, &end, base); + + if (value == ULLONG_MAX && errno == ERANGE) + { + // 'errno' might have been set before we called 'strtoull', but we don't want the overhead of resetting a TLS variable on each call + // so we only reset it when we get a result that might be an out-of-range error and parse again to make sure + errno = 0; + value = strtoull(data, &end, base); + + if (errno == ERANGE) + { + if (DFFlag::LuaReportParseIntegerIssues) + { + if (base == 2) + lua_telemetry_parsed_out_of_range_bin_integer = true; + else + lua_telemetry_parsed_out_of_range_hex_integer = true; + } + + if (FFlag::LuauErrorParseIntegerIssues) + return "Integer number value is out of range"; + } + } + + result = double(value); + return *end == 0 ? nullptr : "Malformed number"; +} + +static const char* parseNumber(double& result, const char* data) +{ + // binary literal + if (data[0] == '0' && (data[1] == 'b' || data[1] == 'B') && data[2]) + return parseInteger(result, data + 2, 2); + + // hexadecimal literal + if (data[0] == '0' && (data[1] == 'x' || data[1] == 'X') && data[2]) + { + if (DFFlag::LuaReportParseIntegerIssues && data[2] == '0' && (data[3] == 'x' || data[3] == 'X')) + lua_telemetry_parsed_double_prefix_hex_integer = true; + + if (FFlag::LuauErrorParseIntegerIssues) + return parseInteger(result, data, 16); // keep prefix, it's handled by 'strtoull' + else + return parseInteger(result, data + 2, 16); + } + + char* end = nullptr; + double value = strtod(data, &end); + + result = value; + return *end == 0 ? nullptr : "Malformed number"; +} + +static bool parseNumber_DEPRECATED(double& result, const char* data) { // binary literal if (data[0] == '0' && (data[1] == 'b' || data[1] == 'B') && data[2]) @@ -2080,18 +2163,37 @@ AstExpr* Parser::parseSimpleExpr() scratchData.erase(std::remove(scratchData.begin(), scratchData.end(), '_'), scratchData.end()); } - double value = 0; - if (parseNumber(value, scratchData.c_str())) + if (DFFlag::LuaReportParseIntegerIssues || FFlag::LuauErrorParseIntegerIssues) { - nextLexeme(); + double value = 0; + if (const char* error = parseNumber(value, scratchData.c_str())) + { + nextLexeme(); - return allocator.alloc(start, value); + return reportExprError(start, {}, "%s", error); + } + else + { + nextLexeme(); + + return allocator.alloc(start, value); + } } else { - nextLexeme(); + double value = 0; + if (parseNumber_DEPRECATED(value, scratchData.c_str())) + { + nextLexeme(); - return reportExprError(start, {}, "Malformed number"); + return allocator.alloc(start, value); + } + else + { + nextLexeme(); + + return reportExprError(start, {}, "Malformed number"); + } } } else if (lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString) diff --git a/Common/include/Luau/Bytecode.h b/Common/include/Luau/Bytecode.h index 218bb5d5d..0cb7e1d92 100644 --- a/Common/include/Luau/Bytecode.h +++ b/Common/include/Luau/Bytecode.h @@ -276,7 +276,7 @@ enum LuauOpcode // FORGLOOP: adjust loop variables for one iteration of a generic for loop, jump back to the loop header if loop needs to continue // A: target register; generic for loops assume a register layout [generator, state, index, variables...] // D: jump offset (-32768..32767) - // AUX: variable count (1..255) + // AUX: variable count (1..255) in the low 8 bits, high bit indicates whether to use ipairs-style traversal in the fast path // loop variables are adjusted by calling generator(state, index) and expecting it to return a tuple that's copied to the user variables // the first variable is then copied into index; generator/state are immutable, index isn't visible to user code LOP_FORGLOOP, @@ -490,6 +490,9 @@ enum LuauBuiltinFunction // select(_, ...) LBF_SELECT_VARARG, + + // rawlen + LBF_RAWLEN, }; // Capture type, used in LOP_CAPTURE diff --git a/Compiler/src/Builtins.cpp b/Compiler/src/Builtins.cpp index ff7531128..6bd24b6d1 100644 --- a/Compiler/src/Builtins.cpp +++ b/Compiler/src/Builtins.cpp @@ -4,6 +4,8 @@ #include "Luau/Bytecode.h" #include "Luau/Compiler.h" +LUAU_FASTFLAGVARIABLE(LuauCompileRawlen, false) + namespace Luau { namespace Compile @@ -58,6 +60,8 @@ int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& options) return LBF_RAWGET; if (builtin.isGlobal("rawequal")) return LBF_RAWEQUAL; + if (FFlag::LuauCompileRawlen && builtin.isGlobal("rawlen")) + return LBF_RAWLEN; if (builtin.isGlobal("unpack")) return LBF_TABLE_UNPACK; diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index 301cf255b..5e2669b2b 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -1302,20 +1302,22 @@ void BytecodeBuilder::validate() const case LOP_FORNPREP: case LOP_FORNLOOP: - VREG(LUAU_INSN_A(insn) + 2); // for loop protocol: A, A+1, A+2 are used for iteration + // for loop protocol: A, A+1, A+2 are used for iteration + VREG(LUAU_INSN_A(insn) + 2); VJUMP(LUAU_INSN_D(insn)); break; case LOP_FORGPREP: - VREG(LUAU_INSN_A(insn) + 2 + 1); // forg loop protocol: A, A+1, A+2 are used for iteration protocol; A+3, ... are loop variables + // forg loop protocol: A, A+1, A+2 are used for iteration protocol; A+3, ... are loop variables + VREG(LUAU_INSN_A(insn) + 2 + 1); VJUMP(LUAU_INSN_D(insn)); break; case LOP_FORGLOOP: - VREG( - LUAU_INSN_A(insn) + 2 + insns[i + 1]); // forg loop protocol: A, A+1, A+2 are used for iteration protocol; A+3, ... are loop variables + // forg loop protocol: A, A+1, A+2 are used for iteration protocol; A+3, ... are loop variables + VREG(LUAU_INSN_A(insn) + 2 + uint8_t(insns[i + 1])); VJUMP(LUAU_INSN_D(insn)); - LUAU_ASSERT(insns[i + 1] >= 1); + LUAU_ASSERT(uint8_t(insns[i + 1]) >= 1); break; case LOP_FORGPREP_INEXT: @@ -1679,7 +1681,8 @@ void BytecodeBuilder::dumpInstruction(const uint32_t* code, std::string& result, break; case LOP_FORGLOOP: - formatAppend(result, "FORGLOOP R%d L%d %d\n", LUAU_INSN_A(insn), targetLabel, *code++); + formatAppend(result, "FORGLOOP R%d L%d %d%s\n", LUAU_INSN_A(insn), targetLabel, uint8_t(*code), int(*code) < 0 ? " [inext]" : ""); + code++; break; case LOP_FORGPREP_INEXT: diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index e732256b9..d7c8155c7 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -23,6 +23,8 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineThreshold, 25) LUAU_FASTINTVARIABLE(LuauCompileInlineThresholdMaxBoost, 300) LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5) +LUAU_FASTFLAGVARIABLE(LuauCompileNoIpairs, false) + namespace Luau { @@ -2665,7 +2667,7 @@ struct Compiler if (builtin.isGlobal("ipairs")) // for .. in ipairs(t) { skipOp = LOP_FORGPREP_INEXT; - loopOp = LOP_FORGLOOP_INEXT; + loopOp = FFlag::LuauCompileNoIpairs ? LOP_FORGLOOP : LOP_FORGLOOP_INEXT; } else if (builtin.isGlobal("pairs")) // for .. in pairs(t) { @@ -2709,8 +2711,16 @@ struct Compiler bytecode.emitAD(loopOp, regs, 0); + if (FFlag::LuauCompileNoIpairs) + { + // TODO: remove loopOp as it's a constant now + LUAU_ASSERT(loopOp == LOP_FORGLOOP); + + // FORGLOOP uses aux to encode variable count and fast path flag for ipairs traversal in the high bit + bytecode.emitAux((skipOp == LOP_FORGPREP_INEXT ? 0x80000000 : 0) | uint32_t(stat->vars.size)); + } // note: FORGLOOP needs variable count encoded in AUX field, other loop instructions assume a fixed variable count - if (loopOp == LOP_FORGLOOP) + else if (loopOp == LOP_FORGLOOP) bytecode.emitAux(uint32_t(stat->vars.size)); size_t endLabel = bytecode.emitLabel(); @@ -3341,7 +3351,7 @@ struct Compiler std::vector upvals; }; - struct ReturnVisitor: AstVisitor + struct ReturnVisitor : AstVisitor { Compiler* self; bool returnsOne = true; diff --git a/VM/src/lbaselib.cpp b/VM/src/lbaselib.cpp index f79176111..4fc5033e3 100644 --- a/VM/src/lbaselib.cpp +++ b/VM/src/lbaselib.cpp @@ -11,6 +11,8 @@ #include #include +LUAU_FASTFLAG(LuauLenTM) + static void writestring(const char* s, size_t l) { fwrite(s, 1, l, stdout); @@ -178,6 +180,18 @@ static int luaB_rawset(lua_State* L) return 1; } +static int luaB_rawlen(lua_State* L) +{ + if (!FFlag::LuauLenTM) + luaL_error(L, "'rawlen' is not available"); + + int tt = lua_type(L, 1); + luaL_argcheck(L, tt == LUA_TTABLE || tt == LUA_TSTRING, 1, "table or string expected"); + int len = lua_objlen(L, 1); + lua_pushinteger(L, len); + return 1; +} + static int luaB_gcinfo(lua_State* L) { lua_pushinteger(L, lua_gc(L, LUA_GCCOUNT, 0)); @@ -428,6 +442,7 @@ static const luaL_Reg base_funcs[] = { {"rawequal", luaB_rawequal}, {"rawget", luaB_rawget}, {"rawset", luaB_rawset}, + {"rawlen", luaB_rawlen}, {"select", luaB_select}, {"setfenv", luaB_setfenv}, {"setmetatable", luaB_setmetatable}, diff --git a/VM/src/lbuiltins.cpp b/VM/src/lbuiltins.cpp index deaf14075..e98660a78 100644 --- a/VM/src/lbuiltins.cpp +++ b/VM/src/lbuiltins.cpp @@ -1117,6 +1117,27 @@ static int luauF_select(lua_State* L, StkId res, TValue* arg0, int nresults, Stk return -1; } +static int luauF_rawlen(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 1) + { + if (ttistable(arg0)) + { + Table* h = hvalue(arg0); + setnvalue(res, double(luaH_getn(h))); + return 1; + } + else if (ttisstring(arg0)) + { + TString* ts = tsvalue(arg0); + setnvalue(res, double(ts->len)); + return 1; + } + } + + return -1; +} + luau_FastFunction luauF_table[256] = { NULL, luauF_assert, @@ -1188,4 +1209,6 @@ luau_FastFunction luauF_table[256] = { luauF_countrz, luauF_select, + + luauF_rawlen, }; diff --git a/VM/src/ldebug.h b/VM/src/ldebug.h index cf905e9fc..75bb8dcc5 100644 --- a/VM/src/ldebug.h +++ b/VM/src/ldebug.h @@ -19,6 +19,7 @@ LUAI_FUNC l_noret luaG_concaterror(lua_State* L, StkId p1, StkId p2); LUAI_FUNC l_noret luaG_aritherror(lua_State* L, const TValue* p1, const TValue* p2, TMS op); LUAI_FUNC l_noret luaG_ordererror(lua_State* L, const TValue* p1, const TValue* p2, TMS op); LUAI_FUNC l_noret luaG_indexerror(lua_State* L, const TValue* p1, const TValue* p2); + LUAI_FUNC LUA_PRINTF_ATTR(2, 3) l_noret luaG_runerrorL(lua_State* L, const char* fmt, ...); LUAI_FUNC void luaG_pusherror(lua_State* L, const char* error); diff --git a/VM/src/ltm.cpp b/VM/src/ltm.cpp index e7df4e533..49982b286 100644 --- a/VM/src/ltm.cpp +++ b/VM/src/ltm.cpp @@ -39,6 +39,7 @@ const char* const luaT_eventname[] = { "__namecall", "__call", "__iter", + "__len", "__eq", @@ -52,7 +53,6 @@ const char* const luaT_eventname[] = { "__unm", - "__len", "__lt", "__le", "__concat", diff --git a/VM/src/ltm.h b/VM/src/ltm.h index a5223941d..e11ddb3a5 100644 --- a/VM/src/ltm.h +++ b/VM/src/ltm.h @@ -18,6 +18,7 @@ typedef enum TM_NAMECALL, TM_CALL, TM_ITER, + TM_LEN, TM_EQ, /* last tag method with `fast' access */ @@ -31,7 +32,6 @@ typedef enum TM_UNM, - TM_LEN, TM_LT, TM_LE, TM_CONCAT, diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index e0a96474a..85829ca13 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -16,6 +16,8 @@ #include +LUAU_FASTFLAGVARIABLE(LuauLenTM, false) + // Disable c99-designator to avoid the warning in CGOTO dispatch table #ifdef __clang__ #if __has_warning("-Wc99-designator") @@ -2082,13 +2084,25 @@ static void luau_execute(lua_State* L) // fast-path #1: tables if (ttistable(rb)) { - setnvalue(ra, cast_num(luaH_getn(hvalue(rb)))); - VM_NEXT(); + Table* h = hvalue(rb); + + if (!FFlag::LuauLenTM || fastnotm(h->metatable, TM_LEN)) + { + setnvalue(ra, cast_num(luaH_getn(h))); + VM_NEXT(); + } + else + { + // slow-path, may invoke C/Lua via metamethods + VM_PROTECT(luaV_dolen(L, ra, rb)); + VM_NEXT(); + } } // fast-path #2: strings (not very important but easy to do) else if (ttisstring(rb)) { - setnvalue(ra, cast_num(tsvalue(rb)->len)); + TString* ts = tsvalue(rb); + setnvalue(ra, cast_num(ts->len)); VM_NEXT(); } else @@ -2226,6 +2240,15 @@ static void luau_execute(lua_State* L) VM_PROTECT(luaD_call(L, ra, 3)); L->top = L->ci->top; + + /* recompute ra since stack might have been reallocated */ + ra = VM_REG(LUAU_INSN_A(insn)); + + /* protect against __iter returning nil, since nil is used as a marker for builtin iteration in FORGLOOP */ + if (ttisnil(ra)) + { + VM_PROTECT(luaG_typeerror(L, ra, "call")); + } } else if (fasttm(L, mt, TM_CALL)) { @@ -2258,27 +2281,38 @@ static void luau_execute(lua_State* L) uint32_t aux = *pc; // fast-path: builtin table iteration - if (ttisnil(ra) && ttistable(ra + 1) && ttislightuserdata(ra + 2)) + // note: ra=nil guarantees ra+1=table and ra+2=userdata because of the setup by FORGPREP* opcodes + // TODO: remove the table check per guarantee above + if (ttisnil(ra) && ttistable(ra + 1)) { Table* h = hvalue(ra + 1); int index = int(reinterpret_cast(pvalue(ra + 2))); int sizearray = h->sizearray; - int sizenode = 1 << h->lsizenode; // clear extra variables since we might have more than two - if (LUAU_UNLIKELY(aux > 2)) + // note: while aux encodes ipairs bit, when set we always use 2 variables, so it's safe to check this via a signed comparison + if (LUAU_UNLIKELY(int(aux) > 2)) for (int i = 2; i < int(aux); ++i) setnilvalue(ra + 3 + i); + // terminate ipairs-style traversal early when encountering nil + if (int(aux) < 0 && (unsigned(index) >= unsigned(sizearray) || ttisnil(&h->array[index]))) + { + pc++; + VM_NEXT(); + } + // first we advance index through the array portion while (unsigned(index) < unsigned(sizearray)) { - if (!ttisnil(&h->array[index])) + TValue* e = &h->array[index]; + + if (!ttisnil(e)) { setpvalue(ra + 2, reinterpret_cast(uintptr_t(index + 1))); setnvalue(ra + 3, double(index + 1)); - setobj2s(L, ra + 4, &h->array[index]); + setobj2s(L, ra + 4, e); pc += LUAU_INSN_D(insn); LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); @@ -2288,6 +2322,8 @@ static void luau_execute(lua_State* L) index++; } + int sizenode = 1 << h->lsizenode; + // then we advance index through the hash portion while (unsigned(index - sizearray) < unsigned(sizenode)) { @@ -2321,7 +2357,7 @@ static void luau_execute(lua_State* L) L->top = ra + 3 + 3; /* func + 2 args (state and index) */ LUAU_ASSERT(L->top <= L->stack_last); - VM_PROTECT(luaD_call(L, ra + 3, aux)); + VM_PROTECT(luaD_call(L, ra + 3, uint8_t(aux))); L->top = L->ci->top; // recompute ra since stack might have been reallocated diff --git a/VM/src/lvmutils.cpp b/VM/src/lvmutils.cpp index 8a18a4d46..b9e762ebc 100644 --- a/VM/src/lvmutils.cpp +++ b/VM/src/lvmutils.cpp @@ -10,6 +10,9 @@ #include "lnumutils.h" #include +#include + +LUAU_FASTFLAG(LuauLenTM) /* limit for table tag-method chains (to avoid loops) */ #define MAXTAGLOOP 100 @@ -51,7 +54,7 @@ const float* luaV_tovector(const TValue* obj) return nullptr; } -static void callTMres(lua_State* L, StkId res, const TValue* f, const TValue* p1, const TValue* p2) +static StkId callTMres(lua_State* L, StkId res, const TValue* f, const TValue* p1, const TValue* p2) { ptrdiff_t result = savestack(L, res); // using stack room beyond top is technically safe here, but for very complicated reasons: @@ -71,6 +74,7 @@ static void callTMres(lua_State* L, StkId res, const TValue* f, const TValue* p1 res = restorestack(L, result); L->top--; setobjs2s(L, res, L->top); + return res; } static void callTM(lua_State* L, const TValue* f, const TValue* p1, const TValue* p2, const TValue* p3) @@ -472,22 +476,56 @@ void luaV_doarith(lua_State* L, StkId ra, const TValue* rb, const TValue* rc, TM void luaV_dolen(lua_State* L, StkId ra, const TValue* rb) { + if (!FFlag::LuauLenTM) + { + switch (ttype(rb)) + { + case LUA_TTABLE: + { + setnvalue(ra, cast_num(luaH_getn(hvalue(rb)))); + break; + } + case LUA_TSTRING: + { + setnvalue(ra, cast_num(tsvalue(rb)->len)); + break; + } + default: + { /* try metamethod */ + if (!call_binTM(L, rb, luaO_nilobject, ra, TM_LEN)) + luaG_typeerror(L, rb, "get length of"); + } + } + return; + } + + const TValue* tm = NULL; switch (ttype(rb)) { case LUA_TTABLE: { - setnvalue(ra, cast_num(luaH_getn(hvalue(rb)))); + Table* h = hvalue(rb); + if ((tm = fasttm(L, h->metatable, TM_LEN)) == NULL) + { + setnvalue(ra, cast_num(luaH_getn(h))); + return; + } break; } case LUA_TSTRING: { - setnvalue(ra, cast_num(tsvalue(rb)->len)); - break; + TString* ts = tsvalue(rb); + setnvalue(ra, cast_num(ts->len)); + return; } default: - { /* try metamethod */ - if (!call_binTM(L, rb, luaO_nilobject, ra, TM_LEN)) - luaG_typeerror(L, rb, "get length of"); - } + tm = luaT_gettmbyobj(L, rb, TM_LEN); } + + if (ttisnil(tm)) + luaG_typeerror(L, rb, "get length of"); + + StkId res = callTMres(L, ra, tm, rb, luaO_nilobject); + if (!ttisnumber(res)) + luaG_runerror(L, "'__len' must return a number"); /* note, we can't access rb since stack may have been reallocated */ } diff --git a/fuzz/protoprint.cpp b/fuzz/protoprint.cpp index 66a89f243..d4d522765 100644 --- a/fuzz/protoprint.cpp +++ b/fuzz/protoprint.cpp @@ -11,6 +11,7 @@ static const std::string kNames[] = { "__div", "__eq", "__index", + "__iter", "__le", "__len", "__lt", @@ -41,13 +42,18 @@ static const std::string kNames[] = { "ceil", "char", "charpattern", + "clamp", "clock", + "clone", + "close", "codepoint", "codes", "concat", "coroutine", "cos", "cosh", + "countlz", + "countrz", "create", "date", "debug", @@ -63,6 +69,7 @@ static const std::string kNames[] = { "foreachi", "format", "frexp", + "freeze", "function", "gcinfo", "getfenv", @@ -72,8 +79,10 @@ static const std::string kNames[] = { "gmatch", "gsub", "huge", + "info", "insert", "ipairs", + "isfrozen", "isyieldable", "ldexp", "len", @@ -93,6 +102,7 @@ static const std::string kNames[] = { "newproxy", "next", "nil", + "noise", "number", "offset", "os", @@ -121,6 +131,7 @@ static const std::string kNames[] = { "select", "setfenv", "setmetatable", + "sign", "sin", "sinh", "sort", diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 655e48cb6..fafafd716 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -261,6 +261,8 @@ L1: RETURN R0 0 TEST_CASE("ForBytecode") { + ScopedFastFlag sff("LuauCompileNoIpairs", true); + // basic for loop: variable directly refers to internal iteration index (R2) CHECK_EQ("\n" + compileFunction0("for i=1,5 do print(i) end"), R"( LOADN R2 1 @@ -313,7 +315,7 @@ L0: GETIMPORT R5 3 MOVE R6 R3 MOVE R7 R4 CALL R5 2 0 -L1: FORGLOOP_INEXT R0 L0 +L1: FORGLOOP R0 L0 2 [inext] RETURN R0 0 )"); @@ -347,13 +349,15 @@ RETURN R0 0 TEST_CASE("ForBytecodeBuiltin") { + ScopedFastFlag sff("LuauCompileNoIpairs", true); + // we generally recognize builtins like pairs/ipairs and emit special opcodes CHECK_EQ("\n" + compileFunction0("for k,v in ipairs({}) do end"), R"( GETIMPORT R0 1 NEWTABLE R1 0 0 CALL R0 1 3 FORGPREP_INEXT R0 L0 -L0: FORGLOOP_INEXT R0 L0 +L0: FORGLOOP R0 L0 2 [inext] RETURN R0 0 )"); @@ -364,7 +368,7 @@ MOVE R1 R0 NEWTABLE R2 0 0 CALL R1 1 3 FORGPREP_INEXT R1 L0 -L0: FORGLOOP_INEXT R1 L0 +L0: FORGLOOP R1 L0 2 [inext] RETURN R0 0 )"); @@ -374,7 +378,7 @@ GETUPVAL R0 0 NEWTABLE R1 0 0 CALL R0 1 3 FORGPREP_INEXT R0 L0 -L0: FORGLOOP_INEXT R0 L0 +L0: FORGLOOP R0 L0 2 [inext] RETURN R0 0 )"); @@ -2107,6 +2111,8 @@ RETURN R3 -1 TEST_CASE("UpvaluesLoopsBytecode") { + ScopedFastFlag sff("LuauCompileNoIpairs", true); + CHECK_EQ("\n" + compileFunction(R"( function test() for i=1,10 do @@ -2169,7 +2175,7 @@ JUMPIFNOT R5 L1 CLOSEUPVALS R3 JUMP L3 L1: CLOSEUPVALS R3 -L2: FORGLOOP_INEXT R0 L0 +L2: FORGLOOP R0 L0 1 [inext] L3: LOADN R0 0 RETURN R0 1 )"); diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 96a2775f6..3f415149c 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -231,6 +231,8 @@ TEST_CASE("Assert") TEST_CASE("Basic") { + ScopedFastFlag sff("LuauLenTM", true); + runConformance("basic.lua"); } @@ -301,6 +303,8 @@ TEST_CASE("Errors") TEST_CASE("Events") { + ScopedFastFlag sff("LuauLenTM", true); + runConformance("events.lua"); } @@ -475,6 +479,8 @@ static void populateRTTI(lua_State* L, Luau::TypeId type) TEST_CASE("Types") { + ScopedFastFlag sff("LuauCheckLenMT", true); + runConformance("types.lua", [](lua_State* L) { Luau::NullModuleResolver moduleResolver; Luau::InternalErrorReporter iceHandler; diff --git a/tests/ConstraintGraphBuilder.test.cpp b/tests/ConstraintGraphBuilder.test.cpp index 96b21613a..00c3309ca 100644 --- a/tests/ConstraintGraphBuilder.test.cpp +++ b/tests/ConstraintGraphBuilder.test.cpp @@ -17,7 +17,7 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "hello_world") )"); cgb.visit(block); - auto constraints = collectConstraints(cgb.rootScope); + auto constraints = collectConstraints(NotNull(cgb.rootScope)); REQUIRE(2 == constraints.size()); @@ -36,7 +36,7 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "primitives") )"); cgb.visit(block); - auto constraints = collectConstraints(cgb.rootScope); + auto constraints = collectConstraints(NotNull(cgb.rootScope)); REQUIRE(3 == constraints.size()); @@ -54,15 +54,15 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "nil_primitive") )"); cgb.visit(block); - auto constraints = collectConstraints(cgb.rootScope); + auto constraints = collectConstraints(NotNull(cgb.rootScope)); ToStringOptions opts; REQUIRE(5 <= constraints.size()); CHECK("*blocked-1* ~ gen () -> (a...)" == toString(*constraints[0], opts)); - CHECK("b ~ inst *blocked-1*" == toString(*constraints[1], opts)); - CHECK("() -> (c...) <: b" == toString(*constraints[2], opts)); - CHECK("c... <: d" == toString(*constraints[3], opts)); + CHECK("*blocked-2* ~ inst *blocked-1*" == toString(*constraints[1], opts)); + CHECK("() -> (b...) <: *blocked-2*" == toString(*constraints[2], opts)); + CHECK("b... <: c" == toString(*constraints[3], opts)); CHECK("nil <: a..." == toString(*constraints[4], opts)); } @@ -74,15 +74,15 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "function_application") )"); cgb.visit(block); - auto constraints = collectConstraints(cgb.rootScope); + auto constraints = collectConstraints(NotNull(cgb.rootScope)); REQUIRE(4 == constraints.size()); ToStringOptions opts; CHECK("string <: a" == toString(*constraints[0], opts)); - CHECK("b ~ inst a" == toString(*constraints[1], opts)); - CHECK("(string) -> (c...) <: b" == toString(*constraints[2], opts)); - CHECK("c... <: d" == toString(*constraints[3], opts)); + CHECK("*blocked-1* ~ inst a" == toString(*constraints[1], opts)); + CHECK("(string) -> (b...) <: *blocked-1*" == toString(*constraints[2], opts)); + CHECK("b... <: c" == toString(*constraints[3], opts)); } TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "local_function_definition") @@ -94,7 +94,7 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "local_function_definition") )"); cgb.visit(block); - auto constraints = collectConstraints(cgb.rootScope); + auto constraints = collectConstraints(NotNull(cgb.rootScope)); REQUIRE(2 == constraints.size()); @@ -112,15 +112,15 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "recursive_function") )"); cgb.visit(block); - auto constraints = collectConstraints(cgb.rootScope); + auto constraints = collectConstraints(NotNull(cgb.rootScope)); REQUIRE(4 == constraints.size()); ToStringOptions opts; CHECK("*blocked-1* ~ gen (a) -> (b...)" == toString(*constraints[0], opts)); - CHECK("c ~ inst (a) -> (b...)" == toString(*constraints[1], opts)); - CHECK("(a) -> (d...) <: c" == toString(*constraints[2], opts)); - CHECK("d... <: b..." == toString(*constraints[3], opts)); + CHECK("*blocked-2* ~ inst (a) -> (b...)" == toString(*constraints[1], opts)); + CHECK("(a) -> (c...) <: *blocked-2*" == toString(*constraints[2], opts)); + CHECK("c... <: b..." == toString(*constraints[3], opts)); } TEST_SUITE_END(); diff --git a/tests/ConstraintSolver.test.cpp b/tests/ConstraintSolver.test.cpp index 5959f55c1..f521c6675 100644 --- a/tests/ConstraintSolver.test.cpp +++ b/tests/ConstraintSolver.test.cpp @@ -9,7 +9,7 @@ using namespace Luau; -static TypeId requireBinding(Scope2* scope, const char* name) +static TypeId requireBinding(NotNull scope, const char* name) { auto b = linearSearchForBinding(scope, name); LUAU_ASSERT(b.has_value()); @@ -26,12 +26,13 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "hello") )"); cgb.visit(block); + NotNull rootScope = NotNull(cgb.rootScope); - ConstraintSolver cs{&arena, cgb.rootScope}; + ConstraintSolver cs{&arena, rootScope}; cs.run(); - TypeId bType = requireBinding(cgb.rootScope, "b"); + TypeId bType = requireBinding(rootScope, "b"); CHECK("number" == toString(bType)); } @@ -45,12 +46,13 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "generic_function") )"); cgb.visit(block); + NotNull rootScope = NotNull(cgb.rootScope); - ConstraintSolver cs{&arena, cgb.rootScope}; + ConstraintSolver cs{&arena, rootScope}; cs.run(); - TypeId idType = requireBinding(cgb.rootScope, "id"); + TypeId idType = requireBinding(rootScope, "id"); CHECK("(a) -> a" == toString(idType)); } @@ -71,14 +73,15 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "proper_let_generalization") )"); cgb.visit(block); + NotNull rootScope = NotNull(cgb.rootScope); ToStringOptions opts; - ConstraintSolver cs{&arena, cgb.rootScope}; + ConstraintSolver cs{&arena, rootScope}; cs.run(); - TypeId idType = requireBinding(cgb.rootScope, "b"); + TypeId idType = requireBinding(rootScope, "b"); CHECK("(a) -> number" == toString(idType, opts)); } diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index ac22f65b7..c92c44579 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -195,12 +195,15 @@ ParseResult Fixture::matchParseError(const std::string& source, const std::strin sourceModule.reset(new SourceModule); ParseResult result = Parser::parse(source.c_str(), source.length(), *sourceModule->names, *sourceModule->allocator, options); - REQUIRE_MESSAGE(!result.errors.empty(), "Expected a parse error in '" << source << "'"); + CHECK_MESSAGE(!result.errors.empty(), "Expected a parse error in '" << source << "'"); - CHECK_EQ(result.errors.front().getMessage(), message); + if (!result.errors.empty()) + { + CHECK_EQ(result.errors.front().getMessage(), message); - if (location) - CHECK_EQ(result.errors.front().getLocation(), *location); + if (location) + CHECK_EQ(result.errors.front().getLocation(), *location); + } return result; } @@ -213,11 +216,14 @@ ParseResult Fixture::matchParseErrorPrefix(const std::string& source, const std: sourceModule.reset(new SourceModule); ParseResult result = Parser::parse(source.c_str(), source.length(), *sourceModule->names, *sourceModule->allocator, options); - REQUIRE_MESSAGE(!result.errors.empty(), "Expected a parse error in '" << source << "'"); + CHECK_MESSAGE(!result.errors.empty(), "Expected a parse error in '" << source << "'"); - const std::string& message = result.errors.front().getMessage(); - CHECK_GE(message.length(), prefix.length()); - CHECK_EQ(prefix, message.substr(0, prefix.size())); + if (!result.errors.empty()) + { + const std::string& message = result.errors.front().getMessage(); + CHECK_GE(message.length(), prefix.length()); + CHECK_EQ(prefix, message.substr(0, prefix.size())); + } return result; } @@ -428,6 +434,7 @@ BuiltinsFixture::BuiltinsFixture(bool freeze, bool prepareAutocomplete) ConstraintGraphBuilderFixture::ConstraintGraphBuilderFixture() : Fixture() + , cgb(mainModuleName, &arena, NotNull(&ice), frontend.getGlobalScope2()) , forceTheFlag{"DebugLuauDeferredConstraintResolution", true} { BlockedTypeVar::nextIndex = 0; diff --git a/tests/Fixture.h b/tests/Fixture.h index 0e3735f67..1bc573da0 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -133,6 +133,7 @@ struct Fixture TestConfigResolver configResolver; std::unique_ptr sourceModule; Frontend frontend; + InternalErrorReporter ice; TypeChecker& typeChecker; std::string decorateWithTypes(const std::string& code); @@ -160,7 +161,7 @@ struct BuiltinsFixture : Fixture struct ConstraintGraphBuilderFixture : Fixture { TypeArena arena; - ConstraintGraphBuilder cgb{&arena}; + ConstraintGraphBuilder cgb; ScopedFastFlag forceTheFlag; diff --git a/tests/NotNull.test.cpp b/tests/NotNull.test.cpp index ed1c25ec2..e77ba78ac 100644 --- a/tests/NotNull.test.cpp +++ b/tests/NotNull.test.cpp @@ -30,23 +30,23 @@ struct Test int Test::count = 0; -} +} // namespace int foo(NotNull p) { return *p; } -void bar(int* q) -{} +void bar(int* q) {} TEST_SUITE_BEGIN("NotNull"); TEST_CASE("basic_stuff") { - NotNull a = NotNull{new int(55)}; // Does runtime test - NotNull b{new int(55)}; // As above - // NotNull c = new int(55); // Nope. Mildly regrettable, but implicit conversion from T* to NotNull in the general case is not good. + NotNull a = NotNull{new int(55)}; // Does runtime test + NotNull b{new int(55)}; // As above + // NotNull c = new int(55); // Nope. Mildly regrettable, but implicit conversion from T* to NotNull in the general case is not + // good. // a = nullptr; // nope diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 878023e35..c3c759989 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -6,6 +6,8 @@ #include "doctest.h" +#include + using namespace Luau; namespace @@ -786,33 +788,46 @@ TEST_CASE_FIXTURE(Fixture, "parse_numbers_decimal") TEST_CASE_FIXTURE(Fixture, "parse_numbers_hexadecimal") { - AstStat* stat = parse("return 0xab, 0XAB05, 0xff_ff"); + AstStat* stat = parse("return 0xab, 0XAB05, 0xff_ff, 0xffffffffffffffff"); REQUIRE(stat != nullptr); AstStatReturn* str = stat->as()->body.data[0]->as(); - CHECK(str->list.size == 3); + CHECK(str->list.size == 4); CHECK_EQ(str->list.data[0]->as()->value, 0xab); CHECK_EQ(str->list.data[1]->as()->value, 0xAB05); CHECK_EQ(str->list.data[2]->as()->value, 0xFFFF); + CHECK_EQ(str->list.data[3]->as()->value, double(ULLONG_MAX)); } TEST_CASE_FIXTURE(Fixture, "parse_numbers_binary") { - AstStat* stat = parse("return 0b1, 0b0, 0b101010"); + AstStat* stat = parse("return 0b1, 0b0, 0b101010, 0b1111111111111111111111111111111111111111111111111111111111111111"); REQUIRE(stat != nullptr); AstStatReturn* str = stat->as()->body.data[0]->as(); - CHECK(str->list.size == 3); + CHECK(str->list.size == 4); CHECK_EQ(str->list.data[0]->as()->value, 1); CHECK_EQ(str->list.data[1]->as()->value, 0); CHECK_EQ(str->list.data[2]->as()->value, 42); + CHECK_EQ(str->list.data[3]->as()->value, double(ULLONG_MAX)); } TEST_CASE_FIXTURE(Fixture, "parse_numbers_error") { + ScopedFastFlag luauErrorParseIntegerIssues{"LuauErrorParseIntegerIssues", true}; + CHECK_EQ(getParseError("return 0b123"), "Malformed number"); CHECK_EQ(getParseError("return 123x"), "Malformed number"); CHECK_EQ(getParseError("return 0xg"), "Malformed number"); + CHECK_EQ(getParseError("return 0x0x123"), "Malformed number"); +} + +TEST_CASE_FIXTURE(Fixture, "parse_numbers_range_error") +{ + ScopedFastFlag luauErrorParseIntegerIssues{"LuauErrorParseIntegerIssues", true}; + + CHECK_EQ(getParseError("return 0x10000000000000000"), "Integer number value is out of range"); + CHECK_EQ(getParseError("return 0b10000000000000000000000000000000000000000000000000000000000000000"), "Integer number value is out of range"); } TEST_CASE_FIXTURE(Fixture, "break_return_not_last_error") @@ -2111,6 +2126,15 @@ type C = Packed<(number, X...)> REQUIRE(stat != nullptr); } +TEST_CASE_FIXTURE(Fixture, "invalid_type_forms") +{ + ScopedFastFlag luauFixNamedFunctionParse{"LuauFixNamedFunctionParse", true}; + + matchParseError("type A = (b: number)", "Expected '->' when parsing function type, got "); + matchParseError("type P = () -> T... type B = P<(x: number, y: string)>", "Expected '->' when parsing function type, got '>'"); + matchParseError("type F = (T...) -> ()", "Expected '->' when parsing function type, got '>'"); +} + TEST_SUITE_END(); TEST_SUITE_BEGIN("ParseErrorRecovery"); diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index e03069a9b..387e07cd0 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -409,6 +409,8 @@ TEST_CASE_FIXTURE(Fixture, "toStringDetailed") TEST_CASE_FIXTURE(BuiltinsFixture, "toStringDetailed2") { + ScopedFastFlag sff2{"DebugLuauSharedSelf", true}; + CheckResult result = check(R"( local base = {} function base:one() return 1 end @@ -424,7 +426,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "toStringDetailed2") TypeId tType = requireType("inst"); ToStringResult r = toStringDetailed(tType); - CHECK_EQ("{ @metatable { __index: { @metatable { __index: base }, child } }, inst }", r.name); + CHECK_EQ("{ @metatable { __index: { @metatable {| __index: base |}, child } }, inst }", r.name); CHECK_EQ(0, r.nameMap.typeVars.size()); ToStringOptions opts; @@ -455,11 +457,10 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "toStringDetailed2") std::string twoResult = toString(tMeta6->props["two"].type, opts); - REQUIRE_EQ("(a) -> number", oneResult.name); - REQUIRE_EQ("(b) -> number", twoResult); + CHECK_EQ("(a) -> number", oneResult.name); + CHECK_EQ("(b) -> number", twoResult); } - TEST_CASE_FIXTURE(Fixture, "toStringErrorPack") { CheckResult result = check(R"( @@ -688,6 +689,10 @@ TEST_CASE_FIXTURE(Fixture, "pick_distinct_names_for_mixed_explicit_and_implicit_ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_include_self_param") { + ScopedFastFlag sff[]{ + {"DebugLuauSharedSelf", true}, + }; + CheckResult result = check(R"( local foo = {} function foo:method(arg: string): () @@ -701,9 +706,12 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_include_self_param") CHECK_EQ("foo:method(self: a, arg: string): ()", toStringNamedFunction("foo:method", *ftv)); } - TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_hide_self_param") { + ScopedFastFlag sff[]{ + {"DebugLuauSharedSelf", true}, + }; + CheckResult result = check(R"( local foo = {} function foo:method(arg: string): () diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index d6f0a0c8e..bdd4d6fd4 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -73,24 +73,24 @@ TEST_CASE_FIXTURE(Fixture, "cannot_steal_hoisted_type_alias") if (FFlag::DebugLuauDeferredConstraintResolution) { CHECK(result.errors[0] == TypeError{ - Location{{1, 21}, {1, 26}}, - getMainSourceModule()->name, - TypeMismatch{ - getSingletonTypes().numberType, - getSingletonTypes().stringType, - }, - }); + Location{{1, 21}, {1, 26}}, + getMainSourceModule()->name, + TypeMismatch{ + getSingletonTypes().numberType, + getSingletonTypes().stringType, + }, + }); } else { CHECK(result.errors[0] == TypeError{ - Location{{1, 8}, {1, 26}}, - getMainSourceModule()->name, - TypeMismatch{ - getSingletonTypes().numberType, - getSingletonTypes().stringType, - }, - }); + Location{{1, 8}, {1, 26}}, + getMainSourceModule()->name, + TypeMismatch{ + getSingletonTypes().numberType, + getSingletonTypes().stringType, + }, + }); } } @@ -716,6 +716,10 @@ TEST_CASE_FIXTURE(Fixture, "forward_declared_alias_is_not_clobbered_by_prior_uni TEST_CASE_FIXTURE(Fixture, "forward_declared_alias_is_not_clobbered_by_prior_unification_with_any_2") { + ScopedFastFlag sff[] = { + {"DebugLuauSharedSelf", true}, + }; + CheckResult result = check(R"( local B = {} B.bar = 4 @@ -737,7 +741,8 @@ TEST_CASE_FIXTURE(Fixture, "forward_declared_alias_is_not_clobbered_by_prior_uni type FutureIntersection = A & B )"); - LUAU_REQUIRE_NO_ERRORS(result); + // TODO: shared self causes this test to break in bizarre ways. + LUAU_REQUIRE_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "recursive_types_restriction_ok") diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index 3e2ad6dc1..8a86ee5fd 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -134,13 +134,13 @@ TEST_CASE_FIXTURE(Fixture, "unknown_type_reference_generates_error") LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK(result.errors[0] == TypeError{ - Location{{1, 17}, {1, 28}}, - getMainSourceModule()->name, - UnknownSymbol{ - "IDoNotExist", - UnknownSymbol::Context::Type, - }, - }); + Location{{1, 17}, {1, 28}}, + getMainSourceModule()->name, + UnknownSymbol{ + "IDoNotExist", + UnknownSymbol::Context::Type, + }, + }); } TEST_CASE_FIXTURE(Fixture, "typeof_variable_type_annotation_should_return_its_type") diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 036a667a4..401a6c640 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -37,6 +37,27 @@ TEST_CASE_FIXTURE(Fixture, "check_function_bodies") }})); } +TEST_CASE_FIXTURE(Fixture, "cannot_hoist_interior_defns_into_signature") +{ + // This test verifies that the signature does not have access to types + // declared within the body. Under DCR, if the function's inner scope + // encompasses the entire function expression, it would be possible for this + // to type check (but the solver output is somewhat undefined). This test + // ensures that this isn't the case. + CheckResult result = check(R"( + local function f(x: T) + type T = number + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(result.errors[0] == TypeError{Location{{1, 28}, {1, 29}}, getMainSourceModule()->name, + UnknownSymbol{ + "T", + UnknownSymbol::Context::Type, + }}); +} + TEST_CASE_FIXTURE(Fixture, "infer_return_type") { CheckResult result = check("function take_five() return 5 end"); diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index 97ba0808b..e9e94cfbc 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -271,13 +271,16 @@ TEST_CASE_FIXTURE(Fixture, "infer_nested_generic_function") TEST_CASE_FIXTURE(Fixture, "infer_generic_methods") { + ScopedFastFlag sff{"DebugLuauSharedSelf", true}; + CheckResult result = check(R"( local x = {} function x:id(x) return x end function x:f(): string return self:id("hello") end function x:g(): number return self:id(37) end )"); - LUAU_REQUIRE_NO_ERRORS(result); + // TODO: Quantification should be doing the conversion, not normalization. + LUAU_REQUIRE_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "calling_self_generic_methods") diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index fd9b1dd41..e6174df2e 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -461,6 +461,61 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_unary_minus") REQUIRE_EQ(gen->message, "Unary operator '-' not supported by type 'bar'"); } +TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_unary_minus_error") +{ + CheckResult result = check(R"( + --!strict + local foo = { + value = 10 + } + + local mt = {} + setmetatable(foo, mt) + + mt.__unm = function(val: boolean): string + return "test" + end + + local a = -foo + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("string", toString(requireType("a"))); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE_EQ(*tm->wantedType, *typeChecker.booleanType); + // given type is the typeof(foo) which is complex to compare against +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_unary_len_error") +{ + ScopedFastFlag sff("LuauCheckLenMT", true); + + CheckResult result = check(R"( + --!strict + local foo = { + value = 10 + } + local mt = {} + setmetatable(foo, mt) + + mt.__len = function(val: any): string + return "test" + end + + local a = #foo + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("number", toString(requireType("a"))); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE_EQ(*tm->wantedType, *typeChecker.numberType); + REQUIRE_EQ(*tm->givenType, *typeChecker.stringType); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "unary_not_is_boolean") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 487e5979d..059aed2e9 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -499,4 +499,26 @@ TEST_CASE_FIXTURE(Fixture, "constrained_is_level_dependent") CHECK_EQ("(t1) -> {| [t1]: boolean |} where t1 = t2 ; t2 = {+ m1: (t1) -> (a...), m2: (t2) -> (b...) +}", toString(requireType("f"))); } +TEST_CASE_FIXTURE(BuiltinsFixture, "greedy_inference_with_shared_self_triggers_function_with_no_returns") +{ + ScopedFastFlag sff{"DebugLuauSharedSelf", true}; + + CheckResult result = check(R"( + local T = {} + T.__index = T + + function T.new() + local self = setmetatable({}, T) + return self:ctor() or self + end + + function T:ctor() + -- oops, no return! + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Not all codepaths in this function return '{ @metatable T, {| |} }, a...'.", toString(result.errors[0])); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 77a2928c9..eead5b30a 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -1863,6 +1863,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "quantifying_a_bound_var_works") TEST_CASE_FIXTURE(BuiltinsFixture, "less_exponential_blowup_please") { + ScopedFastFlag sff{"DebugLuauSharedSelf", true}; + CheckResult result = check(R"( --!strict @@ -1890,7 +1892,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "less_exponential_blowup_please") newData:First() )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); + LUAU_REQUIRE_ERROR_COUNT(2, result); } TEST_CASE_FIXTURE(Fixture, "common_table_element_union_in_call") @@ -2868,6 +2870,7 @@ TEST_CASE_FIXTURE(Fixture, "inferred_return_type_of_free_table") { ScopedFastFlag sff[] = { {"LuauLowerBoundsCalculation", true}, + {"DebugLuauSharedSelf", true}, }; check(R"( @@ -2887,7 +2890,7 @@ TEST_CASE_FIXTURE(Fixture, "inferred_return_type_of_free_table") end )"); - CHECK_EQ("(t1) -> {| Byte: (b) -> (a...), PeekByte: (c) -> (a...) |} where t1 = {+ byte: (t1, number) -> (a...) +}", + CHECK_EQ("(t1) -> {| Byte: (a) -> (b...), PeekByte: (a) -> (b...) |} where t1 = {+ byte: (t1, number) -> (b...) +}", toString(requireType("Base64FileReader"))); } @@ -2904,6 +2907,66 @@ TEST_CASE_FIXTURE(Fixture, "mixed_tables_with_implicit_numbered_keys") CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[2])); } +TEST_CASE_FIXTURE(Fixture, "shared_selfs") +{ + ScopedFastFlag sff{"DebugLuauSharedSelf", true}; + + CheckResult result = check(R"( + local t = {} + t.x = 5 + + function t:m1() return self.x end + function t:m2() return self.y end + + return t + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + ToStringOptions opts; + opts.exhaustive = true; + CHECK_EQ("{| m1: ({+ x: a, y: b +}) -> a, m2: ({+ x: a, y: b +}) -> b, x: number |}", toString(requireType("t"), opts)); +} + +TEST_CASE_FIXTURE(Fixture, "shared_selfs_from_free_param") +{ + ScopedFastFlag sff{"DebugLuauSharedSelf", true}; + + CheckResult result = check(R"( + local function f(t) + function t:m1() return self.x end + function t:m2() return self.y end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("({+ m1: ({+ x: a, y: b +}) -> a, m2: ({+ x: a, y: b +}) -> b +}) -> ()", toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "shared_selfs_through_metatables") +{ + ScopedFastFlag sff{"DebugLuauSharedSelf", true}; + + CheckResult result = check(R"( + local t = {} + t.__index = t + setmetatable({}, t) + + function t:m1() return self.x end + function t:m2() return self.y end + + return t + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + ToStringOptions opts; + opts.exhaustive = true; + CHECK_EQ( + toString(requireType("t"), opts), "t1 where t1 = {| __index: t1, m1: ({+ x: a, y: b +}) -> a, m2: ({+ x: a, y: b +}) -> b |}"); +} + TEST_CASE_FIXTURE(Fixture, "expected_indexer_value_type_extra") { CheckResult result = check(R"( @@ -2953,4 +3016,58 @@ TEST_CASE_FIXTURE(Fixture, "prop_access_on_unions_of_indexers_where_key_whose_ty CHECK_EQ("Type '{number} | {| [boolean]: number |}' does not have key 'x'", toString(result.errors[0])); } +TEST_CASE_FIXTURE(BuiltinsFixture, "quantify_metatables_of_metatables_of_table") +{ + ScopedFastFlag sff[]{ + {"DebugLuauSharedSelf", true}, + }; + + CheckResult result = check(R"( + local T = {} + + function T:m() + return self.x, self.y + end + + function T:n() + end + + local U = setmetatable({}, {__index = T}) + + local V = setmetatable({}, {__index = U}) + + return V + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + ToStringOptions opts; + opts.exhaustive = true; + CHECK_EQ(toString(requireType("V"), opts), "{ @metatable { __index: { @metatable { __index: {| m: ({+ x: a, y: b +}) -> (a, b), n: ({+ x: a, y: b +}) -> () |} }, { } } }, { } }"); +} + +TEST_CASE_FIXTURE(Fixture, "quantify_even_that_table_was_never_exported_at_all") +{ + ScopedFastFlag sff{"DebugLuauSharedSelf", true}; + + CheckResult result = check(R"( + local T = {} + + function T:m() + return self.x + end + + function T:n() + return self.y + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + ToStringOptions opts; + opts.exhaustive = true; + CHECK_EQ("{| m: ({+ x: a, y: b +}) -> a, n: ({+ x: a, y: b +}) -> b |}", toString(requireType("T"), opts)); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 6a048b264..efdfe0b1d 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -369,14 +369,14 @@ TEST_CASE_FIXTURE(Fixture, "globals_are_banned_in_strict_mode") CHECK_EQ("foo", us->name); } -TEST_CASE_FIXTURE(BuiltinsFixture, "correctly_scope_locals_do") +TEST_CASE_FIXTURE(Fixture, "correctly_scope_locals_do") { CheckResult result = check(R"( do local a = 1 end - print(a) -- oops! + local b = a -- oops! )"); LUAU_REQUIRE_ERROR_COUNT(1, result); diff --git a/tests/conformance/basic.lua b/tests/conformance/basic.lua index f803c3193..385a04504 100644 --- a/tests/conformance/basic.lua +++ b/tests/conformance/basic.lua @@ -118,10 +118,12 @@ assert((function() return #_G end)() == 0) assert((function() return #{1,2} end)() == 2) assert((function() return #'g' end)() == 1) -assert((function() local ud = newproxy(true) getmetatable(ud).__len = function() return 42 end return #ud end)() == 42) - assert((function() local a = 1 a = -a return a end)() == -1) +-- __len metamethod +assert((function() local ud = newproxy(true) getmetatable(ud).__len = function() return 42 end return #ud end)() == 42) +assert((function() local t = {} setmetatable(t, { __len = function() return 42 end }) return #t end)() == 42) + -- while/repeat assert((function() local a = 10 local b = 1 while a > 1 do b = b * 2 a = a - 1 end return b end)() == 512) assert((function() local a = 10 local b = 1 repeat b = b * 2 a = a - 1 until a == 1 return b end)() == 512) @@ -889,6 +891,10 @@ assert((function() return table.concat(res, ',') end)() == "6,8,10") +-- typeof and type require an argument +assert(pcall(typeof) == false) +assert(pcall(type) == false) + -- typeof == type in absence of custom userdata assert(concat(typeof(5), typeof(nil), typeof({}), typeof(newproxy())) == "number,nil,table,userdata") diff --git a/tests/conformance/events.lua b/tests/conformance/events.lua index 32e090acc..42f1beda2 100644 --- a/tests/conformance/events.lua +++ b/tests/conformance/events.lua @@ -386,4 +386,42 @@ do assert(t.X) -- fails if table flags are set incorrectly end +do + -- verify __len behavior & error handling + local t = {1} + + setmetatable(t, {}) + assert(#t == 1) + + setmetatable(t, { __len = rawlen }) + assert(#t == 1) + + setmetatable(t, { __len = function() return 42 end }) + assert(#t == 42) + + setmetatable(t, { __len = 42 }) + local ok, err = pcall(function() return #t end) + assert(not ok and err:match("attempt to call a number value")) + + setmetatable(t, { __len = function() end }) + local ok, err = pcall(function() return #t end) + assert(not ok and err:match("'__len' must return a number")) + + setmetatable(t, { __len = error }) + local ok, err = pcall(function() return #t end) + assert(not ok and err == t) +end + +-- verify rawlen behavior +do + local t = {1} + setmetatable(t, { __len = 42 }) + + assert(rawlen(t) == 1) + assert(rawlen("foo") == 3) + + local ok, err = pcall(function() return rawlen(42) end) + assert(not ok and err:match("table or string expected")) +end + return 'OK'