Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/analyzers.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Here are the current supported functionalities of Analyzers.
| Compliance | Compliance(instance, predicate) | Done|
| Correlation | Correlation(column1, column2) | Done|
| CountDistinct | CountDistinct(columns) | Done|
| CustomSql | CustomSql(expression, disambiguator) | Done|
| Datatype | Datatype(column) | Done|
| Distinctness | Distinctness(columns) | Done|
| Entropy | Entropy(column) | Done|
Expand Down
24 changes: 24 additions & 0 deletions pydeequ/analyzers.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,30 @@ def _analyzer_jvm(self):
return self._deequAnalyzers.CountDistinct(to_scala_seq(self._jvm, self.columns))


class CustomSql(_AnalyzerObject):
"""
A custom SQL-based analyzer executing provided SQL expression.
The expression must return a single value.

:param str expression: A SQL expression to execute.
:param str disambiguator: A label used to distinguish this metric
when running multiple custom SQL analyzers. Defaults to "*".
"""

def __init__(self, expression: str, disambiguator: str = "*"):
self.expression = expression
self.disambiguator = disambiguator

@property
def _analyzer_jvm(self):
"""
Returns the result of SQL expression execution.

:return self
"""
return self._deequAnalyzers.CustomSql(self.expression, self.disambiguator)


class DataType(_AnalyzerObject):
"""
Data Type Analyzer. Returns the datatypes of column
Expand Down
29 changes: 29 additions & 0 deletions tests/test_analyzers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Compliance,
Correlation,
CountDistinct,
CustomSql,
DataType,
Distinctness,
Entropy,
Expand Down Expand Up @@ -111,6 +112,14 @@ def CountDistinct(self, columns):
df_from_json = self.spark.read.json(self.sc.parallelize([result_json]))
self.assertEqual(df_from_json.select("value").collect(), result_df.select("value").collect())
return result_df.select("value").collect()

def CustomSql(self, expression, disambiguator="*"):
result = self.AnalysisRunner.onData(self.df).addAnalyzer(CustomSql(expression, disambiguator)).run()
result_df = AnalyzerContext.successMetricsAsDataFrame(self.spark, result)
result_json = AnalyzerContext.successMetricsAsJson(self.spark, result)
df_from_json = self.spark.read.json(self.sc.parallelize([result_json]))
self.assertEqual(df_from_json.select("value").collect(), result_df.select("value").collect())
return result_df.select("value", "instance").collect()

def Datatype(self, column, where=None):
result = self.AnalysisRunner.onData(self.df).addAnalyzer(DataType(column, where)).run()
Expand Down Expand Up @@ -298,6 +307,26 @@ def test_CountDistinct(self):
def test_fail_CountDistinct(self):
self.assertEqual(self.CountDistinct("b"), [Row(value=1.0)])

def test_CustomSql(self):
self.df.createOrReplaceTempView("input_table")
self.assertEqual(self.CustomSql("SELECT SUM(b) FROM input_table"), [Row(value=6.0, instance="*")])
self.assertEqual(
self.CustomSql("SELECT AVG(LENGTH(a)) FROM input_table", disambiguator="foo"),
[Row(value=3.0, instance="foo")]
)
self.assertEqual(
self.CustomSql("SELECT MAX(c) FROM input_table", disambiguator="bar"),
[Row(value=6.0, instance="bar")]
)

@pytest.mark.xfail(reason="@unittest.expectedFailure")
def test_fail_CustomSql(self):
self.assertEqual(self.CustomSql("SELECT SUM(b) FROM input_table"), [Row(value=1.0)])

@pytest.mark.xfail(reason="@unittest.expectedFailure")
def test_fail_CustomSql_incorrect_query(self):
self.CustomSql("SELECT SUM(b)")

def test_DataType(self):
self.assertEqual(
self.Datatype("b"),
Expand Down