diff --git a/cog_safe_push/schema.py b/cog_safe_push/schema.py index 703c9c5..74909dd 100644 --- a/cog_safe_push/schema.py +++ b/cog_safe_push/schema.py @@ -72,10 +72,9 @@ def check_backwards_compatible( # We allow defaults to be changed elif "allOf" in spec: - choice_schema = model_schemas[spec["allOf"][0]["$ref"].split("/")[-1]] - test_choice_schema = test_model_schemas[ - spec["allOf"][0]["$ref"].split("/")[-1] - ] + choice_schema = spec["allOf"][0] + test_choice_schema = test_spec["allOf"][0] + choice_type = choice_schema["type"] test_choice_type = test_choice_schema["type"] if test_choice_type != choice_type: diff --git a/test/test_schema.py b/test/test_schema.py index f30eb8b..9c76dc8 100644 --- a/test/test_schema.py +++ b/test/test_schema.py @@ -134,18 +134,17 @@ def test_decreased_maximum(): def test_changed_choice_type(): + """Test enum type change with inline allOf.""" old = { "Input": make_input_schema( - {"choice": {"allOf": [{"$ref": "#/components/schemas/choice"}]}} + {"choice": {"allOf": [{"type": "string", "enum": ["A", "B", "C"]}]}} ), - "choice": {"type": "string", "enum": ["A", "B", "C"]}, "Output": {"type": "string"}, } new = { "Input": make_input_schema( - {"choice": {"allOf": [{"$ref": "#/components/schemas/choice"}]}} + {"choice": {"allOf": [{"type": "integer", "enum": [1, 2, 3]}]}} ), - "choice": {"type": "integer", "enum": [1, 2, 3]}, "Output": {"type": "string"}, } with pytest.raises( @@ -156,36 +155,34 @@ def test_changed_choice_type(): def test_added_choice(): + """Test adding enum choices is backwards compatible.""" old = { "Input": make_input_schema( - {"choice": {"allOf": [{"$ref": "#/components/schemas/choice"}]}} + {"choice": {"allOf": [{"type": "string", "enum": ["A", "B", "C"]}]}} ), - "choice": {"type": "string", "enum": ["A", "B", "C"]}, "Output": {"type": "string"}, } new = { "Input": make_input_schema( - {"choice": {"allOf": [{"$ref": "#/components/schemas/choice"}]}} + {"choice": {"allOf": [{"type": "string", "enum": ["A", "B", "C", "D"]}]}} ), - "choice": {"type": "string", "enum": ["A", "B", "C", "D"]}, "Output": {"type": "string"}, } check_backwards_compatible(new, old, train=False) # Should not raise def test_removed_choice(): + """Test removing enum choices breaks compatibility.""" old = { "Input": make_input_schema( - {"choice": {"allOf": [{"$ref": "#/components/schemas/choice"}]}} + {"choice": {"allOf": [{"type": "string", "enum": ["A", "B", "C"]}]}} ), - "choice": {"type": "string", "enum": ["A", "B", "C"]}, "Output": {"type": "string"}, } new = { "Input": make_input_schema( - {"choice": {"allOf": [{"$ref": "#/components/schemas/choice"}]}} + {"choice": {"allOf": [{"type": "string", "enum": ["A", "B"]}]}} ), - "choice": {"type": "string", "enum": ["A", "B"]}, "Output": {"type": "string"}, } with pytest.raises( @@ -194,6 +191,54 @@ def test_removed_choice(): check_backwards_compatible(new, old, train=False) +def test_realistic_enum_with_metadata(): + """Test enum with full realistic metadata (like aspect_ratio).""" + old = { + "Input": make_input_schema( + { + "aspect_ratio": { + "allOf": [ + { + "enum": ["1:1", "2:3", "3:2", "4:3", "16:9"], + "type": "string", + "title": "aspect_ratio", + "description": "An enumeration.", + } + ], + "default": "1:1", + "x-order": 2, + "description": "Aspect ratio for expansion.", + } + } + ), + "Output": {"type": "string", "title": "Output", "format": "uri"}, + } + new = { + "Input": make_input_schema( + { + "aspect_ratio": { + "allOf": [ + { + "enum": ["1:1", "2:3", "3:2", "4:3"], # removed 16:9 + "type": "string", + "title": "aspect_ratio", + "description": "An enumeration.", + } + ], + "default": "1:1", + "x-order": 2, + "description": "Aspect ratio for expansion.", + } + } + ), + "Output": {"type": "string", "title": "Output", "format": "uri"}, + } + with pytest.raises( + IncompatibleSchemaError, match="Input aspect_ratio is missing choices: '16:9'" + ): + check_backwards_compatible(new, old, train=False) + + def test_new_required_input(): old = { "Input": make_input_schema({"text": {"type": "string"}}), @@ -230,10 +275,9 @@ def test_multiple_incompatibilities(): { "text": {"type": "string"}, "number": {"type": "integer", "minimum": 0}, - "choice": {"allOf": [{"$ref": "#/components/schemas/choice"}]}, + "choice": {"allOf": [{"type": "string", "enum": ["A", "B", "C"]}]}, } ), - "choice": {"type": "string", "enum": ["A", "B", "C"]}, "Output": {"type": "string"}, } new = { @@ -241,11 +285,10 @@ def test_multiple_incompatibilities(): { "text": {"type": "integer"}, "number": {"type": "integer", "minimum": 1}, - "choice": {"allOf": [{"$ref": "#/components/schemas/choice"}]}, + "choice": {"allOf": [{"type": "string", "enum": ["A", "B"]}]}, "new_required": {"type": "string"}, } ), - "choice": {"type": "string", "enum": ["A", "B"]}, "Output": {"type": "integer"}, } with pytest.raises(IncompatibleSchemaError) as exc_info: