diff --git a/Analysis/include/Luau/BuiltinDefinitions.h b/Analysis/include/Luau/BuiltinDefinitions.h index 162139581..d44576385 100644 --- a/Analysis/include/Luau/BuiltinDefinitions.h +++ b/Analysis/include/Luau/BuiltinDefinitions.h @@ -16,9 +16,7 @@ struct TypeArena; void registerBuiltinTypes(GlobalTypes& globals); -void registerBuiltinGlobals(TypeChecker& typeChecker, GlobalTypes& globals); -void registerBuiltinGlobals(Frontend& frontend); - +void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeCheckForAutocomplete = false); TypeId makeUnion(TypeArena& arena, std::vector&& types); TypeId makeIntersection(TypeArena& arena, std::vector&& types); diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index 68ba8ff5d..82251378e 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -8,7 +8,6 @@ #include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/Variant.h" - #include #include #include @@ -36,9 +35,6 @@ struct LoadDefinitionFileResult ModulePtr module; }; -LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, GlobalTypes& globals, ScopePtr targetScope, std::string_view definition, - const std::string& packageName, bool captureComments); - std::optional parseMode(const std::vector& hotcomments); std::vector parsePathExpr(const AstExpr& pathExpr); @@ -55,7 +51,9 @@ std::optional pathExprToModuleName(const ModuleName& currentModuleNa * error when we try during typechecking. */ std::optional pathExprToModuleName(const ModuleName& currentModuleName, const AstExpr& expr); - +// TODO: Deprecate this code path when we move away from the old solver +LoadDefinitionFileResult loadDefinitionFileNoDCR(TypeChecker& typeChecker, GlobalTypes& globals, ScopePtr targetScope, std::string_view definition, + const std::string& packageName, bool captureComments); struct SourceNode { bool hasDirtySourceModule() const @@ -140,10 +138,6 @@ struct Frontend CheckResult check(const ModuleName& name, std::optional optionOverride = {}); // new shininess - // Use 'check' with 'runLintChecks' set to true in FrontendOptions (enabledLintWarnings be set there as well) - LintResult lint_DEPRECATED(const ModuleName& name, std::optional enabledLintWarnings = {}); - LintResult lint_DEPRECATED(const SourceModule& module, std::optional enabledLintWarnings = {}); - bool isDirty(const ModuleName& name, bool forAutocomplete = false) const; void markDirty(const ModuleName& name, std::vector* markedDirty = nullptr); @@ -164,10 +158,11 @@ struct Frontend ScopePtr addEnvironment(const std::string& environmentName); ScopePtr getEnvironmentScope(const std::string& environmentName) const; - void registerBuiltinDefinition(const std::string& name, std::function); + void registerBuiltinDefinition(const std::string& name, std::function); void applyBuiltinDefinitionToEnvironment(const std::string& environmentName, const std::string& definitionName); - LoadDefinitionFileResult loadDefinitionFile(std::string_view source, const std::string& packageName, bool captureComments); + LoadDefinitionFileResult loadDefinitionFile(GlobalTypes& globals, ScopePtr targetScope, std::string_view source, const std::string& packageName, + bool captureComments, bool typeCheckForAutocomplete = false); private: ModulePtr check(const SourceModule& sourceModule, Mode mode, std::vector requireCycles, bool forAutocomplete = false, bool recordJsonLog = false); @@ -182,7 +177,7 @@ struct Frontend ScopePtr getModuleEnvironment(const SourceModule& module, const Config& config, bool forAutocomplete) const; std::unordered_map environments; - std::unordered_map> builtinDefinitions; + std::unordered_map> builtinDefinitions; BuiltinTypes builtinTypes_; diff --git a/Analysis/include/Luau/Type.h b/Analysis/include/Luau/Type.h index dba2a8de2..cff86df42 100644 --- a/Analysis/include/Luau/Type.h +++ b/Analysis/include/Luau/Type.h @@ -75,13 +75,44 @@ using TypeId = const Type*; using Name = std::string; // A free type var is one whose exact shape has yet to be fully determined. -using FreeType = Unifiable::Free; +struct FreeType +{ + explicit FreeType(TypeLevel level); + explicit FreeType(Scope* scope); + FreeType(Scope* scope, TypeLevel level); -// When a free type var is unified with any other, it is then "bound" -// to that type var, indicating that the two types are actually the same type. -using BoundType = Unifiable::Bound; + int index; + TypeLevel level; + Scope* scope = nullptr; + + // True if this free type variable is part of a mutually + // recursive type alias whose definitions haven't been + // resolved yet. + bool forwardedTypeAlias = false; +}; + +struct GenericType +{ + // By default, generics are global, with a synthetic name + GenericType(); -using GenericType = Unifiable::Generic; + explicit GenericType(TypeLevel level); + explicit GenericType(const Name& name); + explicit GenericType(Scope* scope); + + GenericType(TypeLevel level, const Name& name); + GenericType(Scope* scope, const Name& name); + + int index; + TypeLevel level; + Scope* scope = nullptr; + Name name; + bool explicitName = false; +}; + +// When an equality constraint is found, it is then "bound" to that type, +// indicating that the two types are actually the same type. +using BoundType = Unifiable::Bound; using Tags = std::vector; @@ -395,9 +426,11 @@ struct TableType // Represents a metatable attached to a table type. Somewhat analogous to a bound type. struct MetatableType { - // Always points to a TableType. + // Should always be a TableType. TypeId table; - // Always points to either a TableType or a MetatableType. + // Should almost always either be a TableType or another MetatableType, + // though it is possible for other types (like AnyType and ErrorType) to + // find their way here sometimes. TypeId metatable; std::optional syntheticName; @@ -536,8 +569,8 @@ struct NegationType using ErrorType = Unifiable::Error; -using TypeVariant = Unifiable::Variant; +using TypeVariant = Unifiable::Variant; struct Type final { diff --git a/Analysis/include/Luau/TypePack.h b/Analysis/include/Luau/TypePack.h index 4831f2338..2ae56e5f0 100644 --- a/Analysis/include/Luau/TypePack.h +++ b/Analysis/include/Luau/TypePack.h @@ -12,20 +12,48 @@ namespace Luau { struct TypeArena; +struct TxnLog; struct TypePack; struct VariadicTypePack; struct BlockedTypePack; struct TypePackVar; +using TypePackId = const TypePackVar*; -struct TxnLog; +struct FreeTypePack +{ + explicit FreeTypePack(TypeLevel level); + explicit FreeTypePack(Scope* scope); + FreeTypePack(Scope* scope, TypeLevel level); + + int index; + TypeLevel level; + Scope* scope = nullptr; +}; + +struct GenericTypePack +{ + // By default, generics are global, with a synthetic name + GenericTypePack(); + explicit GenericTypePack(TypeLevel level); + explicit GenericTypePack(const Name& name); + explicit GenericTypePack(Scope* scope); + GenericTypePack(TypeLevel level, const Name& name); + GenericTypePack(Scope* scope, const Name& name); + + int index; + TypeLevel level; + Scope* scope = nullptr; + Name name; + bool explicitName = false; +}; -using TypePackId = const TypePackVar*; -using FreeTypePack = Unifiable::Free; using BoundTypePack = Unifiable::Bound; -using GenericTypePack = Unifiable::Generic; -using TypePackVariant = Unifiable::Variant; + +using ErrorTypePack = Unifiable::Error; + +using TypePackVariant = Unifiable::Variant; /* A TypePack is a rope-like string of TypeIds. We use this structure to encode * notions like packs of unknown length and packs of any length, as well as more diff --git a/Analysis/include/Luau/Unifiable.h b/Analysis/include/Luau/Unifiable.h index ae55f3734..79b3b7dea 100644 --- a/Analysis/include/Luau/Unifiable.h +++ b/Analysis/include/Luau/Unifiable.h @@ -83,24 +83,6 @@ using Name = std::string; int freshIndex(); -struct Free -{ - explicit Free(TypeLevel level); - explicit Free(Scope* scope); - explicit Free(Scope* scope, TypeLevel level); - - int index; - TypeLevel level; - Scope* scope = nullptr; - // True if this free type variable is part of a mutually - // recursive type alias whose definitions haven't been - // resolved yet. - bool forwardedTypeAlias = false; - -private: - static int DEPRECATED_nextIndex; -}; - template struct Bound { @@ -112,26 +94,6 @@ struct Bound Id boundTo; }; -struct Generic -{ - // By default, generics are global, with a synthetic name - Generic(); - explicit Generic(TypeLevel level); - explicit Generic(const Name& name); - explicit Generic(Scope* scope); - Generic(TypeLevel level, const Name& name); - Generic(Scope* scope, const Name& name); - - int index; - TypeLevel level; - Scope* scope = nullptr; - Name name; - bool explicitName = false; - -private: - static int DEPRECATED_nextIndex; -}; - struct Error { // This constructor has to be public, since it's used in Type and TypePack, @@ -145,6 +107,6 @@ struct Error }; template -using Variant = Luau::Variant, Generic, Error, Value...>; +using Variant = Luau::Variant, Error, Value...>; } // namespace Luau::Unifiable diff --git a/Analysis/include/Luau/VisitType.h b/Analysis/include/Luau/VisitType.h index ff4dfc3c3..95b2b0507 100644 --- a/Analysis/include/Luau/VisitType.h +++ b/Analysis/include/Luau/VisitType.h @@ -341,10 +341,10 @@ struct GenericTypeVisitor traverse(btv->boundTo); } - else if (auto ftv = get(tp)) + else if (auto ftv = get(tp)) visit(tp, *ftv); - else if (auto gtv = get(tp)) + else if (auto gtv = get(tp)) visit(tp, *gtv); else if (auto etv = get(tp)) diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 3fdd93190..42fc9a717 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -13,8 +13,6 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauAutocompleteSkipNormalization, false); - static const std::unordered_set kStatementStartingKeywords = { "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; @@ -143,12 +141,9 @@ static bool checkTypeMatch(TypeId subTy, TypeId superTy, NotNull scope, T Normalizer normalizer{typeArena, builtinTypes, NotNull{&unifierState}}; Unifier unifier(NotNull{&normalizer}, Mode::Strict, scope, Location(), Variance::Covariant); - if (FFlag::LuauAutocompleteSkipNormalization) - { - // Cost of normalization can be too high for autocomplete response time requirements - unifier.normalize = false; - unifier.checkInhabited = false; - } + // Cost of normalization can be too high for autocomplete response time requirements + unifier.normalize = false; + unifier.checkInhabited = false; return unifier.canUnify(subTy, superTy).empty(); } diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 2108b160f..7ed92fb41 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -212,7 +212,7 @@ void registerBuiltinTypes(GlobalTypes& globals) globals.globalScope->addBuiltinTypeBinding("never", TypeFun{{}, globals.builtinTypes->neverType}); } -void registerBuiltinGlobals(TypeChecker& typeChecker, GlobalTypes& globals) +void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeCheckForAutocomplete) { LUAU_ASSERT(!globals.globalTypes.types.isFrozen()); LUAU_ASSERT(!globals.globalTypes.typePacks.isFrozen()); @@ -220,8 +220,8 @@ void registerBuiltinGlobals(TypeChecker& typeChecker, GlobalTypes& globals) TypeArena& arena = globals.globalTypes; NotNull builtinTypes = globals.builtinTypes; - LoadDefinitionFileResult loadResult = - Luau::loadDefinitionFile(typeChecker, globals, globals.globalScope, getBuiltinDefinitionSource(), "@luau", /* captureComments */ false); + LoadDefinitionFileResult loadResult = frontend.loadDefinitionFile( + globals, globals.globalScope, getBuiltinDefinitionSource(), "@luau", /* captureComments */ false, typeCheckForAutocomplete); LUAU_ASSERT(loadResult.success); TypeId genericK = arena.addType(GenericType{"K"}); @@ -309,106 +309,6 @@ void registerBuiltinGlobals(TypeChecker& typeChecker, GlobalTypes& globals) attachDcrMagicFunction(getGlobalBinding(globals, "require"), dcrMagicFunctionRequire); } -void registerBuiltinGlobals(Frontend& frontend) -{ - GlobalTypes& globals = frontend.globals; - - LUAU_ASSERT(!globals.globalTypes.types.isFrozen()); - LUAU_ASSERT(!globals.globalTypes.typePacks.isFrozen()); - - registerBuiltinTypes(globals); - - TypeArena& arena = globals.globalTypes; - NotNull builtinTypes = globals.builtinTypes; - - LoadDefinitionFileResult loadResult = frontend.loadDefinitionFile(getBuiltinDefinitionSource(), "@luau", /* captureComments */ false); - LUAU_ASSERT(loadResult.success); - - TypeId genericK = arena.addType(GenericType{"K"}); - TypeId genericV = arena.addType(GenericType{"V"}); - TypeId mapOfKtoV = arena.addType(TableType{{}, TableIndexer(genericK, genericV), globals.globalScope->level, TableState::Generic}); - - std::optional stringMetatableTy = getMetatable(builtinTypes->stringType, builtinTypes); - LUAU_ASSERT(stringMetatableTy); - const TableType* stringMetatableTable = get(follow(*stringMetatableTy)); - LUAU_ASSERT(stringMetatableTable); - - auto it = stringMetatableTable->props.find("__index"); - LUAU_ASSERT(it != stringMetatableTable->props.end()); - - addGlobalBinding(globals, "string", it->second.type, "@luau"); - - // next(t: Table, i: K?) -> (K?, V) - TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(builtinTypes, arena, genericK)}}); - TypePackId nextRetsTypePack = arena.addTypePack(TypePack{{makeOption(builtinTypes, arena, genericK), genericV}}); - addGlobalBinding(globals, "next", arena.addType(FunctionType{{genericK, genericV}, {}, nextArgsTypePack, nextRetsTypePack}), "@luau"); - - TypePackId pairsArgsTypePack = arena.addTypePack({mapOfKtoV}); - - TypeId pairsNext = arena.addType(FunctionType{nextArgsTypePack, nextRetsTypePack}); - TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, builtinTypes->nilType}}); - - // pairs(t: Table) -> ((Table, K?) -> (K?, V), Table, nil) - addGlobalBinding(globals, "pairs", arena.addType(FunctionType{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); - - TypeId genericMT = arena.addType(GenericType{"MT"}); - - TableType tab{TableState::Generic, globals.globalScope->level}; - TypeId tabTy = arena.addType(tab); - - TypeId tableMetaMT = arena.addType(MetatableType{tabTy, genericMT}); - - addGlobalBinding(globals, "getmetatable", makeFunction(arena, std::nullopt, {genericMT}, {}, {tableMetaMT}, {genericMT}), "@luau"); - - // clang-format off - // setmetatable(T, MT) -> { @metatable MT, T } - addGlobalBinding(globals, "setmetatable", - arena.addType( - FunctionType{ - {genericMT}, - {}, - arena.addTypePack(TypePack{{tabTy, genericMT}}), - arena.addTypePack(TypePack{{tableMetaMT}}) - } - ), "@luau" - ); - // clang-format on - - for (const auto& pair : globals.globalScope->bindings) - { - persist(pair.second.typeId); - - if (TableType* ttv = getMutable(pair.second.typeId)) - { - if (!ttv->name) - ttv->name = "typeof(" + toString(pair.first) + ")"; - } - } - - attachMagicFunction(getGlobalBinding(globals, "assert"), magicFunctionAssert); - attachMagicFunction(getGlobalBinding(globals, "setmetatable"), magicFunctionSetMetaTable); - attachMagicFunction(getGlobalBinding(globals, "select"), magicFunctionSelect); - attachDcrMagicFunction(getGlobalBinding(globals, "select"), dcrMagicFunctionSelect); - - if (TableType* ttv = getMutable(getGlobalBinding(globals, "table"))) - { - // tabTy is a generic table type which we can't express via declaration syntax yet - ttv->props["freeze"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.freeze"); - ttv->props["clone"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.clone"); - - ttv->props["getn"].deprecated = true; - ttv->props["getn"].deprecatedSuggestion = "#"; - ttv->props["foreach"].deprecated = true; - ttv->props["foreachi"].deprecated = true; - - attachMagicFunction(ttv->props["pack"].type, magicFunctionPack); - attachDcrMagicFunction(ttv->props["pack"].type, dcrMagicFunctionPack); - } - - attachMagicFunction(getGlobalBinding(globals, "require"), magicFunctionRequire); - attachDcrMagicFunction(getGlobalBinding(globals, "require"), dcrMagicFunctionRequire); -} - static std::optional> magicFunctionSelect( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) { diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index 2645209d5..ac73622d3 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -44,10 +44,10 @@ struct TypeCloner template void defaultClone(const T& t); - void operator()(const Unifiable::Free& t); - void operator()(const Unifiable::Generic& t); - void operator()(const Unifiable::Bound& t); - void operator()(const Unifiable::Error& t); + void operator()(const FreeType& t); + void operator()(const GenericType& t); + void operator()(const BoundType& t); + void operator()(const ErrorType& t); void operator()(const BlockedType& t); void operator()(const PendingExpansionType& t); void operator()(const PrimitiveType& t); @@ -89,15 +89,15 @@ struct TypePackCloner seenTypePacks[typePackId] = cloned; } - void operator()(const Unifiable::Free& t) + void operator()(const FreeTypePack& t) { defaultClone(t); } - void operator()(const Unifiable::Generic& t) + void operator()(const GenericTypePack& t) { defaultClone(t); } - void operator()(const Unifiable::Error& t) + void operator()(const ErrorTypePack& t) { defaultClone(t); } @@ -145,12 +145,12 @@ void TypeCloner::defaultClone(const T& t) seenTypes[typeId] = cloned; } -void TypeCloner::operator()(const Unifiable::Free& t) +void TypeCloner::operator()(const FreeType& t) { defaultClone(t); } -void TypeCloner::operator()(const Unifiable::Generic& t) +void TypeCloner::operator()(const GenericType& t) { defaultClone(t); } diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 191e94f4d..98022d862 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -29,7 +29,6 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTINT(LuauTarjanChildLimit) LUAU_FASTFLAG(LuauInferInNoCheckMode) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) -LUAU_FASTFLAGVARIABLE(LuauLintInTypecheck, false) LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 100) LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false); @@ -84,32 +83,20 @@ static void generateDocumentationSymbols(TypeId ty, const std::string& rootName) } } -LoadDefinitionFileResult Frontend::loadDefinitionFile(std::string_view source, const std::string& packageName, bool captureComments) +static ParseResult parseSourceForModule(std::string_view source, Luau::SourceModule& sourceModule, bool captureComments) { - if (!FFlag::DebugLuauDeferredConstraintResolution) - return Luau::loadDefinitionFile(typeChecker, globals, globals.globalScope, source, packageName, captureComments); - - LUAU_TIMETRACE_SCOPE("loadDefinitionFile", "Frontend"); - - Luau::SourceModule sourceModule; - ParseOptions options; options.allowDeclarationSyntax = true; options.captureComments = captureComments; Luau::ParseResult parseResult = Luau::Parser::parse(source.data(), source.size(), *sourceModule.names, *sourceModule.allocator, options); - - if (parseResult.errors.size() > 0) - return LoadDefinitionFileResult{false, parseResult, sourceModule, nullptr}; - sourceModule.root = parseResult.root; sourceModule.mode = Mode::Definition; + return parseResult; +} - ModulePtr checkedModule = check(sourceModule, Mode::Definition, {}); - - if (checkedModule->errors.size() > 0) - return LoadDefinitionFileResult{false, parseResult, sourceModule, checkedModule}; - +static void persistCheckedTypes(ModulePtr checkedModule, GlobalTypes& globals, ScopePtr targetScope, const std::string& packageName) +{ CloneState cloneState; std::vector typesToPersist; @@ -120,7 +107,7 @@ LoadDefinitionFileResult Frontend::loadDefinitionFile(std::string_view source, c TypeId globalTy = clone(ty, globals.globalTypes, cloneState); std::string documentationSymbol = packageName + "/global/" + name; generateDocumentationSymbols(globalTy, documentationSymbol); - globals.globalScope->bindings[globals.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; + targetScope->bindings[globals.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; typesToPersist.push_back(globalTy); } @@ -130,7 +117,7 @@ LoadDefinitionFileResult Frontend::loadDefinitionFile(std::string_view source, c TypeFun globalTy = clone(ty, globals.globalTypes, cloneState); std::string documentationSymbol = packageName + "/globaltype/" + name; generateDocumentationSymbols(globalTy.type, documentationSymbol); - globals.globalScope->exportedTypeBindings[name] = globalTy; + targetScope->exportedTypeBindings[name] = globalTy; typesToPersist.push_back(globalTy.type); } @@ -139,63 +126,49 @@ LoadDefinitionFileResult Frontend::loadDefinitionFile(std::string_view source, c { persist(ty); } - - return LoadDefinitionFileResult{true, parseResult, sourceModule, checkedModule}; } -LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, GlobalTypes& globals, ScopePtr targetScope, std::string_view source, - const std::string& packageName, bool captureComments) +LoadDefinitionFileResult Frontend::loadDefinitionFile(GlobalTypes& globals, ScopePtr targetScope, std::string_view source, + const std::string& packageName, bool captureComments, bool typeCheckForAutocomplete) { + if (!FFlag::DebugLuauDeferredConstraintResolution) + return Luau::loadDefinitionFileNoDCR(typeCheckForAutocomplete ? typeCheckerForAutocomplete : typeChecker, + typeCheckForAutocomplete ? globalsForAutocomplete : globals, targetScope, source, packageName, captureComments); + LUAU_TIMETRACE_SCOPE("loadDefinitionFile", "Frontend"); Luau::SourceModule sourceModule; - - ParseOptions options; - options.allowDeclarationSyntax = true; - options.captureComments = captureComments; - - Luau::ParseResult parseResult = Luau::Parser::parse(source.data(), source.size(), *sourceModule.names, *sourceModule.allocator, options); - + Luau::ParseResult parseResult = parseSourceForModule(source, sourceModule, captureComments); if (parseResult.errors.size() > 0) return LoadDefinitionFileResult{false, parseResult, sourceModule, nullptr}; - sourceModule.root = parseResult.root; - sourceModule.mode = Mode::Definition; - - ModulePtr checkedModule = typeChecker.check(sourceModule, Mode::Definition); + ModulePtr checkedModule = check(sourceModule, Mode::Definition, {}); if (checkedModule->errors.size() > 0) return LoadDefinitionFileResult{false, parseResult, sourceModule, checkedModule}; - CloneState cloneState; + persistCheckedTypes(checkedModule, globals, targetScope, packageName); - std::vector typesToPersist; - typesToPersist.reserve(checkedModule->declaredGlobals.size() + checkedModule->exportedTypeBindings.size()); + return LoadDefinitionFileResult{true, parseResult, sourceModule, checkedModule}; +} - for (const auto& [name, ty] : checkedModule->declaredGlobals) - { - TypeId globalTy = clone(ty, globals.globalTypes, cloneState); - std::string documentationSymbol = packageName + "/global/" + name; - generateDocumentationSymbols(globalTy, documentationSymbol); - targetScope->bindings[globals.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; +LoadDefinitionFileResult loadDefinitionFileNoDCR(TypeChecker& typeChecker, GlobalTypes& globals, ScopePtr targetScope, std::string_view source, + const std::string& packageName, bool captureComments) +{ + LUAU_TIMETRACE_SCOPE("loadDefinitionFile", "Frontend"); - typesToPersist.push_back(globalTy); - } + Luau::SourceModule sourceModule; + Luau::ParseResult parseResult = parseSourceForModule(source, sourceModule, captureComments); - for (const auto& [name, ty] : checkedModule->exportedTypeBindings) - { - TypeFun globalTy = clone(ty, globals.globalTypes, cloneState); - std::string documentationSymbol = packageName + "/globaltype/" + name; - generateDocumentationSymbols(globalTy.type, documentationSymbol); - targetScope->exportedTypeBindings[name] = globalTy; + if (parseResult.errors.size() > 0) + return LoadDefinitionFileResult{false, parseResult, sourceModule, nullptr}; - typesToPersist.push_back(globalTy.type); - } + ModulePtr checkedModule = typeChecker.check(sourceModule, Mode::Definition); - for (TypeId ty : typesToPersist) - { - persist(ty); - } + if (checkedModule->errors.size() > 0) + return LoadDefinitionFileResult{false, parseResult, sourceModule, checkedModule}; + + persistCheckedTypes(checkedModule, globals, targetScope, packageName); return LoadDefinitionFileResult{true, parseResult, sourceModule, checkedModule}; } @@ -316,8 +289,6 @@ static ErrorVec accumulateErrors( static void filterLintOptions(LintOptions& lintOptions, const std::vector& hotcomments, Mode mode) { - LUAU_ASSERT(FFlag::LuauLintInTypecheck); - uint64_t ignoreLints = LintWarning::parseMask(hotcomments); lintOptions.warningMask &= ~ignoreLints; @@ -472,24 +443,16 @@ CheckResult Frontend::check(const ModuleName& name, std::optional& modules = - frontendOptions.forAutocomplete ? moduleResolverForAutocomplete.modules : moduleResolver.modules; + std::unordered_map& modules = + frontendOptions.forAutocomplete ? moduleResolverForAutocomplete.modules : moduleResolver.modules; - checkResult.errors = accumulateErrors(sourceNodes, modules, name); + checkResult.errors = accumulateErrors(sourceNodes, modules, name); - // Get lint result only for top checked module - if (auto it = modules.find(name); it != modules.end()) - checkResult.lintResult = it->second->lintResult; + // Get lint result only for top checked module + if (auto it = modules.find(name); it != modules.end()) + checkResult.lintResult = it->second->lintResult; - return checkResult; - } - else - { - return CheckResult{accumulateErrors( - sourceNodes, frontendOptions.forAutocomplete ? moduleResolverForAutocomplete.modules : moduleResolver.modules, name)}; - } + return checkResult; } std::vector buildQueue; @@ -553,9 +516,10 @@ CheckResult Frontend::check(const ModuleName& name, std::optional& modules = - frontendOptions.forAutocomplete ? moduleResolverForAutocomplete.modules : moduleResolver.modules; + // Get lint result only for top checked module + std::unordered_map& modules = + frontendOptions.forAutocomplete ? moduleResolverForAutocomplete.modules : moduleResolver.modules; - if (auto it = modules.find(name); it != modules.end()) - checkResult.lintResult = it->second->lintResult; - } + if (auto it = modules.find(name); it != modules.end()) + checkResult.lintResult = it->second->lintResult; return checkResult; } @@ -800,59 +759,6 @@ ScopePtr Frontend::getModuleEnvironment(const SourceModule& module, const Config return result; } -LintResult Frontend::lint_DEPRECATED(const ModuleName& name, std::optional enabledLintWarnings) -{ - LUAU_ASSERT(!FFlag::LuauLintInTypecheck); - - LUAU_TIMETRACE_SCOPE("Frontend::lint", "Frontend"); - LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); - - auto [_sourceNode, sourceModule] = getSourceNode(name); - - if (!sourceModule) - return LintResult{}; // FIXME: We really should do something a bit more obvious when a file is too broken to lint. - - return lint_DEPRECATED(*sourceModule, enabledLintWarnings); -} - -LintResult Frontend::lint_DEPRECATED(const SourceModule& module, std::optional enabledLintWarnings) -{ - LUAU_ASSERT(!FFlag::LuauLintInTypecheck); - - LUAU_TIMETRACE_SCOPE("Frontend::lint", "Frontend"); - LUAU_TIMETRACE_ARGUMENT("module", module.name.c_str()); - - const Config& config = configResolver->getConfig(module.name); - - uint64_t ignoreLints = LintWarning::parseMask(module.hotcomments); - - LintOptions options = enabledLintWarnings.value_or(config.enabledLint); - options.warningMask &= ~ignoreLints; - - Mode mode = module.mode.value_or(config.mode); - if (mode != Mode::NoCheck) - { - options.disableWarning(Luau::LintWarning::Code_UnknownGlobal); - } - - if (mode == Mode::Strict) - { - options.disableWarning(Luau::LintWarning::Code_ImplicitReturn); - } - - ScopePtr environmentScope = getModuleEnvironment(module, config, /*forAutocomplete*/ false); - - ModulePtr modulePtr = moduleResolver.getModule(module.name); - - double timestamp = getTimestamp(); - - std::vector warnings = Luau::lint(module.root, *module.names, environmentScope, modulePtr.get(), module.hotcomments, options); - - stats.timeLint += getTimestamp() - timestamp; - - return classifyLints(warnings, config); -} - bool Frontend::isDirty(const ModuleName& name, bool forAutocomplete) const { auto it = sourceNodes.find(name); @@ -1195,7 +1101,7 @@ ScopePtr Frontend::getEnvironmentScope(const std::string& environmentName) const return {}; } -void Frontend::registerBuiltinDefinition(const std::string& name, std::function applicator) +void Frontend::registerBuiltinDefinition(const std::string& name, std::function applicator) { LUAU_ASSERT(builtinDefinitions.count(name) == 0); @@ -1208,7 +1114,7 @@ void Frontend::applyBuiltinDefinitionToEnvironment(const std::string& environmen LUAU_ASSERT(builtinDefinitions.count(definitionName) > 0); if (builtinDefinitions.count(definitionName) > 0) - builtinDefinitions[definitionName](typeChecker, globals, getEnvironmentScope(environmentName)); + builtinDefinitions[definitionName](*this, globals, getEnvironmentScope(environmentName)); } LintResult Frontend::classifyLints(const std::vector& warnings, const Config& config) diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index 7c56a4b8f..46595b702 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -20,6 +20,7 @@ LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000); LUAU_FASTFLAGVARIABLE(LuauNegatedClassTypes, false); LUAU_FASTFLAGVARIABLE(LuauNegatedTableTypes, false); LUAU_FASTFLAGVARIABLE(LuauNormalizeBlockedTypes, false); +LUAU_FASTFLAGVARIABLE(LuauNormalizeMetatableFixes, false); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(LuauUninhabitedSubAnything2) LUAU_FASTFLAG(LuauTransitiveSubtyping) @@ -2062,6 +2063,18 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there else if (isPrim(there, PrimitiveType::Table)) return here; + if (FFlag::LuauNormalizeMetatableFixes) + { + if (get(here)) + return there; + else if (get(there)) + return here; + else if (get(here)) + return there; + else if (get(there)) + return here; + } + TypeId htable = here; TypeId hmtable = nullptr; if (const MetatableType* hmtv = get(here)) @@ -2078,9 +2091,23 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there } const TableType* httv = get(htable); - LUAU_ASSERT(httv); + if (FFlag::LuauNormalizeMetatableFixes) + { + if (!httv) + return std::nullopt; + } + else + LUAU_ASSERT(httv); + const TableType* tttv = get(ttable); - LUAU_ASSERT(tttv); + if (FFlag::LuauNormalizeMetatableFixes) + { + if (!tttv) + return std::nullopt; + } + else + LUAU_ASSERT(tttv); + if (httv->state == TableState::Free || tttv->state == TableState::Free) return std::nullopt; diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 5c0f48fae..fe09ef11a 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -14,7 +14,6 @@ #include LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) -LUAU_FASTFLAGVARIABLE(LuauFunctionReturnStringificationFixup, false) /* * Prefix generic typenames with gen- @@ -369,7 +368,7 @@ struct TypeStringifier state.emit(">"); } - void operator()(TypeId ty, const Unifiable::Free& ftv) + void operator()(TypeId ty, const FreeType& ftv) { state.result.invalid = true; if (FFlag::DebugLuauVerboseTypeNames) diff --git a/Analysis/src/Type.cpp b/Analysis/src/Type.cpp index 021d95285..d70f17f57 100644 --- a/Analysis/src/Type.cpp +++ b/Analysis/src/Type.cpp @@ -430,6 +430,69 @@ bool hasLength(TypeId ty, DenseHashSet& seen, int* recursionCount) return false; } +FreeType::FreeType(TypeLevel level) + : index(Unifiable::freshIndex()) + , level(level) + , scope(nullptr) +{ +} + +FreeType::FreeType(Scope* scope) + : index(Unifiable::freshIndex()) + , level{} + , scope(scope) +{ +} + +FreeType::FreeType(Scope* scope, TypeLevel level) + : index(Unifiable::freshIndex()) + , level(level) + , scope(scope) +{ +} + +GenericType::GenericType() + : index(Unifiable::freshIndex()) + , name("g" + std::to_string(index)) +{ +} + +GenericType::GenericType(TypeLevel level) + : index(Unifiable::freshIndex()) + , level(level) + , name("g" + std::to_string(index)) +{ +} + +GenericType::GenericType(const Name& name) + : index(Unifiable::freshIndex()) + , name(name) + , explicitName(true) +{ +} + +GenericType::GenericType(Scope* scope) + : index(Unifiable::freshIndex()) + , scope(scope) +{ +} + +GenericType::GenericType(TypeLevel level, const Name& name) + : index(Unifiable::freshIndex()) + , level(level) + , name(name) + , explicitName(true) +{ +} + +GenericType::GenericType(Scope* scope, const Name& name) + : index(Unifiable::freshIndex()) + , scope(scope) + , name(name) + , explicitName(true) +{ +} + BlockedType::BlockedType() : index(FFlag::LuauNormalizeBlockedTypes ? Unifiable::freshIndex() : ++DEPRECATED_nextIndex) { @@ -971,7 +1034,7 @@ const TypeLevel* getLevel(TypeId ty) { ty = follow(ty); - if (auto ftv = get(ty)) + if (auto ftv = get(ty)) return &ftv->level; else if (auto ttv = get(ty)) return &ttv->level; @@ -990,7 +1053,7 @@ std::optional getLevel(TypePackId tp) { tp = follow(tp); - if (auto ftv = get(tp)) + if (auto ftv = get(tp)) return ftv->level; else return std::nullopt; diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index f9a162056..d6494edfd 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -35,7 +35,21 @@ using SyntheticNames = std::unordered_map; namespace Luau { -static const char* getName(Allocator* allocator, SyntheticNames* syntheticNames, const Unifiable::Generic& gen) +static const char* getName(Allocator* allocator, SyntheticNames* syntheticNames, const GenericType& gen) +{ + size_t s = syntheticNames->size(); + char*& n = (*syntheticNames)[&gen]; + if (!n) + { + std::string str = gen.explicitName ? gen.name : generateName(s); + n = static_cast(allocator->allocate(str.size() + 1)); + strcpy(n, str.c_str()); + } + + return n; +} + +static const char* getName(Allocator* allocator, SyntheticNames* syntheticNames, const GenericTypePack& gen) { size_t s = syntheticNames->size(); char*& n = (*syntheticNames)[&gen]; @@ -237,7 +251,7 @@ class TypeRehydrationVisitor size_t numGenericPacks = 0; for (auto it = ftv.genericPacks.begin(); it != ftv.genericPacks.end(); ++it) { - if (auto gtv = get(*it)) + if (auto gtv = get(*it)) genericPacks.data[numGenericPacks++] = {AstName(gtv->name.c_str()), Location(), nullptr}; } diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index f47815588..acf70fec1 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -1020,7 +1020,7 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assig right = errorRecoveryType(scope); else if (auto vtp = get(tailPack)) right = vtp->ty; - else if (get(tailPack)) + else if (get(tailPack)) { *asMutable(tailPack) = TypePack{{left}}; growingPack = getMutable(tailPack); @@ -1281,7 +1281,7 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) callRetPack = checkExprPack(scope, *exprCall).type; callRetPack = follow(callRetPack); - if (get(callRetPack)) + if (get(callRetPack)) { iterTy = freshType(scope); unify(callRetPack, addTypePack({{iterTy}, freshTypePack(scope)}), scope, forin.location); @@ -1951,7 +1951,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp return WithPredicate{errorRecoveryType(scope)}; else if (auto vtp = get(varargPack)) return WithPredicate{vtp->ty}; - else if (get(varargPack)) + else if (get(varargPack)) { // TODO: Better error? reportError(expr.location, GenericError{"Trying to get a type from a variadic type parameter"}); @@ -1970,7 +1970,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp { return {pack->head.empty() ? nilType : pack->head[0], std::move(result.predicates)}; } - else if (const FreeTypePack* ftp = get(retPack)) + else if (const FreeTypePack* ftp = get(retPack)) { TypeId head = freshType(scope->level); TypePackId pack = addTypePack(TypePackVar{TypePack{{head}, freshTypePack(scope->level)}}); @@ -1981,7 +1981,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp return {errorRecoveryType(scope), std::move(result.predicates)}; else if (auto vtp = get(retPack)) return {vtp->ty, std::move(result.predicates)}; - else if (get(retPack)) + else if (get(retPack)) { if (FFlag::LuauReturnAnyInsteadOfICE) return {anyType, std::move(result.predicates)}; @@ -3838,7 +3838,7 @@ void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funNam if (argTail) { - if (state.log.getMutable(state.log.follow(*argTail))) + if (state.log.getMutable(state.log.follow(*argTail))) { if (paramTail) state.tryUnify(*paramTail, *argTail); @@ -3853,7 +3853,7 @@ void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funNam else if (paramTail) { // argTail is definitely empty - if (state.log.getMutable(state.log.follow(*paramTail))) + if (state.log.getMutable(state.log.follow(*paramTail))) state.log.replace(*paramTail, TypePackVar(TypePack{{}})); } @@ -5570,7 +5570,7 @@ GenericTypeDefinitions TypeChecker::createGenericTypes(const ScopePtr& scope, st } else { - g = addType(Unifiable::Generic{level, n}); + g = addType(GenericType{level, n}); } generics.push_back({g, defaultValue}); @@ -5598,7 +5598,7 @@ GenericTypeDefinitions TypeChecker::createGenericTypes(const ScopePtr& scope, st TypePackId& cached = scope->parent->typeAliasTypePackParameters[n]; if (!cached) - cached = addTypePack(TypePackVar{Unifiable::Generic{level, n}}); + cached = addTypePack(TypePackVar{GenericTypePack{level, n}}); genericPacks.push_back({cached, defaultValue}); scope->privateTypePackBindings[n] = cached; diff --git a/Analysis/src/TypePack.cpp b/Analysis/src/TypePack.cpp index ccea604ff..6873820a7 100644 --- a/Analysis/src/TypePack.cpp +++ b/Analysis/src/TypePack.cpp @@ -9,6 +9,69 @@ namespace Luau { +FreeTypePack::FreeTypePack(TypeLevel level) + : index(Unifiable::freshIndex()) + , level(level) + , scope(nullptr) +{ +} + +FreeTypePack::FreeTypePack(Scope* scope) + : index(Unifiable::freshIndex()) + , level{} + , scope(scope) +{ +} + +FreeTypePack::FreeTypePack(Scope* scope, TypeLevel level) + : index(Unifiable::freshIndex()) + , level(level) + , scope(scope) +{ +} + +GenericTypePack::GenericTypePack() + : index(Unifiable::freshIndex()) + , name("g" + std::to_string(index)) +{ +} + +GenericTypePack::GenericTypePack(TypeLevel level) + : index(Unifiable::freshIndex()) + , level(level) + , name("g" + std::to_string(index)) +{ +} + +GenericTypePack::GenericTypePack(const Name& name) + : index(Unifiable::freshIndex()) + , name(name) + , explicitName(true) +{ +} + +GenericTypePack::GenericTypePack(Scope* scope) + : index(Unifiable::freshIndex()) + , scope(scope) +{ +} + +GenericTypePack::GenericTypePack(TypeLevel level, const Name& name) + : index(Unifiable::freshIndex()) + , level(level) + , name(name) + , explicitName(true) +{ +} + +GenericTypePack::GenericTypePack(Scope* scope, const Name& name) + : index(Unifiable::freshIndex()) + , scope(scope) + , name(name) + , explicitName(true) +{ +} + BlockedTypePack::BlockedTypePack() : index(++nextIndex) { @@ -160,8 +223,8 @@ bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs) TypePackId rhsTail = *rhsIter.tail(); { - const Unifiable::Free* lf = get_if(&lhsTail->ty); - const Unifiable::Free* rf = get_if(&rhsTail->ty); + const FreeTypePack* lf = get_if(&lhsTail->ty); + const FreeTypePack* rf = get_if(&rhsTail->ty); if (lf && rf) return lf->index == rf->index; } @@ -174,8 +237,8 @@ bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs) } { - const Unifiable::Generic* lg = get_if(&lhsTail->ty); - const Unifiable::Generic* rg = get_if(&rhsTail->ty); + const GenericTypePack* lg = get_if(&lhsTail->ty); + const GenericTypePack* rg = get_if(&rhsTail->ty); if (lg && rg) return lg->index == rg->index; } diff --git a/Analysis/src/Unifiable.cpp b/Analysis/src/Unifiable.cpp index abdc6c329..2ceb97aae 100644 --- a/Analysis/src/Unifiable.cpp +++ b/Analysis/src/Unifiable.cpp @@ -13,71 +13,6 @@ int freshIndex() return ++nextIndex; } -Free::Free(TypeLevel level) - : index(++nextIndex) - , level(level) -{ -} - -Free::Free(Scope* scope) - : index(++nextIndex) - , scope(scope) -{ -} - -Free::Free(Scope* scope, TypeLevel level) - : index(++nextIndex) - , level(level) - , scope(scope) -{ -} - -int Free::DEPRECATED_nextIndex = 0; - -Generic::Generic() - : index(++nextIndex) - , name("g" + std::to_string(index)) -{ -} - -Generic::Generic(TypeLevel level) - : index(++nextIndex) - , level(level) - , name("g" + std::to_string(index)) -{ -} - -Generic::Generic(const Name& name) - : index(++nextIndex) - , name(name) - , explicitName(true) -{ -} - -Generic::Generic(Scope* scope) - : index(++nextIndex) - , scope(scope) -{ -} - -Generic::Generic(TypeLevel level, const Name& name) - : index(++nextIndex) - , level(level) - , name(name) - , explicitName(true) -{ -} - -Generic::Generic(Scope* scope, const Name& name) - : index(++nextIndex) - , scope(scope) - , name(name) - , explicitName(true) -{ -} - -int Generic::DEPRECATED_nextIndex = 0; - Error::Error() : index(++nextIndex) { diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index b748d115f..642aa399f 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -1489,7 +1489,7 @@ struct WeirdIter bool canGrow() const { - return nullptr != log.getMutable(packId); + return nullptr != log.getMutable(packId); } void grow(TypePackId newTail) @@ -1497,7 +1497,7 @@ struct WeirdIter LUAU_ASSERT(canGrow()); LUAU_ASSERT(log.getMutable(newTail)); - auto freePack = log.getMutable(packId); + auto freePack = log.getMutable(packId); level = freePack->level; if (FFlag::LuauMaintainScopesInUnifier && freePack->scope != nullptr) @@ -1591,7 +1591,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal if (log.haveSeen(superTp, subTp)) return; - if (log.getMutable(superTp)) + if (log.getMutable(superTp)) { if (!occursCheck(superTp, subTp)) { @@ -1599,7 +1599,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal log.replace(superTp, Unifiable::Bound(widen(subTp))); } } - else if (log.getMutable(subTp)) + else if (log.getMutable(subTp)) { if (!occursCheck(subTp, superTp)) { @@ -2567,9 +2567,9 @@ static void queueTypePack(std::vector& queue, DenseHashSet& break; seenTypePacks.insert(a); - if (state.log.getMutable(a)) + if (state.log.getMutable(a)) { - state.log.replace(a, Unifiable::Bound{anyTypePack}); + state.log.replace(a, BoundTypePack{anyTypePack}); } else if (auto tp = state.log.getMutable(a)) { @@ -2617,7 +2617,7 @@ void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool rever { tryUnify_(vtp->ty, superVariadic->ty); } - else if (get(tail)) + else if (get(tail)) { reportError(location, GenericError{"Cannot unify variadic and generic packs"}); } @@ -2777,10 +2777,10 @@ bool Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId hays seen.insert(haystack); - if (log.getMutable(needle)) + if (log.getMutable(needle)) return false; - if (!log.getMutable(needle)) + if (!log.getMutable(needle)) ice("Expected needle to be free"); if (needle == haystack) @@ -2824,10 +2824,10 @@ bool Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ seen.insert(haystack); - if (log.getMutable(needle)) + if (log.getMutable(needle)) return false; - if (!log.getMutable(needle)) + if (!log.getMutable(needle)) ice("Expected needle pack to be free"); RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit); diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp index 4fdb04439..6d1f54514 100644 --- a/CLI/Analyze.cpp +++ b/CLI/Analyze.cpp @@ -14,7 +14,6 @@ #endif LUAU_FASTFLAG(DebugLuauTimeTracing) -LUAU_FASTFLAG(LuauLintInTypecheck) enum class ReportFormat { @@ -81,12 +80,10 @@ static bool analyzeFile(Luau::Frontend& frontend, const char* name, ReportFormat for (auto& error : cr.errors) reportError(frontend, format, error); - Luau::LintResult lr = FFlag::LuauLintInTypecheck ? cr.lintResult : frontend.lint_DEPRECATED(name); - std::string humanReadableName = frontend.fileResolver->getHumanReadableModuleName(name); - for (auto& error : lr.errors) + for (auto& error : cr.lintResult.errors) reportWarning(format, humanReadableName.c_str(), error); - for (auto& warning : lr.warnings) + for (auto& warning : cr.lintResult.warnings) reportWarning(format, humanReadableName.c_str(), warning); if (annotate) @@ -101,7 +98,7 @@ static bool analyzeFile(Luau::Frontend& frontend, const char* name, ReportFormat printf("%s", annotated.c_str()); } - return cr.errors.empty() && lr.errors.empty(); + return cr.errors.empty() && cr.lintResult.errors.empty(); } static void displayHelp(const char* argv0) @@ -264,13 +261,13 @@ int main(int argc, char** argv) Luau::FrontendOptions frontendOptions; frontendOptions.retainFullTypeGraphs = annotate; - frontendOptions.runLintChecks = FFlag::LuauLintInTypecheck; + frontendOptions.runLintChecks = true; CliFileResolver fileResolver; CliConfigResolver configResolver(mode); Luau::Frontend frontend(&fileResolver, &configResolver, frontendOptions); - Luau::registerBuiltinGlobals(frontend.typeChecker, frontend.globals); + Luau::registerBuiltinGlobals(frontend, frontend.globals); Luau::freeze(frontend.globals.globalTypes); #ifdef CALLGRIND diff --git a/CodeGen/include/Luau/AssemblyBuilderA64.h b/CodeGen/include/Luau/AssemblyBuilderA64.h index 0c7387128..def4d0c0c 100644 --- a/CodeGen/include/Luau/AssemblyBuilderA64.h +++ b/CodeGen/include/Luau/AssemblyBuilderA64.h @@ -37,6 +37,7 @@ class AssemblyBuilderA64 void movk(RegisterA64 dst, uint16_t src, int shift = 0); // Arithmetics + // TODO: support various kinds of shifts void add(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift = 0); void add(RegisterA64 dst, RegisterA64 src1, uint16_t src2); void sub(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift = 0); @@ -50,8 +51,10 @@ class AssemblyBuilderA64 void csel(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, ConditionA64 cond); // Bitwise - // Note: shifted-register support and bitfield operations are omitted for simplicity // TODO: support immediate arguments (they have odd encoding and forbid many values) + // TODO: support bic (andnot) + // TODO: support shifts + // TODO: support bitfield ops void and_(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); void orr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); void eor(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); @@ -82,7 +85,7 @@ class AssemblyBuilderA64 void stp(RegisterA64 src1, RegisterA64 src2, AddressA64 dst); // Control flow - // Note: tbz/tbnz are currently not supported because they have 15-bit offsets and we don't support branch thunks + // TODO: support tbz/tbnz; they have 15-bit offsets but they can be useful in constrained cases void b(Label& label); void b(ConditionA64 cond, Label& label); void cbz(RegisterA64 src, Label& label); diff --git a/CodeGen/include/Luau/AssemblyBuilderX64.h b/CodeGen/include/Luau/AssemblyBuilderX64.h index 2b2a849c6..467be4664 100644 --- a/CodeGen/include/Luau/AssemblyBuilderX64.h +++ b/CodeGen/include/Luau/AssemblyBuilderX64.h @@ -121,6 +121,7 @@ class AssemblyBuilderX64 void vcvttsd2si(OperandX64 dst, OperandX64 src); void vcvtsi2sd(OperandX64 dst, OperandX64 src1, OperandX64 src2); + void vcvtsd2ss(OperandX64 dst, OperandX64 src1, OperandX64 src2); void vroundsd(OperandX64 dst, OperandX64 src1, OperandX64 src2, RoundingModeX64 roundingMode); // inexact diff --git a/CodeGen/include/Luau/IrAnalysis.h b/CodeGen/include/Luau/IrAnalysis.h index 470690b95..75b4940a6 100644 --- a/CodeGen/include/Luau/IrAnalysis.h +++ b/CodeGen/include/Luau/IrAnalysis.h @@ -19,6 +19,8 @@ void updateUseCounts(IrFunction& function); void updateLastUseLocations(IrFunction& function); +uint32_t getNextInstUse(IrFunction& function, uint32_t targetInstIdx, uint32_t startInstIdx); + // Returns how many values are coming into the block (live in) and how many are coming out of the block (live out) std::pair getLiveInOutValueCount(IrFunction& function, IrBlock& block); uint32_t getLiveInValueCount(IrFunction& function, IrBlock& block); diff --git a/CodeGen/include/Luau/IrCallWrapperX64.h b/CodeGen/include/Luau/IrCallWrapperX64.h index b70c8da62..724d46243 100644 --- a/CodeGen/include/Luau/IrCallWrapperX64.h +++ b/CodeGen/include/Luau/IrCallWrapperX64.h @@ -17,10 +17,6 @@ namespace CodeGen namespace X64 { -// When IrInst operands are used, current instruction index is required to track lifetime -// In all other calls it is ok to omit the argument -constexpr uint32_t kInvalidInstIdx = ~0u; - struct IrRegAllocX64; struct ScopedRegX64; @@ -61,6 +57,7 @@ class IrCallWrapperX64 void renameRegister(RegisterX64& target, RegisterX64 reg, RegisterX64 replacement); void renameSourceRegisters(RegisterX64 reg, RegisterX64 replacement); RegisterX64 findConflictingTarget() const; + void renameConflictingRegister(RegisterX64 conflict); int getRegisterUses(RegisterX64 reg) const; void addRegisterUse(RegisterX64 reg); diff --git a/CodeGen/include/Luau/IrData.h b/CodeGen/include/Luau/IrData.h index 752160817..fcf29adb1 100644 --- a/CodeGen/include/Luau/IrData.h +++ b/CodeGen/include/Luau/IrData.h @@ -62,11 +62,12 @@ enum class IrCmd : uint8_t // Get pointer (LuaNode) to table node element at the active cached slot index // A: pointer (Table) + // B: unsigned int (pcpos) GET_SLOT_NODE_ADDR, // Get pointer (LuaNode) to table node element at the main position of the specified key hash // A: pointer (Table) - // B: unsigned int + // B: unsigned int (hash) GET_HASH_NODE_ADDR, // Store a tag into TValue @@ -89,6 +90,13 @@ enum class IrCmd : uint8_t // B: int STORE_INT, + // Store a vector into TValue + // A: Rn + // B: double (x) + // C: double (y) + // D: double (z) + STORE_VECTOR, + // Store a TValue into memory // A: Rn or pointer (TValue) // B: TValue @@ -438,15 +446,6 @@ enum class IrCmd : uint8_t // C: block (forgloop location) FORGPREP_XNEXT_FALLBACK, - // Perform `and` or `or` operation (selecting lhs or rhs based on whether the lhs is truthy) and put the result into target register - // A: Rn (target) - // B: Rn (lhs) - // C: Rn or Kn (rhs) - AND, - ANDK, - OR, - ORK, - // Increment coverage data (saturating 24 bit add) // A: unsigned int (bytecode instruction index) COVERAGE, @@ -622,6 +621,17 @@ struct IrOp static_assert(sizeof(IrOp) == 4); +enum class IrValueKind : uint8_t +{ + Unknown, // Used by SUBSTITUTE, argument has to be checked to get type + None, + Tag, + Int, + Pointer, + Double, + Tvalue, +}; + struct IrInst { IrCmd cmd; @@ -641,8 +651,12 @@ struct IrInst X64::RegisterX64 regX64 = X64::noreg; A64::RegisterA64 regA64 = A64::noreg; bool reusedReg = false; + bool spilled = false; }; +// When IrInst operands are used, current instruction index is often required to track lifetime +constexpr uint32_t kInvalidInstIdx = ~0u; + enum class IrBlockKind : uint8_t { Bytecode, @@ -821,6 +835,13 @@ struct IrFunction LUAU_ASSERT(&block >= blocks.data() && &block <= blocks.data() + blocks.size()); return uint32_t(&block - blocks.data()); } + + uint32_t getInstIndex(const IrInst& inst) + { + // Can only be called with instructions from our vector + LUAU_ASSERT(&inst >= instructions.data() && &inst <= instructions.data() + instructions.size()); + return uint32_t(&inst - instructions.data()); + } }; inline IrCondition conditionOp(IrOp op) diff --git a/CodeGen/include/Luau/IrRegAllocX64.h b/CodeGen/include/Luau/IrRegAllocX64.h index c2486faf8..dc7b48c6b 100644 --- a/CodeGen/include/Luau/IrRegAllocX64.h +++ b/CodeGen/include/Luau/IrRegAllocX64.h @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/AssemblyBuilderX64.h" #include "Luau/IrData.h" #include "Luau/RegisterX64.h" @@ -14,33 +15,66 @@ namespace CodeGen namespace X64 { +constexpr uint8_t kNoStackSlot = 0xff; + +struct IrSpillX64 +{ + uint32_t instIdx = 0; + bool useDoubleSlot = 0; + + // Spill location can be a stack location or be empty + // When it's empty, it means that instruction value can be rematerialized + uint8_t stackSlot = kNoStackSlot; + + RegisterX64 originalLoc = noreg; +}; + struct IrRegAllocX64 { - IrRegAllocX64(IrFunction& function); + IrRegAllocX64(AssemblyBuilderX64& build, IrFunction& function); - RegisterX64 allocGprReg(SizeX64 preferredSize); - RegisterX64 allocXmmReg(); + RegisterX64 allocGprReg(SizeX64 preferredSize, uint32_t instIdx); + RegisterX64 allocXmmReg(uint32_t instIdx); - RegisterX64 allocGprRegOrReuse(SizeX64 preferredSize, uint32_t index, std::initializer_list oprefs); - RegisterX64 allocXmmRegOrReuse(uint32_t index, std::initializer_list oprefs); + RegisterX64 allocGprRegOrReuse(SizeX64 preferredSize, uint32_t instIdx, std::initializer_list oprefs); + RegisterX64 allocXmmRegOrReuse(uint32_t instIdx, std::initializer_list oprefs); - RegisterX64 takeReg(RegisterX64 reg); + RegisterX64 takeReg(RegisterX64 reg, uint32_t instIdx); void freeReg(RegisterX64 reg); - void freeLastUseReg(IrInst& target, uint32_t index); - void freeLastUseRegs(const IrInst& inst, uint32_t index); + void freeLastUseReg(IrInst& target, uint32_t instIdx); + void freeLastUseRegs(const IrInst& inst, uint32_t instIdx); - bool isLastUseReg(const IrInst& target, uint32_t index) const; + bool isLastUseReg(const IrInst& target, uint32_t instIdx) const; bool shouldFreeGpr(RegisterX64 reg) const; + // Register used by instruction is about to be freed, have to find a way to restore value later + void preserve(IrInst& inst); + + void restore(IrInst& inst, bool intoOriginalLocation); + + void preserveAndFreeInstValues(); + + uint32_t findInstructionWithFurthestNextUse(const std::array& regInstUsers) const; + void assertFree(RegisterX64 reg) const; void assertAllFree() const; + void assertNoSpills() const; + AssemblyBuilderX64& build; IrFunction& function; + uint32_t currInstIdx = ~0u; + std::array freeGprMap; + std::array gprInstUsers; std::array freeXmmMap; + std::array xmmInstUsers; + + std::bitset<256> usedSpillSlots; + unsigned maxUsedSlot = 0; + std::vector spills; }; struct ScopedRegX64 @@ -62,6 +96,23 @@ struct ScopedRegX64 RegisterX64 reg; }; +// When IR instruction makes a call under a condition that's not reflected as a real branch in IR, +// spilled values have to be restored to their exact original locations, so that both after a call +// and after the skip, values are found in the same place +struct ScopedSpills +{ + explicit ScopedSpills(IrRegAllocX64& owner); + ~ScopedSpills(); + + ScopedSpills(const ScopedSpills&) = delete; + ScopedSpills& operator=(const ScopedSpills&) = delete; + + bool wasSpilledBefore(const IrSpillX64& spill) const; + + IrRegAllocX64& owner; + std::vector snapshot; +}; + } // namespace X64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/include/Luau/IrUtils.h b/CodeGen/include/Luau/IrUtils.h index 6e73e47a6..09c55c799 100644 --- a/CodeGen/include/Luau/IrUtils.h +++ b/CodeGen/include/Luau/IrUtils.h @@ -175,6 +175,8 @@ inline bool isPseudo(IrCmd cmd) return cmd == IrCmd::NOP || cmd == IrCmd::SUBSTITUTE; } +IrValueKind getCmdValueKind(IrCmd cmd); + bool isGCO(uint8_t tag); // Manually add or remove use of an operand diff --git a/CodeGen/include/Luau/RegisterA64.h b/CodeGen/include/Luau/RegisterA64.h index 242e8b793..99e62958d 100644 --- a/CodeGen/include/Luau/RegisterA64.h +++ b/CodeGen/include/Luau/RegisterA64.h @@ -37,6 +37,15 @@ struct RegisterA64 } }; +constexpr RegisterA64 castReg(KindA64 kind, RegisterA64 reg) +{ + LUAU_ASSERT(kind != reg.kind); + LUAU_ASSERT(kind != KindA64::none && reg.kind != KindA64::none); + LUAU_ASSERT((kind == KindA64::w || kind == KindA64::x) == (reg.kind == KindA64::w || reg.kind == KindA64::x)); + + return RegisterA64{kind, reg.index}; +} + constexpr RegisterA64 noreg{KindA64::none, 0}; constexpr RegisterA64 w0{KindA64::w, 0}; diff --git a/CodeGen/src/AssemblyBuilderX64.cpp b/CodeGen/src/AssemblyBuilderX64.cpp index 0285c2a16..d86a37c6e 100644 --- a/CodeGen/src/AssemblyBuilderX64.cpp +++ b/CodeGen/src/AssemblyBuilderX64.cpp @@ -676,6 +676,16 @@ void AssemblyBuilderX64::vcvtsi2sd(OperandX64 dst, OperandX64 src1, OperandX64 s placeAvx("vcvtsi2sd", dst, src1, src2, 0x2a, (src2.cat == CategoryX64::reg ? src2.base.size : src2.memSize) == SizeX64::qword, AVX_0F, AVX_F2); } +void AssemblyBuilderX64::vcvtsd2ss(OperandX64 dst, OperandX64 src1, OperandX64 src2) +{ + if (src2.cat == CategoryX64::reg) + LUAU_ASSERT(src2.base.size == SizeX64::xmmword); + else + LUAU_ASSERT(src2.memSize == SizeX64::qword); + + placeAvx("vcvtsd2ss", dst, src1, src2, 0x5a, (src2.cat == CategoryX64::reg ? src2.base.size : src2.memSize) == SizeX64::qword, AVX_0F, AVX_F2); +} + void AssemblyBuilderX64::vroundsd(OperandX64 dst, OperandX64 src1, OperandX64 src2, RoundingModeX64 roundingMode) { placeAvx("vroundsd", dst, src1, src2, uint8_t(roundingMode) | kRoundingPrecisionInexact, 0x0b, false, AVX_0F3A, AVX_66); diff --git a/CodeGen/src/CodeGen.cpp b/CodeGen/src/CodeGen.cpp index b0cc8d9cd..8e6e94933 100644 --- a/CodeGen/src/CodeGen.cpp +++ b/CodeGen/src/CodeGen.cpp @@ -74,7 +74,7 @@ static NativeProto* createNativeProto(Proto* proto, const IrBuilder& ir) } template -static void lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& function, int bytecodeid, AssemblyOptions options) +static bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& function, int bytecodeid, AssemblyOptions options) { // While we will need a better block ordering in the future, right now we want to mostly preserve build order with fallbacks outlined std::vector sortedBlocks; @@ -193,6 +193,9 @@ static void lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& IrBlock& next = i + 1 < sortedBlocks.size() ? function.blocks[sortedBlocks[i + 1]] : dummy; lowering.lowerInst(inst, index, next); + + if (lowering.hasError()) + return false; } if (options.includeIr) @@ -213,6 +216,8 @@ static void lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& if (irLocation != ~0u) asmLocation = bcLocations[irLocation]; } + + return true; } [[maybe_unused]] static bool lowerIr( @@ -226,9 +231,7 @@ static void lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& X64::IrLoweringX64 lowering(build, helpers, data, ir.function); - lowerImpl(build, lowering, ir.function, proto->bytecodeid, options); - - return true; + return lowerImpl(build, lowering, ir.function, proto->bytecodeid, options); } [[maybe_unused]] static bool lowerIr( @@ -239,9 +242,7 @@ static void lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& A64::IrLoweringA64 lowering(build, helpers, data, proto, ir.function); - lowerImpl(build, lowering, ir.function, proto->bytecodeid, options); - - return true; + return lowerImpl(build, lowering, ir.function, proto->bytecodeid, options); } template diff --git a/CodeGen/src/EmitBuiltinsX64.cpp b/CodeGen/src/EmitBuiltinsX64.cpp index 2e745cbf2..b010ce627 100644 --- a/CodeGen/src/EmitBuiltinsX64.cpp +++ b/CodeGen/src/EmitBuiltinsX64.cpp @@ -11,8 +11,6 @@ #include "lstate.h" -// TODO: LBF_MATH_FREXP and LBF_MATH_MODF can work for 1 result case if second store is removed - namespace Luau { namespace CodeGen @@ -176,8 +174,11 @@ void emitBuiltinMathFrexp(IrRegAllocX64& regs, AssemblyBuilderX64& build, int np build.vmovsd(luauRegValue(ra), xmm0); - build.vcvtsi2sd(xmm0, xmm0, dword[sTemporarySlot + 0]); - build.vmovsd(luauRegValue(ra + 1), xmm0); + if (nresults > 1) + { + build.vcvtsi2sd(xmm0, xmm0, dword[sTemporarySlot + 0]); + build.vmovsd(luauRegValue(ra + 1), xmm0); + } } void emitBuiltinMathModf(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) @@ -190,7 +191,8 @@ void emitBuiltinMathModf(IrRegAllocX64& regs, AssemblyBuilderX64& build, int npa build.vmovsd(xmm1, qword[sTemporarySlot + 0]); build.vmovsd(luauRegValue(ra), xmm1); - build.vmovsd(luauRegValue(ra + 1), xmm0); + if (nresults > 1) + build.vmovsd(luauRegValue(ra + 1), xmm0); } void emitBuiltinMathSign(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) @@ -248,9 +250,9 @@ void emitBuiltin(IrRegAllocX64& regs, AssemblyBuilderX64& build, int bfid, int r OperandX64 argsOp = 0; if (args.kind == IrOpKind::VmReg) - argsOp = luauRegAddress(args.index); + argsOp = luauRegAddress(vmRegOp(args)); else if (args.kind == IrOpKind::VmConst) - argsOp = luauConstantAddress(args.index); + argsOp = luauConstantAddress(vmConstOp(args)); switch (bfid) { diff --git a/CodeGen/src/EmitCommonA64.cpp b/CodeGen/src/EmitCommonA64.cpp index 2b4bbaba1..1758e4fb1 100644 --- a/CodeGen/src/EmitCommonA64.cpp +++ b/CodeGen/src/EmitCommonA64.cpp @@ -101,6 +101,30 @@ void emitReentry(AssemblyBuilderA64& build, ModuleHelpers& helpers) build.br(x1); } +void emitFallback(AssemblyBuilderA64& build, int op, int pcpos) +{ + // fallback(L, instruction, base, k) + build.mov(x0, rState); + + // TODO: refactor into a common helper + if (pcpos * sizeof(Instruction) <= AssemblyBuilderA64::kMaxImmediate) + { + build.add(x1, rCode, uint16_t(pcpos * sizeof(Instruction))); + } + else + { + build.mov(x1, pcpos * sizeof(Instruction)); + build.add(x1, rCode, x1); + } + + build.mov(x2, rBase); + build.mov(x3, rConstants); + build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, fallback) + op * sizeof(NativeFallback) + offsetof(NativeFallback, fallback))); + build.blr(x4); + + emitUpdateBase(build); +} + } // namespace A64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/EmitCommonA64.h b/CodeGen/src/EmitCommonA64.h index 5ca9c5586..2a65afa8f 100644 --- a/CodeGen/src/EmitCommonA64.h +++ b/CodeGen/src/EmitCommonA64.h @@ -46,6 +46,7 @@ void emitUpdateBase(AssemblyBuilderA64& build); void emitExit(AssemblyBuilderA64& build, bool continueInVm); void emitInterrupt(AssemblyBuilderA64& build); void emitReentry(AssemblyBuilderA64& build, ModuleHelpers& helpers); +void emitFallback(AssemblyBuilderA64& build, int op, int pcpos); } // namespace A64 } // namespace CodeGen diff --git a/CodeGen/src/EmitCommonX64.cpp b/CodeGen/src/EmitCommonX64.cpp index 7db4068d0..9136add85 100644 --- a/CodeGen/src/EmitCommonX64.cpp +++ b/CodeGen/src/EmitCommonX64.cpp @@ -196,33 +196,51 @@ void checkObjectBarrierConditions(AssemblyBuilderX64& build, RegisterX64 tmp, Re build.jcc(ConditionX64::Zero, skip); } -void callBarrierObject(IrRegAllocX64& regs, AssemblyBuilderX64& build, RegisterX64 object, IrOp objectOp, int ra, Label& skip) +void callBarrierObject(IrRegAllocX64& regs, AssemblyBuilderX64& build, RegisterX64 object, IrOp objectOp, int ra) { + Label skip; + ScopedRegX64 tmp{regs, SizeX64::qword}; checkObjectBarrierConditions(build, tmp.reg, object, ra, skip); - IrCallWrapperX64 callWrap(regs, build); - callWrap.addArgument(SizeX64::qword, rState); - callWrap.addArgument(SizeX64::qword, object, objectOp); - callWrap.addArgument(SizeX64::qword, tmp); - callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaC_barrierf)]); + { + ScopedSpills spillGuard(regs); + + IrCallWrapperX64 callWrap(regs, build); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::qword, object, objectOp); + callWrap.addArgument(SizeX64::qword, tmp); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaC_barrierf)]); + } + + build.setLabel(skip); } -void callBarrierTableFast(IrRegAllocX64& regs, AssemblyBuilderX64& build, RegisterX64 table, IrOp tableOp, Label& skip) +void callBarrierTableFast(IrRegAllocX64& regs, AssemblyBuilderX64& build, RegisterX64 table, IrOp tableOp) { + Label skip; + // isblack(obj2gco(t)) build.test(byte[table + offsetof(GCheader, marked)], bitmask(BLACKBIT)); build.jcc(ConditionX64::Zero, skip); - IrCallWrapperX64 callWrap(regs, build); - callWrap.addArgument(SizeX64::qword, rState); - callWrap.addArgument(SizeX64::qword, table, tableOp); - callWrap.addArgument(SizeX64::qword, addr[table + offsetof(Table, gclist)]); - callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaC_barrierback)]); + { + ScopedSpills spillGuard(regs); + + IrCallWrapperX64 callWrap(regs, build); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::qword, table, tableOp); + callWrap.addArgument(SizeX64::qword, addr[table + offsetof(Table, gclist)]); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaC_barrierback)]); + } + + build.setLabel(skip); } -void callCheckGc(IrRegAllocX64& regs, AssemblyBuilderX64& build, Label& skip) +void callStepGc(IrRegAllocX64& regs, AssemblyBuilderX64& build) { + Label skip; + { ScopedRegX64 tmp1{regs, SizeX64::qword}; ScopedRegX64 tmp2{regs, SizeX64::qword}; @@ -233,11 +251,17 @@ void callCheckGc(IrRegAllocX64& regs, AssemblyBuilderX64& build, Label& skip) build.jcc(ConditionX64::Below, skip); } - IrCallWrapperX64 callWrap(regs, build); - callWrap.addArgument(SizeX64::qword, rState); - callWrap.addArgument(SizeX64::dword, 1); - callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaC_step)]); - emitUpdateBase(build); + { + ScopedSpills spillGuard(regs); + + IrCallWrapperX64 callWrap(regs, build); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::dword, 1); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaC_step)]); + emitUpdateBase(build); + } + + build.setLabel(skip); } void emitExit(AssemblyBuilderX64& build, bool continueInVm) @@ -256,7 +280,7 @@ void emitUpdateBase(AssemblyBuilderX64& build) } // Note: only uses rax/rdx, the caller may use other registers -void emitSetSavedPc(AssemblyBuilderX64& build, int pcpos) +static void emitSetSavedPc(AssemblyBuilderX64& build, int pcpos) { build.mov(rdx, sCode); build.add(rdx, pcpos * sizeof(Instruction)); @@ -298,9 +322,6 @@ void emitInterrupt(AssemblyBuilderX64& build, int pcpos) void emitFallback(AssemblyBuilderX64& build, NativeState& data, int op, int pcpos) { - if (op == LOP_CAPTURE) - return; - NativeFallback& opinfo = data.context.fallback[op]; LUAU_ASSERT(opinfo.fallback); diff --git a/CodeGen/src/EmitCommonX64.h b/CodeGen/src/EmitCommonX64.h index 85045ad5b..6aac5a1ec 100644 --- a/CodeGen/src/EmitCommonX64.h +++ b/CodeGen/src/EmitCommonX64.h @@ -42,12 +42,14 @@ constexpr RegisterX64 rConstants = r12; // TValue* k // Native code is as stackless as the interpreter, so we can place some data on the stack once and have it accessible at any point // See CodeGenX64.cpp for layout -constexpr unsigned kStackSize = 32 + 16; // 4 home locations for registers, 16 bytes for additional function call arguments -constexpr unsigned kLocalsSize = 24; // 3 extra slots for our custom locals (also aligns the stack to 16 byte boundary) +constexpr unsigned kStackSize = 32 + 16; // 4 home locations for registers, 16 bytes for additional function call arguments +constexpr unsigned kSpillSlots = 4; // locations for register allocator to spill data into +constexpr unsigned kLocalsSize = 24 + 8 * kSpillSlots; // 3 extra slots for our custom locals (also aligns the stack to 16 byte boundary) constexpr OperandX64 sClosure = qword[rsp + kStackSize + 0]; // Closure* cl constexpr OperandX64 sCode = qword[rsp + kStackSize + 8]; // Instruction* code constexpr OperandX64 sTemporarySlot = addr[rsp + kStackSize + 16]; +constexpr OperandX64 sSpillArea = addr[rsp + kStackSize + 24]; // TODO: These should be replaced with a portable call function that checks the ABI at runtime and reorders moves accordingly to avoid conflicts #if defined(_WIN32) @@ -99,6 +101,11 @@ inline OperandX64 luauRegValueInt(int ri) return dword[rBase + ri * sizeof(TValue) + offsetof(TValue, value)]; } +inline OperandX64 luauRegValueVector(int ri, int index) +{ + return dword[rBase + ri * sizeof(TValue) + offsetof(TValue, value) + (sizeof(float) * index)]; +} + inline OperandX64 luauConstant(int ki) { return xmmword[rConstants + ki * sizeof(TValue)]; @@ -247,13 +254,12 @@ void callPrepareForN(IrRegAllocX64& regs, AssemblyBuilderX64& build, int limit, void callGetTable(IrRegAllocX64& regs, AssemblyBuilderX64& build, int rb, OperandX64 c, int ra); void callSetTable(IrRegAllocX64& regs, AssemblyBuilderX64& build, int rb, OperandX64 c, int ra); void checkObjectBarrierConditions(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 object, int ra, Label& skip); -void callBarrierObject(IrRegAllocX64& regs, AssemblyBuilderX64& build, RegisterX64 object, IrOp objectOp, int ra, Label& skip); -void callBarrierTableFast(IrRegAllocX64& regs, AssemblyBuilderX64& build, RegisterX64 table, IrOp tableOp, Label& skip); -void callCheckGc(IrRegAllocX64& regs, AssemblyBuilderX64& build, Label& skip); +void callBarrierObject(IrRegAllocX64& regs, AssemblyBuilderX64& build, RegisterX64 object, IrOp objectOp, int ra); +void callBarrierTableFast(IrRegAllocX64& regs, AssemblyBuilderX64& build, RegisterX64 table, IrOp tableOp); +void callStepGc(IrRegAllocX64& regs, AssemblyBuilderX64& build); void emitExit(AssemblyBuilderX64& build, bool continueInVm); void emitUpdateBase(AssemblyBuilderX64& build); -void emitSetSavedPc(AssemblyBuilderX64& build, int pcpos); // Note: only uses rax/rdx, the caller may use other registers void emitInterrupt(AssemblyBuilderX64& build, int pcpos); void emitFallback(AssemblyBuilderX64& build, NativeState& data, int op, int pcpos); diff --git a/CodeGen/src/EmitInstructionX64.cpp b/CodeGen/src/EmitInstructionX64.cpp index b645f9f7a..c0a64274a 100644 --- a/CodeGen/src/EmitInstructionX64.cpp +++ b/CodeGen/src/EmitInstructionX64.cpp @@ -316,7 +316,7 @@ void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, i build.jmp(qword[rdx + rax * 2]); } -void emitInstSetList(IrRegAllocX64& regs, AssemblyBuilderX64& build, Label& next, int ra, int rb, int count, uint32_t index) +void emitInstSetList(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb, int count, uint32_t index) { OperandX64 last = index + count - 1; @@ -347,7 +347,7 @@ void emitInstSetList(IrRegAllocX64& regs, AssemblyBuilderX64& build, Label& next Label skipResize; - RegisterX64 table = regs.takeReg(rax); + RegisterX64 table = regs.takeReg(rax, kInvalidInstIdx); build.mov(table, luauRegValue(ra)); @@ -412,7 +412,7 @@ void emitInstSetList(IrRegAllocX64& regs, AssemblyBuilderX64& build, Label& next build.setLabel(endLoop); } - callBarrierTableFast(regs, build, table, {}, next); + callBarrierTableFast(regs, build, table, {}); } void emitinstForGLoop(AssemblyBuilderX64& build, int ra, int aux, Label& loopRepeat, Label& loopExit) @@ -504,82 +504,6 @@ void emitInstForGPrepXnextFallback(AssemblyBuilderX64& build, int pcpos, int ra, build.jmp(target); } -static void emitInstAndX(AssemblyBuilderX64& build, int ra, int rb, OperandX64 c) -{ - Label target, fallthrough; - jumpIfFalsy(build, rb, target, fallthrough); - - build.setLabel(fallthrough); - - build.vmovups(xmm0, c); - build.vmovups(luauReg(ra), xmm0); - - if (ra == rb) - { - build.setLabel(target); - } - else - { - Label exit; - build.jmp(exit); - - build.setLabel(target); - - build.vmovups(xmm0, luauReg(rb)); - build.vmovups(luauReg(ra), xmm0); - - build.setLabel(exit); - } -} - -void emitInstAnd(AssemblyBuilderX64& build, int ra, int rb, int rc) -{ - emitInstAndX(build, ra, rb, luauReg(rc)); -} - -void emitInstAndK(AssemblyBuilderX64& build, int ra, int rb, int kc) -{ - emitInstAndX(build, ra, rb, luauConstant(kc)); -} - -static void emitInstOrX(AssemblyBuilderX64& build, int ra, int rb, OperandX64 c) -{ - Label target, fallthrough; - jumpIfTruthy(build, rb, target, fallthrough); - - build.setLabel(fallthrough); - - build.vmovups(xmm0, c); - build.vmovups(luauReg(ra), xmm0); - - if (ra == rb) - { - build.setLabel(target); - } - else - { - Label exit; - build.jmp(exit); - - build.setLabel(target); - - build.vmovups(xmm0, luauReg(rb)); - build.vmovups(luauReg(ra), xmm0); - - build.setLabel(exit); - } -} - -void emitInstOr(AssemblyBuilderX64& build, int ra, int rb, int rc) -{ - emitInstOrX(build, ra, rb, luauReg(rc)); -} - -void emitInstOrK(AssemblyBuilderX64& build, int ra, int rb, int kc) -{ - emitInstOrX(build, ra, rb, luauConstant(kc)); -} - void emitInstGetImportFallback(AssemblyBuilderX64& build, int ra, uint32_t aux) { build.mov(rax, sClosure); diff --git a/CodeGen/src/EmitInstructionX64.h b/CodeGen/src/EmitInstructionX64.h index cc1b86456..d58e13310 100644 --- a/CodeGen/src/EmitInstructionX64.h +++ b/CodeGen/src/EmitInstructionX64.h @@ -19,14 +19,10 @@ struct IrRegAllocX64; void emitInstCall(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, int nparams, int nresults); void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, int actualResults); -void emitInstSetList(IrRegAllocX64& regs, AssemblyBuilderX64& build, Label& next, int ra, int rb, int count, uint32_t index); +void emitInstSetList(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb, int count, uint32_t index); void emitinstForGLoop(AssemblyBuilderX64& build, int ra, int aux, Label& loopRepeat, Label& loopExit); void emitinstForGLoopFallback(AssemblyBuilderX64& build, int ra, int aux, Label& loopRepeat); void emitInstForGPrepXnextFallback(AssemblyBuilderX64& build, int pcpos, int ra, Label& target); -void emitInstAnd(AssemblyBuilderX64& build, int ra, int rb, int rc); -void emitInstAndK(AssemblyBuilderX64& build, int ra, int rb, int kc); -void emitInstOr(AssemblyBuilderX64& build, int ra, int rb, int rc); -void emitInstOrK(AssemblyBuilderX64& build, int ra, int rb, int kc); void emitInstGetImportFallback(AssemblyBuilderX64& build, int ra, uint32_t aux); void emitInstCoverage(AssemblyBuilderX64& build, int pcpos); diff --git a/CodeGen/src/IrAnalysis.cpp b/CodeGen/src/IrAnalysis.cpp index b248b97d5..2246e5c5e 100644 --- a/CodeGen/src/IrAnalysis.cpp +++ b/CodeGen/src/IrAnalysis.cpp @@ -69,6 +69,9 @@ void updateLastUseLocations(IrFunction& function) instructions[op.index].lastUse = uint32_t(instIdx); }; + if (isPseudo(inst.cmd)) + continue; + checkOp(inst.a); checkOp(inst.b); checkOp(inst.c); @@ -78,6 +81,42 @@ void updateLastUseLocations(IrFunction& function) } } +uint32_t getNextInstUse(IrFunction& function, uint32_t targetInstIdx, uint32_t startInstIdx) +{ + LUAU_ASSERT(startInstIdx < function.instructions.size()); + IrInst& targetInst = function.instructions[targetInstIdx]; + + for (uint32_t i = startInstIdx; i <= targetInst.lastUse; i++) + { + IrInst& inst = function.instructions[i]; + + if (isPseudo(inst.cmd)) + continue; + + if (inst.a.kind == IrOpKind::Inst && inst.a.index == targetInstIdx) + return i; + + if (inst.b.kind == IrOpKind::Inst && inst.b.index == targetInstIdx) + return i; + + if (inst.c.kind == IrOpKind::Inst && inst.c.index == targetInstIdx) + return i; + + if (inst.d.kind == IrOpKind::Inst && inst.d.index == targetInstIdx) + return i; + + if (inst.e.kind == IrOpKind::Inst && inst.e.index == targetInstIdx) + return i; + + if (inst.f.kind == IrOpKind::Inst && inst.f.index == targetInstIdx) + return i; + } + + // There must be a next use since there is the last use location + LUAU_ASSERT(!"failed to find next use"); + return targetInst.lastUse; +} + std::pair getLiveInOutValueCount(IrFunction& function, IrBlock& block) { uint32_t liveIns = 0; @@ -97,6 +136,9 @@ std::pair getLiveInOutValueCount(IrFunction& function, IrBlo { IrInst& inst = function.instructions[instIdx]; + if (isPseudo(inst.cmd)) + continue; + liveOuts += inst.useCount; checkOp(inst.a); @@ -149,26 +191,24 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& RegisterSet inRs; auto def = [&](IrOp op, int offset = 0) { - LUAU_ASSERT(op.kind == IrOpKind::VmReg); - defRs.regs.set(op.index + offset, true); + defRs.regs.set(vmRegOp(op) + offset, true); }; auto use = [&](IrOp op, int offset = 0) { - LUAU_ASSERT(op.kind == IrOpKind::VmReg); - if (!defRs.regs.test(op.index + offset)) - inRs.regs.set(op.index + offset, true); + if (!defRs.regs.test(vmRegOp(op) + offset)) + inRs.regs.set(vmRegOp(op) + offset, true); }; auto maybeDef = [&](IrOp op) { if (op.kind == IrOpKind::VmReg) - defRs.regs.set(op.index, true); + defRs.regs.set(vmRegOp(op), true); }; auto maybeUse = [&](IrOp op) { if (op.kind == IrOpKind::VmReg) { - if (!defRs.regs.test(op.index)) - inRs.regs.set(op.index, true); + if (!defRs.regs.test(vmRegOp(op))) + inRs.regs.set(vmRegOp(op), true); } }; @@ -230,6 +270,7 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& case IrCmd::STORE_POINTER: case IrCmd::STORE_DOUBLE: case IrCmd::STORE_INT: + case IrCmd::STORE_VECTOR: case IrCmd::STORE_TVALUE: maybeDef(inst.a); // Argument can also be a pointer value break; @@ -264,9 +305,9 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& def(inst.a); break; case IrCmd::CONCAT: - useRange(inst.a.index, function.uintOp(inst.b)); + useRange(vmRegOp(inst.a), function.uintOp(inst.b)); - defRange(inst.a.index, function.uintOp(inst.b)); + defRange(vmRegOp(inst.a), function.uintOp(inst.b)); break; case IrCmd::GET_UPVALUE: def(inst.a); @@ -298,20 +339,20 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& maybeUse(inst.a); if (function.boolOp(inst.b)) - capturedRegs.set(inst.a.index, true); + capturedRegs.set(vmRegOp(inst.a), true); break; case IrCmd::SETLIST: use(inst.b); - useRange(inst.c.index, function.intOp(inst.d)); + useRange(vmRegOp(inst.c), function.intOp(inst.d)); break; case IrCmd::CALL: use(inst.a); - useRange(inst.a.index + 1, function.intOp(inst.b)); + useRange(vmRegOp(inst.a) + 1, function.intOp(inst.b)); - defRange(inst.a.index, function.intOp(inst.c)); + defRange(vmRegOp(inst.a), function.intOp(inst.c)); break; case IrCmd::RETURN: - useRange(inst.a.index, function.intOp(inst.b)); + useRange(vmRegOp(inst.a), function.intOp(inst.b)); break; case IrCmd::FASTCALL: case IrCmd::INVOKE_FASTCALL: @@ -319,9 +360,9 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& { if (count >= 3) { - LUAU_ASSERT(inst.d.kind == IrOpKind::VmReg && inst.d.index == inst.c.index + 1); + LUAU_ASSERT(inst.d.kind == IrOpKind::VmReg && vmRegOp(inst.d) == vmRegOp(inst.c) + 1); - useRange(inst.c.index, count); + useRange(vmRegOp(inst.c), count); } else { @@ -334,12 +375,12 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& } else { - useVarargs(inst.c.index); + useVarargs(vmRegOp(inst.c)); } // Multiple return sequences (count == -1) are defined by ADJUST_STACK_TO_REG if (int count = function.intOp(inst.f); count != -1) - defRange(inst.b.index, count); + defRange(vmRegOp(inst.b), count); break; case IrCmd::FORGLOOP: // First register is not used by instruction, we check that it's still 'nil' with CHECK_TAG @@ -347,32 +388,17 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& use(inst.a, 2); def(inst.a, 2); - defRange(inst.a.index + 3, function.intOp(inst.b)); + defRange(vmRegOp(inst.a) + 3, function.intOp(inst.b)); break; case IrCmd::FORGLOOP_FALLBACK: - useRange(inst.a.index, 3); + useRange(vmRegOp(inst.a), 3); def(inst.a, 2); - defRange(inst.a.index + 3, uint8_t(function.intOp(inst.b))); // ignore most significant bit + defRange(vmRegOp(inst.a) + 3, uint8_t(function.intOp(inst.b))); // ignore most significant bit break; case IrCmd::FORGPREP_XNEXT_FALLBACK: use(inst.b); break; - // A <- B, C - case IrCmd::AND: - case IrCmd::OR: - use(inst.b); - use(inst.c); - - def(inst.a); - break; - // A <- B - case IrCmd::ANDK: - case IrCmd::ORK: - use(inst.b); - - def(inst.a); - break; case IrCmd::FALLBACK_GETGLOBAL: def(inst.b); break; @@ -391,13 +417,13 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& case IrCmd::FALLBACK_NAMECALL: use(inst.c); - defRange(inst.b.index, 2); + defRange(vmRegOp(inst.b), 2); break; case IrCmd::FALLBACK_PREPVARARGS: // No effect on explicitly referenced registers break; case IrCmd::FALLBACK_GETVARARGS: - defRange(inst.b.index, function.intOp(inst.c)); + defRange(vmRegOp(inst.b), function.intOp(inst.c)); break; case IrCmd::FALLBACK_NEWCLOSURE: def(inst.b); @@ -408,10 +434,10 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& case IrCmd::FALLBACK_FORGPREP: use(inst.b); - defRange(inst.b.index, 3); + defRange(vmRegOp(inst.b), 3); break; case IrCmd::ADJUST_STACK_TO_REG: - defRange(inst.a.index, -1); + defRange(vmRegOp(inst.a), -1); break; case IrCmd::ADJUST_STACK_TO_TOP: // While this can be considered to be a vararg consumer, it is already handled in fastcall instructions diff --git a/CodeGen/src/IrBuilder.cpp b/CodeGen/src/IrBuilder.cpp index 4fee080ba..48c0e25c0 100644 --- a/CodeGen/src/IrBuilder.cpp +++ b/CodeGen/src/IrBuilder.cpp @@ -364,16 +364,16 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) translateInstForGPrepInext(*this, pc, i); break; case LOP_AND: - inst(IrCmd::AND, vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), vmReg(LUAU_INSN_C(*pc))); + translateInstAndX(*this, pc, i, vmReg(LUAU_INSN_C(*pc))); break; case LOP_ANDK: - inst(IrCmd::ANDK, vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), vmConst(LUAU_INSN_C(*pc))); + translateInstAndX(*this, pc, i, vmConst(LUAU_INSN_C(*pc))); break; case LOP_OR: - inst(IrCmd::OR, vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), vmReg(LUAU_INSN_C(*pc))); + translateInstOrX(*this, pc, i, vmReg(LUAU_INSN_C(*pc))); break; case LOP_ORK: - inst(IrCmd::ORK, vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), vmConst(LUAU_INSN_C(*pc))); + translateInstOrX(*this, pc, i, vmConst(LUAU_INSN_C(*pc))); break; case LOP_COVERAGE: inst(IrCmd::COVERAGE, constUint(i)); diff --git a/CodeGen/src/IrCallWrapperX64.cpp b/CodeGen/src/IrCallWrapperX64.cpp index 4f0c0cf66..8ac5f8bcf 100644 --- a/CodeGen/src/IrCallWrapperX64.cpp +++ b/CodeGen/src/IrCallWrapperX64.cpp @@ -58,14 +58,17 @@ void IrCallWrapperX64::call(const OperandX64& func) { CallArgument& arg = args[i]; - // If source is the last use of IrInst, clear the register - // Source registers are recorded separately in CallArgument if (arg.sourceOp.kind != IrOpKind::None) { if (IrInst* inst = regs.function.asInstOp(arg.sourceOp)) { + // Source registers are recorded separately from source operands in CallArgument + // If source is the last use of IrInst, clear the register from the operand if (regs.isLastUseReg(*inst, instIdx)) inst->regX64 = noreg; + // If it's not the last use and register is volatile, register ownership is taken, which also spills the operand + else if (inst->regX64.size == SizeX64::xmmword || regs.shouldFreeGpr(inst->regX64)) + regs.takeReg(inst->regX64, kInvalidInstIdx); } } @@ -83,7 +86,11 @@ void IrCallWrapperX64::call(const OperandX64& func) freeSourceRegisters(arg); - build.mov(tmp.reg, arg.source); + if (arg.source.memSize == SizeX64::none) + build.lea(tmp.reg, arg.source); + else + build.mov(tmp.reg, arg.source); + build.mov(arg.target, tmp.reg); } else @@ -102,7 +109,7 @@ void IrCallWrapperX64::call(const OperandX64& func) // If target is not used as source in other arguments, prevent register allocator from giving it out if (getRegisterUses(arg.target.base) == 0) - regs.takeReg(arg.target.base); + regs.takeReg(arg.target.base, kInvalidInstIdx); else // Otherwise, make sure we won't free it when last source use is completed addRegisterUse(arg.target.base); @@ -122,7 +129,7 @@ void IrCallWrapperX64::call(const OperandX64& func) freeSourceRegisters(*candidate); LUAU_ASSERT(getRegisterUses(candidate->target.base) == 0); - regs.takeReg(candidate->target.base); + regs.takeReg(candidate->target.base, kInvalidInstIdx); moveToTarget(*candidate); @@ -131,15 +138,7 @@ void IrCallWrapperX64::call(const OperandX64& func) // If all registers cross-interfere (rcx <- rdx, rdx <- rcx), one has to be renamed else if (RegisterX64 conflict = findConflictingTarget(); conflict != noreg) { - // Get a fresh register - RegisterX64 freshReg = conflict.size == SizeX64::xmmword ? regs.allocXmmReg() : regs.allocGprReg(conflict.size); - - if (conflict.size == SizeX64::xmmword) - build.vmovsd(freshReg, conflict, conflict); - else - build.mov(freshReg, conflict); - - renameSourceRegisters(conflict, freshReg); + renameConflictingRegister(conflict); } else { @@ -156,10 +155,18 @@ void IrCallWrapperX64::call(const OperandX64& func) if (arg.source.cat == CategoryX64::imm) { + // There could be a conflict with the function source register, make this argument a candidate to find it + arg.candidate = true; + + if (RegisterX64 conflict = findConflictingTarget(); conflict != noreg) + renameConflictingRegister(conflict); + if (arg.target.cat == CategoryX64::reg) - regs.takeReg(arg.target.base); + regs.takeReg(arg.target.base, kInvalidInstIdx); moveToTarget(arg); + + arg.candidate = false; } } @@ -176,6 +183,10 @@ void IrCallWrapperX64::call(const OperandX64& func) regs.freeReg(arg.target.base); } + regs.preserveAndFreeInstValues(); + + regs.assertAllFree(); + build.call(funcOp); } @@ -362,6 +373,19 @@ RegisterX64 IrCallWrapperX64::findConflictingTarget() const return noreg; } +void IrCallWrapperX64::renameConflictingRegister(RegisterX64 conflict) +{ + // Get a fresh register + RegisterX64 freshReg = conflict.size == SizeX64::xmmword ? regs.allocXmmReg(kInvalidInstIdx) : regs.allocGprReg(conflict.size, kInvalidInstIdx); + + if (conflict.size == SizeX64::xmmword) + build.vmovsd(freshReg, conflict, conflict); + else + build.mov(freshReg, conflict); + + renameSourceRegisters(conflict, freshReg); +} + int IrCallWrapperX64::getRegisterUses(RegisterX64 reg) const { return reg.size == SizeX64::xmmword ? xmmUses[reg.index] : (reg.size != SizeX64::none ? gprUses[reg.index] : 0); diff --git a/CodeGen/src/IrDump.cpp b/CodeGen/src/IrDump.cpp index fb56df8c5..8f299520b 100644 --- a/CodeGen/src/IrDump.cpp +++ b/CodeGen/src/IrDump.cpp @@ -100,6 +100,8 @@ const char* getCmdName(IrCmd cmd) return "STORE_DOUBLE"; case IrCmd::STORE_INT: return "STORE_INT"; + case IrCmd::STORE_VECTOR: + return "STORE_VECTOR"; case IrCmd::STORE_TVALUE: return "STORE_TVALUE"; case IrCmd::STORE_NODE_VALUE_TV: @@ -238,14 +240,6 @@ const char* getCmdName(IrCmd cmd) return "FORGLOOP_FALLBACK"; case IrCmd::FORGPREP_XNEXT_FALLBACK: return "FORGPREP_XNEXT_FALLBACK"; - case IrCmd::AND: - return "AND"; - case IrCmd::ANDK: - return "ANDK"; - case IrCmd::OR: - return "OR"; - case IrCmd::ORK: - return "ORK"; case IrCmd::COVERAGE: return "COVERAGE"; case IrCmd::FALLBACK_GETGLOBAL: @@ -345,13 +339,13 @@ void toString(IrToStringContext& ctx, IrOp op) append(ctx.result, "%s_%u", getBlockKindName(ctx.blocks[op.index].kind), op.index); break; case IrOpKind::VmReg: - append(ctx.result, "R%u", op.index); + append(ctx.result, "R%d", vmRegOp(op)); break; case IrOpKind::VmConst: - append(ctx.result, "K%u", op.index); + append(ctx.result, "K%d", vmConstOp(op)); break; case IrOpKind::VmUpvalue: - append(ctx.result, "U%u", op.index); + append(ctx.result, "U%d", vmUpvalueOp(op)); break; } } diff --git a/CodeGen/src/IrLoweringA64.cpp b/CodeGen/src/IrLoweringA64.cpp index 37f381572..7f0305cc2 100644 --- a/CodeGen/src/IrLoweringA64.cpp +++ b/CodeGen/src/IrLoweringA64.cpp @@ -12,6 +12,7 @@ #include "NativeState.h" #include "lstate.h" +#include "lgc.h" // TODO: Eventually this can go away // #define TRACE @@ -32,7 +33,7 @@ struct LoweringStatsA64 ~LoweringStatsA64() { if (total) - printf("A64 lowering succeded for %.1f%% functions (%d/%d)\n", double(can) / double(total) * 100, int(can), int(total)); + printf("A64 lowering succeeded for %.1f%% functions (%d/%d)\n", double(can) / double(total) * 100, int(can), int(total)); } } gStatsA64; #endif @@ -77,6 +78,34 @@ inline ConditionA64 getConditionFP(IrCondition cond) } } +// TODO: instead of temp1/temp2 we can take a register that we will use for ra->value; that way callers to this function will be able to use it when +// calling luaC_barrier* +static void checkObjectBarrierConditions(AssemblyBuilderA64& build, RegisterA64 object, RegisterA64 temp1, RegisterA64 temp2, int ra, Label& skip) +{ + RegisterA64 temp1w = castReg(KindA64::w, temp1); + RegisterA64 temp2w = castReg(KindA64::w, temp2); + + // iscollectable(ra) + build.ldr(temp1w, mem(rBase, ra * sizeof(TValue) + offsetof(TValue, tt))); + build.cmp(temp1w, LUA_TSTRING); + build.b(ConditionA64::Less, skip); + + // isblack(obj2gco(o)) + // TODO: conditional bit test with BLACKBIT + build.ldrb(temp1w, mem(object, offsetof(GCheader, marked))); + build.mov(temp2w, bitmask(BLACKBIT)); + build.and_(temp1w, temp1w, temp2w); + build.cbz(temp1w, skip); + + // iswhite(gcvalue(ra)) + // TODO: tst with bitmask(WHITE0BIT, WHITE1BIT) + build.ldr(temp1, mem(rBase, ra * sizeof(TValue) + offsetof(TValue, value))); + build.ldrb(temp1w, mem(temp1, offsetof(GCheader, marked))); + build.mov(temp2w, bit2mask(WHITE0BIT, WHITE1BIT)); + build.and_(temp1w, temp1w, temp2w); + build.cbz(temp1w, skip); +} + IrLoweringA64::IrLoweringA64(AssemblyBuilderA64& build, ModuleHelpers& helpers, NativeState& data, Proto* proto, IrFunction& function) : build(build) , helpers(helpers) @@ -108,37 +137,89 @@ bool IrLoweringA64::canLower(const IrFunction& function) case IrCmd::LOAD_TVALUE: case IrCmd::LOAD_NODE_VALUE_TV: case IrCmd::LOAD_ENV: + case IrCmd::GET_ARR_ADDR: + case IrCmd::GET_SLOT_NODE_ADDR: + case IrCmd::GET_HASH_NODE_ADDR: case IrCmd::STORE_TAG: case IrCmd::STORE_POINTER: case IrCmd::STORE_DOUBLE: case IrCmd::STORE_INT: case IrCmd::STORE_TVALUE: case IrCmd::STORE_NODE_VALUE_TV: + case IrCmd::ADD_INT: + case IrCmd::SUB_INT: case IrCmd::ADD_NUM: case IrCmd::SUB_NUM: case IrCmd::MUL_NUM: case IrCmd::DIV_NUM: case IrCmd::MOD_NUM: + case IrCmd::POW_NUM: + case IrCmd::MIN_NUM: + case IrCmd::MAX_NUM: case IrCmd::UNM_NUM: + case IrCmd::FLOOR_NUM: + case IrCmd::CEIL_NUM: + case IrCmd::ROUND_NUM: + case IrCmd::SQRT_NUM: + case IrCmd::ABS_NUM: case IrCmd::JUMP: + case IrCmd::JUMP_IF_TRUTHY: + case IrCmd::JUMP_IF_FALSY: case IrCmd::JUMP_EQ_TAG: + case IrCmd::JUMP_EQ_INT: + case IrCmd::JUMP_EQ_POINTER: case IrCmd::JUMP_CMP_NUM: case IrCmd::JUMP_CMP_ANY: + case IrCmd::TABLE_LEN: + case IrCmd::NEW_TABLE: + case IrCmd::DUP_TABLE: + case IrCmd::TRY_NUM_TO_INDEX: + case IrCmd::INT_TO_NUM: + case IrCmd::ADJUST_STACK_TO_REG: + case IrCmd::ADJUST_STACK_TO_TOP: + case IrCmd::INVOKE_FASTCALL: + case IrCmd::CHECK_FASTCALL_RES: case IrCmd::DO_ARITH: + case IrCmd::DO_LEN: + case IrCmd::GET_TABLE: + case IrCmd::SET_TABLE: case IrCmd::GET_IMPORT: + case IrCmd::CONCAT: case IrCmd::GET_UPVALUE: + case IrCmd::SET_UPVALUE: + case IrCmd::PREPARE_FORN: case IrCmd::CHECK_TAG: case IrCmd::CHECK_READONLY: case IrCmd::CHECK_NO_METATABLE: case IrCmd::CHECK_SAFE_ENV: + case IrCmd::CHECK_ARRAY_SIZE: + case IrCmd::CHECK_SLOT_MATCH: case IrCmd::INTERRUPT: + case IrCmd::CHECK_GC: + case IrCmd::BARRIER_OBJ: + case IrCmd::BARRIER_TABLE_BACK: + case IrCmd::BARRIER_TABLE_FORWARD: case IrCmd::SET_SAVEDPC: + case IrCmd::CLOSE_UPVALS: + case IrCmd::CAPTURE: case IrCmd::CALL: case IrCmd::RETURN: + case IrCmd::FALLBACK_GETGLOBAL: + case IrCmd::FALLBACK_SETGLOBAL: + case IrCmd::FALLBACK_GETTABLEKS: + case IrCmd::FALLBACK_SETTABLEKS: + case IrCmd::FALLBACK_NAMECALL: + case IrCmd::FALLBACK_PREPVARARGS: + case IrCmd::FALLBACK_GETVARARGS: + case IrCmd::FALLBACK_NEWCLOSURE: + case IrCmd::FALLBACK_DUPCLOSURE: case IrCmd::SUBSTITUTE: continue; default: +#ifdef TRACE + printf("A64 lowering missing %s\n", getCmdName(inst.cmd)); +#endif return false; } } @@ -199,6 +280,64 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) inst.regA64 = regs.allocReg(KindA64::x); build.ldr(inst.regA64, mem(rClosure, offsetof(Closure, env))); break; + case IrCmd::GET_ARR_ADDR: + { + inst.regA64 = regs.allocReg(KindA64::x); + build.ldr(inst.regA64, mem(regOp(inst.a), offsetof(Table, array))); + + if (inst.b.kind == IrOpKind::Inst) + { + // TODO: This is a temporary hack that reads wN register as if it was xN. This should use unsigned extension shift once we support it. + build.add(inst.regA64, inst.regA64, castReg(KindA64::x, regOp(inst.b)), kTValueSizeLog2); + } + else if (inst.b.kind == IrOpKind::Constant) + { + LUAU_ASSERT(size_t(intOp(inst.b)) <= AssemblyBuilderA64::kMaxImmediate >> kTValueSizeLog2); // TODO: handle out of range values + build.add(inst.regA64, inst.regA64, uint16_t(intOp(inst.b) << kTValueSizeLog2)); + } + else + LUAU_ASSERT(!"Unsupported instruction form"); + break; + } + case IrCmd::GET_SLOT_NODE_ADDR: + { + inst.regA64 = regs.allocReuse(KindA64::x, index, {inst.a}); + RegisterA64 temp1 = regs.allocTemp(KindA64::x); + RegisterA64 temp1w = castReg(KindA64::w, temp1); + RegisterA64 temp2 = regs.allocTemp(KindA64::w); + + // TODO: this can use a slightly more efficient sequence with a 4b load + and-with-right-shift for pcpos<1024 but we don't support it yet. + build.mov(temp1, uintOp(inst.b) * sizeof(Instruction) + kOffsetOfInstructionC); + build.ldrb(temp1w, mem(rCode, temp1)); + build.ldrb(temp2, mem(regOp(inst.a), offsetof(Table, nodemask8))); + build.and_(temp2, temp2, temp1w); + + // note: this may clobber inst.a, so it's important that we don't use it after this + build.ldr(inst.regA64, mem(regOp(inst.a), offsetof(Table, node))); + // TODO: This is a temporary hack that reads wN register as if it was xN. This should use unsigned extension shift once we support it. + build.add(inst.regA64, inst.regA64, castReg(KindA64::x, temp2), kLuaNodeSizeLog2); + break; + } + case IrCmd::GET_HASH_NODE_ADDR: + { + inst.regA64 = regs.allocReuse(KindA64::x, index, {inst.a}); + RegisterA64 temp1 = regs.allocTemp(KindA64::w); + RegisterA64 temp2 = regs.allocTemp(KindA64::w); + + // TODO: this can use bic (andnot) to do hash & ~(-1 << lsizenode) instead but we don't support it yet + build.mov(temp1, 1); + build.ldrb(temp2, mem(regOp(inst.a), offsetof(Table, lsizenode))); + build.lsl(temp1, temp1, temp2); + build.sub(temp1, temp1, 1); + build.mov(temp2, uintOp(inst.b)); + build.and_(temp2, temp2, temp1); + + // note: this may clobber inst.a, so it's important that we don't use it after this + build.ldr(inst.regA64, mem(regOp(inst.a), offsetof(Table, node))); + // TODO: This is a temporary hack that reads wN register as if it was xN. This should use unsigned extension shift once we support it. + build.add(inst.regA64, inst.regA64, castReg(KindA64::x, temp2), kLuaNodeSizeLog2); + break; + } case IrCmd::STORE_TAG: { RegisterA64 temp = regs.allocTemp(KindA64::w); @@ -236,6 +375,16 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) case IrCmd::STORE_NODE_VALUE_TV: build.str(regOp(inst.b), mem(regOp(inst.a), offsetof(LuaNode, val))); break; + case IrCmd::ADD_INT: + LUAU_ASSERT(unsigned(intOp(inst.b)) <= AssemblyBuilderA64::kMaxImmediate); // TODO: handle out of range values + inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a}); + build.add(inst.regA64, regOp(inst.a), uint16_t(intOp(inst.b))); + break; + case IrCmd::SUB_INT: + LUAU_ASSERT(unsigned(intOp(inst.b)) <= AssemblyBuilderA64::kMaxImmediate); // TODO: handle out of range values + inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a}); + build.sub(inst.regA64, regOp(inst.a), uint16_t(intOp(inst.b))); + break; case IrCmd::ADD_NUM: { inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a, inst.b}); @@ -270,7 +419,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } case IrCmd::MOD_NUM: { - inst.regA64 = regs.allocReg(KindA64::d); + inst.regA64 = regs.allocReg(KindA64::d); // can't allocReuse because both A and B are used twice RegisterA64 temp1 = tempDouble(inst.a); RegisterA64 temp2 = tempDouble(inst.b); build.fdiv(inst.regA64, temp1, temp2); @@ -279,6 +428,37 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.fsub(inst.regA64, temp1, inst.regA64); break; } + case IrCmd::POW_NUM: + { + // TODO: this instruction clobbers all registers because of a call but it's unclear how to assert that cleanly atm + inst.regA64 = regs.allocReg(KindA64::d); + RegisterA64 temp1 = tempDouble(inst.a); + RegisterA64 temp2 = tempDouble(inst.b); + build.fmov(d0, temp1); // TODO: aliasing hazard + build.fmov(d1, temp2); // TODO: aliasing hazard + build.ldr(x0, mem(rNativeContext, offsetof(NativeContext, libm_pow))); + build.blr(x0); + build.fmov(inst.regA64, d0); + break; + } + case IrCmd::MIN_NUM: + { + inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a, inst.b}); + RegisterA64 temp1 = tempDouble(inst.a); + RegisterA64 temp2 = tempDouble(inst.b); + build.fcmp(temp1, temp2); + build.fcsel(inst.regA64, temp1, temp2, getConditionFP(IrCondition::Less)); + break; + } + case IrCmd::MAX_NUM: + { + inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a, inst.b}); + RegisterA64 temp1 = tempDouble(inst.a); + RegisterA64 temp2 = tempDouble(inst.b); + build.fcmp(temp1, temp2); + build.fcsel(inst.regA64, temp1, temp2, getConditionFP(IrCondition::Greater)); + break; + } case IrCmd::UNM_NUM: { inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a}); @@ -286,9 +466,76 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.fneg(inst.regA64, temp); break; } + case IrCmd::FLOOR_NUM: + { + inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a}); + RegisterA64 temp = tempDouble(inst.a); + build.frintm(inst.regA64, temp); + break; + } + case IrCmd::CEIL_NUM: + { + inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a}); + RegisterA64 temp = tempDouble(inst.a); + build.frintp(inst.regA64, temp); + break; + } + case IrCmd::ROUND_NUM: + { + inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a}); + RegisterA64 temp = tempDouble(inst.a); + build.frinta(inst.regA64, temp); + break; + } + case IrCmd::SQRT_NUM: + { + inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a}); + RegisterA64 temp = tempDouble(inst.a); + build.fsqrt(inst.regA64, temp); + break; + } + case IrCmd::ABS_NUM: + { + inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a}); + RegisterA64 temp = tempDouble(inst.a); + build.fabs(inst.regA64, temp); + break; + } case IrCmd::JUMP: jumpOrFallthrough(blockOp(inst.a), next); break; + case IrCmd::JUMP_IF_TRUTHY: + { + RegisterA64 temp = regs.allocTemp(KindA64::w); + build.ldr(temp, mem(rBase, vmRegOp(inst.a) * sizeof(TValue) + offsetof(TValue, tt))); + // nil => falsy + LUAU_ASSERT(LUA_TNIL == 0); + build.cbz(temp, labelOp(inst.c)); + // not boolean => truthy + build.cmp(temp, LUA_TBOOLEAN); + build.b(ConditionA64::NotEqual, labelOp(inst.b)); + // compare boolean value + build.ldr(temp, mem(rBase, vmRegOp(inst.a) * sizeof(TValue) + offsetof(TValue, value))); + build.cbnz(temp, labelOp(inst.b)); + jumpOrFallthrough(blockOp(inst.c), next); + break; + } + case IrCmd::JUMP_IF_FALSY: + { + RegisterA64 temp = regs.allocTemp(KindA64::w); + build.ldr(temp, mem(rBase, vmRegOp(inst.a) * sizeof(TValue) + offsetof(TValue, tt))); + // nil => falsy + LUAU_ASSERT(LUA_TNIL == 0); + build.cbz(temp, labelOp(inst.b)); + // not boolean => truthy + build.cmp(temp, LUA_TBOOLEAN); + build.b(ConditionA64::NotEqual, labelOp(inst.c)); + // compare boolean value + build.ldr(temp, mem(rBase, vmRegOp(inst.a) * sizeof(TValue) + offsetof(TValue, value))); + build.cbz(temp, labelOp(inst.b)); + jumpOrFallthrough(blockOp(inst.c), next); + break; + } case IrCmd::JUMP_EQ_TAG: if (inst.b.kind == IrOpKind::Constant) build.cmp(regOp(inst.a), tagOp(inst.b)); @@ -308,6 +555,17 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) jumpOrFallthrough(blockOp(inst.c), next); } break; + case IrCmd::JUMP_EQ_INT: + LUAU_ASSERT(unsigned(intOp(inst.b)) <= AssemblyBuilderA64::kMaxImmediate); + build.cmp(regOp(inst.a), uint16_t(intOp(inst.b))); + build.b(ConditionA64::Equal, labelOp(inst.c)); + jumpOrFallthrough(blockOp(inst.d), next); + break; + case IrCmd::JUMP_EQ_POINTER: + build.cmp(regOp(inst.a), regOp(inst.b)); + build.b(ConditionA64::Equal, labelOp(inst.c)); + jumpOrFallthrough(blockOp(inst.d), next); + break; case IrCmd::JUMP_CMP_NUM: { IrCondition cond = conditionOp(inst.c); @@ -349,6 +607,150 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) jumpOrFallthrough(blockOp(inst.e), next); break; } + case IrCmd::TABLE_LEN: + { + regs.assertAllFreeExcept(regOp(inst.a)); + build.mov(x0, regOp(inst.a)); // TODO: minor aliasing hazard + build.ldr(x1, mem(rNativeContext, offsetof(NativeContext, luaH_getn))); + build.blr(x1); + + inst.regA64 = regs.allocReg(KindA64::d); + build.scvtf(inst.regA64, x0); + break; + } + case IrCmd::NEW_TABLE: + { + regs.assertAllFree(); + build.mov(x0, rState); + build.mov(x1, uintOp(inst.a)); + build.mov(x2, uintOp(inst.b)); + build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaH_new))); + build.blr(x3); + // TODO: we could takeReg x0 but it's unclear if we will be able to keep x0 allocatable due to aliasing concerns + inst.regA64 = regs.allocReg(KindA64::x); + build.mov(inst.regA64, x0); + break; + } + case IrCmd::DUP_TABLE: + { + regs.assertAllFreeExcept(regOp(inst.a)); + build.mov(x0, rState); + build.mov(x1, regOp(inst.a)); // TODO: aliasing hazard + build.ldr(x2, mem(rNativeContext, offsetof(NativeContext, luaH_clone))); + build.blr(x2); + // TODO: we could takeReg x0 but it's unclear if we will be able to keep x0 allocatable due to aliasing concerns + inst.regA64 = regs.allocReg(KindA64::x); + build.mov(inst.regA64, x0); + break; + } + case IrCmd::TRY_NUM_TO_INDEX: + { + inst.regA64 = regs.allocReg(KindA64::w); + RegisterA64 temp1 = tempDouble(inst.a); + + if (build.features & Feature_JSCVT) + { + build.fjcvtzs(inst.regA64, temp1); // fjcvtzs sets PSTATE.Z (equal) iff conversion is exact + build.b(ConditionA64::NotEqual, labelOp(inst.b)); + } + else + { + RegisterA64 temp2 = regs.allocTemp(KindA64::d); + + build.fcvtzs(inst.regA64, temp1); + build.scvtf(temp2, inst.regA64); + build.fcmp(temp1, temp2); + build.b(ConditionA64::NotEqual, labelOp(inst.b)); + } + break; + } + case IrCmd::INT_TO_NUM: + { + inst.regA64 = regs.allocReg(KindA64::d); + RegisterA64 temp = tempInt(inst.a); + build.scvtf(inst.regA64, temp); + break; + } + case IrCmd::ADJUST_STACK_TO_REG: + { + RegisterA64 temp = regs.allocTemp(KindA64::x); + + if (inst.b.kind == IrOpKind::Constant) + { + build.add(temp, rBase, uint16_t((vmRegOp(inst.a) + intOp(inst.b)) * sizeof(TValue))); + build.str(temp, mem(rState, offsetof(lua_State, top))); + } + else if (inst.b.kind == IrOpKind::Inst) + { + build.add(temp, rBase, uint16_t(vmRegOp(inst.a) * sizeof(TValue))); + // TODO: This is a temporary hack that reads wN register as if it was xN. This should use unsigned extension shift once we support it. + build.add(temp, temp, castReg(KindA64::x, regOp(inst.b)), kTValueSizeLog2); + build.str(temp, mem(rState, offsetof(lua_State, top))); + } + else + LUAU_ASSERT(!"Unsupported instruction form"); + break; + } + case IrCmd::ADJUST_STACK_TO_TOP: + { + RegisterA64 temp = regs.allocTemp(KindA64::x); + build.ldr(temp, mem(rState, offsetof(lua_State, ci))); + build.ldr(temp, mem(temp, offsetof(CallInfo, top))); + build.str(temp, mem(rState, offsetof(lua_State, top))); + break; + } + case IrCmd::INVOKE_FASTCALL: + { + regs.assertAllFree(); + build.mov(x0, rState); + build.add(x1, rBase, uint16_t(vmRegOp(inst.b) * sizeof(TValue))); + build.add(x2, rBase, uint16_t(vmRegOp(inst.c) * sizeof(TValue))); + build.mov(w3, intOp(inst.f)); // nresults + + if (inst.d.kind == IrOpKind::VmReg) + build.add(x4, rBase, uint16_t(vmRegOp(inst.d) * sizeof(TValue))); + else if (inst.d.kind == IrOpKind::VmConst) + { + // TODO: refactor into a common helper + if (vmConstOp(inst.d) * sizeof(TValue) <= AssemblyBuilderA64::kMaxImmediate) + { + build.add(x4, rConstants, uint16_t(vmConstOp(inst.d) * sizeof(TValue))); + } + else + { + build.mov(x4, vmConstOp(inst.d) * sizeof(TValue)); + build.add(x4, rConstants, x4); + } + } + else + LUAU_ASSERT(boolOp(inst.d) == false); + + // nparams + if (intOp(inst.e) == LUA_MULTRET) + { + // L->top - (ra + 1) + build.ldr(x5, mem(rState, offsetof(lua_State, top))); + build.sub(x5, x5, rBase); + build.sub(x5, x5, uint16_t((vmRegOp(inst.b) + 1) * sizeof(TValue))); + // TODO: this can use immediate shift right or maybe add/sub with shift right but we don't implement them yet + build.mov(x6, kTValueSizeLog2); + build.lsr(x5, x5, x6); + } + else + build.mov(w5, intOp(inst.e)); + + build.ldr(x6, mem(rNativeContext, offsetof(NativeContext, luauF_table) + uintOp(inst.a) * sizeof(luau_FastFunction))); + build.blr(x6); + + // TODO: we could takeReg w0 but it's unclear if we will be able to keep x0 allocatable due to aliasing concerns + inst.regA64 = regs.allocReg(KindA64::w); + build.mov(inst.regA64, w0); + break; + } + case IrCmd::CHECK_FASTCALL_RES: + build.cmp(regOp(inst.a), 0); + build.b(ConditionA64::Less, labelOp(inst.b)); + break; case IrCmd::DO_ARITH: regs.assertAllFree(); build.mov(x0, rState); @@ -375,12 +777,76 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.ldr(x5, mem(rNativeContext, offsetof(NativeContext, luaV_doarith))); build.blr(x5); + emitUpdateBase(build); + break; + case IrCmd::DO_LEN: + regs.assertAllFree(); + build.mov(x0, rState); + build.add(x1, rBase, uint16_t(vmRegOp(inst.a) * sizeof(TValue))); + build.add(x2, rBase, uint16_t(vmRegOp(inst.b) * sizeof(TValue))); + build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaV_dolen))); + build.blr(x3); + + emitUpdateBase(build); + break; + case IrCmd::GET_TABLE: + regs.assertAllFree(); + build.mov(x0, rState); + build.add(x1, rBase, uint16_t(vmRegOp(inst.b) * sizeof(TValue))); + + if (inst.c.kind == IrOpKind::VmReg) + build.add(x2, rBase, uint16_t(vmRegOp(inst.c) * sizeof(TValue))); + else if (inst.c.kind == IrOpKind::Constant) + { + TValue n; + setnvalue(&n, uintOp(inst.c)); + build.adr(x2, &n, sizeof(n)); + } + else + LUAU_ASSERT(!"Unsupported instruction form"); + + build.add(x3, rBase, uint16_t(vmRegOp(inst.a) * sizeof(TValue))); + build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_gettable))); + build.blr(x4); + + emitUpdateBase(build); + break; + case IrCmd::SET_TABLE: + regs.assertAllFree(); + build.mov(x0, rState); + build.add(x1, rBase, uint16_t(vmRegOp(inst.b) * sizeof(TValue))); + + if (inst.c.kind == IrOpKind::VmReg) + build.add(x2, rBase, uint16_t(vmRegOp(inst.c) * sizeof(TValue))); + else if (inst.c.kind == IrOpKind::Constant) + { + TValue n; + setnvalue(&n, uintOp(inst.c)); + build.adr(x2, &n, sizeof(n)); + } + else + LUAU_ASSERT(!"Unsupported instruction form"); + + build.add(x3, rBase, uint16_t(vmRegOp(inst.a) * sizeof(TValue))); + build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_settable))); + build.blr(x4); + emitUpdateBase(build); break; case IrCmd::GET_IMPORT: regs.assertAllFree(); emitInstGetImport(build, vmRegOp(inst.a), uintOp(inst.b)); break; + case IrCmd::CONCAT: + regs.assertAllFree(); + build.mov(x0, rState); + build.mov(x1, uintOp(inst.b)); + build.mov(x2, vmRegOp(inst.a) + uintOp(inst.b) - 1); + build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaV_concat))); + build.blr(x3); + + emitUpdateBase(build); + break; case IrCmd::GET_UPVALUE: { RegisterA64 temp1 = regs.allocTemp(KindA64::x); @@ -405,6 +871,44 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.str(temp2, mem(rBase, vmRegOp(inst.a) * sizeof(TValue))); break; } + case IrCmd::SET_UPVALUE: + { + regs.assertAllFree(); + RegisterA64 temp1 = regs.allocTemp(KindA64::x); + RegisterA64 temp2 = regs.allocTemp(KindA64::x); + RegisterA64 temp3 = regs.allocTemp(KindA64::q); + RegisterA64 temp4 = regs.allocTemp(KindA64::x); + + // UpVal* + build.ldr(temp1, mem(rClosure, offsetof(Closure, l.uprefs) + sizeof(TValue) * vmUpvalueOp(inst.a) + offsetof(TValue, value.gc))); + + build.ldr(temp2, mem(temp1, offsetof(UpVal, v))); + build.ldr(temp3, mem(rBase, vmRegOp(inst.b) * sizeof(TValue))); + build.str(temp3, temp2); + + Label skip; + checkObjectBarrierConditions(build, temp1, temp2, temp4, vmRegOp(inst.b), skip); + + build.mov(x0, rState); + build.mov(x1, temp1); // TODO: aliasing hazard + build.ldr(x2, mem(rBase, vmRegOp(inst.b) * sizeof(TValue) + offsetof(TValue, value))); + build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaC_barrierf))); + build.blr(x3); + + // note: no emitUpdateBase necessary because luaC_ barriers do not reallocate stack + build.setLabel(skip); + break; + } + case IrCmd::PREPARE_FORN: + regs.assertAllFree(); + build.mov(x0, rState); + build.add(x1, rBase, uint16_t(vmRegOp(inst.a) * sizeof(TValue))); + build.add(x2, rBase, uint16_t(vmRegOp(inst.b) * sizeof(TValue))); + build.add(x3, rBase, uint16_t(vmRegOp(inst.c) * sizeof(TValue))); + build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_prepareFORN))); + build.blr(x4); + // note: no emitUpdateBase necessary because prepareFORN does not reallocate stack + break; case IrCmd::CHECK_TAG: build.cmp(regOp(inst.a), tagOp(inst.b)); build.b(ConditionA64::NotEqual, labelOp(inst.c)); @@ -426,12 +930,55 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) case IrCmd::CHECK_SAFE_ENV: { RegisterA64 temp = regs.allocTemp(KindA64::x); - RegisterA64 tempw{KindA64::w, temp.index}; + RegisterA64 tempw = castReg(KindA64::w, temp); build.ldr(temp, mem(rClosure, offsetof(Closure, env))); build.ldrb(tempw, mem(temp, offsetof(Table, safeenv))); build.cbz(tempw, labelOp(inst.a)); break; } + case IrCmd::CHECK_ARRAY_SIZE: + { + RegisterA64 temp = regs.allocTemp(KindA64::w); + build.ldr(temp, mem(regOp(inst.a), offsetof(Table, sizearray))); + + if (inst.b.kind == IrOpKind::Inst) + build.cmp(temp, regOp(inst.b)); + else if (inst.b.kind == IrOpKind::Constant) + { + LUAU_ASSERT(size_t(intOp(inst.b)) <= AssemblyBuilderA64::kMaxImmediate); // TODO: handle out of range values + build.cmp(temp, uint16_t(intOp(inst.b))); + } + else + LUAU_ASSERT(!"Unsupported instruction form"); + + build.b(ConditionA64::UnsignedLessEqual, labelOp(inst.c)); + break; + } + case IrCmd::CHECK_SLOT_MATCH: + { + RegisterA64 temp1 = regs.allocTemp(KindA64::x); + RegisterA64 temp1w = castReg(KindA64::w, temp1); + RegisterA64 temp2 = regs.allocTemp(KindA64::x); + RegisterA64 temp2w = castReg(KindA64::w, temp2); + + build.ldr(temp1w, mem(regOp(inst.a), kOffsetOfLuaNodeTag)); + // TODO: this needs bitfield extraction, or and-immediate + build.mov(temp2w, kLuaNodeTagMask); + build.and_(temp1w, temp1w, temp2w); + build.cmp(temp1w, LUA_TSTRING); + build.b(ConditionA64::NotEqual, labelOp(inst.c)); + + AddressA64 addr = tempAddr(inst.b, offsetof(TValue, value)); + build.ldr(temp1, mem(regOp(inst.a), offsetof(LuaNode, key.value))); + build.ldr(temp2, addr); + build.cmp(temp1, temp2); + build.b(ConditionA64::NotEqual, labelOp(inst.c)); + + build.ldr(temp1w, mem(regOp(inst.a), offsetof(LuaNode, val.tt))); + LUAU_ASSERT(LUA_TNIL == 0); + build.cbz(temp1w, labelOp(inst.c)); + break; + } case IrCmd::INTERRUPT: { unsigned int pcpos = uintOp(inst.a); @@ -450,6 +997,93 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.setLabel(skip); break; } + case IrCmd::CHECK_GC: + { + regs.assertAllFree(); + RegisterA64 temp1 = regs.allocTemp(KindA64::x); + RegisterA64 temp2 = regs.allocTemp(KindA64::x); + + Label skip; + build.ldr(temp1, mem(rState, offsetof(lua_State, global))); + build.ldr(temp2, mem(temp1, offsetof(global_State, totalbytes))); + build.ldr(temp1, mem(temp1, offsetof(global_State, GCthreshold))); + build.cmp(temp1, temp2); + build.b(ConditionA64::UnsignedGreater, skip); + + build.mov(x0, rState); + build.mov(w1, 1); + build.ldr(x1, mem(rNativeContext, offsetof(NativeContext, luaC_step))); + build.blr(x1); + + emitUpdateBase(build); + build.setLabel(skip); + break; + } + case IrCmd::BARRIER_OBJ: + { + regs.assertAllFreeExcept(regOp(inst.a)); + + Label skip; + RegisterA64 temp1 = regs.allocTemp(KindA64::x); + RegisterA64 temp2 = regs.allocTemp(KindA64::x); + + checkObjectBarrierConditions(build, regOp(inst.a), temp1, temp2, vmRegOp(inst.b), skip); + + build.mov(x0, rState); + build.mov(x1, regOp(inst.a)); // TODO: aliasing hazard + build.ldr(x2, mem(rBase, vmRegOp(inst.b) * sizeof(TValue) + offsetof(TValue, value))); + build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaC_barrierf))); + build.blr(x3); + + // note: no emitUpdateBase necessary because luaC_ barriers do not reallocate stack + build.setLabel(skip); + break; + } + case IrCmd::BARRIER_TABLE_BACK: + { + regs.assertAllFreeExcept(regOp(inst.a)); + + Label skip; + RegisterA64 temp1 = regs.allocTemp(KindA64::w); + RegisterA64 temp2 = regs.allocTemp(KindA64::w); + + // isblack(obj2gco(t)) + build.ldrb(temp1, mem(regOp(inst.a), offsetof(GCheader, marked))); + // TODO: conditional bit test with BLACKBIT + build.mov(temp2, bitmask(BLACKBIT)); + build.and_(temp1, temp1, temp2); + build.cbz(temp1, skip); + + build.mov(x0, rState); + build.mov(x1, regOp(inst.a)); // TODO: aliasing hazard here and below + build.add(x2, regOp(inst.a), uint16_t(offsetof(Table, gclist))); + build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaC_barrierback))); + build.blr(x3); + + // note: no emitUpdateBase necessary because luaC_ barriers do not reallocate stack + build.setLabel(skip); + break; + } + case IrCmd::BARRIER_TABLE_FORWARD: + { + regs.assertAllFreeExcept(regOp(inst.a)); + + Label skip; + RegisterA64 temp1 = regs.allocTemp(KindA64::x); + RegisterA64 temp2 = regs.allocTemp(KindA64::x); + + checkObjectBarrierConditions(build, regOp(inst.a), temp1, temp2, vmRegOp(inst.b), skip); + + build.mov(x0, rState); + build.mov(x1, regOp(inst.a)); // TODO: aliasing hazard + build.ldr(x2, mem(rBase, vmRegOp(inst.b) * sizeof(TValue) + offsetof(TValue, value))); + build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaC_barriertable))); + build.blr(x3); + + // note: no emitUpdateBase necessary because luaC_ barriers do not reallocate stack + build.setLabel(skip); + break; + } case IrCmd::SET_SAVEDPC: { unsigned int pcpos = uintOp(inst.a); @@ -471,6 +1105,34 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.str(temp1, mem(temp2, offsetof(CallInfo, savedpc))); break; } + case IrCmd::CLOSE_UPVALS: + { + regs.assertAllFree(); + Label skip; + RegisterA64 temp1 = regs.allocTemp(KindA64::x); + RegisterA64 temp2 = regs.allocTemp(KindA64::x); + + // L->openupval != 0 + build.ldr(temp1, mem(rState, offsetof(lua_State, openupval))); + build.cbz(temp1, skip); + + // ra <= L->openuval->v + build.ldr(temp1, mem(temp1, offsetof(UpVal, v))); + build.add(temp2, rBase, uint16_t(vmRegOp(inst.a) * sizeof(TValue))); + build.cmp(temp2, temp1); + build.b(ConditionA64::UnsignedGreater, skip); + + build.mov(x0, rState); + build.mov(x1, temp2); // TODO: aliasing hazard + build.ldr(x2, mem(rNativeContext, offsetof(NativeContext, luaF_close))); + build.blr(x2); + + build.setLabel(skip); + break; + } + case IrCmd::CAPTURE: + // no-op + break; case IrCmd::CALL: regs.assertAllFree(); emitInstCall(build, helpers, vmRegOp(inst.a), intOp(inst.b), intOp(inst.c)); @@ -479,6 +1141,74 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) regs.assertAllFree(); emitInstReturn(build, helpers, vmRegOp(inst.a), intOp(inst.b)); break; + + // Full instruction fallbacks + case IrCmd::FALLBACK_GETGLOBAL: + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::VmConst); + + regs.assertAllFree(); + emitFallback(build, LOP_GETGLOBAL, uintOp(inst.a)); + break; + case IrCmd::FALLBACK_SETGLOBAL: + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::VmConst); + + regs.assertAllFree(); + emitFallback(build, LOP_SETGLOBAL, uintOp(inst.a)); + break; + case IrCmd::FALLBACK_GETTABLEKS: + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.d.kind == IrOpKind::VmConst); + + regs.assertAllFree(); + emitFallback(build, LOP_GETTABLEKS, uintOp(inst.a)); + break; + case IrCmd::FALLBACK_SETTABLEKS: + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.d.kind == IrOpKind::VmConst); + + regs.assertAllFree(); + emitFallback(build, LOP_SETTABLEKS, uintOp(inst.a)); + break; + case IrCmd::FALLBACK_NAMECALL: + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.d.kind == IrOpKind::VmConst); + + regs.assertAllFree(); + emitFallback(build, LOP_NAMECALL, uintOp(inst.a)); + break; + case IrCmd::FALLBACK_PREPVARARGS: + LUAU_ASSERT(inst.b.kind == IrOpKind::Constant); + + regs.assertAllFree(); + emitFallback(build, LOP_PREPVARARGS, uintOp(inst.a)); + break; + case IrCmd::FALLBACK_GETVARARGS: + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::Constant); + + regs.assertAllFree(); + emitFallback(build, LOP_GETVARARGS, uintOp(inst.a)); + break; + case IrCmd::FALLBACK_NEWCLOSURE: + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::Constant); + + regs.assertAllFree(); + emitFallback(build, LOP_NEWCLOSURE, uintOp(inst.a)); + break; + case IrCmd::FALLBACK_DUPCLOSURE: + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::VmConst); + + regs.assertAllFree(); + emitFallback(build, LOP_DUPCLOSURE, uintOp(inst.a)); + break; + default: LUAU_ASSERT(!"Not supported yet"); break; @@ -488,6 +1218,11 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) regs.freeTempRegs(); } +bool IrLoweringA64::hasError() const +{ + return false; +} + bool IrLoweringA64::isFallthroughBlock(IrBlock target, IrBlock next) { return target.start == next.start; diff --git a/CodeGen/src/IrLoweringA64.h b/CodeGen/src/IrLoweringA64.h index f638432ff..b374a26a0 100644 --- a/CodeGen/src/IrLoweringA64.h +++ b/CodeGen/src/IrLoweringA64.h @@ -30,6 +30,8 @@ struct IrLoweringA64 void lowerInst(IrInst& inst, uint32_t index, IrBlock& next); + bool hasError() const; + bool isFallthroughBlock(IrBlock target, IrBlock next); void jumpOrFallthrough(IrBlock& target, IrBlock& next); diff --git a/CodeGen/src/IrLoweringX64.cpp b/CodeGen/src/IrLoweringX64.cpp index 8c45f36ad..f2dfdb3b1 100644 --- a/CodeGen/src/IrLoweringX64.cpp +++ b/CodeGen/src/IrLoweringX64.cpp @@ -27,18 +27,39 @@ IrLoweringX64::IrLoweringX64(AssemblyBuilderX64& build, ModuleHelpers& helpers, , helpers(helpers) , data(data) , function(function) - , regs(function) + , regs(build, function) { // In order to allocate registers during lowering, we need to know where instruction results are last used updateLastUseLocations(function); } +void IrLoweringX64::storeDoubleAsFloat(OperandX64 dst, IrOp src) +{ + ScopedRegX64 tmp{regs, SizeX64::xmmword}; + + if (src.kind == IrOpKind::Constant) + { + build.vmovss(tmp.reg, build.f32(float(doubleOp(src)))); + } + else if (src.kind == IrOpKind::Inst) + { + build.vcvtsd2ss(tmp.reg, regOp(src), regOp(src)); + } + else + { + LUAU_ASSERT(!"Unsupported instruction form"); + } + build.vmovss(dst, tmp.reg); +} + void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) { + regs.currInstIdx = index; + switch (inst.cmd) { case IrCmd::LOAD_TAG: - inst.regX64 = regs.allocGprReg(SizeX64::dword); + inst.regX64 = regs.allocGprReg(SizeX64::dword, index); if (inst.a.kind == IrOpKind::VmReg) build.mov(inst.regX64, luauRegTag(vmRegOp(inst.a))); @@ -52,7 +73,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) LUAU_ASSERT(!"Unsupported instruction form"); break; case IrCmd::LOAD_POINTER: - inst.regX64 = regs.allocGprReg(SizeX64::qword); + inst.regX64 = regs.allocGprReg(SizeX64::qword, index); if (inst.a.kind == IrOpKind::VmReg) build.mov(inst.regX64, luauRegValue(vmRegOp(inst.a))); @@ -66,7 +87,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) LUAU_ASSERT(!"Unsupported instruction form"); break; case IrCmd::LOAD_DOUBLE: - inst.regX64 = regs.allocXmmReg(); + inst.regX64 = regs.allocXmmReg(index); if (inst.a.kind == IrOpKind::VmReg) build.vmovsd(inst.regX64, luauRegValue(vmRegOp(inst.a))); @@ -76,12 +97,12 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) LUAU_ASSERT(!"Unsupported instruction form"); break; case IrCmd::LOAD_INT: - inst.regX64 = regs.allocGprReg(SizeX64::dword); + inst.regX64 = regs.allocGprReg(SizeX64::dword, index); build.mov(inst.regX64, luauRegValueInt(vmRegOp(inst.a))); break; case IrCmd::LOAD_TVALUE: - inst.regX64 = regs.allocXmmReg(); + inst.regX64 = regs.allocXmmReg(index); if (inst.a.kind == IrOpKind::VmReg) build.vmovups(inst.regX64, luauReg(vmRegOp(inst.a))); @@ -93,12 +114,12 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) LUAU_ASSERT(!"Unsupported instruction form"); break; case IrCmd::LOAD_NODE_VALUE_TV: - inst.regX64 = regs.allocXmmReg(); + inst.regX64 = regs.allocXmmReg(index); build.vmovups(inst.regX64, luauNodeValue(regOp(inst.a))); break; case IrCmd::LOAD_ENV: - inst.regX64 = regs.allocGprReg(SizeX64::qword); + inst.regX64 = regs.allocGprReg(SizeX64::qword, index); build.mov(inst.regX64, sClosure); build.mov(inst.regX64, qword[inst.regX64 + offsetof(Closure, env)]); @@ -130,7 +151,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) break; case IrCmd::GET_SLOT_NODE_ADDR: { - inst.regX64 = regs.allocGprReg(SizeX64::qword); + inst.regX64 = regs.allocGprReg(SizeX64::qword, index); ScopedRegX64 tmp{regs, SizeX64::qword}; @@ -139,10 +160,10 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } case IrCmd::GET_HASH_NODE_ADDR: { - inst.regX64 = regs.allocGprReg(SizeX64::qword); + inst.regX64 = regs.allocGprReg(SizeX64::qword, index); // Custom bit shift value can only be placed in cl - ScopedRegX64 shiftTmp{regs, regs.takeReg(rcx)}; + ScopedRegX64 shiftTmp{regs, regs.takeReg(rcx, kInvalidInstIdx)}; ScopedRegX64 tmp{regs, SizeX64::qword}; @@ -192,6 +213,13 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) LUAU_ASSERT(!"Unsupported instruction form"); break; } + case IrCmd::STORE_VECTOR: + { + storeDoubleAsFloat(luauRegValueVector(vmRegOp(inst.a), 0), inst.b); + storeDoubleAsFloat(luauRegValueVector(vmRegOp(inst.a), 1), inst.c); + storeDoubleAsFloat(luauRegValueVector(vmRegOp(inst.a), 2), inst.d); + break; + } case IrCmd::STORE_TVALUE: if (inst.a.kind == IrOpKind::VmReg) build.vmovups(luauReg(vmRegOp(inst.a)), regOp(inst.b)); @@ -330,7 +358,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) callWrap.addArgument(SizeX64::xmmword, memRegDoubleOp(inst.a), inst.a); callWrap.addArgument(SizeX64::xmmword, memRegDoubleOp(inst.b), inst.b); callWrap.call(qword[rNativeContext + offsetof(NativeContext, libm_pow)]); - inst.regX64 = regs.takeReg(xmm0); + inst.regX64 = regs.takeReg(xmm0, index); break; } case IrCmd::MIN_NUM: @@ -398,8 +426,10 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) ScopedRegX64 tmp1{regs, SizeX64::xmmword}; ScopedRegX64 tmp2{regs, SizeX64::xmmword}; - if (inst.a.kind != IrOpKind::Inst || regOp(inst.a) != inst.regX64) + if (inst.a.kind != IrOpKind::Inst) build.vmovsd(inst.regX64, memRegDoubleOp(inst.a)); + else if (regOp(inst.a) != inst.regX64) + build.vmovsd(inst.regX64, inst.regX64, regOp(inst.a)); build.vandpd(tmp1.reg, inst.regX64, build.f64x2(-0.0, -0.0)); build.vmovsd(tmp2.reg, build.i64(0x3fdfffffffffffff)); // 0.49999999999999994 @@ -416,8 +446,10 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) case IrCmd::ABS_NUM: inst.regX64 = regs.allocXmmRegOrReuse(index, {inst.a}); - if (inst.a.kind != IrOpKind::Inst || regOp(inst.a) != inst.regX64) + if (inst.a.kind != IrOpKind::Inst) build.vmovsd(inst.regX64, memRegDoubleOp(inst.a)); + else if (regOp(inst.a) != inst.regX64) + build.vmovsd(inst.regX64, inst.regX64, regOp(inst.a)); build.vandpd(inst.regX64, inst.regX64, build.i64(~(1LL << 63))); break; @@ -526,7 +558,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) callWrap.addArgument(SizeX64::qword, regOp(inst.a), inst.a); callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaH_getn)]); - inst.regX64 = regs.allocXmmReg(); + inst.regX64 = regs.allocXmmReg(index); build.vcvtsi2sd(inst.regX64, inst.regX64, eax); break; } @@ -537,7 +569,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) callWrap.addArgument(SizeX64::dword, int32_t(uintOp(inst.a)), inst.a); callWrap.addArgument(SizeX64::dword, int32_t(uintOp(inst.b)), inst.b); callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaH_new)]); - inst.regX64 = regs.takeReg(rax); + inst.regX64 = regs.takeReg(rax, index); break; } case IrCmd::DUP_TABLE: @@ -546,12 +578,12 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) callWrap.addArgument(SizeX64::qword, rState); callWrap.addArgument(SizeX64::qword, regOp(inst.a), inst.a); callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaH_clone)]); - inst.regX64 = regs.takeReg(rax); + inst.regX64 = regs.takeReg(rax, index); break; } case IrCmd::TRY_NUM_TO_INDEX: { - inst.regX64 = regs.allocGprReg(SizeX64::dword); + inst.regX64 = regs.allocGprReg(SizeX64::dword, index); ScopedRegX64 tmp{regs, SizeX64::xmmword}; @@ -574,35 +606,39 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) ScopedRegX64 tmp2{regs, SizeX64::qword}; build.mov(tmp2.reg, qword[rState + offsetof(lua_State, global)]); - IrCallWrapperX64 callWrap(regs, build, index); - callWrap.addArgument(SizeX64::qword, tmp); - callWrap.addArgument(SizeX64::qword, intOp(inst.b)); - callWrap.addArgument(SizeX64::qword, qword[tmp2.release() + offsetof(global_State, tmname) + intOp(inst.b) * sizeof(TString*)]); - callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaT_gettm)]); - inst.regX64 = regs.takeReg(rax); + { + ScopedSpills spillGuard(regs); + + IrCallWrapperX64 callWrap(regs, build, index); + callWrap.addArgument(SizeX64::qword, tmp); + callWrap.addArgument(SizeX64::qword, intOp(inst.b)); + callWrap.addArgument(SizeX64::qword, qword[tmp2.release() + offsetof(global_State, tmname) + intOp(inst.b) * sizeof(TString*)]); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaT_gettm)]); + } + + inst.regX64 = regs.takeReg(rax, index); break; } case IrCmd::INT_TO_NUM: - inst.regX64 = regs.allocXmmReg(); + inst.regX64 = regs.allocXmmReg(index); build.vcvtsi2sd(inst.regX64, inst.regX64, regOp(inst.a)); break; case IrCmd::ADJUST_STACK_TO_REG: { + ScopedRegX64 tmp{regs, SizeX64::qword}; + if (inst.b.kind == IrOpKind::Constant) { - ScopedRegX64 tmp{regs, SizeX64::qword}; - build.lea(tmp.reg, addr[rBase + (vmRegOp(inst.a) + intOp(inst.b)) * sizeof(TValue)]); build.mov(qword[rState + offsetof(lua_State, top)], tmp.reg); } else if (inst.b.kind == IrOpKind::Inst) { - ScopedRegX64 tmp(regs, regs.allocGprRegOrReuse(SizeX64::dword, index, {inst.b})); - - build.shl(qwordReg(tmp.reg), kTValueSizeLog2); - build.lea(qwordReg(tmp.reg), addr[rBase + qwordReg(tmp.reg) + vmRegOp(inst.a) * sizeof(TValue)]); - build.mov(qword[rState + offsetof(lua_State, top)], qwordReg(tmp.reg)); + build.mov(dwordReg(tmp.reg), regOp(inst.b)); + build.shl(tmp.reg, kTValueSizeLog2); + build.lea(tmp.reg, addr[rBase + tmp.reg + vmRegOp(inst.a) * sizeof(TValue)]); + build.mov(qword[rState + offsetof(lua_State, top)], tmp.reg); } else { @@ -640,52 +676,37 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) int nparams = intOp(inst.e); int nresults = intOp(inst.f); - regs.assertAllFree(); - - build.mov(rax, qword[rNativeContext + offsetof(NativeContext, luauF_table) + bfid * sizeof(luau_FastFunction)]); + ScopedRegX64 func{regs, SizeX64::qword}; + build.mov(func.reg, qword[rNativeContext + offsetof(NativeContext, luauF_table) + bfid * sizeof(luau_FastFunction)]); - // 5th parameter (args) is left unset for LOP_FASTCALL1 - if (args.cat == CategoryX64::mem) - { - if (build.abi == ABIX64::Windows) - { - build.lea(rcx, args); - build.mov(sArg5, rcx); - } - else - { - build.lea(rArg5, args); - } - } + IrCallWrapperX64 callWrap(regs, build, index); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::qword, luauRegAddress(ra)); + callWrap.addArgument(SizeX64::qword, luauRegAddress(arg)); + callWrap.addArgument(SizeX64::dword, nresults); + callWrap.addArgument(SizeX64::qword, args); if (nparams == LUA_MULTRET) { - // L->top - (ra + 1) - RegisterX64 reg = (build.abi == ABIX64::Windows) ? rcx : rArg6; + // Compute 'L->top - (ra + 1)', on SystemV, take r9 register to compute directly into the argument + // TODO: IrCallWrapperX64 should provide a way to 'guess' target argument register correctly + RegisterX64 reg = build.abi == ABIX64::Windows ? regs.allocGprReg(SizeX64::qword, kInvalidInstIdx) : regs.takeReg(rArg6, kInvalidInstIdx); + ScopedRegX64 tmp{regs, SizeX64::qword}; + build.mov(reg, qword[rState + offsetof(lua_State, top)]); - build.lea(rdx, addr[rBase + (ra + 1) * sizeof(TValue)]); - build.sub(reg, rdx); + build.lea(tmp.reg, addr[rBase + (ra + 1) * sizeof(TValue)]); + build.sub(reg, tmp.reg); build.shr(reg, kTValueSizeLog2); - if (build.abi == ABIX64::Windows) - build.mov(sArg6, reg); + callWrap.addArgument(SizeX64::dword, dwordReg(reg)); } else { - if (build.abi == ABIX64::Windows) - build.mov(sArg6, nparams); - else - build.mov(rArg6, nparams); + callWrap.addArgument(SizeX64::dword, nparams); } - build.mov(rArg1, rState); - build.lea(rArg2, luauRegAddress(ra)); - build.lea(rArg3, luauRegAddress(arg)); - build.mov(dwordReg(rArg4), nresults); - - build.call(rax); - - inst.regX64 = regs.takeReg(eax); // Result of a builtin call is returned in eax + callWrap.call(func.release()); + inst.regX64 = regs.takeReg(eax, index); // Result of a builtin call is returned in eax break; } case IrCmd::CHECK_FASTCALL_RES: @@ -738,6 +759,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } break; case IrCmd::GET_IMPORT: + regs.assertAllFree(); emitInstGetImportFallback(build, vmRegOp(inst.a), uintOp(inst.b)); break; case IrCmd::CONCAT: @@ -777,7 +799,6 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } case IrCmd::SET_UPVALUE: { - Label next; ScopedRegX64 tmp1{regs, SizeX64::qword}; ScopedRegX64 tmp2{regs, SizeX64::qword}; @@ -794,8 +815,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) tmp1.free(); - callBarrierObject(regs, build, tmp2.release(), {}, vmRegOp(inst.b), next); - build.setLabel(next); + callBarrierObject(regs, build, tmp2.release(), {}, vmRegOp(inst.b)); break; } case IrCmd::PREPARE_FORN: @@ -859,26 +879,14 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) emitInterrupt(build, uintOp(inst.a)); break; case IrCmd::CHECK_GC: - { - Label skip; - callCheckGc(regs, build, skip); - build.setLabel(skip); + callStepGc(regs, build); break; - } case IrCmd::BARRIER_OBJ: - { - Label skip; - callBarrierObject(regs, build, regOp(inst.a), inst.a, vmRegOp(inst.b), skip); - build.setLabel(skip); + callBarrierObject(regs, build, regOp(inst.a), inst.a, vmRegOp(inst.b)); break; - } case IrCmd::BARRIER_TABLE_BACK: - { - Label skip; - callBarrierTableFast(regs, build, regOp(inst.a), inst.a, skip); - build.setLabel(skip); + callBarrierTableFast(regs, build, regOp(inst.a), inst.a); break; - } case IrCmd::BARRIER_TABLE_FORWARD: { Label skip; @@ -886,11 +894,15 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) ScopedRegX64 tmp{regs, SizeX64::qword}; checkObjectBarrierConditions(build, tmp.reg, regOp(inst.a), vmRegOp(inst.b), skip); - IrCallWrapperX64 callWrap(regs, build, index); - callWrap.addArgument(SizeX64::qword, rState); - callWrap.addArgument(SizeX64::qword, regOp(inst.a), inst.a); - callWrap.addArgument(SizeX64::qword, tmp); - callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaC_barriertable)]); + { + ScopedSpills spillGuard(regs); + + IrCallWrapperX64 callWrap(regs, build, index); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::qword, regOp(inst.a), inst.a); + callWrap.addArgument(SizeX64::qword, tmp); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaC_barriertable)]); + } build.setLabel(skip); break; @@ -925,10 +937,14 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) tmp1.free(); - IrCallWrapperX64 callWrap(regs, build, index); - callWrap.addArgument(SizeX64::qword, rState); - callWrap.addArgument(SizeX64::qword, tmp2); - callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaF_close)]); + { + ScopedSpills spillGuard(regs); + + IrCallWrapperX64 callWrap(regs, build, index); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::qword, tmp2); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaF_close)]); + } build.setLabel(next); break; @@ -939,19 +955,17 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) // Fallbacks to non-IR instruction implementations case IrCmd::SETLIST: - { - Label next; regs.assertAllFree(); - emitInstSetList(regs, build, next, vmRegOp(inst.b), vmRegOp(inst.c), intOp(inst.d), uintOp(inst.e)); - build.setLabel(next); + emitInstSetList(regs, build, vmRegOp(inst.b), vmRegOp(inst.c), intOp(inst.d), uintOp(inst.e)); break; - } case IrCmd::CALL: regs.assertAllFree(); + regs.assertNoSpills(); emitInstCall(build, helpers, vmRegOp(inst.a), intOp(inst.b), intOp(inst.c)); break; case IrCmd::RETURN: regs.assertAllFree(); + regs.assertNoSpills(); emitInstReturn(build, helpers, vmRegOp(inst.a), intOp(inst.b)); break; case IrCmd::FORGLOOP: @@ -967,22 +981,6 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) regs.assertAllFree(); emitInstForGPrepXnextFallback(build, uintOp(inst.a), vmRegOp(inst.b), labelOp(inst.c)); break; - case IrCmd::AND: - regs.assertAllFree(); - emitInstAnd(build, vmRegOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c)); - break; - case IrCmd::ANDK: - regs.assertAllFree(); - emitInstAndK(build, vmRegOp(inst.a), vmRegOp(inst.b), vmConstOp(inst.c)); - break; - case IrCmd::OR: - regs.assertAllFree(); - emitInstOr(build, vmRegOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c)); - break; - case IrCmd::ORK: - regs.assertAllFree(); - emitInstOrK(build, vmRegOp(inst.a), vmRegOp(inst.b), vmConstOp(inst.c)); - break; case IrCmd::COVERAGE: regs.assertAllFree(); emitInstCoverage(build, uintOp(inst.a)); @@ -1066,6 +1064,15 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) regs.freeLastUseRegs(inst, index); } +bool IrLoweringX64::hasError() const +{ + // If register allocator had to use more stack slots than we have available, this function can't run natively + if (regs.maxUsedSlot > kSpillSlots) + return true; + + return false; +} + bool IrLoweringX64::isFallthroughBlock(IrBlock target, IrBlock next) { return target.start == next.start; @@ -1077,7 +1084,7 @@ void IrLoweringX64::jumpOrFallthrough(IrBlock& target, IrBlock& next) build.jmp(target.label); } -OperandX64 IrLoweringX64::memRegDoubleOp(IrOp op) const +OperandX64 IrLoweringX64::memRegDoubleOp(IrOp op) { switch (op.kind) { @@ -1096,7 +1103,7 @@ OperandX64 IrLoweringX64::memRegDoubleOp(IrOp op) const return noreg; } -OperandX64 IrLoweringX64::memRegTagOp(IrOp op) const +OperandX64 IrLoweringX64::memRegTagOp(IrOp op) { switch (op.kind) { @@ -1113,9 +1120,13 @@ OperandX64 IrLoweringX64::memRegTagOp(IrOp op) const return noreg; } -RegisterX64 IrLoweringX64::regOp(IrOp op) const +RegisterX64 IrLoweringX64::regOp(IrOp op) { IrInst& inst = function.instOp(op); + + if (inst.spilled) + regs.restore(inst, false); + LUAU_ASSERT(inst.regX64 != noreg); return inst.regX64; } diff --git a/CodeGen/src/IrLoweringX64.h b/CodeGen/src/IrLoweringX64.h index ecaa6a1d5..42d262775 100644 --- a/CodeGen/src/IrLoweringX64.h +++ b/CodeGen/src/IrLoweringX64.h @@ -27,13 +27,17 @@ struct IrLoweringX64 void lowerInst(IrInst& inst, uint32_t index, IrBlock& next); + bool hasError() const; + bool isFallthroughBlock(IrBlock target, IrBlock next); void jumpOrFallthrough(IrBlock& target, IrBlock& next); + void storeDoubleAsFloat(OperandX64 dst, IrOp src); + // Operand data lookup helpers - OperandX64 memRegDoubleOp(IrOp op) const; - OperandX64 memRegTagOp(IrOp op) const; - RegisterX64 regOp(IrOp op) const; + OperandX64 memRegDoubleOp(IrOp op); + OperandX64 memRegTagOp(IrOp op); + RegisterX64 regOp(IrOp op); IrConst constOp(IrOp op) const; uint8_t tagOp(IrOp op) const; diff --git a/CodeGen/src/IrRegAllocA64.cpp b/CodeGen/src/IrRegAllocA64.cpp index dc18ab56d..c6db9e9e0 100644 --- a/CodeGen/src/IrRegAllocA64.cpp +++ b/CodeGen/src/IrRegAllocA64.cpp @@ -151,6 +151,15 @@ void IrRegAllocA64::assertAllFree() const LUAU_ASSERT(simd.free == simd.base); } +void IrRegAllocA64::assertAllFreeExcept(RegisterA64 reg) const +{ + const Set& set = const_cast(this)->getSet(reg.kind); + const Set& other = &set == &gpr ? simd : gpr; + + LUAU_ASSERT(set.free == (set.base & ~(1u << reg.index))); + LUAU_ASSERT(other.free == other.base); +} + IrRegAllocA64::Set& IrRegAllocA64::getSet(KindA64 kind) { switch (kind) diff --git a/CodeGen/src/IrRegAllocA64.h b/CodeGen/src/IrRegAllocA64.h index 2ed0787aa..9ff035528 100644 --- a/CodeGen/src/IrRegAllocA64.h +++ b/CodeGen/src/IrRegAllocA64.h @@ -30,6 +30,7 @@ struct IrRegAllocA64 void freeTempRegs(); void assertAllFree() const; + void assertAllFreeExcept(RegisterA64 reg) const; IrFunction& function; diff --git a/CodeGen/src/IrRegAllocX64.cpp b/CodeGen/src/IrRegAllocX64.cpp index eeb6cfe69..dc9e7f908 100644 --- a/CodeGen/src/IrRegAllocX64.cpp +++ b/CodeGen/src/IrRegAllocX64.cpp @@ -1,6 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/IrRegAllocX64.h" +#include "EmitCommonX64.h" + namespace Luau { namespace CodeGen @@ -10,14 +12,22 @@ namespace X64 static const RegisterX64 kGprAllocOrder[] = {rax, rdx, rcx, rbx, rsi, rdi, r8, r9, r10, r11}; -IrRegAllocX64::IrRegAllocX64(IrFunction& function) - : function(function) +static bool isFullTvalueOperand(IrCmd cmd) +{ + return cmd == IrCmd::LOAD_TVALUE || cmd == IrCmd::LOAD_NODE_VALUE_TV; +} + +IrRegAllocX64::IrRegAllocX64(AssemblyBuilderX64& build, IrFunction& function) + : build(build) + , function(function) { freeGprMap.fill(true); + gprInstUsers.fill(kInvalidInstIdx); freeXmmMap.fill(true); + xmmInstUsers.fill(kInvalidInstIdx); } -RegisterX64 IrRegAllocX64::allocGprReg(SizeX64 preferredSize) +RegisterX64 IrRegAllocX64::allocGprReg(SizeX64 preferredSize, uint32_t instIdx) { LUAU_ASSERT( preferredSize == SizeX64::byte || preferredSize == SizeX64::word || preferredSize == SizeX64::dword || preferredSize == SizeX64::qword); @@ -27,30 +37,40 @@ RegisterX64 IrRegAllocX64::allocGprReg(SizeX64 preferredSize) if (freeGprMap[reg.index]) { freeGprMap[reg.index] = false; + gprInstUsers[reg.index] = instIdx; return RegisterX64{preferredSize, reg.index}; } } + // If possible, spill the value with the furthest next use + if (uint32_t furthestUseTarget = findInstructionWithFurthestNextUse(gprInstUsers); furthestUseTarget != kInvalidInstIdx) + return takeReg(function.instructions[furthestUseTarget].regX64, instIdx); + LUAU_ASSERT(!"Out of GPR registers to allocate"); return noreg; } -RegisterX64 IrRegAllocX64::allocXmmReg() +RegisterX64 IrRegAllocX64::allocXmmReg(uint32_t instIdx) { for (size_t i = 0; i < freeXmmMap.size(); ++i) { if (freeXmmMap[i]) { freeXmmMap[i] = false; + xmmInstUsers[i] = instIdx; return RegisterX64{SizeX64::xmmword, uint8_t(i)}; } } + // Out of registers, spill the value with the furthest next use + if (uint32_t furthestUseTarget = findInstructionWithFurthestNextUse(xmmInstUsers); furthestUseTarget != kInvalidInstIdx) + return takeReg(function.instructions[furthestUseTarget].regX64, instIdx); + LUAU_ASSERT(!"Out of XMM registers to allocate"); return noreg; } -RegisterX64 IrRegAllocX64::allocGprRegOrReuse(SizeX64 preferredSize, uint32_t index, std::initializer_list oprefs) +RegisterX64 IrRegAllocX64::allocGprRegOrReuse(SizeX64 preferredSize, uint32_t instIdx, std::initializer_list oprefs) { for (IrOp op : oprefs) { @@ -59,20 +79,21 @@ RegisterX64 IrRegAllocX64::allocGprRegOrReuse(SizeX64 preferredSize, uint32_t in IrInst& source = function.instructions[op.index]; - if (source.lastUse == index && !source.reusedReg) + if (source.lastUse == instIdx && !source.reusedReg && !source.spilled) { LUAU_ASSERT(source.regX64.size != SizeX64::xmmword); LUAU_ASSERT(source.regX64 != noreg); source.reusedReg = true; + gprInstUsers[source.regX64.index] = instIdx; return RegisterX64{preferredSize, source.regX64.index}; } } - return allocGprReg(preferredSize); + return allocGprReg(preferredSize, instIdx); } -RegisterX64 IrRegAllocX64::allocXmmRegOrReuse(uint32_t index, std::initializer_list oprefs) +RegisterX64 IrRegAllocX64::allocXmmRegOrReuse(uint32_t instIdx, std::initializer_list oprefs) { for (IrOp op : oprefs) { @@ -81,32 +102,45 @@ RegisterX64 IrRegAllocX64::allocXmmRegOrReuse(uint32_t index, std::initializer_l IrInst& source = function.instructions[op.index]; - if (source.lastUse == index && !source.reusedReg) + if (source.lastUse == instIdx && !source.reusedReg && !source.spilled) { LUAU_ASSERT(source.regX64.size == SizeX64::xmmword); LUAU_ASSERT(source.regX64 != noreg); source.reusedReg = true; + xmmInstUsers[source.regX64.index] = instIdx; return source.regX64; } } - return allocXmmReg(); + return allocXmmReg(instIdx); } -RegisterX64 IrRegAllocX64::takeReg(RegisterX64 reg) +RegisterX64 IrRegAllocX64::takeReg(RegisterX64 reg, uint32_t instIdx) { - // In a more advanced register allocator, this would require a spill for the current register user - // But at the current stage we don't have register live ranges intersecting forced register uses if (reg.size == SizeX64::xmmword) { + if (!freeXmmMap[reg.index]) + { + LUAU_ASSERT(xmmInstUsers[reg.index] != kInvalidInstIdx); + preserve(function.instructions[xmmInstUsers[reg.index]]); + } + LUAU_ASSERT(freeXmmMap[reg.index]); freeXmmMap[reg.index] = false; + xmmInstUsers[reg.index] = instIdx; } else { + if (!freeGprMap[reg.index]) + { + LUAU_ASSERT(gprInstUsers[reg.index] != kInvalidInstIdx); + preserve(function.instructions[gprInstUsers[reg.index]]); + } + LUAU_ASSERT(freeGprMap[reg.index]); freeGprMap[reg.index] = false; + gprInstUsers[reg.index] = instIdx; } return reg; @@ -118,17 +152,19 @@ void IrRegAllocX64::freeReg(RegisterX64 reg) { LUAU_ASSERT(!freeXmmMap[reg.index]); freeXmmMap[reg.index] = true; + xmmInstUsers[reg.index] = kInvalidInstIdx; } else { LUAU_ASSERT(!freeGprMap[reg.index]); freeGprMap[reg.index] = true; + gprInstUsers[reg.index] = kInvalidInstIdx; } } -void IrRegAllocX64::freeLastUseReg(IrInst& target, uint32_t index) +void IrRegAllocX64::freeLastUseReg(IrInst& target, uint32_t instIdx) { - if (isLastUseReg(target, index)) + if (isLastUseReg(target, instIdx)) { // Register might have already been freed if it had multiple uses inside a single instruction if (target.regX64 == noreg) @@ -139,11 +175,11 @@ void IrRegAllocX64::freeLastUseReg(IrInst& target, uint32_t index) } } -void IrRegAllocX64::freeLastUseRegs(const IrInst& inst, uint32_t index) +void IrRegAllocX64::freeLastUseRegs(const IrInst& inst, uint32_t instIdx) { - auto checkOp = [this, index](IrOp op) { + auto checkOp = [this, instIdx](IrOp op) { if (op.kind == IrOpKind::Inst) - freeLastUseReg(function.instructions[op.index], index); + freeLastUseReg(function.instructions[op.index], instIdx); }; checkOp(inst.a); @@ -154,9 +190,132 @@ void IrRegAllocX64::freeLastUseRegs(const IrInst& inst, uint32_t index) checkOp(inst.f); } -bool IrRegAllocX64::isLastUseReg(const IrInst& target, uint32_t index) const +bool IrRegAllocX64::isLastUseReg(const IrInst& target, uint32_t instIdx) const +{ + return target.lastUse == instIdx && !target.reusedReg; +} + +void IrRegAllocX64::preserve(IrInst& inst) +{ + bool doubleSlot = isFullTvalueOperand(inst.cmd); + + // Find a free stack slot. Two consecutive slots might be required for 16 byte TValues, so '- 1' is used + for (unsigned i = 0; i < unsigned(usedSpillSlots.size() - 1); ++i) + { + if (usedSpillSlots.test(i)) + continue; + + if (doubleSlot && usedSpillSlots.test(i + 1)) + { + ++i; // No need to retest this double position + continue; + } + + if (inst.regX64.size == SizeX64::xmmword && doubleSlot) + { + build.vmovups(xmmword[sSpillArea + i * 8], inst.regX64); + } + else if (inst.regX64.size == SizeX64::xmmword) + { + build.vmovsd(qword[sSpillArea + i * 8], inst.regX64); + } + else + { + OperandX64 location = addr[sSpillArea + i * 8]; + location.memSize = inst.regX64.size; // Override memory access size + build.mov(location, inst.regX64); + } + + usedSpillSlots.set(i); + + if (i + 1 > maxUsedSlot) + maxUsedSlot = i + 1; + + if (doubleSlot) + { + usedSpillSlots.set(i + 1); + + if (i + 2 > maxUsedSlot) + maxUsedSlot = i + 2; + } + + IrSpillX64 spill; + spill.instIdx = function.getInstIndex(inst); + spill.useDoubleSlot = doubleSlot; + spill.stackSlot = uint8_t(i); + spill.originalLoc = inst.regX64; + + spills.push_back(spill); + + freeReg(inst.regX64); + + inst.regX64 = noreg; + inst.spilled = true; + return; + } + + LUAU_ASSERT(!"nowhere to spill"); +} + +void IrRegAllocX64::restore(IrInst& inst, bool intoOriginalLocation) +{ + uint32_t instIdx = function.getInstIndex(inst); + + for (size_t i = 0; i < spills.size(); i++) + { + const IrSpillX64& spill = spills[i]; + + if (spill.instIdx == instIdx) + { + LUAU_ASSERT(spill.stackSlot != kNoStackSlot); + RegisterX64 reg; + + if (spill.originalLoc.size == SizeX64::xmmword) + { + reg = intoOriginalLocation ? takeReg(spill.originalLoc, instIdx) : allocXmmReg(instIdx); + + if (spill.useDoubleSlot) + build.vmovups(reg, xmmword[sSpillArea + spill.stackSlot * 8]); + else + build.vmovsd(reg, qword[sSpillArea + spill.stackSlot * 8]); + } + else + { + reg = intoOriginalLocation ? takeReg(spill.originalLoc, instIdx) : allocGprReg(spill.originalLoc.size, instIdx); + + OperandX64 location = addr[sSpillArea + spill.stackSlot * 8]; + location.memSize = reg.size; // Override memory access size + build.mov(reg, location); + } + + inst.regX64 = reg; + inst.spilled = false; + + usedSpillSlots.set(spill.stackSlot, false); + + if (spill.useDoubleSlot) + usedSpillSlots.set(spill.stackSlot + 1, false); + + spills[i] = spills.back(); + spills.pop_back(); + return; + } + } +} + +void IrRegAllocX64::preserveAndFreeInstValues() { - return target.lastUse == index && !target.reusedReg; + for (uint32_t instIdx : gprInstUsers) + { + if (instIdx != kInvalidInstIdx) + preserve(function.instructions[instIdx]); + } + + for (uint32_t instIdx : xmmInstUsers) + { + if (instIdx != kInvalidInstIdx) + preserve(function.instructions[instIdx]); + } } bool IrRegAllocX64::shouldFreeGpr(RegisterX64 reg) const @@ -175,6 +334,33 @@ bool IrRegAllocX64::shouldFreeGpr(RegisterX64 reg) const return false; } +uint32_t IrRegAllocX64::findInstructionWithFurthestNextUse(const std::array& regInstUsers) const +{ + uint32_t furthestUseTarget = kInvalidInstIdx; + uint32_t furthestUseLocation = 0; + + for (uint32_t regInstUser : regInstUsers) + { + // Cannot spill temporary registers or the register of the value that's defined in the current instruction + if (regInstUser == kInvalidInstIdx || regInstUser == currInstIdx) + continue; + + uint32_t nextUse = getNextInstUse(function, regInstUser, currInstIdx); + + // Cannot spill value that is about to be used in the current instruction + if (nextUse == currInstIdx) + continue; + + if (furthestUseTarget == kInvalidInstIdx || nextUse > furthestUseLocation) + { + furthestUseLocation = nextUse; + furthestUseTarget = regInstUser; + } + } + + return furthestUseTarget; +} + void IrRegAllocX64::assertFree(RegisterX64 reg) const { if (reg.size == SizeX64::xmmword) @@ -192,6 +378,11 @@ void IrRegAllocX64::assertAllFree() const LUAU_ASSERT(free); } +void IrRegAllocX64::assertNoSpills() const +{ + LUAU_ASSERT(spills.empty()); +} + ScopedRegX64::ScopedRegX64(IrRegAllocX64& owner) : owner(owner) , reg(noreg) @@ -222,9 +413,9 @@ void ScopedRegX64::alloc(SizeX64 size) LUAU_ASSERT(reg == noreg); if (size == SizeX64::xmmword) - reg = owner.allocXmmReg(); + reg = owner.allocXmmReg(kInvalidInstIdx); else - reg = owner.allocGprReg(size); + reg = owner.allocGprReg(size, kInvalidInstIdx); } void ScopedRegX64::free() @@ -241,6 +432,41 @@ RegisterX64 ScopedRegX64::release() return tmp; } +ScopedSpills::ScopedSpills(IrRegAllocX64& owner) + : owner(owner) +{ + snapshot = owner.spills; +} + +ScopedSpills::~ScopedSpills() +{ + // Taking a copy of current spills because we are going to potentially restore them + std::vector current = owner.spills; + + // Restore registers that were spilled inside scope protected by this object + for (IrSpillX64& curr : current) + { + // If spill existed before current scope, it can be restored outside of it + if (!wasSpilledBefore(curr)) + { + IrInst& inst = owner.function.instructions[curr.instIdx]; + + owner.restore(inst, /*intoOriginalLocation*/ true); + } + } +} + +bool ScopedSpills::wasSpilledBefore(const IrSpillX64& spill) const +{ + for (const IrSpillX64& preexisting : snapshot) + { + if (spill.instIdx == preexisting.instIdx) + return true; + } + + return false; +} + } // namespace X64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/IrTranslateBuiltins.cpp b/CodeGen/src/IrTranslateBuiltins.cpp index 2955aaffb..ba4915645 100644 --- a/CodeGen/src/IrTranslateBuiltins.cpp +++ b/CodeGen/src/IrTranslateBuiltins.cpp @@ -61,7 +61,8 @@ BuiltinImplResult translateBuiltinNumberTo2Number( if (ra != arg) build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); - build.inst(IrCmd::STORE_TAG, build.vmReg(ra + 1), build.constTag(LUA_TNUMBER)); + if (nresults > 1) + build.inst(IrCmd::STORE_TAG, build.vmReg(ra + 1), build.constTag(LUA_TNUMBER)); return {BuiltinImplType::UsesFallback, 2}; } @@ -190,10 +191,10 @@ BuiltinImplResult translateBuiltinMathClamp(IrBuilder& build, int nparams, int r build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); build.loadAndCheckTag(args, LUA_TNUMBER, fallback); - build.loadAndCheckTag(build.vmReg(args.index + 1), LUA_TNUMBER, fallback); + build.loadAndCheckTag(build.vmReg(vmRegOp(args) + 1), LUA_TNUMBER, fallback); IrOp min = build.inst(IrCmd::LOAD_DOUBLE, args); - IrOp max = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(args.index + 1)); + IrOp max = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(vmRegOp(args) + 1)); build.inst(IrCmd::JUMP_CMP_NUM, min, max, build.cond(IrCondition::NotLessEqual), fallback, block); build.beginBlock(block); @@ -274,6 +275,27 @@ BuiltinImplResult translateBuiltinTypeof(IrBuilder& build, int nparams, int ra, return {BuiltinImplType::UsesFallback, 1}; } +BuiltinImplResult translateBuiltinVector(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +{ + if (nparams < 3 || nresults > 1) + return {BuiltinImplType::None, -1}; + + LUAU_ASSERT(LUA_VECTOR_SIZE == 3); + + build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); + build.loadAndCheckTag(args, LUA_TNUMBER, fallback); + build.loadAndCheckTag(build.vmReg(vmRegOp(args) + 1), LUA_TNUMBER, fallback); + + IrOp x = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); + IrOp y = build.inst(IrCmd::LOAD_DOUBLE, args); + IrOp z = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(vmRegOp(args) + 1)); + + build.inst(IrCmd::STORE_VECTOR, build.vmReg(ra), x, y, z); + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TVECTOR)); + + return {BuiltinImplType::UsesFallback, 1}; +} + BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, IrOp args, int nparams, int nresults, IrOp fallback) { // Builtins are not allowed to handle variadic arguments @@ -332,6 +354,8 @@ BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, return translateBuiltinType(build, nparams, ra, arg, args, nresults, fallback); case LBF_TYPEOF: return translateBuiltinTypeof(build, nparams, ra, arg, args, nresults, fallback); + case LBF_VECTOR: + return translateBuiltinVector(build, nparams, ra, arg, args, nresults, fallback); default: return {BuiltinImplType::None, -1}; } diff --git a/CodeGen/src/IrTranslation.cpp b/CodeGen/src/IrTranslation.cpp index e366888e6..a985318b9 100644 --- a/CodeGen/src/IrTranslation.cpp +++ b/CodeGen/src/IrTranslation.cpp @@ -301,7 +301,7 @@ static void translateInstBinaryNumeric(IrBuilder& build, int ra, int rb, int rc, if (opc.kind == IrOpKind::VmConst) { LUAU_ASSERT(build.function.proto); - TValue protok = build.function.proto->k[opc.index]; + TValue protok = build.function.proto->k[vmConstOp(opc)]; LUAU_ASSERT(protok.tt == LUA_TNUMBER); @@ -1108,5 +1108,71 @@ void translateInstNamecall(IrBuilder& build, const Instruction* pc, int pcpos) build.beginBlock(next); } +void translateInstAndX(IrBuilder& build, const Instruction* pc, int pcpos, IrOp c) +{ + int ra = LUAU_INSN_A(*pc); + int rb = LUAU_INSN_B(*pc); + + IrOp fallthrough = build.block(IrBlockKind::Internal); + IrOp next = build.blockAtInst(pcpos + 1); + + IrOp target = (ra == rb) ? next : build.block(IrBlockKind::Internal); + + build.inst(IrCmd::JUMP_IF_FALSY, build.vmReg(rb), target, fallthrough); + build.beginBlock(fallthrough); + + IrOp load = build.inst(IrCmd::LOAD_TVALUE, c); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), load); + build.inst(IrCmd::JUMP, next); + + if (ra == rb) + { + build.beginBlock(next); + } + else + { + build.beginBlock(target); + + IrOp load1 = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(rb)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), load1); + build.inst(IrCmd::JUMP, next); + + build.beginBlock(next); + } +} + +void translateInstOrX(IrBuilder& build, const Instruction* pc, int pcpos, IrOp c) +{ + int ra = LUAU_INSN_A(*pc); + int rb = LUAU_INSN_B(*pc); + + IrOp fallthrough = build.block(IrBlockKind::Internal); + IrOp next = build.blockAtInst(pcpos + 1); + + IrOp target = (ra == rb) ? next : build.block(IrBlockKind::Internal); + + build.inst(IrCmd::JUMP_IF_TRUTHY, build.vmReg(rb), target, fallthrough); + build.beginBlock(fallthrough); + + IrOp load = build.inst(IrCmd::LOAD_TVALUE, c); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), load); + build.inst(IrCmd::JUMP, next); + + if (ra == rb) + { + build.beginBlock(next); + } + else + { + build.beginBlock(target); + + IrOp load1 = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(rb)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), load1); + build.inst(IrCmd::JUMP, next); + + build.beginBlock(next); + } +} + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/IrTranslation.h b/CodeGen/src/IrTranslation.h index 0be111dca..87a530b50 100644 --- a/CodeGen/src/IrTranslation.h +++ b/CodeGen/src/IrTranslation.h @@ -61,6 +61,8 @@ void translateInstSetGlobal(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstConcat(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstCapture(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstNamecall(IrBuilder& build, const Instruction* pc, int pcpos); +void translateInstAndX(IrBuilder& build, const Instruction* pc, int pcpos, IrOp c); +void translateInstOrX(IrBuilder& build, const Instruction* pc, int pcpos, IrOp c); } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/IrUtils.cpp b/CodeGen/src/IrUtils.cpp index 45e2bae09..c5e7c887a 100644 --- a/CodeGen/src/IrUtils.cpp +++ b/CodeGen/src/IrUtils.cpp @@ -14,6 +14,134 @@ namespace Luau namespace CodeGen { +IrValueKind getCmdValueKind(IrCmd cmd) +{ + switch (cmd) + { + case IrCmd::NOP: + return IrValueKind::None; + case IrCmd::LOAD_TAG: + return IrValueKind::Tag; + case IrCmd::LOAD_POINTER: + return IrValueKind::Pointer; + case IrCmd::LOAD_DOUBLE: + return IrValueKind::Double; + case IrCmd::LOAD_INT: + return IrValueKind::Int; + case IrCmd::LOAD_TVALUE: + case IrCmd::LOAD_NODE_VALUE_TV: + return IrValueKind::Tvalue; + case IrCmd::LOAD_ENV: + case IrCmd::GET_ARR_ADDR: + case IrCmd::GET_SLOT_NODE_ADDR: + case IrCmd::GET_HASH_NODE_ADDR: + return IrValueKind::Pointer; + case IrCmd::STORE_TAG: + case IrCmd::STORE_POINTER: + case IrCmd::STORE_DOUBLE: + case IrCmd::STORE_INT: + case IrCmd::STORE_VECTOR: + case IrCmd::STORE_TVALUE: + case IrCmd::STORE_NODE_VALUE_TV: + return IrValueKind::None; + case IrCmd::ADD_INT: + case IrCmd::SUB_INT: + return IrValueKind::Int; + case IrCmd::ADD_NUM: + case IrCmd::SUB_NUM: + case IrCmd::MUL_NUM: + case IrCmd::DIV_NUM: + case IrCmd::MOD_NUM: + case IrCmd::POW_NUM: + case IrCmd::MIN_NUM: + case IrCmd::MAX_NUM: + case IrCmd::UNM_NUM: + case IrCmd::FLOOR_NUM: + case IrCmd::CEIL_NUM: + case IrCmd::ROUND_NUM: + case IrCmd::SQRT_NUM: + case IrCmd::ABS_NUM: + return IrValueKind::Double; + case IrCmd::NOT_ANY: + return IrValueKind::Int; + case IrCmd::JUMP: + case IrCmd::JUMP_IF_TRUTHY: + case IrCmd::JUMP_IF_FALSY: + case IrCmd::JUMP_EQ_TAG: + case IrCmd::JUMP_EQ_INT: + case IrCmd::JUMP_EQ_POINTER: + case IrCmd::JUMP_CMP_NUM: + case IrCmd::JUMP_CMP_ANY: + case IrCmd::JUMP_SLOT_MATCH: + return IrValueKind::None; + case IrCmd::TABLE_LEN: + return IrValueKind::Double; + case IrCmd::NEW_TABLE: + case IrCmd::DUP_TABLE: + return IrValueKind::Pointer; + case IrCmd::TRY_NUM_TO_INDEX: + return IrValueKind::Int; + case IrCmd::TRY_CALL_FASTGETTM: + return IrValueKind::Pointer; + case IrCmd::INT_TO_NUM: + return IrValueKind::Double; + case IrCmd::ADJUST_STACK_TO_REG: + case IrCmd::ADJUST_STACK_TO_TOP: + return IrValueKind::None; + case IrCmd::FASTCALL: + return IrValueKind::None; + case IrCmd::INVOKE_FASTCALL: + return IrValueKind::Int; + case IrCmd::CHECK_FASTCALL_RES: + case IrCmd::DO_ARITH: + case IrCmd::DO_LEN: + case IrCmd::GET_TABLE: + case IrCmd::SET_TABLE: + case IrCmd::GET_IMPORT: + case IrCmd::CONCAT: + case IrCmd::GET_UPVALUE: + case IrCmd::SET_UPVALUE: + case IrCmd::PREPARE_FORN: + case IrCmd::CHECK_TAG: + case IrCmd::CHECK_READONLY: + case IrCmd::CHECK_NO_METATABLE: + case IrCmd::CHECK_SAFE_ENV: + case IrCmd::CHECK_ARRAY_SIZE: + case IrCmd::CHECK_SLOT_MATCH: + case IrCmd::CHECK_NODE_NO_NEXT: + case IrCmd::INTERRUPT: + case IrCmd::CHECK_GC: + case IrCmd::BARRIER_OBJ: + case IrCmd::BARRIER_TABLE_BACK: + case IrCmd::BARRIER_TABLE_FORWARD: + case IrCmd::SET_SAVEDPC: + case IrCmd::CLOSE_UPVALS: + case IrCmd::CAPTURE: + case IrCmd::SETLIST: + case IrCmd::CALL: + case IrCmd::RETURN: + case IrCmd::FORGLOOP: + case IrCmd::FORGLOOP_FALLBACK: + case IrCmd::FORGPREP_XNEXT_FALLBACK: + case IrCmd::COVERAGE: + case IrCmd::FALLBACK_GETGLOBAL: + case IrCmd::FALLBACK_SETGLOBAL: + case IrCmd::FALLBACK_GETTABLEKS: + case IrCmd::FALLBACK_SETTABLEKS: + case IrCmd::FALLBACK_NAMECALL: + case IrCmd::FALLBACK_PREPVARARGS: + case IrCmd::FALLBACK_GETVARARGS: + case IrCmd::FALLBACK_NEWCLOSURE: + case IrCmd::FALLBACK_DUPCLOSURE: + case IrCmd::FALLBACK_FORGPREP: + return IrValueKind::None; + case IrCmd::SUBSTITUTE: + return IrValueKind::Unknown; + } + + LUAU_UNREACHABLE(); +} + static void removeInstUse(IrFunction& function, uint32_t instIdx) { IrInst& inst = function.instructions[instIdx]; diff --git a/CodeGen/src/NativeState.cpp b/CodeGen/src/NativeState.cpp index ddc9c03d1..524796929 100644 --- a/CodeGen/src/NativeState.cpp +++ b/CodeGen/src/NativeState.cpp @@ -45,6 +45,7 @@ void initFallbackTable(NativeState& data) CODEGEN_SET_FALLBACK(LOP_BREAK, 0); // Fallbacks that are called from partial implementation of an instruction + // TODO: these fallbacks should be replaced with special functions that exclude the (redundantly executed) fast path from the fallback CODEGEN_SET_FALLBACK(LOP_GETGLOBAL, 0); CODEGEN_SET_FALLBACK(LOP_SETGLOBAL, 0); CODEGEN_SET_FALLBACK(LOP_GETTABLEKS, 0); diff --git a/CodeGen/src/OptimizeConstProp.cpp b/CodeGen/src/OptimizeConstProp.cpp index f767f5496..7157a18c4 100644 --- a/CodeGen/src/OptimizeConstProp.cpp +++ b/CodeGen/src/OptimizeConstProp.cpp @@ -96,20 +96,17 @@ struct ConstPropState void invalidateTag(IrOp regOp) { - LUAU_ASSERT(regOp.kind == IrOpKind::VmReg); - invalidate(regs[regOp.index], /* invalidateTag */ true, /* invalidateValue */ false); + invalidate(regs[vmRegOp(regOp)], /* invalidateTag */ true, /* invalidateValue */ false); } void invalidateValue(IrOp regOp) { - LUAU_ASSERT(regOp.kind == IrOpKind::VmReg); - invalidate(regs[regOp.index], /* invalidateTag */ false, /* invalidateValue */ true); + invalidate(regs[vmRegOp(regOp)], /* invalidateTag */ false, /* invalidateValue */ true); } void invalidate(IrOp regOp) { - LUAU_ASSERT(regOp.kind == IrOpKind::VmReg); - invalidate(regs[regOp.index], /* invalidateTag */ true, /* invalidateValue */ true); + invalidate(regs[vmRegOp(regOp)], /* invalidateTag */ true, /* invalidateValue */ true); } void invalidateRegistersFrom(int firstReg) @@ -156,17 +153,16 @@ struct ConstPropState void createRegLink(uint32_t instIdx, IrOp regOp) { - LUAU_ASSERT(regOp.kind == IrOpKind::VmReg); LUAU_ASSERT(!instLink.contains(instIdx)); - instLink[instIdx] = RegisterLink{uint8_t(regOp.index), regs[regOp.index].version}; + instLink[instIdx] = RegisterLink{uint8_t(vmRegOp(regOp)), regs[vmRegOp(regOp)].version}; } RegisterInfo* tryGetRegisterInfo(IrOp op) { if (op.kind == IrOpKind::VmReg) { - maxReg = int(op.index) > maxReg ? int(op.index) : maxReg; - return ®s[op.index]; + maxReg = vmRegOp(op) > maxReg ? vmRegOp(op) : maxReg; + return ®s[vmRegOp(op)]; } if (RegisterLink* link = tryGetRegLink(op)) @@ -368,6 +364,9 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& } } break; + case IrCmd::STORE_VECTOR: + state.invalidateValue(inst.a); + break; case IrCmd::STORE_TVALUE: if (inst.a.kind == IrOpKind::VmReg) { @@ -503,15 +502,9 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& } } break; - case IrCmd::AND: - case IrCmd::ANDK: - case IrCmd::OR: - case IrCmd::ORK: - state.invalidate(inst.a); - break; case IrCmd::FASTCALL: case IrCmd::INVOKE_FASTCALL: - handleBuiltinEffects(state, LuauBuiltinFunction(function.uintOp(inst.a)), inst.b.index, function.intOp(inst.f)); + handleBuiltinEffects(state, LuauBuiltinFunction(function.uintOp(inst.a)), vmRegOp(inst.b), function.intOp(inst.f)); break; // These instructions don't have an effect on register/memory state we are tracking @@ -590,7 +583,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& state.invalidateUserCall(); break; case IrCmd::CONCAT: - state.invalidateRegisterRange(inst.a.index, function.uintOp(inst.b)); + state.invalidateRegisterRange(vmRegOp(inst.a), function.uintOp(inst.b)); state.invalidateUserCall(); // TODO: if only strings and numbers are concatenated, there will be no user calls break; case IrCmd::PREPARE_FORN: @@ -605,14 +598,14 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& state.invalidateUserCall(); break; case IrCmd::CALL: - state.invalidateRegistersFrom(inst.a.index); + state.invalidateRegistersFrom(vmRegOp(inst.a)); state.invalidateUserCall(); break; case IrCmd::FORGLOOP: - state.invalidateRegistersFrom(inst.a.index + 2); // Rn and Rn+1 are not modified + state.invalidateRegistersFrom(vmRegOp(inst.a) + 2); // Rn and Rn+1 are not modified break; case IrCmd::FORGLOOP_FALLBACK: - state.invalidateRegistersFrom(inst.a.index + 2); // Rn and Rn+1 are not modified + state.invalidateRegistersFrom(vmRegOp(inst.a) + 2); // Rn and Rn+1 are not modified state.invalidateUserCall(); break; case IrCmd::FORGPREP_XNEXT_FALLBACK: @@ -633,14 +626,14 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& state.invalidateUserCall(); break; case IrCmd::FALLBACK_NAMECALL: - state.invalidate(IrOp{inst.b.kind, inst.b.index + 0u}); - state.invalidate(IrOp{inst.b.kind, inst.b.index + 1u}); + state.invalidate(IrOp{inst.b.kind, vmRegOp(inst.b) + 0u}); + state.invalidate(IrOp{inst.b.kind, vmRegOp(inst.b) + 1u}); state.invalidateUserCall(); break; case IrCmd::FALLBACK_PREPVARARGS: break; case IrCmd::FALLBACK_GETVARARGS: - state.invalidateRegistersFrom(inst.b.index); + state.invalidateRegistersFrom(vmRegOp(inst.b)); break; case IrCmd::FALLBACK_NEWCLOSURE: state.invalidate(inst.b); @@ -649,9 +642,9 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& state.invalidate(inst.b); break; case IrCmd::FALLBACK_FORGPREP: - state.invalidate(IrOp{inst.b.kind, inst.b.index + 0u}); - state.invalidate(IrOp{inst.b.kind, inst.b.index + 1u}); - state.invalidate(IrOp{inst.b.kind, inst.b.index + 2u}); + state.invalidate(IrOp{inst.b.kind, vmRegOp(inst.b) + 0u}); + state.invalidate(IrOp{inst.b.kind, vmRegOp(inst.b) + 1u}); + state.invalidate(IrOp{inst.b.kind, vmRegOp(inst.b) + 2u}); break; } } diff --git a/VM/include/lua.h b/VM/include/lua.h index 783e60a50..3f5e99f47 100644 --- a/VM/include/lua.h +++ b/VM/include/lua.h @@ -29,7 +29,7 @@ enum lua_Status LUA_OK = 0, LUA_YIELD, LUA_ERRRUN, - LUA_ERRSYNTAX, + LUA_ERRSYNTAX, // legacy error code, preserved for compatibility LUA_ERRMEM, LUA_ERRERR, LUA_BREAK, // yielded for a debug breakpoint diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index ff8105b8c..264388bc9 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -17,6 +17,8 @@ #include +LUAU_FASTFLAGVARIABLE(LuauBetterOOMHandling, false) + /* ** {====================================================== ** Error-recovery functions @@ -79,22 +81,17 @@ class lua_exception : public std::exception const char* what() const throw() override { - // LUA_ERRRUN/LUA_ERRSYNTAX pass an object on the stack which is intended to describe the error. - if (status == LUA_ERRRUN || status == LUA_ERRSYNTAX) - { - // Conversion to a string could still fail. For example if a user passes a non-string/non-number argument to `error()`. + // LUA_ERRRUN passes error object on the stack + if (status == LUA_ERRRUN || (status == LUA_ERRSYNTAX && !FFlag::LuauBetterOOMHandling)) if (const char* str = lua_tostring(L, -1)) - { return str; - } - } switch (status) { case LUA_ERRRUN: - return "lua_exception: LUA_ERRRUN (no string/number provided as description)"; + return "lua_exception: runtime error"; case LUA_ERRSYNTAX: - return "lua_exception: LUA_ERRSYNTAX (no string/number provided as description)"; + return "lua_exception: syntax error"; case LUA_ERRMEM: return "lua_exception: " LUA_MEMERRMSG; case LUA_ERRERR: @@ -550,19 +547,42 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e int status = luaD_rawrunprotected(L, func, u); if (status != 0) { + int errstatus = status; + // call user-defined error function (used in xpcall) if (ef) { - // if errfunc fails, we fail with "error in error handling" - if (luaD_rawrunprotected(L, callerrfunc, restorestack(L, ef)) != 0) - status = LUA_ERRERR; + if (FFlag::LuauBetterOOMHandling) + { + // push error object to stack top if it's not already there + if (status != LUA_ERRRUN) + seterrorobj(L, status, L->top); + + // if errfunc fails, we fail with "error in error handling" or "not enough memory" + int err = luaD_rawrunprotected(L, callerrfunc, restorestack(L, ef)); + + // in general we preserve the status, except for cases when the error handler fails + // out of memory is treated specially because it's common for it to be cascading, in which case we preserve the code + if (err == 0) + errstatus = LUA_ERRRUN; + else if (status == LUA_ERRMEM && err == LUA_ERRMEM) + errstatus = LUA_ERRMEM; + else + errstatus = status = LUA_ERRERR; + } + else + { + // if errfunc fails, we fail with "error in error handling" + if (luaD_rawrunprotected(L, callerrfunc, restorestack(L, ef)) != 0) + status = LUA_ERRERR; + } } // since the call failed with an error, we might have to reset the 'active' thread state if (!oldactive) L->isactive = false; - // Restore nCcalls before calling the debugprotectederror callback which may rely on the proper value to have been restored. + // restore nCcalls before calling the debugprotectederror callback which may rely on the proper value to have been restored. L->nCcalls = oldnCcalls; // an error occurred, check if we have a protected error callback @@ -577,7 +597,7 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e StkId oldtop = restorestack(L, old_top); luaF_close(L, oldtop); // close eventual pending closures - seterrorobj(L, status, oldtop); + seterrorobj(L, FFlag::LuauBetterOOMHandling ? errstatus : status, oldtop); L->ci = restoreci(L, old_ci); L->base = L->ci->base; restore_stack_limit(L); diff --git a/VM/src/ltablib.cpp b/VM/src/ltablib.cpp index ddee3a71e..4443be34f 100644 --- a/VM/src/ltablib.cpp +++ b/VM/src/ltablib.cpp @@ -10,7 +10,7 @@ #include "ldebug.h" #include "lvm.h" -LUAU_FASTFLAGVARIABLE(LuauOptimizedSort, false) +LUAU_FASTFLAGVARIABLE(LuauIntrosort, false) static int foreachi(lua_State* L) { @@ -298,120 +298,6 @@ static int tunpack(lua_State* L) return (int)n; } -/* -** {====================================================== -** Quicksort -** (based on `Algorithms in MODULA-3', Robert Sedgewick; -** Addison-Wesley, 1993.) -*/ - -static void set2(lua_State* L, int i, int j) -{ - LUAU_ASSERT(!FFlag::LuauOptimizedSort); - lua_rawseti(L, 1, i); - lua_rawseti(L, 1, j); -} - -static int sort_comp(lua_State* L, int a, int b) -{ - LUAU_ASSERT(!FFlag::LuauOptimizedSort); - if (!lua_isnil(L, 2)) - { // function? - int res; - lua_pushvalue(L, 2); - lua_pushvalue(L, a - 1); // -1 to compensate function - lua_pushvalue(L, b - 2); // -2 to compensate function and `a' - lua_call(L, 2, 1); - res = lua_toboolean(L, -1); - lua_pop(L, 1); - return res; - } - else // a < b? - return lua_lessthan(L, a, b); -} - -static void auxsort(lua_State* L, int l, int u) -{ - LUAU_ASSERT(!FFlag::LuauOptimizedSort); - while (l < u) - { // for tail recursion - int i, j; - // sort elements a[l], a[(l+u)/2] and a[u] - lua_rawgeti(L, 1, l); - lua_rawgeti(L, 1, u); - if (sort_comp(L, -1, -2)) // a[u] < a[l]? - set2(L, l, u); // swap a[l] - a[u] - else - lua_pop(L, 2); - if (u - l == 1) - break; // only 2 elements - i = (l + u) / 2; - lua_rawgeti(L, 1, i); - lua_rawgeti(L, 1, l); - if (sort_comp(L, -2, -1)) // a[i]= P - while (lua_rawgeti(L, 1, ++i), sort_comp(L, -1, -2)) - { - if (i >= u) - luaL_error(L, "invalid order function for sorting"); - lua_pop(L, 1); // remove a[i] - } - // repeat --j until a[j] <= P - while (lua_rawgeti(L, 1, --j), sort_comp(L, -3, -1)) - { - if (j <= l) - luaL_error(L, "invalid order function for sorting"); - lua_pop(L, 1); // remove a[j] - } - if (j < i) - { - lua_pop(L, 3); // pop pivot, a[i], a[j] - break; - } - set2(L, i, j); - } - lua_rawgeti(L, 1, u - 1); - lua_rawgeti(L, 1, i); - set2(L, u - 1, i); // swap pivot (a[u-1]) with a[i] - // a[l..i-1] <= a[i] == P <= a[i+1..u] - // adjust so that smaller half is in [j..i] and larger one in [l..u] - if (i - l < u - i) - { - j = l; - i = i - 1; - l = i + 2; - } - else - { - j = i + 1; - i = u; - u = j - 2; - } - auxsort(L, j, i); // call recursively the smaller one - } // repeat the routine for the larger one -} - typedef int (*SortPredicate)(lua_State* L, const TValue* l, const TValue* r); static int sort_func(lua_State* L, const TValue* l, const TValue* r) @@ -456,30 +342,77 @@ inline int sort_less(lua_State* L, Table* t, int i, int j, SortPredicate pred) return res; } -static void sort_rec(lua_State* L, Table* t, int l, int u, SortPredicate pred) +static void sort_siftheap(lua_State* L, Table* t, int l, int u, SortPredicate pred, int root) +{ + LUAU_ASSERT(l <= u); + int count = u - l + 1; + + // process all elements with two children + while (root * 2 + 2 < count) + { + int left = root * 2 + 1, right = root * 2 + 2; + int next = root; + next = sort_less(L, t, l + next, l + left, pred) ? left : next; + next = sort_less(L, t, l + next, l + right, pred) ? right : next; + + if (next == root) + break; + + sort_swap(L, t, l + root, l + next); + root = next; + } + + // process last element if it has just one child + int lastleft = root * 2 + 1; + if (lastleft == count - 1 && sort_less(L, t, l + root, l + lastleft, pred)) + sort_swap(L, t, l + root, l + lastleft); +} + +static void sort_heap(lua_State* L, Table* t, int l, int u, SortPredicate pred) +{ + LUAU_ASSERT(l <= u); + int count = u - l + 1; + + for (int i = count / 2 - 1; i >= 0; --i) + sort_siftheap(L, t, l, u, pred, i); + + for (int i = count - 1; i > 0; --i) + { + sort_swap(L, t, l, l + i); + sort_siftheap(L, t, l, l + i - 1, pred, 0); + } +} + +static void sort_rec(lua_State* L, Table* t, int l, int u, int limit, SortPredicate pred) { // sort range [l..u] (inclusive, 0-based) while (l < u) { - int i, j; + // if the limit has been reached, quick sort is going over the permitted nlogn complexity, so we fall back to heap sort + if (FFlag::LuauIntrosort && limit == 0) + return sort_heap(L, t, l, u, pred); + // sort elements a[l], a[(l+u)/2] and a[u] + // note: this simultaneously acts as a small sort and a median selector if (sort_less(L, t, u, l, pred)) // a[u] < a[l]? sort_swap(L, t, u, l); // swap a[l] - a[u] if (u - l == 1) break; // only 2 elements - i = l + ((u - l) >> 1); // midpoint - if (sort_less(L, t, i, l, pred)) // a[i]> 1); // midpoint + if (sort_less(L, t, m, l, pred)) // a[m]= P @@ -498,62 +431,71 @@ static void sort_rec(lua_State* L, Table* t, int l, int u, SortPredicate pred) break; sort_swap(L, t, i, j); } - // swap pivot (a[u-1]) with a[i], which is the new midpoint - sort_swap(L, t, u - 1, i); - // a[l..i-1] <= a[i] == P <= a[i+1..u] - // adjust so that smaller half is in [j..i] and larger one in [l..u] - if (i - l < u - i) + + // swap pivot a[p] with a[i], which is the new midpoint + sort_swap(L, t, p, i); + + if (FFlag::LuauIntrosort) { - j = l; - i = i - 1; - l = i + 2; + // adjust limit to allow 1.5 log2N recursive steps + limit = (limit >> 1) + (limit >> 2); + + // a[l..i-1] <= a[i] == P <= a[i+1..u] + // sort smaller half recursively; the larger half is sorted in the next loop iteration + if (i - l < u - i) + { + sort_rec(L, t, l, i - 1, limit, pred); + l = i + 1; + } + else + { + sort_rec(L, t, i + 1, u, limit, pred); + u = i - 1; + } } else { - j = i + 1; - i = u; - u = j - 2; + // a[l..i-1] <= a[i] == P <= a[i+1..u] + // adjust so that smaller half is in [j..i] and larger one in [l..u] + if (i - l < u - i) + { + j = l; + i = i - 1; + l = i + 2; + } + else + { + j = i + 1; + i = u; + u = j - 2; + } + + // sort smaller half recursively; the larger half is sorted in the next loop iteration + sort_rec(L, t, j, i, limit, pred); } - sort_rec(L, t, j, i, pred); // call recursively the smaller one - } // repeat the routine for the larger one + } } static int tsort(lua_State* L) { - if (FFlag::LuauOptimizedSort) - { - luaL_checktype(L, 1, LUA_TTABLE); - Table* t = hvalue(L->base); - int n = luaH_getn(t); - if (t->readonly) - luaG_readonlyerror(L); - - SortPredicate pred = luaV_lessthan; - if (!lua_isnoneornil(L, 2)) // is there a 2nd argument? - { - luaL_checktype(L, 2, LUA_TFUNCTION); - pred = sort_func; - } - lua_settop(L, 2); // make sure there are two arguments + luaL_checktype(L, 1, LUA_TTABLE); + Table* t = hvalue(L->base); + int n = luaH_getn(t); + if (t->readonly) + luaG_readonlyerror(L); - if (n > 0) - sort_rec(L, t, 0, n - 1, pred); - return 0; - } - else + SortPredicate pred = luaV_lessthan; + if (!lua_isnoneornil(L, 2)) // is there a 2nd argument? { - luaL_checktype(L, 1, LUA_TTABLE); - int n = lua_objlen(L, 1); - luaL_checkstack(L, 40, ""); // assume array is smaller than 2^40 - if (!lua_isnoneornil(L, 2)) // is there a 2nd argument? - luaL_checktype(L, 2, LUA_TFUNCTION); - lua_settop(L, 2); // make sure there is two arguments - auxsort(L, 1, n); - return 0; + luaL_checktype(L, 2, LUA_TFUNCTION); + pred = sort_func; } -} + lua_settop(L, 2); // make sure there are two arguments -// }====================================================== + if (n > 0) + sort_rec(L, t, 0, n - 1, n, pred); + return 0; +} static int tcreate(lua_State* L) { diff --git a/tests/AssemblyBuilderX64.test.cpp b/tests/AssemblyBuilderX64.test.cpp index 6aa7aa561..054eca7bf 100644 --- a/tests/AssemblyBuilderX64.test.cpp +++ b/tests/AssemblyBuilderX64.test.cpp @@ -507,6 +507,8 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXConversionInstructionForms") SINGLE_COMPARE(vcvtsi2sd(xmm6, xmm11, dword[rcx + rdx]), 0xc4, 0xe1, 0x23, 0x2a, 0x34, 0x11); SINGLE_COMPARE(vcvtsi2sd(xmm5, xmm10, r13), 0xc4, 0xc1, 0xab, 0x2a, 0xed); SINGLE_COMPARE(vcvtsi2sd(xmm6, xmm11, qword[rcx + rdx]), 0xc4, 0xe1, 0xa3, 0x2a, 0x34, 0x11); + SINGLE_COMPARE(vcvtsd2ss(xmm5, xmm10, xmm11), 0xc4, 0xc1, 0x2b, 0x5a, 0xeb); + SINGLE_COMPARE(vcvtsd2ss(xmm6, xmm11, qword[rcx + rdx]), 0xc4, 0xe1, 0xa3, 0x5a, 0x34, 0x11); } TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXTernaryInstructionForms") diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 53dc99e15..c79bf35ea 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -85,8 +85,8 @@ struct ACFixtureImpl : BaseType { GlobalTypes& globals = this->frontend.globalsForAutocomplete; unfreeze(globals.globalTypes); - LoadDefinitionFileResult result = - loadDefinitionFile(this->frontend.typeChecker, globals, globals.globalScope, source, "@test", /* captureComments */ false); + LoadDefinitionFileResult result = this->frontend.loadDefinitionFile( + globals, globals.globalScope, source, "@test", /* captureComments */ false, /* typeCheckForAutocomplete */ true); freeze(globals.globalTypes); REQUIRE_MESSAGE(result.success, "loadDefinition: unable to load definition file"); @@ -3448,8 +3448,6 @@ TEST_CASE_FIXTURE(ACFixture, "string_contents_is_available_to_callback") TEST_CASE_FIXTURE(ACFixture, "autocomplete_response_perf1" * doctest::timeout(0.5)) { - ScopedFastFlag luauAutocompleteSkipNormalization{"LuauAutocompleteSkipNormalization", true}; - // Build a function type with a large overload set const int parts = 100; std::string source; diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 957d32719..2a32bce2d 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -9,6 +9,7 @@ #include "Luau/StringUtils.h" #include "Luau/BytecodeBuilder.h" #include "Luau/CodeGen.h" +#include "Luau/Frontend.h" #include "doctest.h" #include "ScopedFlags.h" @@ -243,6 +244,24 @@ static StateRef runConformance(const char* name, void (*setup)(lua_State* L) = n return globalState; } +static void* limitedRealloc(void* ud, void* ptr, size_t osize, size_t nsize) +{ + if (nsize == 0) + { + free(ptr); + return nullptr; + } + else if (nsize > 8 * 1024 * 1024) + { + // For testing purposes return null for large allocations so we can generate errors related to memory allocation failures + return nullptr; + } + else + { + return realloc(ptr, nsize); + } +} + TEST_SUITE_BEGIN("Conformance"); TEST_CASE("Assert") @@ -381,6 +400,8 @@ static int cxxthrow(lua_State* L) TEST_CASE("PCall") { + ScopedFastFlag sff("LuauBetterOOMHandling", true); + runConformance("pcall.lua", [](lua_State* L) { lua_pushcfunction(L, cxxthrow, "cxxthrow"); lua_setglobal(L, "cxxthrow"); @@ -395,7 +416,7 @@ TEST_CASE("PCall") }, "resumeerror"); lua_setglobal(L, "resumeerror"); - }); + }, nullptr, lua_newstate(limitedRealloc, nullptr)); } TEST_CASE("Pack") @@ -501,17 +522,15 @@ TEST_CASE("Types") { runConformance("types.lua", [](lua_State* L) { Luau::NullModuleResolver moduleResolver; - Luau::InternalErrorReporter iceHandler; - Luau::BuiltinTypes builtinTypes; - Luau::GlobalTypes globals{Luau::NotNull{&builtinTypes}}; - Luau::TypeChecker env(globals.globalScope, &moduleResolver, Luau::NotNull{&builtinTypes}, &iceHandler); - - Luau::registerBuiltinGlobals(env, globals); - Luau::freeze(globals.globalTypes); + Luau::NullFileResolver fileResolver; + Luau::NullConfigResolver configResolver; + Luau::Frontend frontend{&fileResolver, &configResolver}; + Luau::registerBuiltinGlobals(frontend, frontend.globals); + Luau::freeze(frontend.globals.globalTypes); lua_newtable(L); - for (const auto& [name, binding] : globals.globalScope->bindings) + for (const auto& [name, binding] : frontend.globals.globalScope->bindings) { populateRTTI(L, binding.typeId); lua_setfield(L, -2, toString(name).c_str()); @@ -882,7 +901,7 @@ TEST_CASE("ApiIter") TEST_CASE("ApiCalls") { - StateRef globalState = runConformance("apicalls.lua"); + StateRef globalState = runConformance("apicalls.lua", nullptr, nullptr, lua_newstate(limitedRealloc, nullptr)); lua_State* L = globalState.get(); // lua_call @@ -981,6 +1000,55 @@ TEST_CASE("ApiCalls") CHECK(lua_tonumber(L, -1) == 4); lua_pop(L, 1); } + + ScopedFastFlag sff("LuauBetterOOMHandling", true); + + // lua_pcall on OOM + { + lua_getfield(L, LUA_GLOBALSINDEX, "largealloc"); + int res = lua_pcall(L, 0, 0, 0); + CHECK(res == LUA_ERRMEM); + } + + // lua_pcall on OOM with an error handler + { + lua_getfield(L, LUA_GLOBALSINDEX, "oops"); + lua_getfield(L, LUA_GLOBALSINDEX, "largealloc"); + int res = lua_pcall(L, 0, 1, -2); + CHECK(res == LUA_ERRMEM); + CHECK((lua_isstring(L, -1) && strcmp(lua_tostring(L, -1), "oops") == 0)); + lua_pop(L, 1); + } + + // lua_pcall on OOM with an error handler that errors + { + lua_getfield(L, LUA_GLOBALSINDEX, "error"); + lua_getfield(L, LUA_GLOBALSINDEX, "largealloc"); + int res = lua_pcall(L, 0, 1, -2); + CHECK(res == LUA_ERRERR); + CHECK((lua_isstring(L, -1) && strcmp(lua_tostring(L, -1), "error in error handling") == 0)); + lua_pop(L, 1); + } + + // lua_pcall on OOM with an error handler that OOMs + { + lua_getfield(L, LUA_GLOBALSINDEX, "largealloc"); + lua_getfield(L, LUA_GLOBALSINDEX, "largealloc"); + int res = lua_pcall(L, 0, 1, -2); + CHECK(res == LUA_ERRMEM); + CHECK((lua_isstring(L, -1) && strcmp(lua_tostring(L, -1), "not enough memory") == 0)); + lua_pop(L, 1); + } + + // lua_pcall on error with an error handler that OOMs + { + lua_getfield(L, LUA_GLOBALSINDEX, "largealloc"); + lua_getfield(L, LUA_GLOBALSINDEX, "error"); + int res = lua_pcall(L, 0, 1, -2); + CHECK(res == LUA_ERRERR); + CHECK((lua_isstring(L, -1) && strcmp(lua_tostring(L, -1), "error in error handling") == 0)); + lua_pop(L, 1); + } } TEST_CASE("ApiAtoms") @@ -1051,26 +1119,7 @@ TEST_CASE("ExceptionObject") return ExceptionResult{false, ""}; }; - auto reallocFunc = [](void* /*ud*/, void* ptr, size_t /*osize*/, size_t nsize) -> void* { - if (nsize == 0) - { - free(ptr); - return nullptr; - } - else if (nsize > 512 * 1024) - { - // For testing purposes return null for large allocations - // so we can generate exceptions related to memory allocation - // failures. - return nullptr; - } - else - { - return realloc(ptr, nsize); - } - }; - - StateRef globalState = runConformance("exceptions.lua", nullptr, nullptr, lua_newstate(reallocFunc, nullptr)); + StateRef globalState = runConformance("exceptions.lua", nullptr, nullptr, lua_newstate(limitedRealloc, nullptr)); lua_State* L = globalState.get(); { diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index 4d2e83fc2..aebf177cd 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -506,7 +506,8 @@ void Fixture::validateErrors(const std::vector& errors) LoadDefinitionFileResult Fixture::loadDefinition(const std::string& source) { unfreeze(frontend.globals.globalTypes); - LoadDefinitionFileResult result = frontend.loadDefinitionFile(source, "@test", /* captureComments */ false); + LoadDefinitionFileResult result = + frontend.loadDefinitionFile(frontend.globals, frontend.globals.globalScope, source, "@test", /* captureComments */ false); freeze(frontend.globals.globalTypes); if (result.module) @@ -521,9 +522,9 @@ BuiltinsFixture::BuiltinsFixture(bool freeze, bool prepareAutocomplete) Luau::unfreeze(frontend.globals.globalTypes); Luau::unfreeze(frontend.globalsForAutocomplete.globalTypes); - registerBuiltinGlobals(frontend); + registerBuiltinGlobals(frontend, frontend.globals); if (prepareAutocomplete) - registerBuiltinGlobals(frontend.typeCheckerForAutocomplete, frontend.globalsForAutocomplete); + registerBuiltinGlobals(frontend, frontend.globalsForAutocomplete, /*typeCheckForAutocomplete*/ true); registerTestTypes(); Luau::freeze(frontend.globals.globalTypes); @@ -594,8 +595,12 @@ void registerHiddenTypes(Frontend* frontend) TypeId t = globals.globalTypes.addType(GenericType{"T"}); GenericTypeDefinition genericT{t}; + TypeId u = globals.globalTypes.addType(GenericType{"U"}); + GenericTypeDefinition genericU{u}; + ScopePtr globalScope = globals.globalScope; globalScope->exportedTypeBindings["Not"] = TypeFun{{genericT}, globals.globalTypes.addType(NegationType{t})}; + globalScope->exportedTypeBindings["Mt"] = TypeFun{{genericT, genericU}, globals.globalTypes.addType(MetatableType{t, u})}; globalScope->exportedTypeBindings["fun"] = TypeFun{{}, frontend->builtinTypes->functionType}; globalScope->exportedTypeBindings["cls"] = TypeFun{{}, frontend->builtinTypes->classType}; globalScope->exportedTypeBindings["err"] = TypeFun{{}, frontend->builtinTypes->errorType}; diff --git a/tests/Fixture.h b/tests/Fixture.h index 4c49593cc..8d48ab1dc 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -94,7 +94,6 @@ struct Fixture TypeId requireTypeAlias(const std::string& name); ScopedFastFlag sff_DebugLuauFreezeArena; - ScopedFastFlag luauLintInTypecheck{"LuauLintInTypecheck", true}; TestFileResolver fileResolver; TestConfigResolver configResolver; diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index 3b1ec4ad1..13fd6e0f8 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -877,7 +877,7 @@ TEST_CASE_FIXTURE(FrontendFixture, "environments") ScopePtr testScope = frontend.addEnvironment("test"); unfreeze(frontend.globals.globalTypes); - loadDefinitionFile(frontend.typeChecker, frontend.globals, testScope, R"( + frontend.loadDefinitionFile(frontend.globals, testScope, R"( export type Foo = number | string )", "@test", /* captureComments */ false); diff --git a/tests/IrCallWrapperX64.test.cpp b/tests/IrCallWrapperX64.test.cpp index 8c7b1393f..c8918dbde 100644 --- a/tests/IrCallWrapperX64.test.cpp +++ b/tests/IrCallWrapperX64.test.cpp @@ -12,7 +12,7 @@ class IrCallWrapperX64Fixture public: IrCallWrapperX64Fixture() : build(/* logText */ true, ABIX64::Windows) - , regs(function) + , regs(build, function) , callWrap(regs, build, ~0u) { } @@ -46,8 +46,8 @@ TEST_SUITE_BEGIN("IrCallWrapperX64"); TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "SimpleRegs") { - ScopedRegX64 tmp1{regs, regs.takeReg(rax)}; - ScopedRegX64 tmp2{regs, regs.takeReg(rArg2)}; + ScopedRegX64 tmp1{regs, regs.takeReg(rax, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg2, kInvalidInstIdx)}; callWrap.addArgument(SizeX64::qword, tmp1); callWrap.addArgument(SizeX64::qword, tmp2); // Already in its place callWrap.call(qword[r12]); @@ -60,7 +60,7 @@ TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "SimpleRegs") TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "TrickyUse1") { - ScopedRegX64 tmp1{regs, regs.takeReg(rArg1)}; + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; callWrap.addArgument(SizeX64::qword, tmp1.reg); // Already in its place callWrap.addArgument(SizeX64::qword, tmp1.release()); callWrap.call(qword[r12]); @@ -73,7 +73,7 @@ TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "TrickyUse1") TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "TrickyUse2") { - ScopedRegX64 tmp1{regs, regs.takeReg(rArg1)}; + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; callWrap.addArgument(SizeX64::qword, qword[tmp1.reg]); callWrap.addArgument(SizeX64::qword, tmp1.release()); callWrap.call(qword[r12]); @@ -87,8 +87,8 @@ TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "TrickyUse2") TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "SimpleMemImm") { - ScopedRegX64 tmp1{regs, regs.takeReg(rax)}; - ScopedRegX64 tmp2{regs, regs.takeReg(rsi)}; + ScopedRegX64 tmp1{regs, regs.takeReg(rax, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rsi, kInvalidInstIdx)}; callWrap.addArgument(SizeX64::dword, 32); callWrap.addArgument(SizeX64::dword, -1); callWrap.addArgument(SizeX64::qword, qword[r14 + 32]); @@ -106,7 +106,7 @@ TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "SimpleMemImm") TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "SimpleStackArgs") { - ScopedRegX64 tmp{regs, regs.takeReg(rax)}; + ScopedRegX64 tmp{regs, regs.takeReg(rax, kInvalidInstIdx)}; callWrap.addArgument(SizeX64::qword, tmp); callWrap.addArgument(SizeX64::qword, qword[r14 + 16]); callWrap.addArgument(SizeX64::qword, qword[r14 + 32]); @@ -148,10 +148,10 @@ TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "FixedRegisters") TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "EasyInterference") { - ScopedRegX64 tmp1{regs, regs.takeReg(rdi)}; - ScopedRegX64 tmp2{regs, regs.takeReg(rsi)}; - ScopedRegX64 tmp3{regs, regs.takeReg(rArg2)}; - ScopedRegX64 tmp4{regs, regs.takeReg(rArg1)}; + ScopedRegX64 tmp1{regs, regs.takeReg(rdi, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rsi, kInvalidInstIdx)}; + ScopedRegX64 tmp3{regs, regs.takeReg(rArg2, kInvalidInstIdx)}; + ScopedRegX64 tmp4{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; callWrap.addArgument(SizeX64::qword, tmp1); callWrap.addArgument(SizeX64::qword, tmp2); callWrap.addArgument(SizeX64::qword, tmp3); @@ -169,8 +169,8 @@ TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "EasyInterference") TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "FakeInterference") { - ScopedRegX64 tmp1{regs, regs.takeReg(rArg1)}; - ScopedRegX64 tmp2{regs, regs.takeReg(rArg2)}; + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg2, kInvalidInstIdx)}; callWrap.addArgument(SizeX64::qword, qword[tmp1.release() + 8]); callWrap.addArgument(SizeX64::qword, qword[tmp2.release() + 8]); callWrap.call(qword[r12]); @@ -184,10 +184,10 @@ TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "FakeInterference") TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "HardInterferenceInt") { - ScopedRegX64 tmp1{regs, regs.takeReg(rArg4)}; - ScopedRegX64 tmp2{regs, regs.takeReg(rArg3)}; - ScopedRegX64 tmp3{regs, regs.takeReg(rArg2)}; - ScopedRegX64 tmp4{regs, regs.takeReg(rArg1)}; + ScopedRegX64 tmp1{regs, regs.takeReg(rArg4, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg3, kInvalidInstIdx)}; + ScopedRegX64 tmp3{regs, regs.takeReg(rArg2, kInvalidInstIdx)}; + ScopedRegX64 tmp4{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; callWrap.addArgument(SizeX64::qword, tmp1); callWrap.addArgument(SizeX64::qword, tmp2); callWrap.addArgument(SizeX64::qword, tmp3); @@ -207,10 +207,10 @@ TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "HardInterferenceInt") TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "HardInterferenceInt2") { - ScopedRegX64 tmp1{regs, regs.takeReg(rArg4d)}; - ScopedRegX64 tmp2{regs, regs.takeReg(rArg3d)}; - ScopedRegX64 tmp3{regs, regs.takeReg(rArg2d)}; - ScopedRegX64 tmp4{regs, regs.takeReg(rArg1d)}; + ScopedRegX64 tmp1{regs, regs.takeReg(rArg4d, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg3d, kInvalidInstIdx)}; + ScopedRegX64 tmp3{regs, regs.takeReg(rArg2d, kInvalidInstIdx)}; + ScopedRegX64 tmp4{regs, regs.takeReg(rArg1d, kInvalidInstIdx)}; callWrap.addArgument(SizeX64::dword, tmp1); callWrap.addArgument(SizeX64::dword, tmp2); callWrap.addArgument(SizeX64::dword, tmp3); @@ -230,8 +230,8 @@ TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "HardInterferenceInt2") TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "HardInterferenceFp") { - ScopedRegX64 tmp1{regs, regs.takeReg(xmm1)}; - ScopedRegX64 tmp2{regs, regs.takeReg(xmm0)}; + ScopedRegX64 tmp1{regs, regs.takeReg(xmm1, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(xmm0, kInvalidInstIdx)}; callWrap.addArgument(SizeX64::xmmword, tmp1); callWrap.addArgument(SizeX64::xmmword, tmp2); callWrap.call(qword[r12]); @@ -246,10 +246,10 @@ TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "HardInterferenceFp") TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "HardInterferenceBoth") { - ScopedRegX64 int1{regs, regs.takeReg(rArg2)}; - ScopedRegX64 int2{regs, regs.takeReg(rArg1)}; - ScopedRegX64 fp1{regs, regs.takeReg(xmm3)}; - ScopedRegX64 fp2{regs, regs.takeReg(xmm2)}; + ScopedRegX64 int1{regs, regs.takeReg(rArg2, kInvalidInstIdx)}; + ScopedRegX64 int2{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; + ScopedRegX64 fp1{regs, regs.takeReg(xmm3, kInvalidInstIdx)}; + ScopedRegX64 fp2{regs, regs.takeReg(xmm2, kInvalidInstIdx)}; callWrap.addArgument(SizeX64::qword, int1); callWrap.addArgument(SizeX64::qword, int2); callWrap.addArgument(SizeX64::xmmword, fp1); @@ -269,8 +269,8 @@ TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "HardInterferenceBoth") TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "FakeMultiuseInterferenceMem") { - ScopedRegX64 tmp1{regs, regs.takeReg(rArg1)}; - ScopedRegX64 tmp2{regs, regs.takeReg(rArg2)}; + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg2, kInvalidInstIdx)}; callWrap.addArgument(SizeX64::qword, qword[tmp1.reg + tmp2.reg + 8]); callWrap.addArgument(SizeX64::qword, qword[tmp2.reg + 16]); tmp1.release(); @@ -286,8 +286,8 @@ TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "FakeMultiuseInterferenceMem") TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "HardMultiuseInterferenceMem1") { - ScopedRegX64 tmp1{regs, regs.takeReg(rArg1)}; - ScopedRegX64 tmp2{regs, regs.takeReg(rArg2)}; + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg2, kInvalidInstIdx)}; callWrap.addArgument(SizeX64::qword, qword[tmp1.reg + tmp2.reg + 8]); callWrap.addArgument(SizeX64::qword, qword[tmp1.reg + 16]); tmp1.release(); @@ -304,8 +304,8 @@ TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "HardMultiuseInterferenceMem1") TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "HardMultiuseInterferenceMem2") { - ScopedRegX64 tmp1{regs, regs.takeReg(rArg1)}; - ScopedRegX64 tmp2{regs, regs.takeReg(rArg2)}; + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg2, kInvalidInstIdx)}; callWrap.addArgument(SizeX64::qword, qword[tmp1.reg + tmp2.reg + 8]); callWrap.addArgument(SizeX64::qword, qword[tmp1.reg + tmp2.reg + 16]); tmp1.release(); @@ -322,9 +322,9 @@ TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "HardMultiuseInterferenceMem2") TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "HardMultiuseInterferenceMem3") { - ScopedRegX64 tmp1{regs, regs.takeReg(rArg3)}; - ScopedRegX64 tmp2{regs, regs.takeReg(rArg2)}; - ScopedRegX64 tmp3{regs, regs.takeReg(rArg1)}; + ScopedRegX64 tmp1{regs, regs.takeReg(rArg3, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg2, kInvalidInstIdx)}; + ScopedRegX64 tmp3{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; callWrap.addArgument(SizeX64::qword, qword[tmp1.reg + tmp2.reg + 8]); callWrap.addArgument(SizeX64::qword, qword[tmp2.reg + tmp3.reg + 16]); callWrap.addArgument(SizeX64::qword, qword[tmp3.reg + tmp1.reg + 16]); @@ -345,7 +345,7 @@ TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "HardMultiuseInterferenceMem3") TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "InterferenceWithCallArg1") { - ScopedRegX64 tmp1{regs, regs.takeReg(rArg1)}; + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; callWrap.addArgument(SizeX64::qword, qword[tmp1.reg + 8]); callWrap.call(qword[tmp1.release() + 16]); @@ -358,8 +358,8 @@ TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "InterferenceWithCallArg1") TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "InterferenceWithCallArg2") { - ScopedRegX64 tmp1{regs, regs.takeReg(rArg1)}; - ScopedRegX64 tmp2{regs, regs.takeReg(rArg2)}; + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg2, kInvalidInstIdx)}; callWrap.addArgument(SizeX64::qword, tmp2); callWrap.call(qword[tmp1.release() + 16]); @@ -372,7 +372,7 @@ TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "InterferenceWithCallArg2") TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "InterferenceWithCallArg3") { - ScopedRegX64 tmp1{regs, regs.takeReg(rArg1)}; + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; callWrap.addArgument(SizeX64::qword, tmp1.reg); callWrap.call(qword[tmp1.release() + 16]); @@ -385,7 +385,7 @@ TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "WithLastIrInstUse1") { IrInst irInst1; IrOp irOp1 = {IrOpKind::Inst, 0}; - irInst1.regX64 = regs.takeReg(xmm0); + irInst1.regX64 = regs.takeReg(xmm0, irOp1.index); irInst1.lastUse = 1; function.instructions.push_back(irInst1); callWrap.instIdx = irInst1.lastUse; @@ -404,7 +404,7 @@ TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "WithLastIrInstUse2") { IrInst irInst1; IrOp irOp1 = {IrOpKind::Inst, 0}; - irInst1.regX64 = regs.takeReg(xmm0); + irInst1.regX64 = regs.takeReg(xmm0, irOp1.index); irInst1.lastUse = 1; function.instructions.push_back(irInst1); callWrap.instIdx = irInst1.lastUse; @@ -424,7 +424,7 @@ TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "WithLastIrInstUse3") { IrInst irInst1; IrOp irOp1 = {IrOpKind::Inst, 0}; - irInst1.regX64 = regs.takeReg(xmm0); + irInst1.regX64 = regs.takeReg(xmm0, irOp1.index); irInst1.lastUse = 1; function.instructions.push_back(irInst1); callWrap.instIdx = irInst1.lastUse; @@ -443,12 +443,12 @@ TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "WithLastIrInstUse4") { IrInst irInst1; IrOp irOp1 = {IrOpKind::Inst, 0}; - irInst1.regX64 = regs.takeReg(rax); + irInst1.regX64 = regs.takeReg(rax, irOp1.index); irInst1.lastUse = 1; function.instructions.push_back(irInst1); callWrap.instIdx = irInst1.lastUse; - ScopedRegX64 tmp{regs, regs.takeReg(rdx)}; + ScopedRegX64 tmp{regs, regs.takeReg(rdx, kInvalidInstIdx)}; callWrap.addArgument(SizeX64::qword, r15); callWrap.addArgument(SizeX64::qword, irInst1.regX64, irOp1); callWrap.addArgument(SizeX64::qword, tmp); @@ -464,8 +464,8 @@ TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "WithLastIrInstUse4") TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "ExtraCoverage") { - ScopedRegX64 tmp1{regs, regs.takeReg(rArg1)}; - ScopedRegX64 tmp2{regs, regs.takeReg(rArg2)}; + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg2, kInvalidInstIdx)}; callWrap.addArgument(SizeX64::qword, addr[r12 + 8]); callWrap.addArgument(SizeX64::qword, addr[r12 + 16]); callWrap.addArgument(SizeX64::xmmword, xmmword[r13]); @@ -481,4 +481,42 @@ TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "ExtraCoverage") )"); } +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "AddressInStackArguments") +{ + callWrap.addArgument(SizeX64::dword, 1); + callWrap.addArgument(SizeX64::dword, 2); + callWrap.addArgument(SizeX64::dword, 3); + callWrap.addArgument(SizeX64::dword, 4); + callWrap.addArgument(SizeX64::qword, addr[r12 + 16]); + callWrap.call(qword[r14]); + + checkMatch(R"( + lea rax,none ptr [r12+010h] + mov qword ptr [rsp+020h],rax + mov ecx,1 + mov edx,2 + mov r8d,3 + mov r9d,4 + call qword ptr [r14] +)"); +} + +TEST_CASE_FIXTURE(IrCallWrapperX64Fixture, "ImmediateConflictWithFunction") +{ + ScopedRegX64 tmp1{regs, regs.takeReg(rArg1, kInvalidInstIdx)}; + ScopedRegX64 tmp2{regs, regs.takeReg(rArg2, kInvalidInstIdx)}; + + callWrap.addArgument(SizeX64::dword, 1); + callWrap.addArgument(SizeX64::dword, 2); + callWrap.call(qword[tmp1.release() + tmp2.release()]); + + checkMatch(R"( + mov rax,rcx + mov ecx,1 + mov rbx,rdx + mov edx,2 + call qword ptr [rax+rbx] +)"); +} + TEST_SUITE_END(); diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index 8bef5922f..54a1f44cb 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -1273,7 +1273,7 @@ TEST_CASE_FIXTURE(Fixture, "use_all_parent_scopes_for_globals") { ScopePtr testScope = frontend.addEnvironment("Test"); unfreeze(frontend.globals.globalTypes); - loadDefinitionFile(frontend.typeChecker, frontend.globals, testScope, R"( + frontend.loadDefinitionFile(frontend.globals, testScope, R"( declare Foo: number )", "@test", /* captureComments */ false); diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index 4378bab8b..6552a24da 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -748,6 +748,20 @@ TEST_CASE_FIXTURE(NormalizeFixture, "narrow_union_of_classes_with_intersection") CHECK("Child" == toString(normal("(Child | Unrelated) & Child"))); } +TEST_CASE_FIXTURE(NormalizeFixture, "intersection_of_metatables_where_the_metatable_is_top_or_bottom") +{ + ScopedFastFlag sff{"LuauNormalizeMetatableFixes", true}; + + CHECK("{ @metatable *error-type*, {| |} }" == toString(normal("Mt<{}, any> & Mt<{}, err>"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "crazy_metatable") +{ + ScopedFastFlag sff{"LuauNormalizeMetatableFixes", true}; + + CHECK("never" == toString(normal("Mt<{}, number> & Mt<{}, string>"))); +} + TEST_CASE_FIXTURE(NormalizeFixture, "negations_of_classes") { ScopedFastFlag sffs[] = { diff --git a/tests/TypeInfer.definitions.test.cpp b/tests/TypeInfer.definitions.test.cpp index f3f464130..d67997574 100644 --- a/tests/TypeInfer.definitions.test.cpp +++ b/tests/TypeInfer.definitions.test.cpp @@ -78,7 +78,7 @@ TEST_CASE_FIXTURE(Fixture, "definition_file_loading") TEST_CASE_FIXTURE(Fixture, "load_definition_file_errors_do_not_pollute_global_scope") { unfreeze(frontend.globals.globalTypes); - LoadDefinitionFileResult parseFailResult = loadDefinitionFile(frontend.typeChecker, frontend.globals, frontend.globals.globalScope, R"( + LoadDefinitionFileResult parseFailResult = frontend.loadDefinitionFile(frontend.globals, frontend.globals.globalScope, R"( declare foo )", "@test", /* captureComments */ false); @@ -88,7 +88,7 @@ TEST_CASE_FIXTURE(Fixture, "load_definition_file_errors_do_not_pollute_global_sc std::optional fooTy = tryGetGlobalBinding(frontend.globals, "foo"); CHECK(!fooTy.has_value()); - LoadDefinitionFileResult checkFailResult = loadDefinitionFile(frontend.typeChecker, frontend.globals, frontend.globals.globalScope, R"( + LoadDefinitionFileResult checkFailResult = frontend.loadDefinitionFile(frontend.globals, frontend.globals.globalScope, R"( local foo: string = 123 declare bar: typeof(foo) )", @@ -140,7 +140,7 @@ TEST_CASE_FIXTURE(Fixture, "definition_file_classes") TEST_CASE_FIXTURE(Fixture, "class_definitions_cannot_overload_non_function") { unfreeze(frontend.globals.globalTypes); - LoadDefinitionFileResult result = loadDefinitionFile(frontend.typeChecker, frontend.globals, frontend.globals.globalScope, R"( + LoadDefinitionFileResult result = frontend.loadDefinitionFile(frontend.globals, frontend.globals.globalScope, R"( declare class A X: number X: string @@ -161,7 +161,7 @@ TEST_CASE_FIXTURE(Fixture, "class_definitions_cannot_overload_non_function") TEST_CASE_FIXTURE(Fixture, "class_definitions_cannot_extend_non_class") { unfreeze(frontend.globals.globalTypes); - LoadDefinitionFileResult result = loadDefinitionFile(frontend.typeChecker, frontend.globals, frontend.globals.globalScope, R"( + LoadDefinitionFileResult result = frontend.loadDefinitionFile(frontend.globals, frontend.globals.globalScope, R"( type NotAClass = {} declare class Foo extends NotAClass @@ -182,7 +182,7 @@ TEST_CASE_FIXTURE(Fixture, "class_definitions_cannot_extend_non_class") TEST_CASE_FIXTURE(Fixture, "no_cyclic_defined_classes") { unfreeze(frontend.globals.globalTypes); - LoadDefinitionFileResult result = loadDefinitionFile(frontend.typeChecker, frontend.globals, frontend.globals.globalScope, R"( + LoadDefinitionFileResult result = frontend.loadDefinitionFile(frontend.globals, frontend.globals.globalScope, R"( declare class Foo extends Bar end @@ -397,7 +397,7 @@ TEST_CASE_FIXTURE(Fixture, "class_definition_string_props") TEST_CASE_FIXTURE(Fixture, "class_definitions_reference_other_classes") { unfreeze(frontend.globals.globalTypes); - LoadDefinitionFileResult result = loadDefinitionFile(frontend.typeChecker, frontend.globals, frontend.globals.globalScope, R"( + LoadDefinitionFileResult result = frontend.loadDefinitionFile(frontend.globals, frontend.globals.globalScope, R"( declare class Channel Messages: { Message } OnMessage: (message: Message) -> () diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index b3b2e4c94..b97848176 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -874,7 +874,7 @@ TEST_CASE_FIXTURE(Fixture, "generic_table_method") std::vector args = flatten(ftv->argTypes).first; TypeId argType = args.at(1); - CHECK_MESSAGE(get(argType), "Should be generic: " << *barType); + CHECK_MESSAGE(get(argType), "Should be generic: " << *barType); } TEST_CASE_FIXTURE(Fixture, "correctly_instantiate_polymorphic_member_functions") diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 38e7e2f31..87419debb 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -477,10 +477,10 @@ TEST_CASE_FIXTURE(Fixture, "free_options_cannot_be_unified_together") std::unique_ptr scope = std::make_unique(builtinTypes->anyTypePack); - TypeId free1 = arena.addType(FreeTypePack{scope.get()}); + TypeId free1 = arena.addType(FreeType{scope.get()}); TypeId option1 = arena.addType(UnionType{{nilType, free1}}); - TypeId free2 = arena.addType(FreeTypePack{scope.get()}); + TypeId free2 = arena.addType(FreeType{scope.get()}); TypeId option2 = arena.addType(UnionType{{nilType, free2}}); InternalErrorReporter iceHandler; diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 7e317f2ef..3088235ae 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -490,8 +490,13 @@ struct FindFreeTypes return !foundOne; } - template - bool operator()(ID, Unifiable::Free) + bool operator()(TypeId, FreeType) + { + foundOne = true; + return false; + } + + bool operator()(TypePackId, FreeTypePack) { foundOne = true; return false; diff --git a/tests/TypePack.test.cpp b/tests/TypePack.test.cpp index 20404434a..7d8ed38f7 100644 --- a/tests/TypePack.test.cpp +++ b/tests/TypePack.test.cpp @@ -25,7 +25,7 @@ struct TypePackFixture TypePackId freshTypePack() { - typePacks.emplace_back(new TypePackVar{Unifiable::Free{TypeLevel{}}}); + typePacks.emplace_back(new TypePackVar{FreeTypePack{TypeLevel{}}}); return typePacks.back().get(); } diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index 64ba63c8d..3f0becc54 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -74,7 +74,7 @@ TEST_CASE_FIXTURE(Fixture, "return_type_of_function_is_parenthesized_if_not_just TEST_CASE_FIXTURE(Fixture, "return_type_of_function_is_parenthesized_if_tail_is_free") { auto emptyArgumentPack = TypePackVar{TypePack{}}; - auto free = Unifiable::Free(TypeLevel()); + auto free = FreeTypePack(TypeLevel()); auto freePack = TypePackVar{TypePackVariant{free}}; auto returnPack = TypePackVar{TypePack{{builtinTypes->numberType}, &freePack}}; auto returnsTwo = Type(FunctionType(frontend.globals.globalScope->level, &emptyArgumentPack, &returnPack)); diff --git a/tests/conformance/apicalls.lua b/tests/conformance/apicalls.lua index 274166237..8db62d96c 100644 --- a/tests/conformance/apicalls.lua +++ b/tests/conformance/apicalls.lua @@ -22,4 +22,12 @@ function getpi() return pi end +function largealloc() + table.create(1000000) +end + +function oops() + return "oops" +end + return('OK') diff --git a/tests/conformance/pcall.lua b/tests/conformance/pcall.lua index 969209fc4..b94f7972e 100644 --- a/tests/conformance/pcall.lua +++ b/tests/conformance/pcall.lua @@ -161,4 +161,11 @@ checkresults({ false, "ok" }, xpcall(recurse, function() return string.reverse(" -- however, if xpcall handler itself runs out of extra stack space, we get "error in error handling" checkresults({ false, "error in error handling" }, xpcall(recurse, function() return recurse(calllimit) end, calllimit - 2)) +-- simulate OOM and make sure we can catch it with pcall or xpcall +checkresults({ false, "not enough memory" }, pcall(function() table.create(1e6) end)) +checkresults({ false, "not enough memory" }, xpcall(function() table.create(1e6) end, function(e) return e end)) +checkresults({ false, "oops" }, xpcall(function() table.create(1e6) end, function(e) return "oops" end)) +checkresults({ false, "error in error handling" }, xpcall(function() error("oops") end, function(e) table.create(1e6) end)) +checkresults({ false, "not enough memory" }, xpcall(function() table.create(1e6) end, function(e) table.create(1e6) end)) + return 'OK' diff --git a/tests/conformance/sort.lua b/tests/conformance/sort.lua index 693a10dc5..3c2c20dd4 100644 --- a/tests/conformance/sort.lua +++ b/tests/conformance/sort.lua @@ -99,12 +99,12 @@ a = {" table.sort(a) check(a) --- TODO: assert that pcall returns false for new sort implementation (table is modified during sorting) -pcall(table.sort, a, function (x, y) +local ok = pcall(table.sort, a, function (x, y) loadstring(string.format("a[%q] = ''", x))() collectgarbage() return x + + + + + count + capacity + + capacity + data + + + + + + + impl + + + + + + impl + + + + diff --git a/tools/test_dcr.py b/tools/test_dcr.py index d30490b30..817d08313 100644 --- a/tools/test_dcr.py +++ b/tools/test_dcr.py @@ -107,6 +107,12 @@ def main(): action="store_true", help="Write a new faillist.txt after running tests.", ) + parser.add_argument( + "--lti", + dest="lti", + action="store_true", + help="Run the tests with local type inference enabled.", + ) parser.add_argument("--randomize", action="store_true", help="Pick a random seed") @@ -120,13 +126,19 @@ def main(): args = parser.parse_args() + if args.write and args.lti: + print_stderr( + "Cannot run test_dcr.py with --write *and* --lti. You don't want to commit local type inference faillist.txt yet." + ) + sys.exit(1) + failList = loadFailList() - commandLine = [ - args.path, - "--reporters=xml", - "--fflags=true,DebugLuauDeferredConstraintResolution=true", - ] + flags = ["true", "DebugLuauDeferredConstraintResolution"] + if args.lti: + flags.append("DebugLuauLocalTypeInference") + + commandLine = [args.path, "--reporters=xml", "--fflags=" + ",".join(flags)] if args.random_seed: commandLine.append("--random-seed=" + str(args.random_seed))