Skip to content

Commit 2a2f8d8

Browse files
committed
[NV TRT RTX EP]: add support for OrtHardwareDevice on Linux
1 parent 970a9a0 commit 2a2f8d8

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "core/providers/shared_library/provider_api.h"
66
#include "nv_provider_factory.h"
77
#include <atomic>
8+
#include <string>
89
#include "nv_execution_provider.h"
910
#include "nv_provider_factory_creator.h"
1011
#include "nv_data_transfer.h"
@@ -575,6 +576,7 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory {
575576
* @return True if the device is a supported NVIDIA GPU, false otherwise.
576577
*/
577578
bool IsOrtHardwareDeviceSupported(const OrtHardwareDevice& device) {
579+
#if _WIN32
578580
const auto& metadata_entries = device.metadata.Entries();
579581
const auto it = metadata_entries.find("LUID");
580582
if (it == metadata_entries.end()) {
@@ -616,6 +618,25 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory {
616618
}
617619

618620
return false;
621+
#else
622+
const auto& metadata_entries = device.metadata.Entries();
623+
const auto it = metadata_entries.find("bus_id");
624+
if (it == metadata_entries.end()) {
625+
return false;
626+
}
627+
auto& target_id = it->second;
628+
int cuda_device_idx = 0;
629+
if (cudaDeviceGetByPCIBusId(&cuda_device_idx, target_id.c_str()) != cudaSuccess) {
630+
return false;
631+
}
632+
633+
cudaDeviceProp prop;
634+
if (cudaGetDeviceProperties(&prop, cuda_device_idx) != cudaSuccess) {
635+
return false;
636+
}
637+
// Ampere architecture or newer is required.
638+
return prop.major >= 8;
639+
#endif
619640
}
620641

621642
// Creates and returns OrtEpDevice instances for all OrtHardwareDevices that this factory supports.

0 commit comments

Comments
 (0)