diff --git a/src/s2dm/cli.py b/src/s2dm/cli.py index 2d246d01..a7bbfa1f 100644 --- a/src/s2dm/cli.py +++ b/src/s2dm/cli.py @@ -7,7 +7,7 @@ import rich_click as click import yaml -from graphql import build_schema, parse +from graphql import GraphQLSchema, build_schema, parse from rich.traceback import install from s2dm import __version__, log @@ -21,6 +21,7 @@ from s2dm.exporters.utils.graphql_type import is_builtin_scalar_type, is_introspection_type from s2dm.exporters.utils.schema import load_schema_with_naming, search_schema from s2dm.exporters.utils.schema_loader import ( + check_correct_schema, create_tempfile_to_composed_schema, load_schema, load_schema_as_str, @@ -52,15 +53,6 @@ def process_value(self, ctx: click.Context, value: Any) -> list[Path] | None: if not value: return None paths = set(value) - - # Include the default QUDT units directory if it exists, otherwise warn and don't include it - if DEFAULT_QUDT_UNITS_DIR.exists(): - paths.add(DEFAULT_QUDT_UNITS_DIR) - else: - log.warning( - f"No QUDT units directory found at {DEFAULT_QUDT_UNITS_DIR}. Please run 's2dm units sync' first." - ) - return resolve_graphql_files(list(paths)) @@ -102,6 +94,7 @@ def selection_query_option(required: bool = False) -> Callable[[Callable[..., An help="Output file", ) + optional_output_option = click.option( "--output", "-o", @@ -111,6 +104,16 @@ def selection_query_option(required: bool = False) -> Callable[[Callable[..., An ) +units_directory_option = click.option( + "--directory", + "-d", + type=click.Path(file_okay=False, path_type=Path), + default=DEFAULT_QUDT_UNITS_DIR, + help="Directory for QUDT unit enums", + show_default=True, +) + + expanded_instances_option = click.option( "--expanded-instances", "-e", @@ -140,6 +143,16 @@ def multiline_str_representer(obj: Any) -> Any: return {k: multiline_str_representer(v) for k, v in result.items()} +def assert_correct_schema(schema: GraphQLSchema) -> None: + schema_errors = check_correct_schema(schema) + if schema_errors: + log.error("Schema validation failed:") + for error in schema_errors: + log.error(error) + log.error(f"Found {len(schema_errors)} validation error(s). Please fix the schema before exporting.") + sys.exit(1) + + def validate_naming_config(config: dict[str, Any]) -> None: VALID_CASES = { "camelCase", @@ -315,39 +328,48 @@ def units() -> None: "QUDT version tag (e.g., 3.1.6). Defaults to the latest tag, falls back to 'main' when tags are unavailable." ), ) +@units_directory_option @click.option( "--dry-run", is_flag=True, help="Show what would be generated without actually writing files", ) -def units_sync(version: str | None, dry_run: bool) -> None: - """Fetch QUDT quantity kinds and generate GraphQL enums under the output directory.""" +def units_sync(version: str | None, directory: Path, dry_run: bool) -> None: + """Fetch QUDT quantity kinds and generate GraphQL enums under the specified directory. + + Args: + version: QUDT version tag. Defaults to the latest tag. + directory: Output directory for generated QUDT unit enums (default: ~/.s2dm/units/qudt) + dry_run: Show what would be generated without actually writing files + """ version_to_use = version or get_latest_qudt_version() try: - written = sync_qudt_units(DEFAULT_QUDT_UNITS_DIR, version_to_use, dry_run=dry_run) + written = sync_qudt_units(directory, version_to_use, dry_run=dry_run) except UnitEnumError as e: log.error(f"Units sync failed: {e}") sys.exit(1) if dry_run: - log.info(f"Would generate {len(written)} enum files under {DEFAULT_QUDT_UNITS_DIR}") + log.info(f"Would generate {len(written)} enum files under {directory}") log.print(f"Version: {version_to_use}") log.hint("Use without --dry-run to actually write files") else: - log.success(f"Generated {len(written)} enum files under {DEFAULT_QUDT_UNITS_DIR}") + log.success(f"Generated {len(written)} enum files under {directory}") log.print(f"Version: {version_to_use}") @units.command(name="check-version") -def units_check_version() -> None: +@units_directory_option +def units_check_version(directory: Path) -> None: """Compare local synced QUDT version with the latest remote version and print a message. + Args: - qudt_units_dir: Directory containing generated QUDT unit enums (default: ~/.s2dm/units/qudt) + directory: Directory containing generated QUDT unit enums (default: ~/.s2dm/units/qudt) """ - meta_path = DEFAULT_QUDT_UNITS_DIR / UNITS_META_FILENAME + meta_path = directory / UNITS_META_FILENAME if not meta_path.exists(): log.warning("No metadata.json found. Run 's2dm units sync' first.") sys.exit(1) @@ -398,6 +420,7 @@ def compose(schemas: list[Path], root_type: str | None, selection_query: Path | composed_schema_str = load_schema_as_str(schemas, add_references=True) graphql_schema = build_schema(composed_schema_str) + assert_correct_schema(graphql_schema) if selection_query: query_document = parse(selection_query.read_text()) @@ -487,6 +510,7 @@ def shacl( naming_config = ctx.obj.get("naming_config") graphql_schema = load_schema_with_naming(schemas, naming_config) + assert_correct_schema(graphql_schema) if selection_query: query_document = parse(selection_query.read_text()) @@ -515,6 +539,7 @@ def vspec(ctx: click.Context, schemas: list[Path], selection_query: Path | None, """Generate VSPEC from a given GraphQL schema.""" naming_config = ctx.obj.get("naming_config") graphql_schema = load_schema_with_naming(schemas, naming_config) + assert_correct_schema(graphql_schema) if selection_query: query_document = parse(selection_query.read_text()) @@ -553,6 +578,7 @@ def jsonschema( """Generate JSON Schema from a given GraphQL schema.""" naming_config = ctx.obj.get("naming_config") graphql_schema = load_schema_with_naming(schemas, naming_config) + assert_correct_schema(graphql_schema) if selection_query: query_document = parse(selection_query.read_text()) @@ -597,6 +623,7 @@ def protobuf( """Generate Protocol Buffers (.proto) file from GraphQL schema.""" naming_config = ctx.obj.get("naming_config") graphql_schema = load_schema_with_naming(schemas, naming_config) + assert_correct_schema(graphql_schema) query_document = parse(selection_query.read_text()) graphql_schema = prune_schema_using_query_selection(graphql_schema, query_document) diff --git a/src/s2dm/exporters/utils/schema_loader.py b/src/s2dm/exporters/utils/schema_loader.py index e691494d..365952fc 100644 --- a/src/s2dm/exporters/utils/schema_loader.py +++ b/src/s2dm/exporters/utils/schema_loader.py @@ -6,6 +6,7 @@ from ariadne import load_schema_from_path from graphql import ( DocumentNode, + GraphQLEnumType, GraphQLField, GraphQLInputObjectType, GraphQLInterfaceType, @@ -17,6 +18,7 @@ GraphQLString, GraphQLType, GraphQLUnionType, + Undefined, build_schema, get_named_type, is_input_object_type, @@ -26,9 +28,10 @@ is_object_type, is_union_type, print_schema, + validate_schema, ) from graphql import validate as graphql_validate -from graphql.language.ast import SelectionSetNode +from graphql.language.ast import DirectiveNode, EnumValueNode, SelectionSetNode from s2dm import log from s2dm.exporters.utils.directive import ( @@ -268,6 +271,150 @@ def create_tempfile_to_composed_schema(graphql_schema_paths: list[Path]) -> Path return Path(temp_path) +def _check_directive_usage_on_node(schema: GraphQLSchema, directive_node: DirectiveNode, context: str) -> list[str]: + """Check enum values in directive usage on a specific node.""" + errors: list[str] = [] + + directive_def = schema.get_directive(directive_node.name.value) + if not directive_def: + return errors + + for arg_node in directive_node.arguments: + arg_name = arg_node.name.value + arg_def = directive_def.args[arg_name] + named_type = get_named_type(arg_def.type) + + if not isinstance(named_type, GraphQLEnumType): + continue + + if not isinstance(arg_node.value, EnumValueNode): + continue + + enum_value = arg_node.value.value + if enum_value not in named_type.values: + errors.append( + f"{context} uses directive '@{directive_node.name.value}({arg_name})' " + f"with invalid enum value '{enum_value}'. Valid values are: {list(named_type.values.keys())}" + ) + + return errors + + +def check_enum_defaults(schema: GraphQLSchema) -> list[str]: + """Check that all enum default values exist in their enum definitions. + + Args: + schema: The GraphQL schema to validate + + Returns: + List of error messages for invalid enum defaults + """ + errors = [] + + for type_name, type_obj in schema.type_map.items(): + # Validate directive usage on types + if type_obj.ast_node and type_obj.ast_node.directives: + for directive_node in type_obj.ast_node.directives: + errors.extend(_check_directive_usage_on_node(schema, directive_node, f"Type '{type_name}'")) + + # Validate input object field defaults + if isinstance(type_obj, GraphQLInputObjectType): + for field_name, field in type_obj.fields.items(): + named_type = get_named_type(field.type) + if not isinstance(named_type, GraphQLEnumType): + continue + + has_default_in_ast = field.ast_node and field.ast_node.default_value is not None + if not (has_default_in_ast and field.default_value is Undefined): + continue + + invalid_value = field.ast_node.default_value.value + errors.append( + f"Input type '{type_name}.{field_name}' has invalid enum default value '{invalid_value}'. " + f"Valid values are: {list(named_type.values.keys())}" + ) + + # Validate field argument defaults and directive usage on fields + if isinstance(type_obj, GraphQLObjectType | GraphQLInterfaceType | GraphQLInputObjectType): + for field_name, field in type_obj.fields.items(): + # Validate directive usage on fields + if field.ast_node and field.ast_node.directives: + for directive_node in field.ast_node.directives: + errors.extend( + _check_directive_usage_on_node(schema, directive_node, f"Field '{type_name}.{field_name}'") + ) + + # Validate field argument defaults + if isinstance(type_obj, GraphQLObjectType | GraphQLInterfaceType): + for arg_name, arg in field.args.items(): + named_type = get_named_type(arg.type) + if not isinstance(named_type, GraphQLEnumType): + continue + + has_default_in_ast = arg.ast_node and arg.ast_node.default_value is not None + if not (has_default_in_ast and arg.default_value is Undefined): + continue + + invalid_value = arg.ast_node.default_value.value + errors.append( + f"Field argument '{type_name}.{field_name}({arg_name})' " + f"has invalid enum default value '{invalid_value}'. " + f"Valid values are: {list(named_type.values.keys())}" + ) + + # Validate directive definition defaults + for directive in schema.directives: + for arg_name, arg in directive.args.items(): + named_type = get_named_type(arg.type) + if not isinstance(named_type, GraphQLEnumType): + continue + + if not arg.ast_node or not arg.ast_node.default_value: + continue + + if arg.default_value is not Undefined: + continue + + if not isinstance(arg.ast_node.default_value, EnumValueNode): + continue + + invalid_value = arg.ast_node.default_value.value + errors.append( + f"Directive definition '@{directive.name}({arg_name})' " + f"has invalid enum default value '{invalid_value}'. Valid values are: {list(named_type.values.keys())}" + ) + + return errors + + +def check_correct_schema(schema: GraphQLSchema) -> list[str]: + """Assert that the schema conforms to GraphQL specification and has valid enum defaults. + + Args: + schema: The GraphQL schema to validate + + Returns: + list[str]: List of error messages if any validation errors are found + + Exits: + Calls sys.exit(1) if the schema has validation errors + """ + spec_errors = validate_schema(schema) + enum_errors = check_enum_defaults(schema) + + all_errors: list[str] = [] + + if spec_errors: + for spec_error in spec_errors: + all_errors.append(f" - {spec_error.message}") + + if enum_errors: + for enum_error in enum_errors: + all_errors.append(f" - {enum_error}") + + return all_errors + + def ensure_query(schema: GraphQLSchema) -> GraphQLSchema: """ Ensures that the provided GraphQL schema has a Query type. If the schema does not have a Query type, @@ -377,7 +524,7 @@ def visit_field_type(field_type: GraphQLType) -> None: return referenced -def validate_schema(schema: GraphQLSchema, document: DocumentNode) -> GraphQLSchema | None: +def _validate_schema(schema: GraphQLSchema, document: DocumentNode) -> GraphQLSchema | None: log.debug("Validating schema against the provided document") errors = graphql_validate(schema, document) @@ -406,7 +553,7 @@ def prune_schema_using_query_selection(schema: GraphQLSchema, document: Document if not schema.query_type: raise ValueError("Schema has no query type defined") - if validate_schema(schema, document) is None: + if _validate_schema(schema, document) is None: raise ValueError("Schema validation failed") fields_to_keep: dict[str, set[str]] = {} diff --git a/tests/conftest.py b/tests/conftest.py index bbe8901a..0d728b7f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -49,19 +49,6 @@ class TestSchemaData: BREAKING_SCHEMA = TESTS_DATA_DIR / "breaking.graphql" -@pytest.fixture(autouse=True) -def patch_default_units_dir(monkeypatch: pytest.MonkeyPatch) -> None: - """Patch DEFAULT_QUDT_UNITS_DIR to use tests/data/units for all tests. - - This prevents the "No QUDT units directory found" warning during tests - and provides the necessary unit enum definitions that test schemas reference. - - Tests that use the units_sync_mocks fixture will have this overridden with - their own tmp_path directory for isolation. - """ - monkeypatch.setattr("s2dm.cli.DEFAULT_QUDT_UNITS_DIR", TestSchemaData.UNITS_SCHEMA_PATH) - - def parsed_console_output() -> str: """Parse console output (placeholder function).""" return "" diff --git a/tests/test_e2e_cli.py b/tests/test_e2e_cli.py index f57b531d..aed1e8bf 100644 --- a/tests/test_e2e_cli.py +++ b/tests/test_e2e_cli.py @@ -11,6 +11,12 @@ from tests.conftest import TestSchemaData as TSD +@pytest.fixture(scope="session") +def units_directory() -> Path: + """Return the test data units directory.""" + return TSD.UNITS_SCHEMA_PATH + + @pytest.fixture(scope="module") def runner() -> CliRunner: return CliRunner() @@ -50,10 +56,24 @@ def contains_value(obj: dict[str, Any] | list[Any] | str, target: str) -> bool: # ToDo(DA): please update this test to do proper asserts for the shacl exporter -def test_export_shacl(runner: CliRunner, tmp_outputs: Path) -> None: +def test_export_shacl(runner: CliRunner, tmp_outputs: Path, units_directory: Path) -> None: out = tmp_outputs / "shacl.ttl" result = runner.invoke( - cli, ["export", "shacl", "-s", str(TSD.SAMPLE1_1), "-s", str(TSD.SAMPLE1_2), "-o", str(out), "-f", "ttl"] + cli, + [ + "export", + "shacl", + "-s", + str(TSD.SAMPLE1_1), + "-s", + str(TSD.SAMPLE1_2), + "-s", + str(units_directory), + "-o", + str(out), + "-f", + "ttl", + ], ) assert result.exit_code == 0, result.output assert out.exists() @@ -65,9 +85,23 @@ def test_export_shacl(runner: CliRunner, tmp_outputs: Path) -> None: # ToDo(DA): please update this test to do proper asserts for the vspec exporter -def test_export_vspec(runner: CliRunner, tmp_outputs: Path) -> None: +def test_export_vspec(runner: CliRunner, tmp_outputs: Path, units_directory: Path) -> None: out = tmp_outputs / "vspec.yaml" - result = runner.invoke(cli, ["export", "vspec", "-s", str(TSD.SAMPLE1_1), "-s", str(TSD.SAMPLE1_2), "-o", str(out)]) + result = runner.invoke( + cli, + [ + "export", + "vspec", + "-s", + str(TSD.SAMPLE1_1), + "-s", + str(TSD.SAMPLE1_2), + "-s", + str(units_directory), + "-o", + str(out), + ], + ) assert result.exit_code == 0, result.output assert out.exists() with open(out, encoding="utf-8") as f: @@ -77,10 +111,22 @@ def test_export_vspec(runner: CliRunner, tmp_outputs: Path) -> None: assert "Vehicle_ADAS_ObstacleDetection:" in content -def test_export_jsonschema(runner: CliRunner, tmp_outputs: Path) -> None: +def test_export_jsonschema(runner: CliRunner, tmp_outputs: Path, units_directory: Path) -> None: out = tmp_outputs / "jsonschema.yaml" result = runner.invoke( - cli, ["export", "jsonschema", "-s", str(TSD.SAMPLE1_1), "-s", str(TSD.SAMPLE1_2), "-o", str(out)] + cli, + [ + "export", + "jsonschema", + "-s", + str(TSD.SAMPLE1_1), + "-s", + str(TSD.SAMPLE1_2), + "-s", + str(units_directory), + "-o", + str(out), + ], ) assert result.exit_code == 0, result.output assert out.exists() @@ -91,7 +137,7 @@ def test_export_jsonschema(runner: CliRunner, tmp_outputs: Path) -> None: assert '"Vehicle_ADAS_ObstacleDetection"' in content -def test_export_protobuf(runner: CliRunner, tmp_outputs: Path) -> None: +def test_export_protobuf(runner: CliRunner, tmp_outputs: Path, units_directory: Path) -> None: out = tmp_outputs / "schema.proto" result = runner.invoke( cli, @@ -102,6 +148,8 @@ def test_export_protobuf(runner: CliRunner, tmp_outputs: Path) -> None: str(TSD.SAMPLE1_1), "-s", str(TSD.SAMPLE1_2), + "-s", + str(units_directory), "-q", str(TSD.SCHEMA1_QUERY), "-o", @@ -128,7 +176,7 @@ def test_export_protobuf(runner: CliRunner, tmp_outputs: Path) -> None: assert "optional bool isEngaged = 1;" in content -def test_export_protobuf_flattened_naming(runner: CliRunner, tmp_outputs: Path) -> None: +def test_export_protobuf_flattened_naming(runner: CliRunner, tmp_outputs: Path, units_directory: Path) -> None: out = tmp_outputs / "schema.proto" result = runner.invoke( cli, @@ -139,6 +187,8 @@ def test_export_protobuf_flattened_naming(runner: CliRunner, tmp_outputs: Path) str(TSD.SAMPLE1_1), "-s", str(TSD.SAMPLE1_2), + "-s", + str(units_directory), "-q", str(TSD.SCHEMA1_QUERY), "-o", @@ -166,10 +216,22 @@ def test_export_protobuf_flattened_naming(runner: CliRunner, tmp_outputs: Path) assert 'optional bool Vehicle_adas_abs_isEngaged = 3 [(field_source) = "Vehicle_ADAS_ABS"];' in content -def test_generate_skos_skeleton(runner: CliRunner, tmp_outputs: Path) -> None: +def test_generate_skos_skeleton(runner: CliRunner, tmp_outputs: Path, units_directory: Path) -> None: out = tmp_outputs / "skos_skeleton.ttl" result = runner.invoke( - cli, ["generate", "skos-skeleton", "-s", str(TSD.SAMPLE1_1), "-s", str(TSD.SAMPLE1_2), "-o", str(out)] + cli, + [ + "generate", + "skos-skeleton", + "-s", + str(TSD.SAMPLE1_1), + "-s", + str(TSD.SAMPLE1_2), + "-s", + str(units_directory), + "-o", + str(out), + ], ) assert result.exit_code == 0, result.output assert out.exists() @@ -201,13 +263,15 @@ def test_generate_skos_skeleton(runner: CliRunner, tmp_outputs: Path) -> None: ], ) def test_check_version_bump( - runner: CliRunner, schema_file: list[Path], previous_file: list[Path], expected_output: str + runner: CliRunner, schema_file: list[Path], previous_file: list[Path], expected_output: str, units_directory: Path ) -> None: result = runner.invoke( cli, ["check", "version-bump"] + [item for schema in schema_file for item in ["-s", str(schema)]] - + [item for previous in previous_file for item in ["--previous", str(previous)]], + + ["-s", str(units_directory)] + + [item for previous in previous_file for item in ["--previous", str(previous)]] + + ["--previous", str(units_directory)], ) assert result.exit_code == 0, result.output # Replace all newlines and additional spaces with a single space with regex @@ -242,16 +306,32 @@ def test_check_version_bump_output_type( ((TSD.SAMPLE1_1, TSD.SAMPLE1_2), "All constraints passed"), ], ) -def test_check_constraints(runner: CliRunner, input_file: tuple[Path, Path], expected_output: str) -> None: - result = runner.invoke(cli, ["check", "constraints", "-s", str(input_file[0]), "-s", str(input_file[1])]) +def test_check_constraints( + runner: CliRunner, input_file: tuple[Path, Path], expected_output: str, units_directory: Path +) -> None: + result = runner.invoke( + cli, ["check", "constraints", "-s", str(input_file[0]), "-s", str(input_file[1]), "-s", str(units_directory)] + ) assert expected_output.lower() in normalize_whitespace(result.output).lower() assert result.exit_code in (0, 1) -def test_validate_graphql(runner: CliRunner, tmp_outputs: Path) -> None: +def test_validate_graphql(runner: CliRunner, tmp_outputs: Path, units_directory: Path) -> None: out = tmp_outputs / "validate.json" result = runner.invoke( - cli, ["validate", "graphql", "-s", str(TSD.SAMPLE1_1), "-s", str(TSD.SAMPLE1_2), "-o", str(out)] + cli, + [ + "validate", + "graphql", + "-s", + str(TSD.SAMPLE1_1), + "-s", + str(TSD.SAMPLE1_2), + "-s", + str(units_directory), + "-o", + str(out), + ], ) assert result.exit_code == 0, result.output assert out.exists() @@ -274,6 +354,7 @@ def test_diff_graphql( schemas: tuple[Path, Path], val_schemas: tuple[Path, Path], expected_output: str, + units_directory: Path, ) -> None: out = tmp_outputs / "diff.json" result = runner.invoke( @@ -285,10 +366,14 @@ def test_diff_graphql( str(schemas[0]), "-s", str(schemas[1]), + "-s", + str(units_directory), "--val-schema", str(val_schemas[0]), "--val-schema", str(val_schemas[1]), + "--val-schema", + str(units_directory), "-o", str(out), ], @@ -299,7 +384,7 @@ def test_diff_graphql( assert expected_output in file_content or expected_output in result.output -def test_registry_export_concept_uri(runner: CliRunner, tmp_outputs: Path) -> None: +def test_registry_export_concept_uri(runner: CliRunner, tmp_outputs: Path, units_directory: Path) -> None: out = tmp_outputs / "concept_uris.json" result = runner.invoke( cli, @@ -310,6 +395,8 @@ def test_registry_export_concept_uri(runner: CliRunner, tmp_outputs: Path) -> No str(TSD.SAMPLE1_1), "-s", str(TSD.SAMPLE1_2), + "-s", + str(units_directory), "-o", str(out), ], @@ -329,7 +416,7 @@ def test_registry_export_concept_uri(runner: CliRunner, tmp_outputs: Path) -> No ), 'Expected value "ns:Person.name" not found in the concept URI output.' -def test_registry_export_id(runner: CliRunner, tmp_outputs: Path) -> None: +def test_registry_export_id(runner: CliRunner, tmp_outputs: Path, units_directory: Path) -> None: out = tmp_outputs / "ids.json" result = runner.invoke( cli, @@ -340,6 +427,8 @@ def test_registry_export_id(runner: CliRunner, tmp_outputs: Path) -> None: str(TSD.SAMPLE1_1), "-s", str(TSD.SAMPLE1_2), + "-s", + str(units_directory), "-o", str(out), ], @@ -352,10 +441,22 @@ def test_registry_export_id(runner: CliRunner, tmp_outputs: Path) -> None: assert any("Person.name" in k for k in data), "Expected 'Person.name' not found in the output." -def test_registry_init(runner: CliRunner, tmp_outputs: Path) -> None: +def test_registry_init(runner: CliRunner, tmp_outputs: Path, units_directory: Path) -> None: out = tmp_outputs / "spec_history.json" result = runner.invoke( - cli, ["registry", "init", "-s", str(TSD.SAMPLE1_1), "-s", str(TSD.SAMPLE1_2), "-o", str(out)] + cli, + [ + "registry", + "init", + "-s", + str(TSD.SAMPLE1_1), + "-s", + str(TSD.SAMPLE1_2), + "-s", + str(units_directory), + "-o", + str(out), + ], ) assert result.exit_code == 0, result.output assert out.exists() @@ -396,11 +497,25 @@ def test_registry_init(runner: CliRunner, tmp_outputs: Path) -> None: ), f'Expected entry with "@id": "ns:Person.height" and specHistory id "{ExpectedIds.PERSON_HEIGHT_ID}" not found.' -def test_registry_update(runner: CliRunner, tmp_outputs: Path) -> None: +def test_registry_update(runner: CliRunner, tmp_outputs: Path, units_directory: Path) -> None: out = tmp_outputs / "spec_history_update.json" # First, create a spec history file init_out = tmp_outputs / "spec_history.json" - runner.invoke(cli, ["registry", "init", "-s", str(TSD.SAMPLE1_1), "-s", str(TSD.SAMPLE1_2), "-o", str(init_out)]) + runner.invoke( + cli, + [ + "registry", + "init", + "-s", + str(TSD.SAMPLE1_1), + "-s", + str(TSD.SAMPLE1_2), + "-s", + str(units_directory), + "-o", + str(init_out), + ], + ) runner.invoke( cli, [ @@ -410,6 +525,8 @@ def test_registry_update(runner: CliRunner, tmp_outputs: Path) -> None: str(TSD.SAMPLE2_1), "-s", str(TSD.SAMPLE2_2), + "-s", + str(units_directory), "-sh", str(init_out), "-o", @@ -466,18 +583,43 @@ def test_registry_update(runner: CliRunner, tmp_outputs: Path) -> None: ("nonExistentField", "No matches found"), ], ) -def test_search_graphql(runner: CliRunner, search_term: str, expected_output: str) -> None: +def test_search_graphql(runner: CliRunner, search_term: str, expected_output: str, units_directory: Path) -> None: result = runner.invoke( - cli, ["search", "graphql", "-s", str(TSD.SAMPLE1_1), "-s", str(TSD.SAMPLE1_2), "-t", search_term, "--exact"] + cli, + [ + "search", + "graphql", + "-s", + str(TSD.SAMPLE1_1), + "-s", + str(TSD.SAMPLE1_2), + "-s", + str(units_directory), + "-t", + search_term, + "--exact", + ], ) assert result.exit_code == 0, result.output assert expected_output.lower() in normalize_whitespace(result.output).lower() -def test_search_skos(runner: CliRunner, tmp_outputs: Path) -> None: +def test_search_skos(runner: CliRunner, tmp_outputs: Path, units_directory: Path) -> None: skos_file = tmp_outputs / "test_skos.ttl" result = runner.invoke( - cli, ["generate", "skos-skeleton", "-s", str(TSD.SAMPLE1_1), "-s", str(TSD.SAMPLE1_2), "-o", str(skos_file)] + cli, + [ + "generate", + "skos-skeleton", + "-s", + str(TSD.SAMPLE1_1), + "-s", + str(TSD.SAMPLE1_2), + "-s", + str(units_directory), + "-o", + str(skos_file), + ], ) assert result.exit_code == 0, result.output assert skos_file.exists() @@ -503,21 +645,41 @@ def test_search_skos(runner: CliRunner, tmp_outputs: Path) -> None: [("Vehicle", 0, "Vehicle"), ("Seat", 1, "Type 'Seat' doesn't exist")], ) def test_similar_graphql( - runner: CliRunner, tmp_outputs: Path, search_term: str, expected_returncode: int, expected_output: str + runner: CliRunner, + tmp_outputs: Path, + search_term: str, + expected_returncode: int, + expected_output: str, + units_directory: Path, ) -> None: out = tmp_outputs / "similar.json" result = runner.invoke( cli, - ["similar", "graphql", "-s", str(TSD.SAMPLE1_1), "-s", str(TSD.SAMPLE1_2), "-k", search_term, "-o", str(out)], + [ + "similar", + "graphql", + "-s", + str(TSD.SAMPLE1_1), + "-s", + str(TSD.SAMPLE1_2), + "-s", + str(units_directory), + "-k", + search_term, + "-o", + str(out), + ], ) assert expected_returncode == result.exit_code, result.output assert expected_output in normalize_whitespace(result.output) assert out.exists() -def test_compose_graphql(runner: CliRunner, tmp_outputs: Path) -> None: +def test_compose_graphql(runner: CliRunner, tmp_outputs: Path, units_directory: Path) -> None: out = tmp_outputs / "composed.graphql" - result = runner.invoke(cli, ["compose", "-s", str(TSD.SAMPLE1_1), "-s", str(TSD.SAMPLE1_2), "-o", str(out)]) + result = runner.invoke( + cli, ["compose", "-s", str(TSD.SAMPLE1_1), "-s", str(TSD.SAMPLE1_2), "-s", str(units_directory), "-o", str(out)] + ) assert result.exit_code == 0, result.output assert out.exists() @@ -531,10 +693,23 @@ def test_compose_graphql(runner: CliRunner, tmp_outputs: Path) -> None: assert "Successfully composed schema" in normalize_whitespace(result.output) -def test_compose_graphql_with_root_type(runner: CliRunner, tmp_outputs: Path) -> None: +def test_compose_graphql_with_root_type(runner: CliRunner, tmp_outputs: Path, units_directory: Path) -> None: out = tmp_outputs / "composed_filtered.graphql" result = runner.invoke( - cli, ["compose", "-s", str(TSD.SAMPLE1_1), "-s", str(TSD.SAMPLE1_2), "-r", "Vehicle", "-o", str(out)] + cli, + [ + "compose", + "-s", + str(TSD.SAMPLE1_1), + "-s", + str(TSD.SAMPLE1_2), + "-s", + str(units_directory), + "-r", + "Vehicle", + "-o", + str(out), + ], ) assert result.exit_code == 0, result.output assert out.exists() @@ -549,10 +724,25 @@ def test_compose_graphql_with_root_type(runner: CliRunner, tmp_outputs: Path) -> assert "type Person" not in composed_content -def test_compose_graphql_with_root_type_nonexistent(runner: CliRunner, tmp_outputs: Path) -> None: +def test_compose_graphql_with_root_type_nonexistent( + runner: CliRunner, tmp_outputs: Path, units_directory: Path +) -> None: out = tmp_outputs / "composed_error.graphql" result = runner.invoke( - cli, ["compose", "-s", str(TSD.SAMPLE1_1), "-s", str(TSD.SAMPLE1_2), "-r", "NonExistentType", "-o", str(out)] + cli, + [ + "compose", + "-s", + str(TSD.SAMPLE1_1), + "-s", + str(TSD.SAMPLE1_2), + "-s", + str(units_directory), + "-r", + "NonExistentType", + "-o", + str(out), + ], ) assert result.exit_code == 1 @@ -560,10 +750,25 @@ def test_compose_graphql_with_root_type_nonexistent(runner: CliRunner, tmp_outpu assert "Root type 'NonExistentType' not found in schema" in normalize_whitespace(result.output) -def test_compose_graphql_root_type_filters_unreferenced_types(runner: CliRunner, tmp_outputs: Path) -> None: +def test_compose_graphql_root_type_filters_unreferenced_types( + runner: CliRunner, tmp_outputs: Path, units_directory: Path +) -> None: out = tmp_outputs / "composed_filtered.graphql" result = runner.invoke( - cli, ["compose", "-s", str(TSD.SAMPLE1_1), "-s", str(TSD.SAMPLE1_2), "-r", "Vehicle_ADAS", "-o", str(out)] + cli, + [ + "compose", + "-s", + str(TSD.SAMPLE1_1), + "-s", + str(TSD.SAMPLE1_2), + "-s", + str(units_directory), + "-r", + "Vehicle_ADAS", + "-o", + str(out), + ], ) assert result.exit_code == 0 @@ -577,10 +782,12 @@ def test_compose_graphql_root_type_filters_unreferenced_types(runner: CliRunner, assert "type InCabinArea2x2" not in composed_content -def test_compose_preserves_custom_directives(runner: CliRunner, tmp_outputs: Path) -> None: +def test_compose_preserves_custom_directives(runner: CliRunner, tmp_outputs: Path, units_directory: Path) -> None: """Test that compose preserves all types of custom directives and formatting.""" out = tmp_outputs / "directive_preservation_test.graphql" - result = runner.invoke(cli, ["compose", "-s", str(TSD.SAMPLE1_1), "-s", str(TSD.SAMPLE1_2), "-o", str(out)]) + result = runner.invoke( + cli, ["compose", "-s", str(TSD.SAMPLE1_1), "-s", str(TSD.SAMPLE1_2), "-s", str(units_directory), "-o", str(out)] + ) assert result.exit_code == 0, result.output assert out.exists() @@ -596,10 +803,12 @@ def test_compose_preserves_custom_directives(runner: CliRunner, tmp_outputs: Pat assert "type Vehicle_ADAS_ObstacleDetection" in composed_content -def test_compose_adds_reference_directives(runner: CliRunner, tmp_outputs: Path) -> None: +def test_compose_adds_reference_directives(runner: CliRunner, tmp_outputs: Path, units_directory: Path) -> None: """Test that compose adds @reference directives to track source files.""" out = tmp_outputs / "reference_directives_test.graphql" - result = runner.invoke(cli, ["compose", "-s", str(TSD.SAMPLE1_1), "-s", str(TSD.SAMPLE1_2), "-o", str(out)]) + result = runner.invoke( + cli, ["compose", "-s", str(TSD.SAMPLE1_1), "-s", str(TSD.SAMPLE1_2), "-s", str(units_directory), "-o", str(out)] + ) assert result.exit_code == 0, result.output assert out.exists() @@ -666,19 +875,46 @@ def test_compose_reference_directive_placement_after_other_directives(runner: Cl assert 'union Person @reference(source: "test.graphql")' in composed_content -def test_compose_with_invalid_selection_query(runner: CliRunner, tmp_outputs: Path) -> None: +def test_compose_with_invalid_selection_query(runner: CliRunner, tmp_outputs: Path, units_directory: Path) -> None: out = tmp_outputs / "composed_invalid_query.graphql" result = runner.invoke( cli, - ["compose", "-s", str(TSD.SAMPLE2_1), "-s", str(TSD.SAMPLE2_2), "-q", str(TSD.INVALID_QUERY), "-o", str(out)], + [ + "compose", + "-s", + str(TSD.SAMPLE2_1), + "-s", + str(TSD.SAMPLE2_2), + "-s", + str(units_directory), + "-q", + str(TSD.INVALID_QUERY), + "-o", + str(out), + ], ) assert result.exit_code == 1 -def test_compose_with_valid_selection_query_prunes_schema(runner: CliRunner, tmp_outputs: Path) -> None: +def test_compose_with_valid_selection_query_prunes_schema( + runner: CliRunner, tmp_outputs: Path, units_directory: Path +) -> None: out = tmp_outputs / "composed_pruned.graphql" result = runner.invoke( - cli, ["compose", "-s", str(TSD.SAMPLE2_1), "-s", str(TSD.SAMPLE2_2), "-q", str(TSD.VALID_QUERY), "-o", str(out)] + cli, + [ + "compose", + "-s", + str(TSD.SAMPLE2_1), + "-s", + str(TSD.SAMPLE2_2), + "-s", + str(units_directory), + "-q", + str(TSD.VALID_QUERY), + "-o", + str(out), + ], ) assert result.exit_code == 0 @@ -705,8 +941,10 @@ def test_compose_with_valid_selection_query_prunes_schema(runner: CliRunner, tmp # ToDo(DA): needs refactoring after final decision how stats will work -def test_stats_graphql(runner: CliRunner) -> None: - result = runner.invoke(cli, ["stats", "graphql", "-s", str(TSD.SAMPLE1_1), "-s", str(TSD.SAMPLE1_2)]) +def test_stats_graphql(runner: CliRunner, units_directory: Path) -> None: + result = runner.invoke( + cli, ["stats", "graphql", "-s", str(TSD.SAMPLE1_1), "-s", str(TSD.SAMPLE1_2), "-s", str(units_directory)] + ) print(f"{result.output=}") assert result.exit_code == 0, normalize_whitespace(result.output) assert '"UInt32": 1' in normalize_whitespace(result.output) diff --git a/tests/test_enum_validation.py b/tests/test_enum_validation.py new file mode 100644 index 00000000..5e2e4d9e --- /dev/null +++ b/tests/test_enum_validation.py @@ -0,0 +1,250 @@ +from graphql import build_schema + +from s2dm.exporters.utils.schema_loader import check_enum_defaults + + +class TestEnumDefaultValidation: + def test_valid_input_field_default(self) -> None: + schema = build_schema( + """ + enum Color { RED BLUE GREEN } + + input VehicleInput { + color: Color = RED + } + """ + ) + errors = check_enum_defaults(schema) + assert errors == [] + + def test_invalid_input_field_default(self) -> None: + schema = build_schema( + """ + enum Color { RED BLUE GREEN } + + input VehicleInput { + color: Color = YELLOW + } + """ + ) + errors = check_enum_defaults(schema) + assert len(errors) == 1 + + error = errors[0] + assert "VehicleInput.color" in error + assert "YELLOW" in error + assert "['RED', 'BLUE', 'GREEN']" in error + + def test_invalid_enum_defaults_on_multiple_inputs(self) -> None: + schema = build_schema( + """ + enum Color { RED BLUE GREEN } + enum EngineType { ELECTRIC HYBRID COMBUSTION } + + input VehicleInput { + color: Color = YELLOW + engineType: EngineType = DIESEL + } + + input BikeInput { + color: Color = PURPLE + } + """ + ) + errors = check_enum_defaults(schema) + assert len(errors) == 3 + + vehicle_color_error = errors[0] + assert "VehicleInput.color" in vehicle_color_error + assert "YELLOW" in vehicle_color_error + + vehicle_engine_error = errors[1] + assert "VehicleInput.engineType" in vehicle_engine_error + assert "DIESEL" in vehicle_engine_error + + bike_color_error = errors[2] + assert "BikeInput.color" in bike_color_error + assert "PURPLE" in bike_color_error + + def test_valid_field_argument_default(self) -> None: + schema = build_schema( + """ + type Query { + vehicle(color: Color = RED): String + } + + enum Color { RED BLUE GREEN } + """ + ) + errors = check_enum_defaults(schema) + assert errors == [] + + def test_invalid_field_argument_default(self) -> None: + schema = build_schema( + """ + type Query { + vehicle(color: Color = YELLOW): String + } + + enum Color { RED BLUE GREEN } + """ + ) + errors = check_enum_defaults(schema) + assert len(errors) == 1 + + error = errors[0] + assert "Query.vehicle(color)" in error + assert "YELLOW" in error + + def test_valid_directive_definition_default(self) -> None: + schema = build_schema( + """ + directive @color(value: Color = RED) on OBJECT + + enum Color { RED BLUE GREEN } + """ + ) + errors = check_enum_defaults(schema) + assert errors == [] + + def test_invalid_directive_definition_default(self) -> None: + schema = build_schema( + """ + directive @color(value: Color = YELLOW) on OBJECT + + enum Color { RED BLUE GREEN } + """ + ) + errors = check_enum_defaults(schema) + assert len(errors) == 1 + + error = errors[0] + assert "@color(value)" in error + assert "YELLOW" in error + + def test_valid_directive_usage_on_type(self) -> None: + schema = build_schema( + """ + directive @color(value: Color) on OBJECT + + enum Color { RED BLUE GREEN } + + type Vehicle @color(value: RED) { + field: String + } + """ + ) + errors = check_enum_defaults(schema) + assert errors == [] + + def test_invalid_directive_usage_on_type(self) -> None: + schema = build_schema( + """ + directive @color(value: Color) on OBJECT + + enum Color { RED BLUE GREEN } + + type Vehicle @color(value: YELLOW) { + field: String + } + """ + ) + errors = check_enum_defaults(schema) + assert len(errors) == 1 + + error = errors[0] + assert "Type 'Vehicle'" in error + assert "@color(value)" in error + assert "YELLOW" in error + + def test_valid_directive_usage_on_field(self) -> None: + schema = build_schema( + """ + directive @color(value: Color) on FIELD_DEFINITION + + enum Color { RED BLUE GREEN } + + type Vehicle { + field: String @color(value: RED) + } + """ + ) + errors = check_enum_defaults(schema) + assert errors == [] + + def test_invalid_directive_usage_on_field(self) -> None: + schema = build_schema( + """ + directive @color(value: Color) on FIELD_DEFINITION + + enum Color { RED BLUE GREEN } + + type Vehicle { + field: String @color(value: YELLOW) + } + """ + ) + errors = check_enum_defaults(schema) + assert len(errors) == 1 + + error = errors[0] + assert "Field 'Vehicle.field'" in error + assert "@color(value)" in error + assert "YELLOW" in error + + def test_multiple_invalid_enum_defaults(self) -> None: + schema = build_schema( + """ + directive @color(value: Color = WHITE) on OBJECT | FIELD_DEFINITION + + type Query { + vehicle(color: Color = CYAN): String + } + + enum Color { RED BLUE GREEN } + + type Vehicle @color(value: MAROON) { + field: String @color(value: MAGENTA) + } + + input VehicleInput { + color: Color = YELLOW + } + """ + ) + errors = check_enum_defaults(schema) + assert len(errors) == 5 + + field_arg_error = errors[0] + assert "Query.vehicle(color)" in field_arg_error + assert "CYAN" in field_arg_error + + type_directive_error = errors[1] + assert "Type 'Vehicle'" in type_directive_error + assert "MAROON" in type_directive_error + + field_directive_error = errors[2] + assert "Field 'Vehicle.field'" in field_directive_error + assert "MAGENTA" in field_directive_error + + input_field_error = errors[3] + assert "VehicleInput.color" in input_field_error + assert "YELLOW" in input_field_error + + directive_def_error = errors[4] + assert "@color(value)" in directive_def_error + assert "WHITE" in directive_def_error + + def test_no_default_is_not_error(self) -> None: + """Verify fields without defaults aren't flagged (both cases result in Undefined).""" + schema = build_schema( + """ + enum Color { RED BLUE GREEN } + + input VehicleInput { + color: Color + } + """ + ) + errors = check_enum_defaults(schema) + assert errors == []