Skip to content

Commit

Permalink
Sync to upstream/release/512 (#330)
Browse files Browse the repository at this point in the history
- Improve refinement support for unions, in particular it's now possible to implement tagged unions as a union of tables where individual branches use a string literal type for one of the fields.
- Fix `string.split` type information
- Optimize `select(_, ...)` to run in constant time (~2.7x faster on VariadicSelect benchmark)
- Improve debug line information for multi-line assignments
- Improve compilation of table literals when table keys are constant expressions/variables
- Use forward GC barrier for `setmetatable` which slightly accelerates GC progress
  • Loading branch information
zeux authored Jan 27, 2022
1 parent 4b96f7e commit 2f989fc
Show file tree
Hide file tree
Showing 49 changed files with 1,771 additions and 1,122 deletions.
15 changes: 15 additions & 0 deletions Analysis/include/Luau/AstQuery.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,21 @@ struct ExprOrLocal
{
return expr ? expr->location : (local ? local->location : std::optional<Location>{});
}
std::optional<AstName> getName()
{
if (expr)
{
if (AstName name = getIdentifier(expr); name.value)
{
return name;
}
}
else if (local)
{
return local->name;
}
return std::nullopt;
}

private:
AstExpr* expr = nullptr;
Expand Down
8 changes: 8 additions & 0 deletions Analysis/include/Luau/Module.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#include <unordered_map>
#include <optional>

LUAU_FASTFLAG(LuauPrepopulateUnionOptionsBeforeAllocation)

namespace Luau
{

Expand Down Expand Up @@ -58,6 +60,12 @@ struct TypeArena
template<typename T>
TypeId addType(T tv)
{
if (FFlag::LuauPrepopulateUnionOptionsBeforeAllocation)
{
if constexpr (std::is_same_v<T, UnionTypeVar>)
LUAU_ASSERT(tv.options.size() >= 2);
}

return addTV(TypeVar(std::move(tv)));
}

Expand Down
21 changes: 10 additions & 11 deletions Analysis/include/Luau/TypeInfer.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ struct TypeChecker
void checkBlock(const ScopePtr& scope, const AstStatBlock& statement);
void checkBlockTypeAliases(const ScopePtr& scope, std::vector<AstStat*>& sorted);

ExprResult<TypeId> checkExpr(const ScopePtr& scope, const AstExpr& expr, std::optional<TypeId> expectedType = std::nullopt);
ExprResult<TypeId> checkExpr(
const ScopePtr& scope, const AstExpr& expr, std::optional<TypeId> expectedType = std::nullopt, bool forceSingleton = false);
ExprResult<TypeId> checkExpr(const ScopePtr& scope, const AstExprLocal& expr);
ExprResult<TypeId> checkExpr(const ScopePtr& scope, const AstExprGlobal& expr);
ExprResult<TypeId> checkExpr(const ScopePtr& scope, const AstExprVarargs& expr);
Expand All @@ -160,14 +161,12 @@ struct TypeChecker
// Returns the type of the lvalue.
TypeId checkLValue(const ScopePtr& scope, const AstExpr& expr);

// Returns both the type of the lvalue and its binding (if the caller wants to mutate the binding).
// Note: the binding may be null.
// TODO: remove second return value with FFlagLuauUpdateFunctionNameBinding
std::pair<TypeId, TypeId*> checkLValueBinding(const ScopePtr& scope, const AstExpr& expr);
std::pair<TypeId, TypeId*> checkLValueBinding(const ScopePtr& scope, const AstExprLocal& expr);
std::pair<TypeId, TypeId*> checkLValueBinding(const ScopePtr& scope, const AstExprGlobal& expr);
std::pair<TypeId, TypeId*> checkLValueBinding(const ScopePtr& scope, const AstExprIndexName& expr);
std::pair<TypeId, TypeId*> checkLValueBinding(const ScopePtr& scope, const AstExprIndexExpr& expr);
// Returns the type of the lvalue.
TypeId checkLValueBinding(const ScopePtr& scope, const AstExpr& expr);
TypeId checkLValueBinding(const ScopePtr& scope, const AstExprLocal& expr);
TypeId checkLValueBinding(const ScopePtr& scope, const AstExprGlobal& expr);
TypeId checkLValueBinding(const ScopePtr& scope, const AstExprIndexName& expr);
TypeId checkLValueBinding(const ScopePtr& scope, const AstExprIndexExpr& expr);

TypeId checkFunctionName(const ScopePtr& scope, AstExpr& funName, TypeLevel level);
std::pair<TypeId, ScopePtr> checkFunctionSignature(const ScopePtr& scope, int subLevel, const AstExprFunction& expr,
Expand Down Expand Up @@ -322,8 +321,6 @@ struct TypeChecker
return addTV(TypeVar(tv));
}

TypeId addType(const UnionTypeVar& utv);

TypeId addTV(TypeVar&& tv);

TypePackId addTypePack(TypePackVar&& tp);
Expand All @@ -349,6 +346,8 @@ struct TypeChecker
ErrorVec resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense);

private:
void refineLValue(const LValue& lvalue, RefinementMap& refis, const ScopePtr& scope, TypeIdPredicate predicate);

std::optional<TypeId> resolveLValue(const ScopePtr& scope, const LValue& lvalue);
std::optional<TypeId> DEPRECATED_resolveLValue(const ScopePtr& scope, const LValue& lvalue);
std::optional<TypeId> resolveLValue(const RefinementMap& refis, const ScopePtr& scope, const LValue& lvalue);
Expand Down
8 changes: 4 additions & 4 deletions Analysis/include/Luau/TypeVar.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,16 +111,16 @@ struct PrimitiveTypeVar

// Singleton types https://github.com/Roblox/luau/blob/master/rfcs/syntax-singleton-types.md
// Types for true and false
struct BoolSingleton
struct BooleanSingleton
{
bool value;

bool operator==(const BoolSingleton& rhs) const
bool operator==(const BooleanSingleton& rhs) const
{
return value == rhs.value;
}

bool operator!=(const BoolSingleton& rhs) const
bool operator!=(const BooleanSingleton& rhs) const
{
return !(*this == rhs);
}
Expand All @@ -145,7 +145,7 @@ struct StringSingleton
// No type for float singletons, partly because === isn't any equalivalence on floats
// (NaN != NaN).

using SingletonVariant = Luau::Variant<BoolSingleton, StringSingleton>;
using SingletonVariant = Luau::Variant<BooleanSingleton, StringSingleton>;

struct SingletonTypeVar
{
Expand Down
7 changes: 7 additions & 0 deletions Analysis/include/Luau/Unifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,13 @@ struct Unifier

Unifier makeChildUnifier();

// A utility function that appends the given error to the unifier's error log.
// This allows setting a breakpoint wherever the unifier reports an error.
void reportError(TypeError error)
{
errors.push_back(error);
}

private:
bool isNonstrictMode() const;

Expand Down
101 changes: 51 additions & 50 deletions Analysis/src/Autocomplete.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

LUAU_FASTFLAG(LuauUseCommittingTxnLog)
LUAU_FASTFLAGVARIABLE(LuauAutocompleteAvoidMutation, false);
LUAU_FASTFLAGVARIABLE(LuauAutocompleteFirstArg, false);
LUAU_FASTFLAGVARIABLE(LuauCompleteBrokenStringParams, false);
LUAU_FASTFLAGVARIABLE(LuauMissingFollowACMetatables, false);
LUAU_FASTFLAGVARIABLE(PreferToCallFunctionsForIntersects, false);

static const std::unordered_set<std::string> kStatementStartingKeywords = {
"while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"};
Expand Down Expand Up @@ -194,8 +194,6 @@ static ParenthesesRecommendation getParenRecommendation(TypeId id, const std::ve

static std::optional<TypeId> findExpectedTypeAt(const Module& module, AstNode* node, Position position)
{
LUAU_ASSERT(FFlag::LuauAutocompleteFirstArg);

auto expr = node->asExpr();
if (!expr)
return std::nullopt;
Expand Down Expand Up @@ -266,43 +264,63 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ
}
};

TypeId expectedType;
auto typeAtPosition = findExpectedTypeAt(module, node, position);

if (FFlag::LuauAutocompleteFirstArg)
{
auto typeAtPosition = findExpectedTypeAt(module, node, position);
if (!typeAtPosition)
return TypeCorrectKind::None;

if (!typeAtPosition)
return TypeCorrectKind::None;
TypeId expectedType = follow(*typeAtPosition);

expectedType = follow(*typeAtPosition);
}
else
if (FFlag::PreferToCallFunctionsForIntersects)
{
auto expr = node->asExpr();
if (!expr)
return TypeCorrectKind::None;
auto checkFunctionType = [&canUnify, &expectedType](const FunctionTypeVar* ftv) {
auto [retHead, retTail] = flatten(ftv->retType);

auto it = module.astExpectedTypes.find(expr);
if (!it)
return TypeCorrectKind::None;
if (!retHead.empty() && canUnify(retHead.front(), expectedType))
return true;

expectedType = follow(*it);
}
// We might only have a variadic tail pack, check if the element is compatible
if (retTail)
{
if (const VariadicTypePack* vtp = get<VariadicTypePack>(follow(*retTail)); vtp && canUnify(vtp->ty, expectedType))
return true;
}

// We also want to suggest functions that return compatible result
if (const FunctionTypeVar* ftv = get<FunctionTypeVar>(ty))
{
auto [retHead, retTail] = flatten(ftv->retType);
return false;
};

if (!retHead.empty() && canUnify(retHead.front(), expectedType))
// We also want to suggest functions that return compatible result
if (const FunctionTypeVar* ftv = get<FunctionTypeVar>(ty); ftv && checkFunctionType(ftv))
{
return TypeCorrectKind::CorrectFunctionResult;

// We might only have a variadic tail pack, check if the element is compatible
if (retTail)
}
else if (const IntersectionTypeVar* itv = get<IntersectionTypeVar>(ty))
{
for (TypeId id : itv->parts)
{
if (const FunctionTypeVar* ftv = get<FunctionTypeVar>(id); ftv && checkFunctionType(ftv))
{
return TypeCorrectKind::CorrectFunctionResult;
}
}
}
}
else
{
// We also want to suggest functions that return compatible result
if (const FunctionTypeVar* ftv = get<FunctionTypeVar>(ty))
{
if (const VariadicTypePack* vtp = get<VariadicTypePack>(follow(*retTail)); vtp && canUnify(vtp->ty, expectedType))
auto [retHead, retTail] = flatten(ftv->retType);

if (!retHead.empty() && canUnify(retHead.front(), expectedType))
return TypeCorrectKind::CorrectFunctionResult;

// We might only have a variadic tail pack, check if the element is compatible
if (retTail)
{
if (const VariadicTypePack* vtp = get<VariadicTypePack>(follow(*retTail)); vtp && canUnify(vtp->ty, expectedType))
return TypeCorrectKind::CorrectFunctionResult;
}
}
}

Expand Down Expand Up @@ -741,29 +759,12 @@ std::optional<const T*> returnFirstNonnullOptionOfType(const UnionTypeVar* utv)

static std::optional<bool> functionIsExpectedAt(const Module& module, AstNode* node, Position position)
{
TypeId expectedType;
auto typeAtPosition = findExpectedTypeAt(module, node, position);

if (FFlag::LuauAutocompleteFirstArg)
{
auto typeAtPosition = findExpectedTypeAt(module, node, position);

if (!typeAtPosition)
return std::nullopt;

expectedType = follow(*typeAtPosition);
}
else
{
auto expr = node->asExpr();
if (!expr)
return std::nullopt;

auto it = module.astExpectedTypes.find(expr);
if (!it)
return std::nullopt;
if (!typeAtPosition)
return std::nullopt;

expectedType = follow(*it);
}
TypeId expectedType = follow(*typeAtPosition);

if (get<FunctionTypeVar>(expectedType))
return true;
Expand Down
7 changes: 2 additions & 5 deletions Analysis/src/Frontend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
LUAU_FASTFLAG(LuauInferInNoCheckMode)
LUAU_FASTFLAGVARIABLE(LuauTypeCheckTwice, false)
LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false)
LUAU_FASTFLAGVARIABLE(LuauPersistDefinitionFileTypes, false)

namespace Luau
{
Expand Down Expand Up @@ -102,8 +101,7 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t
generateDocumentationSymbols(globalTy, documentationSymbol);
targetScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol};

if (FFlag::LuauPersistDefinitionFileTypes)
persist(globalTy);
persist(globalTy);
}

for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings)
Expand All @@ -113,8 +111,7 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t
generateDocumentationSymbols(globalTy.type, documentationSymbol);
targetScope->exportedTypeBindings[name] = globalTy;

if (FFlag::LuauPersistDefinitionFileTypes)
persist(globalTy.type);
persist(globalTy.type);
}

return LoadDefinitionFileResult{true, parseResult, checkedModule};
Expand Down
28 changes: 22 additions & 6 deletions Analysis/src/Module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false)
LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300)
LUAU_FASTFLAG(LuauTypeAliasDefaults)

LUAU_FASTFLAGVARIABLE(LuauPrepopulateUnionOptionsBeforeAllocation, false)

namespace Luau
{

Expand Down Expand Up @@ -377,14 +379,28 @@ void TypeCloner::operator()(const AnyTypeVar& t)

void TypeCloner::operator()(const UnionTypeVar& t)
{
TypeId result = dest.addType(UnionTypeVar{});
seenTypes[typeId] = result;
if (FFlag::LuauPrepopulateUnionOptionsBeforeAllocation)
{
std::vector<TypeId> options;
options.reserve(t.options.size());

UnionTypeVar* option = getMutable<UnionTypeVar>(result);
LUAU_ASSERT(option != nullptr);
for (TypeId ty : t.options)
options.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState));

TypeId result = dest.addType(UnionTypeVar{std::move(options)});
seenTypes[typeId] = result;
}
else
{
TypeId result = dest.addType(UnionTypeVar{});
seenTypes[typeId] = result;

for (TypeId ty : t.options)
option->options.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState));
UnionTypeVar* option = getMutable<UnionTypeVar>(result);
LUAU_ASSERT(option != nullptr);

for (TypeId ty : t.options)
option->options.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState));
}
}

void TypeCloner::operator()(const IntersectionTypeVar& t)
Expand Down
11 changes: 3 additions & 8 deletions Analysis/src/ToString.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
#include <algorithm>
#include <stdexcept>

LUAU_FASTFLAG(LuauOccursCheckOkWithRecursiveFunctions)
LUAU_FASTFLAG(LuauTypeAliasDefaults)

/*
Expand Down Expand Up @@ -374,7 +373,7 @@ struct TypeVarStringifier

void operator()(TypeId, const SingletonTypeVar& stv)
{
if (const BoolSingleton* bs = Luau::get<BoolSingleton>(&stv))
if (const BooleanSingleton* bs = Luau::get<BooleanSingleton>(&stv))
state.emit(bs->value ? "true" : "false");
else if (const StringSingleton* ss = Luau::get<StringSingleton>(&stv))
{
Expand Down Expand Up @@ -617,9 +616,7 @@ struct TypeVarStringifier

std::string saved = std::move(state.result.name);

bool needParens = FFlag::LuauOccursCheckOkWithRecursiveFunctions
? !state.cycleNames.count(el) && (get<IntersectionTypeVar>(el) || get<FunctionTypeVar>(el))
: get<IntersectionTypeVar>(el) || get<FunctionTypeVar>(el);
bool needParens = !state.cycleNames.count(el) && (get<IntersectionTypeVar>(el) || get<FunctionTypeVar>(el));

if (needParens)
state.emit("(");
Expand Down Expand Up @@ -675,9 +672,7 @@ struct TypeVarStringifier

std::string saved = std::move(state.result.name);

bool needParens = FFlag::LuauOccursCheckOkWithRecursiveFunctions
? !state.cycleNames.count(el) && (get<UnionTypeVar>(el) || get<FunctionTypeVar>(el))
: get<UnionTypeVar>(el) || get<FunctionTypeVar>(el);
bool needParens = !state.cycleNames.count(el) && (get<UnionTypeVar>(el) || get<FunctionTypeVar>(el));

if (needParens)
state.emit("(");
Expand Down
2 changes: 1 addition & 1 deletion Analysis/src/TypeAttach.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class TypeRehydrationVisitor

AstType* operator()(const SingletonTypeVar& stv)
{
if (const BoolSingleton* bs = get<BoolSingleton>(&stv))
if (const BooleanSingleton* bs = get<BooleanSingleton>(&stv))
return allocator->alloc<AstTypeSingletonBool>(Location(), bs->value);
else if (const StringSingleton* ss = get<StringSingleton>(&stv))
{
Expand Down
Loading

0 comments on commit 2f989fc

Please sign in to comment.