Skip to content

Commit 16c66bb

Browse files
committed
fix list_files
1 parent e78e59a commit 16c66bb

File tree

4 files changed

+306
-97
lines changed

4 files changed

+306
-97
lines changed

src/services/code_browsing_service.py

Lines changed: 99 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,13 @@ def list_methods(
5050
"file_pattern": file_pattern,
5151
"callee_pattern": callee_pattern,
5252
"include_external": include_external,
53+
"limit": limit,
5354
}
5455

5556
def execute_query():
5657
codebase_info = self.codebase_tracker.get_codebase(codebase_hash)
57-
if not codebase_info or not codebase_info.cpg_path:
58-
raise ValidationError(f"CPG not found for codebase {codebase_hash}")
58+
if not codebase_info:
59+
raise ValidationError(f"Codebase not found for codebase {codebase_hash}")
5960

6061
query_parts = ["cpg.method"]
6162
if not include_external:
@@ -71,7 +72,7 @@ def execute_query():
7172
".map(m => (m.name, m.id, m.fullName, m.signature, m.filename, m.lineNumber.getOrElse(-1), m.isExternal))"
7273
)
7374

74-
query_limit = max(limit, 10000)
75+
query_limit = min(limit, 10000)
7576
query = "".join(query_parts) + f".dedup.take({query_limit}).l"
7677

7778
result = self.query_executor.execute_query(
@@ -106,6 +107,9 @@ def execute_query():
106107
return full_result
107108

108109
methods = full_result.get("methods", [])
110+
# Respect the provided 'limit' for the returned list, independent of page_size
111+
if limit is not None and limit > 0:
112+
methods = methods[:limit]
109113
total = len(methods)
110114

111115
# Pagination
@@ -125,49 +129,109 @@ def execute_query():
125129
def list_files(
126130
self,
127131
codebase_hash: str,
132+
local_path: Optional[str] = None,
128133
limit: int = 1000,
129134
page: int = 1,
130135
page_size: int = 100,
131136
) -> Dict[str, Any]:
132137

133138
validate_codebase_hash(codebase_hash)
134-
cache_params = {} # No filters for now
139+
cache_params = {"local_path": local_path}
135140

136141
def execute_query():
137142
codebase_info = self.codebase_tracker.get_codebase(codebase_hash)
138-
if not codebase_info or not codebase_info.cpg_path:
139-
raise ValidationError(f"CPG not found for codebase {codebase_hash}")
140-
141-
query = f"cpg.file.map(f => (f.name, f.hash.getOrElse(\"\"))).take({limit}).l"
142-
143-
result = self.query_executor.execute_query(
144-
codebase_hash=codebase_hash,
145-
cpg_path=codebase_info.cpg_path,
146-
query=query,
147-
timeout=30,
148-
limit=limit,
143+
if not codebase_info:
144+
raise ValidationError(f"Codebase not found for codebase {codebase_hash}")
145+
# Determine the actual filesystem path to list
146+
playground_path = os.path.abspath(
147+
os.path.join(os.path.dirname(__file__), "..", "..", "playground")
149148
)
150149

151-
if not result.success:
152-
return {"success": False, "error": {"code": "QUERY_ERROR", "message": result.error}}
153-
154-
files = []
155-
for item in result.data:
156-
if isinstance(item, dict):
157-
files.append({
158-
"name": item.get("_1", ""),
159-
"hash": item.get("_2", ""),
160-
})
161-
return {"success": True, "files": files, "total": len(files)}
150+
if codebase_info.source_type == "github":
151+
from ..tools.core_tools import get_cpg_cache_key
152+
153+
cpg_cache_key = get_cpg_cache_key(
154+
codebase_info.source_type,
155+
codebase_info.source_path,
156+
codebase_info.language,
157+
)
158+
source_dir = os.path.join(playground_path, "codebases", cpg_cache_key)
159+
else:
160+
source_path = codebase_info.source_path
161+
if not os.path.isabs(source_path):
162+
source_path = os.path.abspath(source_path)
163+
source_dir = source_path
164+
165+
if not os.path.exists(source_dir) or not os.path.isdir(source_dir):
166+
raise ValidationError(f"Source directory not found for codebase {codebase_hash}: {source_dir}")
167+
168+
# Resolve target directory if a local_path is provided; otherwise, use source_dir
169+
if local_path:
170+
# Support both absolute and relative local_path; ensure it stays within source_dir
171+
candidate = local_path
172+
if not os.path.isabs(candidate):
173+
candidate = os.path.join(source_dir, candidate)
174+
candidate = os.path.normpath(candidate)
175+
source_dir_norm = os.path.normpath(source_dir)
176+
if not candidate.startswith(source_dir_norm):
177+
raise ValidationError("local_path must be inside the codebase source directory")
178+
target_dir = candidate
179+
else:
180+
target_dir = source_dir
181+
182+
# per-directory limits: default 20; 50 when a local_path was provided
183+
per_dir_limit = 50 if local_path else 20
184+
185+
def _list_dir_tree(root: str, base: str, per_dir_limit: int) -> List[Dict[str, Any]]:
186+
try:
187+
entries = sorted(os.listdir(root))
188+
except OSError:
189+
entries = []
190+
191+
result = []
192+
for name in entries[:per_dir_limit]:
193+
path = os.path.join(root, name)
194+
rel_path = os.path.relpath(path, base)
195+
if os.path.isdir(path):
196+
children = _list_dir_tree(path, base, per_dir_limit)
197+
result.append({
198+
"name": name,
199+
"path": rel_path,
200+
"type": "dir",
201+
"children": children,
202+
})
203+
else:
204+
result.append({
205+
"name": name,
206+
"path": rel_path,
207+
"type": "file",
208+
})
209+
return result
210+
211+
tree = _list_dir_tree(target_dir, source_dir, per_dir_limit)
212+
213+
# Count total entries in the returned tree (non-recursive counting for top-level)
214+
def _count_nodes(nodes: List[Dict[str, Any]]) -> int:
215+
total = 0
216+
for n in nodes:
217+
total += 1
218+
if n.get("type") == "dir":
219+
total += _count_nodes(n.get("children", []))
220+
return total
221+
222+
total_count = _count_nodes(tree)
223+
return {"success": True, "files": tree, "total": total_count}
162224

163225
full_result = self._get_cached_or_execute("list_files", codebase_hash, cache_params, execute_query)
164-
226+
165227
if not full_result.get("success"):
166228
return full_result
167229

168230
files = full_result.get("files", [])
169-
total = len(files)
170-
231+
total = full_result.get("total", len(files))
232+
233+
# The tree-based listing does not meaningfully support pagination of top-level results,
234+
# so keep backward compatibility by paginating the top-level entries only.
171235
start_idx = (page - 1) * page_size
172236
end_idx = start_idx + page_size
173237
paged_files = files[start_idx:end_idx]
@@ -178,7 +242,7 @@ def execute_query():
178242
"total": total,
179243
"page": page,
180244
"page_size": page_size,
181-
"total_pages": (total + page_size - 1) // page_size if page_size > 0 else 1
245+
"total_pages": (total + page_size - 1) // page_size if page_size > 0 else 1,
182246
}
183247

184248
def list_calls(
@@ -195,6 +259,7 @@ def list_calls(
195259
cache_params = {
196260
"caller_pattern": caller_pattern,
197261
"callee_pattern": callee_pattern,
262+
"limit": limit,
198263
}
199264

200265
def execute_query():
@@ -212,7 +277,7 @@ def execute_query():
212277
".map(c => (c.method.name, c.name, c.code, c.method.filename, c.lineNumber.getOrElse(-1)))"
213278
)
214279

215-
query_limit = max(limit, 10000)
280+
query_limit = min(limit, 10000)
216281
query = "".join(query_parts) + f".dedup.take({query_limit}).l"
217282

218283
result = self.query_executor.execute_query(
@@ -244,6 +309,9 @@ def execute_query():
244309
return full_result
245310

246311
calls = full_result.get("calls", [])
312+
# Apply the provided limit to final result set
313+
if limit is not None and limit > 0:
314+
calls = calls[:limit]
247315
total = len(calls)
248316

249317
start_idx = (page - 1) * page_size

src/tools/code_browsing_tools.py

Lines changed: 27 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def list_methods(
8787
@mcp.tool()
8888
def list_files(
8989
codebase_hash: str,
90+
local_path: Optional[str] = None,
9091
limit: int = 1000,
9192
page: int = 1,
9293
page_size: int = 100,
@@ -95,7 +96,8 @@ def list_files(
9596
List source files in the codebase.
9697
9798
Args:
98-
codebase_hash: The session ID from create_cpg_session
99+
codebase_hash: The session ID from create_cpg_create
100+
local_path: Optional path inside the codebase to list (relative to source root or absolute). When provided, per-directory limit is increased to 50.
99101
limit: Maximum number of results to fetch for caching (default: 1000)
100102
page: Page number (default: 1)
101103
page_size: Number of results per page (default: 100)
@@ -114,6 +116,7 @@ def list_files(
114116
code_browsing_service = services["code_browsing_service"]
115117
return code_browsing_service.list_files(
116118
codebase_hash=codebase_hash,
119+
local_path=local_path,
117120
limit=limit,
118121
page=page,
119122
page_size=page_size,
@@ -995,46 +998,40 @@ def run_cpgql_query(
995998
},
996999
}
9971000

998-
# Execute the query directly via Joern client to get raw stdout/stderr
999-
start_time = time.time()
1000-
1001-
port = query_executor.joern_server_manager.get_server_port(codebase_hash)
1002-
if not port:
1003-
return {
1004-
"success": False,
1005-
"error": {"code": "SERVER_ERROR", "message": f"No Joern server running for codebase {codebase_hash}"},
1006-
}
1007-
1008-
joern_client = JoernServerClient(host="localhost", port=port)
1009-
1010-
# Execute query with the provided query string as-is
1011-
result = joern_client.execute_query(query.strip(), timeout=timeout or 30)
1012-
1013-
execution_time = time.time() - start_time
1014-
1001+
# Use the QueryExecutor service to get structured output (data and row_count)
1002+
result = query_executor.execute_query(
1003+
codebase_hash=codebase_hash,
1004+
cpg_path=codebase_info.cpg_path,
1005+
query=query.strip(),
1006+
timeout=timeout or 30,
1007+
limit=None,
1008+
)
1009+
10151010
response = {
1016-
"success": result.get("success", False),
1017-
"stdout": result.get("stdout", ""),
1018-
"stderr": result.get("stderr", ""),
1019-
"execution_time": execution_time,
1011+
"success": result.success,
1012+
"data": result.data,
1013+
"row_count": result.row_count,
1014+
"execution_time": getattr(result, "execution_time", None),
10201015
}
1021-
1016+
1017+
# Include error information if present
1018+
if not result.success and getattr(result, "error", None):
1019+
response["error"] = result.error
1020+
10221021
# If validation was requested, include it in response
10231022
if validate and validation_result:
10241023
response["validation"] = validation_result
1025-
1026-
# If query failed, try to provide helpful suggestions
1027-
if not response["success"] and response["stderr"]:
1028-
stderr = response["stderr"]
1029-
error_suggestion = CPGQLValidator.get_error_suggestion(stderr)
1024+
1025+
# If query failed, try to provide helpful suggestions from stderr (if available)
1026+
if not response["success"] and result.error:
1027+
error_suggestion = CPGQLValidator.get_error_suggestion(result.error)
10301028
if error_suggestion:
10311029
response["suggestion"] = error_suggestion
10321030
response["help"] = {
10331031
"description": error_suggestion.get("description"),
10341032
"solution": error_suggestion.get("solution"),
1035-
"examples": error_suggestion.get("examples", [])[:3], # First 3 examples
1033+
"examples": error_suggestion.get("examples", [])[:3],
10361034
}
1037-
10381035
return response
10391036

10401037
except ValidationError as e:

0 commit comments

Comments
 (0)