diff --git a/rs_bindings_from_cc/generate_bindings/database/rs_snippet.rs b/rs_bindings_from_cc/generate_bindings/database/rs_snippet.rs index 0eace25cd..849b96c31 100644 --- a/rs_bindings_from_cc/generate_bindings/database/rs_snippet.rs +++ b/rs_bindings_from_cc/generate_bindings/database/rs_snippet.rs @@ -491,6 +491,63 @@ fn new_c9_co_record( })) } +/// Information about how the owned function object may be called. +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +pub enum FnKind { + /// A function object that may be called in any context, any number of times. + Fn, + + /// A function object that requires mutable access in order to invoke, and may be called any + /// number of times. + FnMut, + + /// A function object that may be called at most once. + FnOnce, +} + +impl ToTokens for FnKind { + fn to_tokens(&self, tokens: &mut TokenStream) { + match self { + FnKind::Fn => quote! { ::core::ops::Fn }, + FnKind::FnMut => quote! { ::core::ops::FnMut }, + FnKind::FnOnce => quote! { ::core::ops::FnOnce }, + } + .to_tokens(tokens); + } +} + +/// Information about a dyn callable type. +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct DynCallable { + pub fn_kind: FnKind, + pub return_type: Rc, + pub param_types: Rc<[RsTypeKind]>, + pub thunk_ident: Ident, +} + +impl DynCallable { + /// Returns a `TokenStream` in the shape of `-> Output`, or None if the return type is void. + pub fn rust_return_type_fragment(&self, db: &dyn BindingsGenerator) -> Option { + if self.return_type.is_void() { + None + } else { + let return_type_tokens = self.return_type.to_token_stream(db); + Some(quote! { -> #return_type_tokens }) + } + } + + /// Returns a `TokenStream` in the shape of `dyn Trait(Inputs) -> Output`. + pub fn dyn_fn_spelling(&self, db: &dyn BindingsGenerator) -> TokenStream { + let rust_return_type_fragment = self.rust_return_type_fragment(db); + let param_type_tokens = + self.param_types.iter().map(|param_ty| param_ty.to_token_stream(db)); + let fn_kind = self.fn_kind; + quote! { + dyn #fn_kind(#(#param_type_tokens),*) #rust_return_type_fragment + } + } +} + #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub enum BridgeRsTypeKind { BridgeVoidConverters { @@ -512,6 +569,7 @@ pub enum BridgeRsTypeKind { StdString { in_cc_std: bool, }, + DynCallable(Rc), } impl BridgeRsTypeKind { @@ -565,6 +623,24 @@ impl BridgeRsTypeKind { BridgeRsTypeKind::StdString { in_cc_std } } + BridgeType::DynCallable { fn_kind, return_type, param_types } => { + BridgeRsTypeKind::DynCallable(Rc::new(DynCallable { + fn_kind: match fn_kind { + ir::FnKind::Fn => FnKind::Fn, + ir::FnKind::FnMut => FnKind::FnMut, + ir::FnKind::FnOnce => FnKind::FnOnce, + }, + return_type: Rc::new(db.rs_type_kind(return_type.clone())?), + param_types: param_types + .iter() + .map(|param_type| db.rs_type_kind(param_type.clone())) + .collect::>()?, + // TODO(okabayashi): use something more sophisticated than the mangled name + // of the class template specialization. + thunk_ident: syn::parse_str(record.rs_name.identifier.as_ref()) + .expect("should be a valid identifier"), + })) + } }; Ok(Some(bridge_rs_type_kind)) @@ -1123,6 +1199,10 @@ impl RsTypeKind { BridgeRsTypeKind::StdOptional(t) => t.implements_copy(), BridgeRsTypeKind::StdPair(t1, t2) => t1.implements_copy() && t2.implements_copy(), BridgeRsTypeKind::StdString { .. } => false, + BridgeRsTypeKind::DynCallable { .. } => { + // DynCallable represents an owned function object, so it is not copyable. + false + } }, RsTypeKind::ExistingRustType(_) => true, RsTypeKind::C9Co { .. } => false, @@ -1602,6 +1682,10 @@ impl RsTypeKind { quote! { ::cc_std::std::string } } } + BridgeRsTypeKind::DynCallable(dyn_callable) => { + let dyn_callable_spelling = dyn_callable.dyn_fn_spelling(db); + quote! { ::alloc::boxed::Box<#dyn_callable_spelling> } + } } } RsTypeKind::ExistingRustType(existing_rust_type) => fully_qualify_type( @@ -1756,6 +1840,10 @@ impl<'ty> Iterator for RsTypeKindIter<'ty> { self.todo.push(t1); } BridgeRsTypeKind::StdString { .. } => {} + BridgeRsTypeKind::DynCallable(dyn_callable) => { + self.todo.push(&dyn_callable.return_type); + self.todo.extend(dyn_callable.param_types.iter().rev()); + } }, RsTypeKind::ExistingRustType(_) => {} RsTypeKind::C9Co { result_type, .. } => { diff --git a/rs_bindings_from_cc/generate_bindings/lib.rs b/rs_bindings_from_cc/generate_bindings/lib.rs index 2f2621445..51e996fb1 100644 --- a/rs_bindings_from_cc/generate_bindings/lib.rs +++ b/rs_bindings_from_cc/generate_bindings/lib.rs @@ -457,6 +457,14 @@ fn is_rs_type_kind_unsafe(db: &dyn BindingsGenerator, rs_type_kind: RsTypeKind) || db.is_rs_type_kind_unsafe(t2.as_ref().clone()) } BridgeRsTypeKind::StdString { .. } => false, + BridgeRsTypeKind::DynCallable(dyn_callable) => { + db.is_rs_type_kind_unsafe(dyn_callable.return_type.as_ref().clone()) + || dyn_callable + .param_types + .iter() + .cloned() + .any(|param_type| db.is_rs_type_kind_unsafe(param_type)) + } }, RsTypeKind::Record { record, .. } => is_record_unsafe(db, &record), RsTypeKind::C9Co { result_type, .. } => { @@ -727,6 +735,7 @@ fn crubit_abi_type(db: &dyn BindingsGenerator, rs_type_kind: RsTypeKind) -> Resu Ok(CrubitAbiType::Pair(Rc::from(first_abi), Rc::from(second_abi))) } BridgeRsTypeKind::StdString { in_cc_std } => Ok(CrubitAbiType::StdString { in_cc_std }), + BridgeRsTypeKind::DynCallable(_) => bail!("DynCallable is not supported yet"), }, RsTypeKind::Record { record, crate_path, .. } => { ensure!( diff --git a/rs_bindings_from_cc/importers/cxx_record.cc b/rs_bindings_from_cc/importers/cxx_record.cc index fa30f55f3..6fc839360 100644 --- a/rs_bindings_from_cc/importers/cxx_record.cc +++ b/rs_bindings_from_cc/importers/cxx_record.cc @@ -624,6 +624,71 @@ absl::StatusOr GetTemplateSpecializationKind( return TemplateSpecialization::NonSpecial(); } +// Returns the `DynCallable` information for the given `specialization_decl`. +// +// This should only be called on template specializations of +// `rs_std::DynCallable`. +absl::StatusOr GetDynCallable( + ImportContext& ictx, + const clang::ClassTemplateSpecializationDecl& specialization_decl) { + if (specialization_decl.getTemplateArgs().size() != 1) { + return absl::InvalidArgumentError( + "DynCallable template specialization must have exactly one template " + "argument"); + } + const clang::FunctionProtoType* sig_fn_type = + specialization_decl.getTemplateArgs() + .get(0) + .getAsType() + .getTypePtr() + ->getAs(); + + if (sig_fn_type == nullptr) { + return absl::InvalidArgumentError( + "Failed to get function signature for DynCallable"); + } + + // Extract the function kind based on the qualifiers. + BridgeType::DynCallable::FnKind fn_kind; + if (sig_fn_type->getRefQualifier() == clang::RQ_RValue) { + // Regardless of whether it's && or const &&, it's a FnOnce. + fn_kind = BridgeType::DynCallable::FnKind::kFnOnce; + } else if (sig_fn_type->getMethodQuals().hasConst()) { + fn_kind = BridgeType::DynCallable::FnKind::kFn; + } else { + fn_kind = BridgeType::DynCallable::FnKind::kFnMut; + } + + // Convert the return type, ensuring that it is complete first. + if (sig_fn_type->getReturnType()->isIncompleteType()) { + (void)ictx.sema_.isCompleteType(specialization_decl.getLocation(), + sig_fn_type->getReturnType()); + } + CRUBIT_ASSIGN_OR_RETURN(CcType return_type, + ictx.ConvertQualType(sig_fn_type->getReturnType(), + /*lifetimes=*/nullptr)); + + std::vector param_types; + // Convert the parameter types, ensuring that they are complete first. + param_types.reserve(sig_fn_type->getNumParams()); + for (clang::QualType param_type : sig_fn_type->getParamTypes()) { + if (param_type->isIncompleteType()) { + (void)ictx.sema_.isCompleteType(specialization_decl.getLocation(), + param_type); + } + CRUBIT_ASSIGN_OR_RETURN( + CcType param_cc_type, + ictx.ConvertQualType(param_type, /*lifetimes=*/nullptr)); + param_types.push_back(std::move(param_cc_type)); + } + + return BridgeType::DynCallable{ + .fn_kind = fn_kind, + .return_type = std::make_shared(std::move(return_type)), + .param_types = std::move(param_types), + }; +} + } // namespace std::optional CXXRecordDeclImporter::GetTranslatedFieldName( @@ -812,6 +877,24 @@ std::optional CXXRecordDeclImporter::Import( owning_target = ts.defining_target; } + const clang::CXXRecordDecl* templated_decl = + specialization_decl->getSpecializedTemplate()->getTemplatedDecl(); + if (IsTopLevelNamespace("rs_std", templated_decl->getDeclContext())) { + if (templated_decl->getName() == "DynCallable") { + LOG_IF(FATAL, specialization_decl->getTemplateArgs().size() != 1) + << "rs_std::DynCallable should have one template arg"; + absl::StatusOr status_or_dyn_callable = + GetDynCallable(ictx_, *specialization_decl); + if (!status_or_dyn_callable.ok()) { + return ictx_.ImportUnsupportedItem( + *record_decl, std::nullopt, + FormattedError::FromStatus( + std::move(status_or_dyn_callable).status())); + } + bridge_type.emplace().variant = *std::move(status_or_dyn_callable); + } + } + if (!bridge_type.has_value()) { absl::StatusOr> builtin_bridge_type = GetBuiltinBridgeType(specialization_decl); diff --git a/rs_bindings_from_cc/ir.cc b/rs_bindings_from_cc/ir.cc index 93639dd1d..154d96b84 100644 --- a/rs_bindings_from_cc/ir.cc +++ b/rs_bindings_from_cc/ir.cc @@ -446,6 +446,17 @@ llvm::json::Value SizeAlign::ToJson() const { }; } +static llvm::json::Value toJSON(BridgeType::DynCallable::FnKind fn_kind) { + switch (fn_kind) { + case BridgeType::DynCallable::FnKind::kFn: + return "Fn"; + case BridgeType::DynCallable::FnKind::kFnMut: + return "FnMut"; + case BridgeType::DynCallable::FnKind::kFnOnce: + return "FnOnce"; + } +} + llvm::json::Value BridgeType::ToJson() const { return std::visit( visitor{ @@ -492,7 +503,16 @@ llvm::json::Value BridgeType::ToJson() const { }, }}; }, - }, + [&](const BridgeType::DynCallable& dyn_callable) { + return llvm::json::Object{{ + "DynCallable", + llvm::json::Object{ + {"fn_kind", dyn_callable.fn_kind}, + {"return_type", dyn_callable.return_type->ToJson()}, + {"param_types", dyn_callable.param_types}, + }, + }}; + }}, variant); } diff --git a/rs_bindings_from_cc/ir.h b/rs_bindings_from_cc/ir.h index 9bbaf226d..add74d1f8 100644 --- a/rs_bindings_from_cc/ir.h +++ b/rs_bindings_from_cc/ir.h @@ -557,8 +557,19 @@ struct BridgeType { struct StdString {}; + struct DynCallable { + enum FnKind { + kFn, + kFnMut, + kFnOnce, + }; + FnKind fn_kind; + std::shared_ptr return_type; + std::vector param_types; + }; + std::variant + ProtoMessageBridge, StdString, DynCallable> variant; }; diff --git a/rs_bindings_from_cc/ir.rs b/rs_bindings_from_cc/ir.rs index 0706b2ad0..745db40c1 100644 --- a/rs_bindings_from_cc/ir.rs +++ b/rs_bindings_from_cc/ir.rs @@ -1039,6 +1039,14 @@ pub struct SizeAlign { pub alignment: usize, } +#[derive(Debug, PartialEq, Eq, Hash, Clone, Deserialize)] +#[serde(deny_unknown_fields)] +pub enum FnKind { + Fn, + FnMut, + FnOnce, +} + #[derive(Debug, PartialEq, Eq, Hash, Clone, Deserialize)] #[serde(deny_unknown_fields)] pub enum BridgeType { @@ -1059,6 +1067,11 @@ pub enum BridgeType { StdOptional(CcType), StdPair(CcType, CcType), StdString, + DynCallable { + fn_kind: FnKind, + return_type: CcType, + param_types: Vec, + }, } #[derive(Debug, PartialEq, Eq, Hash, Clone, Deserialize)] diff --git a/rs_bindings_from_cc/test/golden/callables_rs_api.rs b/rs_bindings_from_cc/test/golden/callables_rs_api.rs index bbed9ca85..5613d8c12 100644 --- a/rs_bindings_from_cc/test/golden/callables_rs_api.rs +++ b/rs_bindings_from_cc/test/golden/callables_rs_api.rs @@ -15,19 +15,13 @@ #![deny(warnings)] // Error while generating bindings for function 'apply': -// Can't generate bindings for apply, because of missing required features (): -// //rs_bindings_from_cc/test/golden:callables_cc needs [//features:wrapper] for apply (the type of callback (parameter #0): error: Can't generate bindings for rs_std::DynCallable, because of missing required features (): -// //rs_bindings_from_cc/test/golden:callables_cc needs [//features:wrapper] for rs_std::DynCallable (crate::__CcTemplateInstN6rs_std11DynCallableIKFiiEEE is a template instantiation)) +// while generating bridge param 'callback': DynCallable is not supported yet // Error while generating bindings for function 'apply_mut': -// Can't generate bindings for apply_mut, because of missing required features (): -// //rs_bindings_from_cc/test/golden:callables_cc needs [//features:wrapper] for apply_mut (the type of callback (parameter #0): error: Can't generate bindings for rs_std::DynCallable, because of missing required features (): -// //rs_bindings_from_cc/test/golden:callables_cc needs [//features:wrapper] for rs_std::DynCallable (crate::__CcTemplateInstN6rs_std11DynCallableIFiiEEE is a template instantiation)) +// while generating bridge param 'callback': DynCallable is not supported yet // Error while generating bindings for function 'apply_once': -// Can't generate bindings for apply_once, because of missing required features (): -// //rs_bindings_from_cc/test/golden:callables_cc needs [//features:wrapper] for apply_once (the type of callback (parameter #0): error: Can't generate bindings for rs_std::DynCallable, because of missing required features (): -// //rs_bindings_from_cc/test/golden:callables_cc needs [//features:wrapper] for rs_std::DynCallable (crate::__CcTemplateInstN6rs_std11DynCallableIFiiOEEE is a template instantiation)) +// while generating bridge param 'callback': DynCallable is not supported yet #[derive(Clone, Copy, ::ctor::MoveAndAssignViaCopy)] #[repr(C, align(4))] @@ -72,9 +66,7 @@ impl NotCABICompatible { } // Error while generating bindings for function 'rust_inspect_non_c_abi_compatible_struct': -// Can't generate bindings for rust_inspect_non_c_abi_compatible_struct, because of missing required features (): -// //rs_bindings_from_cc/test/golden:callables_cc needs [//features:wrapper] for rust_inspect_non_c_abi_compatible_struct (the type of cb (parameter #0): error: Can't generate bindings for rs_std::DynCallable, because of missing required features (): -// //rs_bindings_from_cc/test/golden:callables_cc needs [//features:wrapper] for rs_std::DynCallable (crate::__CcTemplateInstN6rs_std11DynCallableIF17NotCABICompatibleS1_EEE is a template instantiation)) +// while generating bridge param 'cb': DynCallable is not supported yet // Error while generating bindings for struct 'std::integral_constant': // Can't generate bindings for std::integral_constant, because of missing required features (): @@ -84,22 +76,6 @@ impl NotCABICompatible { // Can't generate bindings for std::integral_constant, because of missing required features (): // //rs_bindings_from_cc/test/golden:callables_cc needs [//features:wrapper] for std::integral_constant (crate::__CcTemplateInstNSt3__u17integral_constantIbLb1EEE is a template instantiation) -// Error while generating bindings for class 'rs_std::DynCallable': -// Can't generate bindings for rs_std::DynCallable, because of missing required features (): -// //rs_bindings_from_cc/test/golden:callables_cc needs [//features:wrapper] for rs_std::DynCallable (crate::__CcTemplateInstN6rs_std11DynCallableIF17NotCABICompatibleS1_EEE is a template instantiation) - -// Error while generating bindings for class 'rs_std::DynCallable': -// Can't generate bindings for rs_std::DynCallable, because of missing required features (): -// //rs_bindings_from_cc/test/golden:callables_cc needs [//features:wrapper] for rs_std::DynCallable (crate::__CcTemplateInstN6rs_std11DynCallableIFiiEEE is a template instantiation) - -// Error while generating bindings for class 'rs_std::DynCallable': -// Can't generate bindings for rs_std::DynCallable, because of missing required features (): -// //rs_bindings_from_cc/test/golden:callables_cc needs [//features:wrapper] for rs_std::DynCallable (crate::__CcTemplateInstN6rs_std11DynCallableIFiiOEEE is a template instantiation) - -// Error while generating bindings for class 'rs_std::DynCallable': -// Can't generate bindings for rs_std::DynCallable, because of missing required features (): -// //rs_bindings_from_cc/test/golden:callables_cc needs [//features:wrapper] for rs_std::DynCallable (crate::__CcTemplateInstN6rs_std11DynCallableIKFiiEEE is a template instantiation) - mod detail { #[allow(unused_imports)] use super::*; diff --git a/rs_bindings_from_cc/test/golden/callables_rs_api_impl.cc b/rs_bindings_from_cc/test/golden/callables_rs_api_impl.cc index 4108b5d9d..e60669cda 100644 --- a/rs_bindings_from_cc/test/golden/callables_rs_api_impl.cc +++ b/rs_bindings_from_cc/test/golden/callables_rs_api_impl.cc @@ -5,6 +5,7 @@ // Automatically @generated Rust bindings for the following C++ target: // //rs_bindings_from_cc/test/golden:callables_cc +#include "support/bridge.h" #include "support/internal/cxx20_backports.h" #include "support/internal/offsetof.h" #include "support/internal/sizeof.h"