Skip to content

Commit 0c6b654

Browse files
authored
Implement FFI task context and task context provider (#18918)
## Which issue does this PR close? Addresses part of #18671 but does not close it. ## Rationale for this change We have an issue in the current FFI code in that we will need access to the `TaskContext` to do function encoding and decoding to pass logical expressions across the boundary. We use the proto crate for this operations. We cannot encode and decode functions that have been registered since our current code creates a default session context. This causes problems in crates where we are using UDFs as inputs to either aggregations, window functions, or table filters. With this change we keep a _weak_ reference to a new trait, a `TaskContextProvider`. By keeping a weak trait we make sure we do not create a circular dependency between function that is internally holding on to the provider, which holds the task context, which holds the function. ## What changes are included in this PR? - Introduce the `TaskContextProvider` trait - Implement FFI versions of `TaskContext` and `TaskContextProvider`. This PR does _not_ use these structures in the current code. That is coming as part of a later PR in an effort to keep the size of the PRs small for effective code review. ## Are these changes tested? Unit tests are added. Coverage report: <img width="525" height="218" alt="Screenshot 2025-11-25 at 2 36 57 PM" src="https://github.com/user-attachments/assets/4e1da78c-8058-41dc-bf87-e4ccb3d2b894" /> ## Are there any user-facing changes? No.
1 parent d4820d1 commit 0c6b654

File tree

20 files changed

+657
-131
lines changed

20 files changed

+657
-131
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/core/src/execution/context/mod.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1860,6 +1860,12 @@ impl FunctionRegistry for SessionContext {
18601860
}
18611861
}
18621862

1863+
impl datafusion_execution::TaskContextProvider for SessionContext {
1864+
fn task_ctx(&self) -> Arc<TaskContext> {
1865+
SessionContext::task_ctx(self)
1866+
}
1867+
}
1868+
18631869
/// Create a new task context instance from SessionContext
18641870
impl From<&SessionContext> for TaskContext {
18651871
fn from(session: &SessionContext) -> Self {

datafusion/core/src/execution/session_state.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1988,6 +1988,12 @@ impl FunctionRegistry for SessionState {
19881988
}
19891989
}
19901990

1991+
impl datafusion_execution::TaskContextProvider for SessionState {
1992+
fn task_ctx(&self) -> Arc<TaskContext> {
1993+
SessionState::task_ctx(self)
1994+
}
1995+
}
1996+
19911997
impl OptimizerConfig for SessionState {
19921998
fn query_execution_start_time(&self) -> DateTime<Utc> {
19931999
self.execution_props.query_execution_start_time

datafusion/execution/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,4 +47,4 @@ pub mod registry {
4747
pub use disk_manager::DiskManager;
4848
pub use registry::FunctionRegistry;
4949
pub use stream::{RecordBatchStream, SendableRecordBatchStream};
50-
pub use task::TaskContext;
50+
pub use task::{TaskContext, TaskContextProvider};

datafusion/execution/src/task.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,11 @@ impl FunctionRegistry for TaskContext {
211211
}
212212
}
213213

214+
/// Produce the [`TaskContext`].
215+
pub trait TaskContextProvider {
216+
fn task_ctx(&self) -> Arc<TaskContext>;
217+
}
218+
214219
#[cfg(test)]
215220
mod tests {
216221
use super::*;

datafusion/ffi/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ async-ffi = { version = "0.5.0", features = ["abi_stable"] }
4848
async-trait = { workspace = true }
4949
datafusion = { workspace = true, default-features = false }
5050
datafusion-common = { workspace = true }
51+
datafusion-execution = { workspace = true }
5152
datafusion-expr = { workspace = true }
5253
datafusion-functions-aggregate-common = { workspace = true }
5354
datafusion-physical-expr = { workspace = true }

datafusion/ffi/README.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,29 @@ your unit tests you should override this with
170170
`crate::mock_foreign_marker_id` to force your test to create the foreign
171171
variant of your struct.
172172

173+
## Task Context Provider
174+
175+
Many of the FFI structs in this crate contain a `FFI_TaskContextProvider`. The
176+
purpose of this struct is to _weakly_ hold a reference to a method to
177+
access the current `TaskContext`. The reason we need this accessor is because
178+
we use the `datafusion-proto` crate to serialize and deserialize data across
179+
the FFI boundary. In particular, we need to serialize and deserialize
180+
functions using a `TaskContext`, which implements `FunctionRegistry`.
181+
182+
This becomes difficult because we may need to register multiple user defined
183+
functions, table or catalog providers, etc with a `Session`, and each of these
184+
will need the `TaskContext` to perform the processing. For this reason we
185+
cannot simply include the `TaskContext` at the time of registration because
186+
it would not have knowledge of anything registered afterward.
187+
188+
The `FFI_TaskContextProvider` is built from a trait that provides a method
189+
to get the current `TaskContext`. `FFI_TaskContextProvider` only holds a
190+
`Weak` reference to the `TaskContextProvider`, because otherwise we could
191+
create a circular dependency at runtime. It is imperative that if you use
192+
these methods that your provider remains valid for the lifetime of the
193+
calls. The `FFI_TaskContextProvider` is implemented on `SessionContext`
194+
and it is easy to implement on any struct that implements `Session`.
195+
173196
[apache datafusion]: https://datafusion.apache.org/
174197
[api docs]: http://docs.rs/datafusion-ffi/latest
175198
[rust abi]: https://doc.rust-lang.org/reference/abi.html
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
mod task_ctx;
19+
pub mod task_ctx_provider;
20+
21+
pub use task_ctx::FFI_TaskContext;
22+
pub use task_ctx_provider::FFI_TaskContextProvider;
Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::ffi::c_void;
19+
use std::sync::Arc;
20+
21+
use abi_stable::pmr::ROption;
22+
use abi_stable::std_types::{RHashMap, RString};
23+
use abi_stable::StableAbi;
24+
use datafusion_execution::config::SessionConfig;
25+
use datafusion_execution::runtime_env::RuntimeEnv;
26+
use datafusion_execution::TaskContext;
27+
use datafusion_expr::{
28+
AggregateUDF, AggregateUDFImpl, ScalarUDF, ScalarUDFImpl, WindowUDF, WindowUDFImpl,
29+
};
30+
31+
use crate::session_config::FFI_SessionConfig;
32+
use crate::udaf::FFI_AggregateUDF;
33+
use crate::udf::FFI_ScalarUDF;
34+
use crate::udwf::FFI_WindowUDF;
35+
36+
/// A stable struct for sharing [`TaskContext`] across FFI boundaries.
37+
#[repr(C)]
38+
#[derive(Debug, StableAbi)]
39+
#[allow(non_camel_case_types)]
40+
pub struct FFI_TaskContext {
41+
/// Return the session ID.
42+
pub session_id: unsafe extern "C" fn(&Self) -> RString,
43+
44+
/// Return the task ID.
45+
pub task_id: unsafe extern "C" fn(&Self) -> ROption<RString>,
46+
47+
/// Return the session configuration.
48+
pub session_config: unsafe extern "C" fn(&Self) -> FFI_SessionConfig,
49+
50+
/// Returns a hashmap of names to scalar functions.
51+
pub scalar_functions: unsafe extern "C" fn(&Self) -> RHashMap<RString, FFI_ScalarUDF>,
52+
53+
/// Returns a hashmap of names to aggregate functions.
54+
pub aggregate_functions:
55+
unsafe extern "C" fn(&Self) -> RHashMap<RString, FFI_AggregateUDF>,
56+
57+
/// Returns a hashmap of names to window functions.
58+
pub window_functions: unsafe extern "C" fn(&Self) -> RHashMap<RString, FFI_WindowUDF>,
59+
60+
/// Release the memory of the private data when it is no longer being used.
61+
pub release: unsafe extern "C" fn(arg: &mut Self),
62+
63+
/// Internal data. This is only to be accessed by the provider of the plan.
64+
/// The foreign library should never attempt to access this data.
65+
pub private_data: *mut c_void,
66+
67+
/// Utility to identify when FFI objects are accessed locally through
68+
/// the foreign interface. See [`crate::get_library_marker_id`] and
69+
/// the crate's `README.md` for more information.
70+
pub library_marker_id: extern "C" fn() -> usize,
71+
}
72+
73+
struct TaskContextPrivateData {
74+
ctx: Arc<TaskContext>,
75+
}
76+
77+
impl FFI_TaskContext {
78+
unsafe fn inner(&self) -> &Arc<TaskContext> {
79+
let private_data = self.private_data as *const TaskContextPrivateData;
80+
&(*private_data).ctx
81+
}
82+
}
83+
84+
unsafe extern "C" fn session_id_fn_wrapper(ctx: &FFI_TaskContext) -> RString {
85+
let ctx = ctx.inner();
86+
ctx.session_id().into()
87+
}
88+
89+
unsafe extern "C" fn task_id_fn_wrapper(ctx: &FFI_TaskContext) -> ROption<RString> {
90+
let ctx = ctx.inner();
91+
ctx.task_id().map(|s| s.as_str().into()).into()
92+
}
93+
94+
unsafe extern "C" fn session_config_fn_wrapper(
95+
ctx: &FFI_TaskContext,
96+
) -> FFI_SessionConfig {
97+
let ctx = ctx.inner();
98+
ctx.session_config().into()
99+
}
100+
101+
unsafe extern "C" fn scalar_functions_fn_wrapper(
102+
ctx: &FFI_TaskContext,
103+
) -> RHashMap<RString, FFI_ScalarUDF> {
104+
let ctx = ctx.inner();
105+
ctx.scalar_functions()
106+
.iter()
107+
.map(|(name, udf)| (name.to_owned().into(), Arc::clone(udf).into()))
108+
.collect()
109+
}
110+
111+
unsafe extern "C" fn aggregate_functions_fn_wrapper(
112+
ctx: &FFI_TaskContext,
113+
) -> RHashMap<RString, FFI_AggregateUDF> {
114+
let ctx = ctx.inner();
115+
ctx.aggregate_functions()
116+
.iter()
117+
.map(|(name, udaf)| {
118+
(
119+
name.to_owned().into(),
120+
FFI_AggregateUDF::from(Arc::clone(udaf)),
121+
)
122+
})
123+
.collect()
124+
}
125+
126+
unsafe extern "C" fn window_functions_fn_wrapper(
127+
ctx: &FFI_TaskContext,
128+
) -> RHashMap<RString, FFI_WindowUDF> {
129+
let ctx = ctx.inner();
130+
ctx.window_functions()
131+
.iter()
132+
.map(|(name, udf)| (name.to_owned().into(), FFI_WindowUDF::from(Arc::clone(udf))))
133+
.collect()
134+
}
135+
136+
unsafe extern "C" fn release_fn_wrapper(ctx: &mut FFI_TaskContext) {
137+
let private_data = Box::from_raw(ctx.private_data as *mut TaskContextPrivateData);
138+
drop(private_data);
139+
}
140+
141+
impl Drop for FFI_TaskContext {
142+
fn drop(&mut self) {
143+
unsafe { (self.release)(self) }
144+
}
145+
}
146+
147+
impl From<Arc<TaskContext>> for FFI_TaskContext {
148+
fn from(ctx: Arc<TaskContext>) -> Self {
149+
let private_data = Box::new(TaskContextPrivateData { ctx });
150+
151+
FFI_TaskContext {
152+
session_id: session_id_fn_wrapper,
153+
task_id: task_id_fn_wrapper,
154+
session_config: session_config_fn_wrapper,
155+
scalar_functions: scalar_functions_fn_wrapper,
156+
aggregate_functions: aggregate_functions_fn_wrapper,
157+
window_functions: window_functions_fn_wrapper,
158+
release: release_fn_wrapper,
159+
private_data: Box::into_raw(private_data) as *mut c_void,
160+
library_marker_id: crate::get_library_marker_id,
161+
}
162+
}
163+
}
164+
165+
impl From<FFI_TaskContext> for Arc<TaskContext> {
166+
fn from(ffi_ctx: FFI_TaskContext) -> Self {
167+
unsafe {
168+
if (ffi_ctx.library_marker_id)() == crate::get_library_marker_id() {
169+
return Arc::clone(ffi_ctx.inner());
170+
}
171+
172+
let task_id = (ffi_ctx.task_id)(&ffi_ctx).map(|s| s.to_string()).into();
173+
let session_id = (ffi_ctx.session_id)(&ffi_ctx).into();
174+
let session_config = (ffi_ctx.session_config)(&ffi_ctx);
175+
let session_config =
176+
SessionConfig::try_from(&session_config).unwrap_or_default();
177+
178+
let scalar_functions = (ffi_ctx.scalar_functions)(&ffi_ctx)
179+
.into_iter()
180+
.map(|kv_pair| {
181+
let udf = <Arc<dyn ScalarUDFImpl>>::from(&kv_pair.1);
182+
183+
(
184+
kv_pair.0.into_string(),
185+
Arc::new(ScalarUDF::new_from_shared_impl(udf)),
186+
)
187+
})
188+
.collect();
189+
let aggregate_functions = (ffi_ctx.aggregate_functions)(&ffi_ctx)
190+
.into_iter()
191+
.map(|kv_pair| {
192+
let udaf = <Arc<dyn AggregateUDFImpl>>::from(&kv_pair.1);
193+
194+
(
195+
kv_pair.0.into_string(),
196+
Arc::new(AggregateUDF::new_from_shared_impl(udaf)),
197+
)
198+
})
199+
.collect();
200+
let window_functions = (ffi_ctx.window_functions)(&ffi_ctx)
201+
.into_iter()
202+
.map(|kv_pair| {
203+
let udwf = <Arc<dyn WindowUDFImpl>>::from(&kv_pair.1);
204+
205+
(
206+
kv_pair.0.into_string(),
207+
Arc::new(WindowUDF::new_from_shared_impl(udwf)),
208+
)
209+
})
210+
.collect();
211+
212+
let runtime = Arc::new(RuntimeEnv::default());
213+
214+
Arc::new(TaskContext::new(
215+
task_id,
216+
session_id,
217+
session_config,
218+
scalar_functions,
219+
aggregate_functions,
220+
window_functions,
221+
runtime,
222+
))
223+
}
224+
}
225+
}
226+
227+
#[cfg(test)]
228+
mod tests {
229+
use std::sync::Arc;
230+
231+
use datafusion::prelude::SessionContext;
232+
use datafusion_common::Result;
233+
use datafusion_execution::TaskContext;
234+
235+
use crate::execution::FFI_TaskContext;
236+
237+
#[test]
238+
fn ffi_task_ctx_round_trip() -> Result<()> {
239+
let session_ctx = SessionContext::new();
240+
let original = session_ctx.task_ctx();
241+
let mut ffi_task_ctx = FFI_TaskContext::from(Arc::clone(&original));
242+
ffi_task_ctx.library_marker_id = crate::mock_foreign_marker_id;
243+
244+
let foreign_task_ctx: Arc<TaskContext> = ffi_task_ctx.into();
245+
246+
// TaskContext doesn't implement Eq (nor should it) so check some of the
247+
// data is round tripping correctly.
248+
249+
assert_eq!(
250+
original.scalar_functions(),
251+
foreign_task_ctx.scalar_functions()
252+
);
253+
assert_eq!(
254+
original.aggregate_functions(),
255+
foreign_task_ctx.aggregate_functions()
256+
);
257+
assert_eq!(
258+
original.window_functions(),
259+
foreign_task_ctx.window_functions()
260+
);
261+
assert_eq!(original.task_id(), foreign_task_ctx.task_id());
262+
assert_eq!(original.session_id(), foreign_task_ctx.session_id());
263+
assert_eq!(
264+
format!("{:?}", original.session_config()),
265+
format!("{:?}", foreign_task_ctx.session_config())
266+
);
267+
268+
Ok(())
269+
}
270+
}

0 commit comments

Comments
 (0)