Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ repository = "https://github.com/mveril/wslplugins-rs"
description = "A Rust framework for developing WSL plugins using safe and idiomatic Rust."

[patch.crates-io]
wslpluginapi-sys = { version = "0.1.0-rc.1", git = "https://github.com/mveril/wslpluginapi-sys.git", branch = "release/0.1.0-rc.1+2.4.4" }
wslpluginapi-sys = { version = "0.1.0-rc.1", git = "https://github.com/mveril/wslpluginapi-sys.git", branch = "develop" }
69 changes: 30 additions & 39 deletions wslplugins-macro-core/src/generator/c_funcs_tokens.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,99 +13,90 @@ pub(super) fn get_c_func_tokens(hook: Hooks) -> Result<Option<TokenStream>> {
extern "C" fn #c_method_ident(
session: *const ::wslplugins_rs::sys::WSLSessionInformation,
settings: *const ::wslplugins_rs::sys::WSLVmCreationSettings,
) -> ::windows::core::HRESULT {
) -> ::wslplugins_rs::sys::windows_sys::core::HRESULT {
let session_ptr = unsafe { &*session };
let settings_ptr = unsafe { &*settings };
if let Some(plugin) = PLUGIN.get() {
PLUGIN.get().map(|plugin|{
let result = plugin.#trait_method_ident(
session_ptr.as_ref(),
settings_ptr.as_ref(),
);
::wslplugins_rs::plugin::utils::consume_to_win_result(result).into()
} else {
::windows::Win32::Foundation::E_FAIL
}
::wslplugins_rs::windows_core::HRESULT::from(::wslplugins_rs::plugin::utils::consume_to_win_result(result)).0
}).unwrap_or(::wslplugins_rs::sys::windows_sys::Win32::Foundation::E_FAIL)
}
}),
Hooks::OnVMStopping => Some(quote! {
extern "C" fn #c_method_ident(
session: *const ::wslplugins_rs::sys::WSLSessionInformation
) -> ::windows::core::HRESULT {
) -> ::wslplugins_rs::sys::windows_sys::core::HRESULT {
let session_ptr = unsafe { &*session };
if let Some(plugin) = PLUGIN.get() {
plugin.#trait_method_ident(session_ptr.as_ref()).into()
} else {
::windows::Win32::Foundation::E_FAIL
}
PLUGIN.get()
.map(|plugin| {
let result = plugin.#trait_method_ident(session_ptr.as_ref());
::wslplugins_rs::windows_core::HRESULT::from(result).0
})
.unwrap_or(::wslplugins_rs::sys::windows_sys::Win32::Foundation::E_FAIL)
}
}),
Hooks::OnDistributionStarted => Some(quote! {
extern "C" fn #c_method_ident(
session: *const ::wslplugins_rs::sys::WSLSessionInformation,
distribution: *const ::wslplugins_rs::sys::WSLDistributionInformation,
) -> ::windows::core::HRESULT {
) -> ::wslplugins_rs::sys::windows_sys::core::HRESULT {
let session_ptr = unsafe { &*session };
let distribution_ptr = unsafe { &*distribution };
if let Some(plugin) = PLUGIN.get() {
PLUGIN.get().map(|plugin|{
let result = plugin.#trait_method_ident(
session_ptr.as_ref(),
distribution_ptr.as_ref(),
);
::wslplugins_rs::plugin::utils::consume_to_win_result(result).into()
} else {
::windows::Win32::Foundation::E_FAIL
}
::wslplugins_rs::windows_core::HRESULT::from(::wslplugins_rs::plugin::utils::consume_to_win_result(result)).0
}).unwrap_or(::wslplugins_rs::sys::windows_sys::Win32::Foundation::E_FAIL)
}
}),
Hooks::OnDistributionStopping => Some(quote! {
extern "C" fn #c_method_ident(
session: *const ::wslplugins_rs::sys::WSLSessionInformation,
distribution: *const ::wslplugins_rs::sys::WSLDistributionInformation,
) -> ::windows::core::HRESULT {
) -> ::wslplugins_rs::sys::windows_sys::core::HRESULT {
let session_ptr = unsafe { &*session };
let distribution_ptr = unsafe { &*distribution };
if let Some(plugin) = PLUGIN.get() {
plugin.#trait_method_ident(
PLUGIN.get().map(|plugin|{
::wslplugins_rs::windows_core::HRESULT::from(plugin.#trait_method_ident(
session_ptr.as_ref(),
distribution_ptr.as_ref(),
).into()
} else {
::windows::Win32::Foundation::E_FAIL
}
)).0
}).unwrap_or(::wslplugins_rs::sys::windows_sys::Win32::Foundation::E_FAIL)
}
}),
Hooks::OnDistributionRegistered => Some(quote! {
extern "C" fn #c_method_ident(
session: *const ::wslplugins_rs::sys::WSLSessionInformation,
distribution: *const ::wslplugins_rs::sys::WSLOfflineDistributionInformation,
) -> ::windows::core::HRESULT {
) -> ::wslplugins_rs::sys::windows_sys::core::HRESULT {
let session_ptr = unsafe { &*session };
let distribution_ptr = unsafe { &*distribution };
if let Some(plugin) = PLUGIN.get() {
plugin.#trait_method_ident(
PLUGIN.get().map(|plugin|{
::wslplugins_rs::windows_core::HRESULT::from(plugin.#trait_method_ident(
session_ptr.as_ref(),
distribution_ptr.as_ref(),
).into()
} else {
::windows::Win32::Foundation::E_FAIL
}
)).0
}).unwrap_or(::wslplugins_rs::sys::windows_sys::Win32::Foundation)
}
}),
Hooks::OnDistributionUnregistered => Some(quote! {
extern "C" fn #c_method_ident(
session: *const ::wslplugins_rs::sys::WSLSessionInformation,
distribution: *const ::wslplugins_rs::sys::WSLOfflineDistributionInformation,
) -> ::windows::core::HRESULT {
) -> ::wslplugins_rs::sys::windows_sys::core::HRESULT {
let session_ptr = unsafe { &*session };
let distribution_ptr = unsafe { &*distribution };
if let Some(plugin) = PLUGIN.get() {
plugin.#trait_method_ident(
PLUGIN.get().map(|plugin|{
::wslplugins_rs::windows_core::HRESULT::from(plugin.#trait_method_ident(
session_ptr.as_ref(),
distribution_ptr.as_ref(),
).into()
} else {
::windows::Win32::Foundation::E_FAIL
}
)).0
}).unwrap_or(::wslplugins_rs::sys::windows_sys::Win32::Foundation)
}
}),
};
Expand Down
6 changes: 3 additions & 3 deletions wslplugins-macro-core/src/generator/hook_field_mapping.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ fn generate_entry_point(imp: &ParsedImpl, version: &RequiredVersion) -> Result<T
pub unsafe extern "C" fn WSLPluginAPIV1_EntryPoint(
api: *const ::wslplugins_rs::sys::WSLPluginAPIV1,
hooks: *mut ::wslplugins_rs::sys::WSLPluginHooksV1,
) -> ::windows::core::HRESULT {
) -> ::wslplugins_rs::windows_core::HRESULT {
unsafe {
let api_ref: &'static ::wslplugins_rs::sys::WSLPluginAPIV1 = unsafe { &*api};
let #hooks_ref_name: &mut ::wslplugins_rs::sys::WSLPluginHooksV1 = unsafe{ &mut *hooks };
Expand All @@ -109,10 +109,10 @@ fn generate_entry_point(imp: &ParsedImpl, version: &RequiredVersion) -> Result<T
fn create_plugin(
api: &'static ::wslplugins_rs::sys::WSLPluginAPIV1,
hooks_ref: &mut ::wslplugins_rs::sys::WSLPluginHooksV1,
) -> ::windows::core::Result<()> {
) -> ::wslplugins_rs::windows_core::Result<()> {
let plugin: #static_plugin_type = ::wslplugins_rs::plugin::create_plugin_with_required_version(api, #major, #minor, #revision)?;
#(#hook_set)*
PLUGIN.set(plugin).map_err(|_| ::windows::core::Error::from(::windows::Win32::Foundation::E_ABORT))
PLUGIN.set(plugin).map_err(|_| ::wslplugins_rs::windows_core::Error::from(::wslplugins_rs::windows_core::HRESULT(::wslplugins_rs::sys::windows_sys::Win32::Foundation::E_ABORT)))
}
})
}
Expand Down
4 changes: 1 addition & 3 deletions wslplugins-rs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,9 @@ repository.workspace = true
description.workspace = true
edition = "2021"

[dependencies.windows]
workspace = true
features = ["Win32_Foundation", "Win32_System", "Win32_Networking_WinSock"]

[dependencies]
windows-core = "0.61.2"
bitflags = { version = ">0.1.0", optional = true }
enumflags2 = { version = ">0.5", optional = true }
flagset = { version = ">0.1.0", optional = true }
Expand Down
51 changes: 25 additions & 26 deletions wslplugins-rs/src/api/api_v1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ use std::os::windows::raw::SOCKET;
use std::path::Path;
use typed_path::Utf8UnixPath;
use widestring::U16CString;
use windows::core::{Result as WinResult, BOOL, GUID, PCSTR, PCWSTR};
use windows::Win32::Networking::WinSock::SOCKET as WinSocket;
use windows_core::{Result as WinResult, GUID, HRESULT};
use wslpluginapi_sys::windows_sys::Win32::Networking::WinSock::SOCKET as WinSocket;

use wslpluginapi_sys::WSLPluginAPIV1;

Expand Down Expand Up @@ -104,13 +104,13 @@ impl ApiV1 {
let result = unsafe {
self.0.MountFolder.unwrap_unchecked()(
session.id(),
PCWSTR::from_raw(encoded_windows_path.as_ptr()),
PCWSTR::from_raw(encoded_linux_path.as_ptr()),
BOOL::from(read_only),
PCWSTR::from_raw(encoded_name.as_ptr()),
encoded_windows_path.as_ptr(),
encoded_linux_path.as_ptr(),
read_only as i32,
encoded_name.as_ptr(),
)
};
result.ok()
HRESULT(result).ok()
}

/// Execute a program in the root namespace.
Expand All @@ -129,7 +129,7 @@ impl ApiV1 {
/// - **Standard Output**: Data output by the process will be readable from the stream.
///
/// # Errors
/// This method can return the following a [`windows::core::Error`]: If the underlying Windows API call fails.
/// This method can return the following a [windows_core::Error]: If the underlying Windows API call fails.
///
/// # Example
/// ```rust,ignore
Expand Down Expand Up @@ -162,23 +162,23 @@ impl ApiV1 {
.iter()
.map(|&arg| CString::from_str_truncate(arg))
.collect();
let mut args_ptrs: Vec<PCSTR> = c_args
let mut args_ptrs: Vec<*const u8> = c_args
.iter()
.map(|arg| PCSTR::from_raw(arg.as_ptr() as *const u8))
.chain(Some(PCSTR::null()))
.map(|arg| arg.as_ptr() as *const u8)
.chain(once(std::ptr::null::<u8>()))
.collect();
let args_ptr = args_ptrs.as_mut_ptr();
let mut socket = MaybeUninit::<WinSocket>::uninit();
let stream = unsafe {
self.0.ExecuteBinary.unwrap_unchecked()(
HRESULT(self.0.ExecuteBinary.unwrap_unchecked()(
session.id(),
PCSTR::from_raw(c_path.as_ptr()),
c_path.as_ptr(),
args_ptr,
socket.as_mut_ptr(),
)
))
.ok()?;
let socket = socket.assume_init();
TcpStream::from_raw_socket(socket.0 as SOCKET)
TcpStream::from_raw_socket(socket as SOCKET)
};
Ok(stream)
}
Expand All @@ -187,9 +187,7 @@ impl ApiV1 {
#[cfg_attr(feature = "log-instrument", instrument)]
pub(crate) fn plugin_error(&self, error: &OsStr) -> WinResult<()> {
let error_utf16 = U16CString::from_os_str_truncate(error);
unsafe {
self.0.PluginError.unwrap_unchecked()(PCWSTR::from_raw(error_utf16.as_ptr())).ok()
}
HRESULT(unsafe { self.0.PluginError.unwrap_unchecked()(error_utf16.as_ptr()) }).ok()
}

/// Execute a program in a user distribution
Expand Down Expand Up @@ -241,29 +239,30 @@ impl ApiV1 {
.copied()
.chain(once(0))
.collect();
let path_ptr = PCSTR::from_raw(c_path.as_ptr());
let path_ptr = c_path.as_ptr();
let c_args: Vec<CString> = args
.iter()
.map(|&arg| CString::from_str_truncate(arg))
.collect();
let mut args_ptrs: Vec<PCSTR> = c_args
let mut args_ptrs: Vec<_> = c_args
.iter()
.map(|arg| PCSTR::from_raw(arg.as_ptr() as *const u8))
.chain(Some(PCSTR::null()))
.map(|arg| arg.as_ptr() as *const u8)
.chain(once(std::ptr::null()))
.collect();
let args_ptr = args_ptrs.as_mut_ptr();
let mut socket = MaybeUninit::<WinSocket>::uninit();
let stream = unsafe {
self.0.ExecuteBinaryInDistribution.unwrap_unchecked()(
HRESULT(self.0.ExecuteBinaryInDistribution.unwrap_unchecked()(
session.id(),
&distribution_id,
(&distribution_id) as *const GUID
as *const wslpluginapi_sys::windows_sys::core::GUID,
path_ptr,
args_ptr,
socket.as_mut_ptr(),
)
))
.ok()?;
let socket = socket.assume_init();
TcpStream::from_raw_socket(socket.0 as SOCKET)
TcpStream::from_raw_socket(socket as SOCKET)
};
Ok(stream)
}
Expand Down
4 changes: 2 additions & 2 deletions wslplugins-rs/src/api/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
use thiserror::Error;
pub mod require_update_error;
pub use require_update_error::Error as RequireUpdateError;
use windows::core::{Error as WinError, HRESULT};
use windows_core::{Error as WinError, HRESULT};
use wslpluginapi_sys::WSL_E_PLUGIN_REQUIRES_UPDATE;

/// A comprehensive error type for WSL plugins.
Expand Down Expand Up @@ -52,7 +52,7 @@ impl From<Error> for WinError {
/// A `WinError` representing the error.
fn from(value: Error) -> Self {
match value {
Error::RequiresUpdate { .. } => WSL_E_PLUGIN_REQUIRES_UPDATE.into(),
Error::RequiresUpdate { .. } => HRESULT(WSL_E_PLUGIN_REQUIRES_UPDATE).into(),
Error::WinError(error) => error,
}
}
Expand Down
4 changes: 2 additions & 2 deletions wslplugins-rs/src/api/errors/require_update_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

use crate::WSLVersion;
use thiserror::Error;
use windows::core::HRESULT;
use windows_core::HRESULT;
use wslpluginapi_sys::WSL_E_PLUGIN_REQUIRES_UPDATE;

/// Represents an error when the current WSL version is unsupported.
Expand Down Expand Up @@ -51,7 +51,7 @@ impl From<Error> for HRESULT {
/// # Returns
/// - `[WSL_E_PLUGIN_REQUIRES_UPDATE]: Indicates the WSL version is insufficient for the plugin.
fn from(_: Error) -> Self {
WSL_E_PLUGIN_REQUIRES_UPDATE
HRESULT(WSL_E_PLUGIN_REQUIRES_UPDATE)
}
}

Expand Down
2 changes: 1 addition & 1 deletion wslplugins-rs/src/core_distribution_information.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

use crate::api::errors::require_update_error::Result;
use std::ffi::OsString;
use windows::core::GUID;
use windows_core::GUID;

/// A trait representing the core information of a WSL distribution.
///
Expand Down
Loading