|
| 1 | +// Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +// or more contributor license agreements. See the NOTICE file |
| 3 | +// distributed with this work for additional information |
| 4 | +// regarding copyright ownership. The ASF licenses this file |
| 5 | +// to you under the Apache License, Version 2.0 (the |
| 6 | +// "License"); you may not use this file except in compliance |
| 7 | +// with the License. You may obtain a copy of the License at |
| 8 | +// |
| 9 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +// |
| 11 | +// Unless required by applicable law or agreed to in writing, |
| 12 | +// software distributed under the License is distributed on an |
| 13 | +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +// KIND, either express or implied. See the License for the |
| 15 | +// specific language governing permissions and limitations |
| 16 | +// under the License. |
| 17 | + |
| 18 | +use std::ffi::c_void; |
| 19 | +use std::sync::Arc; |
| 20 | + |
| 21 | +use abi_stable::pmr::ROption; |
| 22 | +use abi_stable::std_types::{RHashMap, RString}; |
| 23 | +use abi_stable::StableAbi; |
| 24 | +use datafusion_execution::config::SessionConfig; |
| 25 | +use datafusion_execution::runtime_env::RuntimeEnv; |
| 26 | +use datafusion_execution::TaskContext; |
| 27 | +use datafusion_expr::{ |
| 28 | + AggregateUDF, AggregateUDFImpl, ScalarUDF, ScalarUDFImpl, WindowUDF, WindowUDFImpl, |
| 29 | +}; |
| 30 | + |
| 31 | +use crate::session_config::FFI_SessionConfig; |
| 32 | +use crate::udaf::FFI_AggregateUDF; |
| 33 | +use crate::udf::FFI_ScalarUDF; |
| 34 | +use crate::udwf::FFI_WindowUDF; |
| 35 | + |
| 36 | +/// A stable struct for sharing [`TaskContext`] across FFI boundaries. |
| 37 | +#[repr(C)] |
| 38 | +#[derive(Debug, StableAbi)] |
| 39 | +#[allow(non_camel_case_types)] |
| 40 | +pub struct FFI_TaskContext { |
| 41 | + /// Return the session ID. |
| 42 | + pub session_id: unsafe extern "C" fn(&Self) -> RString, |
| 43 | + |
| 44 | + /// Return the task ID. |
| 45 | + pub task_id: unsafe extern "C" fn(&Self) -> ROption<RString>, |
| 46 | + |
| 47 | + /// Return the session configuration. |
| 48 | + pub session_config: unsafe extern "C" fn(&Self) -> FFI_SessionConfig, |
| 49 | + |
| 50 | + /// Returns a hashmap of names to scalar functions. |
| 51 | + pub scalar_functions: unsafe extern "C" fn(&Self) -> RHashMap<RString, FFI_ScalarUDF>, |
| 52 | + |
| 53 | + /// Returns a hashmap of names to aggregate functions. |
| 54 | + pub aggregate_functions: |
| 55 | + unsafe extern "C" fn(&Self) -> RHashMap<RString, FFI_AggregateUDF>, |
| 56 | + |
| 57 | + /// Returns a hashmap of names to window functions. |
| 58 | + pub window_functions: unsafe extern "C" fn(&Self) -> RHashMap<RString, FFI_WindowUDF>, |
| 59 | + |
| 60 | + /// Release the memory of the private data when it is no longer being used. |
| 61 | + pub release: unsafe extern "C" fn(arg: &mut Self), |
| 62 | + |
| 63 | + /// Internal data. This is only to be accessed by the provider of the plan. |
| 64 | + /// The foreign library should never attempt to access this data. |
| 65 | + pub private_data: *mut c_void, |
| 66 | + |
| 67 | + /// Utility to identify when FFI objects are accessed locally through |
| 68 | + /// the foreign interface. See [`crate::get_library_marker_id`] and |
| 69 | + /// the crate's `README.md` for more information. |
| 70 | + pub library_marker_id: extern "C" fn() -> usize, |
| 71 | +} |
| 72 | + |
| 73 | +struct TaskContextPrivateData { |
| 74 | + ctx: Arc<TaskContext>, |
| 75 | +} |
| 76 | + |
| 77 | +impl FFI_TaskContext { |
| 78 | + unsafe fn inner(&self) -> &Arc<TaskContext> { |
| 79 | + let private_data = self.private_data as *const TaskContextPrivateData; |
| 80 | + &(*private_data).ctx |
| 81 | + } |
| 82 | +} |
| 83 | + |
| 84 | +unsafe extern "C" fn session_id_fn_wrapper(ctx: &FFI_TaskContext) -> RString { |
| 85 | + let ctx = ctx.inner(); |
| 86 | + ctx.session_id().into() |
| 87 | +} |
| 88 | + |
| 89 | +unsafe extern "C" fn task_id_fn_wrapper(ctx: &FFI_TaskContext) -> ROption<RString> { |
| 90 | + let ctx = ctx.inner(); |
| 91 | + ctx.task_id().map(|s| s.as_str().into()).into() |
| 92 | +} |
| 93 | + |
| 94 | +unsafe extern "C" fn session_config_fn_wrapper( |
| 95 | + ctx: &FFI_TaskContext, |
| 96 | +) -> FFI_SessionConfig { |
| 97 | + let ctx = ctx.inner(); |
| 98 | + ctx.session_config().into() |
| 99 | +} |
| 100 | + |
| 101 | +unsafe extern "C" fn scalar_functions_fn_wrapper( |
| 102 | + ctx: &FFI_TaskContext, |
| 103 | +) -> RHashMap<RString, FFI_ScalarUDF> { |
| 104 | + let ctx = ctx.inner(); |
| 105 | + ctx.scalar_functions() |
| 106 | + .iter() |
| 107 | + .map(|(name, udf)| (name.to_owned().into(), Arc::clone(udf).into())) |
| 108 | + .collect() |
| 109 | +} |
| 110 | + |
| 111 | +unsafe extern "C" fn aggregate_functions_fn_wrapper( |
| 112 | + ctx: &FFI_TaskContext, |
| 113 | +) -> RHashMap<RString, FFI_AggregateUDF> { |
| 114 | + let ctx = ctx.inner(); |
| 115 | + ctx.aggregate_functions() |
| 116 | + .iter() |
| 117 | + .map(|(name, udaf)| { |
| 118 | + ( |
| 119 | + name.to_owned().into(), |
| 120 | + FFI_AggregateUDF::from(Arc::clone(udaf)), |
| 121 | + ) |
| 122 | + }) |
| 123 | + .collect() |
| 124 | +} |
| 125 | + |
| 126 | +unsafe extern "C" fn window_functions_fn_wrapper( |
| 127 | + ctx: &FFI_TaskContext, |
| 128 | +) -> RHashMap<RString, FFI_WindowUDF> { |
| 129 | + let ctx = ctx.inner(); |
| 130 | + ctx.window_functions() |
| 131 | + .iter() |
| 132 | + .map(|(name, udf)| (name.to_owned().into(), FFI_WindowUDF::from(Arc::clone(udf)))) |
| 133 | + .collect() |
| 134 | +} |
| 135 | + |
| 136 | +unsafe extern "C" fn release_fn_wrapper(ctx: &mut FFI_TaskContext) { |
| 137 | + let private_data = Box::from_raw(ctx.private_data as *mut TaskContextPrivateData); |
| 138 | + drop(private_data); |
| 139 | +} |
| 140 | + |
| 141 | +impl Drop for FFI_TaskContext { |
| 142 | + fn drop(&mut self) { |
| 143 | + unsafe { (self.release)(self) } |
| 144 | + } |
| 145 | +} |
| 146 | + |
| 147 | +impl From<Arc<TaskContext>> for FFI_TaskContext { |
| 148 | + fn from(ctx: Arc<TaskContext>) -> Self { |
| 149 | + let private_data = Box::new(TaskContextPrivateData { ctx }); |
| 150 | + |
| 151 | + FFI_TaskContext { |
| 152 | + session_id: session_id_fn_wrapper, |
| 153 | + task_id: task_id_fn_wrapper, |
| 154 | + session_config: session_config_fn_wrapper, |
| 155 | + scalar_functions: scalar_functions_fn_wrapper, |
| 156 | + aggregate_functions: aggregate_functions_fn_wrapper, |
| 157 | + window_functions: window_functions_fn_wrapper, |
| 158 | + release: release_fn_wrapper, |
| 159 | + private_data: Box::into_raw(private_data) as *mut c_void, |
| 160 | + library_marker_id: crate::get_library_marker_id, |
| 161 | + } |
| 162 | + } |
| 163 | +} |
| 164 | + |
| 165 | +impl From<FFI_TaskContext> for Arc<TaskContext> { |
| 166 | + fn from(ffi_ctx: FFI_TaskContext) -> Self { |
| 167 | + unsafe { |
| 168 | + if (ffi_ctx.library_marker_id)() == crate::get_library_marker_id() { |
| 169 | + return Arc::clone(ffi_ctx.inner()); |
| 170 | + } |
| 171 | + |
| 172 | + let task_id = (ffi_ctx.task_id)(&ffi_ctx).map(|s| s.to_string()).into(); |
| 173 | + let session_id = (ffi_ctx.session_id)(&ffi_ctx).into(); |
| 174 | + let session_config = (ffi_ctx.session_config)(&ffi_ctx); |
| 175 | + let session_config = |
| 176 | + SessionConfig::try_from(&session_config).unwrap_or_default(); |
| 177 | + |
| 178 | + let scalar_functions = (ffi_ctx.scalar_functions)(&ffi_ctx) |
| 179 | + .into_iter() |
| 180 | + .map(|kv_pair| { |
| 181 | + let udf = <Arc<dyn ScalarUDFImpl>>::from(&kv_pair.1); |
| 182 | + |
| 183 | + ( |
| 184 | + kv_pair.0.into_string(), |
| 185 | + Arc::new(ScalarUDF::new_from_shared_impl(udf)), |
| 186 | + ) |
| 187 | + }) |
| 188 | + .collect(); |
| 189 | + let aggregate_functions = (ffi_ctx.aggregate_functions)(&ffi_ctx) |
| 190 | + .into_iter() |
| 191 | + .map(|kv_pair| { |
| 192 | + let udaf = <Arc<dyn AggregateUDFImpl>>::from(&kv_pair.1); |
| 193 | + |
| 194 | + ( |
| 195 | + kv_pair.0.into_string(), |
| 196 | + Arc::new(AggregateUDF::new_from_shared_impl(udaf)), |
| 197 | + ) |
| 198 | + }) |
| 199 | + .collect(); |
| 200 | + let window_functions = (ffi_ctx.window_functions)(&ffi_ctx) |
| 201 | + .into_iter() |
| 202 | + .map(|kv_pair| { |
| 203 | + let udwf = <Arc<dyn WindowUDFImpl>>::from(&kv_pair.1); |
| 204 | + |
| 205 | + ( |
| 206 | + kv_pair.0.into_string(), |
| 207 | + Arc::new(WindowUDF::new_from_shared_impl(udwf)), |
| 208 | + ) |
| 209 | + }) |
| 210 | + .collect(); |
| 211 | + |
| 212 | + let runtime = Arc::new(RuntimeEnv::default()); |
| 213 | + |
| 214 | + Arc::new(TaskContext::new( |
| 215 | + task_id, |
| 216 | + session_id, |
| 217 | + session_config, |
| 218 | + scalar_functions, |
| 219 | + aggregate_functions, |
| 220 | + window_functions, |
| 221 | + runtime, |
| 222 | + )) |
| 223 | + } |
| 224 | + } |
| 225 | +} |
| 226 | + |
| 227 | +#[cfg(test)] |
| 228 | +mod tests { |
| 229 | + use std::sync::Arc; |
| 230 | + |
| 231 | + use datafusion::prelude::SessionContext; |
| 232 | + use datafusion_common::Result; |
| 233 | + use datafusion_execution::TaskContext; |
| 234 | + |
| 235 | + use crate::execution::FFI_TaskContext; |
| 236 | + |
| 237 | + #[test] |
| 238 | + fn ffi_task_ctx_round_trip() -> Result<()> { |
| 239 | + let session_ctx = SessionContext::new(); |
| 240 | + let original = session_ctx.task_ctx(); |
| 241 | + let mut ffi_task_ctx = FFI_TaskContext::from(Arc::clone(&original)); |
| 242 | + ffi_task_ctx.library_marker_id = crate::mock_foreign_marker_id; |
| 243 | + |
| 244 | + let foreign_task_ctx: Arc<TaskContext> = ffi_task_ctx.into(); |
| 245 | + |
| 246 | + // TaskContext doesn't implement Eq (nor should it) so check some of the |
| 247 | + // data is round tripping correctly. |
| 248 | + |
| 249 | + assert_eq!( |
| 250 | + original.scalar_functions(), |
| 251 | + foreign_task_ctx.scalar_functions() |
| 252 | + ); |
| 253 | + assert_eq!( |
| 254 | + original.aggregate_functions(), |
| 255 | + foreign_task_ctx.aggregate_functions() |
| 256 | + ); |
| 257 | + assert_eq!( |
| 258 | + original.window_functions(), |
| 259 | + foreign_task_ctx.window_functions() |
| 260 | + ); |
| 261 | + assert_eq!(original.task_id(), foreign_task_ctx.task_id()); |
| 262 | + assert_eq!(original.session_id(), foreign_task_ctx.session_id()); |
| 263 | + assert_eq!( |
| 264 | + format!("{:?}", original.session_config()), |
| 265 | + format!("{:?}", foreign_task_ctx.session_config()) |
| 266 | + ); |
| 267 | + |
| 268 | + Ok(()) |
| 269 | + } |
| 270 | +} |
0 commit comments