Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/nn/linear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@ pub struct LinearConfig {

impl Default for LinearConfig {
fn default() -> Self {
LinearConfig { ws_init: super::init::DEFAULT_KAIMING_UNIFORM, bs_init: None, bias: true, shard: None }
LinearConfig {
ws_init: super::init::DEFAULT_KAIMING_UNIFORM,
bs_init: None,
bias: true,
shard: None,
}
}
}

Expand Down
6 changes: 5 additions & 1 deletion src/nn/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,11 @@ impl Optimizer {
/// Returns all the trainable variables and their sharding for this optimizer.
pub fn trainable_variables_with_sharding(&self) -> Vec<(Tensor, Option<Shard>)> {
let variables = self.variables.lock().unwrap();
variables.trainable_variables.iter().map(|v| (v.0.tensor.shallow_clone(), v.1.clone())).collect()
variables
.trainable_variables
.iter()
.map(|v| (v.0.tensor.shallow_clone(), v.1.clone()))
.collect()
}

/// Sets the optimizer weight decay.
Expand Down
12 changes: 8 additions & 4 deletions src/tensor/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,11 @@ impl BasicKind {
fn _is_floating_point(&self) -> bool {
match self {
BasicKind::Float => true,
BasicKind::Bool | BasicKind::Int | BasicKind::Complex | BasicKind::Bits | BasicKind::Packed => false,
BasicKind::Bool
| BasicKind::Int
| BasicKind::Complex
| BasicKind::Bits
| BasicKind::Packed => false,
}
}
}
Expand Down Expand Up @@ -459,9 +463,9 @@ impl std::fmt::Display for Tensor {
tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
writeln!(f)?;
}
BasicKind::Complex => {},
BasicKind::Bits => {},
BasicKind::Packed => {},
BasicKind::Complex => {}
BasicKind::Bits => {}
BasicKind::Packed => {}
};
let kind = match self.f_kind() {
Ok(kind) => format!("{kind:?}"),
Expand Down
3 changes: 1 addition & 2 deletions src/tensor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ mod ops;
mod safetensors;

pub use super::wrappers::tensor::{
autocast, no_grad, no_grad_guard, with_grad, NoGradGuard, Reduction,
Tensor,
autocast, no_grad, no_grad_guard, with_grad, NoGradGuard, Reduction, Tensor,
};
pub use index::{IndexOp, NewAxis, TensorIndexer};

Expand Down
12 changes: 12 additions & 0 deletions src/wrappers/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,18 @@ impl Cuda {
pub fn cudnn_set_benchmark(b: bool) {
unsafe_torch!(torch_sys::cuda::atc_set_benchmark_cudnn(i32::from(b)))
}

/// Returns the compute capability of a CUDA device as a `(major, minor)` pair.
pub fn get_device_capability(device_index: usize) -> Result<(i32, i32), crate::TchError> {
let mut major: libc::c_int = 0;
let mut minor: libc::c_int = 0;
unsafe_torch_err!(torch_sys::cuda::atc_cuda_get_device_capability(
device_index as libc::c_int,
&mut major,
&mut minor,
));
Ok((major as i32, minor as i32))
}
}

impl Device {
Expand Down
2 changes: 1 addition & 1 deletion src/wrappers/kind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ pub enum Kind {
UInt4,
UInt5,
UInt6,
UInt7
UInt7,
}

impl Kind {
Expand Down
4 changes: 2 additions & 2 deletions src/wrappers/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -763,14 +763,14 @@ impl Tensor {
}

/// Returns a tensor with pinned memory.
///
///
/// Pinned memory allows for faster data transfer between CPU and GPU.
pub fn pin_memory(&self) -> Tensor {
self.f_internal_pin_memory(None::<Device>).unwrap()
}

/// Returns true if this tensor resides in pinned memory.
///
///
/// Pinned memory allows for faster data transfer between CPU and GPU.
pub fn is_pinned(&self) -> bool {
self.f_is_pinned(None::<Device>).unwrap()
Expand Down
11 changes: 10 additions & 1 deletion tests/device_tests.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
use tch::{Device, Tensor};
use tch::{Cuda, Device, Tensor};

#[test]
fn tensor_device() {
let t = Tensor::from_slice(&[3, 1, 4]);
assert_eq!(t.device(), Device::Cpu)
}

#[test]
fn cuda_device_capability() {
if Cuda::is_available() {
let (major, minor) = Cuda::get_device_capability(0).unwrap();
assert!(major > 0, "expected a positive major compute capability, got {major}");
assert!(minor >= 0, "expected a non-negative minor compute capability, got {minor}");
}
}
37 changes: 13 additions & 24 deletions torch-sys/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,7 @@ fn extract<P: AsRef<Path>>(filename: P, outpath: P) -> anyhow::Result<()> {

// This is if we're unzipping a python wheel.
if outpath.as_ref().join("torch").exists() {
fs::rename(
outpath.as_ref().join("torch"),
outpath.as_ref().join("libtorch"),
)?;
fs::rename(outpath.as_ref().join("torch"), outpath.as_ref().join("libtorch"))?;
}
Ok(())
}
Expand Down Expand Up @@ -179,10 +176,7 @@ fn version_check(version: &str) -> Result<()> {

impl SystemInfo {
fn new() -> Result<Self> {
let os = match env::var("CARGO_CFG_TARGET_OS")
.expect("Unable to get TARGET_OS")
.as_str()
{
let os = match env::var("CARGO_CFG_TARGET_OS").expect("Unable to get TARGET_OS").as_str() {
"linux" => Os::Linux,
"windows" => Os::Windows,
"macos" => Os::Macos,
Expand Down Expand Up @@ -276,10 +270,7 @@ impl SystemInfo {
libtorch_include_dirs.push(includes.join("include"));
libtorch_include_dirs.push(includes.join("include/torch/csrc/api/include"));
libtorch_lib_dir = Some(lib.join("lib"));
(
env_var_rerun("LIBTORCH_CXX11_ABI").unwrap_or_else(|_| "1".to_owned()),
None,
)
(env_var_rerun("LIBTORCH_CXX11_ABI").unwrap_or_else(|_| "1".to_owned()), None)
};
if let Ok(cuda_root) = env_var_rerun("CUDA_ROOT") {
libtorch_include_dirs.push(PathBuf::from(cuda_root).join("include"))
Expand All @@ -302,9 +293,7 @@ impl SystemInfo {

fn check_system_location(os: Os) -> Option<PathBuf> {
match os {
Os::Linux => Path::new("/usr/lib/libtorch.so")
.exists()
.then(|| PathBuf::from("/usr")),
Os::Linux => Path::new("/usr/lib/libtorch.so").exists().then(|| PathBuf::from("/usr")),
_ => None,
}
}
Expand Down Expand Up @@ -417,11 +406,8 @@ impl SystemInfo {
println!("cargo:rerun-if-changed=libtch/stb_image_write.h");
println!("cargo:rerun-if-changed=libtch/stb_image_resize.h");
println!("cargo:rerun-if-changed=libtch/stb_image.h");
let mut c_files = vec![
"libtch/torch_api.cpp",
"libtch/torch_api_generated.cpp",
cuda_dependency,
];
let mut c_files =
vec!["libtch/torch_api.cpp", "libtch/torch_api_generated.cpp", cuda_dependency];
if cfg!(feature = "python-extension") {
c_files.push("libtch/torch_python.cpp")
}
Expand All @@ -442,6 +428,9 @@ impl SystemInfo {
.flag("-std=c++17")
.flag(format!("-D_GLIBCXX_USE_CXX11_ABI={}", self.cxx11_abi))
.flag("-DGLOG_USE_GLOG_EXPORT");
if use_cuda {
builder.define("USE_CUDA", None);
}
if cfg!(feature = "nccl") {
builder.flag("-DUSE_C10D_NCCL");
}
Expand All @@ -459,6 +448,9 @@ impl SystemInfo {
.includes(&self.libtorch_include_dirs)
.flag("/std:c++17")
.flag("/p:DefineConstants=GLOG_USE_GLOG_EXPORT");
if use_cuda {
builder.define("USE_CUDA", None);
}
if cfg!(feature = "nccl") {
builder.flag("/p:DefineConstants=USE_C10D_NCCL");
}
Expand Down Expand Up @@ -531,10 +523,7 @@ fn main() -> anyhow::Result<()> {
system_info.link("torch_python");
system_info.link(&format!(
"python{}",
system_info
.python_version
.as_ref()
.expect("python version is set")
system_info.python_version.as_ref().expect("python version is set")
));
}
if system_info.link_type == LinkType::Static {
Expand Down
17 changes: 17 additions & 0 deletions torch-sys/libtch/torch_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
#include<torch/csrc/jit/runtime/graph_executor.h>
#include<torch/torch.h>
#include<ATen/autocast_mode.h>
#ifdef USE_CUDA
#include<ATen/cuda/CUDAContext.h>
#endif
#include<torch/script.h>
#include<torch/csrc/jit/passes/tensorexpr_fuser.h>
#include<torch/csrc/jit/codegen/cuda/interface.h>
Expand Down Expand Up @@ -1022,6 +1025,20 @@ void atc_set_benchmark_cudnn(int b) {
)
}

void atc_cuda_get_device_capability(int device_index, int *major, int *minor) {
#if defined(USE_CUDA)
PROTECT(
auto props = at::cuda::getDeviceProperties(device_index);
*major = props->major;
*minor = props->minor;
)
#else
*major = -1;
*minor = -1;
torch_last_err = strdup("CUDA is not available in this build");
#endif
}

bool at_context_has_openmp() {
PROTECT (
return at::globalContext().hasOpenMP();
Expand Down
3 changes: 3 additions & 0 deletions torch-sys/libtch/torch_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,9 @@ void atc_manual_seed_all(uint64_t seed);
/// Waits for all kernels in all streams on a CUDA device to complete.
void atc_synchronize(int64_t device_index);

/// Retrieves the compute capability of a CUDA device.
void atc_cuda_get_device_capability(int device_index, int *major, int *minor);


int atc_user_enabled_cudnn();
void atc_set_user_enabled_cudnn(int b);
Expand Down
7 changes: 7 additions & 0 deletions torch-sys/src/cuda.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,11 @@ extern "C" {

/// Sets CUDNN benchmark mode.
pub fn atc_set_benchmark_cudnn(b: c_int);

/// Retrieves the compute capability of a CUDA device.
pub fn atc_cuda_get_device_capability(
device_index: c_int,
major: *mut c_int,
minor: *mut c_int,
);
}