|
| 1 | +import ast |
| 2 | +import pathlib |
| 3 | +import tomllib |
| 4 | +import sys |
| 5 | + |
| 6 | +IMPORT_TO_PACKAGE = { |
| 7 | + "google": "protobuf", |
| 8 | + "dateutil": "python_dateutil", |
| 9 | + "grpc": "grpcio", |
| 10 | + "azure.functions": "azure_functions", |
| 11 | + "azurefunctions.extensions.base": "azurefunctions_extensions_base", |
| 12 | +} |
| 13 | + |
| 14 | + |
| 15 | +def normalize_import(import_name): |
| 16 | + return IMPORT_TO_PACKAGE.get(import_name, import_name) |
| 17 | + |
| 18 | + |
| 19 | +def find_local_modules(src_dir): |
| 20 | + local = set() |
| 21 | + for py in src_dir.rglob("*.py"): |
| 22 | + if py.name == "__init__.py": |
| 23 | + local.add(py.parent.name) |
| 24 | + else: |
| 25 | + local.add(py.stem) |
| 26 | + return local |
| 27 | + |
| 28 | + |
| 29 | +def find_imports(src_dir): |
| 30 | + imports = set() |
| 31 | + for py in src_dir.rglob("*.py"): |
| 32 | + with open(py, "r", encoding="utf8") as f: |
| 33 | + tree = ast.parse(f.read(), filename=str(py)) |
| 34 | + |
| 35 | + for node in ast.walk(tree): |
| 36 | + # import x.y |
| 37 | + if isinstance(node, ast.Import): |
| 38 | + for n in node.names: |
| 39 | + if n.name == "azurefunctions.extensions.base": |
| 40 | + imports.add("azurefunctions.extensions.base") |
| 41 | + elif n.name == "azure.functions": |
| 42 | + imports.add("azure.functions") |
| 43 | + else: |
| 44 | + imports.add(n.name.split(".")[0]) |
| 45 | + |
| 46 | + # from x import y |
| 47 | + elif isinstance(node, ast.ImportFrom): |
| 48 | + # 🔹 Ignore relative imports |
| 49 | + if node.level > 0: |
| 50 | + continue |
| 51 | + |
| 52 | + if node.module: |
| 53 | + if node.module == "azure.functions": |
| 54 | + imports.add("azure.functions") |
| 55 | + elif node.module == "azurefunctions.extensions.base": |
| 56 | + imports.add("azurefunctions.extensions.base") |
| 57 | + # Special cases to ignore |
| 58 | + elif str(src_dir).startswith("workers") and ( |
| 59 | + node.module == "azure.monitor.opentelemetry" |
| 60 | + or node.module == "opentelemetry" |
| 61 | + or node.module == "opentelemetry.trace.propagation.tracecontext" |
| 62 | + or node.module == "Cookie"): |
| 63 | + pass |
| 64 | + elif str(src_dir).startswith("runtimes\\v1\\azure_functions_runtime_v1") and ( |
| 65 | + node.module == "google.protobuf.timestamp_pb2" |
| 66 | + or node.module == "azure.monitor.opentelemetry" |
| 67 | + or node.module == "opentelemetry" |
| 68 | + or node.module == "opentelemetry.trace.propagation.tracecontext" |
| 69 | + or node.module == "Cookie"): |
| 70 | + pass |
| 71 | + elif str(src_dir).startswith("runtimes\\v2\\azure_functions_runtime") and ( |
| 72 | + node.module == "google.protobuf.duration_pb2" |
| 73 | + or node.module == "google.protobuf.timestamp_pb2" |
| 74 | + or node.module == "azure.monitor.opentelemetry" |
| 75 | + or node.module == "opentelemetry" |
| 76 | + or node.module == "opentelemetry.trace.propagation.tracecontext" |
| 77 | + or node.module == "Cookie"): |
| 78 | + pass |
| 79 | + else: |
| 80 | + imports.add(node.module.split(".")[0]) |
| 81 | + |
| 82 | + return imports |
| 83 | + |
| 84 | + |
| 85 | +def load_declared_dependencies(pyproject): |
| 86 | + data = tomllib.loads(pyproject.read_text()) |
| 87 | + deps = data["project"]["dependencies"] |
| 88 | + # Strip extras/markers, e.g. "protobuf~=4.25.3; python_version < '3.13'" |
| 89 | + normalized = set() |
| 90 | + for d in deps: |
| 91 | + name = d.split(";")[0].strip() # strip environment marker |
| 92 | + name = name.split("[")[0].strip() # strip extras |
| 93 | + pkg = name.split("==")[0].split("~=")[0].split(">=")[0].split("<=")[0] |
| 94 | + normalized.add(pkg.lower().replace("-", "_")) |
| 95 | + return normalized |
| 96 | + |
| 97 | + |
| 98 | +def check_package(pkg_root, package_name): |
| 99 | + pyproject = pkg_root / "pyproject.toml" |
| 100 | + src_dir = pkg_root / package_name |
| 101 | + |
| 102 | + imports = find_imports(src_dir) |
| 103 | + deps = load_declared_dependencies(pyproject) |
| 104 | + stdlib = set(stdlib_modules()) |
| 105 | + local_modules = find_local_modules(src_dir) |
| 106 | + print("Found imports:", imports) |
| 107 | + print("Declared dependencies:", deps) |
| 108 | + |
| 109 | + missing = [] |
| 110 | + |
| 111 | + for imp in imports: |
| 112 | + |
| 113 | + normalized = normalize_import(imp) |
| 114 | + if ( |
| 115 | + normalized not in deps |
| 116 | + and imp not in stdlib |
| 117 | + and imp not in local_modules |
| 118 | + and imp != package_name |
| 119 | + ): |
| 120 | + missing.append(imp) |
| 121 | + |
| 122 | + |
| 123 | + |
| 124 | + if missing: |
| 125 | + print("Missing required dependencies:") |
| 126 | + for m in missing: |
| 127 | + print(" -", m) |
| 128 | + raise SystemExit(1) |
| 129 | + |
| 130 | + |
| 131 | +def stdlib_modules(): |
| 132 | + # simple version |
| 133 | + import sys |
| 134 | + return set(sys.stdlib_module_names) |
| 135 | + |
| 136 | + |
| 137 | +def main(): |
| 138 | + roots = sys.argv[1] |
| 139 | + package_name = sys.argv[2] |
| 140 | + |
| 141 | + if not roots: |
| 142 | + print("Usage: python check_imports.py <pkg_dir> [<pkg_dir> ...]") |
| 143 | + sys.exit(2) |
| 144 | + |
| 145 | + failed = False |
| 146 | + check_package(pathlib.Path(roots), package_name) |
| 147 | + |
| 148 | + sys.exit(1 if failed else 0) |
| 149 | + |
| 150 | + |
| 151 | +if __name__ == "__main__": |
| 152 | + main() |
0 commit comments