diff --git a/Cargo.lock b/Cargo.lock index 63c6a7f..245a54b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1849,9 +1849,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "toml" -version = "0.9.11+spec-1.1.0" +version = "1.0.3+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3afc9a848309fe1aaffaed6e1546a7a14de1f935dc9d89d32afd9a44bab7c46" +checksum = "c7614eaf19ad818347db24addfa201729cf2a9b6fdfd9eb0ab870fcacc606c0c" dependencies = [ "indexmap", "serde_core", @@ -1864,18 +1864,18 @@ dependencies = [ [[package]] name = "toml_datetime" -version = "0.7.5+spec-1.1.0" +version = "1.0.0+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92e1cfed4a3038bc5a127e35a2d360f145e1f4b971b551a2ba5fd7aedf7e1347" +checksum = "32c2555c699578a4f59f0cc68e5116c8d7cabbd45e1409b989d4be085b53f13e" dependencies = [ "serde_core", ] [[package]] name = "toml_parser" -version = "1.0.6+spec-1.1.0" +version = "1.0.9+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a3198b4b0a8e11f09dd03e133c0280504d0801269e9afa46362ffde1cbeebf44" +checksum = "702d4415e08923e7e1ef96cd5727c0dfed80b4d2fa25db9647fe5eb6f7c5a4c4" dependencies = [ "winnow", ] diff --git a/ruff b/ruff index 48c6257..700e0aa 160000 --- a/ruff +++ b/ruff @@ -1 +1 @@ -Subproject commit 48c6257da158c5ec56eac66176ee6da4d759b690 +Subproject commit 700e0aaa6408366b7da3a64ae40d333423f9f866 diff --git a/src/collector.rs b/src/collector.rs index 899f3ed..d1682a3 100644 --- a/src/collector.rs +++ b/src/collector.rs @@ -5,10 +5,8 @@ use ruff_python_ast::{ self as ast, visitor::source_order, visitor::source_order::SourceOrderVisitor, }; use ruff_text_size::Ranged; -use ty_python_semantic::types::ParameterKind; -use ty_python_semantic::types::ide_support::{ - call_signature_details, find_active_signature_from_details, -}; +use ty_python_semantic::types::call::CallArguments; +use ty_python_semantic::types::{ParameterKind, Type, TypeContext}; use ty_python_semantic::{Db, HasType, SemanticModel}; use crate::protocol::{CallSignatureInfo, NodeAttribution, ParameterInfo, TypeDescriptor, TypeId}; @@ -92,41 +90,62 @@ impl<'db, 'reg> TypeCollector<'db, 'reg> { call_expr: &ast::ExprCall, return_type_id: Option, ) -> Option { - let signatures = call_signature_details(&self.model, call_expr); - if signatures.is_empty() { - return None; - } - - let active_idx = find_active_signature_from_details(&signatures).unwrap_or(0); - let sig = &signatures[active_idx]; - - let parameters: Vec = sig - .parameter_names + let db = self.db; + + // Get the callable type from the function expression + let func_type = call_expr.func.inferred_type(&self.model)?; + let callable_type = func_type.try_upcast_to_callable(db)?.into_type(db); + + // Build typed arguments so check_types can infer TypeVar specializations + let call_arguments = + CallArguments::from_arguments_typed(&call_expr.arguments, |splatted_value| { + splatted_value + .inferred_type(&self.model) + .unwrap_or(Type::unknown()) + }); + + // Bind parameters and run type checking to resolve specializations + let mut bindings = callable_type + .bindings(db) + .match_parameters(db, &call_arguments); + let _ = bindings.check_types_impl(db, &call_arguments, TypeContext::default(), &[]); + + // Pick the first matching overload (fallback to first overload) + let binding = bindings.iter_flat().flatten().next()?; + + let specialization = binding.specialization(); + + // Extract parameters from the binding's signature + let parameters: Vec = binding + .signature + .parameters() .iter() - .enumerate() - .map(|(i, name)| { - let type_id = sig.parameter_types.get(i).map(|&ty| self.register_type(ty)); - - let (kind, has_default) = if let Some(pk) = sig.parameter_kinds.get(i) { - match pk { - ParameterKind::PositionalOnly { default_type, .. } => { - ("positionalOnly", default_type.is_some()) - } - ParameterKind::PositionalOrKeyword { default_type, .. } => { - ("positionalOrKeyword", default_type.is_some()) - } - ParameterKind::Variadic { .. } => ("variadic", false), - ParameterKind::KeywordOnly { default_type, .. } => { - ("keywordOnly", default_type.is_some()) - } - ParameterKind::KeywordVariadic { .. } => ("keywordVariadic", false), + .map(|param| { + let mut ty = param.annotated_type(); + if let Some(spec) = specialization { + ty = ty.apply_specialization(db, spec); + } + let type_id = Some(self.register_type(ty)); + + let (kind, has_default) = match param.kind() { + ParameterKind::PositionalOnly { default_type, .. } => { + ("positionalOnly", default_type.is_some()) } - } else { - ("positionalOrKeyword", false) + ParameterKind::PositionalOrKeyword { default_type, .. } => { + ("positionalOrKeyword", default_type.is_some()) + } + ParameterKind::Variadic { .. } => ("variadic", false), + ParameterKind::KeywordOnly { default_type, .. } => { + ("keywordOnly", default_type.is_some()) + } + ParameterKind::KeywordVariadic { .. } => ("keywordVariadic", false), }; ParameterInfo { - name: name.clone(), + name: param + .display_name() + .map(|n| n.to_string()) + .unwrap_or_default(), type_id, kind, has_default, @@ -134,9 +153,20 @@ impl<'db, 'reg> TypeCollector<'db, 'reg> { }) .collect(); + // Extract type arguments from the inferred specialization + let type_arguments: Vec = specialization + .map(|spec| { + spec.types(db) + .iter() + .map(|&ty| self.register_type(ty)) + .collect() + }) + .unwrap_or_default(); + Some(CallSignatureInfo { parameters, return_type_id, + type_arguments, }) } diff --git a/src/protocol.rs b/src/protocol.rs index 66a32d3..7560012 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -119,6 +119,8 @@ pub struct NodeAttribution { pub struct CallSignatureInfo { pub parameters: Vec, pub return_type_id: Option, + #[serde(skip_serializing_if = "Vec::is_empty")] + pub type_arguments: Vec, } #[derive(Debug, Clone, Serialize)] diff --git a/tests/integration/main.rs b/tests/integration/main.rs index 014159d..dd6ea99 100644 --- a/tests/integration/main.rs +++ b/tests/integration/main.rs @@ -407,6 +407,52 @@ fn test_non_generic_function_no_type_parameters() { ); } +#[test] +fn test_generic_call_type_arguments() { + let dir = create_test_project(&[( + "g.py", + "def identity[T](x: T) -> T: return x\nresult = identity(42)\n", + )]); + + let responses = run_session(&[ + &initialize_request(dir.path().to_str().unwrap(), 1), + &get_types_request("g.py", 2), + &shutdown_request(99), + ]); + + let result = &responses[1]["result"]; + let nodes: Vec = serde_json::from_value(result["nodes"].clone()).unwrap(); + let types: TypeMap = serde_json::from_value(result["types"].clone()).unwrap(); + + // Find the ExprCall node for identity(42) + let call_node = nodes + .iter() + .find(|n| n.node_kind == "ExprCall") + .expect("should have an ExprCall node"); + + let call_sig = call_node + .call_signature + .as_ref() + .expect("ExprCall should have a call signature"); + + // Should have one type argument (T resolved to int) + assert_eq!( + call_sig.type_arguments.len(), + 1, + "identity(42) should have 1 type argument, got {:?}", + call_sig.type_arguments + ); + + // The type argument should be Literal[42] or int + let ta_id = call_sig.type_arguments[0].to_string(); + let ta_type = &types[&ta_id]; + assert!( + ta_type["kind"] == "intLiteral" || ta_type["kind"] == "instance", + "type argument should be int-like, got {:?}", + ta_type + ); +} + #[test] fn test_error_before_initialize() { let responses = run_session(&[&get_types_request("a.py", 1), &shutdown_request(99)]); diff --git a/tests/integration/protocol.rs b/tests/integration/protocol.rs index ff833f1..d6e74f5 100644 --- a/tests/integration/protocol.rs +++ b/tests/integration/protocol.rs @@ -10,4 +10,23 @@ pub struct NodeInfo { pub end: u32, pub node_kind: String, pub type_id: Option, + pub call_signature: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CallSignatureInfo { + pub parameters: Vec, + pub return_type_id: Option, + #[serde(default)] + pub type_arguments: Vec, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ParameterInfo { + pub name: String, + pub type_id: Option, + pub kind: String, + pub has_default: bool, }