@@ -4,7 +4,8 @@ use cairo_lang_plugins::plugins::HIDDEN_ATTR_SYNTAX;
44use cairo_lang_semantic:: keyword:: SELF_PARAM_KW ;
55use cairo_lang_syntax:: attribute:: consts:: IMPLICIT_PRECEDENCE_ATTR ;
66use cairo_lang_syntax:: node:: ast:: {
7- self , FunctionWithBody , OptionReturnTypeClause , OptionTypeClause , OptionWrappedGenericParamList ,
7+ self , Attribute , FunctionWithBody , OptionReturnTypeClause , OptionTypeClause ,
8+ OptionWrappedGenericParamList ,
89} ;
910use cairo_lang_syntax:: node:: helpers:: QueryAttrs ;
1011use cairo_lang_syntax:: node:: { Terminal , TypedStablePtr , TypedSyntaxNode } ;
@@ -18,7 +19,7 @@ use super::consts::{
1819 IMPLICIT_PRECEDENCE , L1_HANDLER_ATTR , L1_HANDLER_FIRST_PARAM_NAME , L1_HANDLER_MODULE ,
1920 RAW_OUTPUT_ATTR , WRAPPER_PREFIX ,
2021} ;
21- use super :: utils:: { AstPathExtract , ParamEx , has_v0_attribute , maybe_strip_underscore} ;
22+ use super :: utils:: { AstPathExtract , ParamEx , find_v0_attribute , maybe_strip_underscore} ;
2223
2324/// Kind of an entry point. Determined by the entry point's attributes.
2425#[ derive( Debug , Clone , Copy , PartialEq , Eq ) ]
@@ -27,43 +28,30 @@ pub enum EntryPointKind {
2728 Constructor ,
2829 L1Handler ,
2930}
30- impl EntryPointKind {
31- /// Returns the entry point kind if the given function is indeed marked as an entry point.
32- pub fn try_from_function_with_body < ' db > (
31+
32+ /// Helper trait for entry point kind extraction.
33+ pub trait GetEntryPointKind < ' db > {
34+ /// Returns the entry point kind and its trigger attribute if the attributes mark it as an entry
35+ /// point.
36+ fn entry_point_kind (
37+ & self ,
3338 db : & ' db dyn Database ,
3439 diagnostics : & mut Vec < PluginDiagnostic < ' db > > ,
35- item_function : & FunctionWithBody < ' db > ,
36- ) -> Option < Self > {
37- if has_v0_attribute (
38- db,
39- diagnostics,
40- & ast:: ModuleItem :: FreeFunction ( item_function. clone ( ) ) ,
41- EXTERNAL_ATTR ,
42- ) {
43- Some ( EntryPointKind :: External )
44- } else if item_function. has_attr ( db, CONSTRUCTOR_ATTR ) {
45- Some ( EntryPointKind :: Constructor )
46- } else if item_function. has_attr ( db, L1_HANDLER_ATTR ) {
47- Some ( EntryPointKind :: L1Handler )
48- } else {
49- None
50- }
51- }
40+ ) -> Option < ( EntryPointKind , Attribute < ' db > ) > ;
41+ }
5242
53- /// Returns the entry point kind if the attributes mark it as an entry point.
54- pub fn try_from_attrs < ' db > (
43+ impl < ' db > GetEntryPointKind < ' db > for FunctionWithBody < ' db > {
44+ fn entry_point_kind (
45+ & self ,
5546 db : & ' db dyn Database ,
5647 diagnostics : & mut Vec < PluginDiagnostic < ' db > > ,
57- attrs : & impl QueryAttrs < ' db > ,
58- ) -> Option < Self > {
59- if has_v0_attribute ( db, diagnostics, attrs, EXTERNAL_ATTR ) {
60- Some ( EntryPointKind :: External )
61- } else if attrs. has_attr ( db, CONSTRUCTOR_ATTR ) {
62- Some ( EntryPointKind :: Constructor )
63- } else if attrs. has_attr ( db, L1_HANDLER_ATTR ) {
64- Some ( EntryPointKind :: L1Handler )
48+ ) -> Option < ( EntryPointKind , Attribute < ' db > ) > {
49+ if let Some ( trigger) = find_v0_attribute ( db, diagnostics, self , EXTERNAL_ATTR ) {
50+ Some ( ( EntryPointKind :: External , trigger) )
51+ } else if let Some ( trigger) = self . find_attr ( db, CONSTRUCTOR_ATTR ) {
52+ Some ( ( EntryPointKind :: Constructor , trigger) )
6553 } else {
66- None
54+ self . find_attr ( db , L1_HANDLER_ATTR ) . map ( |trigger| ( EntryPointKind :: L1Handler , trigger ) )
6755 }
6856 }
6957}
@@ -125,6 +113,7 @@ fn generate_submodule<'db>(
125113
126114/// Parameters for generating an entry point, used when calling `handle_entry_point`.
127115pub struct EntryPointGenerationParams < ' db , ' a > {
116+ pub trigger_attribute : Attribute < ' db > ,
128117 pub entry_point_kind : EntryPointKind ,
129118 pub item_function : & ' a FunctionWithBody < ' db > ,
130119 pub wrapped_function_path : RewriteNode < ' db > ,
@@ -137,6 +126,7 @@ pub struct EntryPointGenerationParams<'db, 'a> {
137126pub fn handle_entry_point < ' db , ' a > (
138127 db : & ' db dyn Database ,
139128 EntryPointGenerationParams {
129+ trigger_attribute,
140130 entry_point_kind,
141131 item_function,
142132 wrapped_function_path,
@@ -195,7 +185,8 @@ pub fn handle_entry_point<'db, 'a>(
195185 unsafe_new_contract_state_prefix,
196186 ) {
197187 Ok ( generated_function) => {
198- data. generated_wrapper_functions . push ( generated_function) ;
188+ data. generated_wrapper_functions
189+ . push ( generated_function. mapped ( db, & trigger_attribute) ) ;
199190 data. generated_wrapper_functions . push ( RewriteNode :: text ( "\n " ) ) ;
200191 let generated = match entry_point_kind {
201192 EntryPointKind :: Constructor => & mut data. constructor_functions ,
@@ -205,14 +196,17 @@ pub fn handle_entry_point<'db, 'a>(
205196 }
206197 EntryPointKind :: External => & mut data. external_functions ,
207198 } ;
208- generated. push ( RewriteNode :: interpolate_patched (
209- "\n pub use super::$wrapper_function_name$ as $function_name$;" ,
210- & [
211- ( "wrapper_function_name" . into ( ) , wrapper_function_name) ,
212- ( "function_name" . into ( ) , function_name) ,
213- ]
214- . into ( ) ,
215- ) ) ;
199+ generated. push (
200+ RewriteNode :: interpolate_patched (
201+ "\n pub use super::$wrapper_function_name$ as $function_name$;" ,
202+ & [
203+ ( "wrapper_function_name" . into ( ) , wrapper_function_name) ,
204+ ( "function_name" . into ( ) , function_name) ,
205+ ]
206+ . into ( ) ,
207+ )
208+ . mapped ( db, & trigger_attribute) ,
209+ ) ;
216210 }
217211 Err ( entry_point_diagnostics) => {
218212 diagnostics. extend ( entry_point_diagnostics) ;
@@ -374,7 +368,7 @@ fn generate_entry_point_wrapper<'db>(
374368 ( "implicit_precedence" . to_string ( ) , implicit_precedence) ,
375369 ]
376370 . into ( ) ,
377- ) . mapped ( db , function ) )
371+ ) )
378372}
379373
380374/// Validates the second parameter of an L1 handler is `from_address: felt252` or `_from_address:
0 commit comments