diff --git a/Analysis/include/Luau/Connective.h b/Analysis/include/Luau/Connective.h new file mode 100644 index 000000000..c9daa0f9e --- /dev/null +++ b/Analysis/include/Luau/Connective.h @@ -0,0 +1,68 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Def.h" +#include "Luau/TypedAllocator.h" +#include "Luau/TypeVar.h" +#include "Luau/Variant.h" + +#include + +namespace Luau +{ + +struct Negation; +struct Conjunction; +struct Disjunction; +struct Equivalence; +struct Proposition; +using Connective = Variant; +using ConnectiveId = Connective*; // Can and most likely is nullptr. + +struct Negation +{ + ConnectiveId connective; +}; + +struct Conjunction +{ + ConnectiveId lhs; + ConnectiveId rhs; +}; + +struct Disjunction +{ + ConnectiveId lhs; + ConnectiveId rhs; +}; + +struct Equivalence +{ + ConnectiveId lhs; + ConnectiveId rhs; +}; + +struct Proposition +{ + DefId def; + TypeId discriminantTy; +}; + +template +const T* get(ConnectiveId connective) +{ + return get_if(connective); +} + +struct ConnectiveArena +{ + TypedAllocator allocator; + + ConnectiveId negation(ConnectiveId connective); + ConnectiveId conjunction(ConnectiveId lhs, ConnectiveId rhs); + ConnectiveId disjunction(ConnectiveId lhs, ConnectiveId rhs); + ConnectiveId equivalence(ConnectiveId lhs, ConnectiveId rhs); + ConnectiveId proposition(DefId def, TypeId discriminantTy); +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index 7f092f5b2..4370d0cf4 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -132,15 +132,16 @@ struct HasPropConstraint std::string prop; }; -struct RefinementConstraint +// result ~ if isSingleton D then ~D else unknown where D = discriminantType +struct SingletonOrTopTypeConstraint { - DefId def; + TypeId resultType; TypeId discriminantType; }; using ConstraintV = Variant; + HasPropConstraint, SingletonOrTopTypeConstraint>; struct Constraint { diff --git a/Analysis/include/Luau/ConstraintGraphBuilder.h b/Analysis/include/Luau/ConstraintGraphBuilder.h index 6106717c5..cb5900ea9 100644 --- a/Analysis/include/Luau/ConstraintGraphBuilder.h +++ b/Analysis/include/Luau/ConstraintGraphBuilder.h @@ -2,6 +2,7 @@ #pragma once #include "Luau/Ast.h" +#include "Luau/Connective.h" #include "Luau/Constraint.h" #include "Luau/DataFlowGraphBuilder.h" #include "Luau/Module.h" @@ -26,11 +27,13 @@ struct DcrLogger; struct Inference { TypeId ty = nullptr; + ConnectiveId connective = nullptr; Inference() = default; - explicit Inference(TypeId ty) + explicit Inference(TypeId ty, ConnectiveId connective = nullptr) : ty(ty) + , connective(connective) { } }; @@ -38,11 +41,13 @@ struct Inference struct InferencePack { TypePackId tp = nullptr; + std::vector connectives; InferencePack() = default; - explicit InferencePack(TypePackId tp) + explicit InferencePack(TypePackId tp, const std::vector& connectives = {}) : tp(tp) + , connectives(connectives) { } }; @@ -73,6 +78,7 @@ struct ConstraintGraphBuilder // Defining scopes for AST nodes. DenseHashMap astTypeAliasDefiningScopes{nullptr}; NotNull dfg; + ConnectiveArena connectiveArena; int recursionCount = 0; @@ -126,6 +132,8 @@ struct ConstraintGraphBuilder */ NotNull addConstraint(const ScopePtr& scope, std::unique_ptr c); + void applyRefinements(const ScopePtr& scope, Location location, ConnectiveId connective); + /** * The entry point to the ConstraintGraphBuilder. This will construct a set * of scopes, constraints, and free types that can be solved later. @@ -167,10 +175,10 @@ struct ConstraintGraphBuilder * surrounding context. Used to implement bidirectional type checking. * @return the type of the expression. */ - Inference check(const ScopePtr& scope, AstExpr* expr, std::optional expectedType = {}); + Inference check(const ScopePtr& scope, AstExpr* expr, std::optional expectedType = {}, bool forceSingleton = false); - Inference check(const ScopePtr& scope, AstExprConstantString* string, std::optional expectedType); - Inference check(const ScopePtr& scope, AstExprConstantBool* bool_, std::optional expectedType); + Inference check(const ScopePtr& scope, AstExprConstantString* string, std::optional expectedType, bool forceSingleton); + Inference check(const ScopePtr& scope, AstExprConstantBool* bool_, std::optional expectedType, bool forceSingleton); Inference check(const ScopePtr& scope, AstExprLocal* local); Inference check(const ScopePtr& scope, AstExprGlobal* global); Inference check(const ScopePtr& scope, AstExprIndexName* indexName); @@ -180,6 +188,7 @@ struct ConstraintGraphBuilder Inference check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional expectedType); Inference check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert); Inference check(const ScopePtr& scope, AstExprTable* expr, std::optional expectedType); + std::tuple checkBinary(const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType); TypePackId checkLValues(const ScopePtr& scope, AstArray exprs); diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index 5cc63e656..07f027ad2 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -110,7 +110,7 @@ struct ConstraintSolver bool tryDispatch(const FunctionCallConstraint& c, NotNull constraint); bool tryDispatch(const PrimitiveTypeConstraint& c, NotNull constraint); bool tryDispatch(const HasPropConstraint& c, NotNull constraint); - bool tryDispatch(const RefinementConstraint& c, NotNull constraint); + bool tryDispatch(const SingletonOrTopTypeConstraint& c, NotNull constraint); // for a, ... in some_table do // also handles __iter metamethod diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h index f98442dd1..b28c06a58 100644 --- a/Analysis/include/Luau/Normalize.h +++ b/Analysis/include/Luau/Normalize.h @@ -17,10 +17,8 @@ struct SingletonTypes; using ModulePtr = std::shared_ptr; -bool isSubtype( - TypeId subTy, TypeId superTy, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice, bool anyIsTop = true); -bool isSubtype(TypePackId subTy, TypePackId superTy, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice, - bool anyIsTop = true); +bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice); +bool isSubtype(TypePackId subTy, TypePackId superTy, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice); class TypeIds { @@ -169,12 +167,26 @@ struct NormalizedStringType bool isSubtype(const NormalizedStringType& subStr, const NormalizedStringType& superStr); -// A normalized function type is either `never` (represented by `nullopt`) +// A normalized function type can be `never`, the top function type `function`, // or an intersection of function types. -// NOTE: type normalization can fail on function types with generics -// (e.g. because we do not support unions and intersections of generic type packs), -// so this type may contain `error`. -using NormalizedFunctionType = std::optional; +// +// NOTE: type normalization can fail on function types with generics (e.g. +// because we do not support unions and intersections of generic type packs), so +// this type may contain `error`. +struct NormalizedFunctionType +{ + NormalizedFunctionType(); + + bool isTop = false; + // TODO: Remove this wrapping optional when clipping + // FFlagLuauNegatedFunctionTypes. + std::optional parts; + + void resetToNever(); + void resetToTop(); + + bool isNever() const; +}; // A normalized generic/free type is a union, where each option is of the form (X & T) where // * X is either a free type or a generic @@ -234,12 +246,14 @@ struct NormalizedType NormalizedType(NotNull singletonTypes); - NormalizedType(const NormalizedType&) = delete; - NormalizedType(NormalizedType&&) = default; NormalizedType() = delete; ~NormalizedType() = default; + + NormalizedType(const NormalizedType&) = delete; + NormalizedType& operator=(const NormalizedType&) = delete; + + NormalizedType(NormalizedType&&) = default; NormalizedType& operator=(NormalizedType&&) = default; - NormalizedType& operator=(NormalizedType&) = delete; }; class Normalizer @@ -291,7 +305,7 @@ class Normalizer bool unionNormalWithTy(NormalizedType& here, TypeId there, int ignoreSmallerTyvars = -1); // ------- Negations - NormalizedType negateNormal(const NormalizedType& here); + std::optional negateNormal(const NormalizedType& here); TypeIds negateAll(const TypeIds& theres); TypeId negate(TypeId there); void subtractPrimitive(NormalizedType& here, TypeId ty); diff --git a/Analysis/include/Luau/TypeUtils.h b/Analysis/include/Luau/TypeUtils.h index 7409dbe74..085ee21b0 100644 --- a/Analysis/include/Luau/TypeUtils.h +++ b/Analysis/include/Luau/TypeUtils.h @@ -35,7 +35,7 @@ std::vector flatten(TypeArena& arena, NotNull singletonT * identity) types. * @param types the input type list to reduce. * @returns the reduced type list. -*/ + */ std::vector reduceUnion(const std::vector& types); /** @@ -45,7 +45,7 @@ std::vector reduceUnion(const std::vector& types); * @param arena the type arena to allocate the new type in, if necessary * @param ty the type to remove nil from * @returns a type with nil removed, or nil itself if that were the only option. -*/ + */ TypeId stripNil(NotNull singletonTypes, TypeArena& arena, TypeId ty); } // namespace Luau diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index 70c12cb9d..0ab4d4749 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -115,6 +115,7 @@ struct PrimitiveTypeVar Number, String, Thread, + Function, }; Type type; @@ -504,14 +505,6 @@ struct NeverTypeVar { }; -// Invariant 1: there should never be a reason why such UseTypeVar exists without it mapping to another type. -// Invariant 2: UseTypeVar should always disappear across modules. -struct UseTypeVar -{ - DefId def; - NotNull scope; -}; - // ~T // TODO: Some simplification step that overwrites the type graph to make sure negation // types disappear from the user's view, and (?) a debug flag to disable that @@ -522,9 +515,9 @@ struct NegationTypeVar using ErrorTypeVar = Unifiable::Error; -using TypeVariant = Unifiable::Variant; +using TypeVariant = + Unifiable::Variant; struct TypeVar final { @@ -644,13 +637,14 @@ struct SingletonTypes const TypeId stringType; const TypeId booleanType; const TypeId threadType; + const TypeId functionType; const TypeId trueType; const TypeId falseType; const TypeId anyType; const TypeId unknownType; const TypeId neverType; const TypeId errorType; - const TypeId falsyType; // No type binding! + const TypeId falsyType; // No type binding! const TypeId truthyType; // No type binding! const TypePackId anyTypePack; diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 7bf4d50b7..b5f58d3c6 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -61,7 +61,6 @@ struct Unifier ErrorVec errors; Location location; Variance variance = Covariant; - bool anyIsTop = false; // If true, we consider any to be a top type. If false, it is a familiar but weird mix of top and bottom all at once. bool normalize; // Normalize unions and intersections if necessary bool useScopes = false; // If true, we use the scope hierarchy rather than TypeLevels CountMismatch::Context ctx = CountMismatch::Arg; @@ -131,6 +130,7 @@ struct Unifier Unifier makeChildUnifier(); void reportError(TypeError err); + LUAU_NOINLINE void reportError(Location location, TypeErrorData data); private: bool isNonstrictMode() const; diff --git a/Analysis/include/Luau/Variant.h b/Analysis/include/Luau/Variant.h index 76812c9bf..016c51f62 100644 --- a/Analysis/include/Luau/Variant.h +++ b/Analysis/include/Luau/Variant.h @@ -58,13 +58,15 @@ class Variant constexpr int tid = getTypeId(); typeId = tid; - new (&storage) TT(value); + new (&storage) TT(std::forward(value)); } Variant(const Variant& other) { + static constexpr FnCopy table[sizeof...(Ts)] = {&fnCopy...}; + typeId = other.typeId; - tableCopy[typeId](&storage, &other.storage); + table[typeId](&storage, &other.storage); } Variant(Variant&& other) @@ -192,7 +194,6 @@ class Variant return *static_cast(lhs) == *static_cast(rhs); } - static constexpr FnCopy tableCopy[sizeof...(Ts)] = {&fnCopy...}; static constexpr FnMove tableMove[sizeof...(Ts)] = {&fnMove...}; static constexpr FnDtor tableDtor[sizeof...(Ts)] = {&fnDtor...}; diff --git a/Analysis/include/Luau/VisitTypeVar.h b/Analysis/include/Luau/VisitTypeVar.h index d4f8528ff..3dcddba19 100644 --- a/Analysis/include/Luau/VisitTypeVar.h +++ b/Analysis/include/Luau/VisitTypeVar.h @@ -155,10 +155,6 @@ struct GenericTypeVarVisitor { return visit(ty); } - virtual bool visit(TypeId ty, const UseTypeVar& utv) - { - return visit(ty); - } virtual bool visit(TypeId ty, const NegationTypeVar& ntv) { return visit(ty); @@ -321,8 +317,6 @@ struct GenericTypeVarVisitor traverse(a); } } - else if (auto utv = get(ty)) - visit(ty, *utv); else if (auto ntv = get(ty)) visit(ty, *ntv); else if (!FFlag::LuauCompleteVisitor) diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 6051e117a..ee53ae6b4 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -714,8 +714,7 @@ static bool dcrMagicFunctionPack(MagicFunctionCallContext context) result = arena->addType(UnionTypeVar{std::move(options)}); TypeId numberType = context.solver->singletonTypes->numberType; - TypeId packedTable = arena->addType( - TableTypeVar{{{"n", {numberType}}}, TableIndexer(numberType, result), {}, TableState::Sealed}); + TypeId packedTable = arena->addType(TableTypeVar{{{"n", {numberType}}}, TableIndexer(numberType, result), {}, TableState::Sealed}); TypePackId tableTypePack = arena->addTypePack({packedTable}); asMutable(context.result)->ty.emplace(tableTypePack); diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index 85408919b..86e1c7fc9 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -62,7 +62,6 @@ struct TypeCloner void operator()(const LazyTypeVar& t); void operator()(const UnknownTypeVar& t); void operator()(const NeverTypeVar& t); - void operator()(const UseTypeVar& t); void operator()(const NegationTypeVar& t); }; @@ -338,12 +337,6 @@ void TypeCloner::operator()(const NeverTypeVar& t) defaultClone(t); } -void TypeCloner::operator()(const UseTypeVar& t) -{ - TypeId result = dest.addType(BoundTypeVar{follow(typeId)}); - seenTypes[typeId] = result; -} - void TypeCloner::operator()(const NegationTypeVar& t) { TypeId result = dest.addType(AnyTypeVar{}); diff --git a/Analysis/src/Connective.cpp b/Analysis/src/Connective.cpp new file mode 100644 index 000000000..114b5f2f7 --- /dev/null +++ b/Analysis/src/Connective.cpp @@ -0,0 +1,32 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Connective.h" + +namespace Luau +{ + +ConnectiveId ConnectiveArena::negation(ConnectiveId connective) +{ + return NotNull{allocator.allocate(Negation{connective})}; +} + +ConnectiveId ConnectiveArena::conjunction(ConnectiveId lhs, ConnectiveId rhs) +{ + return NotNull{allocator.allocate(Conjunction{lhs, rhs})}; +} + +ConnectiveId ConnectiveArena::disjunction(ConnectiveId lhs, ConnectiveId rhs) +{ + return NotNull{allocator.allocate(Disjunction{lhs, rhs})}; +} + +ConnectiveId ConnectiveArena::equivalence(ConnectiveId lhs, ConnectiveId rhs) +{ + return NotNull{allocator.allocate(Equivalence{lhs, rhs})}; +} + +ConnectiveId ConnectiveArena::proposition(DefId def, TypeId discriminantTy) +{ + return NotNull{allocator.allocate(Proposition{def, discriminantTy})}; +} + +} // namespace Luau diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index 455fc221d..79a69ca47 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -107,6 +107,101 @@ NotNull ConstraintGraphBuilder::addConstraint(const ScopePtr& scope, return NotNull{scope->constraints.emplace_back(std::move(c)).get()}; } +static void unionRefinements(const std::unordered_map& lhs, const std::unordered_map& rhs, + std::unordered_map& dest, NotNull arena) +{ + for (auto [def, ty] : lhs) + { + auto rhsIt = rhs.find(def); + if (rhsIt == rhs.end()) + continue; + + std::vector discriminants{{ty, rhsIt->second}}; + + if (auto destIt = dest.find(def); destIt != dest.end()) + discriminants.push_back(destIt->second); + + dest[def] = arena->addType(UnionTypeVar{std::move(discriminants)}); + } +} + +static void computeRefinement(const ScopePtr& scope, ConnectiveId connective, std::unordered_map* refis, bool sense, + NotNull arena, bool eq, std::vector* constraints) +{ + using RefinementMap = std::unordered_map; + + if (!connective) + return; + else if (auto negation = get(connective)) + return computeRefinement(scope, negation->connective, refis, !sense, arena, eq, constraints); + else if (auto conjunction = get(connective)) + { + RefinementMap lhsRefis; + RefinementMap rhsRefis; + + computeRefinement(scope, conjunction->lhs, sense ? refis : &lhsRefis, sense, arena, eq, constraints); + computeRefinement(scope, conjunction->rhs, sense ? refis : &rhsRefis, sense, arena, eq, constraints); + + if (!sense) + unionRefinements(lhsRefis, rhsRefis, *refis, arena); + } + else if (auto disjunction = get(connective)) + { + RefinementMap lhsRefis; + RefinementMap rhsRefis; + + computeRefinement(scope, disjunction->lhs, sense ? &lhsRefis : refis, sense, arena, eq, constraints); + computeRefinement(scope, disjunction->rhs, sense ? &rhsRefis : refis, sense, arena, eq, constraints); + + if (sense) + unionRefinements(lhsRefis, rhsRefis, *refis, arena); + } + else if (auto equivalence = get(connective)) + { + computeRefinement(scope, equivalence->lhs, refis, sense, arena, true, constraints); + computeRefinement(scope, equivalence->rhs, refis, sense, arena, true, constraints); + } + else if (auto proposition = get(connective)) + { + TypeId discriminantTy = proposition->discriminantTy; + if (!sense && !eq) + discriminantTy = arena->addType(NegationTypeVar{proposition->discriminantTy}); + else if (!sense && eq) + { + discriminantTy = arena->addType(BlockedTypeVar{}); + constraints->push_back(SingletonOrTopTypeConstraint{discriminantTy, proposition->discriminantTy}); + } + + if (auto it = refis->find(proposition->def); it != refis->end()) + (*refis)[proposition->def] = arena->addType(IntersectionTypeVar{{discriminantTy, it->second}}); + else + (*refis)[proposition->def] = discriminantTy; + } +} + +void ConstraintGraphBuilder::applyRefinements(const ScopePtr& scope, Location location, ConnectiveId connective) +{ + if (!connective) + return; + + std::unordered_map refinements; + std::vector constraints; + computeRefinement(scope, connective, &refinements, /*sense*/ true, arena, /*eq*/ false, &constraints); + + for (auto [def, discriminantTy] : refinements) + { + std::optional defTy = scope->lookup(def); + if (!defTy) + ice->ice("Every DefId must map to a type!"); + + TypeId resultTy = arena->addType(IntersectionTypeVar{{*defTy, discriminantTy}}); + scope->dcrRefinements[def] = resultTy; + } + + for (auto& c : constraints) + addConstraint(scope, location, c); +} + void ConstraintGraphBuilder::visit(AstStatBlock* block) { LUAU_ASSERT(scopes.empty()); @@ -250,14 +345,33 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) if (value->is()) { - // HACK: we leave nil-initialized things floating under the assumption that they will later be populated. - // See the test TypeInfer/infer_locals_with_nil_value. - // Better flow awareness should make this obsolete. + // HACK: we leave nil-initialized things floating under the + // assumption that they will later be populated. + // + // See the test TypeInfer/infer_locals_with_nil_value. Better flow + // awareness should make this obsolete. if (!varTypes[i]) varTypes[i] = freshType(scope); } - else if (i == local->values.size - 1) + // Only function calls and vararg expressions can produce packs. All + // other expressions produce exactly one value. + else if (i != local->values.size - 1 || (!value->is() && !value->is())) + { + std::optional expectedType; + if (hasAnnotation) + expectedType = varTypes.at(i); + + TypeId exprType = check(scope, value, expectedType).ty; + if (i < varTypes.size()) + { + if (varTypes[i]) + addConstraint(scope, local->location, SubtypeConstraint{exprType, varTypes[i]}); + else + varTypes[i] = exprType; + } + } + else { std::vector expectedTypes; if (hasAnnotation) @@ -286,21 +400,6 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) addConstraint(scope, local->location, PackSubtypeConstraint{exprPack, tailPack}); } } - else - { - std::optional expectedType; - if (hasAnnotation) - expectedType = varTypes.at(i); - - TypeId exprType = check(scope, value, expectedType).ty; - if (i < varTypes.size()) - { - if (varTypes[i]) - addConstraint(scope, local->location, SubtypeConstraint{varTypes[i], exprType}); - else - varTypes[i] = exprType; - } - } } for (size_t i = 0; i < local->vars.size; ++i) @@ -569,14 +668,16 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatIf* ifStatement // TODO: Optimization opportunity, the interior scope of the condition could be // reused for the then body, so we don't need to refine twice. ScopePtr condScope = childScope(ifStatement->condition, scope); - check(condScope, ifStatement->condition, std::nullopt); + auto [_, connective] = check(condScope, ifStatement->condition, std::nullopt); ScopePtr thenScope = childScope(ifStatement->thenbody, scope); + applyRefinements(thenScope, Location{}, connective); visit(thenScope, ifStatement->thenbody); if (ifStatement->elsebody) { ScopePtr elseScope = childScope(ifStatement->elsebody, scope); + applyRefinements(elseScope, Location{}, connectiveArena.negation(connective)); visit(elseScope, ifStatement->elsebody); } } @@ -925,7 +1026,7 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa } } -Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, std::optional expectedType) +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, std::optional expectedType, bool forceSingleton) { RecursionCounter counter{&recursionCount}; @@ -938,13 +1039,13 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, st Inference result; if (auto group = expr->as()) - result = check(scope, group->expr, expectedType); + result = check(scope, group->expr, expectedType, forceSingleton); else if (auto stringExpr = expr->as()) - result = check(scope, stringExpr, expectedType); + result = check(scope, stringExpr, expectedType, forceSingleton); else if (expr->is()) result = Inference{singletonTypes->numberType}; else if (auto boolExpr = expr->as()) - result = check(scope, boolExpr, expectedType); + result = check(scope, boolExpr, expectedType, forceSingleton); else if (expr->is()) result = Inference{singletonTypes->nilType}; else if (auto local = expr->as()) @@ -999,8 +1100,11 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, st return result; } -Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantString* string, std::optional expectedType) +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantString* string, std::optional expectedType, bool forceSingleton) { + if (forceSingleton) + return Inference{arena->addType(SingletonTypeVar{StringSingleton{std::string{string->value.data, string->value.size}}})}; + if (expectedType) { const TypeId expectedTy = follow(*expectedType); @@ -1020,12 +1124,15 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantSt return Inference{singletonTypes->stringType}; } -Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantBool* boolExpr, std::optional expectedType) +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantBool* boolExpr, std::optional expectedType, bool forceSingleton) { + const TypeId singletonType = boolExpr->value ? singletonTypes->trueType : singletonTypes->falseType; + if (forceSingleton) + return Inference{singletonType}; + if (expectedType) { const TypeId expectedTy = follow(*expectedType); - const TypeId singletonType = boolExpr->value ? singletonTypes->trueType : singletonTypes->falseType; if (get(expectedTy) || get(expectedTy)) { @@ -1045,8 +1152,8 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantBo Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprLocal* local) { std::optional resultTy; - - if (auto def = dfg->getDef(local)) + auto def = dfg->getDef(local); + if (def) resultTy = scope->lookup(*def); if (!resultTy) @@ -1058,7 +1165,10 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprLocal* loc if (!resultTy) return Inference{singletonTypes->errorRecoveryType()}; // TODO: replace with ice, locals should never exist before its definition. - return Inference{*resultTy}; + if (def) + return Inference{*resultTy, connectiveArena.proposition(*def, singletonTypes->truthyType)}; + else + return Inference{*resultTy}; } Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprGlobal* global) @@ -1107,20 +1217,23 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexExpr* Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprUnary* unary) { - TypeId operandType = check(scope, unary->expr).ty; + auto [operandType, connective] = check(scope, unary->expr); TypeId resultType = arena->addType(BlockedTypeVar{}); addConstraint(scope, unary->location, UnaryConstraint{unary->op, operandType, resultType}); - return Inference{resultType}; + + if (unary->op == AstExprUnary::Not) + return Inference{resultType, connectiveArena.negation(connective)}; + else + return Inference{resultType}; } Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType) { - TypeId leftType = check(scope, binary->left, expectedType).ty; - TypeId rightType = check(scope, binary->right, expectedType).ty; + auto [leftType, rightType, connective] = checkBinary(scope, binary, expectedType); TypeId resultType = arena->addType(BlockedTypeVar{}); addConstraint(scope, binary->location, BinaryConstraint{binary->op, leftType, rightType, resultType}); - return Inference{resultType}; + return Inference{resultType, std::move(connective)}; } Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional expectedType) @@ -1147,6 +1260,58 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTypeAssert return Inference{resolveType(scope, typeAssert->annotation)}; } +std::tuple ConstraintGraphBuilder::checkBinary( + const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType) +{ + if (binary->op == AstExprBinary::And) + { + auto [leftType, leftConnective] = check(scope, binary->left, expectedType); + + ScopePtr rightScope = childScope(binary->right, scope); + applyRefinements(rightScope, binary->right->location, leftConnective); + auto [rightType, rightConnective] = check(rightScope, binary->right, expectedType); + + return {leftType, rightType, connectiveArena.conjunction(leftConnective, rightConnective)}; + } + else if (binary->op == AstExprBinary::Or) + { + auto [leftType, leftConnective] = check(scope, binary->left, expectedType); + + ScopePtr rightScope = childScope(binary->right, scope); + applyRefinements(rightScope, binary->right->location, connectiveArena.negation(leftConnective)); + auto [rightType, rightConnective] = check(rightScope, binary->right, expectedType); + + return {leftType, rightType, connectiveArena.disjunction(leftConnective, rightConnective)}; + } + else if (binary->op == AstExprBinary::CompareEq || binary->op == AstExprBinary::CompareNe) + { + TypeId leftType = check(scope, binary->left, expectedType, true).ty; + TypeId rightType = check(scope, binary->right, expectedType, true).ty; + + ConnectiveId leftConnective = nullptr; + if (auto def = dfg->getDef(binary->left)) + leftConnective = connectiveArena.proposition(*def, rightType); + + ConnectiveId rightConnective = nullptr; + if (auto def = dfg->getDef(binary->right)) + rightConnective = connectiveArena.proposition(*def, leftType); + + if (binary->op == AstExprBinary::CompareNe) + { + leftConnective = connectiveArena.negation(leftConnective); + rightConnective = connectiveArena.negation(rightConnective); + } + + return {leftType, rightType, connectiveArena.equivalence(leftConnective, rightConnective)}; + } + else + { + TypeId leftType = check(scope, binary->left, expectedType).ty; + TypeId rightType = check(scope, binary->right, expectedType).ty; + return {leftType, rightType, nullptr}; + } +} + TypePackId ConstraintGraphBuilder::checkLValues(const ScopePtr& scope, AstArray exprs) { std::vector types; @@ -1841,9 +2006,13 @@ std::vector> ConstraintGraphBuilder:: Inference ConstraintGraphBuilder::flattenPack(const ScopePtr& scope, Location location, InferencePack pack) { - auto [tp] = pack; + const auto& [tp, connectives] = pack; + ConnectiveId connective = nullptr; + if (!connectives.empty()) + connective = connectives[0]; + if (auto f = first(tp)) - return Inference{*f}; + return Inference{*f, connective}; TypeId typeResult = freshType(scope); TypePack onePack{{typeResult}, freshTypePack(scope)}; @@ -1851,7 +2020,7 @@ Inference ConstraintGraphBuilder::flattenPack(const ScopePtr& scope, Location lo addConstraint(scope, location, PackSubtypeConstraint{tp, oneTypePack}); - return Inference{typeResult}; + return Inference{typeResult, connective}; } void ConstraintGraphBuilder::reportError(Location location, TypeErrorData err) diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 5e43be0f8..c53ac659a 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -440,8 +440,8 @@ bool ConstraintSolver::tryDispatch(NotNull constraint, bool fo success = tryDispatch(*fcc, constraint); else if (auto hpc = get(*constraint)) success = tryDispatch(*hpc, constraint); - else if (auto rc = get(*constraint)) - success = tryDispatch(*rc, constraint); + else if (auto sottc = get(*constraint)) + success = tryDispatch(*sottc, constraint); else LUAU_ASSERT(false); @@ -1274,25 +1274,18 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNull constraint) +bool ConstraintSolver::tryDispatch(const SingletonOrTopTypeConstraint& c, NotNull constraint) { - // TODO: Figure out exact details on when refinements need to be blocked. - // It's possible that it never needs to be, since we can just use intersection types with the discriminant type? - - if (!constraint->scope->parent) - iceReporter.ice("No parent scope"); - - std::optional previousTy = constraint->scope->parent->lookup(c.def); - if (!previousTy) - iceReporter.ice("No previous type"); + if (isBlocked(c.discriminantType)) + return false; - std::optional useTy = constraint->scope->lookup(c.def); - if (!useTy) - iceReporter.ice("The def is not bound to a type"); + TypeId followed = follow(c.discriminantType); - TypeId resultTy = follow(*useTy); - std::vector parts{*previousTy, c.discriminantType}; - asMutable(resultTy)->ty.emplace(std::move(parts)); + // `nil` is a singleton type too! There's only one value of type `nil`. + if (get(followed) || isNil(followed)) + *asMutable(c.resultType) = NegationTypeVar{c.discriminantType}; + else + *asMutable(c.resultType) = BoundTypeVar{singletonTypes->unknownType}; return true; } diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index 67abbff1f..339de9755 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -13,16 +13,16 @@ declare bit32: { bor: (...number) -> number, bxor: (...number) -> number, btest: (number, ...number) -> boolean, - rrotate: (number, number) -> number, - lrotate: (number, number) -> number, - lshift: (number, number) -> number, - arshift: (number, number) -> number, - rshift: (number, number) -> number, - bnot: (number) -> number, - extract: (number, number, number?) -> number, - replace: (number, number, number, number?) -> number, - countlz: (number) -> number, - countrz: (number) -> number, + rrotate: (x: number, disp: number) -> number, + lrotate: (x: number, disp: number) -> number, + lshift: (x: number, disp: number) -> number, + arshift: (x: number, disp: number) -> number, + rshift: (x: number, disp: number) -> number, + bnot: (x: number) -> number, + extract: (n: number, field: number, width: number?) -> number, + replace: (n: number, v: number, field: number, width: number?) -> number, + countlz: (n: number) -> number, + countrz: (n: number) -> number, } declare math: { @@ -93,9 +93,9 @@ type DateTypeResult = { } declare os: { - time: (DateTypeArg?) -> number, - date: (string?, number?) -> DateTypeResult | string, - difftime: (DateTypeResult | number, DateTypeResult | number) -> number, + time: (time: DateTypeArg?) -> number, + date: (formatString: string?, time: number?) -> DateTypeResult | string, + difftime: (t2: DateTypeResult | number, t1: DateTypeResult | number) -> number, clock: () -> number, } @@ -145,51 +145,51 @@ declare function loadstring(src: string, chunkname: string?): (((A...) -> declare function newproxy(mt: boolean?): any declare coroutine: { - create: ((A...) -> R...) -> thread, - resume: (thread, A...) -> (boolean, R...), + create: (f: (A...) -> R...) -> thread, + resume: (co: thread, A...) -> (boolean, R...), running: () -> thread, - status: (thread) -> "dead" | "running" | "normal" | "suspended", + status: (co: thread) -> "dead" | "running" | "normal" | "suspended", -- FIXME: This technically returns a function, but we can't represent this yet. - wrap: ((A...) -> R...) -> any, + wrap: (f: (A...) -> R...) -> any, yield: (A...) -> R..., isyieldable: () -> boolean, - close: (thread) -> (boolean, any) + close: (co: thread) -> (boolean, any) } declare table: { - concat: ({V}, string?, number?, number?) -> string, - insert: (({V}, V) -> ()) & (({V}, number, V) -> ()), - maxn: ({V}) -> number, - remove: ({V}, number?) -> V?, - sort: ({V}, ((V, V) -> boolean)?) -> (), - create: (number, V?) -> {V}, - find: ({V}, V, number?) -> number?, - - unpack: ({V}, number?, number?) -> ...V, + concat: (t: {V}, sep: string?, i: number?, j: number?) -> string, + insert: ((t: {V}, value: V) -> ()) & ((t: {V}, pos: number, value: V) -> ()), + maxn: (t: {V}) -> number, + remove: (t: {V}, number?) -> V?, + sort: (t: {V}, comp: ((V, V) -> boolean)?) -> (), + create: (count: number, value: V?) -> {V}, + find: (haystack: {V}, needle: V, init: number?) -> number?, + + unpack: (list: {V}, i: number?, j: number?) -> ...V, pack: (...V) -> { n: number, [number]: V }, - getn: ({V}) -> number, - foreach: ({[K]: V}, (K, V) -> ()) -> (), + getn: (t: {V}) -> number, + foreach: (t: {[K]: V}, f: (K, V) -> ()) -> (), foreachi: ({V}, (number, V) -> ()) -> (), - move: ({V}, number, number, number, {V}?) -> {V}, - clear: ({[K]: V}) -> (), + move: (src: {V}, a: number, b: number, t: number, dst: {V}?) -> {V}, + clear: (table: {[K]: V}) -> (), - isfrozen: ({[K]: V}) -> boolean, + isfrozen: (t: {[K]: V}) -> boolean, } declare debug: { - info: ((thread, number, string) -> R...) & ((number, string) -> R...) & (((A...) -> R1..., string) -> R2...), - traceback: ((string?, number?) -> string) & ((thread, string?, number?) -> string), + info: ((thread: thread, level: number, options: string) -> R...) & ((level: number, options: string) -> R...) & ((func: (A...) -> R1..., options: string) -> R2...), + traceback: ((message: string?, level: number?) -> string) & ((thread: thread, message: string?, level: number?) -> string), } declare utf8: { char: (...number) -> string, charpattern: string, - codes: (string) -> ((string, number) -> (number, number), string, number), - codepoint: (string, number?, number?) -> ...number, - len: (string, number?, number?) -> (number?, number?), - offset: (string, number?, number?) -> number, + codes: (str: string) -> ((string, number) -> (number, number), string, number), + codepoint: (str: string, i: number?, j: number?) -> ...number, + len: (s: string, i: number?, j: number?) -> (number?, number?), + offset: (s: string, n: number?, i: number?) -> number, } -- Cannot use `typeof` here because it will produce a polytype when we expect a monotype. diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index 5ef4b7e7c..21e9f7874 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -7,9 +7,9 @@ #include "Luau/Clone.h" #include "Luau/Common.h" +#include "Luau/RecursionCounter.h" #include "Luau/TypeVar.h" #include "Luau/Unifier.h" -#include "Luau/VisitTypeVar.h" LUAU_FASTFLAGVARIABLE(DebugLuauCopyBeforeNormalizing, false) LUAU_FASTFLAGVARIABLE(DebugLuauCheckNormalizeInvariant, false) @@ -20,6 +20,7 @@ LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000); LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineTableFix, false); LUAU_FASTFLAGVARIABLE(LuauTypeNormalization2, false); LUAU_FASTFLAGVARIABLE(LuauNegatedStringSingletons, false); +LUAU_FASTFLAGVARIABLE(LuauNegatedFunctionTypes, false); LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(LuauOverloadedFunctionSubtypingPerf); @@ -206,6 +207,28 @@ bool isSubtype(const NormalizedStringType& subStr, const NormalizedStringType& s return true; } +NormalizedFunctionType::NormalizedFunctionType() + : parts(FFlag::LuauNegatedFunctionTypes ? std::optional{TypeIds{}} : std::nullopt) +{ +} + +void NormalizedFunctionType::resetToTop() +{ + isTop = true; + parts.emplace(); +} + +void NormalizedFunctionType::resetToNever() +{ + isTop = false; + parts.emplace(); +} + +bool NormalizedFunctionType::isNever() const +{ + return !isTop && (!parts || parts->empty()); +} + NormalizedType::NormalizedType(NotNull singletonTypes) : tops(singletonTypes->neverType) , booleans(singletonTypes->neverType) @@ -220,8 +243,8 @@ NormalizedType::NormalizedType(NotNull singletonTypes) static bool isInhabited(const NormalizedType& norm) { return !get(norm.tops) || !get(norm.booleans) || !norm.classes.empty() || !get(norm.errors) || - !get(norm.nils) || !get(norm.numbers) || !norm.strings.isNever() || - !get(norm.threads) || norm.functions || !norm.tables.empty() || !norm.tyvars.empty(); + !get(norm.nils) || !get(norm.numbers) || !norm.strings.isNever() || !get(norm.threads) || + !norm.functions.isNever() || !norm.tables.empty() || !norm.tyvars.empty(); } static int tyvarIndex(TypeId ty) @@ -317,10 +340,14 @@ static bool isNormalizedThread(TypeId ty) static bool areNormalizedFunctions(const NormalizedFunctionType& tys) { - if (tys) - for (TypeId ty : *tys) + if (tys.parts) + { + for (TypeId ty : *tys.parts) + { if (!get(ty) && !get(ty)) return false; + } + } return true; } @@ -420,7 +447,7 @@ void Normalizer::clearNormal(NormalizedType& norm) norm.strings.resetToNever(); norm.threads = singletonTypes->neverType; norm.tables.clear(); - norm.functions = std::nullopt; + norm.functions.resetToNever(); norm.tyvars.clear(); } @@ -809,20 +836,28 @@ std::optional Normalizer::unionOfFunctions(TypeId here, TypeId there) void Normalizer::unionFunctions(NormalizedFunctionType& heres, const NormalizedFunctionType& theres) { - if (!theres) + if (FFlag::LuauNegatedFunctionTypes) + { + if (heres.isTop) + return; + if (theres.isTop) + heres.resetToTop(); + } + + if (theres.isNever()) return; TypeIds tmps; - if (!heres) + if (heres.isNever()) { - tmps.insert(theres->begin(), theres->end()); - heres = std::move(tmps); + tmps.insert(theres.parts->begin(), theres.parts->end()); + heres.parts = std::move(tmps); return; } - for (TypeId here : *heres) - for (TypeId there : *theres) + for (TypeId here : *heres.parts) + for (TypeId there : *theres.parts) { if (std::optional fun = unionOfFunctions(here, there)) tmps.insert(*fun); @@ -830,28 +865,28 @@ void Normalizer::unionFunctions(NormalizedFunctionType& heres, const NormalizedF tmps.insert(singletonTypes->errorRecoveryType(there)); } - heres = std::move(tmps); + heres.parts = std::move(tmps); } void Normalizer::unionFunctionsWithFunction(NormalizedFunctionType& heres, TypeId there) { - if (!heres) + if (heres.isNever()) { TypeIds tmps; tmps.insert(there); - heres = std::move(tmps); + heres.parts = std::move(tmps); return; } TypeIds tmps; - for (TypeId here : *heres) + for (TypeId here : *heres.parts) { if (std::optional fun = unionOfFunctions(here, there)) tmps.insert(*fun); else tmps.insert(singletonTypes->errorRecoveryType(there)); } - heres = std::move(tmps); + heres.parts = std::move(tmps); } void Normalizer::unionTablesWithTable(TypeIds& heres, TypeId there) @@ -1004,6 +1039,11 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor here.strings.resetToString(); else if (ptv->type == PrimitiveTypeVar::Thread) here.threads = there; + else if (ptv->type == PrimitiveTypeVar::Function) + { + LUAU_ASSERT(FFlag::LuauNegatedFunctionTypes); + here.functions.resetToTop(); + } else LUAU_ASSERT(!"Unreachable"); } @@ -1036,8 +1076,11 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor else if (const NegationTypeVar* ntv = get(there)) { const NormalizedType* thereNormal = normalize(ntv->ty); - NormalizedType tn = negateNormal(*thereNormal); - if (!unionNormals(here, tn)) + std::optional tn = negateNormal(*thereNormal); + if (!tn) + return false; + + if (!unionNormals(here, *tn)) return false; } else @@ -1053,7 +1096,7 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor // ------- Negations -NormalizedType Normalizer::negateNormal(const NormalizedType& here) +std::optional Normalizer::negateNormal(const NormalizedType& here) { NormalizedType result{singletonTypes}; if (!get(here.tops)) @@ -1092,10 +1135,24 @@ NormalizedType Normalizer::negateNormal(const NormalizedType& here) result.threads = get(here.threads) ? singletonTypes->threadType : singletonTypes->neverType; + /* + * Things get weird and so, so complicated if we allow negations of + * arbitrary function types. Ordinary code can never form these kinds of + * types, so we decline to negate them. + */ + if (FFlag::LuauNegatedFunctionTypes) + { + if (here.functions.isNever()) + result.functions.resetToTop(); + else if (here.functions.isTop) + result.functions.resetToNever(); + else + return std::nullopt; + } + // TODO: negating tables - // TODO: negating functions // TODO: negating tyvars? - + return result; } @@ -1142,21 +1199,25 @@ void Normalizer::subtractPrimitive(NormalizedType& here, TypeId ty) LUAU_ASSERT(ptv); switch (ptv->type) { - case PrimitiveTypeVar::NilType: - here.nils = singletonTypes->neverType; - break; - case PrimitiveTypeVar::Boolean: - here.booleans = singletonTypes->neverType; - break; - case PrimitiveTypeVar::Number: - here.numbers = singletonTypes->neverType; - break; - case PrimitiveTypeVar::String: - here.strings.resetToNever(); - break; - case PrimitiveTypeVar::Thread: - here.threads = singletonTypes->neverType; - break; + case PrimitiveTypeVar::NilType: + here.nils = singletonTypes->neverType; + break; + case PrimitiveTypeVar::Boolean: + here.booleans = singletonTypes->neverType; + break; + case PrimitiveTypeVar::Number: + here.numbers = singletonTypes->neverType; + break; + case PrimitiveTypeVar::String: + here.strings.resetToNever(); + break; + case PrimitiveTypeVar::Thread: + here.threads = singletonTypes->neverType; + break; + case PrimitiveTypeVar::Function: + LUAU_ASSERT(FFlag::LuauNegatedStringSingletons); + here.functions.resetToNever(); + break; } } @@ -1589,7 +1650,7 @@ std::optional Normalizer::intersectionOfFunctions(TypeId here, TypeId th TypePackId argTypes; TypePackId retTypes; - + if (hftv->retTypes == tftv->retTypes) { std::optional argTypesOpt = unionOfTypePacks(hftv->argTypes, tftv->argTypes); @@ -1598,7 +1659,7 @@ std::optional Normalizer::intersectionOfFunctions(TypeId here, TypeId th argTypes = *argTypesOpt; retTypes = hftv->retTypes; } - else if (FFlag::LuauOverloadedFunctionSubtypingPerf && hftv->argTypes == tftv->argTypes) + else if (FFlag::LuauOverloadedFunctionSubtypingPerf && hftv->argTypes == tftv->argTypes) { std::optional retTypesOpt = intersectionOfTypePacks(hftv->argTypes, tftv->argTypes); if (!retTypesOpt) @@ -1738,18 +1799,20 @@ std::optional Normalizer::unionSaturatedFunctions(TypeId here, TypeId th void Normalizer::intersectFunctionsWithFunction(NormalizedFunctionType& heres, TypeId there) { - if (!heres) + if (heres.isNever()) return; - for (auto it = heres->begin(); it != heres->end();) + heres.isTop = false; + + for (auto it = heres.parts->begin(); it != heres.parts->end();) { TypeId here = *it; if (get(here)) it++; else if (std::optional tmp = intersectionOfFunctions(here, there)) { - heres->erase(it); - heres->insert(*tmp); + heres.parts->erase(it); + heres.parts->insert(*tmp); return; } else @@ -1757,27 +1820,27 @@ void Normalizer::intersectFunctionsWithFunction(NormalizedFunctionType& heres, T } TypeIds tmps; - for (TypeId here : *heres) + for (TypeId here : *heres.parts) { if (std::optional tmp = unionSaturatedFunctions(here, there)) tmps.insert(*tmp); } - heres->insert(there); - heres->insert(tmps.begin(), tmps.end()); + heres.parts->insert(there); + heres.parts->insert(tmps.begin(), tmps.end()); } void Normalizer::intersectFunctions(NormalizedFunctionType& heres, const NormalizedFunctionType& theres) { - if (!heres) + if (heres.isNever()) return; - else if (!theres) + else if (theres.isNever()) { - heres = std::nullopt; + heres.resetToNever(); return; } else { - for (TypeId there : *theres) + for (TypeId there : *theres.parts) intersectFunctionsWithFunction(heres, there); } } @@ -1935,6 +1998,7 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) TypeId nils = here.nils; TypeId numbers = here.numbers; NormalizedStringType strings = std::move(here.strings); + NormalizedFunctionType functions = std::move(here.functions); TypeId threads = here.threads; clearNormal(here); @@ -1949,6 +2013,11 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) here.strings = std::move(strings); else if (ptv->type == PrimitiveTypeVar::Thread) here.threads = threads; + else if (ptv->type == PrimitiveTypeVar::Function) + { + LUAU_ASSERT(FFlag::LuauNegatedFunctionTypes); + here.functions = std::move(functions); + } else LUAU_ASSERT(!"Unreachable"); } @@ -1981,8 +2050,10 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) for (TypeId part : itv->options) { const NormalizedType* normalPart = normalize(part); - NormalizedType negated = negateNormal(*normalPart); - intersectNormals(here, negated); + std::optional negated = negateNormal(*normalPart); + if (!negated) + return false; + intersectNormals(here, *negated); } } else @@ -2016,14 +2087,16 @@ TypeId Normalizer::typeFromNormal(const NormalizedType& norm) result.insert(result.end(), norm.classes.begin(), norm.classes.end()); if (!get(norm.errors)) result.push_back(norm.errors); - if (norm.functions) + if (FFlag::LuauNegatedFunctionTypes && norm.functions.isTop) + result.push_back(singletonTypes->functionType); + else if (!norm.functions.isNever()) { - if (norm.functions->size() == 1) - result.push_back(*norm.functions->begin()); + if (norm.functions.parts->size() == 1) + result.push_back(*norm.functions.parts->begin()); else { std::vector parts; - parts.insert(parts.end(), norm.functions->begin(), norm.functions->end()); + parts.insert(parts.end(), norm.functions.parts->begin(), norm.functions.parts->end()); result.push_back(arena->addType(IntersectionTypeVar{std::move(parts)})); } } @@ -2070,62 +2143,24 @@ TypeId Normalizer::typeFromNormal(const NormalizedType& norm) return arena->addType(UnionTypeVar{std::move(result)}); } -namespace -{ - -struct Replacer -{ - TypeArena* arena; - TypeId sourceType; - TypeId replacedType; - DenseHashMap newTypes; - - Replacer(TypeArena* arena, TypeId sourceType, TypeId replacedType) - : arena(arena) - , sourceType(sourceType) - , replacedType(replacedType) - , newTypes(nullptr) - { - } - - TypeId smartClone(TypeId t) - { - t = follow(t); - TypeId* res = newTypes.find(t); - if (res) - return *res; - - TypeId result = shallowClone(t, *arena, TxnLog::empty()); - newTypes[t] = result; - newTypes[result] = result; - - return result; - } -}; - -} // anonymous namespace - -bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice, bool anyIsTop) +bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice) { UnifierSharedState sharedState{&ice}; TypeArena arena; Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}}; Unifier u{NotNull{&normalizer}, Mode::Strict, scope, Location{}, Covariant}; - u.anyIsTop = anyIsTop; u.tryUnify(subTy, superTy); const bool ok = u.errors.empty() && u.log.empty(); return ok; } -bool isSubtype( - TypePackId subPack, TypePackId superPack, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice, bool anyIsTop) +bool isSubtype(TypePackId subPack, TypePackId superPack, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice) { UnifierSharedState sharedState{&ice}; TypeArena arena; Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}}; Unifier u{NotNull{&normalizer}, Mode::Strict, scope, Location{}, Covariant}; - u.anyIsTop = anyIsTop; u.tryUnify(subPack, superPack); const bool ok = u.errors.empty() && u.log.empty(); diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 7d0fb22d1..903e156b8 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -10,12 +10,12 @@ #include #include -LUAU_FASTFLAG(LuauUnknownAndNeverType) +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(LuauLvaluelessPath) -LUAU_FASTFLAGVARIABLE(LuauSpecialTypesAsterisked, false) +LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAGVARIABLE(LuauFixNameMaps, false) -LUAU_FASTFLAGVARIABLE(LuauUnseeArrayTtv, false) LUAU_FASTFLAGVARIABLE(LuauFunctionReturnStringificationFixup, false) +LUAU_FASTFLAGVARIABLE(LuauUnseeArrayTtv, false) /* * Prefix generic typenames with gen- @@ -225,6 +225,20 @@ struct StringifierState result.name += s; } + void emitLevel(Scope* scope) + { + size_t count = 0; + for (Scope* s = scope; s; s = s->parent.get()) + ++count; + + emit(count); + emit("-"); + char buffer[16]; + uint32_t s = uint32_t(intptr_t(scope) & 0xFFFFFF); + snprintf(buffer, sizeof(buffer), "0x%x", s); + emit(buffer); + } + void emit(TypeLevel level) { emit(std::to_string(level.level)); @@ -296,10 +310,7 @@ struct TypeVarStringifier if (tv->ty.valueless_by_exception()) { state.result.error = true; - if (FFlag::LuauSpecialTypesAsterisked) - state.emit("* VALUELESS BY EXCEPTION *"); - else - state.emit("< VALUELESS BY EXCEPTION >"); + state.emit("* VALUELESS BY EXCEPTION *"); return; } @@ -377,7 +388,10 @@ struct TypeVarStringifier if (FFlag::DebugLuauVerboseTypeNames) { state.emit("-"); - state.emit(ftv.level); + if (FFlag::DebugLuauDeferredConstraintResolution) + state.emitLevel(ftv.scope); + else + state.emit(ftv.level); } } @@ -399,6 +413,15 @@ struct TypeVarStringifier } else state.emit(state.getName(ty)); + + if (FFlag::DebugLuauVerboseTypeNames) + { + state.emit("-"); + if (FFlag::DebugLuauDeferredConstraintResolution) + state.emitLevel(gtv.scope); + else + state.emit(gtv.level); + } } void operator()(TypeId, const BlockedTypeVar& btv) @@ -434,6 +457,9 @@ struct TypeVarStringifier case PrimitiveTypeVar::Thread: state.emit("thread"); return; + case PrimitiveTypeVar::Function: + state.emit("function"); + return; default: LUAU_ASSERT(!"Unknown primitive type"); throwRuntimeError("Unknown primitive type " + std::to_string(ptv.type)); @@ -462,10 +488,7 @@ struct TypeVarStringifier if (state.hasSeen(&ftv)) { state.result.cycle = true; - if (FFlag::LuauSpecialTypesAsterisked) - state.emit("*CYCLE*"); - else - state.emit(""); + state.emit("*CYCLE*"); return; } @@ -573,10 +596,7 @@ struct TypeVarStringifier if (state.hasSeen(&ttv)) { state.result.cycle = true; - if (FFlag::LuauSpecialTypesAsterisked) - state.emit("*CYCLE*"); - else - state.emit(""); + state.emit("*CYCLE*"); return; } @@ -710,10 +730,7 @@ struct TypeVarStringifier if (state.hasSeen(&uv)) { state.result.cycle = true; - if (FFlag::LuauSpecialTypesAsterisked) - state.emit("*CYCLE*"); - else - state.emit(""); + state.emit("*CYCLE*"); return; } @@ -780,10 +797,7 @@ struct TypeVarStringifier if (state.hasSeen(&uv)) { state.result.cycle = true; - if (FFlag::LuauSpecialTypesAsterisked) - state.emit("*CYCLE*"); - else - state.emit(""); + state.emit("*CYCLE*"); return; } @@ -828,10 +842,7 @@ struct TypeVarStringifier void operator()(TypeId, const ErrorTypeVar& tv) { state.result.error = true; - if (FFlag::LuauSpecialTypesAsterisked) - state.emit(FFlag::LuauUnknownAndNeverType ? "*error-type*" : "*unknown*"); - else - state.emit(FFlag::LuauUnknownAndNeverType ? "" : "*unknown*"); + state.emit(FFlag::LuauUnknownAndNeverType ? "*error-type*" : "*unknown*"); } void operator()(TypeId, const LazyTypeVar& ltv) @@ -850,11 +861,6 @@ struct TypeVarStringifier state.emit("never"); } - void operator()(TypeId ty, const UseTypeVar&) - { - stringify(follow(ty)); - } - void operator()(TypeId, const NegationTypeVar& ntv) { state.emit("~"); @@ -907,10 +913,7 @@ struct TypePackStringifier if (tp->ty.valueless_by_exception()) { state.result.error = true; - if (FFlag::LuauSpecialTypesAsterisked) - state.emit("* VALUELESS TP BY EXCEPTION *"); - else - state.emit("< VALUELESS TP BY EXCEPTION >"); + state.emit("* VALUELESS TP BY EXCEPTION *"); return; } @@ -934,10 +937,7 @@ struct TypePackStringifier if (state.hasSeen(&tp)) { state.result.cycle = true; - if (FFlag::LuauSpecialTypesAsterisked) - state.emit("*CYCLETP*"); - else - state.emit(""); + state.emit("*CYCLETP*"); return; } @@ -982,10 +982,7 @@ struct TypePackStringifier void operator()(TypePackId, const Unifiable::Error& error) { state.result.error = true; - if (FFlag::LuauSpecialTypesAsterisked) - state.emit(FFlag::LuauUnknownAndNeverType ? "*error-type*" : "*unknown*"); - else - state.emit(FFlag::LuauUnknownAndNeverType ? "" : "*unknown*"); + state.emit(FFlag::LuauUnknownAndNeverType ? "*error-type*" : "*unknown*"); } void operator()(TypePackId, const VariadicTypePack& pack) @@ -993,10 +990,7 @@ struct TypePackStringifier state.emit("..."); if (FFlag::DebugLuauVerboseTypeNames && pack.hidden) { - if (FFlag::LuauSpecialTypesAsterisked) - state.emit("*hidden*"); - else - state.emit(""); + state.emit("*hidden*"); } stringify(pack.ty); } @@ -1031,7 +1025,10 @@ struct TypePackStringifier if (FFlag::DebugLuauVerboseTypeNames) { state.emit("-"); - state.emit(pack.level); + if (FFlag::DebugLuauDeferredConstraintResolution) + state.emitLevel(pack.scope); + else + state.emit(pack.level); } state.emit("..."); @@ -1204,10 +1201,7 @@ ToStringResult toStringDetailed(TypeId ty, ToStringOptions& opts) { result.truncated = true; - if (FFlag::LuauSpecialTypesAsterisked) - result.name += "... *TRUNCATED*"; - else - result.name += "... "; + result.name += "... *TRUNCATED*"; } return result; @@ -1280,10 +1274,7 @@ ToStringResult toStringDetailed(TypePackId tp, ToStringOptions& opts) if (opts.maxTypeLength > 0 && result.name.length() > opts.maxTypeLength) { - if (FFlag::LuauSpecialTypesAsterisked) - result.name += "... *TRUNCATED*"; - else - result.name += "... "; + result.name += "... *TRUNCATED*"; } return result; @@ -1526,9 +1517,12 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) { return tos(c.resultType, opts) + " ~ hasProp " + tos(c.subjectType, opts) + ", \"" + c.prop + "\""; } - else if constexpr (std::is_same_v) + else if constexpr (std::is_same_v) { - return "TODO"; + std::string result = tos(c.resultType, opts); + std::string discriminant = tos(c.discriminantType, opts); + + return result + " ~ if isSingleton D then ~D else unknown where D = " + discriminant; } else static_assert(always_false_v, "Non-exhaustive constraint switch"); diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index 179846d7c..c97ed05d2 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -338,12 +338,6 @@ class TypeRehydrationVisitor { return allocator->alloc(Location(), std::nullopt, AstName{"never"}); } - AstType* operator()(const UseTypeVar& utv) - { - std::optional ty = utv.scope->lookup(utv.def); - LUAU_ASSERT(ty); - return Luau::visit(*this, (*ty)->ty); - } AstType* operator()(const NegationTypeVar& ntv) { // FIXME: do the same thing we do with ErrorTypeVar diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index a26731586..dde41a65f 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -301,7 +301,6 @@ struct TypeChecker2 UnifierSharedState sharedState{&ice}; Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}}; Unifier u{NotNull{&normalizer}, Mode::Strict, stack.back(), ret->location, Covariant}; - u.anyIsTop = true; u.tryUnify(actualRetType, expectedRetType); const bool ok = u.errors.empty() && u.log.empty(); @@ -331,16 +330,21 @@ struct TypeChecker2 if (value) visit(value); - if (i != local->values.size - 1) + TypeId* maybeValueType = value ? module->astTypes.find(value) : nullptr; + if (i != local->values.size - 1 || maybeValueType) { AstLocal* var = i < local->vars.size ? local->vars.data[i] : nullptr; if (var && var->annotation) { - TypeId varType = lookupAnnotation(var->annotation); + TypeId annotationType = lookupAnnotation(var->annotation); TypeId valueType = value ? lookupType(value) : nullptr; - if (valueType && !isSubtype(varType, valueType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) - reportError(TypeMismatch{varType, valueType}, value->location); + if (valueType) + { + ErrorVec errors = tryUnify(stack.back(), value->location, valueType, annotationType); + if (!errors.empty()) + reportErrors(std::move(errors)); + } } } else @@ -606,7 +610,7 @@ struct TypeChecker2 visit(rhs); TypeId rhsType = lookupType(rhs); - if (!isSubtype(rhsType, lhsType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) + if (!isSubtype(rhsType, lhsType, stack.back(), singletonTypes, ice)) { reportError(TypeMismatch{lhsType, rhsType}, rhs->location); } @@ -757,7 +761,7 @@ struct TypeChecker2 TypeId actualType = lookupType(number); TypeId numberType = singletonTypes->numberType; - if (!isSubtype(numberType, actualType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) + if (!isSubtype(numberType, actualType, stack.back(), singletonTypes, ice)) { reportError(TypeMismatch{actualType, numberType}, number->location); } @@ -768,7 +772,7 @@ struct TypeChecker2 TypeId actualType = lookupType(string); TypeId stringType = singletonTypes->stringType; - if (!isSubtype(actualType, stringType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) + if (!isSubtype(actualType, stringType, stack.back(), singletonTypes, ice)) { reportError(TypeMismatch{actualType, stringType}, string->location); } @@ -857,7 +861,7 @@ struct TypeChecker2 FunctionTypeVar ftv{argsTp, expectedRetType}; TypeId expectedType = arena.addType(ftv); - if (!isSubtype(testFunctionType, expectedType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) + if (!isSubtype(testFunctionType, expectedType, stack.back(), singletonTypes, ice)) { CloneState cloneState; expectedType = clone(expectedType, module->internalTypes, cloneState); @@ -876,7 +880,7 @@ struct TypeChecker2 getIndexTypeFromType(module->getModuleScope(), leftType, indexName->index.value, indexName->location, /* addErrors */ true); if (ty) { - if (!isSubtype(resultType, *ty, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) + if (!isSubtype(resultType, *ty, stack.back(), singletonTypes, ice)) { reportError(TypeMismatch{resultType, *ty}, indexName->location); } @@ -909,7 +913,7 @@ struct TypeChecker2 TypeId inferredArgTy = *argIt; TypeId annotatedArgTy = lookupAnnotation(arg->annotation); - if (!isSubtype(annotatedArgTy, inferredArgTy, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) + if (!isSubtype(annotatedArgTy, inferredArgTy, stack.back(), singletonTypes, ice)) { reportError(TypeMismatch{annotatedArgTy, inferredArgTy}, arg->location); } @@ -1203,10 +1207,10 @@ struct TypeChecker2 TypeId computedType = lookupType(expr->expr); // Note: As an optimization, we try 'number <: number | string' first, as that is the more likely case. - if (isSubtype(annotationType, computedType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) + if (isSubtype(annotationType, computedType, stack.back(), singletonTypes, ice)) return; - if (isSubtype(computedType, annotationType, stack.back(), singletonTypes, ice, /* anyIsTop */ false)) + if (isSubtype(computedType, annotationType, stack.back(), singletonTypes, ice)) return; reportError(TypesAreUnrelated{computedType, annotationType}, expr->location); @@ -1507,7 +1511,6 @@ struct TypeChecker2 UnifierSharedState sharedState{&ice}; Normalizer normalizer{&module->internalTypes, singletonTypes, NotNull{&sharedState}}; Unifier u{NotNull{&normalizer}, Mode::Strict, scope, location, Covariant}; - u.anyIsTop = true; u.tryUnify(subTy, superTy); return std::move(u.errors); diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 94d633c78..de0890e18 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -57,13 +57,6 @@ TypeId follow(TypeId t, std::function mapper) return btv->boundTo; else if (auto ttv = get(mapper(ty))) return ttv->boundTo; - else if (auto utv = get(mapper(ty))) - { - std::optional ty = utv->scope->lookup(utv->def); - if (!ty) - throwRuntimeError("UseTypeVar must map to another TypeId"); - return *ty; - } else return std::nullopt; }; @@ -761,6 +754,7 @@ SingletonTypes::SingletonTypes() , stringType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::String}, /*persistent*/ true})) , booleanType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::Boolean}, /*persistent*/ true})) , threadType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::Thread}, /*persistent*/ true})) + , functionType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::Function}, /*persistent*/ true})) , trueType(arena->addType(TypeVar{SingletonTypeVar{BooleanSingleton{true}}, /*persistent*/ true})) , falseType(arena->addType(TypeVar{SingletonTypeVar{BooleanSingleton{false}}, /*persistent*/ true})) , anyType(arena->addType(TypeVar{AnyTypeVar{}, /*persistent*/ true})) @@ -946,7 +940,8 @@ void persist(TypeId ty) queue.push_back(mtv->table); queue.push_back(mtv->metatable); } - else if (get(t) || get(t) || get(t) || get(t) || get(t) || get(t)) + else if (get(t) || get(t) || get(t) || get(t) || get(t) || + get(t)) { } else diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index b5eba9803..df5d86f1e 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -8,6 +8,7 @@ #include "Luau/TypePack.h" #include "Luau/TypeUtils.h" #include "Luau/TimeTrace.h" +#include "Luau/TypeVar.h" #include "Luau/VisitTypeVar.h" #include "Luau/ToString.h" @@ -23,6 +24,7 @@ LUAU_FASTFLAGVARIABLE(LuauInstantiateInSubtyping, false) LUAU_FASTFLAGVARIABLE(LuauOverloadedFunctionSubtypingPerf, false); LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) +LUAU_FASTFLAG(LuauNegatedFunctionTypes) namespace Luau { @@ -363,7 +365,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (sharedState.counters.iterationLimit > 0 && sharedState.counters.iterationLimit < sharedState.counters.iterationCount) { - reportError(TypeError{location, UnificationTooComplex{}}); + reportError(location, UnificationTooComplex{}); return; } @@ -404,7 +406,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (subGeneric && !subsumes(useScopes, subGeneric, superFree)) { // TODO: a more informative error message? CLI-39912 - reportError(TypeError{location, GenericError{"Generic subtype escaping scope"}}); + reportError(location, GenericError{"Generic subtype escaping scope"}); return; } @@ -433,7 +435,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (superGeneric && !subsumes(useScopes, superGeneric, subFree)) { // TODO: a more informative error message? CLI-39912 - reportError(TypeError{location, GenericError{"Generic supertype escaping scope"}}); + reportError(location, GenericError{"Generic supertype escaping scope"}); return; } @@ -450,15 +452,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool return tryUnifyWithAny(subTy, superTy); if (get(subTy)) - { - if (anyIsTop) - { - reportError(TypeError{location, TypeMismatch{superTy, subTy}}); - return; - } - else - return tryUnifyWithAny(superTy, subTy); - } + return tryUnifyWithAny(superTy, subTy); if (log.get(subTy)) return tryUnifyWithAny(superTy, subTy); @@ -478,7 +472,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (auto error = sharedState.cachedUnifyError.find({subTy, superTy})) { - reportError(TypeError{location, *error}); + reportError(location, *error); return; } } @@ -520,6 +514,12 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool else if ((log.getMutable(superTy) || log.getMutable(superTy)) && log.getMutable(subTy)) tryUnifySingletons(subTy, superTy); + else if (auto ptv = get(superTy); + FFlag::LuauNegatedFunctionTypes && ptv && ptv->type == PrimitiveTypeVar::Function && get(subTy)) + { + // Ok. Do nothing. forall functions F, F <: function + } + else if (log.getMutable(superTy) && log.getMutable(subTy)) tryUnifyFunctions(subTy, superTy, isFunctionCall); @@ -559,7 +559,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool tryUnifyNegationWithType(subTy, superTy); else - reportError(TypeError{location, TypeMismatch{superTy, subTy}}); + reportError(location, TypeMismatch{superTy, subTy}); if (cacheEnabled) cacheResult(subTy, superTy, errorCount); @@ -633,9 +633,9 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionTypeVar* subUnion, else if (failed) { if (firstFailedOption) - reportError(TypeError{location, TypeMismatch{superTy, subTy, "Not all union options are compatible.", *firstFailedOption}}); + reportError(location, TypeMismatch{superTy, subTy, "Not all union options are compatible.", *firstFailedOption}); else - reportError(TypeError{location, TypeMismatch{superTy, subTy}}); + reportError(location, TypeMismatch{superTy, subTy}); } } @@ -734,7 +734,7 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp const NormalizedType* subNorm = normalizer->normalize(subTy); const NormalizedType* superNorm = normalizer->normalize(superTy); if (!subNorm || !superNorm) - reportError(TypeError{location, UnificationTooComplex{}}); + reportError(location, UnificationTooComplex{}); else if ((failedOptionCount == 1 || foundHeuristic) && failedOption) tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption); else @@ -743,9 +743,9 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp else if (!found) { if ((failedOptionCount == 1 || foundHeuristic) && failedOption) - reportError(TypeError{location, TypeMismatch{superTy, subTy, "None of the union options are compatible. For example:", *failedOption}}); + reportError(location, TypeMismatch{superTy, subTy, "None of the union options are compatible. For example:", *failedOption}); else - reportError(TypeError{location, TypeMismatch{superTy, subTy, "none of the union options are compatible"}}); + reportError(location, TypeMismatch{superTy, subTy, "none of the union options are compatible"}); } } @@ -774,7 +774,7 @@ void Unifier::tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const I if (unificationTooComplex) reportError(*unificationTooComplex); else if (firstFailedOption) - reportError(TypeError{location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible.", *firstFailedOption}}); + reportError(location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible.", *firstFailedOption}); } void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeVar* uv, TypeId superTy, bool cacheEnabled, bool isFunctionCall) @@ -832,11 +832,11 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeV if (subNorm && superNorm) tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the intersection parts are compatible"); else - reportError(TypeError{location, UnificationTooComplex{}}); + reportError(location, UnificationTooComplex{}); } else if (!found) { - reportError(TypeError{location, TypeMismatch{superTy, subTy, "none of the intersection parts are compatible"}}); + reportError(location, TypeMismatch{superTy, subTy, "none of the intersection parts are compatible"}); } } @@ -848,37 +848,37 @@ void Unifier::tryUnifyNormalizedTypes( if (get(superNorm.tops) || get(superNorm.tops) || get(subNorm.tops)) return; else if (get(subNorm.tops)) - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error}); if (get(subNorm.errors)) if (!get(superNorm.errors)) - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error}); if (get(subNorm.booleans)) { if (!get(superNorm.booleans)) - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error}); } else if (const SingletonTypeVar* stv = get(subNorm.booleans)) { if (!get(superNorm.booleans) && stv != get(superNorm.booleans)) - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error}); } if (get(subNorm.nils)) if (!get(superNorm.nils)) - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error}); if (get(subNorm.numbers)) if (!get(superNorm.numbers)) - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error}); if (!isSubtype(subNorm.strings, superNorm.strings)) - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error}); if (get(subNorm.threads)) if (!get(superNorm.errors)) - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error}); for (TypeId subClass : subNorm.classes) { @@ -894,7 +894,7 @@ void Unifier::tryUnifyNormalizedTypes( } } if (!found) - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error}); } for (TypeId subTable : subNorm.tables) @@ -919,21 +919,19 @@ void Unifier::tryUnifyNormalizedTypes( return reportError(*e); } if (!found) - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error}); } - if (subNorm.functions) + if (!subNorm.functions.isNever()) { - if (!superNorm.functions) - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); - if (superNorm.functions->empty()) - return; - for (TypeId superFun : *superNorm.functions) + if (superNorm.functions.isNever()) + return reportError(location, TypeMismatch{superTy, subTy, reason, error}); + for (TypeId superFun : *superNorm.functions.parts) { Unifier innerState = makeChildUnifier(); const FunctionTypeVar* superFtv = get(superFun); if (!superFtv) - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error}); TypePackId tgt = innerState.tryApplyOverloadedFunction(subTy, subNorm.functions, superFtv->argTypes); innerState.tryUnify_(tgt, superFtv->retTypes); if (innerState.errors.empty()) @@ -941,7 +939,7 @@ void Unifier::tryUnifyNormalizedTypes( else if (auto e = hasUnificationTooComplex(innerState.errors)) return reportError(*e); else - return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + return reportError(location, TypeMismatch{superTy, subTy, reason, error}); } } @@ -959,15 +957,15 @@ void Unifier::tryUnifyNormalizedTypes( TypePackId Unifier::tryApplyOverloadedFunction(TypeId function, const NormalizedFunctionType& overloads, TypePackId args) { - if (!overloads || overloads->empty()) + if (overloads.isNever()) { - reportError(TypeError{location, CannotCallNonFunction{function}}); + reportError(location, CannotCallNonFunction{function}); return singletonTypes->errorRecoveryTypePack(); } std::optional result; const FunctionTypeVar* firstFun = nullptr; - for (TypeId overload : *overloads) + for (TypeId overload : *overloads.parts) { if (const FunctionTypeVar* ftv = get(overload)) { @@ -1015,12 +1013,12 @@ TypePackId Unifier::tryApplyOverloadedFunction(TypeId function, const Normalized // TODO: better error reporting? // The logic for error reporting overload resolution // is currently over in TypeInfer.cpp, should we move it? - reportError(TypeError{location, GenericError{"No matching overload."}}); + reportError(location, GenericError{"No matching overload."}); return singletonTypes->errorRecoveryTypePack(firstFun->retTypes); } else { - reportError(TypeError{location, CannotCallNonFunction{function}}); + reportError(location, CannotCallNonFunction{function}); return singletonTypes->errorRecoveryTypePack(); } } @@ -1199,7 +1197,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal if (sharedState.counters.iterationLimit > 0 && sharedState.counters.iterationLimit < sharedState.counters.iterationCount) { - reportError(TypeError{location, UnificationTooComplex{}}); + reportError(location, UnificationTooComplex{}); return; } @@ -1372,7 +1370,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal size_t actualSize = size(subTp); if (ctx == CountMismatch::FunctionResult || ctx == CountMismatch::ExprListResult) std::swap(expectedSize, actualSize); - reportError(TypeError{location, CountMismatch{expectedSize, std::nullopt, actualSize, ctx}}); + reportError(location, CountMismatch{expectedSize, std::nullopt, actualSize, ctx}); while (superIter.good()) { @@ -1394,9 +1392,9 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal else { if (FFlag::LuauReportTypeMismatchForTypePackUnificationFailure) - reportError(TypeError{location, TypePackMismatch{subTp, superTp}}); + reportError(location, TypePackMismatch{subTp, superTp}); else - reportError(TypeError{location, GenericError{"Failed to unify type packs"}}); + reportError(location, GenericError{"Failed to unify type packs"}); } } @@ -1408,7 +1406,7 @@ void Unifier::tryUnifyPrimitives(TypeId subTy, TypeId superTy) ice("passed non primitive types to unifyPrimitives"); if (superPrim->type != subPrim->type) - reportError(TypeError{location, TypeMismatch{superTy, subTy}}); + reportError(location, TypeMismatch{superTy, subTy}); } void Unifier::tryUnifySingletons(TypeId subTy, TypeId superTy) @@ -1429,7 +1427,7 @@ void Unifier::tryUnifySingletons(TypeId subTy, TypeId superTy) if (superPrim && superPrim->type == PrimitiveTypeVar::String && get(subSingleton) && variance == Covariant) return; - reportError(TypeError{location, TypeMismatch{superTy, subTy}}); + reportError(location, TypeMismatch{superTy, subTy}); } void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCall) @@ -1465,21 +1463,21 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal } else { - reportError(TypeError{location, UnificationTooComplex{}}); + reportError(location, UnificationTooComplex{}); } } else if (numGenerics != subFunction->generics.size()) { numGenerics = std::min(superFunction->generics.size(), subFunction->generics.size()); - reportError(TypeError{location, TypeMismatch{superTy, subTy, "different number of generic type parameters"}}); + reportError(location, TypeMismatch{superTy, subTy, "different number of generic type parameters"}); } if (numGenericPacks != subFunction->genericPacks.size()) { numGenericPacks = std::min(superFunction->genericPacks.size(), subFunction->genericPacks.size()); - reportError(TypeError{location, TypeMismatch{superTy, subTy, "different number of generic type pack parameters"}}); + reportError(location, TypeMismatch{superTy, subTy, "different number of generic type pack parameters"}); } for (size_t i = 0; i < numGenerics; i++) @@ -1506,11 +1504,10 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal if (auto e = hasUnificationTooComplex(innerState.errors)) reportError(*e); else if (!innerState.errors.empty() && innerState.firstPackErrorPos) - reportError( - TypeError{location, TypeMismatch{superTy, subTy, format("Argument #%d type is not compatible.", *innerState.firstPackErrorPos), - innerState.errors.front()}}); + reportError(location, TypeMismatch{superTy, subTy, format("Argument #%d type is not compatible.", *innerState.firstPackErrorPos), + innerState.errors.front()}); else if (!innerState.errors.empty()) - reportError(TypeError{location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}}); + reportError(location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}); innerState.ctx = CountMismatch::FunctionResult; innerState.tryUnify_(subFunction->retTypes, superFunction->retTypes); @@ -1520,13 +1517,12 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal if (auto e = hasUnificationTooComplex(innerState.errors)) reportError(*e); else if (!innerState.errors.empty() && size(superFunction->retTypes) == 1 && finite(superFunction->retTypes)) - reportError(TypeError{location, TypeMismatch{superTy, subTy, "Return type is not compatible.", innerState.errors.front()}}); + reportError(location, TypeMismatch{superTy, subTy, "Return type is not compatible.", innerState.errors.front()}); else if (!innerState.errors.empty() && innerState.firstPackErrorPos) - reportError( - TypeError{location, TypeMismatch{superTy, subTy, format("Return #%d type is not compatible.", *innerState.firstPackErrorPos), - innerState.errors.front()}}); + reportError(location, TypeMismatch{superTy, subTy, format("Return #%d type is not compatible.", *innerState.firstPackErrorPos), + innerState.errors.front()}); else if (!innerState.errors.empty()) - reportError(TypeError{location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}}); + reportError(location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}); } log.concat(std::move(innerState.log)); @@ -1608,7 +1604,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) } else { - reportError(TypeError{location, UnificationTooComplex{}}); + reportError(location, UnificationTooComplex{}); } } } @@ -1626,7 +1622,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (!missingProperties.empty()) { - reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(missingProperties)}}); + reportError(location, MissingProperties{superTy, subTy, std::move(missingProperties)}); return; } } @@ -1644,7 +1640,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (!extraProperties.empty()) { - reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra}}); + reportError(location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra}); return; } } @@ -1825,13 +1821,13 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (!missingProperties.empty()) { - reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(missingProperties)}}); + reportError(location, MissingProperties{superTy, subTy, std::move(missingProperties)}); return; } if (!extraProperties.empty()) { - reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra}}); + reportError(location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra}); return; } @@ -1867,14 +1863,14 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed) std::swap(subTy, superTy); if (auto ttv = log.get(superTy); !ttv || ttv->state != TableState::Free) - return reportError(TypeError{location, TypeMismatch{osuperTy, osubTy}}); + return reportError(location, TypeMismatch{osuperTy, osubTy}); auto fail = [&](std::optional e) { std::string reason = "The former's metatable does not satisfy the requirements."; if (e) - reportError(TypeError{location, TypeMismatch{osuperTy, osubTy, reason, *e}}); + reportError(location, TypeMismatch{osuperTy, osubTy, reason, *e}); else - reportError(TypeError{location, TypeMismatch{osuperTy, osubTy, reason}}); + reportError(location, TypeMismatch{osuperTy, osubTy, reason}); }; // Given t1 where t1 = { lower: (t1) -> (a, b...) } @@ -1906,7 +1902,7 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed) } } - reportError(TypeError{location, TypeMismatch{osuperTy, osubTy}}); + reportError(location, TypeMismatch{osuperTy, osubTy}); return; } @@ -1947,7 +1943,7 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) if (auto e = hasUnificationTooComplex(innerState.errors)) reportError(*e); else if (!innerState.errors.empty()) - reportError(TypeError{location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front()}}); + reportError(location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front()}); log.concat(std::move(innerState.log)); } @@ -2024,9 +2020,9 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed) auto fail = [&]() { if (!reversed) - reportError(TypeError{location, TypeMismatch{superTy, subTy}}); + reportError(location, TypeMismatch{superTy, subTy}); else - reportError(TypeError{location, TypeMismatch{subTy, superTy}}); + reportError(location, TypeMismatch{subTy, superTy}); }; const ClassTypeVar* superClass = get(superTy); @@ -2071,7 +2067,7 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed) if (!classProp) { ok = false; - reportError(TypeError{location, UnknownProperty{superTy, propName}}); + reportError(location, UnknownProperty{superTy, propName}); } else { @@ -2095,7 +2091,7 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed) { ok = false; std::string msg = "Class " + superClass->name + " does not have an indexer"; - reportError(TypeError{location, GenericError{msg}}); + reportError(location, GenericError{msg}); } if (!ok) @@ -2116,13 +2112,13 @@ void Unifier::tryUnifyTypeWithNegation(TypeId subTy, TypeId superTy) const NormalizedType* subNorm = normalizer->normalize(subTy); const NormalizedType* superNorm = normalizer->normalize(superTy); if (!subNorm || !superNorm) - return reportError(TypeError{location, UnificationTooComplex{}}); + return reportError(location, UnificationTooComplex{}); // T & queue, DenseHashSet& seenTypePacks, Unifier& state, TypePackId a, TypePackId anyTypePack) @@ -2195,7 +2191,7 @@ void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool rever } else if (get(tail)) { - reportError(TypeError{location, GenericError{"Cannot unify variadic and generic packs"}}); + reportError(location, GenericError{"Cannot unify variadic and generic packs"}); } else if (get(tail)) { @@ -2209,7 +2205,7 @@ void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool rever } else { - reportError(TypeError{location, GenericError{"Failed to unify variadic packs"}}); + reportError(location, GenericError{"Failed to unify variadic packs"}); } } @@ -2351,7 +2347,7 @@ bool Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId hays if (needle == haystack) { - reportError(TypeError{location, OccursCheckFailed{}}); + reportError(location, OccursCheckFailed{}); log.replace(needle, *singletonTypes->errorRecoveryType()); return true; @@ -2402,7 +2398,7 @@ bool Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ { if (needle == haystack) { - reportError(TypeError{location, OccursCheckFailed{}}); + reportError(location, OccursCheckFailed{}); log.replace(needle, *singletonTypes->errorRecoveryTypePack()); return true; @@ -2423,18 +2419,31 @@ bool Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ Unifier Unifier::makeChildUnifier() { Unifier u = Unifier{normalizer, mode, scope, location, variance, &log}; - u.anyIsTop = anyIsTop; u.normalize = normalize; + u.useScopes = useScopes; return u; } // A utility function that appends the given error to the unifier's error log. // This allows setting a breakpoint wherever the unifier reports an error. +// +// Note: report error accepts its arguments by value intentionally to reduce the stack usage of functions which call `reportError`. +void Unifier::reportError(Location location, TypeErrorData data) +{ + errors.emplace_back(std::move(location), std::move(data)); +} + +// A utility function that appends the given error to the unifier's error log. +// This allows setting a breakpoint wherever the unifier reports an error. +// +// Note: to conserve stack space in calling functions it is generally preferred to call `Unifier::reportError(Location location, TypeErrorData data)` +// instead of this method. void Unifier::reportError(TypeError err) { errors.push_back(std::move(err)); } + bool Unifier::isNonstrictMode() const { return (mode == Mode::Nonstrict) || (mode == Mode::NoCheck); @@ -2445,7 +2454,7 @@ void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, TypeId if (auto e = hasUnificationTooComplex(innerErrors)) reportError(*e); else if (!innerErrors.empty()) - reportError(TypeError{location, TypeMismatch{wantedType, givenType}}); + reportError(location, TypeMismatch{wantedType, givenType}); } void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, const std::string& prop, TypeId wantedType, TypeId givenType) diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 85c5f5c60..4c0cc1251 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -25,6 +25,7 @@ LUAU_DYNAMIC_FASTFLAGVARIABLE(LuaReportParseIntegerIssues, false) LUAU_FASTFLAGVARIABLE(LuauInterpolatedStringBaseSupport, false) LUAU_FASTFLAGVARIABLE(LuauCommaParenWarnings, false) +LUAU_FASTFLAGVARIABLE(LuauTableConstructorRecovery, false) bool lua_telemetry_parsed_out_of_range_bin_integer = false; bool lua_telemetry_parsed_out_of_range_hex_integer = false; @@ -2310,9 +2311,13 @@ AstExpr* Parser::parseTableConstructor() MatchLexeme matchBrace = lexer.current(); expectAndConsume('{', "table literal"); + unsigned lastElementIndent = 0; while (lexer.current().type != '}') { + if (FFlag::LuauTableConstructorRecovery) + lastElementIndent = lexer.current().location.begin.column; + if (lexer.current().type == '[') { MatchLexeme matchLocationBracket = lexer.current(); @@ -2357,10 +2362,14 @@ AstExpr* Parser::parseTableConstructor() { nextLexeme(); } - else + else if (FFlag::LuauTableConstructorRecovery && (lexer.current().type == '[' || lexer.current().type == Lexeme::Name) && + lexer.current().location.begin.column == lastElementIndent) { - if (lexer.current().type != '}') - break; + report(lexer.current().location, "Expected ',' after table constructor element"); + } + else if (lexer.current().type != '}') + { + break; } } diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 87e19db8b..e567725e5 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -978,7 +978,8 @@ int replMain(int argc, char** argv) if (compileFormat == CompileFormat::Null) printf("Compiled %d KLOC into %d KB bytecode\n", int(stats.lines / 1000), int(stats.bytecode / 1024)); else if (compileFormat == CompileFormat::CodegenNull) - printf("Compiled %d KLOC into %d KB bytecode => %d KB native code\n", int(stats.lines / 1000), int(stats.bytecode / 1024), int(stats.codegen / 1024)); + printf("Compiled %d KLOC into %d KB bytecode => %d KB native code\n", int(stats.lines / 1000), int(stats.bytecode / 1024), + int(stats.codegen / 1024)); return failed ? 1 : 0; } diff --git a/CodeGen/include/Luau/AddressA64.h b/CodeGen/include/Luau/AddressA64.h new file mode 100644 index 000000000..351e67151 --- /dev/null +++ b/CodeGen/include/Luau/AddressA64.h @@ -0,0 +1,52 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/RegisterA64.h" + +namespace Luau +{ +namespace CodeGen +{ + +enum class AddressKindA64 : uint8_t +{ + imm, // reg + imm + reg, // reg + reg + + // TODO: + // reg + reg << shift + // reg + sext(reg) << shift + // reg + uext(reg) << shift + // pc + offset +}; + +struct AddressA64 +{ + AddressA64(RegisterA64 base, int off = 0) + : kind(AddressKindA64::imm) + , base(base) + , offset(xzr) + , data(off) + { + LUAU_ASSERT(base.kind == KindA64::x); + LUAU_ASSERT(off >= 0 && off < 4096); + } + + AddressA64(RegisterA64 base, RegisterA64 offset) + : kind(AddressKindA64::reg) + , base(base) + , offset(offset) + , data(0) + { + LUAU_ASSERT(base.kind == KindA64::x); + LUAU_ASSERT(offset.kind == KindA64::x); + } + + AddressKindA64 kind; + RegisterA64 base; + RegisterA64 offset; + int data; +}; + +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/include/Luau/AssemblyBuilderA64.h b/CodeGen/include/Luau/AssemblyBuilderA64.h new file mode 100644 index 000000000..9a1402bec --- /dev/null +++ b/CodeGen/include/Luau/AssemblyBuilderA64.h @@ -0,0 +1,144 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/RegisterA64.h" +#include "Luau/AddressA64.h" +#include "Luau/ConditionA64.h" +#include "Luau/Label.h" + +#include +#include + +namespace Luau +{ +namespace CodeGen +{ + +class AssemblyBuilderA64 +{ +public: + explicit AssemblyBuilderA64(bool logText); + ~AssemblyBuilderA64(); + + // Moves + void mov(RegisterA64 dst, RegisterA64 src); + void mov(RegisterA64 dst, uint16_t src, int shift = 0); + void movk(RegisterA64 dst, uint16_t src, int shift = 0); + + // Arithmetics + void add(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift = 0); + void add(RegisterA64 dst, RegisterA64 src1, int src2); + void sub(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift = 0); + void sub(RegisterA64 dst, RegisterA64 src1, int src2); + void neg(RegisterA64 dst, RegisterA64 src); + + // Comparisons + // Note: some arithmetic instructions also have versions that update flags (ADDS etc) but we aren't using them atm + // TODO: add cmp + + // Binary + // Note: shifted-register support and bitfield operations are omitted for simplicity + // TODO: support immediate arguments (they have odd encoding and forbid many values) + // TODO: support not variants for and/or/eor (required to support not...) + void and_(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void orr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void eor(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void lsl(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void lsr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void asr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void ror(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void clz(RegisterA64 dst, RegisterA64 src); + void rbit(RegisterA64 dst, RegisterA64 src); + + // Load + // Note: paired loads are currently omitted for simplicity + void ldr(RegisterA64 dst, AddressA64 src); + void ldrb(RegisterA64 dst, AddressA64 src); + void ldrh(RegisterA64 dst, AddressA64 src); + void ldrsb(RegisterA64 dst, AddressA64 src); + void ldrsh(RegisterA64 dst, AddressA64 src); + void ldrsw(RegisterA64 dst, AddressA64 src); + + // Store + void str(RegisterA64 src, AddressA64 dst); + void strb(RegisterA64 src, AddressA64 dst); + void strh(RegisterA64 src, AddressA64 dst); + + // Control flow + // Note: tbz/tbnz are currently not supported because they have 15-bit offsets and we don't support branch thunks + void b(ConditionA64 cond, Label& label); + void cbz(RegisterA64 src, Label& label); + void cbnz(RegisterA64 src, Label& label); + void ret(); + + // Run final checks + bool finalize(); + + // Places a label at current location and returns it + Label setLabel(); + + // Assigns label position to the current location + void setLabel(Label& label); + + void logAppend(const char* fmt, ...) LUAU_PRINTF_ATTR(2, 3); + + uint32_t getCodeSize() const; + + // Resulting data and code that need to be copied over one after the other + // The *end* of 'data' has to be aligned to 16 bytes, this will also align 'code' + std::vector data; + std::vector code; + + std::string text; + + const bool logText = false; + +private: + // Instruction archetypes + void place0(const char* name, uint32_t word); + void placeSR3(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, int shift = 0); + void placeSR2(const char* name, RegisterA64 dst, RegisterA64 src, uint8_t op); + void placeR3(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, uint8_t op2); + void placeR1(const char* name, RegisterA64 dst, RegisterA64 src, uint32_t op); + void placeI12(const char* name, RegisterA64 dst, RegisterA64 src1, int src2, uint8_t op); + void placeI16(const char* name, RegisterA64 dst, int src, uint8_t op, int shift = 0); + void placeA(const char* name, RegisterA64 dst, AddressA64 src, uint8_t op, uint8_t size); + void placeBC(const char* name, Label& label, uint8_t op, uint8_t cond); + void placeBR(const char* name, Label& label, uint8_t op, RegisterA64 cond); + + void place(uint32_t word); + void placeLabel(Label& label); + + void commit(); + LUAU_NOINLINE void extend(); + + // Data + size_t allocateData(size_t size, size_t align); + + // Logging of assembly in text form + LUAU_NOINLINE void log(const char* opcode); + LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift = 0); + LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, RegisterA64 src1, int src2); + LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, RegisterA64 src); + LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, int src, int shift = 0); + LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, AddressA64 src); + LUAU_NOINLINE void log(const char* opcode, RegisterA64 src, Label label); + LUAU_NOINLINE void log(const char* opcode, Label label); + LUAU_NOINLINE void log(Label label); + LUAU_NOINLINE void log(RegisterA64 reg); + LUAU_NOINLINE void log(AddressA64 addr); + + uint32_t nextLabel = 1; + std::vector