From 6658e2d6fc6c3b1608e3f1a6084ef141f4c323ac Mon Sep 17 00:00:00 2001 From: Knut Wannheden Date: Mon, 23 Feb 2026 10:19:33 +0100 Subject: [PATCH 1/3] Add typeParameters to Function, ClassLiteral, and BoundMethod descriptors Surface generic type parameters (PEP 695) in the wire protocol. Extracts type parameters from Signature.generic_context and ClassLiteral.generic_context, filtering out implicit Self type variables. Also fixes TypeVar name to use the simple name (e.g. "T") instead of the qualified display form ("T@identity"). --- src/protocol.rs | 6 +++ src/registry.rs | 52 ++++++++++++++++------ tests/integration/main.rs | 93 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 138 insertions(+), 13 deletions(-) diff --git a/src/protocol.rs b/src/protocol.rs index 0eb6f33..329cc2b 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -175,6 +175,8 @@ pub enum TypeDescriptor { display: Option, class_name: String, #[serde(skip_serializing_if = "Vec::is_empty")] + type_parameters: Vec, + #[serde(skip_serializing_if = "Vec::is_empty")] supertypes: Vec, #[serde(skip_serializing_if = "Vec::is_empty")] members: Vec, @@ -210,6 +212,8 @@ pub enum TypeDescriptor { display: Option, name: String, #[serde(skip_serializing_if = "Vec::is_empty")] + type_parameters: Vec, + #[serde(skip_serializing_if = "Vec::is_empty")] parameters: Vec, #[serde(skip_serializing_if = "Option::is_none")] return_type: Option, @@ -228,6 +232,8 @@ pub enum TypeDescriptor { #[serde(skip_serializing_if = "Option::is_none")] name: Option, #[serde(skip_serializing_if = "Vec::is_empty")] + type_parameters: Vec, + #[serde(skip_serializing_if = "Vec::is_empty")] parameters: Vec, #[serde(skip_serializing_if = "Option::is_none")] return_type: Option, diff --git a/src/registry.rs b/src/registry.rs index 45430d4..09316db 100644 --- a/src/registry.rs +++ b/src/registry.rs @@ -2,7 +2,7 @@ use rustc_hash::FxHashMap; use ty_python_semantic::Db; use ty_python_semantic::types::list_members; use ty_python_semantic::types::{ - ClassLiteral, LiteralValueTypeKind, ParameterKind, Type, TypeGuardLike, + ClassLiteral, GenericContext, LiteralValueTypeKind, ParameterKind, Type, TypeGuardLike, }; use crate::protocol::{ClassMemberInfo, ParameterInfo, TypeDescriptor, TypeId, TypedDictFieldInfo}; @@ -96,23 +96,42 @@ impl<'db> TypeRegistry<'db> { Some(format!("{}", ty.display(db))) } + fn build_type_parameters( + &mut self, + generic_context: Option>, + db: &'db dyn Db, + ) -> Vec { + let Some(ctx) = generic_context else { + return vec![]; + }; + let vars: Vec<_> = ctx + .variables(db) + .filter(|bound_tv| !bound_tv.typevar(db).is_self(db)) + .collect(); + vars.into_iter() + .map(|bound_tv| self.register_component(Type::TypeVar(bound_tv), db)) + .collect() + } + fn build_function_params( &mut self, func_ty: Type<'db>, db: &'db dyn Db, - ) -> (Vec, Option) { + ) -> (Vec, Vec, Option) { let func = match func_ty.as_function_literal() { Some(f) => f, - None => return (vec![], None), + None => return (vec![], vec![], None), }; let callable_sig = func.signature(db); // TODO: only the first overload is used; overloaded functions lose // all but the first signature. Consider representing overloads. let sig = match callable_sig.iter().next() { Some(s) => s, - None => return (vec![], None), + None => return (vec![], vec![], None), }; + let type_parameters = self.build_type_parameters(sig.generic_context, db); + let parameters: Vec = sig .parameters() .into_iter() @@ -158,7 +177,7 @@ impl<'db> TypeRegistry<'db> { Some(self.register_component(return_ty, db)) }; - (parameters, return_type) + (type_parameters, parameters, return_type) } fn build_descriptor(&mut self, ty: Type<'db>, db: &'db dyn Db) -> TypeDescriptor { @@ -315,6 +334,8 @@ impl<'db> TypeRegistry<'db> { Type::ClassLiteral(class_literal) => { let display = self.display_string(ty, db); let class_name = class_literal.name(db).to_string(); + let type_parameters = + self.build_type_parameters(class_literal.generic_context(db), db); let supertypes: Vec = match class_literal { ClassLiteral::Static(static_class) => static_class .explicit_bases(db) @@ -348,6 +369,7 @@ impl<'db> TypeRegistry<'db> { TypeDescriptor::ClassLiteral { display, class_name, + type_parameters, supertypes, members, } @@ -375,6 +397,7 @@ impl<'db> TypeRegistry<'db> { TypeDescriptor::ClassLiteral { display, class_name, + type_parameters: vec![], supertypes, members, } @@ -397,10 +420,12 @@ impl<'db> TypeRegistry<'db> { Type::FunctionLiteral(func) => { let display = self.display_string(ty, db); let name = func.name(db).to_string(); - let (parameters, return_type) = self.build_function_params(ty, db); + let (type_parameters, parameters, return_type) = + self.build_function_params(ty, db); TypeDescriptor::Function { display, name, + type_parameters, parameters, return_type, } @@ -416,10 +441,12 @@ impl<'db> TypeRegistry<'db> { let func = bound.function(db); let func_ty = Type::FunctionLiteral(func); let name = Some(func.name(db).to_string()); - let (parameters, return_type) = self.build_function_params(func_ty, db); + let (type_parameters, parameters, return_type) = + self.build_function_params(func_ty, db); TypeDescriptor::BoundMethod { display, name, + type_parameters, parameters, return_type, } @@ -430,6 +457,7 @@ impl<'db> TypeRegistry<'db> { TypeDescriptor::BoundMethod { display, name: None, + type_parameters: vec![], parameters: vec![], return_type: None, } @@ -444,12 +472,10 @@ impl<'db> TypeRegistry<'db> { } } - Type::TypeVar(_) => { - let display_str = format!("{}", ty.display(db)); - TypeDescriptor::TypeVar { - display: Some(display_str.clone()), - name: display_str, - } + Type::TypeVar(bound_tv) => { + let display = self.display_string(ty, db); + let name = bound_tv.name(db).to_string(); + TypeDescriptor::TypeVar { display, name } } Type::TypeAlias(_) => { diff --git a/tests/integration/main.rs b/tests/integration/main.rs index 60cf46e..6982966 100644 --- a/tests/integration/main.rs +++ b/tests/integration/main.rs @@ -313,6 +313,99 @@ fn test_type_registry() { assert!(has_lit_42, "registry should have 'Literal[42]'"); } +#[test] +fn test_generic_function_type_parameters() { + let dir = create_test_project(&[("g.py", "def identity[T](x: T) -> T: return x\n")]); + + let responses = run_session(&[ + &initialize_request(dir.path().to_str().unwrap(), 1), + &get_types_request("g.py", 2), + &shutdown_request(99), + ]); + + let result = &responses[1]["result"]; + let types: TypeMap = serde_json::from_value(result["types"].clone()).unwrap(); + + // Find the function type for 'identity' + let func_type = types + .values() + .find(|t| t["kind"] == "function" && t["name"] == "identity") + .expect("should have a function type for 'identity'"); + + // Should have typeParameters with one entry + let type_params = func_type["typeParameters"] + .as_array() + .expect("typeParameters should be an array"); + assert_eq!(type_params.len(), 1, "identity[T] should have 1 type parameter"); + + // The type parameter should point to a TypeVar named T + let tv_id = type_params[0].to_string(); + let tv_type = &types[&tv_id]; + assert_eq!(tv_type["kind"], "typeVar"); + assert_eq!(tv_type["name"], "T"); +} + +#[test] +fn test_generic_class_type_parameters() { + let dir = create_test_project(&[( + "gc.py", + "class Box[T]:\n value: T\n", + )]); + + let responses = run_session(&[ + &initialize_request(dir.path().to_str().unwrap(), 1), + &get_types_request("gc.py", 2), + &shutdown_request(99), + ]); + + let result = &responses[1]["result"]; + let types: TypeMap = serde_json::from_value(result["types"].clone()).unwrap(); + + // Find the class literal for 'Box' + let class_type = types + .values() + .find(|t| t["kind"] == "classLiteral" && t["className"] == "Box") + .expect("should have a classLiteral for 'Box'"); + + // Should have typeParameters with one entry + let type_params = class_type["typeParameters"] + .as_array() + .expect("typeParameters should be an array"); + assert_eq!(type_params.len(), 1, "Box[T] should have 1 type parameter"); + + // The type parameter should point to a TypeVar named T + let tv_id = type_params[0].to_string(); + let tv_type = &types[&tv_id]; + assert_eq!(tv_type["kind"], "typeVar"); + assert_eq!(tv_type["name"], "T"); +} + +#[test] +fn test_non_generic_function_no_type_parameters() { + let dir = create_test_project(&[("ng.py", "def add(a: int, b: int) -> int: return a + b\n")]); + + let responses = run_session(&[ + &initialize_request(dir.path().to_str().unwrap(), 1), + &get_types_request("ng.py", 2), + &shutdown_request(99), + ]); + + let result = &responses[1]["result"]; + let types: TypeMap = serde_json::from_value(result["types"].clone()).unwrap(); + + // Find the function type for 'add' + let func_type = types + .values() + .find(|t| t["kind"] == "function" && t["name"] == "add") + .expect("should have a function type for 'add'"); + + // typeParameters should be absent (skip_serializing_if = "Vec::is_empty") + assert!( + func_type.get("typeParameters").is_none(), + "non-generic function should not have typeParameters key" + ); +} + #[test] fn test_error_before_initialize() { let responses = run_session(&[&get_types_request("a.py", 1), &shutdown_request(99)]); From d7721139a515859e5ed8bcaef22efd86a490988c Mon Sep 17 00:00:00 2001 From: Knut Wannheden Date: Mon, 23 Feb 2026 10:29:04 +0100 Subject: [PATCH 2/3] Always serialize parameters field on Function and BoundMethod descriptors A function declaration should always include a parameters array, even when empty, so consumers don't need to handle the absent-key case. --- src/protocol.rs | 2 -- src/registry.rs | 3 +-- tests/integration/main.rs | 11 ++++++----- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/src/protocol.rs b/src/protocol.rs index 329cc2b..66a32d3 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -213,7 +213,6 @@ pub enum TypeDescriptor { name: String, #[serde(skip_serializing_if = "Vec::is_empty")] type_parameters: Vec, - #[serde(skip_serializing_if = "Vec::is_empty")] parameters: Vec, #[serde(skip_serializing_if = "Option::is_none")] return_type: Option, @@ -233,7 +232,6 @@ pub enum TypeDescriptor { name: Option, #[serde(skip_serializing_if = "Vec::is_empty")] type_parameters: Vec, - #[serde(skip_serializing_if = "Vec::is_empty")] parameters: Vec, #[serde(skip_serializing_if = "Option::is_none")] return_type: Option, diff --git a/src/registry.rs b/src/registry.rs index 09316db..5e82767 100644 --- a/src/registry.rs +++ b/src/registry.rs @@ -420,8 +420,7 @@ impl<'db> TypeRegistry<'db> { Type::FunctionLiteral(func) => { let display = self.display_string(ty, db); let name = func.name(db).to_string(); - let (type_parameters, parameters, return_type) = - self.build_function_params(ty, db); + let (type_parameters, parameters, return_type) = self.build_function_params(ty, db); TypeDescriptor::Function { display, name, diff --git a/tests/integration/main.rs b/tests/integration/main.rs index 6982966..014159d 100644 --- a/tests/integration/main.rs +++ b/tests/integration/main.rs @@ -336,7 +336,11 @@ fn test_generic_function_type_parameters() { let type_params = func_type["typeParameters"] .as_array() .expect("typeParameters should be an array"); - assert_eq!(type_params.len(), 1, "identity[T] should have 1 type parameter"); + assert_eq!( + type_params.len(), + 1, + "identity[T] should have 1 type parameter" + ); // The type parameter should point to a TypeVar named T let tv_id = type_params[0].to_string(); @@ -347,10 +351,7 @@ fn test_generic_function_type_parameters() { #[test] fn test_generic_class_type_parameters() { - let dir = create_test_project(&[( - "gc.py", - "class Box[T]:\n value: T\n", - )]); + let dir = create_test_project(&[("gc.py", "class Box[T]:\n value: T\n")]); let responses = run_session(&[ &initialize_request(dir.path().to_str().unwrap(), 1), From 5e8bc4f796a6fd258928fc47d4d2ec8e6fa1ef3a Mon Sep 17 00:00:00 2001 From: Knut Wannheden Date: Mon, 23 Feb 2026 11:18:03 +0100 Subject: [PATCH 3/3] Replace ide_support with direct call binding; add typeArguments Replace the ide_support call_signature_details API with direct use of ty's binding APIs (Type::bindings, match_parameters, check_types_impl). This gives access to Binding::specialization(), which we use to: - Resolve TypeVar parameters to their concrete types inline - Surface inferred type arguments on CallSignatureInfo --- Cargo.lock | 12 ++-- ruff | 2 +- src/collector.rs | 103 +++++++++++++++++++++++----------- src/protocol.rs | 2 + tests/integration/main.rs | 46 +++++++++++++++ tests/integration/protocol.rs | 19 +++++++ 6 files changed, 143 insertions(+), 41 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 63c6a7f..245a54b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1849,9 +1849,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "toml" -version = "0.9.11+spec-1.1.0" +version = "1.0.3+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3afc9a848309fe1aaffaed6e1546a7a14de1f935dc9d89d32afd9a44bab7c46" +checksum = "c7614eaf19ad818347db24addfa201729cf2a9b6fdfd9eb0ab870fcacc606c0c" dependencies = [ "indexmap", "serde_core", @@ -1864,18 +1864,18 @@ dependencies = [ [[package]] name = "toml_datetime" -version = "0.7.5+spec-1.1.0" +version = "1.0.0+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92e1cfed4a3038bc5a127e35a2d360f145e1f4b971b551a2ba5fd7aedf7e1347" +checksum = "32c2555c699578a4f59f0cc68e5116c8d7cabbd45e1409b989d4be085b53f13e" dependencies = [ "serde_core", ] [[package]] name = "toml_parser" -version = "1.0.6+spec-1.1.0" +version = "1.0.9+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a3198b4b0a8e11f09dd03e133c0280504d0801269e9afa46362ffde1cbeebf44" +checksum = "702d4415e08923e7e1ef96cd5727c0dfed80b4d2fa25db9647fe5eb6f7c5a4c4" dependencies = [ "winnow", ] diff --git a/ruff b/ruff index c4708d5..700e0aa 160000 --- a/ruff +++ b/ruff @@ -1 +1 @@ -Subproject commit c4708d581ed8f893a434986d762b49a5f8d4e02a +Subproject commit 700e0aaa6408366b7da3a64ae40d333423f9f866 diff --git a/src/collector.rs b/src/collector.rs index 899f3ed..76f9391 100644 --- a/src/collector.rs +++ b/src/collector.rs @@ -5,10 +5,8 @@ use ruff_python_ast::{ self as ast, visitor::source_order, visitor::source_order::SourceOrderVisitor, }; use ruff_text_size::Ranged; -use ty_python_semantic::types::ParameterKind; -use ty_python_semantic::types::ide_support::{ - call_signature_details, find_active_signature_from_details, -}; +use ty_python_semantic::types::call::CallArguments; +use ty_python_semantic::types::{ParameterKind, Type, TypeContext}; use ty_python_semantic::{Db, HasType, SemanticModel}; use crate::protocol::{CallSignatureInfo, NodeAttribution, ParameterInfo, TypeDescriptor, TypeId}; @@ -92,41 +90,67 @@ impl<'db, 'reg> TypeCollector<'db, 'reg> { call_expr: &ast::ExprCall, return_type_id: Option, ) -> Option { - let signatures = call_signature_details(&self.model, call_expr); - if signatures.is_empty() { - return None; - } - - let active_idx = find_active_signature_from_details(&signatures).unwrap_or(0); - let sig = &signatures[active_idx]; - - let parameters: Vec = sig - .parameter_names + let db = self.db; + + // Get the callable type from the function expression + let func_type = call_expr.func.inferred_type(&self.model)?; + let callable_type = func_type + .try_upcast_to_callable(db)? + .into_type(db); + + // Build typed arguments so check_types can infer TypeVar specializations + let call_arguments = + CallArguments::from_arguments_typed(&call_expr.arguments, |splatted_value| { + splatted_value + .inferred_type(&self.model) + .unwrap_or(Type::unknown()) + }); + + // Bind parameters and run type checking to resolve specializations + let mut bindings = callable_type + .bindings(db) + .match_parameters(db, &call_arguments); + let _ = bindings.check_types_impl(db, &call_arguments, TypeContext::default(), &[]); + + // Pick the first matching overload (fallback to first overload) + let binding = bindings + .iter_flat() + .flatten() + .next()?; + + let specialization = binding.specialization(); + + // Extract parameters from the binding's signature + let parameters: Vec = binding + .signature + .parameters() .iter() - .enumerate() - .map(|(i, name)| { - let type_id = sig.parameter_types.get(i).map(|&ty| self.register_type(ty)); - - let (kind, has_default) = if let Some(pk) = sig.parameter_kinds.get(i) { - match pk { - ParameterKind::PositionalOnly { default_type, .. } => { - ("positionalOnly", default_type.is_some()) - } - ParameterKind::PositionalOrKeyword { default_type, .. } => { - ("positionalOrKeyword", default_type.is_some()) - } - ParameterKind::Variadic { .. } => ("variadic", false), - ParameterKind::KeywordOnly { default_type, .. } => { - ("keywordOnly", default_type.is_some()) - } - ParameterKind::KeywordVariadic { .. } => ("keywordVariadic", false), + .map(|param| { + let mut ty = param.annotated_type(); + if let Some(spec) = specialization { + ty = ty.apply_specialization(db, spec); + } + let type_id = Some(self.register_type(ty)); + + let (kind, has_default) = match param.kind() { + ParameterKind::PositionalOnly { default_type, .. } => { + ("positionalOnly", default_type.is_some()) } - } else { - ("positionalOrKeyword", false) + ParameterKind::PositionalOrKeyword { default_type, .. } => { + ("positionalOrKeyword", default_type.is_some()) + } + ParameterKind::Variadic { .. } => ("variadic", false), + ParameterKind::KeywordOnly { default_type, .. } => { + ("keywordOnly", default_type.is_some()) + } + ParameterKind::KeywordVariadic { .. } => ("keywordVariadic", false), }; ParameterInfo { - name: name.clone(), + name: param + .display_name() + .map(|n| n.to_string()) + .unwrap_or_default(), type_id, kind, has_default, @@ -134,9 +158,20 @@ impl<'db, 'reg> TypeCollector<'db, 'reg> { }) .collect(); + // Extract type arguments from the inferred specialization + let type_arguments: Vec = specialization + .map(|spec| { + spec.types(db) + .iter() + .map(|&ty| self.register_type(ty)) + .collect() + }) + .unwrap_or_default(); + Some(CallSignatureInfo { parameters, return_type_id, + type_arguments, }) } diff --git a/src/protocol.rs b/src/protocol.rs index 66a32d3..7560012 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -119,6 +119,8 @@ pub struct NodeAttribution { pub struct CallSignatureInfo { pub parameters: Vec, pub return_type_id: Option, + #[serde(skip_serializing_if = "Vec::is_empty")] + pub type_arguments: Vec, } #[derive(Debug, Clone, Serialize)] diff --git a/tests/integration/main.rs b/tests/integration/main.rs index 014159d..dd6ea99 100644 --- a/tests/integration/main.rs +++ b/tests/integration/main.rs @@ -407,6 +407,52 @@ fn test_non_generic_function_no_type_parameters() { ); } +#[test] +fn test_generic_call_type_arguments() { + let dir = create_test_project(&[( + "g.py", + "def identity[T](x: T) -> T: return x\nresult = identity(42)\n", + )]); + + let responses = run_session(&[ + &initialize_request(dir.path().to_str().unwrap(), 1), + &get_types_request("g.py", 2), + &shutdown_request(99), + ]); + + let result = &responses[1]["result"]; + let nodes: Vec = serde_json::from_value(result["nodes"].clone()).unwrap(); + let types: TypeMap = serde_json::from_value(result["types"].clone()).unwrap(); + + // Find the ExprCall node for identity(42) + let call_node = nodes + .iter() + .find(|n| n.node_kind == "ExprCall") + .expect("should have an ExprCall node"); + + let call_sig = call_node + .call_signature + .as_ref() + .expect("ExprCall should have a call signature"); + + // Should have one type argument (T resolved to int) + assert_eq!( + call_sig.type_arguments.len(), + 1, + "identity(42) should have 1 type argument, got {:?}", + call_sig.type_arguments + ); + + // The type argument should be Literal[42] or int + let ta_id = call_sig.type_arguments[0].to_string(); + let ta_type = &types[&ta_id]; + assert!( + ta_type["kind"] == "intLiteral" || ta_type["kind"] == "instance", + "type argument should be int-like, got {:?}", + ta_type + ); +} + #[test] fn test_error_before_initialize() { let responses = run_session(&[&get_types_request("a.py", 1), &shutdown_request(99)]); diff --git a/tests/integration/protocol.rs b/tests/integration/protocol.rs index ff833f1..d6e74f5 100644 --- a/tests/integration/protocol.rs +++ b/tests/integration/protocol.rs @@ -10,4 +10,23 @@ pub struct NodeInfo { pub end: u32, pub node_kind: String, pub type_id: Option, + pub call_signature: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CallSignatureInfo { + pub parameters: Vec, + pub return_type_id: Option, + #[serde(default)] + pub type_arguments: Vec, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ParameterInfo { + pub name: String, + pub type_id: Option, + pub kind: String, + pub has_default: bool, }