diff --git a/crates/cairo-lang-starknet/src/plugin/embeddable.rs b/crates/cairo-lang-starknet/src/plugin/embeddable.rs index 2326bcee493..45658dcd6fb 100644 --- a/crates/cairo-lang-starknet/src/plugin/embeddable.rs +++ b/crates/cairo-lang-starknet/src/plugin/embeddable.rs @@ -140,6 +140,7 @@ pub fn handle_embeddable<'db>( handle_entry_point( db, EntryPointGenerationParams { + trigger_attribute: embeddable_attr.clone(), entry_point_kind: EntryPointKind::External, item_function: &item_function, wrapped_function_path: function_path, diff --git a/crates/cairo-lang-starknet/src/plugin/entry_point.rs b/crates/cairo-lang-starknet/src/plugin/entry_point.rs index c8feda73e1d..551adda5e8d 100644 --- a/crates/cairo-lang-starknet/src/plugin/entry_point.rs +++ b/crates/cairo-lang-starknet/src/plugin/entry_point.rs @@ -4,7 +4,8 @@ use cairo_lang_plugins::plugins::HIDDEN_ATTR_SYNTAX; use cairo_lang_semantic::keyword::SELF_PARAM_KW; use cairo_lang_syntax::attribute::consts::IMPLICIT_PRECEDENCE_ATTR; use cairo_lang_syntax::node::ast::{ - self, FunctionWithBody, OptionReturnTypeClause, OptionTypeClause, OptionWrappedGenericParamList, + self, Attribute, FunctionWithBody, OptionReturnTypeClause, OptionTypeClause, + OptionWrappedGenericParamList, }; use cairo_lang_syntax::node::helpers::QueryAttrs; use cairo_lang_syntax::node::{Terminal, TypedStablePtr, TypedSyntaxNode}; @@ -18,7 +19,7 @@ use super::consts::{ IMPLICIT_PRECEDENCE, L1_HANDLER_ATTR, L1_HANDLER_FIRST_PARAM_NAME, L1_HANDLER_MODULE, RAW_OUTPUT_ATTR, WRAPPER_PREFIX, }; -use super::utils::{AstPathExtract, ParamEx, has_v0_attribute, maybe_strip_underscore}; +use super::utils::{AstPathExtract, ParamEx, find_v0_attribute, maybe_strip_underscore}; /// Kind of an entry point. Determined by the entry point's attributes. #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -27,43 +28,30 @@ pub enum EntryPointKind { Constructor, L1Handler, } -impl EntryPointKind { - /// Returns the entry point kind if the given function is indeed marked as an entry point. - pub fn try_from_function_with_body<'db>( + +/// Helper trait for entry point kind extraction. +pub trait GetEntryPointKind<'db> { + /// Returns the entry point kind and its trigger attribute if the attributes mark it as an entry + /// point. + fn entry_point_kind( + &self, db: &'db dyn Database, diagnostics: &mut Vec>, - item_function: &FunctionWithBody<'db>, - ) -> Option { - if has_v0_attribute( - db, - diagnostics, - &ast::ModuleItem::FreeFunction(item_function.clone()), - EXTERNAL_ATTR, - ) { - Some(EntryPointKind::External) - } else if item_function.has_attr(db, CONSTRUCTOR_ATTR) { - Some(EntryPointKind::Constructor) - } else if item_function.has_attr(db, L1_HANDLER_ATTR) { - Some(EntryPointKind::L1Handler) - } else { - None - } - } + ) -> Option<(EntryPointKind, Attribute<'db>)>; +} - /// Returns the entry point kind if the attributes mark it as an entry point. - pub fn try_from_attrs<'db>( +impl<'db> GetEntryPointKind<'db> for FunctionWithBody<'db> { + fn entry_point_kind( + &self, db: &'db dyn Database, diagnostics: &mut Vec>, - attrs: &impl QueryAttrs<'db>, - ) -> Option { - if has_v0_attribute(db, diagnostics, attrs, EXTERNAL_ATTR) { - Some(EntryPointKind::External) - } else if attrs.has_attr(db, CONSTRUCTOR_ATTR) { - Some(EntryPointKind::Constructor) - } else if attrs.has_attr(db, L1_HANDLER_ATTR) { - Some(EntryPointKind::L1Handler) + ) -> Option<(EntryPointKind, Attribute<'db>)> { + if let Some(trigger) = find_v0_attribute(db, diagnostics, self, EXTERNAL_ATTR) { + Some((EntryPointKind::External, trigger)) + } else if let Some(trigger) = self.find_attr(db, CONSTRUCTOR_ATTR) { + Some((EntryPointKind::Constructor, trigger)) } else { - None + self.find_attr(db, L1_HANDLER_ATTR).map(|trigger| (EntryPointKind::L1Handler, trigger)) } } } @@ -125,6 +113,7 @@ fn generate_submodule<'db>( /// Parameters for generating an entry point, used when calling `handle_entry_point`. pub struct EntryPointGenerationParams<'db, 'a> { + pub trigger_attribute: Attribute<'db>, pub entry_point_kind: EntryPointKind, pub item_function: &'a FunctionWithBody<'db>, pub wrapped_function_path: RewriteNode<'db>, @@ -137,6 +126,7 @@ pub struct EntryPointGenerationParams<'db, 'a> { pub fn handle_entry_point<'db, 'a>( db: &'db dyn Database, EntryPointGenerationParams { + trigger_attribute, entry_point_kind, item_function, wrapped_function_path, @@ -195,7 +185,8 @@ pub fn handle_entry_point<'db, 'a>( unsafe_new_contract_state_prefix, ) { Ok(generated_function) => { - data.generated_wrapper_functions.push(generated_function); + data.generated_wrapper_functions + .push(generated_function.mapped(db, &trigger_attribute)); data.generated_wrapper_functions.push(RewriteNode::text("\n")); let generated = match entry_point_kind { EntryPointKind::Constructor => &mut data.constructor_functions, @@ -205,14 +196,17 @@ pub fn handle_entry_point<'db, 'a>( } EntryPointKind::External => &mut data.external_functions, }; - generated.push(RewriteNode::interpolate_patched( - "\n pub use super::$wrapper_function_name$ as $function_name$;", - &[ - ("wrapper_function_name".into(), wrapper_function_name), - ("function_name".into(), function_name), - ] - .into(), - )); + generated.push( + RewriteNode::interpolate_patched( + "\n pub use super::$wrapper_function_name$ as $function_name$;", + &[ + ("wrapper_function_name".into(), wrapper_function_name), + ("function_name".into(), function_name), + ] + .into(), + ) + .mapped(db, &trigger_attribute), + ); } Err(entry_point_diagnostics) => { diagnostics.extend(entry_point_diagnostics); @@ -374,7 +368,7 @@ fn generate_entry_point_wrapper<'db>( ("implicit_precedence".to_string(), implicit_precedence), ] .into(), - ).mapped(db, function)) + )) } /// Validates the second parameter of an L1 handler is `from_address: felt252` or `_from_address: diff --git a/crates/cairo-lang-starknet/src/plugin/plugin_test_data/contracts/diagnostics b/crates/cairo-lang-starknet/src/plugin/plugin_test_data/contracts/diagnostics index 775a665aa0b..bebbfd54181 100644 --- a/crates/cairo-lang-starknet/src/plugin/plugin_test_data/contracts/diagnostics +++ b/crates/cairo-lang-starknet/src/plugin/plugin_test_data/contracts/diagnostics @@ -940,11 +940,9 @@ impl StorageStorageBaseMutCopy<> of core::traits::Copy::; //! > expected_diagnostics error: Trait has no implementation in context: core::serde::Serde::. - --> lib.cairo:5:5-6:56 - #[external(v0)] - _____^ -| fn foo(ref self: ContractState, x: super::MyType) {} -|________________________________________________________^ + --> lib.cairo:5:5 + #[external(v0)] + ^^^^^^^^^^^^^^^ //! > ========================================================================== @@ -1382,18 +1380,14 @@ warning: Plugin diagnostic: Failed to generate ABI: Got unexpected type. ^^^^^^^^^^^^^^^^^^^^^ error[E0006]: Type not found. - --> lib.cairo:5:5-6:47 - #[external(v0)] - _____^ -| fn foo(ref self: ContractState, x: T) {} -|_______________________________________________^ + --> lib.cairo:5:5 + #[external(v0)] + ^^^^^^^^^^^^^^^ error: Trait has no implementation in context: core::serde::Serde::<>. - --> lib.cairo:5:5-6:47 - #[external(v0)] - _____^ -| fn foo(ref self: ContractState, x: T) {} -|_______________________________________________^ + --> lib.cairo:5:5 + #[external(v0)] + ^^^^^^^^^^^^^^^ error: Variable not dropped. --> lib.cairo:6:40 @@ -17790,18 +17784,14 @@ error[E0006]: Type not found. ^^^^^^^ error[E0006]: Type not found. - --> lib.cairo:6:5-7:54 - #[external(v0)] - _____^ -| fn foo(ref self: ContractState, value: BadType) {} -|______________________________________________________^ + --> lib.cairo:6:5 + #[external(v0)] + ^^^^^^^^^^^^^^^ error: Trait has no implementation in context: core::serde::Serde::<>. - --> lib.cairo:6:5-7:54 - #[external(v0)] - _____^ -| fn foo(ref self: ContractState, value: BadType) {} -|______________________________________________________^ + --> lib.cairo:6:5 + #[external(v0)] + ^^^^^^^^^^^^^^^ //! > generated_cairo_code lib.cairo: @@ -18031,9 +18021,9 @@ error: Parameter type of impl function `Impl::foo` is incompatible with `Interfa ^^^^^^^ error: Unexpected argument type. Expected: "core::integer::u256", found: "core::felt252". - --> lib.cairo:15:9 - fn foo(ref self: ContractState, value: felt252) {} - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + --> lib.cairo:13:5 + #[abi(embed_v0)] + ^^^^^^^^^^^^^^^^ //! > generated_cairo_code lib.cairo: diff --git a/crates/cairo-lang-starknet/src/plugin/starknet_module/contract.rs b/crates/cairo-lang-starknet/src/plugin/starknet_module/contract.rs index 4271f9c85c2..f06f05b14a6 100644 --- a/crates/cairo-lang-starknet/src/plugin/starknet_module/contract.rs +++ b/crates/cairo-lang-starknet/src/plugin/starknet_module/contract.rs @@ -5,7 +5,7 @@ use cairo_lang_filesystem::ids::SmolStrId; use cairo_lang_parser::macro_helpers::AsLegacyInlineMacro; use cairo_lang_plugins::plugins::HasItemsInCfgEx; use cairo_lang_starknet_classes::keccak::starknet_keccak; -use cairo_lang_syntax::node::ast::OptionTypeClause; +use cairo_lang_syntax::node::ast::{Attribute, OptionTypeClause}; use cairo_lang_syntax::node::helpers::{ BodyItems, GetIdentifier, PathSegmentEx, QueryAttrs, is_single_arg_attr, }; @@ -25,10 +25,11 @@ use crate::plugin::consts::{ EXTERNAL_ATTR, HAS_COMPONENT_TRAIT, STORAGE_STRUCT_NAME, SUBSTORAGE_ATTR, }; use crate::plugin::entry_point::{ - EntryPointGenerationParams, EntryPointKind, EntryPointsGenerationData, handle_entry_point, + EntryPointGenerationParams, EntryPointKind, EntryPointsGenerationData, GetEntryPointKind, + handle_entry_point, }; use crate::plugin::storage::handle_storage_struct; -use crate::plugin::utils::{forbid_attributes_in_impl, has_v0_attribute_ex}; +use crate::plugin::utils::{find_v0_attribute_ex, forbid_attributes_in_impl}; /// Accumulated data specific for contract generation. #[derive(Default)] @@ -369,8 +370,7 @@ fn generate_constructor_deploy_function<'db>( for item in body.iter_items(db) { if let ast::ModuleItem::FreeFunction(func) = item - && let Some(EntryPointKind::Constructor) = - EntryPointKind::try_from_function_with_body(db, diagnostics, &func) + && let Some((EntryPointKind::Constructor, _)) = func.entry_point_kind(db, diagnostics) { let signature_params = func.declaration(db).signature(db).parameters(db); let params = signature_params.elements(db); @@ -434,31 +434,6 @@ fn generate_deploy_function<'db>( )) } -/// Handles a contract entrypoint function. -fn handle_contract_entry_point<'db>( - entry_point_kind: EntryPointKind, - item_function: &ast::FunctionWithBody<'db>, - wrapped_function_path: RewriteNode<'db>, - wrapper_identifier: String, - db: &'db dyn Database, - diagnostics: &mut Vec>, - data: &mut EntryPointsGenerationData<'db>, -) { - handle_entry_point( - db, - EntryPointGenerationParams { - entry_point_kind, - item_function, - wrapped_function_path, - wrapper_identifier, - unsafe_new_contract_state_prefix: "", - generic_params: RewriteNode::empty(), - }, - diagnostics, - data, - ) -} - /// Handles a free function inside a contract module. fn handle_contract_free_function<'db>( db: &'db dyn Database, @@ -466,19 +441,24 @@ fn handle_contract_free_function<'db>( item_function: &ast::FunctionWithBody<'db>, data: &mut EntryPointsGenerationData<'db>, ) { - let Some(entry_point_kind) = - EntryPointKind::try_from_function_with_body(db, diagnostics, item_function) + let Some((entry_point_kind, trigger_attribute)) = + item_function.entry_point_kind(db, diagnostics) else { return; }; let function_name = item_function.declaration(db).name(db); - let function_name_node = RewriteNode::from_ast_trimmed(&function_name); - handle_contract_entry_point( - entry_point_kind, - item_function, - function_name_node, - function_name.text(db).to_string(db), + + handle_entry_point( db, + EntryPointGenerationParams { + trigger_attribute, + entry_point_kind, + item_function, + wrapped_function_path: RewriteNode::from_ast_trimmed(&function_name), + wrapper_identifier: function_name.text(db).to_string(db), + unsafe_new_contract_state_prefix: "", + generic_params: RewriteNode::empty(), + }, diagnostics, data, ); @@ -492,10 +472,9 @@ fn handle_contract_impl<'db, 'a>( metadata: &'a MacroPluginMetadata<'a>, data: &mut EntryPointsGenerationData<'db>, ) { - let abi_config = impl_abi_config(db, diagnostics, imp); - if abi_config == ImplAbiConfig::None { + let Some((abi_config, abi_attr)) = impl_abi_config(db, diagnostics, imp) else { return; - } + }; let ast::MaybeImplBody::Some(impl_body) = imp.body(db) else { return; }; @@ -511,16 +490,14 @@ fn handle_contract_impl<'db, 'a>( let ast::ImplItem::Function(item_function) = item else { continue; }; - let entry_point_kind = if abi_config == ImplAbiConfig::PerItem { - let Some(entry_point_kind) = - EntryPointKind::try_from_attrs(db, diagnostics, &item_function) - else { + let (entry_point_kind, trigger_attribute) = if abi_config == ImplAbiConfig::PerItem { + let Some((kind, trigger)) = item_function.entry_point_kind(db, diagnostics) else { continue; }; - entry_point_kind + (kind, trigger) } else { // matches!(abi_config, ImplAbiConfig::Embed | ImplAbiConfig::External) - EntryPointKind::External + (EntryPointKind::External, abi_attr.clone()) }; let function_name = item_function.declaration(db).name(db); let function_name_node = RewriteNode::interpolate_patched( @@ -533,12 +510,17 @@ fn handle_contract_impl<'db, 'a>( ); let wrapper_identifier = format!("{}__{}", impl_name.text(db).long(db), function_name.text(db).long(db)); - handle_contract_entry_point( - entry_point_kind, - &item_function, - function_name_node, - wrapper_identifier, + handle_entry_point( db, + EntryPointGenerationParams { + trigger_attribute, + entry_point_kind, + item_function: &item_function, + wrapped_function_path: function_name_node, + wrapper_identifier, + unsafe_new_contract_state_prefix: "", + generic_params: RewriteNode::empty(), + }, diagnostics, data, ); @@ -548,8 +530,6 @@ fn handle_contract_impl<'db, 'a>( /// The configuration of an impl addition to the abi. #[derive(PartialEq, Eq)] enum ImplAbiConfig { - /// No ABI configuration. - None, /// The impl is marked with `#[abi(per_item)]`. Each item should provide its own configuration. PerItem, /// The impl is marked with `#[abi(embed_v0)]`. The entire impl and the interface are embedded @@ -561,17 +541,18 @@ enum ImplAbiConfig { } /// Returns the configuration of an impl addition to the abi using `#[abi(...)]` or the old -/// equivalent `#[external(v0)]`. +/// equivalent `#[external(v0)]`, as well as the actual attribute. +/// If none exists - returns None. fn impl_abi_config<'db>( db: &'db dyn Database, diagnostics: &mut Vec>, imp: &ast::ItemImpl<'db>, -) -> ImplAbiConfig { +) -> Option<(ImplAbiConfig, Attribute<'db>)> { if let Some(abi_attr) = imp.find_attr(db, ABI_ATTR) { if is_single_arg_attr(db, &abi_attr, ABI_ATTR_PER_ITEM_ARG) { - ImplAbiConfig::PerItem + Some((ImplAbiConfig::PerItem, abi_attr)) } else if is_single_arg_attr(db, &abi_attr, ABI_ATTR_EMBED_V0_ARG) { - ImplAbiConfig::Embed + Some((ImplAbiConfig::Embed, abi_attr)) } else { diagnostics.push(PluginDiagnostic::error( abi_attr.stable_ptr(db), @@ -580,17 +561,16 @@ fn impl_abi_config<'db>( '{ABI_ATTR_PER_ITEM_ARG}' or '{ABI_ATTR_EMBED_V0_ARG}' argument.", ), )); - ImplAbiConfig::None + None } - } else if has_v0_attribute_ex(db, diagnostics, imp, EXTERNAL_ATTR, || { - Some(format!( - "The '{EXTERNAL_ATTR}' attribute on impls is deprecated. Use \ - '{ABI_ATTR}({ABI_ATTR_PER_ITEM_ARG})' or '{ABI_ATTR}({ABI_ATTR_EMBED_V0_ARG})'." - )) - }) { - ImplAbiConfig::External } else { - ImplAbiConfig::None + let attr = find_v0_attribute_ex(db, diagnostics, imp, EXTERNAL_ATTR, || { + Some(format!( + "The '{EXTERNAL_ATTR}' attribute on impls is deprecated. Use \ + '{ABI_ATTR}({ABI_ATTR_PER_ITEM_ARG})' or '{ABI_ATTR}({ABI_ATTR_EMBED_V0_ARG})'." + )) + })?; + Some((ImplAbiConfig::External, attr)) } } diff --git a/crates/cairo-lang-starknet/src/plugin/utils.rs b/crates/cairo-lang-starknet/src/plugin/utils.rs index ec681a81762..a4720d3a6d2 100644 --- a/crates/cairo-lang-starknet/src/plugin/utils.rs +++ b/crates/cairo-lang-starknet/src/plugin/utils.rs @@ -153,34 +153,32 @@ pub fn maybe_strip_underscore(s: &str) -> &str { // === Attributes utilities === -/// Checks if the given (possibly-attributed-)object is attributed with the given `attr_name`. Also -/// validates that the attribute is v0. -pub fn has_v0_attribute<'db>( +/// Finds the attributes with the name `attr_name`. If exists also validates that the attribute is +/// v0. +pub fn find_v0_attribute<'db>( db: &'db dyn Database, diagnostics: &mut Vec>, object: &impl QueryAttrs<'db>, attr_name: &'db str, -) -> bool { - has_v0_attribute_ex(db, diagnostics, object, attr_name, || None) +) -> Option> { + find_v0_attribute_ex(db, diagnostics, object, attr_name, || None) } -/// Checks if the given (possibly-attributed-)object is attributed with the given `attr_name`. Also -/// validates that the attribute is v0, and adds a warning if supplied `deprecated` returns a value. -pub fn has_v0_attribute_ex<'db>( +/// Finds the attributes with the name `attr_name`. If exists also validates that the attribute is +/// v0, and adds a warning if supplied `deprecated` returns a value. +pub fn find_v0_attribute_ex<'db>( db: &'db dyn Database, diagnostics: &mut Vec>, object: &impl QueryAttrs<'db>, attr_name: &'db str, deprecated: impl FnOnce() -> Option, -) -> bool { - let Some(attr) = object.find_attr(db, attr_name) else { - return false; - }; +) -> Option> { + let attr = object.find_attr(db, attr_name)?; validate_v0(db, diagnostics, &attr, attr_name); if let Some(deprecated) = deprecated() { diagnostics.push(PluginDiagnostic::warning(attr.stable_ptr(db), deprecated)); } - true + Some(attr) } /// Assuming the attribute is `name`, validates it's in the form "#[name(v0)]".