diff --git a/flexeval/core/utils/json_util.py b/flexeval/core/utils/json_util.py index e47e6586..451312fc 100644 --- a/flexeval/core/utils/json_util.py +++ b/flexeval/core/utils/json_util.py @@ -13,6 +13,8 @@ def _truncate_base64(o: Any) -> Any: # noqa: ANN401 return type(o)(_truncate_base64(item) for item in o) if isinstance(o, dict): return {k: _truncate_base64(v) for k, v in o.items()} + if isinstance(o, (int, float, bool, type(None))): + return o s = str(o) diff --git a/tests/core/utils/test_json_util.py b/tests/core/utils/test_json_util.py index 42f460c9..32af4574 100644 --- a/tests/core/utils/test_json_util.py +++ b/tests/core/utils/test_json_util.py @@ -1,6 +1,7 @@ import dataclasses import json -from ast import literal_eval + +import pytest from flexeval.core.utils.json_util import Base64TruncatingJSONEncoder @@ -8,7 +9,7 @@ @dataclasses.dataclass class TestDataClass: field1: str - field2: int + field2: int | float | bool | None class TestData: @@ -27,25 +28,32 @@ def test_truncate_base64() -> None: def _json_dumps(x): # noqa: ANN001, ANN202 return json.dumps(x, cls=Base64TruncatingJSONEncoder) - assert literal_eval(_json_dumps(TestDataClass("example", 123))) == {"field1": "example", "field2": "123"} + assert json.loads(_json_dumps(TestDataClass("example", 123))) == {"field1": "example", "field2": 123} + + assert json.loads(_json_dumps(TestDataClass("example", 1.23))) == pytest.approx( + {"field1": "example", "field2": 1.23} + ) + + assert json.loads(_json_dumps(TestDataClass("example", True))) == {"field1": "example", "field2": True} + assert json.loads(_json_dumps(TestDataClass("example", None))) == {"field1": "example", "field2": None} - assert literal_eval(_json_dumps(TestData())) == "TestData" + assert json.loads(_json_dumps(TestData())) == "TestData" - assert literal_eval(_json_dumps({"key": base64_string})) == { + assert json.loads(_json_dumps({"key": base64_string})) == { "key": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAgA/... [truncated, 169 chars total]" } - assert literal_eval(_json_dumps([base64_string, "normal string"])) == [ + assert json.loads(_json_dumps([base64_string, "normal string"])) == [ "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAgA/... [truncated, 169 chars total]", "normal string", ] - assert literal_eval(_json_dumps(TestDataClass(base64_string, 456))) == { + assert json.loads(_json_dumps(TestDataClass(base64_string, 456))) == { "field1": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAgA/... [truncated, 169 chars total]", - "field2": "456", + "field2": 456, } - image_url = literal_eval( + image_url = json.loads( _json_dumps({"messages": [{"content": {"type": "image_url", "image_url": {"url": base64_string}}}]}) ) assert (