Skip to content

Commit

Permalink
Move unification logic out of InferenceTable.
Browse files Browse the repository at this point in the history
Instead, just have the table remember the type annotations that have been seen on nodes associated with a given type variable. We can then unify those type annotations at the time of conversion of the table to `TypeInfo`.

This change also gets parametric signedness working.

PiperOrigin-RevId: 700611106
  • Loading branch information
richmckeever authored and copybara-github committed Nov 27, 2024
1 parent 62e2a7e commit 72da07d
Show file tree
Hide file tree
Showing 7 changed files with 422 additions and 119 deletions.
19 changes: 18 additions & 1 deletion xls/dslx/type_system_v2/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,22 @@ cc_library(
hdrs = ["inference_table_to_type_info.h"],
deps = [
":inference_table",
":type_annotation_utils",
"//xls/common/status:status_macros",
"//xls/dslx:constexpr_evaluator",
"//xls/dslx:errors",
"//xls/dslx:import_data",
"//xls/dslx:interp_value",
"//xls/dslx:warning_collector",
"//xls/dslx/frontend:ast",
"//xls/dslx/frontend:module",
"//xls/dslx/frontend:pos",
"//xls/dslx/type_system:deduce_utils",
"//xls/dslx/type_system:parametric_env",
"//xls/dslx/type_system:type",
"//xls/dslx/type_system:type_info",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
Expand Down Expand Up @@ -99,3 +101,18 @@ cc_test(
"@com_google_googletest//:gtest",
],
)

cc_library(
name = "type_annotation_utils",
srcs = ["type_annotation_utils.cc"],
hdrs = ["type_annotation_utils.h"],
deps = [
"//xls/common/status:status_macros",
"//xls/dslx/frontend:ast",
"//xls/dslx/frontend:module",
"//xls/dslx/frontend:pos",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
],
)
113 changes: 56 additions & 57 deletions xls/dslx/type_system_v2/inference_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,32 +89,41 @@ absl::StatusOr<InferenceVariableKind> TypeAnnotationToInferenceVariableKind(
// Represents the immutable metadata for a variable in an `InferenceTable`.
class InferenceVariable {
public:
InferenceVariable(const AstNode* definer, std::string_view name,
InferenceVariable(const AstNode* definer, const NameRef* name_ref,
InferenceVariableKind kind, bool parametric)
: definer_(definer), name_(name), kind_(kind), parametric_(parametric) {}
: definer_(definer),
name_ref_(name_ref),
kind_(kind),
parametric_(parametric) {}

const AstNode* definer() const { return definer_; }

std::string_view name() const { return name_; }
std::string_view name() const { return name_ref_->identifier(); }

InferenceVariableKind kind() const { return kind_; }

bool parametric() const { return parametric_; }

// Returns the `NameRef` dealt out for this variable at creation time. This
// is used to avoid spurious creations of additional refs by table functions
// that look up variables. However, any `NameRef` to the variable's `NameDef`
// is equally usable.
const NameRef* name_ref() const { return name_ref_; }

template <typename H>
friend H AbslHashValue(H h, const InferenceVariable& v) {
return H::combine(std::move(h), v.definer_, v.name_, v.kind_);
return H::combine(std::move(h), v.definer_, v.name(), v.kind_);
}

std::string ToString() const {
return absl::Substitute("InferenceVariable(name=$0, kind=$1, definer=$2)",
name_, InferenceVariableKindToString(kind_),
name(), InferenceVariableKindToString(kind_),
definer_->ToString());
}

private:
const AstNode* const definer_;
const std::string name_;
const NameRef* const name_ref_;
const InferenceVariableKind kind_;
const bool parametric_;
};
Expand Down Expand Up @@ -177,28 +186,32 @@ class InferenceTableImpl : public InferenceTable {
InferenceTableImpl(Module& module, const FileTable& file_table)
: module_(module), file_table_(file_table) {}

absl::StatusOr<NameRef*> DefineInternalVariable(
absl::StatusOr<const NameRef*> DefineInternalVariable(
InferenceVariableKind kind, AstNode* definer,
std::string_view name) override {
CHECK(definer->GetSpan().has_value());
Span span = *definer->GetSpan();
NameDef* name_def = module_.Make<NameDef>(span, std::string(name), definer);
const NameDef* name_def =
module_.Make<NameDef>(span, std::string(name), definer);
const NameRef* name_ref =
module_.Make<NameRef>(span, std::string(name), name_def);
AddVariable(name_def, std::make_unique<InferenceVariable>(
definer, name, kind, /*parametric=*/false));
return module_.Make<NameRef>(span, std::string(name), name_def);
definer, name_ref, kind, /*parametric=*/false));
return name_ref;
}

absl::StatusOr<NameRef*> DefineParametricVariable(
absl::StatusOr<const NameRef*> DefineParametricVariable(
const ParametricBinding& binding) override {
XLS_ASSIGN_OR_RETURN(
InferenceVariableKind kind,
TypeAnnotationToInferenceVariableKind(binding.type_annotation()));
const NameDef* name_def = binding.name_def();
const NameRef* name_ref = module_.Make<NameRef>(
name_def->span(), name_def->identifier(), name_def);
AddVariable(name_def, std::make_unique<InferenceVariable>(
name_def, name_def->identifier(), kind,
/*parametric=*/true));
return module_.Make<NameRef>(name_def->span(), name_def->identifier(),
name_def);
name_def, name_ref, kind, /*parametric=*/true));
XLS_RETURN_IF_ERROR(SetTypeAnnotation(name_def, binding.type_annotation()));
return name_ref;
}

absl::StatusOr<const ParametricInvocation*> AddParametricInvocation(
Expand Down Expand Up @@ -323,17 +336,34 @@ class InferenceTableImpl : public InferenceTable {
return it->second.type_annotation;
}

std::optional<const NameRef*> GetTypeVariable(
const AstNode* node) const override {
const auto it = node_data_.find(node);
if (it == node_data_.end()) {
return std::nullopt;
}
const std::optional<const InferenceVariable*>& variable =
it->second.type_variable;
return variable.has_value() ? std::make_optional((*variable)->name_ref())
: std::nullopt;
}

absl::StatusOr<std::vector<const TypeAnnotation*>>
GetTypeAnnotationsForTypeVariable(const NameRef* ref) const override {
XLS_ASSIGN_OR_RETURN(const InferenceVariable* variable, GetVariable(ref));
const auto it = type_annotations_per_type_variable_.find(variable);
return it == type_annotations_per_type_variable_.end()
? std::vector<const TypeAnnotation*>()
: it->second;
}

private:
void AddVariable(const NameDef* name_def,
std::unique_ptr<InferenceVariable> variable) {
if (variable->kind() == InferenceVariableKind::kType) {
type_constraints_.emplace(variable.get(),
std::make_unique<TypeConstraints>());
}
variables_.emplace(name_def, std::move(variable));
}

absl::StatusOr<InferenceVariable*> GetVariable(const NameRef* ref) {
absl::StatusOr<InferenceVariable*> GetVariable(const NameRef* ref) const {
if (std::holds_alternative<const NameDef*>(ref->name_def())) {
const auto it =
variables_.find(std::get<const NameDef*>(ref->name_def()));
Expand Down Expand Up @@ -362,8 +392,8 @@ class InferenceTableImpl : public InferenceTable {
// Refine and check the associated type variable.
if (node_data.type_variable.has_value() &&
node_data.type_annotation.has_value()) {
XLS_RETURN_IF_ERROR(RefineAndCheckTypeVariable(
*node_data.type_variable, *node_data.type_annotation));
type_annotations_per_type_variable_[*node_data.type_variable].push_back(
*node_data.type_annotation);
}
// Update the dependencies of the node.
if (node_data.type_variable.has_value()) {
Expand All @@ -380,37 +410,6 @@ class InferenceTableImpl : public InferenceTable {
return absl::OkStatus();
}

// Refines what is known about the given `variable` (which is assumed to be a
// type-kind variable) based on the given `annotation` that it must satisfy,
// and errors if there is a conflict with existing information.
absl::Status RefineAndCheckTypeVariable(const InferenceVariable* variable,
const TypeAnnotation* annotation) {
const auto* builtin_annotation =
dynamic_cast<const BuiltinTypeAnnotation*>(annotation);
if (builtin_annotation == nullptr) {
return absl::InvalidArgumentError(
absl::StrCat("Type inference version 2 does not yet support refining "
"and updating a variable with type annotation: ",
annotation->ToString()));
}
TypeConstraints& constraints = *type_constraints_[variable];
if (!constraints.min_width.has_value() ||
builtin_annotation->GetBitCount() > *constraints.min_width) {
constraints.min_width = builtin_annotation->GetBitCount();
}
XLS_ASSIGN_OR_RETURN(const bool annotation_is_signed,
builtin_annotation->GetSignedness());
if (constraints.is_signed.has_value() &&
annotation_is_signed != *constraints.is_signed) {
return SignednessMismatchErrorStatus(
annotation, *constraints.signedness_definer, file_table_);
} else if (!constraints.is_signed.has_value()) {
constraints.is_signed = annotation_is_signed;
constraints.signedness_definer = annotation;
}
return absl::OkStatus();
}

void AddDependency(const AstNode* node, const InferenceVariable* variable) {
variable_dependents_[variable].insert(node);
if (variable->parametric()) {
Expand Down Expand Up @@ -444,11 +443,11 @@ class InferenceTableImpl : public InferenceTable {
// internally.
absl::flat_hash_map<const NameDef*, std::unique_ptr<InferenceVariable>>
variables_;
// The constraints that have been determined for `variables_` that are
// of `kType` kind.
// The type annotations that have been associated with each inference
// variable of type-kind.
absl::flat_hash_map<const InferenceVariable*,
std::unique_ptr<TypeConstraints>>
type_constraints_;
std::vector<const TypeAnnotation*>>
type_annotations_per_type_variable_;
// The `AstNode` objects that have associated data.
absl::flat_hash_map<const AstNode*, NodeData> node_data_;
// Which `AstNode` objects depend on which variables.
Expand Down
13 changes: 11 additions & 2 deletions xls/dslx/type_system_v2/inference_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ class InferenceTable {
// which has no direct representation in the DSLX source code that is being
// analyzed. It is up to the inference system using the table to decide a
// naming scheme for such variables.
virtual absl::StatusOr<NameRef*> DefineInternalVariable(
virtual absl::StatusOr<const NameRef*> DefineInternalVariable(
InferenceVariableKind kind, AstNode* definer, std::string_view name) = 0;

// Defines an inference variable corresponding to a parametric in the DSLX
Expand All @@ -161,7 +161,7 @@ class InferenceTable {
//
// At the time of conversion of the table to `TypeInfo`, we distinctly resolve
// `N` and its dependent types for each invocation context of `foo`.
virtual absl::StatusOr<NameRef*> DefineParametricVariable(
virtual absl::StatusOr<const NameRef*> DefineParametricVariable(
const ParametricBinding& binding) = 0;

// Defines an invocation context for a parametric function, giving its
Expand Down Expand Up @@ -212,6 +212,15 @@ class InferenceTable {
// Returns the type annotation for `node` in the table, if any.
virtual std::optional<const TypeAnnotation*> GetTypeAnnotation(
const AstNode* node) const = 0;

// Returns the type variable for `node` in the table, if any.
virtual std::optional<const NameRef*> GetTypeVariable(
const AstNode* node) const = 0;

// Returns all type annotations that have been associated with the given
// variable, in the order they were added to the table.
virtual absl::StatusOr<std::vector<const TypeAnnotation*>>
GetTypeAnnotationsForTypeVariable(const NameRef* variable) const = 0;
};

} // namespace xls::dslx
Expand Down
86 changes: 64 additions & 22 deletions xls/dslx/type_system_v2/inference_table_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ TEST_F(InferenceTableTest, SetTypeVariableToNonType) {
NameDef* x = module_->Make<NameDef>(Span::Fake(), "x", /*definer=*/nullptr);
NameDef* n = module_->Make<NameDef>(Span::Fake(), "N", /*definer=*/nullptr);
XLS_ASSERT_OK_AND_ASSIGN(
NameRef * n_var,
const NameRef* n_var,
table_->DefineInternalVariable(InferenceVariableKind::kInteger, n, "N"));
EXPECT_THAT(table_->SetTypeVariable(x, n_var),
StatusIs(absl::StatusCode::kInvalidArgument));
Expand All @@ -169,15 +169,15 @@ TEST_F(InferenceTableTest, SignednessMismatch) {
Span::Fake(), BuiltinType::kS32,
module_->GetOrCreateBuiltinNameDef("s32"));

XLS_ASSERT_OK_AND_ASSIGN(
NameRef * t0, table_->DefineInternalVariable(InferenceVariableKind::kType,
add_node, "T0"));
XLS_ASSERT_OK_AND_ASSIGN(const NameRef* t0,
table_->DefineInternalVariable(
InferenceVariableKind::kType, add_node, "T0"));
XLS_EXPECT_OK(table_->SetTypeVariable(x_ref, t0));
XLS_EXPECT_OK(table_->SetTypeVariable(y_ref, t0));
XLS_EXPECT_OK(table_->SetTypeAnnotation(x_ref, u32_annotation));

XLS_EXPECT_OK(table_->SetTypeAnnotation(y_ref, s32_annotation));
EXPECT_THAT(
table_->SetTypeAnnotation(y_ref, s32_annotation),
ConvertTableToTypeInfo(),
StatusIs(absl::StatusCode::kInvalidArgument,
ContainsRegex("signed vs. unsigned mismatch.*s32.*vs. u32")));
}
Expand All @@ -197,19 +197,17 @@ TEST_F(InferenceTableTest, SignednessAgreement) {
Span::Fake(), BuiltinType::kU32,
module_->GetOrCreateBuiltinNameDef("u32"));

XLS_ASSERT_OK_AND_ASSIGN(
NameRef * t0, table_->DefineInternalVariable(InferenceVariableKind::kType,
add_node, "T0"));
XLS_ASSERT_OK_AND_ASSIGN(const NameRef* t0,
table_->DefineInternalVariable(
InferenceVariableKind::kType, add_node, "T0"));
XLS_EXPECT_OK(table_->SetTypeVariable(x_ref, t0));
XLS_EXPECT_OK(table_->SetTypeVariable(y_ref, t0));
XLS_EXPECT_OK(table_->SetTypeAnnotation(x_ref, u32_annotation));
XLS_EXPECT_OK(table_->SetTypeAnnotation(y_ref, u32_annotation));
XLS_ASSERT_OK_AND_ASSIGN(std::string type_info_string,
ConvertTableToTypeInfoString());
EXPECT_EQ(type_info_string, R"(
span: <no-file>:1:1-1:1, node: `x`, type: uN[32]
span: <no-file>:1:1-1:1, node: `y`, type: uN[32]
)");
EXPECT_THAT(type_info_string, AllOf(HasSubstr("node: `x`, type: uN[32]"),
HasSubstr("node: `y`, type: uN[32]")));
}

TEST_F(InferenceTableTest, ParametricVariable) {
Expand Down Expand Up @@ -255,16 +253,60 @@ TEST_F(InferenceTableTest, ParametricVariable) {
TypeInfoToString(**invocation_ti1));
XLS_ASSERT_OK_AND_ASSIGN(std::string invocation_ti2_string,
TypeInfoToString(**invocation_ti2));
EXPECT_EQ(invocation_ti1_string, R"(
span: fake.x:2:20-2:28, node: `a: uN[N]`, type: uN[4]
span: fake.x:2:26-2:27, node: `N`, type: uN[32]
span: fake.x:2:26-2:27, node: `u32`, type: typeof(uN[32])
)");
EXPECT_EQ(invocation_ti2_string, R"(
span: fake.x:2:20-2:28, node: `a: uN[N]`, type: uN[5]
span: fake.x:2:26-2:27, node: `N`, type: uN[32]
span: fake.x:2:26-2:27, node: `u32`, type: typeof(uN[32])
EXPECT_THAT(invocation_ti1_string,
HasSubstr("node: `a: uN[N]`, type: uN[4]"));
EXPECT_THAT(invocation_ti2_string,
HasSubstr("node: `a: uN[N]`, type: uN[5]"));
}

TEST_F(InferenceTableTest, ParametricVariableForSignedness) {
ParseAndInitModuleAndTable(R"(
fn foo<S: bool, N: u32>(a: xN[S][N]) -> xN[S][N] { a }
fn bar() {
foo<true, u32:4>(u4:1);
foo<false, u32:5>(u5:3);
}
)");

XLS_ASSERT_OK_AND_ASSIGN(const Function* foo,
module_->GetMemberOrError<Function>("foo"));
ASSERT_EQ(foo->parametric_bindings().size(), 2);
ASSERT_EQ(foo->params().size(), 1);
for (const ParametricBinding* binding : foo->parametric_bindings()) {
XLS_ASSERT_OK(table_->DefineParametricVariable(*binding));
}
for (const Param* param : foo->params()) {
XLS_ASSERT_OK(table_->SetTypeAnnotation(param, param->type_annotation()));
}
XLS_ASSERT_OK_AND_ASSIGN(const Function* bar,
module_->GetMemberOrError<Function>("bar"));
ASSERT_EQ(bar->body()->statements().size(), 2);
const Invocation* invocation1 = down_cast<const Invocation*>(
ToAstNode(bar->body()->statements().at(0)->wrapped()));
const Invocation* invocation2 = down_cast<const Invocation*>(
ToAstNode(bar->body()->statements().at(1)->wrapped()));
XLS_ASSERT_OK(
table_->AddParametricInvocation(*invocation1, *foo, *bar,
/*caller_invocation=*/std::nullopt));
XLS_ASSERT_OK(
table_->AddParametricInvocation(*invocation2, *foo, *bar,
/*caller_invocation=*/std::nullopt));

XLS_ASSERT_OK_AND_ASSIGN(TypeInfo * ti, ConvertTableToTypeInfo());
std::optional<TypeInfo*> invocation_ti1 =
ti->GetInvocationTypeInfo(invocation1, ParametricEnv());
std::optional<TypeInfo*> invocation_ti2 =
ti->GetInvocationTypeInfo(invocation2, ParametricEnv());
EXPECT_TRUE(invocation_ti1.has_value());
EXPECT_TRUE(invocation_ti2.has_value());
XLS_ASSERT_OK_AND_ASSIGN(std::string invocation_ti1_string,
TypeInfoToString(**invocation_ti1));
XLS_ASSERT_OK_AND_ASSIGN(std::string invocation_ti2_string,
TypeInfoToString(**invocation_ti2));
EXPECT_THAT(invocation_ti1_string,
HasSubstr("node: `a: xN[S][N]`, type: sN[4]"));
EXPECT_THAT(invocation_ti2_string,
HasSubstr("node: `a: xN[S][N]`, type: uN[5]"));
}

TEST_F(InferenceTableTest, ParametricVariableWithDefault) {
Expand Down
Loading

0 comments on commit 72da07d

Please sign in to comment.