diff --git a/pydeequ/verification.py b/pydeequ/verification.py index c164246..fb88f0a 100644 --- a/pydeequ/verification.py +++ b/pydeequ/verification.py @@ -5,14 +5,30 @@ from pyspark.sql import DataFrame, SparkSession from pydeequ.analyzers import AnalysisRunBuilder -from pydeequ.checks import Check +from pydeequ.checks import Check, CheckLevel from pydeequ.pandas_utils import ensure_pyspark_df # TODO integrate Analyzer context class AnomalyCheckConfig: - pass + + def __init__(self, level: CheckLevel, description): + self.level = level + self.description = description + + def _get_java_object(self, jvm): + self._jvm = jvm + self._java_level = self.level._get_java_object(self._jvm) + self._check_java_class = self._jvm.com.amazon.deequ.AnomalyCheckConfig + self._anomalyCheckConfig_jvm = self._check_java_class( + self._java_level, + self.description, + getattr(self._check_java_class, 'apply$default$3')(), + getattr(self._check_java_class, 'apply$default$4')(), + getattr(self._check_java_class, 'apply$default$5')(), + ) + return self._anomalyCheckConfig_jvm class VerificationResult: @@ -187,10 +203,11 @@ def addAnomalyCheck(self, anomaly, analyzer: _AnalyzerObject, anomalyCheckConfig :param anomalyCheckConfig: Some configuration settings for the Check :return: Adds an anomaly strategy to the run """ + anomalyCheckConfig_jvm = None if anomalyCheckConfig: - raise NotImplementedError("anomalyCheckConfigs have not been implemented yet, using default value") + anomalyCheckConfig_jvm = anomalyCheckConfig._get_java_object(self._jvm) - AnomalyCheckConfig = self._jvm.scala.Option.apply(anomalyCheckConfig) + AnomalyCheckConfig = self._jvm.scala.Option.apply(anomalyCheckConfig_jvm) anomaly._set_jvm(self._jvm) anomaly_jvm = anomaly._anomaly_jvm diff --git a/tests/test_anomaly_detection.py b/tests/test_anomaly_detection.py index ae349ac..1e397e5 100644 --- a/tests/test_anomaly_detection.py +++ b/tests/test_anomaly_detection.py @@ -181,7 +181,8 @@ def OnlineNormalStrategy( print(df.collect()) return df.select("check_status").collect() - def SimpleThresholdStrategy(self, df_prev, df_curr, analyzer_func, lowerBound, upperBound): + def SimpleThresholdStrategy(self, df_prev, df_curr, analyzer_func, lowerBound, upperBound, + anomalyCheckConfig: AnomalyCheckConfig = None): metricsRepository = InMemoryMetricsRepository(self.spark) previousKey = ResultKey(self.spark, ResultKey.current_milli_time() - 24 * 60 * 1000 * 60) @@ -196,7 +197,7 @@ def SimpleThresholdStrategy(self, df_prev, df_curr, analyzer_func, lowerBound, u .onData(df_curr) .useRepository(metricsRepository) .saveOrAppendResult(currKey) - .addAnomalyCheck(SimpleThresholdStrategy(lowerBound, upperBound), analyzer_func) + .addAnomalyCheck(SimpleThresholdStrategy(lowerBound, upperBound), analyzer_func, anomalyCheckConfig) .run() ) @@ -486,6 +487,20 @@ def get_anomalyDetector(self, anomaly): def test_anomalyDetector(self): self.get_anomalyDetector(SimpleThresholdStrategy(1.0, 3.0)) + def test_SimpleThresholdStrategy_Error(self): + config = AnomalyCheckConfig(description='test error case', level=CheckLevel.Error) + # Lower bound is 1 upper bound is 6 (Range: 1-6 rows) + self.assertEqual( + self.SimpleThresholdStrategy(self.df_1, self.df_2, Size(), 1.0, 4.0, config), [Row(check_status="Error")] + ) + + def test_SimpleThresholdStrategy_Warning(self): + config = AnomalyCheckConfig(description='test error case', level=CheckLevel.Warning) + # Lower bound is 1 upper bound is 6 (Range: 1-6 rows) + self.assertEqual( + self.SimpleThresholdStrategy(self.df_1, self.df_2, Size(), 1.0, 4.0, config), [Row(check_status="Warning")] + ) + # # def test_RelativeRateOfChangeStrategy(self): # metricsRepository = InMemoryMetricsRepository(self.spark)