Skip to content

Commit

Permalink
ADT recursion
Browse files Browse the repository at this point in the history
  • Loading branch information
Grant Wuerker committed Feb 21, 2024
1 parent 574c1f0 commit 08fe771
Show file tree
Hide file tree
Showing 11 changed files with 238 additions and 60 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions crates/common2/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,5 @@ camino = "1.1.4"
smol_str = "0.1.24"
salsa = { git = "https://github.com/salsa-rs/salsa", package = "salsa-2022" }
parser = { path = "../parser2", package = "fe-parser2" }
rustc-hash = "1.1.0"
ena = "0.14"
1 change: 1 addition & 0 deletions crates/common2/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pub mod diagnostics;
pub mod input;
pub mod recursive_def;

pub use input::{InputFile, InputIngot};

Expand Down
138 changes: 138 additions & 0 deletions crates/common2/src/recursive_def.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
use std::{fmt::Debug, hash::Hash};

use ena::unify::{InPlaceUnificationTable, UnifyKey};
use rustc_hash::FxHashMap;

/// Represents a definition that contains a direct reference to itself.
///
/// Recursive definitions are not valid and must be reported to the user.
/// It is preferable to group definitions together such that recursions
/// are reported in-whole rather than separately. `RecursiveDef` can be
/// used with `RecursiveDefHelper` to perform this grouping operation.
///
/// The fields `from` and `to` are the relevant identifiers and `site` can
/// be used to carry diagnostic information.
#[derive(Eq, PartialEq, Clone, Debug, Hash)]
pub struct RecursiveDef<T, U>
where
T: PartialEq + Copy,
{
pub from: T,
pub to: T,
pub site: U,
}

impl<T, U> RecursiveDef<T, U>
where
T: PartialEq + Copy,
{
pub fn new(from: T, to: T, site: U) -> Self {
Self { from, to, site }
}
}

#[derive(PartialEq, Debug, Clone, Copy)]
struct RecursiveDefKey(u32);

impl UnifyKey for RecursiveDefKey {
type Value = ();

fn index(&self) -> u32 {
self.0
}

fn from_index(idx: u32) -> Self {
Self(idx)
}

fn tag() -> &'static str {
"RecursiveDefKey"
}
}

pub struct RecursiveDefHelper<T, U>
where
T: Eq + Clone + Debug + Copy,
{
defs: Vec<RecursiveDef<T, U>>,
table: InPlaceUnificationTable<RecursiveDefKey>,
keys: FxHashMap<T, RecursiveDefKey>,
}

impl<T, U> RecursiveDefHelper<T, U>
where
T: Eq + Clone + Debug + Copy + Hash,
{
pub fn new(defs: Vec<RecursiveDef<T, U>>) -> Self {
let mut table = InPlaceUnificationTable::new();
let keys: FxHashMap<_, _> = defs
.iter()
.map(|def| (def.from, table.new_key(())))
.collect();

for def in defs.iter() {
table.union(keys[&def.from], keys[&def.to])
}

Self { defs, table, keys }
}

/// Removes a disjoint set of recursive definitions from the helper
/// and returns it, if one exists.
pub fn remove_disjoint_set(&mut self) -> Option<Vec<RecursiveDef<T, U>>> {
let mut disjoint_set = vec![];
let mut remaining_set = vec![];
let mut union_key: Option<&RecursiveDefKey> = None;

while let Some(def) = self.defs.pop() {
let cur_key = &self.keys[&def.from];

if union_key.is_none() || self.table.unioned(*union_key.unwrap(), *cur_key) {
disjoint_set.push(def)
} else {
remaining_set.push(def)
}

if union_key.is_none() {
union_key = Some(cur_key)
}
}

self.defs = remaining_set;

if union_key.is_some() {
Some(disjoint_set)
} else {
None
}
}
}

#[test]
fn one_recursion() {
let defs = vec![RecursiveDef::new(0, 1, ()), RecursiveDef::new(1, 0, ())];

let mut helper = RecursiveDefHelper::new(defs);
let disjoint_constituents = helper.remove_disjoint_set();
panic!("{:?}", disjoint_constituents)
// assert_eq!(disjoint_constituents[0].from.0, 0);
// assert_eq!(disjoint_constituents[1].from.0, 0);
}

#[test]
fn two_recursions() {
let defs = vec![
RecursiveDef::new(0, 1, ()),
RecursiveDef::new(1, 0, ()),
RecursiveDef::new(2, 3, ()),
RecursiveDef::new(3, 4, ()),
RecursiveDef::new(4, 2, ()),
];

let mut helper = RecursiveDefHelper::new(defs);
let disjoint_constituents1 = helper.remove_disjoint_set();
let disjoint_constituents2 = helper.remove_disjoint_set();
panic!("{:?}", disjoint_constituents1)
// assert_eq!(disjoint_constituents[0].from.0, 0);
// assert_eq!(disjoint_constituents[1].from.0, 0);
}
1 change: 1 addition & 0 deletions crates/hir-analysis/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ pub struct Jar(
ty::diagnostics::ImplTraitDefDiagAccumulator,
ty::diagnostics::ImplDefDiagAccumulator,
ty::diagnostics::FuncDefDiagAccumulator,
ty::diagnostics::RecursiveAdtDefAccumulator,
);

pub trait HirAnalysisDb: salsa::DbWithJar<Jar> + HirDb {
Expand Down
29 changes: 18 additions & 11 deletions crates/hir-analysis/src/ty/def_analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ use super::{
collect_impl_block_constraints, collect_super_traits, AssumptionListId, SuperTraitCycle,
},
constraint_solver::{is_goal_satisfiable, GoalSatisfiability},
diagnostics::{ImplDiag, TraitConstraintDiag, TraitLowerDiag, TyDiagCollection, TyLowerDiag},
diagnostics::{
ImplDiag, RecursiveAdtDef, TraitConstraintDiag, TraitLowerDiag, TyDiagCollection,
TyLowerDiag,
},
trait_def::{ingot_trait_env, Implementor, TraitDef, TraitMethod},
trait_lower::{lower_trait, lower_trait_ref, TraitRefLowerError},
ty_def::{AdtDef, AdtRef, AdtRefId, FuncDef, InvalidCause, TyData, TyId},
Expand All @@ -33,7 +36,8 @@ use crate::{
ty::{
diagnostics::{
AdtDefDiagAccumulator, FuncDefDiagAccumulator, ImplDefDiagAccumulator,
ImplTraitDefDiagAccumulator, TraitDefDiagAccumulator, TypeAliasDefDiagAccumulator,
ImplTraitDefDiagAccumulator, RecursiveAdtDefAccumulator, TraitDefDiagAccumulator,
TypeAliasDefDiagAccumulator,
},
method_table::collect_methods,
trait_lower::lower_impl_trait,
Expand Down Expand Up @@ -62,8 +66,8 @@ pub fn analyze_adt(db: &dyn HirAnalysisDb, adt_ref: AdtRefId) {
AdtDefDiagAccumulator::push(db, diag);
}

if let Some(diag) = check_recursive_adt(db, adt_ref) {
AdtDefDiagAccumulator::push(db, diag);
if let Some(def) = check_recursive_adt(db, adt_ref) {
RecursiveAdtDefAccumulator::push(db, def);
}
}

Expand Down Expand Up @@ -764,7 +768,7 @@ impl<'db> Visitor for DefAnalyzer<'db> {
pub(crate) fn check_recursive_adt(
db: &dyn HirAnalysisDb,
adt: AdtRefId,
) -> Option<TyDiagCollection> {
) -> Option<RecursiveAdtDef> {
let adt_def = lower_adt(db, adt);
for field in adt_def.fields(db) {
for ty in field.iter_types(db) {
Expand All @@ -781,7 +785,7 @@ fn check_recursive_adt_impl(
db: &dyn HirAnalysisDb,
cycle: &salsa::Cycle,
adt: AdtRefId,
) -> Option<TyDiagCollection> {
) -> Option<RecursiveAdtDef> {
let participants: FxHashSet<_> = cycle
.participant_keys()
.map(|key| check_recursive_adt::key_from_id(key.key_index()))
Expand All @@ -792,11 +796,14 @@ fn check_recursive_adt_impl(
for (ty_idx, ty) in field.iter_types(db).enumerate() {
for field_adt_ref in ty.collect_direct_adts(db) {
if participants.contains(&field_adt_ref) && participants.contains(&adt) {
let diag = TyLowerDiag::recursive_type(
adt.name_span(db),
adt_def.variant_ty_span(db, field_idx, ty_idx),
);
return Some(diag.into());
return Some(RecursiveAdtDef::new(
adt,
field_adt_ref,
(
adt.name_span(db),
adt_def.variant_ty_span(db, field_idx, ty_idx),
),
));
}
}
}
Expand Down
57 changes: 30 additions & 27 deletions crates/hir-analysis/src/ty/diagnostics.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use std::collections::BTreeSet;

use common::diagnostics::{
CompleteDiagnostic, DiagnosticPass, GlobalErrorCode, LabelStyle, Severity, SubDiagnostic,
use common::{
diagnostics::{
CompleteDiagnostic, DiagnosticPass, GlobalErrorCode, LabelStyle, Severity, SubDiagnostic,
},
recursive_def::RecursiveDef,
};
use hir::{
diagnostics::DiagnosticVoucher,
Expand All @@ -11,11 +14,12 @@ use hir::{
};
use itertools::Itertools;

use crate::HirAnalysisDb;

use super::{
constraint::PredicateId,
ty_def::{Kind, TyId},
ty_def::{AdtRefId, Kind, TyId},
};
use crate::HirAnalysisDb;

#[salsa::accumulator]
pub struct AdtDefDiagAccumulator(pub(super) TyDiagCollection);
Expand All @@ -29,6 +33,10 @@ pub struct ImplDefDiagAccumulator(pub(super) TyDiagCollection);
pub struct FuncDefDiagAccumulator(pub(super) TyDiagCollection);
#[salsa::accumulator]
pub struct TypeAliasDefDiagAccumulator(pub(super) TyDiagCollection);
#[salsa::accumulator]
pub struct RecursiveAdtDefAccumulator(pub(super) RecursiveAdtDef);

pub type RecursiveAdtDef = RecursiveDef<AdtRefId, (DynLazySpan, DynLazySpan)>;

#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::From)]
pub enum TyDiagCollection {
Expand All @@ -53,10 +61,7 @@ impl TyDiagCollection {
pub enum TyLowerDiag {
ExpectedStarKind(DynLazySpan),
InvalidTypeArgKind(DynLazySpan, String),
RecursiveType {
primary_span: DynLazySpan,
field_span: DynLazySpan,
},
AdtRecursion(Vec<RecursiveAdtDef>),

UnboundTypeAliasParam {
span: DynLazySpan,
Expand Down Expand Up @@ -140,11 +145,8 @@ impl TyLowerDiag {
Self::InvalidTypeArgKind(span, msg)
}

pub(super) fn recursive_type(primary_span: DynLazySpan, field_span: DynLazySpan) -> Self {
Self::RecursiveType {
primary_span,
field_span,
}
pub(super) fn adt_recursion(defs: Vec<RecursiveAdtDef>) -> Self {
Self::AdtRecursion(defs)
}

pub(super) fn unbound_type_alias_param(
Expand Down Expand Up @@ -249,7 +251,7 @@ impl TyLowerDiag {
match self {
Self::ExpectedStarKind(_) => 0,
Self::InvalidTypeArgKind(_, _) => 1,
Self::RecursiveType { .. } => 2,
Self::AdtRecursion { .. } => 2,
Self::UnboundTypeAliasParam { .. } => 3,
Self::TypeAliasCycle { .. } => 4,
Self::InconsistentKindBound(_, _) => 5,
Expand All @@ -270,7 +272,7 @@ impl TyLowerDiag {
match self {
Self::ExpectedStarKind(_) => "expected `*` kind in this context".to_string(),
Self::InvalidTypeArgKind(_, _) => "invalid type argument kind".to_string(),
Self::RecursiveType { .. } => "recursive type is not allowed".to_string(),
Self::AdtRecursion { .. } => "recursive type is not allowed".to_string(),

Self::UnboundTypeAliasParam { .. } => {
"all type parameters of type alias must be given".to_string()
Expand Down Expand Up @@ -326,22 +328,23 @@ impl TyLowerDiag {
span.resolve(db),
)],

Self::RecursiveType {
primary_span,
field_span,
} => {
vec![
SubDiagnostic::new(
Self::AdtRecursion(defs) => {
let mut diags = vec![];

for RecursiveAdtDef { site, .. } in defs {
diags.push(SubDiagnostic::new(
LabelStyle::Primary,
"recursive type definition".to_string(),
primary_span.resolve(db),
),
SubDiagnostic::new(
site.0.resolve(db),
));
diags.push(SubDiagnostic::new(
LabelStyle::Secondary,
"recursion occurs here".to_string(),
field_span.resolve(db),
),
]
site.1.resolve(db),
));
}

diags
}

Self::UnboundTypeAliasParam {
Expand Down
Loading

0 comments on commit 08fe771

Please sign in to comment.