Skip to content

Commit

Permalink
BaseClass wraps Expr: fixes invariant violation
Browse files Browse the repository at this point in the history
Summary:
This commit solves a bug where Pyre was unable to correctly handle
generic classes that referred to themselves in their own base classes,
because it triggered recursive lookups of the type parameters that
had inconsistent lengths.

This happens because we have to analyze the base classes in order to *get* the
type parameters, and we were doing type-level analysis. The fix is to
modify `BaseClass` to be expression-level, so that we can use an AST visitor
rather than a type visitor. In doing so, we can avoid ever triggering an
`expr_untype` call on the current class, which bypasses the cause of the
invariant violation.

---------

This fix is a bit messy: we're now putting `Expr` into the
`Answer` of a `Keyed` instance, which is awkward because of extra clones,
because we have to define "recursive promotion" even though there's no
good recursive behavior anymore, and because we need `Display` which can't
really be implemented well on `Expr`.

I want to punt on all these issues for this diff, because with this approach
in place it becomes much easier to just roll the base classes into `Class`,
which will make all of these things more tractable and also make it much
easier for us to truly enforce the invariants that keep causing panics.

So instead of spending time optimizing this setup I'd prefer to just move
forward with that plan, which will take some effort but should yield big
benefits.

Reviewed By: ndmitchell

Differential Revision: D66511478

fbshipit-source-id: a1e900f0ef6bb710b599d23725791ad2d051bf95
  • Loading branch information
stroxler authored and facebook-github-bot committed Nov 26, 2024
1 parent 5a26748 commit f3402bf
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 17 deletions.
6 changes: 5 additions & 1 deletion pyre2/pyre2/bin/alt/answers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,11 @@ impl Solve for KeyBaseClass {
fn recursive(_answers: &AnswersSolver) -> Self::Recursive {}

fn promote_recursive(_: Self::Recursive) -> Self::Answer {
BaseClass::Type(Type::any_implicit())
// TODO(stroxler): Putting a panic here is risky, but I am expecting to refactor
// within a few commits to make the base class handling internal to `classes.rs`
// and eliminate the binding, which will eliminate this kind of boilerplate
// altogether.
unreachable!("BaseClass cannot hit recursive cases without violating invariants");
}

fn visit_type_mut(v: &mut BaseClass, f: &mut dyn FnMut(&mut Type)) {
Expand Down
28 changes: 25 additions & 3 deletions pyre2/pyre2/bin/alt/classes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ use crate::types::types::Quantified;
use crate::types::types::QuantifiedVec;
use crate::types::types::Type;
use crate::util::prelude::SliceExt;
use crate::visitors::Visitors;

/// Class members can fail to be
pub enum NoClassAttribute {
Expand Down Expand Up @@ -93,7 +94,7 @@ impl<'a> AnswersSolver<'a> {
special_base_class
} else {
// This branch handles all other base classes.
BaseClass::Type(self.expr_untype(base_expr))
BaseClass::Expr(base_expr.clone())
}
}

Expand All @@ -104,8 +105,26 @@ impl<'a> AnswersSolver<'a> {
/// If the base class is a "normal" generic base (not `Protocol` or `Generic`), then
/// call `f` on each `Quantified` in left-to-right order.
fn for_each_quantified_if_not_special(&self, base: &BaseClass, f: &mut impl FnMut(Quantified)) {
fn for_each_quantified_in_expr(
x: &Expr,
answers_solver: &AnswersSolver,
f: &mut impl FnMut(Quantified),
) {
match x {
Expr::Name(_) => match answers_solver.expr(x, None) {
Type::Type(box Type::Quantified(q)) => f(q),
_ => {}
},
_ => {}
}
Visitors::visit_expr(x, &mut |x: &Expr| {
for_each_quantified_in_expr(x, answers_solver, f)
})
}
match base {
BaseClass::Type(t) => t.for_each_quantified(f),
BaseClass::Expr(base) => Visitors::visit_expr(base, &mut |x: &Expr| {
for_each_quantified_in_expr(x, self, f)
}),
_ => {}
}
}
Expand Down Expand Up @@ -172,7 +191,10 @@ impl<'a> AnswersSolver<'a> {
self.bases_of_class(class)
.iter()
.filter_map(|base| match base.deref() {
BaseClass::Type(Type::ClassType(c)) => Some(c.clone()),
BaseClass::Expr(x) => match self.expr_untype(x) {
Type::ClassType(c) => Some(c),
_ => None,
},
_ => None,
})
.collect()
Expand Down
9 changes: 5 additions & 4 deletions pyre2/pyre2/bin/test/legacy_generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -333,14 +333,15 @@ def f2(c: C[Child, Parent]):
simple_test!(
test_generic_with_reference_to_self_in_base,
r#"
from typing import Generic, TypeVar, Any
from typing import Generic, TypeVar, Any, assert_type
T = TypeVar("T")
class C(list[C[T]]): # E: Expected 0 type arguments for class `C`, got 1.
class C(list[C[T]]):
t: T
def f(c: C[int]): # E: Expected 0 type arguments for class `C`, got 1
pass
def f(c: C[int]):
assert_type(c.t, int)
assert_type(c[0], C[int])
"#,
);
12 changes: 8 additions & 4 deletions pyre2/pyre2/bin/types/base_class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
use std::fmt;
use std::fmt::Display;

use ruff_python_ast::Expr;
use ruff_text_size::TextRange;

use crate::error::collector::ErrorCollector;
Expand All @@ -25,7 +26,7 @@ pub enum BaseClass {
TypedDict,
Generic(Vec<Type>),
Protocol(Vec<Type>),
Type(Type),
Expr(Expr),
}

impl Display for BaseClass {
Expand All @@ -35,16 +36,19 @@ impl Display for BaseClass {
BaseClass::TypedDict => write!(f, "TypedDict"),
BaseClass::Generic(xs) => write!(f, "Generic[{}]", commas_iter(|| xs.iter())),
BaseClass::Protocol(xs) => write!(f, "Protocol[{}]", commas_iter(|| xs.iter())),
BaseClass::Type(t) => write!(f, "{t}"),
// TODO(stroxler): Do not use Debug here. Putting this off for now because I'm expecting
// to refactor in upcoming commits until this is an implementation detail of `classes.rs`,
// at which point we won't need Display at all anymore.
BaseClass::Expr(s) => write!(f, "Expr({s:?})"),
}
}
}

impl BaseClass {
pub fn visit_mut<'a>(&'a mut self, mut f: impl FnMut(&'a mut Type)) {
pub fn visit_mut<'a>(&'a mut self, f: impl FnMut(&'a mut Type)) {
match self {
BaseClass::Generic(xs) | BaseClass::Protocol(xs) => xs.iter_mut().for_each(f),
BaseClass::Type(t) => f(t),
BaseClass::Expr(_) => {}
BaseClass::NamedTuple | BaseClass::TypedDict => {}
}
}
Expand Down
5 changes: 0 additions & 5 deletions pyre2/pyre2/bin/types/class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,17 +226,12 @@ impl ClassType {
if targs.len() != tparams.len() {
// Invariant violation: all type arguments should be constructed through
// `check_and_sanitize_targs_for_class`, which should guarantee zippability.
/* TODO(stroxler): until this is enabled, there are edge cases with undefined
behavior. Needed because we are still working out how to handle recursion in
base classes.
unreachable!(
"Encountered invalid type arguments of length {} in class `{}` (expected {})",
targs.len(),
self.name().id,
tparams.len(),
);
*/
return Substitution(SmallMap::new());
}
Substitution(
tparams
Expand Down

0 comments on commit f3402bf

Please sign in to comment.