diff --git a/backends/vulkan/runtime/vk_api/Adapter.cpp b/backends/vulkan/runtime/vk_api/Adapter.cpp index 8eae5ff35e4..82482a5b7c4 100644 --- a/backends/vulkan/runtime/vk_api/Adapter.cpp +++ b/backends/vulkan/runtime/vk_api/Adapter.cpp @@ -375,6 +375,10 @@ void Adapter::submit_cmd( VK_CHECK(vkQueueSubmit(device_queue.handle, 1u, &submit_info, fence)); } +void Adapter::override_device_name(const std::string& new_name) { + physical_device_.override_device_name(new_name); +} + std::string Adapter::stringize() const { std::stringstream ss; diff --git a/backends/vulkan/runtime/vk_api/Adapter.h b/backends/vulkan/runtime/vk_api/Adapter.h index 89beb5c3a5c..3c503deab70 100644 --- a/backends/vulkan/runtime/vk_api/Adapter.h +++ b/backends/vulkan/runtime/vk_api/Adapter.h @@ -306,6 +306,8 @@ class Adapter final { VkSemaphore wait_semaphore = VK_NULL_HANDLE, VkSemaphore signal_semaphore = VK_NULL_HANDLE); + void override_device_name(const std::string& new_name); + std::string stringize() const; friend std::ostream& operator<<(std::ostream&, const Adapter&); }; diff --git a/backends/vulkan/runtime/vk_api/Device.cpp b/backends/vulkan/runtime/vk_api/Device.cpp index dbe8a73651c..cb6a54dc489 100644 --- a/backends/vulkan/runtime/vk_api/Device.cpp +++ b/backends/vulkan/runtime/vk_api/Device.cpp @@ -21,6 +21,25 @@ namespace vkcompute { namespace vkapi { +namespace { + +DeviceType determine_device_type(const std::string& device_name) { + if (device_name.find("adreno") != std::string::npos) { + return DeviceType::ADRENO; + } else if (device_name.find("swiftshader") != std::string::npos) { + return DeviceType::SWIFTSHADER; + } else if (device_name.find("nvidia") != std::string::npos) { + return DeviceType::NVIDIA; + } else if ( + device_name.find("mali") != std::string::npos || + device_name.find("immortalis") != std::string::npos) { + return DeviceType::MALI; + } + return DeviceType::UNKNOWN; +} + +} // namespace + PhysicalDevice::PhysicalDevice( VkInstance instance_handle, VkPhysicalDevice physical_device_handle) @@ -126,17 +145,7 @@ PhysicalDevice::PhysicalDevice( device_name.begin(), [](unsigned char c) { return std::tolower(c); }); - if (device_name.find("adreno") != std::string::npos) { - device_type = DeviceType::ADRENO; - } else if (device_name.find("swiftshader") != std::string::npos) { - device_type = DeviceType::SWIFTSHADER; - } else if (device_name.find("nvidia") != std::string::npos) { - device_type = DeviceType::NVIDIA; - } else if ( - device_name.find("mali") != std::string::npos || - device_name.find("immortalis") != std::string::npos) { - device_type = DeviceType::MALI; - } + device_type = determine_device_type(device_name); } void PhysicalDevice::query_extensions_vk_1_0() { @@ -292,6 +301,17 @@ void PhysicalDevice::query_extensions_vk_1_1() { vkGetPhysicalDeviceProperties2(handle, &properties2); } +void PhysicalDevice::override_device_name(const std::string& new_name) { + device_name = new_name; + std::transform( + device_name.begin(), + device_name.end(), + device_name.begin(), + [](unsigned char c) { return std::tolower(c); }); + + device_type = determine_device_type(device_name); +} + // // DeviceHandle // diff --git a/backends/vulkan/runtime/vk_api/Device.h b/backends/vulkan/runtime/vk_api/Device.h index ac5e381e46a..9fa413b2457 100644 --- a/backends/vulkan/runtime/vk_api/Device.h +++ b/backends/vulkan/runtime/vk_api/Device.h @@ -84,6 +84,9 @@ struct PhysicalDevice final { private: void query_extensions_vk_1_0(); void query_extensions_vk_1_1(); + + public: + void override_device_name(const std::string& new_name); }; struct DeviceHandle final {