diff --git a/config.yaml b/config.yaml index bc70cc3..adf3ada 100644 --- a/config.yaml +++ b/config.yaml @@ -1,7 +1,19 @@ python: + source_root: src/google/adk exclude: - examples typescript: + source_root: core/src/ + exclude: + - examples + +java: + source_root: core/src/main/java/com/google/adk + exclude: + - examples + +go: + source_root: . exclude: - examples \ No newline at end of file diff --git a/run.sh b/run.sh index 437791a..7c8cf5f 100755 --- a/run.sh +++ b/run.sh @@ -4,6 +4,11 @@ echo "Extracting Python features..." echo "Extracting TypeScript features..." ./extract.sh --language typescript --input-repo ../adk-js ./output +echo "Extracting Java features..." +./extract.sh --language java --input-repo ../adk-java ./output + +# Py -> TS + echo "Generating symmetric reports..." ./report.sh --base output/py.txtpb --target output/ts.txtpb --output ./output --report-type symmetric @@ -11,4 +16,13 @@ echo "Generating directional reports.. ." ./report.sh --base output/py.txtpb --target output/ts.txtpb --output ./output --report-type directional echo "Generating raw reports..." -./report.sh --base output/py.txtpb --target output/ts.txtpb --output ./output --report-type raw \ No newline at end of file +./report.sh --base output/py.txtpb --target output/ts.txtpb --output ./output --report-type raw + +# Py -> Java + +echo "Generating symmetric reports..." +./report.sh --base output/py.txtpb --target output/java.txtpb --output ./output --report-type symmetric + +echo "Generating directional reports (py->java)..." +./report.sh --base output/py.txtpb --target output/java.txtpb --output ./output --report-type directional + diff --git a/src/google/adk/scope/extractors/converter_java.py b/src/google/adk/scope/extractors/converter_java.py new file mode 100644 index 0000000..bc9ca87 --- /dev/null +++ b/src/google/adk/scope/extractors/converter_java.py @@ -0,0 +1,380 @@ +""" +Converter to transform Tree-sitter Java nodes into Feature objects. +""" + +import logging +from pathlib import Path +from typing import List, Optional, Tuple + +from tree_sitter import Node + +from google.adk.scope import features_pb2 as feature_pb2 +from google.adk.scope.utils.normalizer import TypeNormalizer, normalize_name + +logger = logging.getLogger(__name__) + + +class NodeProcessor: + """Process Tree-sitter nodes into Feature objects for Java.""" + + def __init__(self): + self.normalizer = TypeNormalizer() + + def process( + self, node: Node, file_path: Path, repo_root: Path + ) -> Optional[feature_pb2.Feature]: + """Convert a Tree-sitter node into a Feature. + + Args: + node: The method_declaration or constructor_declaration node. + file_path: Absolute path to the file. + repo_root: Root of the repository. + """ + if node.type not in ("method_declaration", "constructor_declaration"): + return None + + # 1. Identity + original_name = self._extract_name(node) + if not original_name: + return None + + normalized_name = normalize_name(original_name) + + # Skip testing methods if they happen to sneak in + if original_name.startswith("test"): + # A simplistic heuristic, could be improved + logger.debug("Skipping test method: %s", original_name) + return None + + # Exclude boilerplate methods + if original_name in ( + "equals", + "hashCode", + "toString", + "canEqual", + "clone", + ): + logger.debug("Skipping boilerplate method: %s", original_name) + return None + + # Exclude getters and setters + if node.type == "method_declaration": + if ( + ( + original_name.startswith("get") + and len(original_name) > 3 + and original_name[3].isupper() + ) + or ( + original_name.startswith("set") + and len(original_name) > 3 + and original_name[3].isupper() + ) + or ( + original_name.startswith("is") + and len(original_name) > 2 + and original_name[2].isupper() + ) + ): + logger.debug("Skipping getter/setter: %s", original_name) + return None + + member_of, normalized_member_of = self._extract_class(node) + + # If it's a constructor, the name is typically the class name + if node.type == "constructor_declaration" and not original_name: + original_name = member_of + normalized_name = normalized_member_of + + if not member_of: + member_of = "null" + + namespace, normalized_namespace = self._extract_namespace( + file_path, repo_root, node + ) + + # 3. Contract + jsdoc = self._extract_javadoc(node) + description = jsdoc if jsdoc else None + + parameters = self._extract_params(node) + + feature_type = feature_pb2.Feature.Type.INSTANCE_METHOD + if node.type == "constructor_declaration": + feature_type = feature_pb2.Feature.Type.CONSTRUCTOR + elif self._is_static(node): + feature_type = feature_pb2.Feature.Type.CLASS_METHOD + + original_returns, normalized_returns = self._extract_return_types(node) + + is_async = self._is_async(node, original_returns) + + maturity = self._extract_maturity(node) + + feature_kwargs = { + "original_name": original_name, + "normalized_name": normalized_name, + "member_of": member_of, + "normalized_member_of": normalized_member_of, + "file_path": str(file_path.resolve()), + "namespace": namespace, + "normalized_namespace": normalized_namespace, + "type": feature_type, + "parameters": parameters, + "original_return_types": original_returns, + "normalized_return_types": normalized_returns, + } + + if is_async: + feature_kwargs["async"] = True + + feature = feature_pb2.Feature(**feature_kwargs) + + if description: + feature.description = description + + if maturity is not None: + feature.maturity = maturity + + return feature + + def _extract_name(self, node: Node) -> str: + name_node = node.child_by_field_name("name") + if name_node: + return name_node.text.decode("utf-8") + return "" + + def _extract_class(self, node: Node) -> Tuple[str, str]: + parent = node.parent + while parent: + if parent.type in ( + "class_declaration", + "interface_declaration", + "enum_declaration", + ): + name_node = parent.child_by_field_name("name") + if name_node: + original = name_node.text.decode("utf-8") + return original, normalize_name(original) + parent = parent.parent + return "", "" + + def _extract_namespace( + self, file_path: Path, repo_root: Path, node: Node + ) -> Tuple[str, str]: + # Try to find package_declaration in the file + root = node + while root.parent: + root = root.parent + + namespace = "" + for child in root.children: + if child.type == "package_declaration": + # Find scoped_identifier or identifier + for sub in child.children: + if sub.type in ("scoped_identifier", "identifier"): + namespace = sub.text.decode("utf-8") + break + if namespace: + break + + if not namespace: + # Fallback to directory structure + try: + rel_path = file_path.relative_to(repo_root) + parts = list(rel_path.parent.parts) + # Try to strip common java roots like src/main/java + if "src" in parts: + idx = parts.index("src") + if ( + len(parts) > idx + 2 + and parts[idx + 1] == "main" + and parts[idx + 2] == "java" + ): + parts = parts[idx + 3 :] + elif len(parts) > idx + 1: + parts = parts[idx + 1 :] + except ValueError: + parts = list(file_path.parent.parts)[-3:] + + parts = [p for p in parts if p and p not in (".", "..")] + + if not parts: + return "", "" + + namespace = ".".join(parts) + + if namespace == "com.google.adk": + namespace = "" + elif namespace.startswith("com.google.adk."): + namespace = namespace[len("com.google.adk.") :] + + normalized = namespace.replace(".", "_") + return namespace, normalized + + def _extract_params(self, node: Node) -> List[feature_pb2.Param]: + params = [] + parameters_node = node.child_by_field_name("parameters") + if not parameters_node: + return params + + for p_node in parameters_node.children: + if p_node.type == "formal_parameter": + p_name_node = p_node.child_by_field_name("name") + p_type_node = p_node.child_by_field_name("type") + + name = p_name_node.text.decode("utf-8") if p_name_node else "" + type_str = ( + p_type_node.text.decode("utf-8") + if p_type_node + else "Object" + ) + + normalized_types = self.normalizer.normalize(type_str, "java") + + param = feature_pb2.Param( + original_name=name, + normalized_name=normalize_name(name), + original_types=[type_str], + normalized_types=[ + getattr(feature_pb2.ParamType, nt) + for nt in normalized_types + ], + is_optional=False, # Java params aren't optional by default + ) + params.append(param) + + return params + + def _extract_return_types( + self, node: Node + ) -> Tuple[List[str], List[feature_pb2.ParamType]]: + if node.type == "constructor_declaration": + return [], [] + + type_node = node.child_by_field_name("type") + if type_node: + raw = type_node.text.decode("utf-8") + normalized = self.normalizer.normalize(raw, "java") + return [raw], normalized + return [], [] + + def _is_static(self, node: Node) -> bool: + modifiers = node.child_by_field_name("modifiers") + if modifiers: + for child in modifiers.children: + if child.text.decode("utf-8") == "static": + return True + # also check node children if modifiers node not direct + for child in node.children: + if child.type == "modifiers": + for m_child in child.children: + if m_child.text.decode("utf-8") == "static": + return True + return False + + def _is_async(self, node: Node, return_types: List[str]) -> bool: + # Check return types for CompletableFuture, Mono, Flux, etc. + for rt in return_types: + if any( + rt.startswith(async_type) + for async_type in ( + "CompletableFuture", + "Future", + "Mono", + "Flux", + ) + ): + return True + + # Check for @Async annotation + modifiers = node.child_by_field_name("modifiers") + if modifiers: + for child in modifiers.children: + if child.type == "marker_annotation": + name = child.child_by_field_name("name") + if name and name.text.decode("utf-8") == "Async": + return True + for child in node.children: + if child.type == "modifiers": + for m_child in child.children: + if m_child.type == "marker_annotation": + name = m_child.child_by_field_name("name") + if name and name.text.decode("utf-8") == "Async": + return True + return False + + def _extract_maturity(self, node: Node) -> Optional[int]: + modifiers = node.child_by_field_name("modifiers") + + def _check_annotations(mods_node): + for child in mods_node.children: + if child.type == "marker_annotation": + name_node = child.child_by_field_name("name") + if name_node: + anno = name_node.text.decode("utf-8") + if anno in ("Experimental", "Beta"): + return feature_pb2.Feature.Maturity.EXPERIMENTAL + if anno == "Deprecated": + return feature_pb2.Feature.Maturity.DEPRECATED + return None + + if modifiers: + res = _check_annotations(modifiers) + if res is not None: + return res + + for child in node.children: + if child.type == "modifiers": + res = _check_annotations(child) + if res is not None: + return res + + return None + + def _extract_javadoc(self, node: Node) -> str: + prev = node.prev_sibling + while prev: + if prev.type == "block_comment": + text = prev.text.decode("utf-8") + if text.startswith("/**"): + lines = text.split("\n") + clean_lines = [] + for line in lines: + line = line.strip() + if line.startswith("/**"): + line = line[3:] + if line.endswith("*/"): + line = line[:-2] + if line.startswith("*"): + line = line[1:] + clean_lines.append(line.strip()) + return "\n".join(clean_lines).strip() + # If we hit modifiers or annotations, we keep going up + elif prev.type == "modifiers" or prev.type == "marker_annotation": + pass + else: + break + prev = prev.prev_sibling + + # Also check if it's placed inside modifiers by some AST quirks + modifiers = node.child_by_field_name("modifiers") + if modifiers: + for child in modifiers.children: + if child.type == "block_comment": + text = child.text.decode("utf-8") + if text.startswith("/**"): + lines = text.split("\n") + clean_lines = [] + for line in lines: + line = line.strip() + if line.startswith("/**"): + line = line[3:] + if line.endswith("*/"): + line = line[:-2] + if line.startswith("*"): + line = line[1:] + clean_lines.append(line.strip()) + return "\n".join(clean_lines).strip() + + return "" diff --git a/src/google/adk/scope/extractors/extract.py b/src/google/adk/scope/extractors/extract.py index 809c738..0a85a2d 100644 --- a/src/google/adk/scope/extractors/extract.py +++ b/src/google/adk/scope/extractors/extract.py @@ -6,7 +6,8 @@ from google.protobuf import text_format from google.protobuf.json_format import MessageToDict, MessageToJson -from google.adk.scope.extractors import extractor_py, extractor_ts +from google.adk.scope.extractors import (extractor_java, extractor_py, + extractor_ts) from google.adk.scope.features_pb2 import FeatureRegistry from google.adk.scope.utils.args import parse_args @@ -24,16 +25,19 @@ EXTRACTORS = { "python": extractor_py, "typescript": extractor_ts, + "java": extractor_java, } REPO_ROOT_MARKERS = { "python": ["src"], "typescript": ["package.json", "tsconfig.json"], + "java": ["pom.xml", "build.gradle", "build.gradle.kts"], } REPO_SRC_SUBDIRS = { "python": ["src"], "typescript": ["core/src", "src"], + "java": ["src/main/java", "src"], } @@ -87,6 +91,7 @@ def main(): all_features = [] repo_root = None + config = {} if args.input_file: input_path = args.input_file @@ -100,8 +105,13 @@ def main(): repo_root = input_path.parent if root := get_repo_root(input_path, args.language): repo_root = root + + config = get_config(repo_root) + source_root = config.get(args.language, {}).get("source_root", ".") - features = extractor_module.extract_features(input_path, repo_root) + features = extractor_module.extract_features( + input_path, repo_root, source_root + ) all_features.extend(features) try: @@ -126,8 +136,13 @@ def main(): files = list(extractor_module.find_files(input_path, recursive=False)) logger.info("Found %d %s files.", len(files), args.language) + config = get_config(repo_root) + source_root = config.get(args.language, {}).get("source_root", ".") + for p in files: - features = extractor_module.extract_features(p, repo_root) + features = extractor_module.extract_features( + p, repo_root, source_root + ) all_features.extend(features) # Log only if features found? Or keep unified summary at end. if features: @@ -170,8 +185,11 @@ def main(): "Found %d %s files in %s.", len(files), args.language, search_dir ) + source_root = config.get(args.language, {}).get("source_root", ".") for p in files: - features = extractor_module.extract_features(p, repo_root) + features = extractor_module.extract_features( + p, repo_root, source_root + ) all_features.extend(features) else: @@ -198,7 +216,11 @@ def main(): logger.error("Failed to create output directory %s: %s", output_dir, e) sys.exit(1) - prefix = "py" if args.language in {"python", "py"} else "ts" + prefix = ( + "py" + if args.language in {"python", "py"} + else "ts" if args.language in {"typescript", "ts"} else "java" + ) base_filename = f"{prefix}" if _JSON_OUTPUT: diff --git a/src/google/adk/scope/extractors/extractor_java.py b/src/google/adk/scope/extractors/extractor_java.py new file mode 100644 index 0000000..15bc0d8 --- /dev/null +++ b/src/google/adk/scope/extractors/extractor_java.py @@ -0,0 +1,203 @@ +import logging +import pathlib +from typing import Iterator, List + +import tree_sitter_java as tsjava +from tree_sitter import Language, Parser, Query, QueryCursor + +from google.adk.scope.extractors.converter_java import NodeProcessor +from google.adk.scope.features_pb2 import Feature +from google.adk.scope.utils.normalizer import normalize_namespace + +# Initialize Tree-sitter +try: + JAVA_LANGUAGE = Language(tsjava.language()) +except AttributeError: + # Some older versions have .language_java() + JAVA_LANGUAGE = Language(tsjava.language_java()) + +PARSER = Parser() +PARSER.language = JAVA_LANGUAGE + +logger = logging.getLogger(__name__) + + +def find_files( + root: pathlib.Path, recursive: bool = True +) -> Iterator[pathlib.Path]: + """Find Java files in the given directory. + + Args: + root (Path): The root directory to search. + recursive (bool): Whether to search recursively. + + Yields: + Iterator[Path]: An iterator of Paths to Java files. + """ + if not root.exists(): + logger.warning("Directory %s does not exist. Skipping traversal.", root) + return + + # Files to exclude from extraction + excluded_files = { + "module-info.java", + "package-info.java", + } + + iterator = root.rglob("*.java") if recursive else root.glob("*.java") + + for path in iterator: + if path.name in excluded_files: + logger.debug("Skipping excluded file: %s", path) + continue + + # exclude node_modules, build, etc. + if any( + part in ("node_modules", "build", "target", "out", "bin") + or part.startswith(".") + for part in path.parts + if part not in (".", "..") + ): + logger.debug("Skipping hidden/build path: %s", path) + continue + + # Also exclude commonly known test directories + if any(part == "test" for part in path.parts): + logger.debug("Skipping test file: %s", path) + continue + + yield path + + +def extract_features( + file_path: pathlib.Path, repo_root: pathlib.Path, source_root: str +) -> List[Feature]: + """Extract Feature objects from a Java file. + + Args: + file_path (Path): Path to the Java file. + repo_root (Path): Path to the repository root. + + Returns: + List[feature_pb2.Feature]: A list of extracted Features. + """ + try: + content = file_path.read_bytes() + except IOError as e: + logger.error("Failed to read %s: %s", file_path, e) + return [] + + tree = PARSER.parse(content) + root_node = tree.root_node + + processor = NodeProcessor() + features = [] + + # Query for methods and constructors + query = Query( + JAVA_LANGUAGE, + """ + (method_declaration) @method + (constructor_declaration) @constructor + """, + ) + + cursor = QueryCursor(query) + captures = cursor.captures(root_node) + + all_nodes = [] + if "method" in captures: + all_nodes.extend(captures["method"]) + if "constructor" in captures: + all_nodes.extend(captures["constructor"]) + + logger.debug("Found %d potential nodes in %s", len(all_nodes), file_path) + + for node in all_nodes: + + feature = processor.process(node, file_path, repo_root) + if feature: + feature.normalized_namespace = normalize_namespace( + str(file_path), str(repo_root / source_root) + ) + features.append(feature) + logger.debug("Extracted feature: %s", feature.original_name) + + return features + + +def get_version(repo_root: pathlib.Path) -> str: + version = "0.0.0" + + # 1. Try to get version from Version.java + version_file = ( + repo_root + / "core" + / "src" + / "main" + / "java" + / "com" + / "google" + / "adk" + / "Version.java" + ) + if version_file.exists(): + try: + content = version_file.read_text() + import re + + match = re.search( + r'JAVA_ADK_VERSION\s*=\s*"([^"]+)"', content + ) + if match: + return match.group(1) + except Exception as e: + logger.warning("Failed to read or parse Version.java: %s", e) + + # 2. Fallback to pom.xml + pom_xml = repo_root / "pom.xml" + if pom_xml.exists(): + import xml.etree.ElementTree as ET + + try: + tree = ET.parse(pom_xml) + root = tree.getroot() + # Handle XML namespace usually present in Maven POMs + ns = {"mvn": "http://maven.apache.org/POM/4.0.0"} + version_node = root.find("mvn:version", ns) + if version_node is None: + # Check parent version + parent = root.find("mvn:parent", ns) + if parent is not None: + version_node = parent.find("mvn:version", ns) + + if version_node is None: + # Try without namespace + version_node = root.find("version") + + if version_node is not None and version_node.text: + version = version_node.text.strip() + return version # Return as soon as we find it + except Exception as e: + logger.warning("Failed to parse pom.xml for version: %s", e) + + # 3. Fallback to build.gradle / build.gradle.kts + for gradle_file in ("build.gradle", "build.gradle.kts"): + path = repo_root / gradle_file + if path.exists(): + try: + content = path.read_text() + for line in content.splitlines(): + if line.strip().startswith("version"): + import re + + match = re.search( + r"""version\s*=?\s*['"]([^'"]+)['"]""", line + ) + if match: + version = match.group(1) + return version # Return as soon as we find it + except Exception: + pass + + return version diff --git a/src/google/adk/scope/extractors/extractor_py.py b/src/google/adk/scope/extractors/extractor_py.py index 97313c1..1626a2e 100644 --- a/src/google/adk/scope/extractors/extractor_py.py +++ b/src/google/adk/scope/extractors/extractor_py.py @@ -7,8 +7,7 @@ from google.adk.scope.extractors.converter_py import NodeProcessor from google.adk.scope.features_pb2 import Feature - -SRC_DIR = "src" +from google.adk.scope.utils.normalizer import normalize_namespace # Initialize Tree-sitter PY_LANGUAGE = Language(tspy.language()) @@ -60,7 +59,7 @@ def find_files( def extract_features( - file_path: pathlib.Path, repo_root: pathlib.Path + file_path: pathlib.Path, repo_root: pathlib.Path, source_root: str ) -> List[Feature]: """Extract Feature objects from a Python file. @@ -106,6 +105,9 @@ def extract_features( # The node is a function_definition feature = processor.process(node, file_path, repo_root) if feature: + feature.normalized_namespace = normalize_namespace( + str(file_path), str(repo_root / source_root) + ) features.append(feature) logger.debug("Extracted feature: %s", feature.original_name) else: diff --git a/src/google/adk/scope/extractors/extractor_ts.py b/src/google/adk/scope/extractors/extractor_ts.py index eca4251..225a498 100644 --- a/src/google/adk/scope/extractors/extractor_ts.py +++ b/src/google/adk/scope/extractors/extractor_ts.py @@ -7,6 +7,7 @@ from google.adk.scope.extractors.converter_ts import NodeProcessor from google.adk.scope.features_pb2 import Feature +from google.adk.scope.utils.normalizer import normalize_namespace # Initialize Tree-sitter try: @@ -51,7 +52,7 @@ def find_files( def extract_features( - file_path: pathlib.Path, repo_root: pathlib.Path + file_path: pathlib.Path, repo_root: pathlib.Path, source_root: str ) -> List[Feature]: try: content = file_path.read_bytes() @@ -77,8 +78,6 @@ def extract_features( cursor = QueryCursor(query) captures = cursor.captures(root_node) - processed_ids = set() - all_nodes = [] if "func" in captures: all_nodes.extend(captures["func"]) @@ -88,12 +87,12 @@ def extract_features( logger.debug("Found %d potential nodes in %s", len(all_nodes), file_path) for node in all_nodes: - if node.id in processed_ids: - continue - processed_ids.add(node.id) feature = processor.process(node, file_path, repo_root) if feature: + feature.normalized_namespace = normalize_namespace( + str(file_path), str(repo_root / source_root) + ) features.append(feature) logger.debug("Extracted feature: %s", feature.original_name) diff --git a/src/google/adk/scope/utils/args.py b/src/google/adk/scope/utils/args.py index 745c99e..3155865 100644 --- a/src/google/adk/scope/utils/args.py +++ b/src/google/adk/scope/utils/args.py @@ -34,7 +34,7 @@ def parse_args() -> argparse.Namespace: "--language", type=str, required=True, - choices=["python", "py", "typescript", "ts"], + choices=["python", "py", "typescript", "ts", "java"], help="Language to extract features for.", ) diff --git a/src/google/adk/scope/utils/normalizer.py b/src/google/adk/scope/utils/normalizer.py index 5adf96d..3f8093f 100644 --- a/src/google/adk/scope/utils/normalizer.py +++ b/src/google/adk/scope/utils/normalizer.py @@ -2,10 +2,28 @@ Unified type normalization for ADK Scope. """ +import os import re from typing import List +def normalize_namespace(file_path: str, source_root: str) -> str: + """Derives a normalized namespace from a file's path.""" + abs_file_path = os.path.abspath(file_path) + abs_source_root = os.path.abspath(source_root) + + if not abs_file_path.startswith(abs_source_root): + return "adk" + + relative_path = os.path.relpath(abs_file_path, abs_source_root) + relative_dir = os.path.dirname(relative_path) + + if not relative_dir: + return "adk" + + return f'adk_{relative_dir.replace(os.path.sep, "_")}' + + def normalize_name(name: str) -> str: """Convert name to snake_case (e.g. CamelCase -> camel_case).""" name = name.replace("-", "_") @@ -22,7 +40,8 @@ def normalize(self, type_name: str, language: str) -> List[str]: return self._normalize_py_type(type_name) elif language == "typescript": return self._normalize_ts_type(type_name) - # Add placeholders for future languages like 'java' and 'go' + elif language == "java": + return self._normalize_java_type(type_name) # Fallback for unknown languages: only normalize if it's a known simple # type, otherwise OBJECT normalized = self._simple_normalize(type_name) @@ -144,6 +163,78 @@ def _normalize_ts_type(self, t: str) -> List[str]: return ["OBJECT"] + def _normalize_java_type(self, t: str) -> List[str]: + # Handle fundamental Java types + t = t.strip() + if not t: + return ["OBJECT"] + + if t in ("void", "Void"): + return ["NULL"] + + # Handle formatting like byte[] as array + if t.endswith("[]"): + return ["LIST"] + + # Generics: CompletableFuture, List, Map + match = re.match(r"([a-zA-Z0-9_]+)<(.+)>$", t) + if match: + base, inner = match.groups() + + # Async types + if base in ( + "CompletableFuture", + "Future", + "Mono", + "Flux", + "Promise", + ): + return self._normalize_java_type(inner) + + if base in ( + "List", + "ArrayList", + "LinkedList", + "Collection", + "Iterable", + ): + return ["LIST"] + if base in ("Map", "HashMap", "TreeMap", "ConcurrentHashMap"): + return ["MAP"] + if base in ("Set", "HashSet", "TreeSet"): + return ["SET"] + if base in ("Optional", "Maybe"): + result = self._normalize_java_type(inner) + if "NULL" not in result: + result.append("NULL") + return result + # Fallback for other generics + return ["OBJECT"] + + t_lower = t.lower() + + if t_lower in ("string", "char", "character", "charsequence"): + return ["STRING"] + if t_lower in ( + "int", + "integer", + "long", + "short", + "byte", + "float", + "double", + "number", + "bigdecimal", + "biginteger", + ): + return ["NUMBER"] + if t_lower in ("boolean", "bool"): + return ["BOOLEAN"] + if t_lower == "object": + return ["OBJECT"] + + return ["OBJECT"] + def _simple_normalize(self, t: str) -> str: t = t.lower().strip() if t == "none": diff --git a/test/adk/scope/extractors/test_converter_java.py b/test/adk/scope/extractors/test_converter_java.py new file mode 100644 index 0000000..e592dfc --- /dev/null +++ b/test/adk/scope/extractors/test_converter_java.py @@ -0,0 +1,193 @@ +import unittest +from pathlib import Path + +import tree_sitter_java as tsjava +from tree_sitter import Language, Parser + +from google.adk.scope.extractors.converter_java import NodeProcessor +from google.adk.scope.features_pb2 import Feature, ParamType + + +class TestNodeProcessor(unittest.TestCase): + def setUp(self): + self.processor = NodeProcessor() + self.language = Language(tsjava.language()) + self.parser = Parser() + self.parser.language = self.language + self.file_path = Path("/mock/repo/src/main/java/com/example/Test.java") + self.repo_root = Path("/mock/repo") + + def test_extract_method_basic(self): + code = b""" + package com.example; + public class Test { + /** + * @return the result + */ + public String doSomething(int count) { return "test"; } + } + """ + tree = self.parser.parse(code) + + # Manually find the method_declaration + root = tree.root_node + method_node = None + for child in root.children: + if child.type == "class_declaration": + body = child.child_by_field_name("body") + for member in body.children: + if member.type == "method_declaration": + method_node = member + break + + feature = self.processor.process( + method_node, self.file_path, self.repo_root + ) + + self.assertIsNotNone(feature) + self.assertEqual(feature.original_name, "doSomething") + self.assertEqual(feature.normalized_name, "do_something") + self.assertEqual(feature.member_of, "Test") + self.assertEqual(feature.normalized_member_of, "test") + self.assertEqual(feature.namespace, "com.example") + self.assertEqual(feature.type, Feature.Type.INSTANCE_METHOD) + self.assertEqual(feature.description, "@return the result") + + self.assertEqual(len(feature.parameters), 1) + self.assertEqual(feature.parameters[0].original_name, "count") + self.assertEqual(feature.parameters[0].normalized_name, "count") + self.assertEqual( + feature.parameters[0].normalized_types, [ParamType.NUMBER] + ) + + self.assertEqual(list(feature.original_return_types), ["String"]) + self.assertEqual(list(feature.normalized_return_types), ["STRING"]) + self.assertFalse(getattr(feature, "async")) + + def test_extract_constructor(self): + code = b""" + package com.example; + public class Test { + public Test(String name) {} + } + """ + tree = self.parser.parse(code) + + # find the constructor_declaration + root = tree.root_node + constructor_node = None + for child in root.children: + if child.type == "class_declaration": + body = child.child_by_field_name("body") + for member in body.children: + if member.type == "constructor_declaration": + constructor_node = member + break + + feature = self.processor.process( + constructor_node, self.file_path, self.repo_root + ) + + self.assertIsNotNone(feature) + self.assertEqual(feature.original_name, "Test") + self.assertEqual(feature.normalized_name, "test") + self.assertEqual(feature.type, Feature.Type.CONSTRUCTOR) + + def test_extract_static_async(self): + code = b""" + package com.example; + import java.util.concurrent.CompletableFuture; + public class Test { + @Beta + public static CompletableFuture runAsync() { return null; } + } + """ + tree = self.parser.parse(code) + + root = tree.root_node + method_node = None + for child in root.children: + if child.type == "class_declaration": + body = child.child_by_field_name("body") + for member in body.children: + if member.type == "method_declaration": + method_node = member + break + + feature = self.processor.process( + method_node, self.file_path, self.repo_root + ) + + self.assertIsNotNone(feature) + self.assertEqual(feature.original_name, "runAsync") + self.assertEqual(feature.type, Feature.Type.CLASS_METHOD) + self.assertEqual(list(feature.normalized_return_types), ["NUMBER"]) + self.assertTrue(getattr(feature, "async")) + self.assertEqual(feature.maturity, Feature.Maturity.EXPERIMENTAL) + + def test_extract_namespace_rewriting(self): + code = b""" + package com.google.adk.agents; + public class Test { + public void agentMethod() {} + } + """ + tree = self.parser.parse(code) + + root = tree.root_node + method_node = None + for child in root.children: + if child.type == "class_declaration": + body = child.child_by_field_name("body") + for member in body.children: + if member.type == "method_declaration": + method_node = member + break + + feature = self.processor.process( + method_node, self.file_path, self.repo_root + ) + + self.assertIsNotNone(feature) + self.assertEqual(feature.namespace, "agents") + self.assertEqual(feature.normalized_namespace, "agents") + + def test_extract_boilerplate_filter(self): + code = b""" + package com.google.adk; + public class Test { + public String getName() { return ""; } + public void setName(String name) {} + public boolean isValid() { return true; } + public boolean equals(Object o) { return false; } + public int hashCode() { return 0; } + public String toString() { return ""; } + public void normalMethod() {} + } + """ + tree = self.parser.parse(code) + + root = tree.root_node + methods = [] + for child in root.children: + if child.type == "class_declaration": + body = child.child_by_field_name("body") + for member in body.children: + if member.type == "method_declaration": + methods.append(member) + + features = [ + self.processor.process(m, self.file_path, self.repo_root) + for m in methods + ] + + # Only 'normalMethod' should not be filtered out + valid_features = [f for f in features if f is not None] + self.assertEqual(len(valid_features), 1) + self.assertEqual(valid_features[0].original_name, "normalMethod") + # Ensure 'com.google.adk' was completely stripped to empty string + self.assertEqual(valid_features[0].namespace, "") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/adk/scope/extractors/test_extractor_java.py b/test/adk/scope/extractors/test_extractor_java.py new file mode 100644 index 0000000..15eec72 --- /dev/null +++ b/test/adk/scope/extractors/test_extractor_java.py @@ -0,0 +1,85 @@ +import unittest +from pathlib import Path +from unittest.mock import MagicMock, patch + +from google.adk.scope.extractors.extractor_java import ( + extract_features, + find_files, +) +from google.adk.scope.features_pb2 import Feature + + +class TestExtractor(unittest.TestCase): + def test_find_files(self): + # Mock Path.rglob + with ( + patch("pathlib.Path.exists", return_value=True), + patch("pathlib.Path.rglob") as mock_rglob, + ): + p1 = Path("src/main/java/A.java") + p2 = Path("src/test/java/TestA.java") # Should be excluded + p3 = Path("build/classes/B.java") # Should be excluded + p4 = Path("package-info.java") # Should be excluded + p5 = Path("src/main/java/subdir/C.java") + p6 = Path("node_modules/lib.java") # Should be excluded + + mock_rglob.return_value = [p1, p2, p3, p4, p5, p6] + + results = list(find_files(Path("src"))) + + self.assertIn(p1, results) + self.assertNotIn(p2, results) # test excluded + self.assertNotIn(p3, results) # build excluded + self.assertNotIn(p4, results) # package-info excluded + self.assertIn(p5, results) + self.assertNotIn(p6, results) # node_modules excluded + + @patch("google.adk.scope.extractors.extractor_java.QueryCursor") + @patch("google.adk.scope.extractors.extractor_java.Query") + @patch("google.adk.scope.extractors.extractor_java.PARSER") + def test_extract_features( + self, mock_parser, mock_query_cls, mock_cursor_cls + ): + mock_path = MagicMock(spec=Path) + mock_path.read_bytes.return_value = b"class A { void foo() {} }" + + mock_tree = MagicMock() + mock_parser.parse.return_value = mock_tree + mock_tree.root_node = MagicMock() + + mock_cursor_instance = mock_cursor_cls.return_value + + mock_node = MagicMock() + mock_node.id = 123 + mock_cursor_instance.captures.return_value = {"method": [mock_node]} + + with patch( + "google.adk.scope.extractors.extractor_java.NodeProcessor" + ) as MockProcessor: + processor_instance = MockProcessor.return_value + expected_feature = Feature( + original_name="foo", normalized_name="foo" + ) + processor_instance.process.return_value = expected_feature + + features = extract_features(mock_path, Path("/repo"), ".") + + self.assertEqual(len(features), 1) + self.assertEqual(features[0], expected_feature) + + processor_instance.process.assert_called_once() + + def test_find_files_not_exists(self): + with patch("pathlib.Path.exists", return_value=False): + results = list(find_files(Path("bad_path"))) + self.assertEqual(results, []) + + def test_extract_features_read_error(self): + mock_path = MagicMock(spec=Path) + mock_path.read_bytes.side_effect = IOError("Read error") + features = extract_features(mock_path, Path("/repo"), ".") + self.assertEqual(features, []) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/adk/scope/extractors/test_extractor_py.py b/test/adk/scope/extractors/test_extractor_py.py index d16ec3c..165d7f4 100644 --- a/test/adk/scope/extractors/test_extractor_py.py +++ b/test/adk/scope/extractors/test_extractor_py.py @@ -76,7 +76,7 @@ def test_extract_features( ) processor_instance.process.return_value = expected_feature - features = extract_features(mock_path, Path("/repo")) + features = extract_features(mock_path, Path("/repo"), ".") self.assertEqual(len(features), 1) self.assertEqual(features[0], expected_feature) @@ -92,7 +92,7 @@ def test_find_files_not_exists(self): def test_extract_features_read_error(self): mock_path = MagicMock(spec=Path) mock_path.read_bytes.side_effect = IOError("Read error") - features = extract_features(mock_path, Path("/repo")) + features = extract_features(mock_path, Path("/repo"), ".") self.assertEqual(features, []) diff --git a/test/adk/scope/extractors/test_extractor_ts.py b/test/adk/scope/extractors/test_extractor_ts.py index c52bd19..5fc6d5a 100644 --- a/test/adk/scope/extractors/test_extractor_ts.py +++ b/test/adk/scope/extractors/test_extractor_ts.py @@ -82,7 +82,7 @@ def test_extract_features(self, mock_parser): original_name="foo" ) - features = extractor.extract_features(p, self.test_dir) + features = extractor.extract_features(p, self.test_dir, ".") self.assertEqual(len(features), 1) self.assertEqual(features[0].original_name, "foo") diff --git a/test/adk/scope/utils/test_normalizer.py b/test/adk/scope/utils/test_normalizer.py index 284bf54..3802a1c 100644 --- a/test/adk/scope/utils/test_normalizer.py +++ b/test/adk/scope/utils/test_normalizer.py @@ -116,6 +116,47 @@ def test_typescript_normalization(self): ["STRING", "NULL"], ) + def test_java_normalization(self): + self.assertEqual( + self.normalizer.normalize("String", "java"), ["STRING"] + ) + self.assertEqual(self.normalizer.normalize("int", "java"), ["NUMBER"]) + self.assertEqual( + self.normalizer.normalize("Integer", "java"), ["NUMBER"] + ) + self.assertEqual( + self.normalizer.normalize("boolean", "java"), ["BOOLEAN"] + ) + self.assertEqual(self.normalizer.normalize("byte[]", "java"), ["LIST"]) + self.assertEqual( + self.normalizer.normalize("List", "java"), ["LIST"] + ) + self.assertEqual( + self.normalizer.normalize("Map", "java"), ["MAP"] + ) + self.assertEqual( + self.normalizer.normalize("Set", "java"), ["SET"] + ) + self.assertEqual(self.normalizer.normalize("void", "java"), ["NULL"]) + self.assertEqual( + self.normalizer.normalize("Object", "java"), ["OBJECT"] + ) + self.assertEqual( + self.normalizer.normalize("MyCustomClass", "java"), ["OBJECT"] + ) + self.assertEqual( + self.normalizer.normalize("CompletableFuture", "java"), + ["STRING"], + ) + self.assertEqual( + self.normalizer.normalize("Optional", "java"), + ["STRING", "NULL"], + ) + self.assertEqual( + self.normalizer.normalize("Maybe", "java"), + ["STRING", "NULL"], + ) + def test_edge_cases(self): self.assertEqual(self.normalizer.normalize("", "python"), ["OBJECT"]) self.assertEqual(self.normalizer.normalize(" ", "python"), ["OBJECT"])