Skip to content

Commit c8efd3e

Browse files
authored
return pydantic metrics instead of dict
1 parent fd67211 commit c8efd3e

File tree

13 files changed

+458
-148
lines changed

13 files changed

+458
-148
lines changed

README.md

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,21 +27,21 @@ import json
2727
from mostlyai import qa
2828

2929
# fetch original + synthetic data
30-
base_url = 'https://github.com/mostly-ai/mostlyai-qa/raw/refs/heads/main/examples/quick-start'
31-
syn = pd.read_csv(f'{base_url}/census2k-syn_mostly.csv.gz')
30+
base_url = "https://github.com/mostly-ai/mostlyai-qa/raw/refs/heads/main/examples/quick-start"
31+
syn = pd.read_csv(f"{base_url}/census2k-syn_mostly.csv.gz")
3232
# syn = pd.read_csv(f'{base_url}/census2k-syn_flip30.csv.gz') # a 30% perturbation of trn
33-
trn = pd.read_csv(f'{base_url}/census2k-trn.csv.gz')
34-
hol = pd.read_csv(f'{base_url}/census2k-hol.csv.gz')
33+
trn = pd.read_csv(f"{base_url}/census2k-trn.csv.gz")
34+
hol = pd.read_csv(f"{base_url}/census2k-hol.csv.gz")
3535

3636
# runs for ~30secs
3737
report_path, metrics = qa.report(
38-
syn_tgt_data = syn,
39-
trn_tgt_data = trn,
40-
hol_tgt_data = hol,
38+
syn_tgt_data=syn,
39+
trn_tgt_data=trn,
40+
hol_tgt_data=hol,
4141
)
4242

4343
# pretty print metrics
44-
print(json.dumps(metrics, indent=4))
44+
print(metrics.model_dump_json(indent=4))
4545

4646
# open up HTML report in new browser window
4747
webbrowser.open(f"file://{report_path.absolute()}")
@@ -104,7 +104,7 @@ def report(
104104
max_sample_size_embeddings: int | None = None,
105105
statistics_path: str | Path | None = None,
106106
on_progress: ProgressCallback | None = None,
107-
) -> tuple[Path, dict | None]:
107+
) -> tuple[Path, Metrics | None]:
108108
"""
109109
Generate HTML report and metrics for comparing synthetic and original data samples.
110110
@@ -128,7 +128,7 @@ def report(
128128
on_progress: A custom progress callback
129129
Returns:
130130
1. Path to the HTML report
131-
2. Dictionary of calculated metrics:
131+
2. Pydantic Metrics:
132132
- `accuracy`: # Accuracy is defined as (100% - Total Variation Distance), for each distribution, and then averaged across.
133133
- `overall`: Overall accuracy of synthetic data, i.e. average across univariate, bivariate and coherence.
134134
- `univariate`: Average accuracy of discretized univariate distributions.

examples/baseball-players.ipynb

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@
1313
"execution_count": null,
1414
"id": "082b3689-2807-420b-8bb7-a9a40cedf3c3",
1515
"metadata": {},
16-
"outputs": [],
1716
"source": [
1817
"import pandas as pd\n",
1918
"import webbrowser\n",
2019
"from pathlib import Path\n",
2120
"from mostlyai.qa import report\n",
2221
"\n",
2322
"wdir = Path(\"baseball-players\")"
24-
]
23+
],
24+
"outputs": []
2525
},
2626
{
2727
"cell_type": "markdown",
@@ -36,7 +36,6 @@
3636
"execution_count": null,
3737
"id": "95f6914a-e6cf-4e1f-8183-7f482228317f",
3838
"metadata": {},
39-
"outputs": [],
4039
"source": [
4140
"report_path, metrics = report(\n",
4241
" syn_tgt_data=pd.read_parquet(wdir / \"generated-data\"),\n",
@@ -46,17 +45,18 @@
4645
" report_path=\"baseball-players.html\",\n",
4746
")\n",
4847
"metrics"
49-
]
48+
],
49+
"outputs": []
5050
},
5151
{
5252
"cell_type": "code",
5353
"execution_count": null,
5454
"id": "88046c1f-0343-4e15-a1d0-ab2191417492",
5555
"metadata": {},
56-
"outputs": [],
5756
"source": [
5857
"webbrowser.open(f\"file://{report_path.absolute()}\")"
59-
]
58+
],
59+
"outputs": []
6060
},
6161
{
6262
"cell_type": "markdown",
@@ -71,7 +71,6 @@
7171
"execution_count": null,
7272
"id": "b45c06d4-1a7e-4bf2-aa83-f3f6e411caa9",
7373
"metadata": {},
74-
"outputs": [],
7574
"source": [
7675
"report_path, metrics = report(\n",
7776
" syn_tgt_data=pd.read_parquet(wdir / \"generated-data\"),\n",
@@ -86,17 +85,18 @@
8685
" report_path=\"baseball-players-with-context.html\",\n",
8786
")\n",
8887
"metrics"
89-
]
88+
],
89+
"outputs": []
9090
},
9191
{
9292
"cell_type": "code",
9393
"execution_count": null,
9494
"id": "00aa1fe7-6a9a-40c6-94d0-8b03632f1fa8",
9595
"metadata": {},
96-
"outputs": [],
9796
"source": [
9897
"webbrowser.open(f\"file://{report_path.absolute()}\")"
99-
]
98+
],
99+
"outputs": []
100100
}
101101
],
102102
"metadata": {

examples/baseball-seasons.ipynb

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@
1313
"execution_count": null,
1414
"id": "082b3689-2807-420b-8bb7-a9a40cedf3c3",
1515
"metadata": {},
16-
"outputs": [],
1716
"source": [
1817
"import pandas as pd\n",
1918
"import webbrowser\n",
2019
"from pathlib import Path\n",
2120
"from mostlyai.qa import report\n",
2221
"\n",
2322
"wdir = Path(\"baseball-seasons\")"
24-
]
23+
],
24+
"outputs": []
2525
},
2626
{
2727
"cell_type": "markdown",
@@ -36,7 +36,6 @@
3636
"execution_count": null,
3737
"id": "95f6914a-e6cf-4e1f-8183-7f482228317f",
3838
"metadata": {},
39-
"outputs": [],
4039
"source": [
4140
"report_path, metrics = report(\n",
4241
" syn_tgt_data=pd.read_parquet(wdir / \"generated-data\"),\n",
@@ -47,17 +46,18 @@
4746
" report_path=\"baseball-seasons.html\",\n",
4847
")\n",
4948
"metrics"
50-
]
49+
],
50+
"outputs": []
5151
},
5252
{
5353
"cell_type": "code",
5454
"execution_count": null,
5555
"id": "88046c1f-0343-4e15-a1d0-ab2191417492",
5656
"metadata": {},
57-
"outputs": [],
5857
"source": [
5958
"webbrowser.open(f\"file://{report_path.absolute()}\")"
60-
]
59+
],
60+
"outputs": []
6161
},
6262
{
6363
"cell_type": "markdown",
@@ -72,7 +72,6 @@
7272
"execution_count": null,
7373
"id": "b45c06d4-1a7e-4bf2-aa83-f3f6e411caa9",
7474
"metadata": {},
75-
"outputs": [],
7675
"source": [
7776
"report_path, metrics = report(\n",
7877
" syn_tgt_data=pd.read_parquet(wdir / \"generated-data\"),\n",
@@ -87,17 +86,18 @@
8786
" report_path=\"baseball-seasons-with-context.html\",\n",
8887
")\n",
8988
"metrics"
90-
]
89+
],
90+
"outputs": []
9191
},
9292
{
9393
"cell_type": "code",
9494
"execution_count": null,
9595
"id": "00aa1fe7-6a9a-40c6-94d0-8b03632f1fa8",
9696
"metadata": {},
97-
"outputs": [],
9897
"source": [
9998
"webbrowser.open(f\"file://{report_path.absolute()}\")"
100-
]
99+
],
100+
"outputs": []
101101
}
102102
],
103103
"metadata": {

examples/benchmark.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
" trn_tgt_data=tgt,\n",
6060
" hol_tgt_data=hol,\n",
6161
" )\n",
62-
" row = pd.json_normalize(metrics, sep=\"_\")\n",
62+
" row = pd.json_normalize(metrics.model_dump(), sep=\"_\")\n",
6363
" row.insert(0, \"dataset\", dataset)\n",
6464
" row.insert(1, \"synthesizer\", synthesizer)\n",
6565
" rows += [row]\n",
@@ -665,7 +665,7 @@
665665
"name": "python",
666666
"nbconvert_exporter": "python",
667667
"pygments_lexer": "ipython3",
668-
"version": "3.10.0"
668+
"version": "3.11.7"
669669
}
670670
},
671671
"nbformat": 4,

examples/quick-start.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
")\n",
3636
"\n",
3737
"# pretty print metrics\n",
38-
"print(json.dumps(metrics, indent=4))\n",
38+
"print(metrics.model_dump_json(indent=4))\n",
3939
"\n",
4040
"# open up HTML report in new browser window\n",
4141
"webbrowser.open(f\"file://{report_path.absolute()}\")"

0 commit comments

Comments
 (0)