From 13d2767ee2fe9cac37b07ca59b0ac41354ac60bc Mon Sep 17 00:00:00 2001 From: rob-maron <132852777+rob-maron@users.noreply.github.com> Date: Tue, 10 Mar 2026 11:40:29 -0400 Subject: [PATCH] cuda compute capability bindings --- src/nn/linear.rs | 7 ++++++- src/nn/optimizer.rs | 6 +++++- src/tensor/display.rs | 12 +++++++---- src/tensor/mod.rs | 3 +-- src/wrappers/device.rs | 12 +++++++++++ src/wrappers/kind.rs | 2 +- src/wrappers/tensor.rs | 4 ++-- tests/device_tests.rs | 11 +++++++++- torch-sys/build.rs | 37 ++++++++++++---------------------- torch-sys/libtch/torch_api.cpp | 17 ++++++++++++++++ torch-sys/libtch/torch_api.h | 3 +++ torch-sys/src/cuda.rs | 7 +++++++ 12 files changed, 85 insertions(+), 36 deletions(-) diff --git a/src/nn/linear.rs b/src/nn/linear.rs index 0b0d6013..3e5c6953 100644 --- a/src/nn/linear.rs +++ b/src/nn/linear.rs @@ -13,7 +13,12 @@ pub struct LinearConfig { impl Default for LinearConfig { fn default() -> Self { - LinearConfig { ws_init: super::init::DEFAULT_KAIMING_UNIFORM, bs_init: None, bias: true, shard: None } + LinearConfig { + ws_init: super::init::DEFAULT_KAIMING_UNIFORM, + bs_init: None, + bias: true, + shard: None, + } } } diff --git a/src/nn/optimizer.rs b/src/nn/optimizer.rs index ff035bcd..94e15255 100644 --- a/src/nn/optimizer.rs +++ b/src/nn/optimizer.rs @@ -322,7 +322,11 @@ impl Optimizer { /// Returns all the trainable variables and their sharding for this optimizer. pub fn trainable_variables_with_sharding(&self) -> Vec<(Tensor, Option)> { let variables = self.variables.lock().unwrap(); - variables.trainable_variables.iter().map(|v| (v.0.tensor.shallow_clone(), v.1.clone())).collect() + variables + .trainable_variables + .iter() + .map(|v| (v.0.tensor.shallow_clone(), v.1.clone())) + .collect() } /// Sets the optimizer weight decay. diff --git a/src/tensor/display.rs b/src/tensor/display.rs index 17806ddf..cd10165b 100644 --- a/src/tensor/display.rs +++ b/src/tensor/display.rs @@ -57,7 +57,11 @@ impl BasicKind { fn _is_floating_point(&self) -> bool { match self { BasicKind::Float => true, - BasicKind::Bool | BasicKind::Int | BasicKind::Complex | BasicKind::Bits | BasicKind::Packed => false, + BasicKind::Bool + | BasicKind::Int + | BasicKind::Complex + | BasicKind::Bits + | BasicKind::Packed => false, } } } @@ -459,9 +463,9 @@ impl std::fmt::Display for Tensor { tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; writeln!(f)?; } - BasicKind::Complex => {}, - BasicKind::Bits => {}, - BasicKind::Packed => {}, + BasicKind::Complex => {} + BasicKind::Bits => {} + BasicKind::Packed => {} }; let kind = match self.f_kind() { Ok(kind) => format!("{kind:?}"), diff --git a/src/tensor/mod.rs b/src/tensor/mod.rs index c532fa84..1f1f356e 100644 --- a/src/tensor/mod.rs +++ b/src/tensor/mod.rs @@ -11,8 +11,7 @@ mod ops; mod safetensors; pub use super::wrappers::tensor::{ - autocast, no_grad, no_grad_guard, with_grad, NoGradGuard, Reduction, - Tensor, + autocast, no_grad, no_grad_guard, with_grad, NoGradGuard, Reduction, Tensor, }; pub use index::{IndexOp, NewAxis, TensorIndexer}; diff --git a/src/wrappers/device.rs b/src/wrappers/device.rs index f7d78a8b..d9aef17a 100644 --- a/src/wrappers/device.rs +++ b/src/wrappers/device.rs @@ -80,6 +80,18 @@ impl Cuda { pub fn cudnn_set_benchmark(b: bool) { unsafe_torch!(torch_sys::cuda::atc_set_benchmark_cudnn(i32::from(b))) } + + /// Returns the compute capability of a CUDA device as a `(major, minor)` pair. + pub fn get_device_capability(device_index: usize) -> Result<(i32, i32), crate::TchError> { + let mut major: libc::c_int = 0; + let mut minor: libc::c_int = 0; + unsafe_torch_err!(torch_sys::cuda::atc_cuda_get_device_capability( + device_index as libc::c_int, + &mut major, + &mut minor, + )); + Ok((major as i32, minor as i32)) + } } impl Device { diff --git a/src/wrappers/kind.rs b/src/wrappers/kind.rs index ca0b9319..b3f049c5 100644 --- a/src/wrappers/kind.rs +++ b/src/wrappers/kind.rs @@ -42,7 +42,7 @@ pub enum Kind { UInt4, UInt5, UInt6, - UInt7 + UInt7, } impl Kind { diff --git a/src/wrappers/tensor.rs b/src/wrappers/tensor.rs index d64bc2e8..a868cbc5 100644 --- a/src/wrappers/tensor.rs +++ b/src/wrappers/tensor.rs @@ -763,14 +763,14 @@ impl Tensor { } /// Returns a tensor with pinned memory. - /// + /// /// Pinned memory allows for faster data transfer between CPU and GPU. pub fn pin_memory(&self) -> Tensor { self.f_internal_pin_memory(None::).unwrap() } /// Returns true if this tensor resides in pinned memory. - /// + /// /// Pinned memory allows for faster data transfer between CPU and GPU. pub fn is_pinned(&self) -> bool { self.f_is_pinned(None::).unwrap() diff --git a/tests/device_tests.rs b/tests/device_tests.rs index 213b70ea..974a5476 100644 --- a/tests/device_tests.rs +++ b/tests/device_tests.rs @@ -1,7 +1,16 @@ -use tch::{Device, Tensor}; +use tch::{Cuda, Device, Tensor}; #[test] fn tensor_device() { let t = Tensor::from_slice(&[3, 1, 4]); assert_eq!(t.device(), Device::Cpu) } + +#[test] +fn cuda_device_capability() { + if Cuda::is_available() { + let (major, minor) = Cuda::get_device_capability(0).unwrap(); + assert!(major > 0, "expected a positive major compute capability, got {major}"); + assert!(minor >= 0, "expected a non-negative minor compute capability, got {minor}"); + } +} diff --git a/torch-sys/build.rs b/torch-sys/build.rs index f9719f25..0c70ac18 100644 --- a/torch-sys/build.rs +++ b/torch-sys/build.rs @@ -146,10 +146,7 @@ fn extract>(filename: P, outpath: P) -> anyhow::Result<()> { // This is if we're unzipping a python wheel. if outpath.as_ref().join("torch").exists() { - fs::rename( - outpath.as_ref().join("torch"), - outpath.as_ref().join("libtorch"), - )?; + fs::rename(outpath.as_ref().join("torch"), outpath.as_ref().join("libtorch"))?; } Ok(()) } @@ -179,10 +176,7 @@ fn version_check(version: &str) -> Result<()> { impl SystemInfo { fn new() -> Result { - let os = match env::var("CARGO_CFG_TARGET_OS") - .expect("Unable to get TARGET_OS") - .as_str() - { + let os = match env::var("CARGO_CFG_TARGET_OS").expect("Unable to get TARGET_OS").as_str() { "linux" => Os::Linux, "windows" => Os::Windows, "macos" => Os::Macos, @@ -276,10 +270,7 @@ impl SystemInfo { libtorch_include_dirs.push(includes.join("include")); libtorch_include_dirs.push(includes.join("include/torch/csrc/api/include")); libtorch_lib_dir = Some(lib.join("lib")); - ( - env_var_rerun("LIBTORCH_CXX11_ABI").unwrap_or_else(|_| "1".to_owned()), - None, - ) + (env_var_rerun("LIBTORCH_CXX11_ABI").unwrap_or_else(|_| "1".to_owned()), None) }; if let Ok(cuda_root) = env_var_rerun("CUDA_ROOT") { libtorch_include_dirs.push(PathBuf::from(cuda_root).join("include")) @@ -302,9 +293,7 @@ impl SystemInfo { fn check_system_location(os: Os) -> Option { match os { - Os::Linux => Path::new("/usr/lib/libtorch.so") - .exists() - .then(|| PathBuf::from("/usr")), + Os::Linux => Path::new("/usr/lib/libtorch.so").exists().then(|| PathBuf::from("/usr")), _ => None, } } @@ -417,11 +406,8 @@ impl SystemInfo { println!("cargo:rerun-if-changed=libtch/stb_image_write.h"); println!("cargo:rerun-if-changed=libtch/stb_image_resize.h"); println!("cargo:rerun-if-changed=libtch/stb_image.h"); - let mut c_files = vec![ - "libtch/torch_api.cpp", - "libtch/torch_api_generated.cpp", - cuda_dependency, - ]; + let mut c_files = + vec!["libtch/torch_api.cpp", "libtch/torch_api_generated.cpp", cuda_dependency]; if cfg!(feature = "python-extension") { c_files.push("libtch/torch_python.cpp") } @@ -442,6 +428,9 @@ impl SystemInfo { .flag("-std=c++17") .flag(format!("-D_GLIBCXX_USE_CXX11_ABI={}", self.cxx11_abi)) .flag("-DGLOG_USE_GLOG_EXPORT"); + if use_cuda { + builder.define("USE_CUDA", None); + } if cfg!(feature = "nccl") { builder.flag("-DUSE_C10D_NCCL"); } @@ -459,6 +448,9 @@ impl SystemInfo { .includes(&self.libtorch_include_dirs) .flag("/std:c++17") .flag("/p:DefineConstants=GLOG_USE_GLOG_EXPORT"); + if use_cuda { + builder.define("USE_CUDA", None); + } if cfg!(feature = "nccl") { builder.flag("/p:DefineConstants=USE_C10D_NCCL"); } @@ -531,10 +523,7 @@ fn main() -> anyhow::Result<()> { system_info.link("torch_python"); system_info.link(&format!( "python{}", - system_info - .python_version - .as_ref() - .expect("python version is set") + system_info.python_version.as_ref().expect("python version is set") )); } if system_info.link_type == LinkType::Static { diff --git a/torch-sys/libtch/torch_api.cpp b/torch-sys/libtch/torch_api.cpp index 31e6c963..174c93b3 100644 --- a/torch-sys/libtch/torch_api.cpp +++ b/torch-sys/libtch/torch_api.cpp @@ -7,6 +7,9 @@ #include #include #include +#ifdef USE_CUDA +#include +#endif #include #include #include @@ -1022,6 +1025,20 @@ void atc_set_benchmark_cudnn(int b) { ) } +void atc_cuda_get_device_capability(int device_index, int *major, int *minor) { +#if defined(USE_CUDA) + PROTECT( + auto props = at::cuda::getDeviceProperties(device_index); + *major = props->major; + *minor = props->minor; + ) +#else + *major = -1; + *minor = -1; + torch_last_err = strdup("CUDA is not available in this build"); +#endif +} + bool at_context_has_openmp() { PROTECT ( return at::globalContext().hasOpenMP(); diff --git a/torch-sys/libtch/torch_api.h b/torch-sys/libtch/torch_api.h index f019fcff..68687ee4 100644 --- a/torch-sys/libtch/torch_api.h +++ b/torch-sys/libtch/torch_api.h @@ -197,6 +197,9 @@ void atc_manual_seed_all(uint64_t seed); /// Waits for all kernels in all streams on a CUDA device to complete. void atc_synchronize(int64_t device_index); +/// Retrieves the compute capability of a CUDA device. +void atc_cuda_get_device_capability(int device_index, int *major, int *minor); + int atc_user_enabled_cudnn(); void atc_set_user_enabled_cudnn(int b); diff --git a/torch-sys/src/cuda.rs b/torch-sys/src/cuda.rs index 7ea4bbef..1da3231e 100644 --- a/torch-sys/src/cuda.rs +++ b/torch-sys/src/cuda.rs @@ -27,4 +27,11 @@ extern "C" { /// Sets CUDNN benchmark mode. pub fn atc_set_benchmark_cudnn(b: c_int); + + /// Retrieves the compute capability of a CUDA device. + pub fn atc_cuda_get_device_capability( + device_index: c_int, + major: *mut c_int, + minor: *mut c_int, + ); }