Skip to content

Commit

Permalink
Removed the concrete trait from generated closure function key.
Browse files Browse the repository at this point in the history
Fixed #7095
  • Loading branch information
orizi committed Jan 18, 2025
1 parent 85f06b9 commit e28ef5f
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 38 deletions.
4 changes: 2 additions & 2 deletions crates/cairo-lang-lowering/src/borrow_check/test_data/closure
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ Statements:
(v2: {[email protected]:3:13: 3:16}, v3: @{[email protected]:3:13: 3:16}) <- snapshot(v1{`c`})
(v4: core::felt252) <- 2
(v5: (core::felt252,)) <- struct_construct(v4{`2`})
(v6: core::felt252) <- Generated core::ops::function::Fn::<{[email protected]:3:13: 3:16}, (core::felt252,)>::call(v3{`c`}, v5{`c(2)`})
(v6: core::felt252) <- Generated `core::ops::function::Fn::call` for {[email protected]:3:13: 3:16}(v3{`c`}, v5{`c(2)`})
(v7: core::felt252) <- core::Felt252Add::add(v6{`y`}, v0{`x`})
End:
Return(v7)
Expand Down Expand Up @@ -318,7 +318,7 @@ Statements:
(v7: {[email protected]:3:13: 3:16}, v8: @{[email protected]:3:13: 3:16}) <- snapshot(v6{`c`})
(v9: core::integer::u32) <- 2
(v10: (core::integer::u32,)) <- struct_construct(v9{`2`})
(v11: core::integer::u32) <- Generated core::ops::function::Fn::<{[email protected]:3:13: 3:16}, (core::integer::u32,)>::call(v8{`c`}, v10{`c(2)`})
(v11: core::integer::u32) <- Generated `core::ops::function::Fn::call` for {[email protected]:3:13: 3:16}(v8{`c`}, v10{`c(2)`})
(v12: core::array::Array::<core::felt252>, v13: @core::array::Array::<core::felt252>) <- snapshot(v4{`x`})
(v14: core::integer::u32) <- 0
(v15: @core::felt252) <- core::ops::index::DeprecatedIndexViewImpl::<core::array::Array::<core::felt252>, core::integer::u32, @core::felt252, core::array::ArrayIndex::<core::felt252>>::index(v13{`x`}, v14{`0`})
Expand Down
40 changes: 18 additions & 22 deletions crates/cairo-lang-lowering/src/ids.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
use cairo_lang_debug::DebugWithDb;
use cairo_lang_defs::ids::UnstableSalsaId;
use cairo_lang_defs::ids::{
NamedLanguageElementId, TopLevelLanguageElementId, TraitFunctionId, UnstableSalsaId,
};
use cairo_lang_diagnostics::{DiagnosticAdded, DiagnosticNote, Maybe};
use cairo_lang_proc_macros::{DebugWithDb, SemanticObject};
use cairo_lang_semantic::corelib::panic_destruct_trait_fn;
use cairo_lang_semantic::items::functions::ImplGenericFunctionId;
use cairo_lang_semantic::items::imp::ImplLongId;
use cairo_lang_semantic::{GenericArgumentId, TypeLongId};
use cairo_lang_syntax::node::{TypedStablePtr, ast};
use cairo_lang_utils::{
Intern, LookupIntern, define_short_id, extract_matches, try_extract_matches,
};
use cairo_lang_utils::{Intern, LookupIntern, define_short_id, try_extract_matches};
use defs::diagnostic_utils::StableLocation;
use defs::ids::{ExternFunctionId, FreeFunctionId};
use semantic::items::functions::GenericFunctionId;
Expand Down Expand Up @@ -106,12 +106,7 @@ impl ConcreteFunctionWithBodyId {
ConcreteFunctionWithBodyLongId::Generated(GeneratedFunction {
parent: _,
key: GeneratedFunctionKey::TraitFunc(function, _),
}) => Ok(extract_matches!(
function.get_concrete(db.upcast()).generic_function,
GenericFunctionId::Impl
)
.function
== panic_destruct_trait_fn(db.upcast())),
}) => Ok(function == panic_destruct_trait_fn(db.upcast())),
_ => Ok(false),
}
}
Expand Down Expand Up @@ -287,7 +282,7 @@ impl FunctionLongId {
return Ok(Some(
GeneratedFunction {
parent,
key: GeneratedFunctionKey::TraitFunc(id, ty.wrapper_location),
key: GeneratedFunctionKey::TraitFunc(function, ty.wrapper_location),
}
.body(db),
));
Expand Down Expand Up @@ -385,7 +380,7 @@ impl<'a> DebugWithDb<dyn LoweringGroup + 'a> for FunctionLongId {
pub enum GeneratedFunctionKey {
/// Generated loop functions are identified by the loop expr_id.
Loop(semantic::ExprId),
TraitFunc(semantic::FunctionId, StableLocation),
TraitFunc(TraitFunctionId, StableLocation),
}

/// Generated function.
Expand All @@ -396,17 +391,11 @@ pub struct GeneratedFunction {
}
impl GeneratedFunction {
pub fn body(&self, db: &dyn LoweringGroup) -> ConcreteFunctionWithBodyId {
let GeneratedFunction { parent, key } = *self;
let long_id = ConcreteFunctionWithBodyLongId::Generated(GeneratedFunction { parent, key });
let long_id = ConcreteFunctionWithBodyLongId::Generated(*self);
long_id.intern(db)
}
pub fn full_path(&self, db: &dyn LoweringGroup) -> String {
match self.key {
GeneratedFunctionKey::Loop(expr_id) => {
format!("{}[expr{}]", self.parent.full_path(db.upcast()), expr_id.index())
}
GeneratedFunctionKey::TraitFunc(trait_func, _) => trait_func.full_path(db.upcast()),
}
format!("{:?}", self.debug(db))
}
}
impl<'a> DebugWithDb<dyn LoweringGroup + 'a> for GeneratedFunction {
Expand All @@ -419,8 +408,15 @@ impl<'a> DebugWithDb<dyn LoweringGroup + 'a> for GeneratedFunction {
GeneratedFunctionKey::Loop(expr_id) => {
write!(f, "{:?}[expr{}]", self.parent.debug(db), expr_id.index())
}
GeneratedFunctionKey::TraitFunc(trait_func, _) => {
write!(f, "{:?}", trait_func.debug(db))
GeneratedFunctionKey::TraitFunc(trait_func, loc) => {
let trait_id = trait_func.trait_id(db.upcast());
write!(
f,
"Generated `{}::{}` for {{closure@{:?}}}",
trait_id.full_path(db.upcast()),
trait_func.name(db.upcast()),
loc.debug(db.upcast()),
)
}
}
}
Expand Down
10 changes: 2 additions & 8 deletions crates/cairo-lang-lowering/src/lower/generated_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@ use std::fmt::Write;
use cairo_lang_debug::DebugWithDb;
use cairo_lang_defs::ids::TopLevelLanguageElementId;
use cairo_lang_diagnostics::get_location_marks;
use cairo_lang_semantic::items::functions::GenericFunctionId;
use cairo_lang_semantic::test_utils::setup_test_function;
use cairo_lang_test_utils::parse_test_file::TestRunnerResult;
use cairo_lang_utils::Intern;
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
use cairo_lang_utils::{Intern, LookupIntern, extract_matches};

use crate::db::LoweringGroup;
use crate::fmt::LoweredFormatter;
Expand Down Expand Up @@ -73,12 +72,7 @@ fn test_generated_function(

let func_description = match key {
crate::ids::GeneratedFunctionKey::Loop(_) => "loop".into(),
crate::ids::GeneratedFunctionKey::TraitFunc(func, _) => extract_matches!(
func.lookup_intern(db).function.generic_function,
GenericFunctionId::Impl
)
.function
.full_path(db),
crate::ids::GeneratedFunctionKey::TraitFunc(func, _) => func.full_path(db),
};

writeln!(
Expand Down
11 changes: 6 additions & 5 deletions crates/cairo-lang-lowering/src/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1719,7 +1719,7 @@ fn add_capture_destruct_impl(
let signature =
Signature::from_semantic(ctx.db, semantic_db.concrete_function_signature(function)?);

let func_key = GeneratedFunctionKey::TraitFunc(function, location);
let func_key = GeneratedFunctionKey::TraitFunc(trait_function, location);
let function_id =
FunctionWithBodyLongId::Generated { parent: ctx.semantic_function_id, key: func_key }
.intern(ctx.db);
Expand Down Expand Up @@ -1824,7 +1824,7 @@ fn add_closure_call_function(
.intern(semantic_db);
let function_with_body_id = FunctionWithBodyLongId::Generated {
parent: encapsulated_ctx.semantic_function_id,
key: GeneratedFunctionKey::TraitFunc(function, closure_ty.wrapper_location),
key: GeneratedFunctionKey::TraitFunc(trait_function, closure_ty.wrapper_location),
}
.intern(encapsulated_ctx.db);
let signature = Signature::from_semantic(
Expand Down Expand Up @@ -1909,9 +1909,10 @@ fn add_closure_call_function(
signature: ctx.signature.clone(),
parameters,
};
encapsulated_ctx
.lowerings
.insert(GeneratedFunctionKey::TraitFunc(function, closure_ty.wrapper_location), lowered);
encapsulated_ctx.lowerings.insert(
GeneratedFunctionKey::TraitFunc(trait_function, closure_ty.wrapper_location),
lowered,
);
Ok(())
}

Expand Down
2 changes: 1 addition & 1 deletion crates/cairo-lang-lowering/src/lower/test_data/closure
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ Statements:
(v2: {[email protected]:2:13: 2:25}, v3: @{[email protected]:2:13: 2:25}) <- snapshot(v1)
(v4: core::felt252) <- 0
(v5: (core::felt252,)) <- struct_construct(v4)
(v6: ()) <- Generated core::ops::function::Fn::<{[email protected]:2:13: 2:25}, (core::felt252,)>::call(v3, v5)
(v6: ()) <- Generated `core::ops::function::Fn::call` for {[email protected]:2:13: 2:25}(v3, v5)
(v7: ()) <- struct_construct()
End:
Return(v7)
Expand Down
17 changes: 17 additions & 0 deletions tests/bug_samples/issue7095.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
trait MyTrait<T> {
fn consume<B, F, +core::ops::Fn<F, (T,)>[Output: B], +Destruct<F>>(
self: T, f: F,
) -> B {
f(self)
}
fn user<+Destruct<T>>(self: T) -> usize {
Self::consume(self, |_v| 1)
}
}

impl MyImpl<T> of MyTrait<T> {}

#[test]
fn test_call() {
MyTrait::<u8>::user(0);
}
1 change: 1 addition & 0 deletions tests/bug_samples/lib.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ mod issue7038;
mod issue7060;
mod issue7071;
mod issue7083;
mod issue7095;
mod loop_break_in_match;
mod loop_only_change;
mod partial_param_local;
Expand Down

0 comments on commit e28ef5f

Please sign in to comment.