From 654355d1b97de183d42208dfbecf7d67be9e2fda Mon Sep 17 00:00:00 2001 From: orizi <104711814+orizi@users.noreply.github.com> Date: Tue, 14 Jan 2025 23:38:48 +0100 Subject: [PATCH] Fixed generated functions in trait-fns. (#7081) feat(corelib): Iterator::enumerate (#7048) fix: Fix handling of --skip-first argument Update release_crates.sh (#7087) chore: orthographic correction in file if_else (#7088) prevents closure parameters from being declared as refrences (#7078) Refactored bounded_int_trim. (#7062) Added const for starknet types. (#6961) feat(corelib): Iterator::fold (#7084) feat(corelib): Iterator::advance_by (#7059) fix(corelib): Add the #[test] annotation to enumerate test (#7098) feat(corelib): storage vectors iterators (#6941) Extract ModuleHelper from const folding. (#7099) Added support for basic `Into`s in consts. (#7100) Removed taking value for `validate_literal`. (#7101) added closure params to semantic defs in lowering (#7085) Added support for `downcast` in constant context. (#7102) fix(corelib): Add the #[test] annotation to enumerate test (#7098) --- .../cairo-lang-semantic/src/expr/compute.rs | 87 +++++++++++++++---- .../src/expr/test_data/closure | 20 +++++ .../src/items/functions.rs | 58 +++++++++++-- 3 files changed, 141 insertions(+), 24 deletions(-) diff --git a/crates/cairo-lang-semantic/src/expr/compute.rs b/crates/cairo-lang-semantic/src/expr/compute.rs index 89bb0cb858e..895f7402175 100644 --- a/crates/cairo-lang-semantic/src/expr/compute.rs +++ b/crates/cairo-lang-semantic/src/expr/compute.rs @@ -66,7 +66,7 @@ use crate::expr::inference::{ImplVarTraitItemMappings, InferenceId}; use crate::items::constant::{ConstValue, resolve_const_expr_and_evaluate, validate_const_expr}; use crate::items::enm::SemanticEnumEx; use crate::items::feature_kind::extract_item_feature_config; -use crate::items::functions::function_signature_params; +use crate::items::functions::{concrete_function_closure_params, function_signature_params}; use crate::items::imp::{ImplLookupContext, filter_candidate_traits, infer_impl_by_self}; use crate::items::modifiers::compute_mutability; use crate::items::us::get_use_path_segments; @@ -84,8 +84,8 @@ use crate::types::{ }; use crate::usage::Usages; use crate::{ - ConcreteEnumId, GenericArgumentId, GenericParam, LocalItem, Member, Mutability, Parameter, - PatternStringLiteral, PatternStruct, Signature, StatementItemKind, + ConcreteEnumId, ConcreteFunction, GenericArgumentId, GenericParam, LocalItem, Member, + Mutability, Parameter, PatternStringLiteral, PatternStruct, Signature, StatementItemKind, }; /// Expression with its id. @@ -424,7 +424,7 @@ pub fn maybe_compute_expr_semantic( ast::Expr::Indexed(expr) => compute_expr_indexed_semantic(ctx, expr), ast::Expr::FixedSizeArray(expr) => compute_expr_fixed_size_array_semantic(ctx, expr), ast::Expr::For(expr) => compute_expr_for_semantic(ctx, expr), - ast::Expr::Closure(expr) => compute_expr_closure_semantic(ctx, expr), + ast::Expr::Closure(expr) => compute_expr_closure_semantic(ctx, expr, None), } } @@ -882,7 +882,7 @@ fn compute_expr_function_call_semantic( let mut arg_types = vec![]; for arg_syntax in args_iter { let stable_ptr = arg_syntax.stable_ptr(); - let arg = compute_named_argument_clause(ctx, arg_syntax); + let arg = compute_named_argument_clause(ctx, arg_syntax, None); if arg.2 != Mutability::Immutable { return Err(ctx.diagnostics.report(stable_ptr, RefClosureArgument)); } @@ -930,7 +930,7 @@ fn compute_expr_function_call_semantic( let named_args: Vec<_> = args_syntax .elements(syntax_db) .into_iter() - .map(|arg_syntax| compute_named_argument_clause(ctx, arg_syntax)) + .map(|arg_syntax| compute_named_argument_clause(ctx, arg_syntax, None)) .collect(); if named_args.len() != 1 { return Err(ctx.diagnostics.report(syntax, WrongNumberOfArguments { @@ -979,16 +979,22 @@ fn compute_expr_function_call_semantic( let mut args_iter = args_syntax.elements(syntax_db).into_iter(); // Normal parameters let mut named_args = vec![]; - for _ in function_parameter_types(ctx, function)? { + let ConcreteFunction { .. } = function.lookup_intern(db).function; + let closure_params = concrete_function_closure_params(db, function)?; + for ty in function_parameter_types(ctx, function)? { let Some(arg_syntax) = args_iter.next() else { continue; }; - named_args.push(compute_named_argument_clause(ctx, arg_syntax)); + named_args.push(compute_named_argument_clause( + ctx, + arg_syntax, + closure_params.get(&ty).cloned(), + )); } // Maybe coupon if let Some(arg_syntax) = args_iter.next() { - named_args.push(compute_named_argument_clause(ctx, arg_syntax)); + named_args.push(compute_named_argument_clause(ctx, arg_syntax, None)); } expr_function_call(ctx, function, named_args, syntax, syntax.stable_ptr().into()) @@ -1006,6 +1012,7 @@ fn compute_expr_function_call_semantic( pub fn compute_named_argument_clause( ctx: &mut ComputationContext<'_>, arg_syntax: ast::Arg, + closure_param_types: Option, ) -> NamedArg { let syntax_db = ctx.db.upcast(); @@ -1018,12 +1025,38 @@ pub fn compute_named_argument_clause( let arg_clause = arg_syntax.arg_clause(syntax_db); let (expr, arg_name_identifier) = match arg_clause { ast::ArgClause::Unnamed(arg_unnamed) => { - (compute_expr_semantic(ctx, &arg_unnamed.value(syntax_db)), None) + let arg_expr = arg_unnamed.value(syntax_db); + if let ast::Expr::Closure(expr_closure) = arg_expr { + let expr = compute_expr_closure_semantic(ctx, &expr_closure, closure_param_types); + let expr = wrap_maybe_with_missing( + ctx, + expr, + ast::ExprPtr::from(expr_closure.stable_ptr()), + ); + let id = ctx.arenas.exprs.alloc(expr.clone()); + (ExprAndId { expr, id }, None) + } else { + (compute_expr_semantic(ctx, &arg_unnamed.value(syntax_db)), None) + } + } + ast::ArgClause::Named(arg_named) => { + let arg_expr = arg_named.value(syntax_db); + if let ast::Expr::Closure(expr_closure) = arg_expr { + let expr = compute_expr_closure_semantic(ctx, &expr_closure, closure_param_types); + let expr = wrap_maybe_with_missing( + ctx, + expr, + ast::ExprPtr::from(expr_closure.stable_ptr()), + ); + let id = ctx.arenas.exprs.alloc(expr.clone()); + (ExprAndId { expr, id }, None) + } else { + ( + compute_expr_semantic(ctx, &arg_named.value(syntax_db)), + Some(arg_named.name(syntax_db)), + ) + } } - ast::ArgClause::Named(arg_named) => ( - compute_expr_semantic(ctx, &arg_named.value(syntax_db)), - Some(arg_named.name(syntax_db)), - ), ast::ArgClause::FieldInitShorthand(arg_field_init_shorthand) => { let name_expr = arg_field_init_shorthand.name(syntax_db); let stable_ptr: ast::ExprPtr = name_expr.stable_ptr().into(); @@ -1645,6 +1678,7 @@ fn compute_loop_body_semantic( fn compute_expr_closure_semantic( ctx: &mut ComputationContext<'_>, syntax: &ast::ExprClosure, + param_types: Option, ) -> Maybe { ctx.are_closures_in_context = true; let syntax_db = ctx.db.upcast(); @@ -1663,6 +1697,18 @@ fn compute_expr_closure_semantic( } else { vec![] }; + let closure_type = + TypeLongId::Tuple(params.iter().map(|param| param.ty).collect()).intern(new_ctx.db); + if let Some(param_types) = param_types { + if let Err(err_set) = new_ctx.resolver.inference().conform_ty(closure_type, param_types) + { + new_ctx.resolver.inference().consume_error_without_reporting(err_set); + } + } + + params.iter().filter(|param| param.mutability == Mutability::Reference).for_each(|param| { + new_ctx.diagnostics.report(param.stable_ptr(ctx.db.upcast()), RefClosureParam); + }); params.iter().filter(|param| param.mutability == Mutability::Reference).for_each(|param| { new_ctx.diagnostics.report(param.stable_ptr(ctx.db.upcast()), RefClosureParam); @@ -2834,16 +2880,22 @@ fn method_call_expr( // Self argument. let mut named_args = vec![NamedArg(fixed_lexpr, None, mutability)]; // Other arguments. - for _ in function_parameter_types(ctx, function_id)?.skip(1) { + let ConcreteFunction { .. } = function_id.lookup_intern(ctx.db).function; + let closure_params = concrete_function_closure_params(ctx.db, function_id)?; + for ty in function_parameter_types(ctx, function_id)?.skip(1) { let Some(arg_syntax) = args_iter.next() else { break; }; - named_args.push(compute_named_argument_clause(ctx, arg_syntax)); + named_args.push(compute_named_argument_clause( + ctx, + arg_syntax, + closure_params.get(&ty).cloned(), + )); } // Maybe coupon if let Some(arg_syntax) = args_iter.next() { - named_args.push(compute_named_argument_clause(ctx, arg_syntax)); + named_args.push(compute_named_argument_clause(ctx, arg_syntax, None)); } expr_function_call(ctx, function_id, named_args, &expr, stable_ptr) @@ -3263,7 +3315,6 @@ fn expr_function_call( // Check argument names and types. check_named_arguments(&named_args, &signature, ctx)?; - let mut args = Vec::new(); for (NamedArg(arg, _name, mutability), param) in named_args.into_iter().zip(signature.params.iter()) diff --git a/crates/cairo-lang-semantic/src/expr/test_data/closure b/crates/cairo-lang-semantic/src/expr/test_data/closure index d50eb9e0af7..8f907f5ff3e 100644 --- a/crates/cairo-lang-semantic/src/expr/test_data/closure +++ b/crates/cairo-lang-semantic/src/expr/test_data/closure @@ -768,3 +768,23 @@ error: Closure parameters cannot be references --> lib.cairo:2:14 let _ = |ref a| { ^^^^^ + +//! > ========================================================================== + +//! > Passing closures as args with less explicit typing. + +//! > test_runner_name +test_function_diagnostics(expect_diagnostics: false) + +//! > function +fn foo() -> Option { + let x: Option> = Option::Some(array![1, 2, 3]); + x.map(|x| x.len()) +} + +//! > function_name +foo + +//! > module_code + +//! > expected_diagnostics diff --git a/crates/cairo-lang-semantic/src/items/functions.rs b/crates/cairo-lang-semantic/src/items/functions.rs index 65dc2ae0d3a..7d98f2f4fd6 100644 --- a/crates/cairo-lang-semantic/src/items/functions.rs +++ b/crates/cairo-lang-semantic/src/items/functions.rs @@ -1,4 +1,3 @@ -use std::fmt::Debug; use std::sync::Arc; use cairo_lang_debug::DebugWithDb; @@ -14,6 +13,7 @@ use cairo_lang_proc_macros::{DebugWithDb, SemanticObject}; use cairo_lang_syntax as syntax; use cairo_lang_syntax::attribute::structured::Attribute; use cairo_lang_syntax::node::{Terminal, TypedSyntaxNode, ast}; +use cairo_lang_utils::ordered_hash_map::OrderedHashMap; use cairo_lang_utils::{ Intern, LookupIntern, OptionFrom, define_short_id, require, try_extract_matches, }; @@ -27,7 +27,7 @@ use super::generics::{fmt_generic_args, generic_params_to_args}; use super::imp::{ImplId, ImplLongId}; use super::modifiers; use super::trt::ConcreteTraitGenericFunctionId; -use crate::corelib::{panic_destruct_trait_fn, unit_ty}; +use crate::corelib::{fn_traits, panic_destruct_trait_fn, unit_ty}; use crate::db::SemanticGroup; use crate::diagnostic::{SemanticDiagnosticKind, SemanticDiagnostics, SemanticDiagnosticsBuilder}; use crate::expr::compute::Environment; @@ -35,8 +35,8 @@ use crate::resolve::{Resolver, ResolverData}; use crate::substitution::{GenericSubstitution, SemanticRewriter, SubstitutionRewriter}; use crate::types::resolve_type; use crate::{ - ConcreteImplId, ConcreteImplLongId, ConcreteTraitLongId, GenericParam, SemanticDiagnostic, - TypeId, semantic, semantic_object_for_id, + ConcreteImplId, ConcreteImplLongId, ConcreteTraitLongId, GenericArgumentId, GenericParam, + SemanticDiagnostic, TypeId, semantic, semantic_object_for_id, }; /// A generic function of an impl. @@ -124,6 +124,36 @@ impl GenericFunctionId { } } } + + pub fn get_closure_params( + &self, + db: &dyn SemanticGroup, + ) -> Maybe> { + let mut closure_params_map = OrderedHashMap::default(); + let generic_params = self.generic_params(db)?; + + for param in generic_params { + if let GenericParam::Impl(generic_param_impl) = param { + let trait_id = generic_param_impl.concrete_trait?.trait_id(db); + + if fn_traits(db).contains(&trait_id) { + if let Ok(concrete_trait) = generic_param_impl.concrete_trait { + let [ + GenericArgumentId::Type(closure_type), + GenericArgumentId::Type(params_type), + ] = *concrete_trait.generic_args(db) + else { + unreachable!() + }; + + closure_params_map.insert(closure_type, params_type); + } + } + } + } + Ok(closure_params_map) + } + pub fn generic_signature(&self, db: &dyn SemanticGroup) -> Maybe { match *self { GenericFunctionId::Free(id) => db.free_function_signature(id), @@ -146,8 +176,11 @@ impl GenericFunctionId { GenericFunctionId::Extern(id) => db.extern_function_declaration_generic_params(id), GenericFunctionId::Impl(id) => { let concrete_trait_id = db.impl_concrete_trait(id.impl_id)?; - let id = ConcreteTraitGenericFunctionId::new(db, concrete_trait_id, id.function); - db.concrete_trait_function_generic_params(id) + let concrete_id = + ConcreteTraitGenericFunctionId::new(db, concrete_trait_id, id.function); + let substitution = GenericSubstitution::from_impl(id.impl_id); + let mut rewriter = SubstitutionRewriter { db, substitution: &substitution }; + rewriter.rewrite(db.concrete_trait_function_generic_params(concrete_id)?) } GenericFunctionId::Trait(id) => db.concrete_trait_function_generic_params(id), } @@ -860,6 +893,19 @@ pub fn concrete_function_signature( SubstitutionRewriter { db, substitution: &substitution }.rewrite(generic_signature) } +/// Query implementation of [crate::db::SemanticGroup::concrete_function_closure_params]. +pub fn concrete_function_closure_params( + db: &dyn SemanticGroup, + function_id: FunctionId, +) -> Maybe> { + let ConcreteFunction { generic_function, generic_args, .. } = + function_id.lookup_intern(db).function; + let generic_params = generic_function.generic_params(db)?; + let generic_closure_params = generic_function.get_closure_params(db)?; + let substitution = GenericSubstitution::new(&generic_params, &generic_args); + SubstitutionRewriter { db, substitution: &substitution }.rewrite(generic_closure_params) +} + /// For a given list of AST parameters, returns the list of semantic parameters along with the /// corresponding environment. fn update_env_with_ast_params(