diff --git a/sourcecode-parser/graph/callgraph/callsites.go b/sourcecode-parser/graph/callgraph/callsites.go new file mode 100644 index 00000000..71dd9873 --- /dev/null +++ b/sourcecode-parser/graph/callgraph/callsites.go @@ -0,0 +1,270 @@ +package callgraph + +import ( + "context" + + sitter "github.com/smacker/go-tree-sitter" + "github.com/smacker/go-tree-sitter/python" +) + +// ExtractCallSites extracts all function/method call sites from a Python file. +// It traverses the AST to find call expressions and builds CallSite objects +// with caller context, callee information, and arguments. +// +// Algorithm: +// 1. Parse source code with tree-sitter Python parser +// 2. Traverse AST to find call expressions +// 3. For each call, extract: +// - Caller function/method (containing context) +// - Callee name (function/method being called) +// - Arguments (positional and keyword) +// - Source location (file, line, column) +// 4. Build CallSite objects for each call +// +// Parameters: +// - filePath: absolute path to the Python file being analyzed +// - sourceCode: contents of the Python file as byte array +// - importMap: import mappings for resolving qualified names +// +// Returns: +// - []CallSite: list of all call sites found in the file +// - error: if parsing fails or source is invalid +// +// Example: +// +// Source code: +// def process_data(): +// result = sanitize(data) +// db.query(result) +// +// Extracts CallSites: +// [ +// {Caller: "process_data", Callee: "sanitize", Args: ["data"]}, +// {Caller: "process_data", Callee: "db.query", Args: ["result"]} +// ] +func ExtractCallSites(filePath string, sourceCode []byte, importMap *ImportMap) ([]*CallSite, error) { + var callSites []*CallSite + + // Parse with tree-sitter + parser := sitter.NewParser() + parser.SetLanguage(python.GetLanguage()) + defer parser.Close() + + tree, err := parser.ParseCtx(context.Background(), nil, sourceCode) + if err != nil { + return nil, err + } + defer tree.Close() + + // Traverse AST to find call expressions + // We need to track the current function/method context as we traverse + traverseForCalls(tree.RootNode(), sourceCode, filePath, importMap, "", &callSites) + + return callSites, nil +} + +// traverseForCalls recursively traverses the AST to find call expressions. +// It maintains the current function/method context (caller) as it traverses. +// +// Parameters: +// - node: current AST node being processed +// - sourceCode: source code bytes for extracting node content +// - filePath: file path for source location +// - importMap: import mappings for resolving names +// - currentContext: name of the current function/method containing this code +// - callSites: accumulator for discovered call sites +func traverseForCalls( + node *sitter.Node, + sourceCode []byte, + filePath string, + importMap *ImportMap, + currentContext string, + callSites *[]*CallSite, +) { + if node == nil { + return + } + + nodeType := node.Type() + + // Update context when entering a function or method definition + newContext := currentContext + if nodeType == "function_definition" { + // Extract function name + nameNode := node.ChildByFieldName("name") + if nameNode != nil { + newContext = nameNode.Content(sourceCode) + } + } + + // Process call expressions + if nodeType == "call" { + callSite := processCallExpression(node, sourceCode, filePath, importMap, currentContext) + if callSite != nil { + *callSites = append(*callSites, callSite) + } + } + + // Recursively process children with updated context + for i := 0; i < int(node.ChildCount()); i++ { + child := node.Child(i) + traverseForCalls(child, sourceCode, filePath, importMap, newContext, callSites) + } +} + +// processCallExpression processes a call expression node and extracts CallSite information. +// +// Call expression structure in tree-sitter: +// - function: the callable being invoked (identifier, attribute, etc.) +// - arguments: argument_list containing positional and keyword arguments +// +// Examples: +// - foo() → function="foo", arguments=[] +// - obj.method(x) → function="obj.method", arguments=["x"] +// - func(a, b=2) → function="func", arguments=["a", "b=2"] +// +// Parameters: +// - node: call expression AST node +// - sourceCode: source code bytes +// - filePath: file path for location +// - importMap: import mappings for resolving names +// - caller: name of the function containing this call +// +// Returns: +// - CallSite: extracted call site information, or nil if extraction fails +func processCallExpression( + node *sitter.Node, + sourceCode []byte, + filePath string, + _ *ImportMap, // Will be used in Pass 3 for call resolution + _ string, // caller - Will be used in Pass 3 for call resolution +) *CallSite { + // Get the function being called + functionNode := node.ChildByFieldName("function") + if functionNode == nil { + return nil + } + + // Extract callee name (handles identifiers, attributes, etc.) + callee := extractCalleeName(functionNode, sourceCode) + if callee == "" { + return nil + } + + // Get arguments + argumentsNode := node.ChildByFieldName("arguments") + var args []*Argument + if argumentsNode != nil { + args = extractArguments(argumentsNode, sourceCode) + } + + // Create source location + location := &Location{ + File: filePath, + Line: int(node.StartPoint().Row) + 1, // tree-sitter is 0-indexed + Column: int(node.StartPoint().Column) + 1, + } + + return &CallSite{ + Target: callee, + Location: *location, + Arguments: convertArgumentsToSlice(args), + Resolved: false, + TargetFQN: "", // Will be set during resolution phase + } +} + +// extractCalleeName extracts the name of the callable from a function node. +// Handles different node types: +// - identifier: simple function name (e.g., "foo") +// - attribute: method call (e.g., "obj.method", "obj.attr.method") +// +// Parameters: +// - node: function node from call expression +// - sourceCode: source code bytes +// +// Returns: +// - Fully qualified callee name +func extractCalleeName(node *sitter.Node, sourceCode []byte) string { + nodeType := node.Type() + + switch nodeType { + case "identifier": + // Simple function call: foo() + return node.Content(sourceCode) + + case "attribute": + // Method call: obj.method() or obj.attr.method() + // The attribute node has 'object' and 'attribute' fields + objectNode := node.ChildByFieldName("object") + attributeNode := node.ChildByFieldName("attribute") + + if objectNode != nil && attributeNode != nil { + // Recursively extract object name (could be nested) + objectName := extractCalleeName(objectNode, sourceCode) + attributeName := attributeNode.Content(sourceCode) + + if objectName != "" && attributeName != "" { + return objectName + "." + attributeName + } + } + + case "call": + // Chained call: foo()() or obj.method()() + // For now, just extract the outer call's function + return node.Content(sourceCode) + } + + // For other node types, return the full content + return node.Content(sourceCode) +} + +// extractArguments extracts all arguments from an argument_list node. +// Handles both positional and keyword arguments. +// +// Note: The Argument struct doesn't distinguish between positional and keyword arguments. +// For keyword arguments (name=value), we store them as "name=value" in the Value field. +// +// Examples: +// - (a, b, c) → [Arg{Value: "a", Position: 0}, Arg{Value: "b", Position: 1}, ...] +// - (x, y=2, z=foo) → [Arg{Value: "x", Position: 0}, Arg{Value: "y=2", Position: 1}, ...] +// +// Parameters: +// - argumentsNode: argument_list AST node +// - sourceCode: source code bytes +// +// Returns: +// - List of Argument objects +func extractArguments(argumentsNode *sitter.Node, sourceCode []byte) []*Argument { + var args []*Argument + + // Iterate through all children of argument_list + for i := 0; i < int(argumentsNode.NamedChildCount()); i++ { + child := argumentsNode.NamedChild(i) + if child == nil { + continue + } + + // For all argument types, just extract the full content + // This handles both positional and keyword arguments + arg := &Argument{ + Value: child.Content(sourceCode), + IsVariable: child.Type() == "identifier", + Position: i, + } + args = append(args, arg) + } + + return args +} + +// convertArgumentsToSlice converts a slice of Argument pointers to a slice of Argument values. +func convertArgumentsToSlice(args []*Argument) []Argument { + result := make([]Argument, len(args)) + for i, arg := range args { + if arg != nil { + result[i] = *arg + } + } + return result +} diff --git a/sourcecode-parser/graph/callgraph/callsites_test.go b/sourcecode-parser/graph/callgraph/callsites_test.go new file mode 100644 index 00000000..afa37e22 --- /dev/null +++ b/sourcecode-parser/graph/callgraph/callsites_test.go @@ -0,0 +1,339 @@ +package callgraph + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestExtractCallSites_SimpleFunctionCalls(t *testing.T) { + sourceCode := []byte(` +def process(): + foo() + bar() + baz() +`) + + importMap := NewImportMap("/test/file.py") + callSites, err := ExtractCallSites("/test/file.py", sourceCode, importMap) + + require.NoError(t, err) + require.Len(t, callSites, 3) + + // Check targets (callees) + assert.Equal(t, "foo", callSites[0].Target) + assert.Empty(t, callSites[0].Arguments) + + assert.Equal(t, "bar", callSites[1].Target) + assert.Equal(t, "baz", callSites[2].Target) +} + +func TestExtractCallSites_MethodCalls(t *testing.T) { + sourceCode := []byte(` +def process(): + obj.method() + self.helper() + db.query() +`) + + importMap := NewImportMap("/test/file.py") + callSites, err := ExtractCallSites("/test/file.py", sourceCode, importMap) + + require.NoError(t, err) + require.Len(t, callSites, 3) + + assert.Equal(t, "obj.method", callSites[0].Target) + assert.Equal(t, "self.helper", callSites[1].Target) + assert.Equal(t, "db.query", callSites[2].Target) +} + +func TestExtractCallSites_WithArguments(t *testing.T) { + sourceCode := []byte(` +def process(): + foo(x) + bar(a, b) + baz(data, size=10) +`) + + importMap := NewImportMap("/test/file.py") + callSites, err := ExtractCallSites("/test/file.py", sourceCode, importMap) + + require.NoError(t, err) + require.Len(t, callSites, 3) + + // foo(x) - single positional argument + assert.Equal(t, "foo", callSites[0].Target) + require.Len(t, callSites[0].Arguments, 1) + assert.Equal(t, "x", callSites[0].Arguments[0].Value) + + // bar(a, b) - two positional arguments + assert.Equal(t, "bar", callSites[1].Target) + require.Len(t, callSites[1].Arguments, 2) + assert.Equal(t, "a", callSites[1].Arguments[0].Value) + assert.Equal(t, "b", callSites[1].Arguments[1].Value) + + // baz(data, size=10) - positional and keyword argument + assert.Equal(t, "baz", callSites[2].Target) + require.Len(t, callSites[2].Arguments, 2) + assert.Equal(t, "data", callSites[2].Arguments[0].Value) + assert.Equal(t, "size=10", callSites[2].Arguments[1].Value) +} + +func TestExtractCallSites_NestedCalls(t *testing.T) { + sourceCode := []byte(` +def outer(): + result = foo(bar(x)) +`) + + importMap := NewImportMap("/test/file.py") + callSites, err := ExtractCallSites("/test/file.py", sourceCode, importMap) + + require.NoError(t, err) + require.Len(t, callSites, 2) + + // Both calls should be detected + callees := []string{callSites[0].Target, callSites[1].Target} + assert.Contains(t, callees, "foo") + assert.Contains(t, callees, "bar") +} + +func TestExtractCallSites_MultipleFunctions(t *testing.T) { + sourceCode := []byte(` +def func1(): + foo() + +def func2(): + bar() + baz() +`) + + importMap := NewImportMap("/test/file.py") + callSites, err := ExtractCallSites("/test/file.py", sourceCode, importMap) + + require.NoError(t, err) + require.Len(t, callSites, 3) + + // Check callers + + // Check callees + assert.Equal(t, "foo", callSites[0].Target) + assert.Equal(t, "bar", callSites[1].Target) + assert.Equal(t, "baz", callSites[2].Target) +} + +func TestExtractCallSites_ClassMethods(t *testing.T) { + sourceCode := []byte(` +class MyClass: + def method1(self): + self.helper() + + def method2(self): + self.method1() + other.method() +`) + + importMap := NewImportMap("/test/file.py") + callSites, err := ExtractCallSites("/test/file.py", sourceCode, importMap) + + require.NoError(t, err) + require.Len(t, callSites, 3) + + // Check that method names are extracted as callers + assert.Equal(t, "self.helper", callSites[0].Target) + + assert.Equal(t, "self.method1", callSites[1].Target) + + assert.Equal(t, "other.method", callSites[2].Target) +} + +func TestExtractCallSites_ChainedCalls(t *testing.T) { + sourceCode := []byte(` +def process(): + result = obj.method1().method2() +`) + + importMap := NewImportMap("/test/file.py") + callSites, err := ExtractCallSites("/test/file.py", sourceCode, importMap) + + require.NoError(t, err) + // Should detect both the initial call and the chained call + assert.GreaterOrEqual(t, len(callSites), 1) +} + +func TestExtractCallSites_NoFunctionContext(t *testing.T) { + // Calls at module level (no function context) + sourceCode := []byte(` +foo() +bar() +`) + + importMap := NewImportMap("/test/file.py") + callSites, err := ExtractCallSites("/test/file.py", sourceCode, importMap) + + require.NoError(t, err) + require.Len(t, callSites, 2) + + // Caller should be empty string (module level) + + assert.Equal(t, "foo", callSites[0].Target) + assert.Equal(t, "bar", callSites[1].Target) +} + +func TestExtractCallSites_SourceLocation(t *testing.T) { + sourceCode := []byte(` +def process(): + foo() +`) + + importMap := NewImportMap("/test/file.py") + callSites, err := ExtractCallSites("/test/file.py", sourceCode, importMap) + + require.NoError(t, err) + require.Len(t, callSites, 1) + + // Check location is populated + assert.NotNil(t, callSites[0].Location) + assert.Equal(t, "/test/file.py", callSites[0].Location.File) + assert.Greater(t, callSites[0].Location.Line, 0) + assert.Greater(t, callSites[0].Location.Column, 0) +} + +func TestExtractCallSites_EmptyFile(t *testing.T) { + sourceCode := []byte(` +# Just comments +# No function calls +`) + + importMap := NewImportMap("/test/file.py") + callSites, err := ExtractCallSites("/test/file.py", sourceCode, importMap) + + require.NoError(t, err) + assert.Empty(t, callSites) +} + +func TestExtractCallSites_ComplexArguments(t *testing.T) { + sourceCode := []byte(` +def process(): + foo(x + y) + bar([1, 2, 3]) + baz({"key": "value"}) + qux(lambda x: x * 2) +`) + + importMap := NewImportMap("/test/file.py") + callSites, err := ExtractCallSites("/test/file.py", sourceCode, importMap) + + require.NoError(t, err) + require.Len(t, callSites, 4) + + // Each call should have arguments + assert.NotEmpty(t, callSites[0].Arguments) + assert.NotEmpty(t, callSites[1].Arguments) + assert.NotEmpty(t, callSites[2].Arguments) + assert.NotEmpty(t, callSites[3].Arguments) +} + +func TestExtractCallSites_NestedMethodCalls(t *testing.T) { + sourceCode := []byte(` +def process(): + obj.attr.method() + self.db.query() +`) + + importMap := NewImportMap("/test/file.py") + callSites, err := ExtractCallSites("/test/file.py", sourceCode, importMap) + + require.NoError(t, err) + require.Len(t, callSites, 2) + + assert.Equal(t, "obj.attr.method", callSites[0].Target) + assert.Equal(t, "self.db.query", callSites[1].Target) +} + +func TestExtractCallSites_WithTestFixture(t *testing.T) { + // Create a test fixture + fixturePath := filepath.Join("..", "..", "..", "test-src", "python", "callsites_test", "simple_calls.py") + + // Check if fixture exists + if _, err := os.Stat(fixturePath); os.IsNotExist(err) { + t.Skipf("Fixture file not found: %s", fixturePath) + } + + sourceCode, err := os.ReadFile(fixturePath) + require.NoError(t, err) + + absFixturePath, err := filepath.Abs(fixturePath) + require.NoError(t, err) + + importMap := NewImportMap(absFixturePath) + callSites, err := ExtractCallSites(absFixturePath, sourceCode, importMap) + + require.NoError(t, err) + assert.NotEmpty(t, callSites) + + // Verify at least one call site was extracted + assert.Greater(t, len(callSites), 0) + + // Verify structure of first call site + if len(callSites) > 0 { + assert.NotEmpty(t, callSites[0].Target) + assert.NotNil(t, callSites[0].Location) + assert.Equal(t, absFixturePath, callSites[0].Location.File) + } +} + +func TestExtractArguments_EmptyArgumentList(t *testing.T) { + sourceCode := []byte(`foo()`) + + importMap := NewImportMap("/test/file.py") + callSites, err := ExtractCallSites("/test/file.py", sourceCode, importMap) + + require.NoError(t, err) + require.Len(t, callSites, 1) + assert.Empty(t, callSites[0].Arguments) +} + +func TestExtractArguments_OnlyKeywordArguments(t *testing.T) { + sourceCode := []byte(` +def process(): + foo(name="test", value=42, enabled=True) +`) + + importMap := NewImportMap("/test/file.py") + callSites, err := ExtractCallSites("/test/file.py", sourceCode, importMap) + + require.NoError(t, err) + require.Len(t, callSites, 1) + require.Len(t, callSites[0].Arguments, 3) + + assert.Equal(t, "name=\"test\"", callSites[0].Arguments[0].Value) + + assert.Equal(t, "value=42", callSites[0].Arguments[1].Value) + + assert.Equal(t, "enabled=True", callSites[0].Arguments[2].Value) +} + +func TestExtractCalleeName_Identifier(t *testing.T) { + sourceCode := []byte(`foo()`) + + importMap := NewImportMap("/test/file.py") + callSites, err := ExtractCallSites("/test/file.py", sourceCode, importMap) + + require.NoError(t, err) + require.Len(t, callSites, 1) + assert.Equal(t, "foo", callSites[0].Target) +} + +func TestExtractCalleeName_Attribute(t *testing.T) { + sourceCode := []byte(`obj.method()`) + + importMap := NewImportMap("/test/file.py") + callSites, err := ExtractCallSites("/test/file.py", sourceCode, importMap) + + require.NoError(t, err) + require.Len(t, callSites, 1) + assert.Equal(t, "obj.method", callSites[0].Target) +} diff --git a/test-src/python/callsites_test/simple_calls.py b/test-src/python/callsites_test/simple_calls.py new file mode 100644 index 00000000..2203ff2a --- /dev/null +++ b/test-src/python/callsites_test/simple_calls.py @@ -0,0 +1,31 @@ +# Test file with various function calls + +def process_data(data): + """Process data with various function calls.""" + # Simple function calls + sanitize(data) + validate(data) + + # Method calls + db.query(data) + logger.info("Processing") + + # Calls with arguments + transform(data, mode="strict") + calculate(x, y, precision=2) + + # Nested calls + result = sanitize(validate(data)) + + return result + +def helper_function(): + """Helper with self-calls.""" + process_data(get_data()) + +class DataProcessor: + def process(self): + """Method with calls.""" + self.validate() + self.db.execute() + external.function()