Skip to content
Open
63 changes: 45 additions & 18 deletions src/s2dm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import rich_click as click
import yaml
from graphql import build_schema, parse
from graphql import GraphQLSchema, build_schema, parse
from rich.traceback import install

from s2dm import __version__, log
Expand All @@ -21,6 +21,7 @@
from s2dm.exporters.utils.graphql_type import is_builtin_scalar_type, is_introspection_type
from s2dm.exporters.utils.schema import load_schema_with_naming, search_schema
from s2dm.exporters.utils.schema_loader import (
check_correct_schema,
create_tempfile_to_composed_schema,
load_schema,
load_schema_as_str,
Expand Down Expand Up @@ -52,15 +53,6 @@ def process_value(self, ctx: click.Context, value: Any) -> list[Path] | None:
if not value:
return None
paths = set(value)

# Include the default QUDT units directory if it exists, otherwise warn and don't include it
if DEFAULT_QUDT_UNITS_DIR.exists():
paths.add(DEFAULT_QUDT_UNITS_DIR)
else:
log.warning(
f"No QUDT units directory found at {DEFAULT_QUDT_UNITS_DIR}. Please run 's2dm units sync' first."
)

return resolve_graphql_files(list(paths))


Expand Down Expand Up @@ -102,6 +94,7 @@ def selection_query_option(required: bool = False) -> Callable[[Callable[..., An
help="Output file",
)


optional_output_option = click.option(
"--output",
"-o",
Expand All @@ -111,6 +104,16 @@ def selection_query_option(required: bool = False) -> Callable[[Callable[..., An
)


units_directory_option = click.option(
"--directory",
"-d",
type=click.Path(file_okay=False, path_type=Path),
default=DEFAULT_QUDT_UNITS_DIR,
help="Directory for QUDT unit enums",
show_default=True,
)


expanded_instances_option = click.option(
"--expanded-instances",
"-e",
Expand Down Expand Up @@ -140,6 +143,16 @@ def multiline_str_representer(obj: Any) -> Any:
return {k: multiline_str_representer(v) for k, v in result.items()}


def assert_correct_schema(schema: GraphQLSchema) -> None:
schema_errors = check_correct_schema(schema)
if schema_errors:
log.error("Schema validation failed:")
for error in schema_errors:
log.error(error)
log.error(f"Found {len(schema_errors)} validation error(s). Please fix the schema before exporting.")
sys.exit(1)


def validate_naming_config(config: dict[str, Any]) -> None:
VALID_CASES = {
"camelCase",
Expand Down Expand Up @@ -315,39 +328,48 @@ def units() -> None:
"QUDT version tag (e.g., 3.1.6). Defaults to the latest tag, falls back to 'main' when tags are unavailable."
),
)
@units_directory_option
@click.option(
"--dry-run",
is_flag=True,
help="Show what would be generated without actually writing files",
)
def units_sync(version: str | None, dry_run: bool) -> None:
"""Fetch QUDT quantity kinds and generate GraphQL enums under the output directory."""
def units_sync(version: str | None, directory: Path, dry_run: bool) -> None:
"""Fetch QUDT quantity kinds and generate GraphQL enums under the specified directory.

Args:
version: QUDT version tag. Defaults to the latest tag.
directory: Output directory for generated QUDT unit enums (default: ~/.s2dm/units/qudt)
dry_run: Show what would be generated without actually writing files
"""

version_to_use = version or get_latest_qudt_version()

try:
written = sync_qudt_units(DEFAULT_QUDT_UNITS_DIR, version_to_use, dry_run=dry_run)
written = sync_qudt_units(directory, version_to_use, dry_run=dry_run)
except UnitEnumError as e:
log.error(f"Units sync failed: {e}")
sys.exit(1)

if dry_run:
log.info(f"Would generate {len(written)} enum files under {DEFAULT_QUDT_UNITS_DIR}")
log.info(f"Would generate {len(written)} enum files under {directory}")
log.print(f"Version: {version_to_use}")
log.hint("Use without --dry-run to actually write files")
else:
log.success(f"Generated {len(written)} enum files under {DEFAULT_QUDT_UNITS_DIR}")
log.success(f"Generated {len(written)} enum files under {directory}")
log.print(f"Version: {version_to_use}")


@units.command(name="check-version")
def units_check_version() -> None:
@units_directory_option
def units_check_version(directory: Path) -> None:
"""Compare local synced QUDT version with the latest remote version and print a message.

Args:
qudt_units_dir: Directory containing generated QUDT unit enums (default: ~/.s2dm/units/qudt)
directory: Directory containing generated QUDT unit enums (default: ~/.s2dm/units/qudt)
"""

meta_path = DEFAULT_QUDT_UNITS_DIR / UNITS_META_FILENAME
meta_path = directory / UNITS_META_FILENAME
if not meta_path.exists():
log.warning("No metadata.json found. Run 's2dm units sync' first.")
sys.exit(1)
Expand Down Expand Up @@ -398,6 +420,7 @@ def compose(schemas: list[Path], root_type: str | None, selection_query: Path |
composed_schema_str = load_schema_as_str(schemas, add_references=True)

graphql_schema = build_schema(composed_schema_str)
assert_correct_schema(graphql_schema)

if selection_query:
query_document = parse(selection_query.read_text())
Expand Down Expand Up @@ -487,6 +510,7 @@ def shacl(
naming_config = ctx.obj.get("naming_config")

graphql_schema = load_schema_with_naming(schemas, naming_config)
assert_correct_schema(graphql_schema)

if selection_query:
query_document = parse(selection_query.read_text())
Expand Down Expand Up @@ -515,6 +539,7 @@ def vspec(ctx: click.Context, schemas: list[Path], selection_query: Path | None,
"""Generate VSPEC from a given GraphQL schema."""
naming_config = ctx.obj.get("naming_config")
graphql_schema = load_schema_with_naming(schemas, naming_config)
assert_correct_schema(graphql_schema)

if selection_query:
query_document = parse(selection_query.read_text())
Expand Down Expand Up @@ -553,6 +578,7 @@ def jsonschema(
"""Generate JSON Schema from a given GraphQL schema."""
naming_config = ctx.obj.get("naming_config")
graphql_schema = load_schema_with_naming(schemas, naming_config)
assert_correct_schema(graphql_schema)

if selection_query:
query_document = parse(selection_query.read_text())
Expand Down Expand Up @@ -597,6 +623,7 @@ def protobuf(
"""Generate Protocol Buffers (.proto) file from GraphQL schema."""
naming_config = ctx.obj.get("naming_config")
graphql_schema = load_schema_with_naming(schemas, naming_config)
assert_correct_schema(graphql_schema)

query_document = parse(selection_query.read_text())
graphql_schema = prune_schema_using_query_selection(graphql_schema, query_document)
Expand Down
153 changes: 150 additions & 3 deletions src/s2dm/exporters/utils/schema_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ariadne import load_schema_from_path
from graphql import (
DocumentNode,
GraphQLEnumType,
GraphQLField,
GraphQLInputObjectType,
GraphQLInterfaceType,
Expand All @@ -17,6 +18,7 @@
GraphQLString,
GraphQLType,
GraphQLUnionType,
Undefined,
build_schema,
get_named_type,
is_input_object_type,
Expand All @@ -26,9 +28,10 @@
is_object_type,
is_union_type,
print_schema,
validate_schema,
)
from graphql import validate as graphql_validate
from graphql.language.ast import SelectionSetNode
from graphql.language.ast import DirectiveNode, EnumValueNode, SelectionSetNode

from s2dm import log
from s2dm.exporters.utils.directive import (
Expand Down Expand Up @@ -268,6 +271,150 @@ def create_tempfile_to_composed_schema(graphql_schema_paths: list[Path]) -> Path
return Path(temp_path)


def _check_directive_usage_on_node(schema: GraphQLSchema, directive_node: DirectiveNode, context: str) -> list[str]:
"""Check enum values in directive usage on a specific node."""
errors: list[str] = []

directive_def = schema.get_directive(directive_node.name.value)
if not directive_def:
return errors

for arg_node in directive_node.arguments:
arg_name = arg_node.name.value
arg_def = directive_def.args[arg_name]
named_type = get_named_type(arg_def.type)

if not isinstance(named_type, GraphQLEnumType):
continue

if not isinstance(arg_node.value, EnumValueNode):
continue

enum_value = arg_node.value.value
if enum_value not in named_type.values:
errors.append(
f"{context} uses directive '@{directive_node.name.value}({arg_name})' "
f"with invalid enum value '{enum_value}'. Valid values are: {list(named_type.values.keys())}"
)

return errors


def check_enum_defaults(schema: GraphQLSchema) -> list[str]:
"""Check that all enum default values exist in their enum definitions.

Args:
schema: The GraphQL schema to validate

Returns:
List of error messages for invalid enum defaults
"""
errors = []

for type_name, type_obj in schema.type_map.items():
# Validate directive usage on types
if type_obj.ast_node and type_obj.ast_node.directives:
for directive_node in type_obj.ast_node.directives:
errors.extend(_check_directive_usage_on_node(schema, directive_node, f"Type '{type_name}'"))

# Validate input object field defaults
if isinstance(type_obj, GraphQLInputObjectType):
for field_name, field in type_obj.fields.items():
named_type = get_named_type(field.type)
if not isinstance(named_type, GraphQLEnumType):
continue

has_default_in_ast = field.ast_node and field.ast_node.default_value is not None
if not (has_default_in_ast and field.default_value is Undefined):
continue

invalid_value = field.ast_node.default_value.value
errors.append(
f"Input type '{type_name}.{field_name}' has invalid enum default value '{invalid_value}'. "
f"Valid values are: {list(named_type.values.keys())}"
)

# Validate field argument defaults and directive usage on fields
if isinstance(type_obj, GraphQLObjectType | GraphQLInterfaceType | GraphQLInputObjectType):
for field_name, field in type_obj.fields.items():
# Validate directive usage on fields
if field.ast_node and field.ast_node.directives:
for directive_node in field.ast_node.directives:
errors.extend(
_check_directive_usage_on_node(schema, directive_node, f"Field '{type_name}.{field_name}'")
)

# Validate field argument defaults
if isinstance(type_obj, GraphQLObjectType | GraphQLInterfaceType):
for arg_name, arg in field.args.items():
named_type = get_named_type(arg.type)
if not isinstance(named_type, GraphQLEnumType):
continue

has_default_in_ast = arg.ast_node and arg.ast_node.default_value is not None
if not (has_default_in_ast and arg.default_value is Undefined):
continue

invalid_value = arg.ast_node.default_value.value
errors.append(
f"Field argument '{type_name}.{field_name}({arg_name})' "
f"has invalid enum default value '{invalid_value}'. "
f"Valid values are: {list(named_type.values.keys())}"
)

# Validate directive definition defaults
for directive in schema.directives:
for arg_name, arg in directive.args.items():
named_type = get_named_type(arg.type)
if not isinstance(named_type, GraphQLEnumType):
continue

if not arg.ast_node or not arg.ast_node.default_value:
continue

if arg.default_value is not Undefined:
continue

if not isinstance(arg.ast_node.default_value, EnumValueNode):
continue

invalid_value = arg.ast_node.default_value.value
errors.append(
f"Directive definition '@{directive.name}({arg_name})' "
f"has invalid enum default value '{invalid_value}'. Valid values are: {list(named_type.values.keys())}"
)

return errors


def check_correct_schema(schema: GraphQLSchema) -> list[str]:
"""Assert that the schema conforms to GraphQL specification and has valid enum defaults.

Args:
schema: The GraphQL schema to validate

Returns:
list[str]: List of error messages if any validation errors are found

Exits:
Calls sys.exit(1) if the schema has validation errors
"""
spec_errors = validate_schema(schema)
enum_errors = check_enum_defaults(schema)

all_errors: list[str] = []

if spec_errors:
for spec_error in spec_errors:
all_errors.append(f" - {spec_error.message}")

if enum_errors:
for enum_error in enum_errors:
all_errors.append(f" - {enum_error}")

return all_errors


def ensure_query(schema: GraphQLSchema) -> GraphQLSchema:
"""
Ensures that the provided GraphQL schema has a Query type. If the schema does not have a Query type,
Expand Down Expand Up @@ -377,7 +524,7 @@ def visit_field_type(field_type: GraphQLType) -> None:
return referenced


def validate_schema(schema: GraphQLSchema, document: DocumentNode) -> GraphQLSchema | None:
def _validate_schema(schema: GraphQLSchema, document: DocumentNode) -> GraphQLSchema | None:
log.debug("Validating schema against the provided document")

errors = graphql_validate(schema, document)
Expand Down Expand Up @@ -406,7 +553,7 @@ def prune_schema_using_query_selection(schema: GraphQLSchema, document: Document
if not schema.query_type:
raise ValueError("Schema has no query type defined")

if validate_schema(schema, document) is None:
if _validate_schema(schema, document) is None:
raise ValueError("Schema validation failed")

fields_to_keep: dict[str, set[str]] = {}
Expand Down
13 changes: 0 additions & 13 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,6 @@ class TestSchemaData:
BREAKING_SCHEMA = TESTS_DATA_DIR / "breaking.graphql"


@pytest.fixture(autouse=True)
def patch_default_units_dir(monkeypatch: pytest.MonkeyPatch) -> None:
"""Patch DEFAULT_QUDT_UNITS_DIR to use tests/data/units for all tests.

This prevents the "No QUDT units directory found" warning during tests
and provides the necessary unit enum definitions that test schemas reference.

Tests that use the units_sync_mocks fixture will have this overridden with
their own tmp_path directory for isolation.
"""
monkeypatch.setattr("s2dm.cli.DEFAULT_QUDT_UNITS_DIR", TestSchemaData.UNITS_SCHEMA_PATH)


def parsed_console_output() -> str:
"""Parse console output (placeholder function)."""
return ""
Expand Down
Loading