Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions pydeequ/verification.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,30 @@
from pyspark.sql import DataFrame, SparkSession

from pydeequ.analyzers import AnalysisRunBuilder
from pydeequ.checks import Check
from pydeequ.checks import Check, CheckLevel
from pydeequ.pandas_utils import ensure_pyspark_df

# TODO integrate Analyzer context


class AnomalyCheckConfig:
pass

def __init__(self, level: CheckLevel, description):
self.level = level
self.description = description

def _get_java_object(self, jvm):
self._jvm = jvm
self._java_level = self.level._get_java_object(self._jvm)
self._check_java_class = self._jvm.com.amazon.deequ.AnomalyCheckConfig
self._anomalyCheckConfig_jvm = self._check_java_class(
self._java_level,
self.description,
getattr(self._check_java_class, 'apply$default$3')(),
getattr(self._check_java_class, 'apply$default$4')(),
getattr(self._check_java_class, 'apply$default$5')(),
)
return self._anomalyCheckConfig_jvm


class VerificationResult:
Expand Down Expand Up @@ -187,10 +203,11 @@ def addAnomalyCheck(self, anomaly, analyzer: _AnalyzerObject, anomalyCheckConfig
:param anomalyCheckConfig: Some configuration settings for the Check
:return: Adds an anomaly strategy to the run
"""
anomalyCheckConfig_jvm = None
if anomalyCheckConfig:
raise NotImplementedError("anomalyCheckConfigs have not been implemented yet, using default value")
anomalyCheckConfig_jvm = anomalyCheckConfig._get_java_object(self._jvm)

AnomalyCheckConfig = self._jvm.scala.Option.apply(anomalyCheckConfig)
AnomalyCheckConfig = self._jvm.scala.Option.apply(anomalyCheckConfig_jvm)

anomaly._set_jvm(self._jvm)
anomaly_jvm = anomaly._anomaly_jvm
Expand Down
19 changes: 17 additions & 2 deletions tests/test_anomaly_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,8 @@ def OnlineNormalStrategy(
print(df.collect())
return df.select("check_status").collect()

def SimpleThresholdStrategy(self, df_prev, df_curr, analyzer_func, lowerBound, upperBound):
def SimpleThresholdStrategy(self, df_prev, df_curr, analyzer_func, lowerBound, upperBound,
anomalyCheckConfig: AnomalyCheckConfig = None):
metricsRepository = InMemoryMetricsRepository(self.spark)
previousKey = ResultKey(self.spark, ResultKey.current_milli_time() - 24 * 60 * 1000 * 60)

Expand All @@ -196,7 +197,7 @@ def SimpleThresholdStrategy(self, df_prev, df_curr, analyzer_func, lowerBound, u
.onData(df_curr)
.useRepository(metricsRepository)
.saveOrAppendResult(currKey)
.addAnomalyCheck(SimpleThresholdStrategy(lowerBound, upperBound), analyzer_func)
.addAnomalyCheck(SimpleThresholdStrategy(lowerBound, upperBound), analyzer_func, anomalyCheckConfig)
.run()
)

Expand Down Expand Up @@ -486,6 +487,20 @@ def get_anomalyDetector(self, anomaly):
def test_anomalyDetector(self):
self.get_anomalyDetector(SimpleThresholdStrategy(1.0, 3.0))

def test_SimpleThresholdStrategy_Error(self):
config = AnomalyCheckConfig(description='test error case', level=CheckLevel.Error)
# Lower bound is 1 upper bound is 6 (Range: 1-6 rows)
self.assertEqual(
self.SimpleThresholdStrategy(self.df_1, self.df_2, Size(), 1.0, 4.0, config), [Row(check_status="Error")]
)

def test_SimpleThresholdStrategy_Warning(self):
config = AnomalyCheckConfig(description='test error case', level=CheckLevel.Warning)
# Lower bound is 1 upper bound is 6 (Range: 1-6 rows)
self.assertEqual(
self.SimpleThresholdStrategy(self.df_1, self.df_2, Size(), 1.0, 4.0, config), [Row(check_status="Warning")]
)

#
# def test_RelativeRateOfChangeStrategy(self):
# metricsRepository = InMemoryMetricsRepository(self.spark)
Expand Down