From 583014c289cd2efc4651c88210e48206cbffbfc3 Mon Sep 17 00:00:00 2001 From: orizi <104711814+orizi@users.noreply.github.com> Date: Tue, 14 Jan 2025 23:38:48 +0100 Subject: [PATCH] Inferring closures types within higher order functions --- corelib/src/test/option_test.cairo | 2 +- .../src/lower/test_data/for | 8 +- .../src/lower/test_data/loop | 20 ++--- crates/cairo-lang-lowering/src/test_data/for | 2 +- .../cairo-lang-lowering/src/test_data/strings | 2 +- crates/cairo-lang-semantic/src/db.rs | 18 +++- .../cairo-lang-semantic/src/expr/compute.rs | 82 +++++++++++++++---- .../src/expr/test_data/closure | 20 +++++ .../src/items/functions.rs | 67 +++++++++++++-- .../account__account.contract_class.json | 4 +- .../test_data/account__account.sierra | 8 +- ...age__libfuncs_coverage.contract_class.json | 16 ++-- ...ibfuncs_coverage__libfuncs_coverage.sierra | 32 ++++---- 13 files changed, 210 insertions(+), 71 deletions(-) diff --git a/corelib/src/test/option_test.cairo b/corelib/src/test/option_test.cairo index 53ce0448b2b..dc6441c9a0b 100644 --- a/corelib/src/test/option_test.cairo +++ b/corelib/src/test/option_test.cairo @@ -209,7 +209,7 @@ fn test_default_for_option() { #[test] fn test_option_some_map() { let maybe_some_string: Option = Option::Some("Hello, World!"); - let maybe_some_len = maybe_some_string.map(|s: ByteArray| s.len()); + let maybe_some_len = maybe_some_string.map(|s| s.len()); assert!(maybe_some_len == Option::Some(13)); } diff --git a/crates/cairo-lang-lowering/src/lower/test_data/for b/crates/cairo-lang-lowering/src/lower/test_data/for index d8ef132d47b..a657d19f7a4 100644 --- a/crates/cairo-lang-lowering/src/lower/test_data/for +++ b/crates/cairo-lang-lowering/src/lower/test_data/for @@ -45,7 +45,7 @@ Statements: (v16: core::array::Array::, v17: @core::array::Array::) <- snapshot(v15) (v18: core::array::Span::) <- core::array::ArrayToSpan::::span(v17) (v19: core::array::SpanIter::) <- core::array::SpanIntoIterator::::into_iter(v18) - (v21: core::array::SpanIter::, v22: core::felt252, v20: ()) <- test::foo[expr30](v19, v0, v2) + (v21: core::array::SpanIter::, v22: core::felt252, v20: ()) <- test::foo[expr38](v19, v0, v2) End: Return(v22) @@ -69,7 +69,7 @@ Statements: (v14: core::array::Array::, v15: @core::array::Array::) <- snapshot(v10) (v16: core::array::Span::) <- struct_construct(v15) (v17: core::array::SpanIter::) <- struct_construct(v16) - (v18: core::RangeCheck, v19: core::gas::GasBuiltin, v20: core::panics::PanicResult::<(core::array::SpanIter::, core::felt252, ())>) <- test::foo[expr30](v0, v1, v17, v11, v13) + (v18: core::RangeCheck, v19: core::gas::GasBuiltin, v20: core::panics::PanicResult::<(core::array::SpanIter::, core::felt252, ())>) <- test::foo[expr38](v0, v1, v17, v11, v13) End: Match(match_enum(v20) { PanicResult::Ok(v21) => blk1, @@ -111,7 +111,7 @@ End: blk1: Statements: (v6: core::felt252) <- core::Felt252Add::add(v1, v2) - (v8: core::array::SpanIter::, v9: core::felt252, v7: ()) <- test::foo[expr30](v4, v6, v2) + (v8: core::array::SpanIter::, v9: core::felt252, v7: ()) <- test::foo[expr38](v4, v6, v2) End: Goto(blk3, {v9 -> v12, v8 -> v13, v7 -> v11}) @@ -195,7 +195,7 @@ End: blk8: Statements: (v30: core::felt252) <- core::felt252_add(v3, v4) - (v31: core::RangeCheck, v32: core::gas::GasBuiltin, v33: core::panics::PanicResult::<(core::array::SpanIter::, core::felt252, ())>) <- test::foo[expr30](v5, v6, v27, v30, v4) + (v31: core::RangeCheck, v32: core::gas::GasBuiltin, v33: core::panics::PanicResult::<(core::array::SpanIter::, core::felt252, ())>) <- test::foo[expr38](v5, v6, v27, v30, v4) End: Return(v31, v32, v33) diff --git a/crates/cairo-lang-lowering/src/lower/test_data/loop b/crates/cairo-lang-lowering/src/lower/test_data/loop index 14998075155..a48c96def51 100644 --- a/crates/cairo-lang-lowering/src/lower/test_data/loop +++ b/crates/cairo-lang-lowering/src/lower/test_data/loop @@ -1064,7 +1064,7 @@ Statements: (v1: test::T) <- struct_construct(v0) (v2: test::S) <- struct_destructure(v1) (v3: test::S, v4: @test::S) <- snapshot(v2) - (v5: ()) <- test::foo[expr7](v4) + (v5: ()) <- test::foo[expr8](v4) (v6: ()) <- test::STT::f1oo(v4) (v7: ()) <- struct_construct() End: @@ -1158,7 +1158,7 @@ Statements: (v4: test::B, v5: @test::B) <- snapshot(v3) (v6: core::integer::u32) <- struct_destructure(v4) (v7: test::B) <- struct_construct(v6) - (v8: ()) <- test::foo[expr16](v6, v5) + (v8: ()) <- test::foo[expr18](v6, v5) (v9: ()) <- struct_construct() End: Return(v9) @@ -1214,7 +1214,7 @@ Parameters: v0: core::integer::u32, v1: @test::B blk0 (root): Statements: () <- test::ex1(v1) - (v2: ()) <- test::foo[expr14](v0) + (v2: ()) <- test::foo[expr16](v0) (v3: ()) <- struct_construct() End: Return(v3) @@ -1282,7 +1282,7 @@ Statements: (v1: test::B) <- struct_construct(v0) (v2: test::A) <- struct_construct(v1) (v3: test::B) <- struct_destructure(v2) - (v4: ()) <- test::foo[expr16](v3) + (v4: ()) <- test::foo[expr18](v3) (v5: ()) <- struct_construct() End: Return(v5) @@ -1339,7 +1339,7 @@ Statements: () <- test::ex1(v0) (v1: core::integer::u32) <- struct_destructure(v0) (v2: core::integer::u32, v3: @core::integer::u32) <- snapshot(v1) - (v4: ()) <- test::foo[expr14](v3) + (v4: ()) <- test::foo[expr16](v3) (v5: ()) <- struct_construct() End: Return(v5) @@ -1399,7 +1399,7 @@ Statements: (v0: core::felt252) <- 0 (v1: test::A) <- struct_construct(v0) (v2: core::felt252) <- 0 - (v4: test::A, v5: core::felt252, v3: ()) <- test::foo[expr19](v2, v1) + (v4: test::A, v5: core::felt252, v3: ()) <- test::foo[expr20](v2, v1) (v6: ()) <- struct_construct() End: Return(v6) @@ -1412,7 +1412,7 @@ Statements: (v2: core::felt252) <- 0 (v3: core::felt252) <- 0 (v4: test::A) <- struct_construct(v2) - (v5: core::RangeCheck, v6: core::gas::GasBuiltin, v7: core::panics::PanicResult::<(test::A, core::felt252, ())>) <- test::foo[expr19](v0, v1, v3, v4) + (v5: core::RangeCheck, v6: core::gas::GasBuiltin, v7: core::panics::PanicResult::<(test::A, core::felt252, ())>) <- test::foo[expr20](v0, v1, v3, v4) End: Match(match_enum(v7) { PanicResult::Ok(v8) => blk1, @@ -1462,7 +1462,7 @@ Statements: () <- test::use_a(v12) (v13: core::felt252) <- 1 (v15: core::felt252, v14: ()) <- core::ops::arith::DeprecatedAddAssign::::add_assign(v2, v13) - (v17: test::A, v18: core::felt252, v16: ()) <- test::foo[expr19](v15, v11) + (v17: test::A, v18: core::felt252, v16: ()) <- test::foo[expr20](v15, v11) End: Goto(blk3, {v17 -> v21, v18 -> v22, v16 -> v20}) @@ -1513,7 +1513,7 @@ Statements: () <- test::use_a(v16) (v17: core::felt252) <- 1 (v18: core::felt252) <- core::felt252_add(v2, v17) - (v19: core::RangeCheck, v20: core::gas::GasBuiltin, v21: core::panics::PanicResult::<(test::A, core::felt252, ())>) <- test::foo[expr19](v4, v5, v18, v15) + (v19: core::RangeCheck, v20: core::gas::GasBuiltin, v21: core::panics::PanicResult::<(test::A, core::felt252, ())>) <- test::foo[expr20](v4, v5, v18, v15) End: Return(v19, v20, v21) @@ -1584,7 +1584,7 @@ Parameters: v0: core::RangeCheck, v1: core::gas::GasBuiltin blk0 (root): Statements: (v2: core::integer::u8) <- 0 - (v3: core::RangeCheck, v4: core::gas::GasBuiltin, v5: core::panics::PanicResult::<(core::integer::u8, ())>) <- test::MyImpl::impl_in_trait[expr9](v0, v1, v2) + (v3: core::RangeCheck, v4: core::gas::GasBuiltin, v5: core::panics::PanicResult::<(core::integer::u8, ())>) <- test::MyImpl::impl_in_trait[expr10](v0, v1, v2) End: Match(match_enum(v5) { PanicResult::Ok(v6) => blk1, diff --git a/crates/cairo-lang-lowering/src/test_data/for b/crates/cairo-lang-lowering/src/test_data/for index 9fad6ec3029..39bebf1275d 100644 --- a/crates/cairo-lang-lowering/src/test_data/for +++ b/crates/cairo-lang-lowering/src/test_data/for @@ -39,7 +39,7 @@ Statements: (v11: core::array::Array::, v12: @core::array::Array::) <- snapshot(v10) (v13: core::array::Span::) <- struct_construct(v12) (v14: core::array::SpanIter::) <- struct_construct(v13) - (v15: core::RangeCheck, v16: core::gas::GasBuiltin, v17: core::panics::PanicResult::<(core::array::SpanIter::, ())>) <- test::foo[expr26](v0, v1, v14) + (v15: core::RangeCheck, v16: core::gas::GasBuiltin, v17: core::panics::PanicResult::<(core::array::SpanIter::, ())>) <- test::foo[expr34](v0, v1, v14) End: Match(match_enum(v17) { PanicResult::Ok(v18) => blk1, diff --git a/crates/cairo-lang-lowering/src/test_data/strings b/crates/cairo-lang-lowering/src/test_data/strings index 1b04f79c88f..dc75e4a5138 100644 --- a/crates/cairo-lang-lowering/src/test_data/strings +++ b/crates/cairo-lang-lowering/src/test_data/strings @@ -201,7 +201,7 @@ Statements: (v8: core::byte_array::ByteArray, v9: @core::byte_array::ByteArray) <- snapshot(v7) (v10: @core::array::Array::, v11: @core::felt252, v12: @core::integer::u32) <- struct_destructure(v9) (v13: core::array::Span::) <- struct_construct(v10) - (v14: core::RangeCheck, v15: core::gas::GasBuiltin, v16: core::panics::PanicResult::<(core::array::Span::, core::array::Array::, ())>) <- core::array::ArrayTCloneImpl::clone[expr14](v0, v1, v13, v4) + (v14: core::RangeCheck, v15: core::gas::GasBuiltin, v16: core::panics::PanicResult::<(core::array::Span::, core::array::Array::, ())>) <- core::array::ArrayTCloneImpl::clone[expr15](v0, v1, v13, v4) End: Match(match_enum(v16) { PanicResult::Ok(v17) => blk1, diff --git a/crates/cairo-lang-semantic/src/db.rs b/crates/cairo-lang-semantic/src/db.rs index a5cdb3995a5..51b3fcc4988 100644 --- a/crates/cairo-lang-semantic/src/db.rs +++ b/crates/cairo-lang-semantic/src/db.rs @@ -26,7 +26,7 @@ use crate::diagnostic::SemanticDiagnosticKind; use crate::expr::inference::{self, ImplVar, ImplVarId}; use crate::items::constant::{ConstCalcInfo, ConstValueId, Constant, ImplConstantId}; use crate::items::function_with_body::FunctionBody; -use crate::items::functions::{ImplicitPrecedence, InlineConfiguration}; +use crate::items::functions::{GenericFunctionId, ImplicitPrecedence, InlineConfiguration}; use crate::items::generics::{GenericParam, GenericParamData, GenericParamsData}; use crate::items::imp::{ ImplId, ImplImplId, ImplLookupContext, ImplicitImplImplData, UninferredImpl, @@ -1431,6 +1431,22 @@ pub trait SemanticGroup: #[salsa::invoke(items::functions::concrete_function_signature)] fn concrete_function_signature(&self, function_id: FunctionId) -> Maybe; + /// Returns a mapping of closure types to their associated parameter types for a concrete + /// function. + #[salsa::invoke(items::functions::concrete_function_closure_params)] + fn concrete_function_closure_params( + &self, + function_id: FunctionId, + ) -> Maybe>; + + /// Returns a mapping of closure types to their associated parameter types for a generic + /// function. + #[salsa::invoke(items::functions::get_closure_params)] + fn get_closure_params( + &self, + generic_function_id: GenericFunctionId, + ) -> Maybe>; + // Generic type. // ============= /// Returns the generic params of a generic type. diff --git a/crates/cairo-lang-semantic/src/expr/compute.rs b/crates/cairo-lang-semantic/src/expr/compute.rs index 89bb0cb858e..33d8385f066 100644 --- a/crates/cairo-lang-semantic/src/expr/compute.rs +++ b/crates/cairo-lang-semantic/src/expr/compute.rs @@ -22,7 +22,7 @@ use cairo_lang_filesystem::ids::{FileKind, FileLongId, VirtualFile}; use cairo_lang_proc_macros::DebugWithDb; use cairo_lang_syntax::node::ast::{ BinaryOperator, BlockOrIf, ClosureParamWrapper, ExprPtr, OptionReturnTypeClause, PatternListOr, - PatternStructParam, UnaryOperator, + PatternStructParam, TerminalIdentifier, UnaryOperator, }; use cairo_lang_syntax::node::db::SyntaxGroup; use cairo_lang_syntax::node::helpers::{GetIdentifier, PathSegmentEx}; @@ -66,7 +66,7 @@ use crate::expr::inference::{ImplVarTraitItemMappings, InferenceId}; use crate::items::constant::{ConstValue, resolve_const_expr_and_evaluate, validate_const_expr}; use crate::items::enm::SemanticEnumEx; use crate::items::feature_kind::extract_item_feature_config; -use crate::items::functions::function_signature_params; +use crate::items::functions::{concrete_function_closure_params, function_signature_params}; use crate::items::imp::{ImplLookupContext, filter_candidate_traits, infer_impl_by_self}; use crate::items::modifiers::compute_mutability; use crate::items::us::get_use_path_segments; @@ -424,7 +424,7 @@ pub fn maybe_compute_expr_semantic( ast::Expr::Indexed(expr) => compute_expr_indexed_semantic(ctx, expr), ast::Expr::FixedSizeArray(expr) => compute_expr_fixed_size_array_semantic(ctx, expr), ast::Expr::For(expr) => compute_expr_for_semantic(ctx, expr), - ast::Expr::Closure(expr) => compute_expr_closure_semantic(ctx, expr), + ast::Expr::Closure(expr) => compute_expr_closure_semantic(ctx, expr, None), } } @@ -882,7 +882,7 @@ fn compute_expr_function_call_semantic( let mut arg_types = vec![]; for arg_syntax in args_iter { let stable_ptr = arg_syntax.stable_ptr(); - let arg = compute_named_argument_clause(ctx, arg_syntax); + let arg = compute_named_argument_clause(ctx, arg_syntax, None); if arg.2 != Mutability::Immutable { return Err(ctx.diagnostics.report(stable_ptr, RefClosureArgument)); } @@ -930,7 +930,7 @@ fn compute_expr_function_call_semantic( let named_args: Vec<_> = args_syntax .elements(syntax_db) .into_iter() - .map(|arg_syntax| compute_named_argument_clause(ctx, arg_syntax)) + .map(|arg_syntax| compute_named_argument_clause(ctx, arg_syntax, None)) .collect(); if named_args.len() != 1 { return Err(ctx.diagnostics.report(syntax, WrongNumberOfArguments { @@ -979,16 +979,21 @@ fn compute_expr_function_call_semantic( let mut args_iter = args_syntax.elements(syntax_db).into_iter(); // Normal parameters let mut named_args = vec![]; - for _ in function_parameter_types(ctx, function)? { + let closure_params = concrete_function_closure_params(db, function)?; + for ty in function_parameter_types(ctx, function)? { let Some(arg_syntax) = args_iter.next() else { continue; }; - named_args.push(compute_named_argument_clause(ctx, arg_syntax)); + named_args.push(compute_named_argument_clause( + ctx, + arg_syntax, + closure_params.get(&ty).cloned(), + )); } // Maybe coupon if let Some(arg_syntax) = args_iter.next() { - named_args.push(compute_named_argument_clause(ctx, arg_syntax)); + named_args.push(compute_named_argument_clause(ctx, arg_syntax, None)); } expr_function_call(ctx, function, named_args, syntax, syntax.stable_ptr().into()) @@ -1006,6 +1011,7 @@ fn compute_expr_function_call_semantic( pub fn compute_named_argument_clause( ctx: &mut ComputationContext<'_>, arg_syntax: ast::Arg, + closure_param_types: Option, ) -> NamedArg { let syntax_db = ctx.db.upcast(); @@ -1017,11 +1023,16 @@ pub fn compute_named_argument_clause( let arg_clause = arg_syntax.arg_clause(syntax_db); let (expr, arg_name_identifier) = match arg_clause { - ast::ArgClause::Unnamed(arg_unnamed) => { - (compute_expr_semantic(ctx, &arg_unnamed.value(syntax_db)), None) - } - ast::ArgClause::Named(arg_named) => ( - compute_expr_semantic(ctx, &arg_named.value(syntax_db)), + ast::ArgClause::Unnamed(arg_unnamed) => handle_possible_closure_expr( + ctx, + &arg_unnamed.value(syntax_db), + closure_param_types, + None, + ), + ast::ArgClause::Named(arg_named) => handle_possible_closure_expr( + ctx, + &arg_named.value(syntax_db), + closure_param_types, Some(arg_named.name(syntax_db)), ), ast::ArgClause::FieldInitShorthand(arg_field_init_shorthand) => { @@ -1034,10 +1045,32 @@ pub fn compute_named_argument_clause( (expr, Some(arg_name_identifier)) } }; - NamedArg(expr, arg_name_identifier, mutability) } +/// Handles the semantic computation of a closure expression. +/// It processes a closure expression, computes its semantic model, +/// allocates it in the expression arena, and ensures that the closure's +/// parameter types are conformed if provided. +fn handle_possible_closure_expr( + ctx: &mut ComputationContext<'_>, + expr: &ast::Expr, + closure_param_types: Option, + arg_name: Option, +) -> (ExprAndId, Option) { + if let ast::Expr::Closure(expr_closure) = expr { + let expr = compute_expr_closure_semantic(ctx, expr_closure, closure_param_types); + let expr = wrap_maybe_with_missing(ctx, expr, expr_closure.stable_ptr().into()); + let id = ctx.arenas.exprs.alloc(expr.clone()); + (ExprAndId { expr, id }, arg_name) + } else { + let expr = compute_expr_semantic(ctx, expr); + let expr = wrap_maybe_with_missing(ctx, Ok(expr.expr.clone()), expr.stable_ptr()); + let id = ctx.arenas.exprs.alloc(expr.clone()); + (ExprAndId { expr, id }, arg_name) + } +} + pub fn compute_root_expr( ctx: &mut ComputationContext<'_>, syntax: &ast::ExprBlock, @@ -1645,6 +1678,7 @@ fn compute_loop_body_semantic( fn compute_expr_closure_semantic( ctx: &mut ComputationContext<'_>, syntax: &ast::ExprClosure, + param_types: Option, ) -> Maybe { ctx.are_closures_in_context = true; let syntax_db = ctx.db.upcast(); @@ -1663,6 +1697,14 @@ fn compute_expr_closure_semantic( } else { vec![] }; + let closure_type = + TypeLongId::Tuple(params.iter().map(|param| param.ty).collect()).intern(new_ctx.db); + if let Some(param_types) = param_types { + if let Err(err_set) = new_ctx.resolver.inference().conform_ty(closure_type, param_types) + { + new_ctx.resolver.inference().consume_error_without_reporting(err_set); + } + } params.iter().filter(|param| param.mutability == Mutability::Reference).for_each(|param| { new_ctx.diagnostics.report(param.stable_ptr(ctx.db.upcast()), RefClosureParam); @@ -2834,16 +2876,22 @@ fn method_call_expr( // Self argument. let mut named_args = vec![NamedArg(fixed_lexpr, None, mutability)]; // Other arguments. - for _ in function_parameter_types(ctx, function_id)?.skip(1) { + let closure_params: OrderedHashMap = + concrete_function_closure_params(ctx.db, function_id)?; + for ty in function_parameter_types(ctx, function_id)?.skip(1) { let Some(arg_syntax) = args_iter.next() else { break; }; - named_args.push(compute_named_argument_clause(ctx, arg_syntax)); + named_args.push(compute_named_argument_clause( + ctx, + arg_syntax, + closure_params.get(&ty).cloned(), + )); } // Maybe coupon if let Some(arg_syntax) = args_iter.next() { - named_args.push(compute_named_argument_clause(ctx, arg_syntax)); + named_args.push(compute_named_argument_clause(ctx, arg_syntax, None)); } expr_function_call(ctx, function_id, named_args, &expr, stable_ptr) diff --git a/crates/cairo-lang-semantic/src/expr/test_data/closure b/crates/cairo-lang-semantic/src/expr/test_data/closure index d50eb9e0af7..8f907f5ff3e 100644 --- a/crates/cairo-lang-semantic/src/expr/test_data/closure +++ b/crates/cairo-lang-semantic/src/expr/test_data/closure @@ -768,3 +768,23 @@ error: Closure parameters cannot be references --> lib.cairo:2:14 let _ = |ref a| { ^^^^^ + +//! > ========================================================================== + +//! > Passing closures as args with less explicit typing. + +//! > test_runner_name +test_function_diagnostics(expect_diagnostics: false) + +//! > function +fn foo() -> Option { + let x: Option> = Option::Some(array![1, 2, 3]); + x.map(|x| x.len()) +} + +//! > function_name +foo + +//! > module_code + +//! > expected_diagnostics diff --git a/crates/cairo-lang-semantic/src/items/functions.rs b/crates/cairo-lang-semantic/src/items/functions.rs index 65dc2ae0d3a..fe463adce88 100644 --- a/crates/cairo-lang-semantic/src/items/functions.rs +++ b/crates/cairo-lang-semantic/src/items/functions.rs @@ -1,4 +1,3 @@ -use std::fmt::Debug; use std::sync::Arc; use cairo_lang_debug::DebugWithDb; @@ -14,6 +13,7 @@ use cairo_lang_proc_macros::{DebugWithDb, SemanticObject}; use cairo_lang_syntax as syntax; use cairo_lang_syntax::attribute::structured::Attribute; use cairo_lang_syntax::node::{Terminal, TypedSyntaxNode, ast}; +use cairo_lang_utils::ordered_hash_map::OrderedHashMap; use cairo_lang_utils::{ Intern, LookupIntern, OptionFrom, define_short_id, require, try_extract_matches, }; @@ -27,7 +27,7 @@ use super::generics::{fmt_generic_args, generic_params_to_args}; use super::imp::{ImplId, ImplLongId}; use super::modifiers; use super::trt::ConcreteTraitGenericFunctionId; -use crate::corelib::{panic_destruct_trait_fn, unit_ty}; +use crate::corelib::{fn_traits, panic_destruct_trait_fn, unit_ty}; use crate::db::SemanticGroup; use crate::diagnostic::{SemanticDiagnosticKind, SemanticDiagnostics, SemanticDiagnosticsBuilder}; use crate::expr::compute::Environment; @@ -35,8 +35,8 @@ use crate::resolve::{Resolver, ResolverData}; use crate::substitution::{GenericSubstitution, SemanticRewriter, SubstitutionRewriter}; use crate::types::resolve_type; use crate::{ - ConcreteImplId, ConcreteImplLongId, ConcreteTraitLongId, GenericParam, SemanticDiagnostic, - TypeId, semantic, semantic_object_for_id, + ConcreteImplId, ConcreteImplLongId, ConcreteTraitLongId, GenericArgumentId, GenericParam, + SemanticDiagnostic, TypeId, semantic, semantic_object_for_id, }; /// A generic function of an impl. @@ -146,8 +146,11 @@ impl GenericFunctionId { GenericFunctionId::Extern(id) => db.extern_function_declaration_generic_params(id), GenericFunctionId::Impl(id) => { let concrete_trait_id = db.impl_concrete_trait(id.impl_id)?; - let id = ConcreteTraitGenericFunctionId::new(db, concrete_trait_id, id.function); - db.concrete_trait_function_generic_params(id) + let concrete_id = + ConcreteTraitGenericFunctionId::new(db, concrete_trait_id, id.function); + let substitution = GenericSubstitution::from_impl(id.impl_id); + let mut rewriter = SubstitutionRewriter { db, substitution: &substitution }; + rewriter.rewrite(db.concrete_trait_function_generic_params(concrete_id)?) } GenericFunctionId::Trait(id) => db.concrete_trait_function_generic_params(id), } @@ -860,6 +863,19 @@ pub fn concrete_function_signature( SubstitutionRewriter { db, substitution: &substitution }.rewrite(generic_signature) } +/// Query implementation of [crate::db::SemanticGroup::concrete_function_closure_params]. +pub fn concrete_function_closure_params( + db: &dyn SemanticGroup, + function_id: FunctionId, +) -> Maybe> { + let ConcreteFunction { generic_function, generic_args, .. } = + function_id.lookup_intern(db).function; + let generic_params = generic_function.generic_params(db)?; + let generic_closure_params = db.get_closure_params(generic_function)?; + let substitution = GenericSubstitution::new(&generic_params, &generic_args); + SubstitutionRewriter { db, substitution: &substitution }.rewrite(generic_closure_params) +} + /// For a given list of AST parameters, returns the list of semantic parameters along with the /// corresponding environment. fn update_env_with_ast_params( @@ -1010,3 +1026,42 @@ impl FromIterator for ImplicitPrecedence { Self(Vec::from_iter(iter)) } } + +/// This function retrieves a mapping of closure types to their associated parameter types. +/// It analyzes the generic parameters of the current context +/// to identify any closures and their respective parameter types. It checks +/// for `Fn`, `FnMut`, or `FnOnce` traits among the generic parameters and +/// returns a `HashMap` where the key is the closure type, and the value is a +/// vector of parameter types. +/// Query implementation of [crate::db::SemanticGroup::get_closure_params]. +pub fn get_closure_params( + db: &dyn SemanticGroup, + generic_function_id: GenericFunctionId, +) -> Maybe> { + let mut closure_params_map = OrderedHashMap::default(); + let generic_params = generic_function_id.generic_params(db)?; + + for param in generic_params { + if let GenericParam::Impl(generic_param_impl) = param { + let trait_id = generic_param_impl.concrete_trait?.trait_id(db); + + if fn_traits(db).contains(&trait_id) { + if let Ok(concrete_trait) = generic_param_impl.concrete_trait { + let [ + GenericArgumentId::Type(closure_type), + GenericArgumentId::Type(params_type), + ] = *concrete_trait.generic_args(db) + else { + unreachable!( + "Fn trait must have exactly two generic arguments: closure type and \ + parameter type." + ) + }; + + closure_params_map.insert(closure_type, params_type); + } + } + } + } + Ok(closure_params_map) +} diff --git a/crates/cairo-lang-starknet/test_data/account__account.contract_class.json b/crates/cairo-lang-starknet/test_data/account__account.contract_class.json index 0446795cf1b..833886b9ee3 100644 --- a/crates/cairo-lang-starknet/test_data/account__account.contract_class.json +++ b/crates/cairo-lang-starknet/test_data/account__account.contract_class.json @@ -1956,7 +1956,7 @@ ], [ 145, - "function_call" + "function_call" ], [ 146, @@ -2342,7 +2342,7 @@ ], [ 11, - "cairo_level_tests::contracts::account::account::AccountContractImpl::__execute__[expr33]" + "cairo_level_tests::contracts::account::account::AccountContractImpl::__execute__[expr41]" ], [ 12, diff --git a/crates/cairo-lang-starknet/test_data/account__account.sierra b/crates/cairo-lang-starknet/test_data/account__account.sierra index 47befe0ba15..7edd505f319 100644 --- a/crates/cairo-lang-starknet/test_data/account__account.sierra +++ b/crates/cairo-lang-starknet/test_data/account__account.sierra @@ -232,7 +232,7 @@ libfunc enum_init>)>> = store_temp>)>>; libfunc array_new> = array_new>; libfunc store_temp>> = store_temp>>; -libfunc function_call = function_call; +libfunc function_call = function_call; libfunc enum_match, core::array::Array::>, ())>> = enum_match, core::array::Array::>, ())>>; libfunc struct_deconstruct, Array>, Unit>> = struct_deconstruct, Array>, Unit>>; libfunc struct_construct>>> = struct_construct>>>; @@ -1336,7 +1336,7 @@ store_temp([19]) -> ([19]); // 1014 store_temp([20]) -> ([20]); // 1015 store_temp>([4]) -> ([4]); // 1016 store_temp>>([52]) -> ([52]); // 1017 -function_call([0], [19], [20], [4], [52]) -> ([53], [54], [55], [56]); // 1018 +function_call([0], [19], [20], [4], [52]) -> ([53], [54], [55], [56]); // 1018 enum_match, core::array::Array::>, ())>>([56]) { fallthrough([57]) 1031([58]) }; // 1019 branch_align() -> (); // 1020 struct_deconstruct, Array>, Unit>>([57]) -> ([59], [60], [61]); // 1021 @@ -1806,7 +1806,7 @@ store_temp([16]) -> ([16]); // 1484 store_temp([17]) -> ([17]); // 1485 store_temp>([9]) -> ([9]); // 1486 store_temp>>([22]) -> ([22]); // 1487 -function_call([5], [16], [17], [9], [22]) -> ([23], [24], [25], [26]); // 1488 +function_call([5], [16], [17], [9], [22]) -> ([23], [24], [25], [26]); // 1488 return([23], [24], [25], [26]); // 1489 branch_align() -> (); // 1490 drop>([9]) -> (); // 1491 @@ -2000,6 +2000,6 @@ cairo_level_tests::contracts::account::account::AccountContractImpl::__execute__ core::array::serialize_array_helper::, core::array::SpanFelt252Serde, core::array::SpanDrop::>@1077([0]: RangeCheck, [1]: GasBuiltin, [2]: core::array::Span::>, [3]: Array) -> (RangeCheck, GasBuiltin, core::panics::PanicResult::<(core::array::Array::, ())>); core::ecdsa::check_ecdsa_signature@1145([0]: RangeCheck, [1]: EcOp, [2]: felt252, [3]: felt252, [4]: felt252, [5]: felt252) -> (RangeCheck, EcOp, core::panics::PanicResult::<(core::bool,)>); core::starknet::account::CallSerde::deserialize@1376([0]: RangeCheck, [1]: core::array::Span::) -> (RangeCheck, core::panics::PanicResult::<(core::array::Span::, core::option::Option::)>); -cairo_level_tests::contracts::account::account::AccountContractImpl::__execute__[expr33]@1467([0]: RangeCheck, [1]: GasBuiltin, [2]: System, [3]: Array, [4]: Array>) -> (RangeCheck, GasBuiltin, System, core::panics::PanicResult::<(core::array::Array::, core::array::Array::>, ())>); +cairo_level_tests::contracts::account::account::AccountContractImpl::__execute__[expr41]@1467([0]: RangeCheck, [1]: GasBuiltin, [2]: System, [3]: Array, [4]: Array>) -> (RangeCheck, GasBuiltin, System, core::panics::PanicResult::<(core::array::Array::, core::array::Array::>, ())>); core::array::serialize_array_helper::@1525([0]: RangeCheck, [1]: GasBuiltin, [2]: core::array::Span::, [3]: Array) -> (RangeCheck, GasBuiltin, core::panics::PanicResult::<(core::array::Array::, ())>); core::array::SpanFelt252Serde::deserialize@1566([0]: RangeCheck, [1]: core::array::Span::) -> (RangeCheck, core::panics::PanicResult::<(core::array::Span::, core::option::Option::>)>); diff --git a/crates/cairo-lang-starknet/test_data/libfuncs_coverage__libfuncs_coverage.contract_class.json b/crates/cairo-lang-starknet/test_data/libfuncs_coverage__libfuncs_coverage.contract_class.json index 33cc02f52a2..9e4ce77c4b1 100644 --- a/crates/cairo-lang-starknet/test_data/libfuncs_coverage__libfuncs_coverage.contract_class.json +++ b/crates/cairo-lang-starknet/test_data/libfuncs_coverage__libfuncs_coverage.contract_class.json @@ -12536,7 +12536,7 @@ ], [ 532, - "function_call" + "function_call" ], [ 533, @@ -12664,7 +12664,7 @@ ], [ 564, - "function_call" + "function_call" ], [ 565, @@ -15192,7 +15192,7 @@ ], [ 1196, - "function_call" + "function_call" ], [ 1197, @@ -15320,7 +15320,7 @@ ], [ 1228, - "function_call" + "function_call" ], [ 1229, @@ -16122,7 +16122,7 @@ ], [ 83, - "core::sha256::compute_sha256_byte_array[expr62]" + "core::sha256::compute_sha256_byte_array[expr67]" ], [ 84, @@ -16134,7 +16134,7 @@ ], [ 86, - "core::sha256::compute_sha256_u32_array[expr20]" + "core::sha256::compute_sha256_u32_array[expr27]" ], [ 87, @@ -16278,7 +16278,7 @@ ], [ 122, - "core::keccak::keccak_u256s_be_inputs[expr12]" + "core::keccak::keccak_u256s_be_inputs[expr14]" ], [ 123, @@ -16286,7 +16286,7 @@ ], [ 124, - "core::starknet::storage_access::inner_write_byte_array[expr56]" + "core::starknet::storage_access::inner_write_byte_array[expr72]" ], [ 125, diff --git a/crates/cairo-lang-starknet/test_data/libfuncs_coverage__libfuncs_coverage.sierra b/crates/cairo-lang-starknet/test_data/libfuncs_coverage__libfuncs_coverage.sierra index 466ee02eea2..8d357033fc9 100644 --- a/crates/cairo-lang-starknet/test_data/libfuncs_coverage__libfuncs_coverage.sierra +++ b/crates/cairo-lang-starknet/test_data/libfuncs_coverage__libfuncs_coverage.sierra @@ -1125,7 +1125,7 @@ libfunc store_temp> = store_temp>; libfunc const_as_immediate> = const_as_immediate>; libfunc snapshot_take = snapshot_take; libfunc store_temp> = store_temp>; -libfunc function_call = function_call; +libfunc function_call = function_call; libfunc enum_match, core::integer::u32, ())>> = enum_match, core::integer::u32, ())>>; libfunc struct_deconstruct, u32, Unit>> = struct_deconstruct, u32, Unit>>; libfunc enable_ap_tracking = enable_ap_tracking; @@ -1157,7 +1157,7 @@ libfunc snapshot_take> = snapshot_take>; libfunc struct_construct> = struct_construct>; libfunc store_temp> = store_temp>; libfunc store_temp = store_temp; -libfunc function_call = function_call; +libfunc function_call = function_call; libfunc enum_match, core::sha256::Sha256StateHandle, ())>> = enum_match, core::sha256::Sha256StateHandle, ())>>; libfunc struct_deconstruct, Sha256StateHandle, Unit>> = struct_deconstruct, Sha256StateHandle, Unit>>; libfunc drop> = drop>; @@ -1789,7 +1789,7 @@ libfunc span_from_tuple> = span_ libfunc array_new = array_new; libfunc store_temp> = store_temp>; libfunc store_local = store_local; -libfunc function_call = function_call; +libfunc function_call = function_call; libfunc store_local = store_local; libfunc enum_match, core::array::Array::, ())>> = enum_match, core::array::Array::, ())>>; libfunc struct_deconstruct, Array, Unit>> = struct_deconstruct, Array, Unit>>; @@ -1821,7 +1821,7 @@ libfunc drop> = drop>; libfunc const_as_immediate> = const_as_immediate>; libfunc struct_construct> = struct_construct>; libfunc store_temp> = store_temp>; -libfunc function_call = function_call; +libfunc function_call = function_call; libfunc enum_match, core::felt252, core::starknet::storage_access::StorageBaseAddress, core::integer::u8, core::result::Result::<(), core::array::Array::>)>> = enum_match, core::felt252, core::starknet::storage_access::StorageBaseAddress, core::integer::u8, core::result::Result::<(), core::array::Array::>)>>; libfunc struct_deconstruct, felt252, StorageBaseAddress, u8, core::result::Result::<(), core::array::Array::>>> = struct_deconstruct, felt252, StorageBaseAddress, u8, core::result::Result::<(), core::array::Array::>>>; libfunc drop> = drop>; @@ -5902,7 +5902,7 @@ store_temp>([40]) -> ([40]); // 3959 store_temp([37]) -> ([37]); // 3960 store_temp>([8]) -> ([8]); // 3961 store_temp([39]) -> ([39]); // 3962 -function_call([33], [1], [40], [37], [8], [39]) -> ([41], [42], [43]); // 3963 +function_call([33], [1], [40], [37], [8], [39]) -> ([41], [42], [43]); // 3963 enum_match, core::integer::u32, ())>>([43]) { fallthrough([44]) 4570([45]) }; // 3964 branch_align() -> (); // 3965 dup([4]) -> ([4], [46]); // 3966 @@ -6255,7 +6255,7 @@ store_temp([42]) -> ([42]); // 4312 store_temp([2]) -> ([2]); // 4313 store_temp>([235]) -> ([235]); // 4314 store_temp([230]) -> ([230]); // 4315 -function_call([225], [42], [2], [235], [230]) -> ([236], [237], [238], [239]); // 4316 +function_call([225], [42], [2], [235], [230]) -> ([236], [237], [238], [239]); // 4316 enum_match, core::sha256::Sha256StateHandle, ())>>([239]) { fallthrough([240]) 4332([241]) }; // 4317 branch_align() -> (); // 4318 struct_deconstruct, Sha256StateHandle, Unit>>([240]) -> ([242], [243], [244]); // 4319 @@ -10513,7 +10513,7 @@ store_temp>([2]) -> ([2]); // 8570 store_temp([98]) -> ([98]); // 8571 store_temp>([95]) -> ([95]); // 8572 store_temp([5]) -> ([5]); // 8573 -function_call([97], [7], [2], [98], [95], [5]) -> ([101], [102], [103]); // 8574 +function_call([97], [7], [2], [98], [95], [5]) -> ([101], [102], [103]); // 8574 return([101], [102], [103]); // 8575 branch_align() -> (); // 8576 drop([100]) -> (); // 8577 @@ -11256,7 +11256,7 @@ store_temp([25]) -> ([25]); // 9313 store_temp([26]) -> ([26]); // 9314 store_temp>([21]) -> ([21]); // 9315 store_temp([27]) -> ([27]); // 9316 -function_call([16], [25], [26], [21], [27]) -> ([31], [32], [33], [34]); // 9317 +function_call([16], [25], [26], [21], [27]) -> ([31], [32], [33], [34]); // 9317 return([31], [32], [33], [34]); // 9318 branch_align() -> (); // 9319 drop>([21]) -> (); // 9320 @@ -11874,7 +11874,7 @@ store_temp([2]) -> ([2]); // 9931 store_temp>([21]) -> ([21]); // 9932 store_temp>([20]) -> ([20]); // 9933 store_local([8], [7]) -> ([7]); // 9934 -function_call([0], [9], [2], [21], [20]) -> ([22], [23], [5], [24]); // 9935 +function_call([0], [9], [2], [21], [20]) -> ([22], [23], [5], [24]); // 9935 store_local([6], [5]) -> ([5]); // 9936 enum_match, core::array::Array::, ())>>([24]) { fallthrough([25]) 9999([26]) }; // 9937 branch_align() -> (); // 9938 @@ -12019,7 +12019,7 @@ store_temp([53]) -> ([53]); // 10076 store_temp([48]) -> ([48]); // 10077 store_temp([51]) -> ([51]); // 10078 store_temp([40]) -> ([40]); // 10079 -function_call([47], [30], [43], [31], [52], [5], [53], [48], [51], [40]) -> ([54], [55], [56], [57], [58]); // 10080 +function_call([47], [30], [43], [31], [52], [5], [53], [48], [51], [40]) -> ([54], [55], [56], [57], [58]); // 10080 enum_match, core::felt252, core::starknet::storage_access::StorageBaseAddress, core::integer::u8, core::result::Result::<(), core::array::Array::>)>>([58]) { fallthrough([59]) 10144([60]) }; // 10081 branch_align() -> (); // 10082 struct_deconstruct, felt252, StorageBaseAddress, u8, core::result::Result::<(), core::array::Array::>>>([59]) -> ([61], [62], [63], [64], [65]); // 10083 @@ -13046,7 +13046,7 @@ store_temp([6]) -> ([6]); // 11103 store_temp([24]) -> ([24]); // 11104 store_temp>([18]) -> ([18]); // 11105 store_temp>([28]) -> ([28]); // 11106 -function_call([23], [6], [24], [18], [28]) -> ([30], [31], [32], [33]); // 11107 +function_call([23], [6], [24], [18], [28]) -> ([30], [31], [32], [33]); // 11107 return([30], [31], [32], [33]); // 11108 branch_align() -> (); // 11109 disable_ap_tracking() -> (); // 11110 @@ -13362,7 +13362,7 @@ store_temp([6]) -> ([6]); // 11419 store_temp([46]) -> ([46]); // 11420 store_temp([47]) -> ([47]); // 11421 store_temp([45]) -> ([45]); // 11422 -function_call([43], [33], [44], [34], [23], [5], [6], [46], [47], [45]) -> ([61], [62], [63], [64], [65]); // 11423 +function_call([43], [33], [44], [34], [23], [5], [6], [46], [47], [45]) -> ([61], [62], [63], [64], [65]); // 11423 return([61], [62], [63], [64], [65]); // 11424 branch_align() -> (); // 11425 disable_ap_tracking() -> (); // 11426 @@ -13758,10 +13758,10 @@ cairo_level_tests::contracts::libfuncs_coverage::use_and_panic::, core::traits::PanicDestructForDestruct::, core::traits::DestructFromDrop::, core::option::OptionDrop::>>>@8340([0]: core::option::Option::) -> (core::panics::PanicResult::<((),)>); cairo_level_tests::contracts::libfuncs_coverage::use_and_panic::, core::traits::PanicDestructForDestruct::, core::traits::DestructFromDrop::, core::option::OptionDrop::>>>@8359([0]: core::option::Option::) -> (core::panics::PanicResult::<((),)>); core::integer::u256_wide_mul@8378([0]: RangeCheck, [1]: core::integer::u256, [2]: core::integer::u256) -> (RangeCheck, core::integer::u512); -core::sha256::compute_sha256_byte_array[expr62]@8472([0]: RangeCheck, [1]: GasBuiltin, [2]: Snapshot, [3]: u32, [4]: Array, [5]: u32) -> (RangeCheck, GasBuiltin, core::panics::PanicResult::<(core::array::Array::, core::integer::u32, ())>); +core::sha256::compute_sha256_byte_array[expr67]@8472([0]: RangeCheck, [1]: GasBuiltin, [2]: Snapshot, [3]: u32, [4]: Array, [5]: u32) -> (RangeCheck, GasBuiltin, core::panics::PanicResult::<(core::array::Array::, core::integer::u32, ())>); core::byte_array::ByteArrayImpl::at@8887([0]: RangeCheck, [1]: Snapshot, [2]: u32) -> (RangeCheck, core::panics::PanicResult::<(core::option::Option::,)>); core::sha256::add_sha256_padding@9107([0]: RangeCheck, [1]: Array, [2]: u32, [3]: u32) -> (RangeCheck, core::panics::PanicResult::<(core::array::Array::, ())>); -core::sha256::compute_sha256_u32_array[expr20]@9287([0]: RangeCheck, [1]: GasBuiltin, [2]: System, [3]: core::array::Span::, [4]: Sha256StateHandle) -> (RangeCheck, GasBuiltin, System, core::panics::PanicResult::<(core::array::Span::, core::sha256::Sha256StateHandle, ())>); +core::sha256::compute_sha256_u32_array[expr27]@9287([0]: RangeCheck, [1]: GasBuiltin, [2]: System, [3]: core::array::Span::, [4]: Sha256StateHandle) -> (RangeCheck, GasBuiltin, System, core::panics::PanicResult::<(core::array::Span::, core::sha256::Sha256StateHandle, ())>); cairo_level_tests::contracts::libfuncs_coverage::use_and_panic::, core::traits::PanicDestructForDestruct::, core::traits::DestructFromDrop::, core::array::ArrayDrop::>>>@9355([0]: Array) -> (core::panics::PanicResult::<((),)>); cairo_level_tests::contracts::libfuncs_coverage::use_and_panic::<(), core::traits::PanicDestructForDestruct::<(), core::traits::DestructFromDrop::<(), core::traits::TupleSize0Drop>>>@9374([0]: Unit) -> (core::panics::PanicResult::<((),)>); cairo_level_tests::contracts::libfuncs_coverage::use_and_panic_drop::, core::option::OptionDrop::>@9393([0]: core::option::Option::) -> (core::panics::PanicResult::<((),)>); @@ -13797,9 +13797,9 @@ core::integer::I128Mul::mul@10567([0]: RangeCheck, [1]: i128, [2]: i128) -> (Ran core::bytes_31::Bytes31Impl::at@10650([0]: RangeCheck, [1]: bytes31, [2]: u32) -> (RangeCheck, core::panics::PanicResult::<(core::integer::u8,)>); core::bytes_31::one_shift_left_bytes_u128_nz@10733([0]: RangeCheck, [1]: u32) -> (RangeCheck, core::panics::PanicResult::<(core::zeroable::NonZero::,)>); core::sha256::append_zeros@10834([0]: Array, [1]: felt252) -> (Array); -core::keccak::keccak_u256s_be_inputs[expr12]@11070([0]: RangeCheck, [1]: GasBuiltin, [2]: Bitwise, [3]: core::array::Span::, [4]: Array) -> (RangeCheck, GasBuiltin, Bitwise, core::panics::PanicResult::<(core::array::Span::, core::array::Array::, ())>); +core::keccak::keccak_u256s_be_inputs[expr14]@11070([0]: RangeCheck, [1]: GasBuiltin, [2]: Bitwise, [3]: core::array::Span::, [4]: Array) -> (RangeCheck, GasBuiltin, Bitwise, core::panics::PanicResult::<(core::array::Span::, core::array::Array::, ())>); core::keccak::add_padding@11144([0]: RangeCheck, [1]: GasBuiltin, [2]: Array, [3]: u64, [4]: u32) -> (RangeCheck, GasBuiltin, core::panics::PanicResult::<(core::array::Array::, ())>); -core::starknet::storage_access::inner_write_byte_array[expr56]@11347([0]: RangeCheck, [1]: GasBuiltin, [2]: Poseidon, [3]: System, [4]: core::array::Span::, [5]: StorageAddress, [6]: u32, [7]: StorageBaseAddress, [8]: u8, [9]: felt252) -> (RangeCheck, GasBuiltin, Poseidon, System, core::panics::PanicResult::<(core::array::Span::, core::felt252, core::starknet::storage_access::StorageBaseAddress, core::integer::u8, core::result::Result::<(), core::array::Array::>)>); +core::starknet::storage_access::inner_write_byte_array[expr72]@11347([0]: RangeCheck, [1]: GasBuiltin, [2]: Poseidon, [3]: System, [4]: core::array::Span::, [5]: StorageAddress, [6]: u32, [7]: StorageBaseAddress, [8]: u8, [9]: felt252) -> (RangeCheck, GasBuiltin, Poseidon, System, core::panics::PanicResult::<(core::array::Span::, core::felt252, core::starknet::storage_access::StorageBaseAddress, core::integer::u8, core::result::Result::<(), core::array::Array::>)>); core::integer::u256_overflowing_mul@11473([0]: RangeCheck, [1]: core::integer::u256, [2]: core::integer::u256) -> (RangeCheck, Tuple); core::keccak::keccak_add_u256_be@11586([0]: RangeCheck, [1]: Bitwise, [2]: Array, [3]: core::integer::u256) -> (RangeCheck, Bitwise, core::panics::PanicResult::<(core::array::Array::, ())>); core::keccak::finalize_padding@11673([0]: RangeCheck, [1]: GasBuiltin, [2]: Array, [3]: u32) -> (RangeCheck, GasBuiltin, core::panics::PanicResult::<(core::array::Array::, ())>);