diff --git a/rust/rubydex/src/diagnostic.rs b/rust/rubydex/src/diagnostic.rs index bdd3fe8c..eddd8599 100644 --- a/rust/rubydex/src/diagnostic.rs +++ b/rust/rubydex/src/diagnostic.rs @@ -103,6 +103,8 @@ rules! { DynamicSingletonDefinition; DynamicAncestor; TopLevelMixinSelf; + InvalidConstantName; + DynamicConstantField; // Resolution } diff --git a/rust/rubydex/src/indexing/ruby_indexer.rs b/rust/rubydex/src/indexing/ruby_indexer.rs index 4a4d5e74..a2b11f9c 100644 --- a/rust/rubydex/src/indexing/ruby_indexer.rs +++ b/rust/rubydex/src/indexing/ruby_indexer.rs @@ -572,14 +572,15 @@ impl<'a> RubyIndexer<'a> { Some(definition_id) } - fn handle_class_definition( + fn handle_class_definition( &mut self, location: &ruby_prism::Location, name_node: Option<&ruby_prism::Node>, - body_node: Option, superclass_node: Option, - nesting_type: fn(DefinitionId) -> Nesting, - ) { + process_body: F, + ) where + F: FnOnce(&mut Self, DefinitionId), + { let offset = Offset::from_prism_location(location); let (comments, flags) = self.find_comments_for(offset.start()); let lexical_nesting_id = self.parent_lexical_scope_id(); @@ -625,16 +626,9 @@ impl<'a> RubyIndexer<'a> { ))); let definition_id = self.local_graph.add_definition(definition); - self.add_member_to_current_lexical_scope(definition_id); - if let Some(body) = body_node { - self.nesting_stack.push(nesting_type(definition_id)); - self.visibility_stack.push(Visibility::Public); - self.visit(&body); - self.visibility_stack.pop(); - self.nesting_stack.pop(); - } + process_body(self, definition_id); } } @@ -675,9 +669,7 @@ impl<'a> RubyIndexer<'a> { if let Some(body) = body_node { self.nesting_stack.push(nesting_type(definition_id)); - self.visibility_stack.push(Visibility::Public); - self.visit(&body); - self.visibility_stack.pop(); + self.visit_with_new_visibility(&body); self.nesting_stack.pop(); } } @@ -710,16 +702,225 @@ impl<'a> RubyIndexer<'a> { self.handle_class_definition( &node.location(), Some(node), - call_node.block(), call_node.arguments().and_then(|args| args.arguments().iter().next()), - Nesting::Owner, + |indexer, definition_id| { + indexer.nesting_stack.push(Nesting::Owner(definition_id)); + + if let Some(body) = call_node.block() { + indexer.visit_with_new_visibility(&body); + } + + indexer.nesting_stack.pop(); + }, ); return true; } + // Handle `Struct.new` + if receiver_name == b"Struct" || receiver_name == b"::Struct" { + // Create the class + self.handle_class_definition( + &node.location(), + Some(node), + Some(receiver), + |indexer, definition_id| { + indexer.nesting_stack.push(Nesting::Owner(definition_id)); + + // From the arguments, we need to create attr accessors, instance variables and an initialize method + indexer.handle_struct_fields(&call_node); + + if let Some(body) = call_node.block() { + indexer.visit_with_new_visibility(&body); + } + + indexer.nesting_stack.pop(); + }, + ); + + return true; + } + false } + /// Handles the arguments to `Struct.new` and creates the appropriate definitions + fn handle_struct_fields(&mut self, call_node: &ruby_prism::CallNode) { + let mut parameters = Vec::new(); + let parent_nesting_id = self.parent_nesting_id(); + + // If there are no arguments, we still need to create `initialize` accepting no parameters. Otherwise, we have the following combinations: + // - First argument might be a capitalized string, in which case a new constant alias is created. Otherwise, it + // must be a symbol + // - Regular arguments for fields are either strings or symbols + // - The `keyword_init: true` option changes the `initialize` method to use keyword arguments + if let Some(arg_node) = call_node.arguments() { + let mut arg_iter = arg_node.arguments().iter().peekable(); + + // If the first argument is a string node, then it _needs_ to be a valid constant name and we need to create an + // alias for this struct under the `Struct` namespace + if let Some(first_node) = arg_iter.peek() + && let Some(alias_name) = first_node.as_string_node() + { + self.handle_struct_alias(first_node, &alias_name); + // Consume the first argument since it matched + arg_iter.next(); + } + + // Collect argument names + let mut fields = Vec::new(); + let mut keyword_init = false; + + for argument in arg_iter { + match argument { + ruby_prism::Node::StringNode { .. } => { + let string_node = argument.as_string_node().unwrap(); + let field_name = String::from_utf8_lossy(string_node.unescaped()).to_string(); + fields.push((field_name, argument.location())); + } + ruby_prism::Node::SymbolNode { .. } => { + let symbol_node = argument.as_symbol_node().unwrap(); + + if let Some(value_loc) = symbol_node.value_loc() { + let field_name = Self::location_to_string(&value_loc); + fields.push((field_name, value_loc)); + } + } + ruby_prism::Node::KeywordHashNode { .. } => { + let hash_node = argument.as_keyword_hash_node().unwrap(); + + for element in &hash_node.elements() { + if let Some(assoc_node) = element.as_assoc_node() + && let Some(symbol_key) = assoc_node.key().as_symbol_node() + && let Some(symbol_value) = symbol_key.value_loc() + && symbol_value.as_slice() == b"keyword_init" + { + keyword_init = matches!(assoc_node.value(), ruby_prism::Node::TrueNode { .. }); + } + } + } + _ => { + // Dynamic argument + let offset = Offset::from_prism_location(&argument.location()); + + self.local_graph.add_diagnostic( + Rule::DynamicConstantField, + offset, + "Struct arguments that aren't string or symbol literals will be ignored".into(), + ); + } + } + } + + // For each field name, create an instance variable and an attr_accessor. Also, track parameters to create + // the `initialize` method below + for (name, loc) in fields { + let ivar_name = format!("@{name}"); + let str_id = self.local_graph.intern_string(name); + let offset = Offset::from_prism_location(&loc); + + if keyword_init { + parameters.push(Parameter::OptionalKeyword(ParameterStruct::new(offset.clone(), str_id))); + } else { + parameters.push(Parameter::OptionalPositional(ParameterStruct::new( + offset.clone(), + str_id, + ))); + } + + let attr_def = Definition::AttrAccessor(Box::new(AttrAccessorDefinition::new( + str_id, + self.uri_id, + offset.clone(), + Vec::new(), + DefinitionFlags::empty(), + parent_nesting_id, + Visibility::Public, + ))); + + let attr_def_id = self.local_graph.add_definition(attr_def); + self.add_member_to_current_owner(attr_def_id); + + let str_id = self.local_graph.intern_string(ivar_name); + let ivar_def = Definition::InstanceVariable(Box::new(InstanceVariableDefinition::new( + str_id, + self.uri_id, + offset, + Vec::new(), + DefinitionFlags::empty(), + parent_nesting_id, + ))); + + let ivar_def_id = self.local_graph.add_definition(ivar_def); + self.add_member_to_current_owner(ivar_def_id); + } + } + + // Create the `initialize` method based on the field names and the `keyword_init` option + let str_id = self.local_graph.intern_string("initialize".into()); + let offset = Offset::from_prism_location(&call_node.location()); + + let initialize = Definition::Method(Box::new(MethodDefinition::new( + str_id, + self.uri_id, + offset, + Vec::new(), + DefinitionFlags::empty(), + parent_nesting_id, + parameters, + Visibility::Private, + None, + ))); + + let definition_id = self.local_graph.add_definition(initialize); + self.add_member_to_current_owner(definition_id); + } + + fn handle_struct_alias(&mut self, first_node: &ruby_prism::Node, alias_name: &ruby_prism::StringNode) { + let offset = Offset::from_prism_location(&first_node.location()); + + // Check if the first character is an uppercase letter + let name_as_bytes = alias_name.unescaped(); + + if name_as_bytes.first().is_some_and(u8::is_ascii_uppercase) { + // Create alias under the `Struct` namespace + let lexical_nesting_id = self.parent_lexical_scope_id(); + let constant_name_id = self.current_owner_name_id().unwrap(); + let struct_str_id = self.local_graph.intern_string("Struct".into()); + let struct_name = Name::new(struct_str_id, None, None); + let struct_name_id = self.local_graph.add_name(struct_name); + + let constant_name = self.local_graph.names().get(&constant_name_id).unwrap(); + let alias_name = Name::new(*constant_name.str(), Some(struct_name_id), *constant_name.nesting()); + + let alias_constant = ConstantDefinition::new( + self.local_graph.add_name(alias_name), + self.uri_id, + offset, + Vec::new(), + DefinitionFlags::empty(), + lexical_nesting_id, + ); + let definition = + Definition::ConstantAlias(Box::new(ConstantAliasDefinition::new(constant_name_id, alias_constant))); + + let definition_id = self.local_graph.add_definition(definition); + self.add_member_to_current_owner(definition_id); + } else { + // First argument is a string, but not a valid constant name. This will crash in the runtime + self.local_graph.add_diagnostic( + Rule::InvalidConstantName, + offset, + "When the first argument to Struct.new is a string, it must be a valid constant name".into(), + ); + } + } + + fn visit_with_new_visibility(&mut self, node: &ruby_prism::Node) { + self.visibility_stack.push(Visibility::Public); + self.visit(node); + self.visibility_stack.pop(); + } + /// Returns the definition ID of the current nesting (class, module, or singleton class), /// but skips methods in the definitions stack. fn current_nesting_definition_id(&self) -> Option { @@ -1009,9 +1210,16 @@ impl Visit<'_> for RubyIndexer<'_> { self.handle_class_definition( &node.location(), Some(&node.constant_path()), - node.body(), node.superclass(), - Nesting::LexicalScope, + |indexer, definition_id| { + indexer.nesting_stack.push(Nesting::LexicalScope(definition_id)); + + if let Some(body) = node.body() { + indexer.visit_with_new_visibility(&body); + } + + indexer.nesting_stack.pop(); + }, ); } @@ -1091,9 +1299,7 @@ impl Visit<'_> for RubyIndexer<'_> { if let Some(body) = node.body() { self.nesting_stack.push(Nesting::LexicalScope(definition_id)); - self.visibility_stack.push(Visibility::Public); - self.visit(&body); - self.visibility_stack.pop(); + self.visit_with_new_visibility(&body); self.nesting_stack.pop(); } } @@ -1511,9 +1717,16 @@ impl Visit<'_> for RubyIndexer<'_> { self.handle_class_definition( &node.location(), None, - node.block(), node.arguments().and_then(|args| args.arguments().iter().next()), - Nesting::Owner, + |indexer, definition_id| { + indexer.nesting_stack.push(Nesting::Owner(definition_id)); + + if let Some(body) = node.block() { + indexer.visit_with_new_visibility(&body); + } + + indexer.nesting_stack.pop(); + }, ); return; } @@ -5310,4 +5523,134 @@ mod tests { assert_name_path_eq!(&context, "ALIAS1", def.target_name_id()); }); } + + #[test] + fn index_structs() { + let context = index_source({ + " + Foo = Struct.new(:bar) do + include Qux + def self.baz; end + end + " + }); + assert_no_diagnostics!(&context); + + assert_definition_at!(&context, "1:1-4:4", Class, |foo| { + assert_mixins_eq!(&context, foo, Include, vec!["Qux"]); + + let definitions = context.all_definitions_at("1:19-1:22"); + let Definition::InstanceVariable(bar_ivar) = &definitions[0] else { + panic!("Expected InstanceVariable definition for bar"); + }; + assert_name_eq!(&context, "@bar", bar_ivar); + assert_eq!(foo.id(), bar_ivar.lexical_nesting_id().unwrap()); + + let Definition::AttrAccessor(bar_attribute) = &definitions[1] else { + panic!("Expected Method definition for bar attribute"); + }; + assert_name_eq!(&context, "bar", bar_attribute); + assert_eq!(foo.id(), bar_attribute.lexical_nesting_id().unwrap()); + + assert_definition_at!(&context, "1:7-4:4", Method, |initialize| { + assert_name_eq!(&context, "initialize", initialize); + assert_eq!(foo.id(), initialize.lexical_nesting_id().unwrap()); + assert_eq!(1, initialize.parameters().len()); + assert_eq!(&Visibility::Private, initialize.visibility()); + + assert_parameter!(&initialize.parameters()[0], OptionalPositional, |param| { + assert_string_eq!(context, param.str(), "bar"); + }); + }); + + assert_definition_at!(&context, "3:3-3:20", Method, |baz| { + assert_eq!(foo.id(), baz.lexical_nesting_id().unwrap()); + + let receiver = baz.receiver().unwrap(); + let name_ref = context.graph().names().get(&receiver).unwrap(); + assert_eq!(StringId::from("Foo"), *name_ref.str()); + }); + }); + } + + #[test] + fn index_structs_with_keyword_init() { + let context = index_source({ + " + Foo = Struct.new(:bar, keyword_init: true) do + end + " + }); + assert_no_diagnostics!(&context); + + assert_definition_at!(&context, "1:1-2:4", Class, |foo| { + assert_definition_at!(&context, "1:7-2:4", Method, |initialize| { + assert_name_eq!(&context, "initialize", initialize); + assert_eq!(foo.id(), initialize.lexical_nesting_id().unwrap()); + + assert_parameter!(&initialize.parameters()[0], OptionalKeyword, |param| { + assert_string_eq!(context, param.str(), "bar"); + }); + }); + }); + } + + #[test] + fn index_named_structs() { + let context = index_source({ + " + Foo = Struct.new('Foo', :bar) do + end + " + }); + assert_no_diagnostics!(&context); + + assert_definition_at!(&context, "1:1-2:4", Class, |foo| { + assert_name_id_to_string_eq!(&context, "Foo", foo); + assert_superclass_ref_eq!(&context, foo, "Struct"); + + assert_definition_at!(&context, "1:18-1:23", ConstantAlias, |struct_foo| { + assert_name_id_to_string_eq!(&context, "Struct::Foo", struct_foo); + assert_name_path_eq!(&context, "Foo", struct_foo.target_name_id()); + }); + }); + } + + #[test] + fn index_struct_with_no_arguments() { + let context = index_source({ + " + Foo = Struct.new do + end + " + }); + assert_no_diagnostics!(&context); + + assert_definition_at!(&context, "1:1-2:4", Class, |foo| { + assert_definition_at!(&context, "1:7-2:4", Method, |initialize| { + assert_name_eq!(&context, "initialize", initialize); + assert_eq!(foo.id(), initialize.lexical_nesting_id().unwrap()); + assert!(initialize.parameters().is_empty()); + }); + }); + } + + #[test] + fn index_struct_with_dynamic_arguments() { + let context = index_source({ + " + Foo = Struct.new(var, :\"foo_#{var}\", \"bar_#{var}\") do + end + " + }); + + assert_diagnostics_eq!( + &context, + vec![ + "dynamic-constant-field: Struct arguments that aren't string or symbol literals will be ignored (1:18-1:21)", + "dynamic-constant-field: Struct arguments that aren't string or symbol literals will be ignored (1:23-1:36)", + "dynamic-constant-field: Struct arguments that aren't string or symbol literals will be ignored (1:38-1:50)" + ] + ); + } }