Skip to content

Commit 91eba1d

Browse files
authored
feat: support spark sec (#18728)
## Which issue does this PR close? Partially implements #15914 ## Rationale for this change Spark has support for secant https://spark.apache.org/docs/latest/api/sql/index.html#sec. This function is not there in other in DB's like postgres, mysql and sqlite3. Hence I have added this function into datafusion-spark ## What changes are included in this PR? ## Are these changes tested? Yes, Unit tests with inputs/outputs obtained from spark ## Are there any user-facing changes? Yes
1 parent 3ac3d0e commit 91eba1d

File tree

3 files changed

+93
-2
lines changed

3 files changed

+93
-2
lines changed

datafusion/spark/src/function/math/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ make_udf_function!(modulus::SparkPmod, pmod);
3535
make_udf_function!(rint::SparkRint, rint);
3636
make_udf_function!(width_bucket::SparkWidthBucket, width_bucket);
3737
make_udf_function!(trigonometry::SparkCsc, csc);
38+
make_udf_function!(trigonometry::SparkSec, sec);
3839

3940
pub mod expr_fn {
4041
use datafusion_functions::export_functions;
@@ -51,6 +52,7 @@ pub mod expr_fn {
5152
export_functions!((rint, "Returns the double value that is closest in value to the argument and is equal to a mathematical integer.", arg1));
5253
export_functions!((width_bucket, "Returns the bucket number into which the value of this expression would fall after being evaluated.", arg1 arg2 arg3 arg4));
5354
export_functions!((csc, "Returns the cosecant of expr.", arg1));
55+
export_functions!((sec, "Returns the secant of expr.", arg1));
5456
}
5557

5658
pub fn functions() -> Vec<Arc<ScalarUDF>> {
@@ -63,5 +65,6 @@ pub fn functions() -> Vec<Arc<ScalarUDF>> {
6365
rint(),
6466
width_bucket(),
6567
csc(),
68+
sec(),
6669
]
6770
}

datafusion/spark/src/function/math/trigonometry.rs

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,73 @@ fn spark_csc(arg: &ColumnarValue) -> Result<ColumnarValue> {
9595
)),
9696
}
9797
}
98+
99+
static SEC_FUNCTION_NAME: &str = "sec";
100+
101+
/// <https://spark.apache.org/docs/latest/api/sql/index.html#sec>
102+
#[derive(Debug, PartialEq, Eq, Hash)]
103+
pub struct SparkSec {
104+
signature: Signature,
105+
}
106+
107+
impl Default for SparkSec {
108+
fn default() -> Self {
109+
Self::new()
110+
}
111+
}
112+
113+
impl SparkSec {
114+
pub fn new() -> Self {
115+
Self {
116+
signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable),
117+
}
118+
}
119+
}
120+
121+
impl ScalarUDFImpl for SparkSec {
122+
fn as_any(&self) -> &dyn Any {
123+
self
124+
}
125+
126+
fn name(&self) -> &str {
127+
SEC_FUNCTION_NAME
128+
}
129+
130+
fn signature(&self) -> &Signature {
131+
&self.signature
132+
}
133+
134+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
135+
Ok(DataType::Float64)
136+
}
137+
138+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
139+
let [arg] = take_function_args(self.name(), &args.args)?;
140+
spark_sec(arg)
141+
}
142+
}
143+
144+
fn spark_sec(arg: &ColumnarValue) -> Result<ColumnarValue> {
145+
match arg {
146+
ColumnarValue::Scalar(ScalarValue::Float64(value)) => Ok(ColumnarValue::Scalar(
147+
ScalarValue::Float64(value.map(|x| 1.0 / x.cos())),
148+
)),
149+
ColumnarValue::Array(array) => match array.data_type() {
150+
DataType::Float64 => Ok(ColumnarValue::Array(Arc::new(
151+
array
152+
.as_primitive::<Float64Type>()
153+
.unary::<_, Float64Type>(|x| 1.0 / x.cos()),
154+
) as ArrayRef)),
155+
other => Err(unsupported_data_type_exec_err(
156+
SEC_FUNCTION_NAME,
157+
format!("{}", DataType::Float64).as_str(),
158+
other,
159+
)),
160+
},
161+
other => Err(unsupported_data_type_exec_err(
162+
SEC_FUNCTION_NAME,
163+
format!("{}", DataType::Float64).as_str(),
164+
&other.data_type(),
165+
)),
166+
}
167+
}

datafusion/sqllogictest/test_files/spark/math/sec.slt

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,23 @@
2323

2424
## Original Query: SELECT sec(0);
2525
## PySpark 3.5.5 Result: {'SEC(0)': 1.0, 'typeof(SEC(0))': 'double', 'typeof(0)': 'int'}
26-
#query
27-
#SELECT sec(0::int);
26+
query R
27+
SELECT sec(0::int);
28+
----
29+
1
30+
31+
query R
32+
SELECT sec(a) FROM (VALUES (0::INT), (1::INT), (-1::INT), (null)) AS t(a);
33+
----
34+
1
35+
1.850815717680926
36+
1.850815717680926
37+
NULL
38+
39+
query R
40+
SELECT sec(a) FROM (VALUES (pi()), (3 * pi()/2), (pi()/2) , (arrow_cast('NAN','Float32'))) AS t(a);
41+
----
42+
-1
43+
-5443746451065123
44+
16331239353195370
45+
NaN

0 commit comments

Comments
 (0)