From c8b26cf8e9540cfaba9b42308d94d2c9bb7fe8bd Mon Sep 17 00:00:00 2001
From: Pantelis Sopasakis
Date: Fri, 28 Mar 2025 22:53:45 +0000
Subject: [PATCH] streams in kernels
---
CHANGELOG.md | 10 ++++++++++
include/tensor.cuh | 13 +++++++++++--
2 files changed, 21 insertions(+), 2 deletions(-)
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 5d853fe..bb276dd 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -5,6 +5,16 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
+
+## v1.9.0 - 28-03-2025
+
+### Fixed
+
+- Using streams in kernels
+
+
diff --git a/include/tensor.cuh b/include/tensor.cuh
index b73e280..8ba6f8b 100644
--- a/include/tensor.cuh
+++ b/include/tensor.cuh
@@ -197,6 +197,13 @@ public:
*/
cusolverDnHandle_t &cuSolverHandle(size_t idx = 0) { return m_cusolverHandles[idx]; }
+ /**
+ *
+ * @param idx index of stream
+ * @return stream
+ */
+ cudaStream_t &stream(size_t idx = 0) { return m_cublasStreams[idx]; }
+
/**
* Preferred method for CUDA memory allocation; it allocated memory on the device
* and counts the allocated bytes (you can then call #totalAllocatedBytes()).
@@ -1602,7 +1609,8 @@ public:
for (size_t i = 0; i < m_rank->numMats(); i++) {
DTensor Si(*m_S, 2, i, i);
DTensor rankI(*m_rank, 2, i, i);
- k_countNonzeroSingularValues<<>>(Si.raw(), numElS,
+ cudaStream_t s = Session::getInstance().stream(m_tensor->streamIdx());
+ k_countNonzeroSingularValues<<>>(Si.raw(), numElS,
rankI.raw(), epsilon);
}
return *m_rank;
@@ -2301,7 +2309,8 @@ inline void GivensAnnihilator::annihilate(size_t i, size_t k, size_t j) {
T *matData = m_matrix->raw();
/* Call kernel to determine 1/sqrt(Ai^2 + Ak^2) */
- k_givensAnnihilateRHypot<<<1, 1>>>(m_matrix->raw(), aux, i, k, j, nR);
+ cudaStream_t s = Session::getInstance().stream(m_matrix->streamIdx());
+ k_givensAnnihilateRHypot<<<1, 1, 0, s>>>(m_matrix->raw(), aux, i, k, j, nR);
/* Apply Givens rotation */
m_matrix->applyLeftGivensRotation(i, k, aux + 1, aux + 2);