Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

98 changes: 64 additions & 34 deletions src/collector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@ use ruff_python_ast::{
self as ast, visitor::source_order, visitor::source_order::SourceOrderVisitor,
};
use ruff_text_size::Ranged;
use ty_python_semantic::types::ParameterKind;
use ty_python_semantic::types::ide_support::{
call_signature_details, find_active_signature_from_details,
};
use ty_python_semantic::types::call::CallArguments;
use ty_python_semantic::types::{ParameterKind, Type, TypeContext};
use ty_python_semantic::{Db, HasType, SemanticModel};

use crate::protocol::{CallSignatureInfo, NodeAttribution, ParameterInfo, TypeDescriptor, TypeId};
Expand Down Expand Up @@ -92,51 +90,83 @@ impl<'db, 'reg> TypeCollector<'db, 'reg> {
call_expr: &ast::ExprCall,
return_type_id: Option<TypeId>,
) -> Option<CallSignatureInfo> {
let signatures = call_signature_details(&self.model, call_expr);
if signatures.is_empty() {
return None;
}

let active_idx = find_active_signature_from_details(&signatures).unwrap_or(0);
let sig = &signatures[active_idx];

let parameters: Vec<ParameterInfo> = sig
.parameter_names
let db = self.db;

// Get the callable type from the function expression
let func_type = call_expr.func.inferred_type(&self.model)?;
let callable_type = func_type.try_upcast_to_callable(db)?.into_type(db);

// Build typed arguments so check_types can infer TypeVar specializations
let call_arguments =
CallArguments::from_arguments_typed(&call_expr.arguments, |splatted_value| {
splatted_value
.inferred_type(&self.model)
.unwrap_or(Type::unknown())
});

// Bind parameters and run type checking to resolve specializations
let mut bindings = callable_type
.bindings(db)
.match_parameters(db, &call_arguments);
let _ = bindings.check_types_impl(db, &call_arguments, TypeContext::default(), &[]);

// Pick the first matching overload (fallback to first overload)
let binding = bindings.iter_flat().flatten().next()?;

let specialization = binding.specialization();

// Extract parameters from the binding's signature
let parameters: Vec<ParameterInfo> = binding
.signature
.parameters()
.iter()
.enumerate()
.map(|(i, name)| {
let type_id = sig.parameter_types.get(i).map(|&ty| self.register_type(ty));

let (kind, has_default) = if let Some(pk) = sig.parameter_kinds.get(i) {
match pk {
ParameterKind::PositionalOnly { default_type, .. } => {
("positionalOnly", default_type.is_some())
}
ParameterKind::PositionalOrKeyword { default_type, .. } => {
("positionalOrKeyword", default_type.is_some())
}
ParameterKind::Variadic { .. } => ("variadic", false),
ParameterKind::KeywordOnly { default_type, .. } => {
("keywordOnly", default_type.is_some())
}
ParameterKind::KeywordVariadic { .. } => ("keywordVariadic", false),
.map(|param| {
let mut ty = param.annotated_type();
if let Some(spec) = specialization {
ty = ty.apply_specialization(db, spec);
}
let type_id = Some(self.register_type(ty));

let (kind, has_default) = match param.kind() {
ParameterKind::PositionalOnly { default_type, .. } => {
("positionalOnly", default_type.is_some())
}
} else {
("positionalOrKeyword", false)
ParameterKind::PositionalOrKeyword { default_type, .. } => {
("positionalOrKeyword", default_type.is_some())
}
ParameterKind::Variadic { .. } => ("variadic", false),
ParameterKind::KeywordOnly { default_type, .. } => {
("keywordOnly", default_type.is_some())
}
ParameterKind::KeywordVariadic { .. } => ("keywordVariadic", false),
};

ParameterInfo {
name: name.clone(),
name: param
.display_name()
.map(|n| n.to_string())
.unwrap_or_default(),
type_id,
kind,
has_default,
}
})
.collect();

// Extract type arguments from the inferred specialization
let type_arguments: Vec<TypeId> = specialization
.map(|spec| {
spec.types(db)
.iter()
.map(|&ty| self.register_type(ty))
.collect()
})
.unwrap_or_default();

Some(CallSignatureInfo {
parameters,
return_type_id,
type_arguments,
})
}

Expand Down
2 changes: 2 additions & 0 deletions src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ pub struct NodeAttribution {
pub struct CallSignatureInfo {
pub parameters: Vec<ParameterInfo>,
pub return_type_id: Option<TypeId>,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub type_arguments: Vec<TypeId>,
}

#[derive(Debug, Clone, Serialize)]
Expand Down
46 changes: 46 additions & 0 deletions tests/integration/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,52 @@ fn test_non_generic_function_no_type_parameters() {
);
}

#[test]
fn test_generic_call_type_arguments() {
let dir = create_test_project(&[(
"g.py",
"def identity[T](x: T) -> T: return x\nresult = identity(42)\n",
)]);

let responses = run_session(&[
&initialize_request(dir.path().to_str().unwrap(), 1),
&get_types_request("g.py", 2),
&shutdown_request(99),
]);

let result = &responses[1]["result"];
let nodes: Vec<NodeInfo> = serde_json::from_value(result["nodes"].clone()).unwrap();
let types: TypeMap = serde_json::from_value(result["types"].clone()).unwrap();

// Find the ExprCall node for identity(42)
let call_node = nodes
.iter()
.find(|n| n.node_kind == "ExprCall")
.expect("should have an ExprCall node");

let call_sig = call_node
.call_signature
.as_ref()
.expect("ExprCall should have a call signature");

// Should have one type argument (T resolved to int)
assert_eq!(
call_sig.type_arguments.len(),
1,
"identity(42) should have 1 type argument, got {:?}",
call_sig.type_arguments
);

// The type argument should be Literal[42] or int
let ta_id = call_sig.type_arguments[0].to_string();
let ta_type = &types[&ta_id];
assert!(
ta_type["kind"] == "intLiteral" || ta_type["kind"] == "instance",
"type argument should be int-like, got {:?}",
ta_type
);
}

#[test]
fn test_error_before_initialize() {
let responses = run_session(&[&get_types_request("a.py", 1), &shutdown_request(99)]);
Expand Down
19 changes: 19 additions & 0 deletions tests/integration/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,23 @@ pub struct NodeInfo {
pub end: u32,
pub node_kind: String,
pub type_id: Option<u32>,
pub call_signature: Option<CallSignatureInfo>,
}

#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CallSignatureInfo {
pub parameters: Vec<ParameterInfo>,
pub return_type_id: Option<u32>,
#[serde(default)]
pub type_arguments: Vec<u32>,
}

#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ParameterInfo {
pub name: String,
pub type_id: Option<u32>,
pub kind: String,
pub has_default: bool,
}