File tree Expand file tree Collapse file tree 3 files changed +30
-0
lines changed Expand file tree Collapse file tree 3 files changed +30
-0
lines changed Original file line number Diff line number Diff line change 1313# limitations under the License.
1414
1515import logging
16+ import os
17+ import struct
1618from typing import Protocol
1719
1820import pandas as pd
@@ -122,3 +124,21 @@ def determine_data_size(
122124 return len (tgt_keys )
123125 else :
124126 return len (tgt_data )
127+
128+
129+ def set_random_state (random_state : int | None = None ):
130+ def get_random_int_from_os () -> int :
131+ # 32-bit, cryptographically secure random int from os
132+ return int (struct .unpack ("I" , os .urandom (4 ))[0 ])
133+
134+ if random_state is not None :
135+ _LOG .info (f"Global random_state set to `{ random_state } `" )
136+
137+ if random_state is None :
138+ random_state = get_random_int_from_os ()
139+
140+ import random
141+ import numpy as np
142+
143+ random .seed (random_state )
144+ np .random .seed (random_state )
Original file line number Diff line number Diff line change 6262 TGT_COLUMN_PREFIX ,
6363 REPORT_CREDITS ,
6464 ProgressCallbackWrapper ,
65+ set_random_state ,
6566)
6667from mostlyai .qa ._filesystem import Statistics , TemporaryWorkspace
6768
@@ -87,6 +88,7 @@ def report(
8788 max_sample_size_embeddings : int | None = None ,
8889 statistics_path : str | Path | None = None ,
8990 update_progress : ProgressCallback | None = None ,
91+ random_state : int | None = None ,
9092) -> tuple [Path , ModelMetrics | None ]:
9193 """
9294 Generate an HTML report and metrics for assessing synthetic data quality.
@@ -121,12 +123,15 @@ def report(
121123 max_sample_size_embeddings: The maximum sample size for embedding calculations.
122124 statistics_path: The path of where to store the statistics to be used by `report_from_statistics`
123125 update_progress: The progress callback.
126+ random_state: Seed for the random number generators.
124127
125128 Returns:
126129 The path to the generated HTML report.
127130 Metrics instance with accuracy, similarity, and distances metrics.
128131 """
129132
133+ set_random_state (random_state )
134+
130135 if syn_ctx_data is not None :
131136 if ctx_primary_key is None :
132137 raise ValueError ("If syn_ctx_data is provided, then ctx_primary_key must also be provided." )
Original file line number Diff line number Diff line change 3333 determine_data_size ,
3434 REPORT_CREDITS ,
3535 ProgressCallbackWrapper ,
36+ set_random_state ,
3637)
3738from mostlyai .qa ._filesystem import Statistics , TemporaryWorkspace
3839
@@ -53,6 +54,7 @@ def report_from_statistics(
5354 max_sample_size_accuracy : int | None = None ,
5455 max_sample_size_coherence : int | None = None ,
5556 update_progress : ProgressCallback | None = None ,
57+ random_state : int | None = None ,
5658) -> Path :
5759 """
5860 Generate an HTML report based on previously generated statistics and newly provided synthetic data samples.
@@ -70,11 +72,14 @@ def report_from_statistics(
7072 max_sample_size_accuracy: The maximum sample size for accuracy calculations.
7173 max_sample_size_coherence: The maximum sample size for coherence calculations.
7274 update_progress: The progress callback.
75+ random_state: Seed for the random number generators.
7376
7477 Returns:
7578 The path to the generated HTML report.
7679 """
7780
81+ set_random_state (random_state )
82+
7883 with (
7984 TemporaryWorkspace () as workspace ,
8085 ProgressCallbackWrapper (update_progress ) as progress ,
You can’t perform that action at this time.
0 commit comments