Skip to content
Merged

Jq #41

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 58 additions & 38 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,16 @@ After pushing this workflow to the main branch, you can run it manually from the
```text
# cog-safe-push --help

usage: cog-safe-push [-h] [--config CONFIG] [--help-config] [--test-model TEST_MODEL]
[--no-push] [--test-hardware TEST_HARDWARE] [--no-compare-outputs]
usage: cog-safe-push [-h] [--config CONFIG] [--help-config]
[--test-model TEST_MODEL] [--no-push]
[--test-hardware TEST_HARDWARE] [--no-compare-outputs]
[--predict-timeout PREDICT_TIMEOUT] [--fast-push]
[--test-case TEST_CASES] [--fuzz-fixed-inputs FUZZ_FIXED_INPUTS]
[--test-case TEST_CASES]
[--fuzz-fixed-inputs FUZZ_FIXED_INPUTS]
[--fuzz-disabled-inputs FUZZ_DISABLED_INPUTS]
[--fuzz-iterations FUZZ_ITERATIONS] [--fuzz-prompt FUZZ_PROMPT]
[--parallel PARALLEL] [--ignore-schema-compatibility] [-v]
[--fuzz-iterations FUZZ_ITERATIONS]
[--fuzz-prompt FUZZ_PROMPT] [--parallel PARALLEL]
[--ignore-schema-compatibility] [-v]
[--push-official-model]
[model]

Expand All @@ -113,53 +116,62 @@ positional arguments:

options:
-h, --help show this help message and exit
--config CONFIG Path to the YAML config file. If --config is not passed, ./cog-
safe-push.yaml will be used, if it exists. Any arguments you pass
in will override fields on the predict configuration stanza.
--config CONFIG Path to the YAML config file. If --config is not
passed, ./cog-safe-push.yaml will be used, if it
exists. Any arguments you pass in will override fields
on the predict configuration stanza.
--help-config Print a default cog-safe-push.yaml config to stdout.
--test-model TEST_MODEL
Replicate model to test on, in the format <username>/<model-name>.
If omitted, <model>-test will be used. The test model is created
automatically if it doesn't exist already
Replicate model to test on, in the format
<username>/<model-name>. If omitted, <model>-test will
be used. The test model is created automatically if it
doesn't exist already
--no-push Only test the model, don't push it to <model>
--test-hardware TEST_HARDWARE
Hardware to run the test model on. Only used when creating the
test model, if it doesn't already exist.
--no-compare-outputs Don't make predictions to compare that prediction outputs match
the current version
Hardware to run the test model on. Only used when
creating the test model, if it doesn't already exist.
--no-compare-outputs Don't make predictions to compare that prediction
outputs match the current version
--predict-timeout PREDICT_TIMEOUT
Timeout (in seconds) for predictions. Default: 300
--fast-push Use the --x-fast flag when doing cog push
--test-case TEST_CASES
Inputs and expected output that will be used for testing, you can
provide multiple --test-case options for multiple test cases. The
first test case will be used when comparing outputs to the current
version. Each --test-case is semicolon-separated key-value pairs
in the format '<key1>=<value1>;<key2=value2>[<output-checker>]'.
<output-checker> can either be '==<exact-string-or-url>' or
'~=<ai-prompt>'. If you use '==<exact-string-or-url>' then the
output of the model must match exactly the string or url you
specify. If you use '~=<ai-prompt>' then the AI will verify your
output based on <ai-prompt>. If you omit <output-checker>, it will
just verify that the prediction doesn't throw an error.
Inputs and expected output that will be used for
testing, you can provide multiple --test-case options
for multiple test cases. The first test case will be
used when comparing outputs to the current version.
Each --test-case is semicolon-separated key-value
pairs in the format
'<key1>=<value1>;<key2=value2>[<output-checker>]'.
<output-checker> can either be '==<exact-string-or-
url>' or '~=<ai-prompt>'. If you use '==<exact-string-
or-url>' then the output of the model must match
exactly the string or url you specify. If you use
'~=<ai-prompt>' then the AI will verify your output
based on <ai-prompt>. If you omit <output-checker>, it
will just verify that the prediction doesn't throw an
error.
--fuzz-fixed-inputs FUZZ_FIXED_INPUTS
Inputs that should have fixed values during fuzzing. All other
non-disabled input values will be generated by AI. If no test
cases are specified, these will also be used when comparing
outputs to the current version. Semicolon-separated key-value
pairs in the format '<key1>=<value1>;<key2=value2>' (etc.)
Inputs that should have fixed values during fuzzing.
All other non-disabled input values will be generated
by AI. If no test cases are specified, these will also
be used when comparing outputs to the current version.
Semicolon-separated key-value pairs in the format
'<key1>=<value1>;<key2=value2>' (etc.)
--fuzz-disabled-inputs FUZZ_DISABLED_INPUTS
Don't pass values for these inputs during fuzzing. Semicolon-
separated keys in the format '<key1>;<key2>' (etc.). If no test
cases are specified, these will also be disabled when comparing
outputs to the current version.
Don't pass values for these inputs during fuzzing.
Semicolon-separated keys in the format '<key1>;<key2>'
(etc.). If no test cases are specified, these will
also be disabled when comparing outputs to the current
version.
--fuzz-iterations FUZZ_ITERATIONS
Maximum number of iterations to run fuzzing.
--fuzz-prompt FUZZ_PROMPT
Additional prompting for the fuzz input generation
--parallel PARALLEL Number of parallel prediction threads.
--ignore-schema-compatibility
Ignore schema compatibility checks when pushing the model
Ignore schema compatibility checks when pushing the
model
-v, --verbose Increase verbosity level (max 3)
--push-official-model
Push to the official model defined in the config
Expand Down Expand Up @@ -191,7 +203,11 @@ predict:
<input3>: <value3>
match_prompt: <match output using AI prompt, e.g. 'an image of a cat'>
- inputs:
<input3>: <value3>
<input4>: <value4>
jq_query: <jq query to validate JSON output, e.g. ".status == \"success\" and
.confidence > 0.8">
- inputs:
<input5>: <value5>
error_contains: <assert that these inputs throws an error, and that the error
message contains a string>
fuzz:
Expand All @@ -215,7 +231,11 @@ train:
<input3>: <value3>
match_prompt: <match output using AI prompt, e.g. 'an image of a cat'>
- inputs:
<input3>: <value3>
<input4>: <value4>
jq_query: <jq query to validate JSON output, e.g. ".status == \"success\" and
.confidence > 0.8">
- inputs:
<input5>: <value5>
error_contains: <assert that these inputs throws an error, and that the error
message contains a string>
fuzz:
Expand Down
11 changes: 9 additions & 2 deletions cog_safe_push/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,24 @@ class TestCase(BaseModel):
exact_string: str | None = None
match_url: str | None = None
match_prompt: str | None = None
jq_query: str | None = None
error_contains: str | None = None

@model_validator(mode="after")
def check_mutually_exclusive(self):
set_fields = sum(
getattr(self, field) is not None
for field in ["exact_string", "match_url", "match_prompt", "error_contains"]
for field in [
"exact_string",
"match_url",
"match_prompt",
"jq_query",
"error_contains",
]
)
if set_fields > 1:
raise ArgumentError(
"At most one of 'exact_string', 'match_url', 'match_prompt', or 'error_contains' must be set"
"At most one of 'exact_string', 'match_url', 'match_prompt', 'jq_query', or 'error_contains' must be set"
)
return self

Expand Down
9 changes: 8 additions & 1 deletion cog_safe_push/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
AIChecker,
ErrorContainsChecker,
ExactStringChecker,
JqQueryChecker,
MatchURLChecker,
NoChecker,
OutputChecker,
Expand Down Expand Up @@ -552,6 +553,8 @@ def parse_config_test_case(
checker = MatchURLChecker(url=config_test_case.match_url)
elif config_test_case.match_prompt:
checker = AIChecker(prompt=config_test_case.match_prompt)
elif config_test_case.jq_query:
checker = JqQueryChecker(query=config_test_case.jq_query)
elif config_test_case.error_contains:
checker = ErrorContainsChecker(string=config_test_case.error_contains)
else:
Expand Down Expand Up @@ -581,7 +584,11 @@ def print_help_config():
match_prompt="<match output using AI prompt, e.g. 'an image of a cat'>",
),
ConfigTestCase(
inputs={"<input3>": "<value3>"},
inputs={"<input4>": "<value4>"},
jq_query='<jq query to validate JSON output, e.g. ".status == \\"success\\" and .confidence > 0.8">',
),
ConfigTestCase(
inputs={"<input5>": "<value5>"},
error_contains="<assert that these inputs throws an error, and that the error message contains a string>",
),
]
Expand Down
41 changes: 41 additions & 0 deletions cog_safe_push/output_checkers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from dataclasses import dataclass
from typing import Any, Protocol

Expand Down Expand Up @@ -81,6 +82,46 @@ async def __call__(self, output: Any | None, error: str | None) -> None:
raise TestCaseFailedError(f"AI error: {str(e)}")


@dataclass
class JqQueryChecker(OutputChecker):
query: str

async def __call__(self, output: Any | None, error: str | None) -> None:
check_no_error(error)

try:
import jq
except ImportError:
raise TestCaseFailedError(
"jq library not installed. Install with: pip install jq"
)

json_data = output
if isinstance(output, str):
try:
json_data = json.loads(output)
except json.JSONDecodeError:
raise TestCaseFailedError(
f"Output is a string but not valid JSON: {truncate(output, 200)}"
)

try:
compiled = jq.compile(self.query)
result = compiled.input_value(json_data).first()
except ValueError as e:
raise TestCaseFailedError(f"jq query error: {str(e)}")
except Exception as e:
raise TestCaseFailedError(f"jq execution failed: {str(e)}")

if not result:
json_str = json.dumps(json_data, indent=2)
raise TestCaseFailedError(
f"jq query '{self.query}' returned falsy value: {result}\n\nApplied to data:\n{truncate(json_str, 500)}"
)

log.info(f"jq query '{self.query}' matched successfully with result: {result}")


@dataclass
class ErrorContainsChecker(OutputChecker):
string: str
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"pydantic>=2,<3",
"PyYAML>=6,<7",
"requests>=2,<3",
"jq>=1.6.0,<2",
],
entry_points={
"console_scripts": [
Expand Down
30 changes: 30 additions & 0 deletions test/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,36 @@ def test_parse_config_with_train(tmp_path, monkeypatch):
assert config.train.fuzz.iterations == 8


def test_parse_config_with_jq_query(tmp_path, monkeypatch):
config_yaml = """
model: user/model
predict:
test_cases:
- inputs:
query: test
jq_query: '.status == "success" and .confidence > 0.8'
- inputs:
count: 5
jq_query: '.results | length == 5'
"""
config_file = tmp_path / "cog-safe-push.yaml"
config_file.write_text(config_yaml)
monkeypatch.setattr("sys.argv", ["cog-safe-push", "--config", str(config_file)])

config, _, _ = parse_args_and_config()

assert config.model == "user/model"
assert config.predict is not None
assert len(config.predict.test_cases) == 2
assert config.predict.test_cases[0].inputs == {"query": "test"}
assert (
config.predict.test_cases[0].jq_query
== '.status == "success" and .confidence > 0.8'
)
assert config.predict.test_cases[1].inputs == {"count": 5}
assert config.predict.test_cases[1].jq_query == ".results | length == 5"


def test_parse_args_with_default_config(tmp_path, monkeypatch):
config_yaml = """
model: user/default-model
Expand Down
Loading