diff --git a/README.md b/README.md index 2e26be1..c20ff48 100644 --- a/README.md +++ b/README.md @@ -96,13 +96,16 @@ After pushing this workflow to the main branch, you can run it manually from the ```text # cog-safe-push --help -usage: cog-safe-push [-h] [--config CONFIG] [--help-config] [--test-model TEST_MODEL] - [--no-push] [--test-hardware TEST_HARDWARE] [--no-compare-outputs] +usage: cog-safe-push [-h] [--config CONFIG] [--help-config] + [--test-model TEST_MODEL] [--no-push] + [--test-hardware TEST_HARDWARE] [--no-compare-outputs] [--predict-timeout PREDICT_TIMEOUT] [--fast-push] - [--test-case TEST_CASES] [--fuzz-fixed-inputs FUZZ_FIXED_INPUTS] + [--test-case TEST_CASES] + [--fuzz-fixed-inputs FUZZ_FIXED_INPUTS] [--fuzz-disabled-inputs FUZZ_DISABLED_INPUTS] - [--fuzz-iterations FUZZ_ITERATIONS] [--fuzz-prompt FUZZ_PROMPT] - [--parallel PARALLEL] [--ignore-schema-compatibility] [-v] + [--fuzz-iterations FUZZ_ITERATIONS] + [--fuzz-prompt FUZZ_PROMPT] [--parallel PARALLEL] + [--ignore-schema-compatibility] [-v] [--push-official-model] [model] @@ -113,53 +116,62 @@ positional arguments: options: -h, --help show this help message and exit - --config CONFIG Path to the YAML config file. If --config is not passed, ./cog- - safe-push.yaml will be used, if it exists. Any arguments you pass - in will override fields on the predict configuration stanza. + --config CONFIG Path to the YAML config file. If --config is not + passed, ./cog-safe-push.yaml will be used, if it + exists. Any arguments you pass in will override fields + on the predict configuration stanza. --help-config Print a default cog-safe-push.yaml config to stdout. --test-model TEST_MODEL - Replicate model to test on, in the format /. - If omitted, -test will be used. The test model is created - automatically if it doesn't exist already + Replicate model to test on, in the format + /. If omitted, -test will + be used. The test model is created automatically if it + doesn't exist already --no-push Only test the model, don't push it to --test-hardware TEST_HARDWARE - Hardware to run the test model on. Only used when creating the - test model, if it doesn't already exist. - --no-compare-outputs Don't make predictions to compare that prediction outputs match - the current version + Hardware to run the test model on. Only used when + creating the test model, if it doesn't already exist. + --no-compare-outputs Don't make predictions to compare that prediction + outputs match the current version --predict-timeout PREDICT_TIMEOUT Timeout (in seconds) for predictions. Default: 300 --fast-push Use the --x-fast flag when doing cog push --test-case TEST_CASES - Inputs and expected output that will be used for testing, you can - provide multiple --test-case options for multiple test cases. The - first test case will be used when comparing outputs to the current - version. Each --test-case is semicolon-separated key-value pairs - in the format '=;[]'. - can either be '==' or - '~='. If you use '==' then the - output of the model must match exactly the string or url you - specify. If you use '~=' then the AI will verify your - output based on . If you omit , it will - just verify that the prediction doesn't throw an error. + Inputs and expected output that will be used for + testing, you can provide multiple --test-case options + for multiple test cases. The first test case will be + used when comparing outputs to the current version. + Each --test-case is semicolon-separated key-value + pairs in the format + '=;[]'. + can either be '==' or '~='. If you use '==' then the output of the model must match + exactly the string or url you specify. If you use + '~=' then the AI will verify your output + based on . If you omit , it + will just verify that the prediction doesn't throw an + error. --fuzz-fixed-inputs FUZZ_FIXED_INPUTS - Inputs that should have fixed values during fuzzing. All other - non-disabled input values will be generated by AI. If no test - cases are specified, these will also be used when comparing - outputs to the current version. Semicolon-separated key-value - pairs in the format '=;' (etc.) + Inputs that should have fixed values during fuzzing. + All other non-disabled input values will be generated + by AI. If no test cases are specified, these will also + be used when comparing outputs to the current version. + Semicolon-separated key-value pairs in the format + '=;' (etc.) --fuzz-disabled-inputs FUZZ_DISABLED_INPUTS - Don't pass values for these inputs during fuzzing. Semicolon- - separated keys in the format ';' (etc.). If no test - cases are specified, these will also be disabled when comparing - outputs to the current version. + Don't pass values for these inputs during fuzzing. + Semicolon-separated keys in the format ';' + (etc.). If no test cases are specified, these will + also be disabled when comparing outputs to the current + version. --fuzz-iterations FUZZ_ITERATIONS Maximum number of iterations to run fuzzing. --fuzz-prompt FUZZ_PROMPT Additional prompting for the fuzz input generation --parallel PARALLEL Number of parallel prediction threads. --ignore-schema-compatibility - Ignore schema compatibility checks when pushing the model + Ignore schema compatibility checks when pushing the + model -v, --verbose Increase verbosity level (max 3) --push-official-model Push to the official model defined in the config @@ -191,7 +203,11 @@ predict: : match_prompt: - inputs: - : + : + jq_query: 0.8"> + - inputs: + : error_contains: fuzz: @@ -215,7 +231,11 @@ train: : match_prompt: - inputs: - : + : + jq_query: 0.8"> + - inputs: + : error_contains: fuzz: diff --git a/cog_safe_push/config.py b/cog_safe_push/config.py index d614798..e252a42 100644 --- a/cog_safe_push/config.py +++ b/cog_safe_push/config.py @@ -19,17 +19,24 @@ class TestCase(BaseModel): exact_string: str | None = None match_url: str | None = None match_prompt: str | None = None + jq_query: str | None = None error_contains: str | None = None @model_validator(mode="after") def check_mutually_exclusive(self): set_fields = sum( getattr(self, field) is not None - for field in ["exact_string", "match_url", "match_prompt", "error_contains"] + for field in [ + "exact_string", + "match_url", + "match_prompt", + "jq_query", + "error_contains", + ] ) if set_fields > 1: raise ArgumentError( - "At most one of 'exact_string', 'match_url', 'match_prompt', or 'error_contains' must be set" + "At most one of 'exact_string', 'match_url', 'match_prompt', 'jq_query', or 'error_contains' must be set" ) return self diff --git a/cog_safe_push/main.py b/cog_safe_push/main.py index 79d1ce7..6a31ff8 100644 --- a/cog_safe_push/main.py +++ b/cog_safe_push/main.py @@ -25,6 +25,7 @@ AIChecker, ErrorContainsChecker, ExactStringChecker, + JqQueryChecker, MatchURLChecker, NoChecker, OutputChecker, @@ -552,6 +553,8 @@ def parse_config_test_case( checker = MatchURLChecker(url=config_test_case.match_url) elif config_test_case.match_prompt: checker = AIChecker(prompt=config_test_case.match_prompt) + elif config_test_case.jq_query: + checker = JqQueryChecker(query=config_test_case.jq_query) elif config_test_case.error_contains: checker = ErrorContainsChecker(string=config_test_case.error_contains) else: @@ -581,7 +584,11 @@ def print_help_config(): match_prompt="", ), ConfigTestCase( - inputs={"": ""}, + inputs={"": ""}, + jq_query=' 0.8">', + ), + ConfigTestCase( + inputs={"": ""}, error_contains="", ), ] diff --git a/cog_safe_push/output_checkers.py b/cog_safe_push/output_checkers.py index dd53ddb..9044cb4 100644 --- a/cog_safe_push/output_checkers.py +++ b/cog_safe_push/output_checkers.py @@ -1,3 +1,4 @@ +import json from dataclasses import dataclass from typing import Any, Protocol @@ -81,6 +82,46 @@ async def __call__(self, output: Any | None, error: str | None) -> None: raise TestCaseFailedError(f"AI error: {str(e)}") +@dataclass +class JqQueryChecker(OutputChecker): + query: str + + async def __call__(self, output: Any | None, error: str | None) -> None: + check_no_error(error) + + try: + import jq + except ImportError: + raise TestCaseFailedError( + "jq library not installed. Install with: pip install jq" + ) + + json_data = output + if isinstance(output, str): + try: + json_data = json.loads(output) + except json.JSONDecodeError: + raise TestCaseFailedError( + f"Output is a string but not valid JSON: {truncate(output, 200)}" + ) + + try: + compiled = jq.compile(self.query) + result = compiled.input_value(json_data).first() + except ValueError as e: + raise TestCaseFailedError(f"jq query error: {str(e)}") + except Exception as e: + raise TestCaseFailedError(f"jq execution failed: {str(e)}") + + if not result: + json_str = json.dumps(json_data, indent=2) + raise TestCaseFailedError( + f"jq query '{self.query}' returned falsy value: {result}\n\nApplied to data:\n{truncate(json_str, 500)}" + ) + + log.info(f"jq query '{self.query}' matched successfully with result: {result}") + + @dataclass class ErrorContainsChecker(OutputChecker): string: str diff --git a/setup.py b/setup.py index 842ba78..a6b11cc 100644 --- a/setup.py +++ b/setup.py @@ -14,6 +14,7 @@ "pydantic>=2,<3", "PyYAML>=6,<7", "requests>=2,<3", + "jq>=1.6.0,<2", ], entry_points={ "console_scripts": [ diff --git a/test/test_main.py b/test/test_main.py index f5ae734..13dec72 100644 --- a/test/test_main.py +++ b/test/test_main.py @@ -287,6 +287,36 @@ def test_parse_config_with_train(tmp_path, monkeypatch): assert config.train.fuzz.iterations == 8 +def test_parse_config_with_jq_query(tmp_path, monkeypatch): + config_yaml = """ + model: user/model + predict: + test_cases: + - inputs: + query: test + jq_query: '.status == "success" and .confidence > 0.8' + - inputs: + count: 5 + jq_query: '.results | length == 5' + """ + config_file = tmp_path / "cog-safe-push.yaml" + config_file.write_text(config_yaml) + monkeypatch.setattr("sys.argv", ["cog-safe-push", "--config", str(config_file)]) + + config, _, _ = parse_args_and_config() + + assert config.model == "user/model" + assert config.predict is not None + assert len(config.predict.test_cases) == 2 + assert config.predict.test_cases[0].inputs == {"query": "test"} + assert ( + config.predict.test_cases[0].jq_query + == '.status == "success" and .confidence > 0.8' + ) + assert config.predict.test_cases[1].inputs == {"count": 5} + assert config.predict.test_cases[1].jq_query == ".results | length == 5" + + def test_parse_args_with_default_config(tmp_path, monkeypatch): config_yaml = """ model: user/default-model diff --git a/test/test_output_checkers.py b/test/test_output_checkers.py new file mode 100644 index 0000000..1184909 --- /dev/null +++ b/test/test_output_checkers.py @@ -0,0 +1,256 @@ +import pytest + +from cog_safe_push.exceptions import TestCaseFailedError +from cog_safe_push.output_checkers import ( + JqQueryChecker, +) + + +@pytest.mark.asyncio +async def test_jq_query_checker_basic_equality(): + checker = JqQueryChecker(query='.status == "success"') + output = {"status": "success", "data": "test"} + await checker(output, None) + + +@pytest.mark.asyncio +async def test_jq_query_checker_basic_equality_fails(): + checker = JqQueryChecker(query='.status == "success"') + output = {"status": "failure", "data": "test"} + with pytest.raises(TestCaseFailedError, match="returned falsy value"): + await checker(output, None) + + +@pytest.mark.asyncio +async def test_jq_query_checker_numeric_comparison(): + checker = JqQueryChecker(query=".confidence > 0.8") + output = {"confidence": 0.9, "result": "good"} + await checker(output, None) + + +@pytest.mark.asyncio +async def test_jq_query_checker_numeric_comparison_fails(): + checker = JqQueryChecker(query=".confidence > 0.8") + output = {"confidence": 0.5, "result": "bad"} + with pytest.raises(TestCaseFailedError, match="returned falsy value"): + await checker(output, None) + + +@pytest.mark.asyncio +async def test_jq_query_checker_array_length(): + checker = JqQueryChecker(query=".results | length == 5") + output = {"results": [1, 2, 3, 4, 5]} + await checker(output, None) + + +@pytest.mark.asyncio +async def test_jq_query_checker_array_length_fails(): + checker = JqQueryChecker(query=".results | length == 5") + output = {"results": [1, 2, 3]} + with pytest.raises(TestCaseFailedError, match="returned falsy value"): + await checker(output, None) + + +@pytest.mark.asyncio +async def test_jq_query_checker_multiple_conditions(): + checker = JqQueryChecker( + query='.status == "success" and .confidence > 0.8 and (.results | length) > 0' + ) + output = {"status": "success", "confidence": 0.9, "results": [1, 2, 3]} + await checker(output, None) + + +@pytest.mark.asyncio +async def test_jq_query_checker_multiple_conditions_fails(): + checker = JqQueryChecker( + query='.status == "success" and .confidence > 0.8 and (.results | length) > 0' + ) + output = {"status": "success", "confidence": 0.9, "results": []} + with pytest.raises(TestCaseFailedError, match="returned falsy value"): + await checker(output, None) + + +@pytest.mark.asyncio +async def test_jq_query_checker_nested_fields(): + checker = JqQueryChecker(query=".metadata.author and .metadata.version") + output = {"metadata": {"author": "test", "version": "1.0"}, "data": "content"} + await checker(output, None) + + +@pytest.mark.asyncio +async def test_jq_query_checker_nested_fields_fails(): + checker = JqQueryChecker(query=".metadata.author and .metadata.version") + output = {"metadata": {"author": "test"}, "data": "content"} + with pytest.raises(TestCaseFailedError, match="returned falsy value"): + await checker(output, None) + + +@pytest.mark.asyncio +async def test_jq_query_checker_all_operator(): + checker = JqQueryChecker(query=".predictions | all(.[]; .score > 0.5)") + output = {"predictions": [{"score": 0.6}, {"score": 0.7}, {"score": 0.8}]} + await checker(output, None) + + +@pytest.mark.asyncio +async def test_jq_query_checker_all_operator_fails(): + checker = JqQueryChecker(query=".predictions | all(.[]; .score > 0.5)") + output = {"predictions": [{"score": 0.6}, {"score": 0.3}, {"score": 0.8}]} + with pytest.raises(TestCaseFailedError, match="returned falsy value"): + await checker(output, None) + + +@pytest.mark.asyncio +async def test_jq_query_checker_any_operator(): + checker = JqQueryChecker(query='.results | any(.[]; .category == "animal")') + output = { + "results": [ + {"category": "plant"}, + {"category": "animal"}, + {"category": "mineral"}, + ] + } + await checker(output, None) + + +@pytest.mark.asyncio +async def test_jq_query_checker_any_operator_fails(): + checker = JqQueryChecker(query='.results | any(.[]; .category == "animal")') + output = {"results": [{"category": "plant"}, {"category": "mineral"}]} + with pytest.raises(TestCaseFailedError, match="returned falsy value"): + await checker(output, None) + + +@pytest.mark.asyncio +async def test_jq_query_checker_type_check(): + checker = JqQueryChecker(query='(.id | type) == "string"') + output = {"id": "abc123", "value": 42} + await checker(output, None) + + +@pytest.mark.asyncio +async def test_jq_query_checker_type_check_fails(): + checker = JqQueryChecker(query='(.id | type) == "string"') + output = {"id": 123, "value": 42} + with pytest.raises(TestCaseFailedError, match="returned falsy value"): + await checker(output, None) + + +@pytest.mark.asyncio +async def test_jq_query_checker_with_json_string(): + checker = JqQueryChecker(query='.status == "success"') + output = '{"status": "success", "data": "test"}' + await checker(output, None) + + +@pytest.mark.asyncio +async def test_jq_query_checker_with_invalid_json_string(): + checker = JqQueryChecker(query='.status == "success"') + output = "not valid json" + with pytest.raises(TestCaseFailedError, match="not valid JSON"): + await checker(output, None) + + +@pytest.mark.asyncio +async def test_jq_query_checker_invalid_query(): + checker = JqQueryChecker(query="invalid jq syntax ][") + output = {"status": "success"} + with pytest.raises(TestCaseFailedError, match="jq query error|jq execution failed"): + await checker(output, None) + + +@pytest.mark.asyncio +async def test_jq_query_checker_with_error(): + checker = JqQueryChecker(query='.status == "success"') + with pytest.raises(TestCaseFailedError, match="unexpected error"): + await checker({"status": "success"}, "some error occurred") + + +@pytest.mark.asyncio +async def test_jq_query_checker_keys_validation(): + checker = JqQueryChecker( + query='keys | length == 3 and contains(["status", "data", "timestamp"])' + ) + output = {"status": "ok", "data": "test", "timestamp": 12345} + await checker(output, None) + + +@pytest.mark.asyncio +async def test_jq_query_checker_keys_validation_fails(): + checker = JqQueryChecker( + query='keys | length == 3 and contains(["status", "data", "timestamp"])' + ) + output = {"status": "ok", "data": "test"} + with pytest.raises(TestCaseFailedError, match="returned falsy value"): + await checker(output, None) + + +@pytest.mark.asyncio +async def test_jq_query_checker_array_indexing(): + checker = JqQueryChecker(query='.[0].type == "start" and .[-1].type == "end"') + output = [{"type": "start"}, {"type": "middle"}, {"type": "end"}] + await checker(output, None) + + +@pytest.mark.asyncio +async def test_jq_query_checker_string_contains(): + checker = JqQueryChecker(query='.message | contains("success")') + output = {"message": "Operation completed successfully"} + await checker(output, None) + + +@pytest.mark.asyncio +async def test_jq_query_checker_string_contains_fails(): + checker = JqQueryChecker(query='.message | contains("success")') + output = {"message": "Operation failed"} + with pytest.raises(TestCaseFailedError, match="returned falsy value"): + await checker(output, None) + + +@pytest.mark.asyncio +async def test_jq_query_checker_regex_test(): + checker = JqQueryChecker(query='.id | test("^[0-9a-f]{8}-[0-9a-f]{4}")') + output = {"id": "12345678-abcd-1234-5678-abcdef123456"} + await checker(output, None) + + +@pytest.mark.asyncio +async def test_jq_query_checker_regex_test_fails(): + checker = JqQueryChecker(query='.id | test("^[0-9a-f]{8}-[0-9a-f]{4}")') + output = {"id": "not-a-uuid"} + with pytest.raises(TestCaseFailedError, match="returned falsy value"): + await checker(output, None) + + +@pytest.mark.asyncio +async def test_jq_query_checker_or_conditions(): + checker = JqQueryChecker( + query='.status == "success" or .status == "completed" or .status == "done"' + ) + output = {"status": "completed"} + await checker(output, None) + + +@pytest.mark.asyncio +async def test_jq_query_checker_or_conditions_fails(): + checker = JqQueryChecker( + query='.status == "success" or .status == "completed" or .status == "done"' + ) + output = {"status": "failed"} + with pytest.raises(TestCaseFailedError, match="returned falsy value"): + await checker(output, None) + + +@pytest.mark.asyncio +async def test_jq_query_checker_select_filter(): + checker = JqQueryChecker(query="[.items[] | select(.price > 100)] | length > 0") + output = {"items": [{"price": 50}, {"price": 150}, {"price": 200}]} + await checker(output, None) + + +@pytest.mark.asyncio +async def test_jq_query_checker_select_filter_fails(): + checker = JqQueryChecker(query="[.items[] | select(.price > 100)] | length > 0") + output = {"items": [{"price": 50}, {"price": 75}]} + with pytest.raises(TestCaseFailedError, match="returned falsy value"): + await checker(output, None)