diff --git a/.ci/Brewfile b/.ci/Brewfile index fc84fbbcde6f..43559b823b16 100644 --- a/.ci/Brewfile +++ b/.ci/Brewfile @@ -36,6 +36,7 @@ brew 'jpeg-turbo' brew 'jpeg-xl' brew 'json-glib' brew 'lensfun' +brew 'libarchive' brew 'libavif' brew 'libheif' brew 'libraw' diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2f8e7151d854..6cf469ff622b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -114,6 +114,7 @@ jobs: git \ gdb \ intltool \ + libarchive-dev \ libatk1.0-dev \ libavif-dev \ libcairo2-dev \ @@ -245,6 +246,7 @@ jobs: iso-codes:p lcms2:p lensfun:p + libarchive:p libavif:p libgphoto2:p libheif:p diff --git a/DefineOptions.cmake b/DefineOptions.cmake index 729faed12ecb..4ed0bd34437e 100644 --- a/DefineOptions.cmake +++ b/DefineOptions.cmake @@ -26,6 +26,7 @@ option(USE_LIBRAW "Enable LibRaw support" ON) option(DONT_USE_INTERNAL_LIBRAW "If possible, use system instead of intree copy of LibRaw" OFF) option(BUILD_CMSTEST "Build a test program to check your system's color management setup" ON) option(USE_OPENEXR "Enable OpenEXR support" ON) +option(BUILD_AI "Enable AI support" OFF) option(BUILD_PRINT "Build the print module" ON) option(BUILD_RS_IDENTIFY "Build the darktable-rs-identify debug aid" ON) option(BUILD_SSE2_CODEPATHS "(EXPERIMENTAL OPTION, DO NOT DISABLE) Building SSE2-optimized codepaths" ON) diff --git a/README.md b/README.md index 8b2638e92a32..580d311159eb 100644 --- a/README.md +++ b/README.md @@ -194,6 +194,7 @@ Optional dependencies (minimum version): * libgphoto2 2.5 *(for camera tethering)* * Imath 3.1.0 *(for 16-bit "half" float TIFF export and faster import)* * libavif 0.9.3 *(for AVIF import & export)* +* libarchive 3.8.5 *(for AI models download)* * libheif 1.13.0 *(for HEIF/HEIC/HIF import; also for AVIF import if no libavif)* * libjxl 0.7.0 *(for JPEG XL import & export)* * WebP 0.3.0 *(for WebP import & export)* diff --git a/cmake/modules/FindONNXRuntime.cmake b/cmake/modules/FindONNXRuntime.cmake new file mode 100644 index 000000000000..7c48742609da --- /dev/null +++ b/cmake/modules/FindONNXRuntime.cmake @@ -0,0 +1,461 @@ +# FindONNXRuntime.cmake +# +# Find ONNX Runtime pre-built binaries, downloading them automatically if not +# found. Searches two locations: +# +# 1. Source tree: ${CMAKE_SOURCE_DIR}/src/external/onnxruntime/ +# (manually pre-installed) +# 2. Build tree: ${CMAKE_BINARY_DIR}/_deps/onnxruntime/ +# (auto-download destination) +# +# After this module completes the imported target onnxruntime::onnxruntime is +# available for linking, along with the standard variables: +# +# ONNXRuntime_FOUND - TRUE if found +# ONNXRuntime_INCLUDE_DIRS - include directories +# ONNXRuntime_LIBRARIES - libraries to link +# ONNXRuntime_LIB_DIR - directory containing the shared library +# +# Cache variables that influence behaviour: +# +# ONNXRUNTIME_VERSION - version to download (default 1.23.2) +# ONNXRUNTIME_DIRECTML_VERSION - DirectML NuGet version for Windows (default 1.23.0) +# ONNXRUNTIME_OFFLINE - if TRUE, never attempt a download (default OFF) + +# Skip if already found (avoid re-entry issues with cached variables) +if(ONNXRuntime_FOUND) + return() +endif() + +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- +set(ONNXRUNTIME_VERSION "1.23.2" CACHE STRING "ONNX Runtime version to download") +set(ONNXRUNTIME_DIRECTML_VERSION "1.23.0" CACHE STRING "DirectML NuGet package version for Windows") +option(ONNXRUNTIME_OFFLINE "Disable automatic download of ONNX Runtime" OFF) + +set(_ORT_SRC_ROOT "${CMAKE_SOURCE_DIR}/src/external/onnxruntime") +set(_ORT_BUILD_ROOT "${CMAKE_BINARY_DIR}/_deps/onnxruntime") + +# Invalidate cached results if the files no longer exist on disk +# (prevents stale paths from a previous build or deleted installation) +if(_ORT_HEADER AND NOT EXISTS "${_ORT_HEADER}/onnxruntime_c_api.h") + unset(_ORT_HEADER CACHE) +endif() +if(_ORT_LIBRARY AND NOT EXISTS "${_ORT_LIBRARY}") + unset(_ORT_LIBRARY CACHE) +endif() + +# --------------------------------------------------------------------------- +# Search for existing installation (source tree first, then build tree) +# --------------------------------------------------------------------------- +macro(_ort_find_at ROOT) + # Standard layout (GitHub releases): include/ + lib/ + find_path(_ORT_HEADER + NAMES onnxruntime_c_api.h + PATHS "${ROOT}/include" + NO_DEFAULT_PATH + ) + find_library(_ORT_LIBRARY + NAMES onnxruntime + PATHS "${ROOT}/lib" + NO_DEFAULT_PATH + ) + # NuGet layout: build/native/include/ + runtimes//native/ + if(NOT _ORT_HEADER) + find_path(_ORT_HEADER + NAMES onnxruntime_c_api.h + PATHS "${ROOT}/build/native/include" + NO_DEFAULT_PATH + ) + endif() + if(NOT _ORT_LIBRARY AND WIN32) + find_library(_ORT_LIBRARY + NAMES onnxruntime + PATHS "${ROOT}/runtimes/win-x64/native" "${ROOT}/runtimes/win-arm64/native" + NO_DEFAULT_PATH + ) + endif() + if(_ORT_HEADER AND _ORT_LIBRARY) + set(_ORT_ROOT "${ROOT}") + endif() +endmacro() + +# Try source tree (manually pre-installed) +_ort_find_at("${_ORT_SRC_ROOT}") + +# Try build tree (populated by prior auto-download) +if(NOT _ORT_HEADER OR NOT _ORT_LIBRARY) + unset(_ORT_HEADER CACHE) + unset(_ORT_LIBRARY CACHE) + _ort_find_at("${_ORT_BUILD_ROOT}") +endif() + +# Try system-installed package (e.g. libonnxruntime-dev on Linux) +if(NOT _ORT_HEADER OR NOT _ORT_LIBRARY) + unset(_ORT_HEADER CACHE) + unset(_ORT_LIBRARY CACHE) + find_path(_ORT_HEADER + NAMES onnxruntime_c_api.h + PATH_SUFFIXES onnxruntime + ) + find_library(_ORT_LIBRARY NAMES onnxruntime) + if(_ORT_HEADER AND _ORT_LIBRARY) + get_filename_component(_ORT_INC_DIR "${_ORT_HEADER}" DIRECTORY) + get_filename_component(_ORT_ROOT "${_ORT_INC_DIR}" DIRECTORY) + set(_ORT_SYSTEM_PACKAGE TRUE) + message(STATUS "Found system ONNX Runtime: ${_ORT_LIBRARY}") + endif() +endif() + +# --------------------------------------------------------------------------- +# Auto-download if not found +# --------------------------------------------------------------------------- +if(NOT _ORT_HEADER OR NOT _ORT_LIBRARY) + if(ONNXRUNTIME_OFFLINE) + message(STATUS "ONNX Runtime not found and ONNXRUNTIME_OFFLINE is ON.") + else() + # -- Determine package name for current platform -- + set(_ORT_VER "${ONNXRUNTIME_VERSION}") + + if(CMAKE_SYSTEM_NAME STREQUAL "Darwin") + if(CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") + set(_ORT_PACKAGE "onnxruntime-osx-arm64-${_ORT_VER}.tgz") + elseif(CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64") + set(_ORT_PACKAGE "onnxruntime-osx-x86_64-${_ORT_VER}.tgz") + else() + message(FATAL_ERROR "Unsupported macOS architecture: ${CMAKE_SYSTEM_PROCESSOR}") + endif() + elseif(CMAKE_SYSTEM_NAME STREQUAL "Linux") + if(CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64") + set(_ORT_PACKAGE "onnxruntime-linux-x64-${_ORT_VER}.tgz") + elseif(CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64") + set(_ORT_PACKAGE "onnxruntime-linux-aarch64-${_ORT_VER}.tgz") + else() + message(FATAL_ERROR "Unsupported Linux architecture: ${CMAKE_SYSTEM_PROCESSOR}") + endif() + elseif(CMAKE_SYSTEM_NAME STREQUAL "Windows") + # On Windows, download the DirectML NuGet package for GPU support + # (AMD/Intel/NVIDIA via DirectX 12). NuGet packages are just ZIP files. + set(_ORT_VER "${ONNXRUNTIME_DIRECTML_VERSION}") + set(_ORT_PACKAGE "microsoft.ml.onnxruntime.directml.${_ORT_VER}.nupkg") + set(_ORT_NUGET TRUE) + if(CMAKE_SYSTEM_PROCESSOR MATCHES "AMD64|x86_64") + set(_ORT_NUGET_RID "win-x64") + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "ARM64|aarch64") + set(_ORT_NUGET_RID "win-arm64") + else() + message(FATAL_ERROR "Unsupported Windows architecture: ${CMAKE_SYSTEM_PROCESSOR}") + endif() + else() + message(FATAL_ERROR "Unsupported OS for ONNX Runtime auto-download: ${CMAKE_SYSTEM_NAME}") + endif() + + if(_ORT_NUGET) + set(_ORT_URL "https://www.nuget.org/api/v2/package/Microsoft.ML.OnnxRuntime.DirectML/${_ORT_VER}") + else() + set(_ORT_URL "https://github.com/microsoft/onnxruntime/releases/download/v${_ORT_VER}/${_ORT_PACKAGE}") + endif() + set(_ORT_DOWNLOAD_DIR "${CMAKE_BINARY_DIR}/_deps") + set(_ORT_ARCHIVE "${_ORT_DOWNLOAD_DIR}/${_ORT_PACKAGE}") + + # -- Fetch integrity hash from package repository -- + # All auto-downloaded packages must be verified. The hash is fetched from the + # repository API (GitHub Releases or NuGet catalog) and checked after download. + set(_ORT_HASH "") + set(_ORT_HASH_ALGO "") + file(MAKE_DIRECTORY "${_ORT_DOWNLOAD_DIR}") + + if(_ORT_NUGET) + # NuGet: fetch SHA512 from the catalog API (base64-encoded). + # Step 1: registration entry → catalog URL + set(_ORT_NUGET_ID "microsoft.ml.onnxruntime.directml") + set(_ORT_REG_JSON "${_ORT_DOWNLOAD_DIR}/nuget-reg-${_ORT_VER}.json") + set(_ORT_REG_URL "https://api.nuget.org/v3/registration5-semver1/${_ORT_NUGET_ID}/${_ORT_VER}.json") + if(NOT EXISTS "${_ORT_REG_JSON}") + message(STATUS "Fetching NuGet registration for ${_ORT_NUGET_ID} ${_ORT_VER}...") + file(DOWNLOAD "${_ORT_REG_URL}" "${_ORT_REG_JSON}" STATUS _ORT_REG_STATUS) + list(GET _ORT_REG_STATUS 0 _ORT_REG_CODE) + if(NOT _ORT_REG_CODE EQUAL 0) + file(REMOVE "${_ORT_REG_JSON}") + endif() + endif() + if(EXISTS "${_ORT_REG_JSON}") + file(READ "${_ORT_REG_JSON}" _ORT_REG_CONTENT) + string(REGEX MATCH "\"catalogEntry\" *: *\"([^\"]+)\"" _ORT_CAT_MATCH "${_ORT_REG_CONTENT}") + if(_ORT_CAT_MATCH) + set(_ORT_CAT_URL "${CMAKE_MATCH_1}") + # Step 2: catalog entry → packageHash (base64 SHA512) + set(_ORT_CAT_JSON "${_ORT_DOWNLOAD_DIR}/nuget-catalog-${_ORT_VER}.json") + if(NOT EXISTS "${_ORT_CAT_JSON}") + message(STATUS "Fetching NuGet catalog entry...") + file(DOWNLOAD "${_ORT_CAT_URL}" "${_ORT_CAT_JSON}" STATUS _ORT_CAT_STATUS) + list(GET _ORT_CAT_STATUS 0 _ORT_CAT_CODE) + if(NOT _ORT_CAT_CODE EQUAL 0) + file(REMOVE "${_ORT_CAT_JSON}") + endif() + endif() + if(EXISTS "${_ORT_CAT_JSON}") + file(READ "${_ORT_CAT_JSON}" _ORT_CAT_CONTENT) + string(REGEX MATCH "\"packageHash\" *: *\"([^\"]+)\"" _ORT_HASH_MATCH "${_ORT_CAT_CONTENT}") + if(_ORT_HASH_MATCH) + set(_ORT_B64_HASH "${CMAKE_MATCH_1}") + # Convert base64 SHA512 to hex via PowerShell (NuGet path is Windows-only) + execute_process( + COMMAND powershell -NoProfile -Command + "[BitConverter]::ToString([Convert]::FromBase64String('${_ORT_B64_HASH}')).Replace('-','').ToLower()" + OUTPUT_VARIABLE _ORT_HASH + OUTPUT_STRIP_TRAILING_WHITESPACE + RESULT_VARIABLE _ORT_PS_RESULT + ) + if(_ORT_PS_RESULT EQUAL 0 AND _ORT_HASH) + set(_ORT_HASH_ALGO "SHA512") + message(STATUS "NuGet package ${_ORT_PACKAGE} SHA512: ${_ORT_HASH}") + else() + set(_ORT_HASH "") + endif() + endif() + endif() + endif() + endif() + else() + # GitHub: fetch SHA256 from the Releases API (hex-encoded digest field). + set(_ORT_API_JSON "${_ORT_DOWNLOAD_DIR}/ort-release-${_ORT_VER}.json") + set(_ORT_API_URL "https://api.github.com/repos/microsoft/onnxruntime/releases/tags/v${_ORT_VER}") + if(NOT EXISTS "${_ORT_API_JSON}") + message(STATUS "Fetching ONNX Runtime v${_ORT_VER} release metadata from GitHub API...") + file(DOWNLOAD + "${_ORT_API_URL}" + "${_ORT_API_JSON}" + STATUS _ORT_API_STATUS + HTTPHEADER "Accept: application/vnd.github+json" + ) + list(GET _ORT_API_STATUS 0 _ORT_API_CODE) + if(NOT _ORT_API_CODE EQUAL 0) + file(REMOVE "${_ORT_API_JSON}") + endif() + endif() + if(EXISTS "${_ORT_API_JSON}") + file(READ "${_ORT_API_JSON}" _ORT_API_CONTENT) + # Two-step extraction: locate the package name, then find the first + # "digest" field that follows it within the same asset JSON object. + string(FIND "${_ORT_API_CONTENT}" "\"name\":\"${_ORT_PACKAGE}\"" _ORT_NAME_POS) + if(_ORT_NAME_POS EQUAL -1) + # Try with spaces around colon (GitHub API may format either way) + string(FIND "${_ORT_API_CONTENT}" "\"name\": \"${_ORT_PACKAGE}\"" _ORT_NAME_POS) + endif() + if(NOT _ORT_NAME_POS EQUAL -1) + string(SUBSTRING "${_ORT_API_CONTENT}" ${_ORT_NAME_POS} 2000 _ORT_ASSET_TAIL) + string(REGEX MATCH "\"digest\" *: *\"sha256:([a-f0-9]+)\"" _ORT_DIGEST_MATCH "${_ORT_ASSET_TAIL}") + if(_ORT_DIGEST_MATCH) + set(_ORT_HASH "${CMAKE_MATCH_1}") + set(_ORT_HASH_ALGO "SHA256") + message(STATUS "ONNX Runtime ${_ORT_PACKAGE} SHA256: ${_ORT_HASH}") + endif() + endif() + endif() + endif() + + if(NOT _ORT_HASH) + message(WARNING + "Could not obtain integrity hash for ${_ORT_PACKAGE} from package repository. " + "Refusing to download without verification. " + "Install ONNX Runtime manually or set ONNXRUNTIME_OFFLINE=ON.") + else() # _ORT_HASH obtained — proceed with verified download + + # -- Verify cached archive if it exists -- + if(EXISTS "${_ORT_ARCHIVE}") + file(${_ORT_HASH_ALGO} "${_ORT_ARCHIVE}" _ORT_CACHED_HASH) + if(NOT _ORT_CACHED_HASH STREQUAL "${_ORT_HASH}") + message(STATUS "Cached ONNX Runtime archive has wrong checksum, re-downloading...") + file(REMOVE "${_ORT_ARCHIVE}") + endif() + endif() + + # -- Download -- + if(NOT EXISTS "${_ORT_ARCHIVE}") + message(STATUS "Downloading ONNX Runtime ${_ORT_VER} (${_ORT_PACKAGE})...") + file(DOWNLOAD + "${_ORT_URL}" + "${_ORT_ARCHIVE}" + STATUS _ORT_DL_STATUS + SHOW_PROGRESS + EXPECTED_HASH "${_ORT_HASH_ALGO}=${_ORT_HASH}" + ) + list(GET _ORT_DL_STATUS 0 _ORT_DL_CODE) + list(GET _ORT_DL_STATUS 1 _ORT_DL_MSG) + if(NOT _ORT_DL_CODE EQUAL 0) + file(REMOVE "${_ORT_ARCHIVE}") + message(WARNING + "Failed to download ONNX Runtime from ${_ORT_URL}: ${_ORT_DL_MSG}. " + "Install ONNX Runtime manually or download from: " + "https://github.com/microsoft/onnxruntime/releases") + endif() + endif() + + # -- Extract to a temporary directory -- + set(_ORT_EXTRACT_DIR "${_ORT_DOWNLOAD_DIR}/onnxruntime-extract") + if(EXISTS "${_ORT_EXTRACT_DIR}") + file(REMOVE_RECURSE "${_ORT_EXTRACT_DIR}") + endif() + message(STATUS "Extracting ${_ORT_PACKAGE}...") + file(ARCHIVE_EXTRACT + INPUT "${_ORT_ARCHIVE}" + DESTINATION "${_ORT_EXTRACT_DIR}" + ) + + # -- Install into build tree -- + if(EXISTS "${_ORT_BUILD_ROOT}") + file(REMOVE_RECURSE "${_ORT_BUILD_ROOT}") + endif() + file(MAKE_DIRECTORY "${_ORT_BUILD_ROOT}/include") + file(MAKE_DIRECTORY "${_ORT_BUILD_ROOT}/lib") + + if(_ORT_NUGET) + # NuGet package layout: + # build/native/include/ -> headers + # runtimes//native/ -> DLLs and .lib files + file(GLOB _ORT_NUGET_HEADERS "${_ORT_EXTRACT_DIR}/build/native/include/*") + file(COPY ${_ORT_NUGET_HEADERS} DESTINATION "${_ORT_BUILD_ROOT}/include") + file(GLOB _ORT_NUGET_LIBS "${_ORT_EXTRACT_DIR}/runtimes/${_ORT_NUGET_RID}/native/*") + file(COPY ${_ORT_NUGET_LIBS} DESTINATION "${_ORT_BUILD_ROOT}/lib") + else() + # GitHub release layout: single top-level directory with lib/ and include/ + file(GLOB _ORT_INNER_DIRS "${_ORT_EXTRACT_DIR}/*") + list(LENGTH _ORT_INNER_DIRS _ORT_INNER_COUNT) + if(_ORT_INNER_COUNT EQUAL 1) + list(GET _ORT_INNER_DIRS 0 _ORT_INNER) + else() + set(_ORT_INNER "${_ORT_EXTRACT_DIR}") + endif() + file(COPY "${_ORT_INNER}/lib" DESTINATION "${_ORT_BUILD_ROOT}") + file(COPY "${_ORT_INNER}/include" DESTINATION "${_ORT_BUILD_ROOT}") + endif() + + # -- Cleanup extraction directory -- + file(REMOVE_RECURSE "${_ORT_EXTRACT_DIR}") + + # -- Set root and re-search -- + set(_ORT_ROOT "${_ORT_BUILD_ROOT}") + + unset(_ORT_HEADER CACHE) + unset(_ORT_LIBRARY CACHE) + + find_path(_ORT_HEADER + NAMES onnxruntime_c_api.h + PATHS "${_ORT_ROOT}/include" + NO_DEFAULT_PATH + ) + find_library(_ORT_LIBRARY + NAMES onnxruntime + PATHS "${_ORT_ROOT}/lib" + NO_DEFAULT_PATH + ) + endif() # _ORT_HASH + endif() # NOT ONNXRUNTIME_OFFLINE +endif() + +# --------------------------------------------------------------------------- +# Create onnxruntime::onnxruntime imported target +# --------------------------------------------------------------------------- +if(_ORT_HEADER AND _ORT_LIBRARY AND NOT TARGET onnxruntime::onnxruntime) + + # _ORT_HEADER is the directory where find_path() located onnxruntime_c_api.h. + # Use it directly — it's correct for all layouts (GitHub releases: include/, + # NuGet: build/native/include/, system: /usr/include/onnxruntime). + set(_ORT_INCLUDE_DIR "${_ORT_HEADER}") + + # Try the shipped CMake config files first (they exist under lib/cmake/) + set(_ORT_CMAKE_DIR "${_ORT_ROOT}/lib/cmake/onnxruntime") + set(_ORT_CMAKE_TARGETS "${_ORT_CMAKE_DIR}/onnxruntimeTargets.cmake") + + # Patch the known include-path bug before loading the config + if(EXISTS "${_ORT_CMAKE_TARGETS}") + file(READ "${_ORT_CMAKE_TARGETS}" _ORT_TARGETS_CONTENT) + if(_ORT_TARGETS_CONTENT MATCHES "include/onnxruntime" + AND NOT EXISTS "${_ORT_ROOT}/include/onnxruntime") + string(REPLACE "/include/onnxruntime" "/include" + _ORT_TARGETS_CONTENT "${_ORT_TARGETS_CONTENT}") + file(WRITE "${_ORT_CMAKE_TARGETS}" "${_ORT_TARGETS_CONTENT}") + message(STATUS "Patched ONNX Runtime CMake config: include/onnxruntime -> include") + endif() + endif() + + # Patch the lib64 vs lib path issue in CMake targets + file(GLOB _ORT_CMAKE_TARGET_FILES "${_ORT_CMAKE_DIR}/onnxruntimeTargets*.cmake") + foreach(_ORT_TARGET_FILE ${_ORT_CMAKE_TARGET_FILES}) + if(EXISTS "${_ORT_TARGET_FILE}") + file(READ "${_ORT_TARGET_FILE}" _ORT_TARGET_CONTENT) + if(_ORT_TARGET_CONTENT MATCHES "/lib64/" + AND NOT EXISTS "${_ORT_ROOT}/lib64") + string(REPLACE "/lib64/" "/lib/" _ORT_TARGET_CONTENT "${_ORT_TARGET_CONTENT}") + file(WRITE "${_ORT_TARGET_FILE}" "${_ORT_TARGET_CONTENT}") + message(STATUS "Patched ONNX Runtime CMake config: lib64 -> lib") + endif() + endif() + endforeach() + + find_package(onnxruntime QUIET + PATHS "${_ORT_ROOT}" + NO_DEFAULT_PATH + ) + + # Fallback: create the imported target manually + if(NOT TARGET onnxruntime::onnxruntime) + message(STATUS "Creating onnxruntime::onnxruntime imported target manually") + add_library(onnxruntime::onnxruntime SHARED IMPORTED) + set_target_properties(onnxruntime::onnxruntime PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "${_ORT_INCLUDE_DIR}" + ) + if(WIN32) + # On Windows find_library finds the .lib import library; + # IMPORTED_IMPLIB = .lib, IMPORTED_LOCATION = .dll + set_target_properties(onnxruntime::onnxruntime PROPERTIES + IMPORTED_IMPLIB "${_ORT_LIBRARY}" + ) + # Search both standard (lib/) and NuGet (runtimes//native/) layouts + find_file(_ORT_DLL NAMES onnxruntime.dll + PATHS "${_ORT_ROOT}/lib" + "${_ORT_ROOT}/runtimes/win-x64/native" + "${_ORT_ROOT}/runtimes/win-arm64/native" + NO_DEFAULT_PATH + ) + if(_ORT_DLL) + set_target_properties(onnxruntime::onnxruntime PROPERTIES + IMPORTED_LOCATION "${_ORT_DLL}" + ) + endif() + else() + set_target_properties(onnxruntime::onnxruntime PROPERTIES + IMPORTED_LOCATION "${_ORT_LIBRARY}" + ) + if(APPLE) + get_filename_component(_ORT_LIB_NAME "${_ORT_LIBRARY}" NAME) + set_target_properties(onnxruntime::onnxruntime PROPERTIES + IMPORTED_SONAME "@rpath/${_ORT_LIB_NAME}" + ) + endif() + endif() + endif() +endif() + +# --------------------------------------------------------------------------- +# Standard find_package result handling +# --------------------------------------------------------------------------- +mark_as_advanced(_ORT_HEADER _ORT_LIBRARY) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(ONNXRuntime + REQUIRED_VARS _ORT_LIBRARY _ORT_HEADER +) + +if(ONNXRuntime_FOUND) + set(ONNXRuntime_INCLUDE_DIRS "${_ORT_INCLUDE_DIR}") + set(ONNXRuntime_LIBRARIES "${_ORT_LIBRARY}") + # Derive library directory from the actual found library path + # (works for all layouts: standard lib/, NuGet runtimes//native/, system) + get_filename_component(ONNXRuntime_LIB_DIR "${_ORT_LIBRARY}" DIRECTORY) + set(ONNXRuntime_LIB_DIR "${ONNXRuntime_LIB_DIR}" CACHE PATH "ONNX Runtime library directory") + set(ONNXRuntime_SYSTEM_PACKAGE ${_ORT_SYSTEM_PACKAGE} CACHE BOOL "Using system ONNX Runtime package") + mark_as_advanced(ONNXRuntime_LIB_DIR ONNXRuntime_SYSTEM_PACKAGE) +endif() diff --git a/data/CMakeLists.txt b/data/CMakeLists.txt index fc2399eb00f0..2b1d5ff7faf2 100644 --- a/data/CMakeLists.txt +++ b/data/CMakeLists.txt @@ -165,6 +165,14 @@ endif() FILE(COPY wb_presets.json DESTINATION "${DARKTABLE_DATADIR}") install(FILES wb_presets.json DESTINATION ${CMAKE_INSTALL_DATAROOTDIR}/darktable COMPONENT DTApplication) +if(BUILD_AI) + # + # Install AI models registry + # + FILE(COPY ai_models.json DESTINATION "${DARKTABLE_DATADIR}") + install(FILES ai_models.json DESTINATION ${CMAKE_INSTALL_DATAROOTDIR}/darktable COMPONENT DTApplication) +endif(BUILD_AI) + # # Transform darktableconfig.xml into darktablerc # diff --git a/data/ai_models.json b/data/ai_models.json new file mode 100644 index 000000000000..5d473e312e2b --- /dev/null +++ b/data/ai_models.json @@ -0,0 +1,5 @@ +{ + "version": 1, + "models": [ + ] +} diff --git a/data/darktableconfig.xml.in b/data/darktableconfig.xml.in index b2e408363ecb..9c8a33091b88 100644 --- a/data/darktableconfig.xml.in +++ b/data/darktableconfig.xml.in @@ -367,6 +367,27 @@ tags case sensitivity tags case sensitivity. without the Sqlite ICU extension, insensitivity works only for the 26 latin letters + + plugins/ai/enabled + bool + false + enable AI features + enable AI-powered features such as denoising and upscaling using neural networks + + + plugins/ai/provider + string + auto + AI execution provider + select the hardware acceleration provider for AI inference. 'auto' will automatically detect the best available option for your system. + + + plugins/ai/repository + string + andriiryzhkov/darktable-ai + AI models GitHub repository + GitHub repository for downloading AI models (format: owner/repo) + opencl bool diff --git a/packaging/macosx/1_install_hb_dependencies.sh b/packaging/macosx/1_install_hb_dependencies.sh index bfa41aa5dfe1..92fa1c561854 100755 --- a/packaging/macosx/1_install_hb_dependencies.sh +++ b/packaging/macosx/1_install_hb_dependencies.sh @@ -44,6 +44,7 @@ hbDependencies="adwaita-icon-theme \ json-glib \ lensfun \ libavif \ + libarchive \ libheif \ libomp \ libraw \ diff --git a/packaging/macosx/2_build_hb_darktable_default.sh b/packaging/macosx/2_build_hb_darktable_default.sh index a30845ce1904..c70ac9252239 100755 --- a/packaging/macosx/2_build_hb_darktable_default.sh +++ b/packaging/macosx/2_build_hb_darktable_default.sh @@ -21,7 +21,7 @@ options=" \ -DUSE_GRAPHICSMAGICK=OFF \ -DUSE_IMAGEMAGICK=ON \ -DBUILD_CURVE_TOOLS=ON \ - -DBUILD_NOISE_TOOLS=ON \ + -DBUILD_NOISE_TOOLS=ON " # Check for previous attempt and clean diff --git a/packaging/windows/README.md b/packaging/windows/README.md index 78df283ef473..71a7f04c91c4 100644 --- a/packaging/windows/README.md +++ b/packaging/windows/README.md @@ -17,7 +17,7 @@ are as follows: * Install required and recommended dependencies for darktable: ```bash - pacman -S --needed mingw-w64-ucrt-x86_64-{libxslt,python-jsonschema,curl,drmingw,exiv2,gettext,gmic,graphicsmagick,gtk3,icu,imath,iso-codes,lcms2,lensfun,libavif,libgphoto2,libheif,libjpeg-turbo,libjxl,libpng,libraw,librsvg,libsecret,libtiff,libwebp,libxml2,lua,openexr,openjpeg2,osm-gps-map,portmidi,potrace,pugixml,SDL2,sqlite3,webp-pixbuf-loader,zlib} + pacman -S --needed mingw-w64-ucrt-x86_64-{libxslt,python-jsonschema,curl,drmingw,exiv2,gettext,gmic,graphicsmagick,gtk3,icu,imath,iso-codes,lcms2,lensfun,libavif,libarchive,libgphoto2,libheif,libjpeg-turbo,libjxl,libpng,libraw,librsvg,libsecret,libtiff,libwebp,libxml2,lua,openexr,openjpeg2,osm-gps-map,portmidi,potrace,pugixml,SDL2,sqlite3,webp-pixbuf-loader,zlib} ``` * Install the optional tool for building an installer image (currently x64 only): diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index a8a66c94ce8d..eb5a09b73644 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,4 +1,14 @@ add_subdirectory(external) +if(BUILD_AI) + find_package(ONNXRuntime QUIET) + if(NOT ONNXRuntime_FOUND) + message(STATUS "ONNX Runtime not found — disabling AI features (BUILD_AI=OFF)") + set(BUILD_AI OFF) + endif() +endif() +if(BUILD_AI) + add_subdirectory(ai) +endif(BUILD_AI) include(CheckCSourceCompiles) include(CheckCXXSymbolExists) @@ -443,6 +453,27 @@ foreach(lib ${OUR_LIBS} GIO GThread GModule PangoCairo Rsvg2 PNG JPEG TIFF LCMS2 add_definitions(${${lib}_DEFINITIONS}) endforeach(lib) +if(BUILD_AI) + option(BUILD_AI_DOWNLOAD "Enable downloading AI models from repository" ON) + + list(APPEND LIBS darktable_ai) + + FILE(GLOB SOURCE_FILES_AI + "common/ai_models.c" + "gui/preferences_ai.c" + ) + set(SOURCES ${SOURCES} ${SOURCE_FILES_AI}) + add_definitions("-DHAVE_AI") + if(BUILD_AI_DOWNLOAD) + add_definitions("-DHAVE_AI_DOWNLOAD") + message(STATUS "AI features: enabled (with model download)") + else() + message(STATUS "AI features: enabled (without model download)") + endif() +else(BUILD_AI) + message(STATUS "AI features: disabled") +endif(BUILD_AI) + if(PNG_VERSION_STRING VERSION_LESS 1.5) message(FATAL_ERROR "libpng version 1.5 or newer is required") endif() @@ -451,6 +482,31 @@ endif() find_package(CURL 7.56 REQUIRED) list(APPEND LIBS CURL::libcurl) +if(BUILD_AI) + # libarchive for AI model ZIP extraction + # On macOS with Homebrew, libarchive is keg-only + if(APPLE) + execute_process( + COMMAND brew --prefix libarchive + OUTPUT_VARIABLE HOMEBREW_LIBARCHIVE_PREFIX + OUTPUT_STRIP_TRAILING_WHITESPACE + ERROR_QUIET + RESULT_VARIABLE HOMEBREW_LIBARCHIVE_RESULT + ) + if(HOMEBREW_LIBARCHIVE_RESULT EQUAL 0 AND EXISTS "${HOMEBREW_LIBARCHIVE_PREFIX}") + list(APPEND CMAKE_PREFIX_PATH "${HOMEBREW_LIBARCHIVE_PREFIX}") + endif() + endif() + find_package(LibArchive) + if(NOT LibArchive_FOUND) + message(STATUS "libarchive not found — disabling AI features (BUILD_AI=OFF)") + set(BUILD_AI OFF) + else() + include_directories(SYSTEM ${LibArchive_INCLUDE_DIRS}) + list(APPEND LIBS ${LibArchive_LIBRARIES}) + endif() +endif(BUILD_AI) + # Require exiv2 >= 0.27.2 to make sure everything we need is available find_package(Exiv2 0.27.2 REQUIRED) include_directories(SYSTEM ${Exiv2_INCLUDE_DIRS}) @@ -977,6 +1033,15 @@ add_dependencies(lib_darktable generate_authors_h) if(APPLE) set_target_properties(lib_darktable PROPERTIES MACOSX_RPATH TRUE) endif(APPLE) +if(NOT WIN32) + # lib_darktable and ONNX Runtime are installed to the same directory + # (${CMAKE_INSTALL_LIBDIR}/darktable/) — set RPATH so the linker finds it + set_target_properties(lib_darktable PROPERTIES INSTALL_RPATH "${RPATH_ORIGIN}") +endif() +if(WIN32) + # Export all symbols on Windows so modules can access library functions + set_target_properties(lib_darktable PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS TRUE) +endif(WIN32) set_target_properties(lib_darktable PROPERTIES OUTPUT_NAME darktable) set_target_properties(lib_darktable PROPERTIES LINKER_LANGUAGE C) diff --git a/src/ai/CMakeLists.txt b/src/ai/CMakeLists.txt new file mode 100644 index 000000000000..877d9a397e5a --- /dev/null +++ b/src/ai/CMakeLists.txt @@ -0,0 +1,83 @@ +cmake_minimum_required(VERSION 3.18) + +add_library(darktable_ai STATIC + backend.h + backend_common.c + backend_onnx.c +) + +# Find ONNX Runtime (auto-downloads if not present) +find_package(ONNXRuntime REQUIRED) + +# Find Dependencies (inherited from main build if available, but explicit here for clarity) +find_package(PkgConfig REQUIRED) +pkg_check_modules(GLIB REQUIRED glib-2.0 gobject-2.0 gio-2.0 gmodule-2.0) +pkg_check_modules(JSON_GLIB REQUIRED json-glib-1.0) +# GTK3/RSVG needed for headers only (common/darktable.h -> utility.h -> gtk/gtk.h, librsvg/rsvg.h) +pkg_check_modules(GTK3 REQUIRED gtk+-3.0) +pkg_check_modules(RSVG2 REQUIRED librsvg-2.0) + +# Include current directory for header +target_include_directories(darktable_ai PUBLIC + $ + $ + $ + ${GLIB_INCLUDE_DIRS} + ${JSON_GLIB_INCLUDE_DIRS} + ${GTK3_INCLUDE_DIRS} + ${RSVG2_INCLUDE_DIRS} +) + +target_link_directories(darktable_ai PUBLIC + ${GLIB_LIBRARY_DIRS} + ${JSON_GLIB_LIBRARY_DIRS} +) + +if(ONNXRuntime_SYSTEM_PACKAGE AND UNIX AND NOT APPLE) + # Ubuntu/Debian's libonnxruntime links against the system libonnx, so both + # register the same ONNX schemas at library load time, producing harmless but + # noisy "already registered" warnings on stderr. Suppress them by lazy-loading + # ORT at runtime via g_module_open, with stderr temporarily redirected, instead + # of having libonnxruntime.so as a direct ELF dependency of libdarktable.so. + # ONNXRuntime_LIBRARIES is set by FindONNXRuntime.cmake to the .so path. + target_compile_definitions(darktable_ai PRIVATE + ORT_LAZY_LOAD=1 + ORT_LIBRARY_PATH="${ONNXRuntime_LIBRARIES}") + target_include_directories(darktable_ai PRIVATE ${ONNXRuntime_INCLUDE_DIRS}) +else() + target_link_libraries(darktable_ai PUBLIC onnxruntime::onnxruntime) +endif() + +target_link_libraries(darktable_ai PUBLIC + ${GLIB_LIBRARIES} + ${JSON_GLIB_LIBRARIES} +) + +# Install the bundled ONNX Runtime shared library alongside darktable_ai +# (skip when using system package — it's already in the system library path) +if(ONNXRuntime_SYSTEM_PACKAGE) + # Nothing to install +elseif(APPLE) + file(GLOB _ORT_DYLIB "${ONNXRuntime_LIB_DIR}/libonnxruntime.*.dylib") + if(_ORT_DYLIB) + install(FILES ${_ORT_DYLIB} + DESTINATION ${CMAKE_INSTALL_LIBDIR}/darktable + ) + endif() +elseif(WIN32) + # Install all ORT DLLs (onnxruntime.dll, onnxruntime_providers_shared.dll, etc.) + # DirectML.dll is a system component on Windows 10 1903+ and does not need bundling. + file(GLOB _ORT_DLL "${ONNXRuntime_LIB_DIR}/onnxruntime*.dll") + if(_ORT_DLL) + install(FILES ${_ORT_DLL} + DESTINATION ${CMAKE_INSTALL_BINDIR} + ) + endif() +elseif(UNIX) + file(GLOB _ORT_SO "${ONNXRuntime_LIB_DIR}/libonnxruntime.so*") + if(_ORT_SO) + install(FILES ${_ORT_SO} + DESTINATION ${CMAKE_INSTALL_LIBDIR}/darktable + ) + endif() +endif() diff --git a/src/ai/README.md b/src/ai/README.md new file mode 100644 index 000000000000..6e867c971edd --- /dev/null +++ b/src/ai/README.md @@ -0,0 +1,447 @@ +# darktable AI Backend — Developer Documentation + +## Architecture Overview + +The AI subsystem is built as a static library (`darktable_ai`) that wraps +ONNX Runtime's C API behind a backend-agnostic interface. It provides: + +- **Model discovery** — scans directories for `config.json` manifests +- **Session management** — load/unload ONNX models with hardware acceleration +- **Tensor I/O** — type-safe inference with automatic Float32/Float16 conversion +- **Provider abstraction** — unified API for CPU, GPU, and NPU execution providers + +### Source Files + +| File | Role | +|------|------| +| `backend.h` | Public API: types, enums, function declarations | +| `backend_common.c` | Environment, model registry, provider string conversion | +| `backend_onnx.c` | ONNX Runtime C API wrapper, inference engine | +| `segmentation.h` | Segmentation (SAM) public API | +| `segmentation.c` | SAM encoder-decoder implementation | +| `CMakeLists.txt` | Build config, ONNX Runtime linkage, install rules | + +Higher-level consumers live outside `src/ai/`: + +| File | Role | +|------|------| +| `src/libs/denoise_ai.c` | Denoise lighttable module (tiled inference) | +| `src/develop/masks/object.c` | Object mask tool using segmentation API | +| `src/common/ai_models.c` | Model registry UI, download, extraction | +| `src/gui/preferences_ai.c` | AI preferences tab (provider selection, model management) | + +--- + +## ONNX Runtime Integration + +### Initialization + +ONNX Runtime is initialized lazily via `g_once()` singletons: +- **`g_ort`** — `OrtApi` pointer (one per process) +- **`g_env`** — `OrtEnv` instance (one per process) + +Both are created on first model load or provider probe and persist for the +lifetime of the process. + +### Model Loading + +``` +dt_ai_load_model(env, model_id, model_file, provider) + └─ dt_ai_load_model_ext(env, model_id, model_file, provider, opt_level, dim_overrides, n_overrides) + └─ dt_ai_onnx_load_ext(model_dir, model_file, provider, opt_level, dim_overrides, n_overrides) +``` + +Loading a model: +1. Resolves `model_id` to a directory path via the environment's model registry +2. Creates `OrtSessionOptions` with intra-op parallelism (all cores) +3. Sets graph optimization level +4. Applies symbolic dimension overrides (for dynamic-shape models) +5. Calls `_enable_acceleration()` to attach the selected execution provider +6. Creates `OrtSession` from the `.onnx` file +7. Introspects all input/output names, types, and shapes +8. Detects dynamic output shapes (any dim <= 0) for ORT-allocated output mode + +### Inference (`dt_ai_run`) + +Callers provide `dt_ai_tensor_t` arrays for inputs and outputs: + +```c +typedef struct dt_ai_tensor_t { + void *data; // Pointer to raw data buffer + dt_ai_dtype_t type; // Data type of elements + int64_t *shape; // Array of dimensions + int ndim; // Number of dimensions +} dt_ai_tensor_t; +``` + +The runtime handles two special cases transparently: + +- **Float16 auto-conversion**: If the caller provides Float32 data but the model + expects Float16, the backend converts on-the-fly (and vice versa for outputs). +- **Dynamic output shapes**: If any output has symbolic dimensions, ORT allocates + the output tensors internally. After inference, the backend copies data to the + caller's buffer and updates the caller's shape array with actual dimensions. + +### Graph Optimization Levels + +| Level | Enum | Use Case | +|-------|------|----------| +| All | `DT_AI_OPT_ALL` | Default, fastest. Works for most models. | +| Basic | `DT_AI_OPT_BASIC` | Constant folding + redundancy elimination only. Required for SAM2 decoder (aggressive optimization breaks shape inference on dynamic dims). | +| Disabled | `DT_AI_OPT_DISABLED` | Reserved for future use. | + +### Symbolic Dimension Overrides + +Models with symbolic dimensions (e.g., SAM2 decoder's `num_labels`) can cause +ONNX Runtime to fail shape inference. Use `dt_ai_load_model_ext()` with +`dt_ai_dim_override_t` to bind concrete values: + +```c +dt_ai_dim_override_t overrides[] = { { "num_labels", 1 } }; +ctx = dt_ai_load_model_ext(env, id, file, provider, + DT_AI_OPT_BASIC, overrides, 1); +``` + +--- + +## Execution Providers + +### Provider Table + +| Provider | Enum | Config String | Platform | +|----------|------|---------------|----------| +| Auto | `DT_AI_PROVIDER_AUTO` | `auto` | All | +| CPU | `DT_AI_PROVIDER_CPU` | `CPU` | All | +| Apple CoreML | `DT_AI_PROVIDER_COREML` | `CoreML` | macOS | +| NVIDIA CUDA | `DT_AI_PROVIDER_CUDA` | `CUDA` | Linux, Windows | +| AMD MIGraphX | `DT_AI_PROVIDER_MIGRAPHX` | `MIGraphX` | Linux | +| Intel OpenVINO | `DT_AI_PROVIDER_OPENVINO` | `OpenVINO` | Linux, Windows, macOS (x86_64) | +| Windows DirectML | `DT_AI_PROVIDER_DIRECTML` | `DirectML` | Windows | + +The `available` field in the provider descriptor is a compile-time platform +guard controlled by `#if` preprocessor directives. It determines which providers +are shown in the UI — runtime availability is checked separately. + +### Auto-Detection Strategy + +When `DT_AI_PROVIDER_AUTO` is selected, the backend tries platform-native +acceleration first and falls back gracefully: + +- **macOS**: CoreML +- **Windows**: DirectML +- **Linux**: CUDA → MIGraphX → ROCm (legacy) + +All providers fall back to CPU if the accelerator is unavailable. + +### Runtime Provider Loading + +Provider functions are loaded at runtime via dynamic symbol lookup +(`GModule`/`dlsym` on Unix, `GetProcAddress` on Windows) from the linked +ONNX Runtime shared library. This means: + +- No compile-time dependency on provider-specific headers +- Providers are optional — missing symbols are handled gracefully +- The same binary works with CPU-only and GPU-enabled ORT builds + +### MIGraphX / ROCm Fallback + +ONNX Runtime is transitioning from ROCm to MIGraphX on AMD. When the MIGraphX +provider is selected (or reached via auto-detection on Linux), the backend tries +MIGraphX first, then falls back to ROCm for older ORT builds that only ship the +ROCm provider. + +### Provider Probing + +`dt_ai_probe_provider()` tests if a provider can be initialized at runtime +without loading a model. It creates temporary `OrtSessionOptions`, attempts to +attach the provider, and returns 1 (available) or 0 (unavailable). This is used +by the preferences UI to show a warning when a selected provider is not +available. + +### ONNX Runtime Packages + +| Platform | Package Source | Providers Included | +|----------|---------------|--------------------| +| macOS | GitHub releases (CPU) | CPU, CoreML (via Apple frameworks) | +| Linux | GitHub releases (CPU) | CPU only | +| Linux | System packages (vendor repos) | CPU + CUDA / MIGraphX / OpenVINO | +| Windows | NuGet (DirectML variant) | CPU, DirectML | + +The build system (`cmake/modules/FindONNXRuntime.cmake`) auto-downloads the +appropriate package if ONNX Runtime is not found on the system. On Windows, it +downloads the DirectML NuGet package for vendor-agnostic GPU support (AMD, +Intel, NVIDIA via DirectX 12). + +--- + +## Model Discovery + +### Directory Layout + +Models are discovered by scanning for `config.json` files in subdirectories of: + +1. Custom paths passed to `dt_ai_env_init()` (semicolon-separated) +2. `/models/` — darktable's own config directory (respects `--configdir`) + - Linux: `~/.config/darktable/models/` + - Windows: `%APPDATA%\darktable\models\` + - macOS: `~/.config/darktable/models/` + +First discovered model ID wins (duplicates are skipped). + +Model downloads (`ai_models.c`) also extract to the configdir-based path, so +downloaded models are immediately discoverable. + +### config.json Format + +Each model directory must contain a `config.json`: + +```json +{ + "id": "denoise-nind", + "name": "denoise nind", + "description": "UNet denoiser trained on NIND dataset", + "task": "denoise", + "backend": "onnx", + "num_inputs": 1 +} +``` + +| Field | Required | Default | Description | +|-------|----------|---------|-------------| +| `id` | Yes | — | Unique identifier (see naming convention below) | +| `name` | Yes | — | Display name shown in UI | +| `description` | No | `""` | Short description | +| `task` | No | `"general"` | Task type: `"denoise"`, `"mask"` | +| `backend` | No | `"onnx"` | Backend type (only `"onnx"` supported) | +| `num_inputs` | No | `1` | Number of model inputs (used by denoise for multi-input models) | + +### Model ID Naming Convention + +Model IDs follow the pattern `-[-]`: + +| ID | Task | Description | +|----|------|-------------| +| `denoise-nind` | `denoise` | UNet denoiser (NIND dataset) | +| `mask-object-sam21-small` | `mask` | SAM 2.1 Hiera Small for object masking | +| `mask-light-hq-sam` | `mask` | SAM-HQ lightweight variant | + +Rules: +- Use lowercase, hyphen-separated +- First component is the task type (`denoise`, `mask`) +- For masks, the second component is the subtask (`object`, `light`) +- Append size suffix when multiple sizes exist (`small`, `base`, `large`) + +--- + +## Task: Denoise + +**Module**: `src/libs/denoise_ai.c` +**Task type**: `"denoise"` + +### Model Requirements + +#### Single-Input Denoise Models (NAFNet, UNet, etc.) + +| Tensor | Name | Shape | Type | Description | +|--------|------|-------|------|-------------| +| Input 0 | `input` | `[1, 3, H, W]` | float32 | RGB image, NCHW layout, sRGB color space | +| Output 0 | `output` | `[1, 3, H, W]` | float32 | Denoised RGB image, same layout | + +- H and W are dynamic (determined by tile size at runtime) +- Input and output spatial dimensions must match + +#### Multi-Input Denoise Models (FFDNet) + +| Tensor | Name | Shape | Type | Description | +|--------|------|-------|------|-------------| +| Input 0 | `input` | `[1, 3, H, W]` | float32 | RGB image, NCHW layout, sRGB color space | +| Input 1 | `sigma` | `[1, 1, H, W]` | float32 | Noise level map, values = sigma / 255.0 | +| Output 0 | `output` | `[1, 3, H, W]` | float32 | Denoised RGB image, same layout | + +Set `"num_inputs": 2` in `config.json` for multi-input models. + +### Color Space + +Models operate in sRGB. The denoise module converts: +- **Before inference**: linear RGB → sRGB (IEC 61966-2-1 transfer function) +- **After inference**: sRGB → linear RGB + +### Tiling Strategy + +Large images are processed in overlapping tiles to fit within memory: + +- **Tile sizes** (tried in order): 2048, 1536, 1024, 768, 512, 384, 256 +- **Overlap**: 64 pixels on each edge +- **Memory budget**: 1/4 of available darktable memory +- **Border handling**: mirror padding at image edges + +The output is streamed to TIFF scanlines — no full-resolution buffer is needed. + +--- + +## Task: Mask (Object Segmentation) + +**Module**: `src/ai/segmentation.c` +**API**: `src/ai/segmentation.h` +**Task type**: `"mask"` + +### Supported Models + +All SAM variants (SAM, SAM-HQ, SAM2/2.1) are exported to a common interface +via conversion scripts in the +[darktable-ai](https://github.com/andriiryzhkov/darktable-ai) repository. +The model directory must contain both `encoder.onnx` and `decoder.onnx`. + +### Encoder Requirements + +| Tensor | Shape | Type | Description | +|--------|-------|------|-------------| +| Input 0 | `[1, 3, 1024, 1024]` | float32 | Preprocessed image (CHW, ImageNet-normalized) | + +**Preprocessing** (applied by `segmentation.c`): +1. Resize longest side to 1024 px (bilinear), preserve aspect ratio +2. Zero-pad the shorter side to reach 1024x1024 +3. Normalize each channel: `(pixel - mean) / std` + - Mean: `[123.675, 116.28, 103.53]` + - Std: `[58.395, 57.12, 57.375]` +4. Convert HWC → CHW layout + +| Output | Typical Shape | Description | +|--------|---------------|-------------| +| 0 | `[1, 256, 64, 64]` | Image embeddings | +| 1 | `[1, 32, 256, 256]` | High-resolution feature (skip connection) | +| 2 (SAM2 only) | `[1, 64, 128, 128]` | Mid-resolution feature | + +Exact shapes vary by model variant. The encoder's outputs are passed to the +decoder, reordered by name matching if necessary (SAM2 encoder may output in a +different order than the decoder expects). + +### Decoder Requirements + +All model families use the same decoder interface. The number of encoder +outputs varies (2 for SAM-HQ, 3 for SAM2), but the prompt inputs and outputs +are identical. + +#### Inputs + +| Index | Name | Shape | Type | Description | +|-------|------|-------|------|-------------| +| 0..E | (encoder outputs) | (varies) | float32 | Passed through from encoder | +| E+1 | `point_coords` | `[1, N+1, 2]` | float32 | Point coordinates (N prompts + 1 padding) | +| E+2 | `point_labels` | `[1, N+1]` | float32 | 1=foreground, 0=background, -1=padding | +| E+3 | `mask_input` | `[1, 1, 256, 256]` | float32 | Low-res mask from previous iteration | +| E+4 | `has_mask_input` | `[1]` | float32 | 0.0 (first decode) or 1.0 (iterative refinement) | + +Where E = number of encoder outputs. No `orig_im_size` input — resizing to +original image dimensions is done by darktable at runtime. + +#### Outputs + +| Index | Name | Shape | Description | +|-------|------|-------|-------------| +| 0 | `masks` | `[1, M, 1024, 1024]` | Mask logits (pre-sigmoid) at fixed resolution | +| 1 | `iou_predictions` | `[1, M]` | Predicted IoU score per mask | +| 2 | `low_res_masks` | `[1, M, 256, 256]` | Low-res masks for iterative refinement | + +Where M = number of mask candidates: + +- **SAM-HQ** (hq-token-only): M=1 +- **SAM2**: M=3 (multi-mask) +- **SAM2 HQ**: M=4 (3 SAM + 1 HQ) + +All spatial dimensions must be concrete (no symbolic dims). The `masks` output +must include `F.interpolate` upsampling from 256 to 1024 in the exported graph. + +### Mask Post-Processing + +1. Select mask with highest IoU score +2. Crop out the zero-padded region +3. Bilinear resize to original image dimensions +4. Apply sigmoid: `mask = 1 / (1 + exp(-logits))` +5. Output values in `[0, 1]` range + +### Iterative Refinement + +The decoder supports iterative mask refinement: +- **First decode**: `has_mask_input = 0.0`, decoder ignores `mask_input` +- **Subsequent decodes**: previous low-res mask is fed back with `has_mask_input = 1.0` +- `dt_seg_reset_prev_mask()` clears the cached mask without clearing image embeddings + +### Encoding Cache + +The encoder output is cached in `dt_seg_context_t`. Multiple decode calls +(different point prompts) reuse the same encoding. Call +`dt_seg_reset_encoding()` when the image changes. + +### config.json for Mask Models + +```json +{ + "id": "mask-object-sam21-small", + "name": "sam2.1 hiera small", + "description": "Segment Anything 2.1 (Hiera Small) for masking", + "task": "mask", + "backend": "onnx" +} +``` + +The model directory must contain: + +``` +mask-object-sam21-small/ + config.json + encoder.onnx + decoder.onnx +``` + +--- + +## Exporting Models for darktable + +### Denoise Models + +Export to ONNX with dynamic batch and spatial dimensions: + +```python +torch.onnx.export(model, dummy_input, "model.onnx", + input_names=["input"], + output_names=["output"], + dynamic_axes={ + "input": {2: "height", 3: "width"}, + "output": {2: "height", 3: "width"} + }) +``` + +### SAM / SAM-HQ / SAM2 Models + +Export encoder and decoder as separate ONNX files (`encoder.onnx`, `decoder.onnx`). +Conversion scripts are maintained in the +[darktable-ai](https://github.com/andriiryzhkov/darktable-ai) repository. + +All decoders must follow the interface described above: + +- No `orig_im_size` input +- `masks` output at fixed 1024x1024 (include `F.interpolate` in the graph) +- `low_res_masks` output at 256x256 (for iterative refinement) +- All spatial dimensions concrete (no symbolic dims like `num_labels`) +- Only `num_points` may be dynamic (for variable number of point prompts) + +--- + +## Adding a New Provider + +1. Add enum value to `dt_ai_provider_t` in `backend.h` (before `DT_AI_PROVIDER_COUNT`) +2. Add entry to `dt_ai_providers[]` table with config string, display name, and platform guard +3. Add `case` in `_enable_acceleration()` in `backend_onnx.c` with the ORT append function symbol +4. Add `case` in `dt_ai_probe_provider()` in `backend_onnx.c` +5. The `_Static_assert` ensures the table stays in sync with the enum + +--- + +## Adding a New Task + +1. Create a new source file (e.g., `src/ai/newtask.c` + `newtask.h`) +2. Use `dt_ai_load_model()` / `dt_ai_load_model_ext()` to load models +3. Use `dt_ai_run()` for inference with `dt_ai_tensor_t` arrays +4. Set `"task": "newtask"` in the model's `config.json` +5. Add the new files to `src/ai/CMakeLists.txt` diff --git a/src/ai/backend.h b/src/ai/backend.h new file mode 100644 index 000000000000..c34878a1043d --- /dev/null +++ b/src/ai/backend.h @@ -0,0 +1,308 @@ +/* + This file is part of darktable, + Copyright (C) 2026 darktable developers. + + darktable is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + darktable is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with darktable. If not, see . +*/ + +#pragma once + +#include + +/** + * @brief AI Execution Provider + */ +typedef enum { + DT_AI_PROVIDER_AUTO = 0, + DT_AI_PROVIDER_CPU, + DT_AI_PROVIDER_COREML, + DT_AI_PROVIDER_CUDA, + DT_AI_PROVIDER_MIGRAPHX, + DT_AI_PROVIDER_OPENVINO, + DT_AI_PROVIDER_DIRECTML, + DT_AI_PROVIDER_COUNT // must be last +} dt_ai_provider_t; + +/** + * @brief Provider descriptor: maps enum to config/display strings. + * + * config_string: persisted to darktablerc, matches ONNX Runtime provider names + * display_name: shown in UI combo boxes and log messages + * available: compile-time platform guard (FALSE = hidden from UI) + */ +typedef struct dt_ai_provider_desc_t { + dt_ai_provider_t value; + const char *config_string; + const char *display_name; + int available; +} dt_ai_provider_desc_t; + +extern const dt_ai_provider_desc_t dt_ai_providers[DT_AI_PROVIDER_COUNT]; + +/** Config key for the AI execution provider preference */ +#define DT_AI_CONF_PROVIDER "plugins/ai/provider" + +/** Get display name for a provider enum value */ +const char *dt_ai_provider_to_string(dt_ai_provider_t provider); + +/** Parse provider from config string (with legacy alias support) */ +dt_ai_provider_t dt_ai_provider_from_string(const char *str); + +/** Test if a provider is available at runtime (checks deps, not just compile-time). + * @return 1 if available, 0 if not. */ +int dt_ai_probe_provider(dt_ai_provider_t provider); + +/** + * @brief Graph Optimization Level + * + * Models with fully dynamic output shapes (e.g. SAM2 decoder) can fail + * under aggressive graph optimization because ONNX Runtime's shape + * inference mis-computes intermediate tensor sizes. Use DT_AI_OPT_BASIC + * for such models to avoid internal shape validation errors. + */ +typedef enum { + DT_AI_OPT_ALL = 0, ///< All optimizations (default, fastest) + DT_AI_OPT_BASIC = 1, ///< Basic only (constant folding, redundant node elimination) + DT_AI_OPT_DISABLED = 2, ///< No optimization (reserved for future use) +} dt_ai_opt_level_t; + +/** + * @brief Library Environment Handle + * Opaque handle representing the initialized AI library environment. + */ +typedef struct dt_ai_environment_t dt_ai_environment_t; + +/** + * @brief Execution Context Handle + * Opaque handle for a loaded model session. + */ +typedef struct dt_ai_context_t dt_ai_context_t; + +/** + * @brief Model Metadata (ReadOnly) + */ +typedef struct dt_ai_model_info_t { + const char *id; ///< Unique ID (e.g. "nafnet-sidd") + const char *name; ///< Display name + const char *description; ///< Short description + const char *task_type; ///< e.g. "denoise", "inpainting" + const char *backend; ///< Backend type (e.g. "onnx") + int num_inputs; ///< Number of model inputs (default 1) +} dt_ai_model_info_t; + +/* --- Discovery --- */ + +/** + * @brief Initialize the library environment and scan for models. + * @param search_paths Semicolon-separated list of paths to scan. + * @return dt_ai_environment_t* Handle, or NULL on error. + */ +dt_ai_environment_t *dt_ai_env_init(const char *search_paths); + +/** + * @brief Get the number of discovered models. + */ +int dt_ai_get_model_count(dt_ai_environment_t *env); + +/** + * @brief Get model details by index. + * @param env The environment handle. + * @param index Index 0 to count-1. + * @return const dt_ai_model_info_t* Pointer to info struct. + */ +const dt_ai_model_info_t * +dt_ai_get_model_info_by_index(dt_ai_environment_t *env, int index); + +/** + * @brief Get model details by unique ID. + * @param env The environment handle. + * @param id The unique ID of the model. + * @return const dt_ai_model_info_t* Pointer to info struct. + */ +const dt_ai_model_info_t * +dt_ai_get_model_info_by_id(dt_ai_environment_t *env, const char *id); + +/** + * @brief Refresh the environment by rescanning model directories. + * @param env The environment handle to refresh. + * @note Call this after downloading new models. + */ +void dt_ai_env_refresh(dt_ai_environment_t *env); + +/** + * @brief Cleanup the library environment. + * @param env The environment handle to destroy. + */ +void dt_ai_env_destroy(dt_ai_environment_t *env); + +/** + * @brief Set the default execution provider for this environment. + * When dt_ai_load_model / dt_ai_load_model_opts is called with + * DT_AI_PROVIDER_AUTO, the environment's provider is used instead. + * @param env The environment handle. + * @param provider The provider to use (DT_AI_PROVIDER_AUTO = platform auto-detect). + */ +void dt_ai_env_set_provider(dt_ai_environment_t *env, dt_ai_provider_t provider); + +/** + * @brief Get the default execution provider for this environment. + * @param env The environment handle. + * @return The currently set provider. + */ +dt_ai_provider_t dt_ai_env_get_provider(dt_ai_environment_t *env); + +/* --- Execution --- */ + +/** + * @brief Load a model for execution from the registry. + * @param env Library environment. + * @param model_id ID of the model to load. + * @param model_file Filename within the model directory (NULL = "model.onnx"). + * @param provider Execution provider to use for hardware acceleration. + * @return dt_ai_context_t* Context ready for inference, or NULL. + */ +dt_ai_context_t *dt_ai_load_model(dt_ai_environment_t *env, + const char *model_id, + const char *model_file, + dt_ai_provider_t provider); + +/** + * @brief Symbolic dimension override for models with dynamic shapes. + */ +typedef struct { + const char *name; ///< Symbolic dimension name (e.g. "num_labels") + int64_t value; ///< Concrete value to use +} dt_ai_dim_override_t; + +/** + * @brief Load a model with optimization options and symbolic dimension overrides. + * Dimension overrides fix shape inference for models with symbolic dims + * that prevent ONNX Runtime from resolving intermediate tensor shapes. + * @param env Library environment. + * @param model_id ID of the model to load. + * @param model_file Filename within the model directory (NULL = "model.onnx"). + * @param provider Execution provider to use for hardware acceleration. + * @param opt_level Graph optimization level. + * @param dim_overrides Array of symbolic dimension overrides (NULL = none). + * @param n_overrides Number of overrides. + * @return dt_ai_context_t* Context ready for inference, or NULL. + */ +dt_ai_context_t *dt_ai_load_model_ext(dt_ai_environment_t *env, + const char *model_id, + const char *model_file, + dt_ai_provider_t provider, + dt_ai_opt_level_t opt_level, + const dt_ai_dim_override_t *dim_overrides, + int n_overrides); + +/** + * @brief Tensor Data Types + */ +typedef enum { + DT_AI_FLOAT = 1, + DT_AI_UINT8 = 2, + DT_AI_INT8 = 3, + DT_AI_INT32 = 4, + DT_AI_INT64 = 5, + DT_AI_FLOAT16 = 10 +} dt_ai_dtype_t; + +/** + * @brief Tensor descriptor for I/O + */ +typedef struct dt_ai_tensor_t { + void *data; ///< Pointer to raw data buffer + dt_ai_dtype_t type; ///< Data type of elements + int64_t *shape; ///< Array of dimensions + int ndim; ///< Number of dimensions +} dt_ai_tensor_t; + +/** + * @brief Run inference through the loaded model. + * @param ctx The AI context. + * @param inputs Array of input tensors. + * @param num_inputs Number of input tensors. + * @param outputs Array of output tensors. + * @param num_outputs Number of output tensors. + * @return int 0 on success, <0 on error. + */ +int dt_ai_run(dt_ai_context_t *ctx, dt_ai_tensor_t *inputs, + int num_inputs, dt_ai_tensor_t *outputs, + int num_outputs); + +/** + * @brief Get the number of model inputs. + * @param ctx The AI context. + * @return Number of inputs, or 0 if ctx is NULL. + */ +int dt_ai_get_input_count(dt_ai_context_t *ctx); + +/** + * @brief Get the number of model outputs. + * @param ctx The AI context. + * @return Number of outputs, or 0 if ctx is NULL. + */ +int dt_ai_get_output_count(dt_ai_context_t *ctx); + +/** + * @brief Get the name of a model input by index. + * @param ctx The AI context. + * @param index Input index (0-based). + * @return Input name string (owned by ctx, do not free), or NULL. + */ +const char *dt_ai_get_input_name(dt_ai_context_t *ctx, int index); + +/** + * @brief Get the data type of a model input by index. + * @param ctx The AI context. + * @param index Input index (0-based). + * @return Data type, or DT_AI_FLOAT as fallback. + */ +dt_ai_dtype_t dt_ai_get_input_type(dt_ai_context_t *ctx, + int index); + +/** + * @brief Get the name of a model output by index. + * @param ctx The AI context. + * @param index Output index (0-based). + * @return Output name string (owned by ctx, do not free), or NULL. + */ +const char *dt_ai_get_output_name(dt_ai_context_t *ctx, + int index); + +/** + * @brief Get the data type of a model output by index. + * @param ctx The AI context. + * @param index Output index (0-based). + * @return Data type, or DT_AI_FLOAT as fallback. + */ +dt_ai_dtype_t dt_ai_get_output_type(dt_ai_context_t *ctx, + int index); + +/** + * @brief Get the shape of a model output by index. + * @param ctx The AI context. + * @param index Output index (0-based). + * @param shape Output array to fill with dimension sizes. + * @param max_dims Maximum number of dimensions to write. + * @return Number of dimensions, or -1 on error. + */ +int dt_ai_get_output_shape(dt_ai_context_t *ctx, int index, + int64_t *shape, int max_dims); + +/** + * @brief Unload a model and free execution context. + * @param ctx The AI context to unload. + */ +void dt_ai_unload_model(dt_ai_context_t *ctx); diff --git a/src/ai/backend_common.c b/src/ai/backend_common.c new file mode 100644 index 000000000000..ee72580620ab --- /dev/null +++ b/src/ai/backend_common.c @@ -0,0 +1,442 @@ +/* + This file is part of darktable, + Copyright (C) 2026 darktable developers. + + darktable is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + darktable is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with darktable. If not, see . +*/ + +#include "backend.h" +#include "common/darktable.h" +#include "control/conf.h" +#include +#include +#include + +// provider table + +// clang-format off +const dt_ai_provider_desc_t dt_ai_providers[DT_AI_PROVIDER_COUNT] = { + { DT_AI_PROVIDER_AUTO, "auto", "auto", + 1 }, + { DT_AI_PROVIDER_CPU, "CPU", "CPU", + 1 }, + { DT_AI_PROVIDER_COREML, "CoreML", "Apple CoreML", +#if defined(__APPLE__) + 1 +#else + 0 +#endif + }, + { DT_AI_PROVIDER_CUDA, "CUDA", "NVIDIA CUDA", +#if defined(__linux__) + 1 +#else + 0 +#endif + }, + { DT_AI_PROVIDER_MIGRAPHX, "MIGraphX", "AMD MIGraphX", +#if defined(__linux__) + 1 +#else + 0 +#endif + }, + { DT_AI_PROVIDER_OPENVINO, "OpenVINO", "Intel OpenVINO", +#if defined(__linux__) || (defined(__APPLE__) && defined(__x86_64__)) + 1 +#else + 0 +#endif + }, + { DT_AI_PROVIDER_DIRECTML, "DirectML", "Windows DirectML", +#if defined(_WIN32) + 1 +#else + 0 +#endif + }, +}; +// clang-format on + +struct dt_ai_environment_t +{ + GList *models; // list of dt_ai_model_info_t* + GHashTable *model_paths; // id -> path (string) + + // to keep pointers in dt_ai_model_info_t valid + GList *string_storage; // list of char* + + // remembered for refresh + char *search_paths; + + // default execution provider (read from config at init, override with dt_ai_env_set_provider) + dt_ai_provider_t provider; // DT_AI_PROVIDER_AUTO = platform auto-detect + + GMutex lock; // thread safety for model list access +}; + +static void _store_string(dt_ai_environment_t *env, const char *str, const char **out_ptr) +{ + char *copy = g_strdup(str); + env->string_storage = g_list_prepend(env->string_storage, copy); + *out_ptr = copy; +} + +static void _scan_directory(dt_ai_environment_t *env, const char *root_path) +{ + GDir *dir = g_dir_open(root_path, 0, NULL); + if(!dir) + return; + + const char *entry_name; + while((entry_name = g_dir_read_name(dir))) + { + char *full_path = g_build_filename(root_path, entry_name, NULL); + if(g_file_test(full_path, G_FILE_TEST_IS_DIR)) + { + char *config_path = g_build_filename(full_path, "config.json", NULL); + + if(g_file_test(config_path, G_FILE_TEST_EXISTS)) + { + JsonParser *parser = json_parser_new(); + GError *error = NULL; + + if(json_parser_load_from_file(parser, config_path, &error)) + { + JsonNode *root = json_parser_get_root(parser); + JsonObject *obj = json_node_get_object(root); + + const char *id = json_object_get_string_member(obj, "id"); + const char *name = json_object_get_string_member(obj, "name"); + const char *desc = json_object_has_member(obj, "description") + ? json_object_get_string_member(obj, "description") + : ""; + const char *task = json_object_has_member(obj, "task") + ? json_object_get_string_member(obj, "task") + : "general"; + const char *backend = json_object_has_member(obj, "backend") + ? json_object_get_string_member(obj, "backend") + : "onnx"; + + if(id && name) + { + // skip duplicate model IDs (first discovered wins) + if(g_hash_table_contains(env->model_paths, id)) + { + dt_print(DT_DEBUG_AI, "[darktable_ai] skipping duplicate model ID: %s", id); + } + else + { + dt_ai_model_info_t *info = g_new0(dt_ai_model_info_t, 1); + _store_string(env, id, &info->id); + _store_string(env, name, &info->name); + _store_string(env, desc, &info->description); + _store_string(env, task, &info->task_type); + _store_string(env, backend, &info->backend); + info->num_inputs = json_object_has_member(obj, "num_inputs") + ? (int)json_object_get_int_member(obj, "num_inputs") + : 1; + + env->models = g_list_prepend(env->models, info); + g_hash_table_insert( + env->model_paths, + g_strdup(info->id), + g_strdup(full_path)); + + dt_print(DT_DEBUG_AI, + "[darktable_ai] discovered: %s (%s, backend=%s)", + name, id, backend); + } + } + } + else + { + dt_print(DT_DEBUG_AI, "[darktable_ai] parse error: %s", error->message); + g_error_free(error); + } + g_object_unref(parser); + } + g_free(config_path); + } + g_free(full_path); + } + g_dir_close(dir); +} + +// scan custom search_paths + default config/data directories +static void _scan_all_paths(dt_ai_environment_t *env) +{ + if(env->search_paths) + { + char **tokens = g_strsplit(env->search_paths, ";", -1); + for(int i = 0; tokens[i] != NULL; i++) + { + _scan_directory(env, tokens[i]); + } + g_strfreev(tokens); + } + + // scan XDG data dir where the registry downloads/extracts models + char *datadir = g_build_filename(g_get_user_data_dir(), "darktable", "models", NULL); + _scan_directory(env, datadir); + g_free(datadir); +} + +// API implementation + +dt_ai_environment_t *dt_ai_env_init(const char *search_paths) +{ + dt_print(DT_DEBUG_AI, "[darktable_ai] dt_ai_env_init start."); + + dt_ai_environment_t *env = g_new0(dt_ai_environment_t, 1); + g_mutex_init(&env->lock); + env->model_paths = g_hash_table_new_full(g_str_hash, g_str_equal, g_free, g_free); + env->search_paths = g_strdup(search_paths); + + // read user's preferred execution provider from config + char *prov_str = dt_conf_get_string(DT_AI_CONF_PROVIDER); + env->provider = dt_ai_provider_from_string(prov_str); + g_free(prov_str); + + _scan_all_paths(env); + env->models = g_list_reverse(env->models); + + return env; +} + +int dt_ai_get_model_count(dt_ai_environment_t *env) +{ + if(!env) + return 0; + g_mutex_lock(&env->lock); + int count = g_list_length(env->models); + g_mutex_unlock(&env->lock); + return count; +} + +const dt_ai_model_info_t * +dt_ai_get_model_info_by_index(dt_ai_environment_t *env, int index) +{ + if(!env) + return NULL; + g_mutex_lock(&env->lock); + GList *item = g_list_nth(env->models, index); + const dt_ai_model_info_t *info = item ? (const dt_ai_model_info_t *)item->data : NULL; + g_mutex_unlock(&env->lock); + return info; +} + +const dt_ai_model_info_t * +dt_ai_get_model_info_by_id(dt_ai_environment_t *env, const char *id) +{ + if(!env || !id) + return NULL; + g_mutex_lock(&env->lock); + const dt_ai_model_info_t *result = NULL; + for(GList *l = env->models; l != NULL; l = l->next) + { + dt_ai_model_info_t *info = (dt_ai_model_info_t *)l->data; + if(strcmp(info->id, id) == 0) + { + result = info; + break; + } + } + g_mutex_unlock(&env->lock); + return result; +} + +static void _free_model_info(gpointer data) { g_free(data); } + +void dt_ai_env_refresh(dt_ai_environment_t *env) +{ + if(!env) + return; + + g_mutex_lock(&env->lock); + + dt_print(DT_DEBUG_AI, "[darktable_ai] refreshing model list"); + + // clear existing data + g_list_free_full(env->models, _free_model_info); + env->models = NULL; + + g_list_free_full(env->string_storage, g_free); + env->string_storage = NULL; + + g_hash_table_remove_all(env->model_paths); + + _scan_all_paths(env); + + dt_print(DT_DEBUG_AI, + "[darktable_ai] refresh complete, found %d models", + g_list_length(env->models)); + + g_mutex_unlock(&env->lock); +} + +void dt_ai_env_destroy(dt_ai_environment_t *env) +{ + if(!env) + return; + + g_list_free_full(env->models, _free_model_info); + g_list_free_full(env->string_storage, g_free); + g_hash_table_destroy(env->model_paths); + g_free(env->search_paths); + g_mutex_clear(&env->lock); + + g_free(env); +} + +void dt_ai_env_set_provider(dt_ai_environment_t *env, dt_ai_provider_t provider) +{ + if(!env) + return; + g_mutex_lock(&env->lock); + env->provider = provider; + g_mutex_unlock(&env->lock); +} + +dt_ai_provider_t dt_ai_env_get_provider(dt_ai_environment_t *env) +{ + if(!env) + return DT_AI_PROVIDER_AUTO; + g_mutex_lock(&env->lock); + const dt_ai_provider_t p = env->provider; + g_mutex_unlock(&env->lock); + return p; +} + +// =backend-specific load (defined in backend_onnx.c) + +extern dt_ai_context_t * +dt_ai_onnx_load_ext(const char *model_dir, const char *model_file, + dt_ai_provider_t provider, dt_ai_opt_level_t opt_level, + const dt_ai_dim_override_t *dim_overrides, int n_overrides); + +// model loading with backend dispatch + +dt_ai_context_t *dt_ai_load_model(dt_ai_environment_t *env, + const char *model_id, + const char *model_file, + dt_ai_provider_t provider) +{ + return dt_ai_load_model_ext(env, model_id, model_file, provider, + DT_AI_OPT_ALL, NULL, 0); +} + +dt_ai_context_t *dt_ai_load_model_ext(dt_ai_environment_t *env, + const char *model_id, + const char *model_file, + dt_ai_provider_t provider, + dt_ai_opt_level_t opt_level, + const dt_ai_dim_override_t *dim_overrides, + int n_overrides) +{ + if(!env || !model_id) + return NULL; + + // resolve AUTO: re-read from config so preference changes take effect + // immediately without requiring app restart. Read config before acquiring + // env->lock to avoid lock-ordering issues with darktable's config lock + if(provider == DT_AI_PROVIDER_AUTO) + { + char *prov_str = dt_conf_get_string(DT_AI_CONF_PROVIDER); + provider = dt_ai_provider_from_string(prov_str); + g_free(prov_str); + } + + g_mutex_lock(&env->lock); + const char *model_dir_orig + = (const char *)g_hash_table_lookup(env->model_paths, model_id); + char *model_dir = model_dir_orig ? g_strdup(model_dir_orig) : NULL; + + const char *backend = "onnx"; + for(GList *l = env->models; l != NULL; l = l->next) + { + dt_ai_model_info_t *info = (dt_ai_model_info_t *)l->data; + if(strcmp(info->id, model_id) == 0) + { + backend = info->backend; + break; + } + } + char *backend_copy = g_strdup(backend); + g_mutex_unlock(&env->lock); + + if(!model_dir) + { + dt_print(DT_DEBUG_AI, "[darktable_ai] ID not found: %s", model_id); + g_free(backend_copy); + return NULL; + } + + dt_ai_context_t *ctx = NULL; + + if(strcmp(backend_copy, "onnx") == 0) + { + ctx = dt_ai_onnx_load_ext(model_dir, model_file, provider, opt_level, + dim_overrides, n_overrides); + } + else + { + dt_print(DT_DEBUG_AI, + "[darktable_ai] unknown backend '%s' for model '%s'", + backend_copy, model_id); + } + + g_free(model_dir); + g_free(backend_copy); + return ctx; +} + +// provider string conversion + +const char *dt_ai_provider_to_string(dt_ai_provider_t provider) +{ + for(int i = 0; i < DT_AI_PROVIDER_COUNT; i++) + { + if(dt_ai_providers[i].value == provider) + return dt_ai_providers[i].display_name; + } + return dt_ai_providers[0].display_name; // fallback to "auto" +} + +dt_ai_provider_t dt_ai_provider_from_string(const char *str) +{ + if(!str) + return DT_AI_PROVIDER_AUTO; + + // match against config_string (primary) and display_name + for(int i = 0; i < DT_AI_PROVIDER_COUNT; i++) + { + if(g_ascii_strcasecmp(str, dt_ai_providers[i].config_string) == 0) + return dt_ai_providers[i].value; + if(g_ascii_strcasecmp(str, dt_ai_providers[i].display_name) == 0) + return dt_ai_providers[i].value; + } + + // legacy alias: ROCm was renamed to MIGraphX + if(g_ascii_strcasecmp(str, "ROCm") == 0) + return DT_AI_PROVIDER_MIGRAPHX; + + return DT_AI_PROVIDER_AUTO; +} + +// clang-format off +// modelines: These editor modelines have been set for all relevant files by tools/update_modelines.py +// vim: shiftwidth=2 expandtab tabstop=2 cindent +// kate: tab-indents: off; indent-width 2; replace-tabs on; indent-mode cstyle; remove-trailing-spaces modified; +// clang-format on diff --git a/src/ai/backend_onnx.c b/src/ai/backend_onnx.c new file mode 100644 index 000000000000..5c0a6a88ff72 --- /dev/null +++ b/src/ai/backend_onnx.c @@ -0,0 +1,1311 @@ +/* + This file is part of darktable, + Copyright (C) 2026 darktable developers. + + darktable is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + darktable is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with darktable. If not, see . +*/ + +#include "backend.h" +#include "common/darktable.h" +#include +#include +#include +#include +#include +#include +#ifdef _WIN32 +#include +#else +#include +#include +#endif + +struct dt_ai_context_t +{ + // ONNX runtime C objects + OrtSession *session; + OrtMemoryInfo *memory_info; + + // IO names + OrtAllocator *allocator; + char **input_names; + char **output_names; + size_t input_count; + dt_ai_dtype_t *input_types; + size_t output_count; + dt_ai_dtype_t *output_types; + + // TRUE when any output has symbolic/dynamic shape dims. + // in that case dt_ai_run() lets ORT allocate outputs and copies back + gboolean dynamic_outputs; +}; + +// global singletons (initialized exactly once via g_once) +// ORT requires at most one OrtEnv per process. +static const OrtApi *g_ort = NULL; +static GOnce g_ort_once = G_ONCE_INIT; +static OrtEnv *g_env = NULL; +static GOnce g_env_once = G_ONCE_INIT; + +#ifdef ORT_LAZY_LOAD +// redirect fd 2 to /dev/null. returns the saved fd on success, -1 on failure. +static int _stderr_suppress_begin(void) +{ + int saved = dup(STDERR_FILENO); + if(saved == -1) return -1; + int devnull = open("/dev/null", O_WRONLY); + if(devnull == -1) { close(saved); return -1; } + dup2(devnull, STDERR_FILENO); + close(devnull); + return saved; +} +// restore fd 2 from the saved fd returned by _stderr_suppress_begin. +static void _stderr_suppress_end(int saved) +{ + if(saved != -1) { dup2(saved, STDERR_FILENO); close(saved); } +} +#endif + +static gpointer _init_ort_api(gpointer data) +{ + (void)data; + const OrtApi *api = NULL; + +#ifdef ORT_LAZY_LOAD + // Ubuntu/Debian's system ORT links against libonnx, causing harmless but noisy + // "already registered" ONNX schema warnings when the library is first loaded. + // suppress them by loading ORT explicitly, with stderr temporarily redirected. + // G_MODULE_BIND_LAZY = RTLD_LAZY; default (no BIND_LOCAL) = RTLD_GLOBAL so + // provider symbols remain visible to the rest of the process via dlsym(NULL). + const int saved = _stderr_suppress_begin(); + // the handle is intentionally not stored: ORT must stay loaded for the process + // lifetime and g_module_close is never called, so the library stays resident. + GModule *ort_mod = g_module_open(ORT_LIBRARY_PATH, G_MODULE_BIND_LAZY); + _stderr_suppress_end(saved); + + if(!ort_mod) + { + dt_print(DT_DEBUG_AI, + "[darktable_ai] failed to load ORT library '%s': %s", + ORT_LIBRARY_PATH, g_module_error()); + return NULL; + } + typedef const OrtApiBase *(*OrtGetApiBaseFn)(void); + OrtGetApiBaseFn get_api_base = NULL; + if(!g_module_symbol(ort_mod, "OrtGetApiBase", (gpointer *)&get_api_base) || !get_api_base) + { + dt_print(DT_DEBUG_AI, "[darktable_ai] OrtGetApiBase symbol not found"); + return NULL; + } + api = get_api_base()->GetApi(ORT_API_VERSION); +#else + api = OrtGetApiBase()->GetApi(ORT_API_VERSION); +#endif + + if(!api) + dt_print(DT_DEBUG_AI, "[darktable_ai] failed to init ONNX runtime API"); + else + g_ort = api; + return (gpointer)api; +} + +static gpointer _init_ort_env(gpointer data) +{ + (void)data; + OrtEnv *env = NULL; +#ifdef ORT_LAZY_LOAD + // ORT may emit additional schema-registration noise during env creation. + const int saved = _stderr_suppress_begin(); +#endif + OrtStatus *status = g_ort->CreateEnv(ORT_LOGGING_LEVEL_WARNING, "DarktableAI", &env); +#ifdef ORT_LAZY_LOAD + _stderr_suppress_end(saved); +#endif + if(status) + { + dt_print(DT_DEBUG_AI, + "[darktable_ai] failed to create ORT environment: %s", + g_ort->GetErrorMessage(status)); + g_ort->ReleaseStatus(status); + return NULL; + } + g_env = env; + return (gpointer)env; +} + +// map ONNX tensor element type to our dt_ai_dtype_t. +// returns TRUE on success, FALSE if the type is unsupported +static gboolean _map_onnx_type(ONNXTensorElementDataType onnx_type, dt_ai_dtype_t *out) +{ + switch(onnx_type) + { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + *out = DT_AI_FLOAT; + return TRUE; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: + *out = DT_AI_FLOAT16; + return TRUE; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + *out = DT_AI_UINT8; + return TRUE; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: + *out = DT_AI_INT8; + return TRUE; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + *out = DT_AI_INT32; + return TRUE; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: + *out = DT_AI_INT64; + return TRUE; + default: + return FALSE; + } +} + +// compute total element count from shape dimensions with overflow checking. +// returns the product of all shape dimensions, or -1 if any dimension is +// non-positive or the multiplication would overflow int64_t +static int64_t _safe_element_count(const int64_t *shape, int ndim) +{ + int64_t count = 1; + for(int i = 0; i < ndim; i++) + { + if(shape[i] <= 0) + return -1; + if(count > INT64_MAX / shape[i]) + return -1; + count *= shape[i]; + } + return count; +} + +// map dt_ai_dtype_t to ONNX type and element size. +// returns TRUE on success, FALSE if the type is unsupported +static gboolean +_dtype_to_onnx(dt_ai_dtype_t dtype, ONNXTensorElementDataType *out_type, size_t *out_size) +{ + switch(dtype) + { + case DT_AI_FLOAT: + *out_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + *out_size = sizeof(float); + return TRUE; + case DT_AI_FLOAT16: + *out_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; + *out_size = sizeof(uint16_t); + return TRUE; + case DT_AI_UINT8: + *out_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; + *out_size = sizeof(uint8_t); + return TRUE; + case DT_AI_INT8: + *out_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8; + *out_size = sizeof(int8_t); + return TRUE; + case DT_AI_INT32: + *out_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; + *out_size = sizeof(int32_t); + return TRUE; + case DT_AI_INT64: + *out_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; + *out_size = sizeof(int64_t); + return TRUE; + default: + return FALSE; + } +} + +// float16 conversion utilities +// based on: https://gist.github.com/rygorous/2156668 +// handles zero, denormals, and infinity correctly +static uint16_t _float_to_half(float f) +{ + uint32_t x; + memcpy(&x, &f, sizeof(x)); + uint32_t sign = (x >> 31) & 1; + uint32_t exp = (x >> 23) & 0xFF; + uint32_t mant = x & 0x7FFFFF; + + // handle zero and float32 denormals (too small for float16) + if(exp == 0) + return (uint16_t)(sign << 15); + + // handle infinity / NaN + if(exp == 255) + return (uint16_t)((sign << 15) | 0x7C00 | (mant ? 1 : 0)); + + // re-bias exponent from float32 (bias 127) to float16 (bias 15) + const int new_exp = (int)exp - 127 + 15; + + if(new_exp <= 0) + { + // encode as float16 denormal: shift mantissa with implicit leading 1 + // the implicit 1 bit plus 10 mantissa bits, shifted right by (1 - new_exp) + const int shift = 1 - new_exp; + if(shift > 24) + return (uint16_t)(sign << 15); // too small even for denormal + const uint32_t full_mant = (1 << 23) | mant; // restore implicit leading 1 + const uint16_t half_mant = (uint16_t)(full_mant >> (13 + shift)); + return (uint16_t)((sign << 15) | half_mant); + } + else if(new_exp >= 31) + { + // overflow to infinity + return (uint16_t)((sign << 15) | 0x7C00); + } + + return (uint16_t)((sign << 15) | (new_exp << 10) | (mant >> 13)); +} + +static float _half_to_float(uint16_t h) +{ + uint32_t sign = (h >> 15) & 1; + uint32_t exp = (h >> 10) & 0x1F; + uint32_t mant = h & 0x3FF; + + if(exp == 0) + { + if(mant == 0) + { + // zero + uint32_t result = (sign << 31); + float f; + memcpy(&f, &result, 4); + return f; + } + // denormal: value = (-1)^sign * 2^(-14) * (mant / 1024) + // convert to float32 by normalizing: find leading 1 and shift + uint32_t m = mant; + int e = -1; + while(!(m & 0x400)) + { // shift until leading 1 reaches bit 10 + m <<= 1; + e--; + } + m &= 0x3FF; // remove the leading 1 + uint32_t new_exp = (uint32_t)(e + 127 - 14 + 1); + uint32_t result = (sign << 31) | (new_exp << 23) | (m << 13); + float f; + memcpy(&f, &result, 4); + return f; + } + else if(exp == 31) + { + // inf / NaN + uint32_t result = (sign << 31) | 0x7F800000 | (mant << 13); + float f; + memcpy(&f, &result, 4); + return f; + } + + // normalized + const uint32_t new_exp = exp + 127 - 15; + const uint32_t result = (sign << 31) | (new_exp << 23) | (mant << 13); + float f; + memcpy(&f, &result, sizeof(f)); + return f; +} + +// try to find and call an ORT execution provider function at runtime via +// dynamic symbol lookup (GModule/dlsym). returns TRUE if the provider was +// enabled successfully, FALSE otherwise. +// most providers take (OrtSessionOptions*, uint32_t device_id), but OpenVINO +// takes (OrtSessionOptions*, const char* device_type). pass device_type for +// string-argument providers, NULL for integer-argument ones +static gboolean _try_provider(OrtSessionOptions *session_opts, + const char *symbol_name, + const char *provider_name, + const char *device_type) +{ + OrtStatus *status = NULL; + gboolean ok = FALSE; + + dt_print(DT_DEBUG_AI, "[darktable_ai] attempting to enable %s...", provider_name); + +#ifdef _WIN32 + // on windows, we need to get the handle to onnxruntime.dll, not the main executable + HMODULE h = GetModuleHandleA("onnxruntime.dll"); + if(!h) + { + // if not already loaded, try to load it + h = LoadLibraryA("onnxruntime.dll"); + } + void *func_ptr = NULL; + if(h) + { + func_ptr = (void *)GetProcAddress(h, symbol_name); + // don't call FreeLibrary - we want to keep onnxruntime.dll loaded + } +#else + GModule *mod = g_module_open(NULL, 0); + void *func_ptr = NULL; + if(mod) + g_module_symbol(mod, symbol_name, &func_ptr); +#endif + + if(func_ptr) + { + if(device_type) + { + // string-argument providers (e.g. OpenVINO) + typedef OrtStatus *(*ProviderAppenderStr)(OrtSessionOptions *, const char *); + ProviderAppenderStr appender = (ProviderAppenderStr)func_ptr; + status = appender(session_opts, device_type); + } + else + { + // integer-argument providers (CUDA, CoreML, DML, MIGraphX, ROCm) + typedef OrtStatus *(*ProviderAppenderInt)(OrtSessionOptions *, uint32_t); + ProviderAppenderInt appender = (ProviderAppenderInt)func_ptr; + status = appender(session_opts, 0); + } + if(!status) + { + dt_print(DT_DEBUG_AI, "[darktable_ai] %s enabled successfully.", provider_name); + ok = TRUE; + } + else + { + dt_print(DT_DEBUG_AI, + "[darktable_ai] %s enable failed: %s", + provider_name, g_ort->GetErrorMessage(status)); + g_ort->ReleaseStatus(status); + } + } + else + { + dt_print(DT_DEBUG_AI, "[darktable_ai] %s provider not found.", provider_name); + } + +#ifndef _WIN32 + if(mod) + g_module_close(mod); +#endif + + return ok; +} + +static void +_enable_acceleration(OrtSessionOptions *session_opts, dt_ai_provider_t provider) +{ + switch(provider) + { + case DT_AI_PROVIDER_CPU: + // CPU only - don't enable any accelerator + dt_print(DT_DEBUG_AI, "[darktable_ai] using CPU only (no hardware acceleration)"); + break; + + case DT_AI_PROVIDER_COREML: +#if defined(__APPLE__) + _try_provider( + session_opts, + "OrtSessionOptionsAppendExecutionProvider_CoreML", + "Apple CoreML", NULL); +#else + dt_print(DT_DEBUG_AI, "[darktable_ai] apple CoreML not available on this platform"); +#endif + break; + + case DT_AI_PROVIDER_CUDA: + _try_provider(session_opts, "OrtSessionOptionsAppendExecutionProvider_CUDA", "NVIDIA CUDA", NULL); + break; + + case DT_AI_PROVIDER_MIGRAPHX: + // try MIGraphX first; fall back to ROCm for older ORT builds + if(!_try_provider(session_opts, "OrtSessionOptionsAppendExecutionProvider_MIGraphX", "AMD MIGraphX", NULL)) + _try_provider(session_opts, "OrtSessionOptionsAppendExecutionProvider_ROCM", "AMD ROCm (legacy)", NULL); + break; + + case DT_AI_PROVIDER_OPENVINO: + _try_provider(session_opts, "OrtSessionOptionsAppendExecutionProvider_OpenVINO", "Intel OpenVINO", "AUTO"); + break; + + case DT_AI_PROVIDER_DIRECTML: +#if defined(_WIN32) + _try_provider( + session_opts, + "OrtSessionOptionsAppendExecutionProvider_DML", + "Windows DirectML", NULL); +#else + dt_print(DT_DEBUG_AI, "[darktable_ai] windows DirectML not available on this platform"); +#endif + break; + + case DT_AI_PROVIDER_AUTO: + default: + // auto-detect best provider based on platform +#if defined(__APPLE__) + _try_provider( + session_opts, + "OrtSessionOptionsAppendExecutionProvider_CoreML", + "Apple CoreML", NULL); +#elif defined(_WIN32) + _try_provider( + session_opts, + "OrtSessionOptionsAppendExecutionProvider_DML", + "Windows DirectML", NULL); +#elif defined(__linux__) + // try CUDA first, then MIGraphX + if(!_try_provider( + session_opts, + "OrtSessionOptionsAppendExecutionProvider_CUDA", + "NVIDIA CUDA", NULL)) + { + if(!_try_provider( + session_opts, + "OrtSessionOptionsAppendExecutionProvider_MIGraphX", + "AMD MIGraphX", NULL)) + _try_provider( + session_opts, + "OrtSessionOptionsAppendExecutionProvider_ROCM", + "AMD ROCm (legacy)", NULL); + } +#endif + break; + } +} + +// provider probe + +int dt_ai_probe_provider(dt_ai_provider_t provider) +{ + // AUTO and CPU are always available + if(provider == DT_AI_PROVIDER_AUTO || provider == DT_AI_PROVIDER_CPU) + return 1; + + // ensure ORT API is initialized + g_once(&g_ort_once, _init_ort_api, NULL); + if(!g_ort) return 0; + + g_once(&g_env_once, _init_ort_env, NULL); + if(!g_env) return 0; + + // create temporary session options for the probe + OrtSessionOptions *opts = NULL; + OrtStatus *status = g_ort->CreateSessionOptions(&opts); + if(status) + { + g_ort->ReleaseStatus(status); + return 0; + } + + gboolean ok = FALSE; + + switch(provider) + { + case DT_AI_PROVIDER_COREML: + ok = _try_provider(opts, "OrtSessionOptionsAppendExecutionProvider_CoreML", "Apple CoreML", NULL); + break; + case DT_AI_PROVIDER_CUDA: + ok = _try_provider(opts, "OrtSessionOptionsAppendExecutionProvider_CUDA", "NVIDIA CUDA", NULL); + break; + case DT_AI_PROVIDER_MIGRAPHX: + ok = _try_provider(opts, "OrtSessionOptionsAppendExecutionProvider_MIGraphX", "AMD MIGraphX", NULL) + || _try_provider(opts, "OrtSessionOptionsAppendExecutionProvider_ROCM", "AMD ROCm (legacy)", NULL); + break; + case DT_AI_PROVIDER_OPENVINO: + ok = _try_provider(opts, "OrtSessionOptionsAppendExecutionProvider_OpenVINO", "Intel OpenVINO", "AUTO"); + break; + case DT_AI_PROVIDER_DIRECTML: + ok = _try_provider(opts, "OrtSessionOptionsAppendExecutionProvider_DML", "Windows DirectML", NULL); + break; + default: + break; + } + + g_ort->ReleaseSessionOptions(opts); + return ok ? 1 : 0; +} + +// ONNX Model Loading + +// load ONNX model from model_dir/model_file with dimension overrides. +// if model_file is NULL, defaults to "model.onnx". +dt_ai_context_t * +dt_ai_onnx_load_ext(const char *model_dir, const char *model_file, + dt_ai_provider_t provider, dt_ai_opt_level_t opt_level, + const dt_ai_dim_override_t *dim_overrides, int n_overrides) +{ + if(!model_dir) + return NULL; + + char *onnx_path + = g_build_filename(model_dir, model_file ? model_file : "model.onnx", NULL); + if(!g_file_test(onnx_path, G_FILE_TEST_EXISTS)) + { + dt_print(DT_DEBUG_AI, "[darktable_ai] model file missing: %s", onnx_path); + g_free(onnx_path); + return NULL; + } + + // lazy init ORT API and shared environment on first load + g_once(&g_ort_once, _init_ort_api, NULL); + if(!g_ort) + { + g_free(onnx_path); + return NULL; + } + + g_once(&g_env_once, _init_ort_env, NULL); + if(!g_env) + { + g_free(onnx_path); + return NULL; + } + + dt_print(DT_DEBUG_AI, "[darktable_ai] loading: %s", onnx_path); + + dt_ai_context_t *ctx = g_new0(dt_ai_context_t, 1); + + OrtStatus *status; + OrtSessionOptions *session_opts; + status = g_ort->CreateSessionOptions(&session_opts); + if(status) + { + g_ort->ReleaseStatus(status); + g_free(onnx_path); + dt_ai_unload_model(ctx); + return NULL; + } + + // optimize: use all available cores (intra-op parallelism) +#ifdef _WIN32 + SYSTEM_INFO sysinfo; + GetSystemInfo(&sysinfo); + const long num_cores = MAX(1, sysinfo.dwNumberOfProcessors); +#else + const long num_cores = MAX(1, sysconf(_SC_NPROCESSORS_ONLN)); +#endif + + status = g_ort->SetIntraOpNumThreads(session_opts, (int)num_cores); + if(status) + { + g_ort->ReleaseStatus(status); + g_ort->ReleaseSessionOptions(session_opts); + g_free(onnx_path); + dt_ai_unload_model(ctx); + return NULL; + } + + const GraphOptimizationLevel ort_opt + = (opt_level == DT_AI_OPT_DISABLED) ? ORT_DISABLE_ALL + : (opt_level == DT_AI_OPT_BASIC) ? ORT_ENABLE_BASIC + : ORT_ENABLE_ALL; + status = g_ort->SetSessionGraphOptimizationLevel(session_opts, ort_opt); + if(status) + { + g_ort->ReleaseStatus(status); + g_ort->ReleaseSessionOptions(session_opts); + g_free(onnx_path); + dt_ai_unload_model(ctx); + return NULL; + } + + // override symbolic dimensions (fixes shape inference for dynamic-shape models) + for(int i = 0; i < n_overrides; i++) + { + if(!dim_overrides[i].name) continue; + status = g_ort->AddFreeDimensionOverrideByName(session_opts, + dim_overrides[i].name, + dim_overrides[i].value); + if(status) + { + dt_print(DT_DEBUG_AI, "[darktable_ai] dim override '%s' failed: %s", + dim_overrides[i].name, g_ort->GetErrorMessage(status)); + g_ort->ReleaseStatus(status); + } + } + + // optimize: enable hardware acceleration + _enable_acceleration(session_opts, provider); + +#ifdef _WIN32 + // on windows, CreateSession expects a wide character string + wchar_t *onnx_path_wide = (wchar_t *)g_utf8_to_utf16(onnx_path, -1, NULL, NULL, NULL); + status = g_ort->CreateSession(g_env, onnx_path_wide, session_opts, &ctx->session); +#else + status = g_ort->CreateSession(g_env, onnx_path, session_opts, &ctx->session); +#endif + + // if accelerated provider failed, fall back to CPU-only + if(status && provider != DT_AI_PROVIDER_CPU) + { + dt_print(DT_DEBUG_AI, + "[darktable_ai] accelerated session failed: %s — falling back to CPU", + g_ort->GetErrorMessage(status)); + g_ort->ReleaseStatus(status); + g_ort->ReleaseSessionOptions(session_opts); + + status = g_ort->CreateSessionOptions(&session_opts); + if(status) + { + g_ort->ReleaseStatus(status); +#ifdef _WIN32 + g_free(onnx_path_wide); +#endif + g_free(onnx_path); + dt_ai_unload_model(ctx); + return NULL; + } + status = g_ort->SetIntraOpNumThreads(session_opts, (int)num_cores); + if(status) g_ort->ReleaseStatus(status); + status = g_ort->SetSessionGraphOptimizationLevel(session_opts, ort_opt); + if(status) g_ort->ReleaseStatus(status); + for(int i = 0; i < n_overrides; i++) + { + if(!dim_overrides[i].name) continue; + status = g_ort->AddFreeDimensionOverrideByName( + session_opts, dim_overrides[i].name, dim_overrides[i].value); + if(status) g_ort->ReleaseStatus(status); + } + // CPU-only: no _enable_acceleration call +#ifdef _WIN32 + status = g_ort->CreateSession(g_env, onnx_path_wide, session_opts, &ctx->session); +#else + status = g_ort->CreateSession(g_env, onnx_path, session_opts, &ctx->session); +#endif + } + +#ifdef _WIN32 + g_free(onnx_path_wide); +#endif + g_ort->ReleaseSessionOptions(session_opts); + g_free(onnx_path); + + if(status) + { + dt_print(DT_DEBUG_AI, + "[darktable_ai] failed to create session: %s", + g_ort->GetErrorMessage(status)); + g_ort->ReleaseStatus(status); + dt_ai_unload_model(ctx); + return NULL; + } + + status + = g_ort->CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &ctx->memory_info); + if(status) + { + g_ort->ReleaseStatus(status); + dt_ai_unload_model(ctx); + return NULL; + } + + // resolve IO names + status = g_ort->GetAllocatorWithDefaultOptions(&ctx->allocator); + if(status) + { + g_ort->ReleaseStatus(status); + dt_ai_unload_model(ctx); + return NULL; + } + + status = g_ort->SessionGetInputCount(ctx->session, &ctx->input_count); + if(status) + { + g_ort->ReleaseStatus(status); + dt_ai_unload_model(ctx); + return NULL; + } + + status = g_ort->SessionGetOutputCount(ctx->session, &ctx->output_count); + if(status) + { + g_ort->ReleaseStatus(status); + dt_ai_unload_model(ctx); + return NULL; + } + + ctx->input_names = g_new0(char *, ctx->input_count); + ctx->input_types = g_new0(dt_ai_dtype_t, ctx->input_count); + for(size_t i = 0; i < ctx->input_count; i++) + { + status + = g_ort->SessionGetInputName(ctx->session, i, ctx->allocator, &ctx->input_names[i]); + if(status) + { + g_ort->ReleaseStatus(status); + dt_ai_unload_model(ctx); + return NULL; + } + + // get input type + OrtTypeInfo *typeinfo = NULL; + status = g_ort->SessionGetInputTypeInfo(ctx->session, i, &typeinfo); + if(status) + { + g_ort->ReleaseStatus(status); + dt_ai_unload_model(ctx); + return NULL; + } + const OrtTensorTypeAndShapeInfo *tensor_info = NULL; + status = g_ort->CastTypeInfoToTensorInfo(typeinfo, &tensor_info); + if(status) + { + g_ort->ReleaseStatus(status); + g_ort->ReleaseTypeInfo(typeinfo); + dt_ai_unload_model(ctx); + return NULL; + } + ONNXTensorElementDataType type; + status = g_ort->GetTensorElementType(tensor_info, &type); + if(status) + { + g_ort->ReleaseStatus(status); + g_ort->ReleaseTypeInfo(typeinfo); + dt_ai_unload_model(ctx); + return NULL; + } + + if(!_map_onnx_type(type, &ctx->input_types[i])) + { + dt_print(DT_DEBUG_AI, + "[darktable_ai] unsupported ONNX input type %d for input %zu", + type, i); + g_ort->ReleaseTypeInfo(typeinfo); + dt_ai_unload_model(ctx); + return NULL; + } + + g_ort->ReleaseTypeInfo(typeinfo); + } + + ctx->output_names = g_new0(char *, ctx->output_count); + ctx->output_types = g_new0(dt_ai_dtype_t, ctx->output_count); + for(size_t i = 0; i < ctx->output_count; i++) + { + status = g_ort->SessionGetOutputName( + ctx->session, + i, + ctx->allocator, + &ctx->output_names[i]); + if(status) + { + g_ort->ReleaseStatus(status); + dt_ai_unload_model(ctx); + return NULL; + } + + // get output type + OrtTypeInfo *typeinfo = NULL; + status = g_ort->SessionGetOutputTypeInfo(ctx->session, i, &typeinfo); + if(status) + { + g_ort->ReleaseStatus(status); + dt_ai_unload_model(ctx); + return NULL; + } + const OrtTensorTypeAndShapeInfo *tensor_info = NULL; + status = g_ort->CastTypeInfoToTensorInfo(typeinfo, &tensor_info); + if(status) + { + g_ort->ReleaseStatus(status); + g_ort->ReleaseTypeInfo(typeinfo); + dt_ai_unload_model(ctx); + return NULL; + } + ONNXTensorElementDataType type; + status = g_ort->GetTensorElementType(tensor_info, &type); + if(status) + { + g_ort->ReleaseStatus(status); + g_ort->ReleaseTypeInfo(typeinfo); + dt_ai_unload_model(ctx); + return NULL; + } + + if(!_map_onnx_type(type, &ctx->output_types[i])) + { + dt_print(DT_DEBUG_AI, + "[darktable_ai] unsupported ONNX output type %d for output %zu", + type, i); + g_ort->ReleaseTypeInfo(typeinfo); + dt_ai_unload_model(ctx); + return NULL; + } + + g_ort->ReleaseTypeInfo(typeinfo); + } + + // detect dynamic output shapes (any dim <= 0 means symbolic/unknown). + // when detected, dt_ai_run() will let ORT allocate outputs during + // execution and copy the results back to the caller's buffer + ctx->dynamic_outputs = FALSE; + for(size_t i = 0; i < ctx->output_count; i++) + { + int64_t shape[16]; + int ndim = dt_ai_get_output_shape(ctx, (int)i, shape, 16); + if(ndim > 0) + { + for(int d = 0; d < ndim; d++) + { + if(shape[d] <= 0) + { + ctx->dynamic_outputs = TRUE; + dt_print(DT_DEBUG_AI, + "[darktable_ai] output[%zu] has dynamic dims — using ORT-allocated outputs", + i); + break; + } + } + } + if(ctx->dynamic_outputs) break; + } + + return ctx; +} + +int dt_ai_run( + dt_ai_context_t *ctx, + dt_ai_tensor_t *inputs, + int num_inputs, + dt_ai_tensor_t *outputs, + int num_outputs) +{ + if(!ctx || !ctx->session) + return -1; + if(num_inputs != ctx->input_count || num_outputs != ctx->output_count) + { + dt_print(DT_DEBUG_AI, + "[darktable_ai] IO count mismatch: expected %zu/%zu, got %d/%d", + ctx->input_count, ctx->output_count, num_inputs, num_outputs); + return -2; + } + + // run + OrtStatus *status = NULL; + int ret = 0; + + // track temporary buffers to free later + void **temp_input_buffers = g_new0(void *, num_inputs); + + // create input tensors + OrtValue **input_tensors = g_new0(OrtValue *, num_inputs); + OrtValue **output_tensors = g_new0(OrtValue *, num_outputs); + const char **input_names = (const char **)ctx->input_names; // cast for Run() + + for(int i = 0; i < num_inputs; i++) + { + const int64_t element_count = _safe_element_count(inputs[i].shape, inputs[i].ndim); + if(element_count < 0) + { + dt_print(DT_DEBUG_AI, + "[darktable_ai] invalid or overflowing shape for input[%d]", i); + ret = -4; + goto cleanup; + } + + ONNXTensorElementDataType onnx_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + size_t type_size = sizeof(float); + void *data_ptr = inputs[i].data; + + // check for type mismatch (float -> half) + if(inputs[i].type == DT_AI_FLOAT && ctx->input_types[i] == DT_AI_FLOAT16) + { + dt_print(DT_DEBUG_AI, + "[darktable_ai] auto-converting input[%d] float32 -> float16", i); + // auto-convert float32 -> float16 + onnx_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; + type_size = sizeof(uint16_t); // half is 2 bytes + + if((size_t)element_count > SIZE_MAX / type_size) + { + dt_print(DT_DEBUG_AI, "[darktable_ai] tensor size overflow for input[%d]", i); + ret = -4; + goto cleanup; + } + uint16_t *half_data = g_malloc(element_count * type_size); + const float *src = (const float *)inputs[i].data; + for(int64_t k = 0; k < element_count; k++) + { + half_data[k] = _float_to_half(src[k]); + } + + data_ptr = half_data; + temp_input_buffers[i] = half_data; + } + else + { + if(!_dtype_to_onnx(inputs[i].type, &onnx_type, &type_size)) + { + dt_print(DT_DEBUG_AI, + "[darktable_ai] unsupported input type %d for input[%d]", + inputs[i].type, i); + ret = -4; + goto cleanup; + } + } + + if((size_t)element_count > SIZE_MAX / type_size) + { + dt_print(DT_DEBUG_AI, "[darktable_ai] tensor size overflow for input[%d]", i); + ret = -4; + goto cleanup; + } + status = g_ort->CreateTensorWithDataAsOrtValue( + ctx->memory_info, + data_ptr, + element_count * type_size, + inputs[i].shape, + inputs[i].ndim, + onnx_type, + &input_tensors[i]); + + if(status) + { + dt_print(DT_DEBUG_AI, + "[darktable_ai] CreateTensor input[%d] fail: %s", + i, g_ort->GetErrorMessage(status)); + g_ort->ReleaseStatus(status); + ret = -4; + goto cleanup; + } + } + + // create output tensors + const char **output_names = (const char **)ctx->output_names; + + for(int i = 0; i < num_outputs; i++) + { + // dynamic outputs or float16 mismatch: let ORT allocate during Run() + if(ctx->dynamic_outputs + || (outputs[i].type == DT_AI_FLOAT && ctx->output_types[i] == DT_AI_FLOAT16)) + { + output_tensors[i] = NULL; + continue; + } + + const int64_t element_count = _safe_element_count(outputs[i].shape, outputs[i].ndim); + if(element_count < 0) + { + dt_print(DT_DEBUG_AI, + "[darktable_ai] invalid or overflowing shape for output[%d]", i); + ret = -4; + goto cleanup; + } + + ONNXTensorElementDataType onnx_type; + size_t type_size; + + if(!_dtype_to_onnx(outputs[i].type, &onnx_type, &type_size)) + { + dt_print(DT_DEBUG_AI, + "[darktable_ai] unsupported output type %d for output[%d]", + outputs[i].type, i); + ret = -4; + goto cleanup; + } + + if((size_t)element_count > SIZE_MAX / type_size) + { + dt_print(DT_DEBUG_AI, "[darktable_ai] tensor size overflow for output[%d]", i); + ret = -4; + goto cleanup; + } + status = g_ort->CreateTensorWithDataAsOrtValue( + ctx->memory_info, + outputs[i].data, + element_count * type_size, + outputs[i].shape, + outputs[i].ndim, + onnx_type, + &output_tensors[i]); + + if(status) + { + dt_print(DT_DEBUG_AI, + "[darktable_ai] CreateTensor output[%d] fail: %s", + i, g_ort->GetErrorMessage(status)); + g_ort->ReleaseStatus(status); + ret = -4; + goto cleanup; + } + } + + // run + status = g_ort->Run(ctx->session, + NULL, + input_names, + (const OrtValue *const *)input_tensors, + num_inputs, + output_names, + num_outputs, + output_tensors); + + if(status) + { + dt_print(DT_DEBUG_AI, "[darktable_ai] run error: %s", g_ort->GetErrorMessage(status)); + g_ort->ReleaseStatus(status); + ret = -3; + } + else + { + // post-run: copy data from ORT-allocated outputs to caller's buffers. + // this handles both dynamic-shape models (where we can't pre-allocate + // because ORT's shape inference disagrees with the actual output shape) + // and Float16→Float auto-conversion + for(int i = 0; i < num_outputs; i++) + { + const gboolean ort_allocated = ctx->dynamic_outputs + || (outputs[i].type == DT_AI_FLOAT && ctx->output_types[i] == DT_AI_FLOAT16); + if(!ort_allocated || !output_tensors[i]) continue; + + void *raw_data = NULL; + status = g_ort->GetTensorMutableData(output_tensors[i], &raw_data); + if(status) + { + dt_print(DT_DEBUG_AI, + "[darktable_ai] GetTensorMutableData output[%d] failed: %s", + i, g_ort->GetErrorMessage(status)); + g_ort->ReleaseStatus(status); + continue; + } + + // query ORT's actual tensor size to avoid reading past its allocation. + // the caller's expected shape may differ from what ORT produced + // (e.g., dynamic-shape models) + OrtTensorTypeAndShapeInfo *tensor_info = NULL; + status = g_ort->GetTensorTypeAndShape(output_tensors[i], &tensor_info); + if(status) + { + dt_print(DT_DEBUG_AI, "[darktable_ai] GetTensorTypeAndShape output[%d] failed: %s", + i, g_ort->GetErrorMessage(status)); + g_ort->ReleaseStatus(status); + continue; + } + // update caller's shape array with actual ORT output dimensions. + // this is essential for dynamic-shape models where the caller's + // pre-assumed shape may differ from what ORT actually produced. + size_t actual_ndim = 0; + OrtStatus *dim_st = g_ort->GetDimensionsCount(tensor_info, &actual_ndim); + if(!dim_st && actual_ndim > 0 && (int)actual_ndim <= outputs[i].ndim) + { + OrtStatus *get_st = g_ort->GetDimensions(tensor_info, outputs[i].shape, actual_ndim); + if(!get_st) + outputs[i].ndim = (int)actual_ndim; + else + g_ort->ReleaseStatus(get_st); + } + if(dim_st) g_ort->ReleaseStatus(dim_st); + + size_t ort_element_count = 0; + status = g_ort->GetTensorShapeElementCount(tensor_info, &ort_element_count); + g_ort->ReleaseTensorTypeAndShapeInfo(tensor_info); + if(status) + { + dt_print(DT_DEBUG_AI, "[darktable_ai] GetTensorShapeElementCount output[%d] failed: %s", + i, g_ort->GetErrorMessage(status)); + g_ort->ReleaseStatus(status); + continue; + } + + const int64_t caller_count + = _safe_element_count(outputs[i].shape, outputs[i].ndim); + if(caller_count < 0) + { + dt_print(DT_DEBUG_AI, + "[darktable_ai] invalid shape for output[%d] post-copy", i); + continue; + } + + // use the smaller of ORT's actual size and caller's expected size + const int64_t element_count = ((int64_t)ort_element_count < caller_count) + ? (int64_t)ort_element_count + : caller_count; + + if(element_count != caller_count) + { + dt_print(DT_DEBUG_AI, + "[darktable_ai] output[%d] shape mismatch: ORT has %zu elements, " + "caller expects %" PRId64, + i, ort_element_count, caller_count); + } + + if(ctx->output_types[i] == DT_AI_FLOAT16 && outputs[i].type == DT_AI_FLOAT) + { + // float16 → float conversion + uint16_t *half_data = (uint16_t *)raw_data; + float *dst = (float *)outputs[i].data; + for(int64_t k = 0; k < element_count; k++) + dst[k] = _half_to_float(half_data[k]); + } + else + { + // same-type copy from ORT allocation to caller's buffer + ONNXTensorElementDataType onnx_type; + size_t type_size; + if(!_dtype_to_onnx(outputs[i].type, &onnx_type, &type_size)) + { + dt_print(DT_DEBUG_AI, + "[darktable_ai] unknown dtype %d for output[%d] post-copy", + outputs[i].type, i); + continue; + } + memcpy(outputs[i].data, raw_data, element_count * type_size); + } + } + } + +cleanup: + // cleanup OrtValues (wrappers only, data is owned by caller) + for(int i = 0; i < num_inputs; i++) + if(input_tensors[i]) + g_ort->ReleaseValue(input_tensors[i]); + for(int i = 0; i < num_outputs; i++) + if(output_tensors[i]) + g_ort->ReleaseValue(output_tensors[i]); + + // free temp input buffers + for(int i = 0; i < num_inputs; i++) + { + if(temp_input_buffers[i]) + g_free(temp_input_buffers[i]); + } + g_free(temp_input_buffers); + + g_free(input_tensors); + g_free(output_tensors); + + return ret; +} + +int dt_ai_get_input_count(dt_ai_context_t *ctx) +{ + return ctx ? (int)ctx->input_count : 0; +} + +int dt_ai_get_output_count(dt_ai_context_t *ctx) +{ + return ctx ? (int)ctx->output_count : 0; +} + +const char *dt_ai_get_input_name(dt_ai_context_t *ctx, int index) +{ + if(!ctx || index < 0 || (size_t)index >= ctx->input_count) + return NULL; + return ctx->input_names[index]; +} + +dt_ai_dtype_t dt_ai_get_input_type(dt_ai_context_t *ctx, int index) +{ + if(!ctx || index < 0 || (size_t)index >= ctx->input_count) + return DT_AI_FLOAT; + return ctx->input_types[index]; +} + +const char *dt_ai_get_output_name(dt_ai_context_t *ctx, int index) +{ + if(!ctx || index < 0 || (size_t)index >= ctx->output_count) + return NULL; + return ctx->output_names[index]; +} + +dt_ai_dtype_t dt_ai_get_output_type(dt_ai_context_t *ctx, int index) +{ + if(!ctx || index < 0 || (size_t)index >= ctx->output_count) + return DT_AI_FLOAT; + return ctx->output_types[index]; +} + +int dt_ai_get_output_shape(dt_ai_context_t *ctx, int index, + int64_t *shape, int max_dims) +{ + if(!ctx || !ctx->session || index < 0 || (size_t)index >= ctx->output_count + || !shape || max_dims <= 0) + return -1; + + OrtTypeInfo *typeinfo = NULL; + OrtStatus *status = g_ort->SessionGetOutputTypeInfo(ctx->session, index, &typeinfo); + if(status) + { + g_ort->ReleaseStatus(status); + return -1; + } + + const OrtTensorTypeAndShapeInfo *tensor_info = NULL; + status = g_ort->CastTypeInfoToTensorInfo(typeinfo, &tensor_info); + if(status) + { + g_ort->ReleaseStatus(status); + g_ort->ReleaseTypeInfo(typeinfo); + return -1; + } + + size_t ndim = 0; + status = g_ort->GetDimensionsCount(tensor_info, &ndim); + if(status) + { + g_ort->ReleaseStatus(status); + g_ort->ReleaseTypeInfo(typeinfo); + return -1; + } + + const int dims = (int)ndim < max_dims ? (int)ndim : max_dims; + int64_t full_shape[16]; + if(ndim > 16) + { + g_ort->ReleaseTypeInfo(typeinfo); + return -1; + } + + status = g_ort->GetDimensions(tensor_info, full_shape, ndim); + g_ort->ReleaseTypeInfo(typeinfo); + if(status) + { + g_ort->ReleaseStatus(status); + return -1; + } + + memcpy(shape, full_shape, dims * sizeof(int64_t)); + return (int)ndim; +} + +void dt_ai_unload_model(dt_ai_context_t *ctx) +{ + if(ctx) + { + if(ctx->session) + g_ort->ReleaseSession(ctx->session); + // note: OrtEnv is a shared singleton (g_env), not per-context + if(ctx->memory_info) + g_ort->ReleaseMemoryInfo(ctx->memory_info); + + // release IO names using the allocator that created them + if(ctx->allocator) + { + for(size_t i = 0; i < ctx->input_count; i++) + { + if(ctx->input_names[i]) + ctx->allocator->Free(ctx->allocator, ctx->input_names[i]); + } + for(size_t i = 0; i < ctx->output_count; i++) + { + if(ctx->output_names[i]) + ctx->allocator->Free(ctx->allocator, ctx->output_names[i]); + } + } + + g_free(ctx->input_names); + g_free(ctx->output_names); + g_free(ctx->input_types); + g_free(ctx->output_types); + g_free(ctx); + } +} + +// clang-format off +// modelines: These editor modelines have been set for all relevant files by tools/update_modelines.py +// vim: shiftwidth=2 expandtab tabstop=2 cindent +// kate: tab-indents: off; indent-width 2; replace-tabs on; indent-mode cstyle; remove-trailing-spaces modified; +// clang-format on diff --git a/src/common/ai_models.c b/src/common/ai_models.c new file mode 100644 index 000000000000..67dc066210b8 --- /dev/null +++ b/src/common/ai_models.c @@ -0,0 +1,1567 @@ +/* + This file is part of darktable, + Copyright (C) 2026 darktable developers. + + darktable is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + darktable is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with darktable. If not, see . +*/ + +#include "common/ai_models.h" +#include "common/darktable.h" +#include "common/curl_tools.h" +#include "common/file_location.h" +#include "control/control.h" +#include "control/jobs.h" + +#include +#include +#include +#include +#include +#include + +#ifdef _WIN32 +// windows doesn't have realpath, use _fullpath instead +static inline char *realpath(const char *path, char *resolved_path) +{ + (void)resolved_path; // ignored, always allocate + return _fullpath(NULL, path, _MAX_PATH); +} +#endif + +// config keys +#define CONF_AI_ENABLED "plugins/ai/enabled" +#define CONF_AI_REPOSITORY "plugins/ai/repository" +#define CONF_MODEL_ENABLED_PREFIX "plugins/ai/models/" +#define CONF_ACTIVE_MODEL_PREFIX "plugins/ai/models/active/" + +static void _model_free(dt_ai_model_t *model) +{ + if(!model) + return; + g_free(model->id); + g_free(model->name); + g_free(model->description); + g_free(model->task); + g_free(model->github_asset); + g_free(model->checksum); + g_free(model); +} + +static dt_ai_model_t *_model_copy(const dt_ai_model_t *src) +{ + if(!src) + return NULL; + dt_ai_model_t *copy = g_new0(dt_ai_model_t, 1); + copy->id = g_strdup(src->id); + copy->name = g_strdup(src->name); + copy->description = g_strdup(src->description); + copy->task = g_strdup(src->task); + copy->github_asset = g_strdup(src->github_asset); + copy->checksum = g_strdup(src->checksum); + copy->is_default = src->is_default; + copy->enabled = src->enabled; + copy->status = src->status; + copy->download_progress = src->download_progress; + return copy; +} + +void dt_ai_model_free(dt_ai_model_t *model) { _model_free(model); } + +static dt_ai_model_t *_model_new(void) +{ + dt_ai_model_t *model = g_new0(dt_ai_model_t, 1); + model->enabled = TRUE; + model->status = DT_AI_MODEL_NOT_DOWNLOADED; + return model; +} + +static gboolean _ensure_directory(const char *path) +{ + if(g_file_test(path, G_FILE_TEST_IS_DIR)) + return TRUE; + return g_mkdir_with_parents(path, 0755) == 0; +} + +#ifdef HAVE_AI_DOWNLOAD +// curl write callback that appends to a GString (capped at 1 MB) +static size_t _curl_write_string(void *ptr, size_t size, size_t nmemb, void *userdata) +{ + GString *buf = (GString *)userdata; + const size_t bytes = size * nmemb; + if(buf->len + bytes > 1024 * 1024) return 0; // abort transfer + g_string_append_len(buf, (const char *)ptr, bytes); + return bytes; +} + +/** + * @brief extract "major.minor.patch" from darktable_package_version. + * + * darktable_package_version looks like "5.5.0+156~gabcdef-dirty" or "5.4.0". + * we extract the leading "X.Y.Z" portion. + * + * @return newly allocated string "X.Y.Z", or NULL on parse failure. + */ +static char *_get_darktable_version_prefix(void) +{ + int major = 0, minor = 0, patch = 0; + if(sscanf(darktable_package_version, "%d.%d.%d", &major, &minor, &patch) == 3) + return g_strdup_printf("%d.%d.%d", major, minor, patch); + return NULL; +} + +/** + * @brief query the github api to find the latest model release compatible + * with the current darktable version. + * + * looks for releases tagged "vX.Y.Z" or "vX.Y.Z.N" where X.Y.Z matches + * the darktable version. returns the tag with the highest revision number. + * + * @param repository github "owner/repo" string + * @return newly allocated tag string (e.g. "v5.5.0.2"), or NULL if none found. + */ +static char *_find_latest_compatible_release(const char *repository, char **error_msg) +{ + if(error_msg) *error_msg = NULL; + + char *dt_version = _get_darktable_version_prefix(); + if(!dt_version) + return NULL; + + char *api_url = g_strdup_printf( + "https://api.github.com/repos/%s/releases?per_page=100", + repository); + + CURL *curl = curl_easy_init(); + if(!curl) + { + g_free(api_url); + g_free(dt_version); + return NULL; + } + dt_curl_init(curl, FALSE); + + GString *response = g_string_new(NULL); + curl_easy_setopt(curl, CURLOPT_URL, api_url); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, _curl_write_string); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, response); + curl_easy_setopt(curl, CURLOPT_TIMEOUT, 30L); + + struct curl_slist *headers = NULL; + headers = curl_slist_append(headers, "Accept: application/vnd.github+json"); + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); + + CURLcode res = curl_easy_perform(curl); + long http_code = 0; + curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &http_code); + curl_slist_free_all(headers); + curl_easy_cleanup(curl); + g_free(api_url); + + if(res != CURLE_OK || http_code != 200) + { + dt_print(DT_DEBUG_AI, + "[ai_models] github api request failed: curl=%d, http=%ld", + res, http_code); + if(error_msg) + { + if(res != CURLE_OK) + *error_msg = g_strdup_printf(_("network error: %s"), curl_easy_strerror(res)); + else if(http_code == 404) + *error_msg = g_strdup_printf(_("model repository \"%s\" not found"), repository); + else if(http_code == 403) + *error_msg = g_strdup(_("GitHub API rate limit exceeded, try again later")); + else + *error_msg = g_strdup_printf(_("GitHub API error (HTTP %ld)"), http_code); + } + g_string_free(response, TRUE); + g_free(dt_version); + return NULL; + } + + // parse json array of releases + JsonParser *parser = json_parser_new(); + if(!json_parser_load_from_data(parser, response->str, response->len, NULL)) + { + g_object_unref(parser); + g_string_free(response, TRUE); + g_free(dt_version); + return NULL; + } + g_string_free(response, TRUE); + + JsonNode *root = json_parser_get_root(parser); + if(!root || !JSON_NODE_HOLDS_ARRAY(root)) + { + g_object_unref(parser); + g_free(dt_version); + return NULL; + } + + // build prefix to match: accept both "vX.Y.Z" and "X.Y.Z" tag formats + size_t ver_len = strlen(dt_version); + + char *best_tag = NULL; + int best_revision = -1; // -1 means no revision suffix (e.g., "5.5.0" itself) + + JsonArray *releases = json_node_get_array(root); + guint len = json_array_get_length(releases); + for(guint i = 0; i < len; i++) + { + JsonNode *node = json_array_get_element(releases, i); + if(!JSON_NODE_HOLDS_OBJECT(node)) + continue; + JsonObject *rel = json_node_get_object(node); + + if(!json_object_has_member(rel, "tag_name")) + continue; + const char *tag = json_object_get_string_member(rel, "tag_name"); + if(!tag) + continue; + + // skip any non-digit prefix (e.g. "v", "release-") to extract X.Y.Z.W + const char *ver_part = tag; + while(*ver_part && !g_ascii_isdigit(*ver_part)) + ver_part++; + if(!*ver_part) + continue; + + if(strncmp(ver_part, dt_version, ver_len) != 0) + continue; + + // tag matches version prefix. check what follows: + // "X.Y.Z" (exact) -> revision = 0 + // "X.Y.Z.N" -> revision = N + const char *suffix = ver_part + ver_len; + int revision = 0; + if(suffix[0] == '\0') + { + revision = 0; + } + else if(suffix[0] == '.' && suffix[1] >= '0' && suffix[1] <= '9') + { + revision = atoi(suffix + 1); + } + else + { + continue; // doesn't match pattern + } + + if(revision > best_revision) + { + best_revision = revision; + g_free(best_tag); + best_tag = g_strdup(tag); + } + } + + g_free(dt_version); + g_object_unref(parser); + + if(best_tag) + dt_print(DT_DEBUG_AI, + "[ai_models] found compatible release: %s", + best_tag); + else + dt_print(DT_DEBUG_AI, + "[ai_models] no compatible release found for darktable %s", + darktable_package_version); + + return best_tag; +} + +/** + * @brief fetch the SHA256 digest for a release asset from the GitHub API + * + * queries /repos/{repo}/releases/tags/{tag}, iterates the assets array, + * and returns the "digest" field for the asset whose "name" matches + * + * @param repository github "owner/repo" string + * @param release_tag release tag (e.g. "5.5.0.1") + * @param asset_name asset filename (e.g. "denoise-nafnet.zip") + * @return newly allocated string "SHA256:...", or NULL if not found + */ +static char *_fetch_asset_digest( + const char *repository, + const char *release_tag, + const char *asset_name) +{ + char *api_url = g_strdup_printf( + "https://api.github.com/repos/%s/releases/tags/%s", + repository, + release_tag + ); + + CURL *curl = curl_easy_init(); + if(!curl) + { + g_free(api_url); + return NULL; + } + dt_curl_init(curl, FALSE); + + GString *response = g_string_new(NULL); + curl_easy_setopt(curl, CURLOPT_URL, api_url); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, _curl_write_string); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, response); + curl_easy_setopt(curl, CURLOPT_TIMEOUT, 30L); + + struct curl_slist *headers = NULL; + headers = curl_slist_append(headers, "Accept: application/vnd.github+json"); + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); + + CURLcode res = curl_easy_perform(curl); + long http_code = 0; + curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &http_code); + curl_slist_free_all(headers); + curl_easy_cleanup(curl); + g_free(api_url); + + if(res != CURLE_OK || http_code != 200) + { + dt_print(DT_DEBUG_AI, + "[ai_models] failed to fetch release metadata: curl=%d, http=%ld", + res, http_code); + g_string_free(response, TRUE); + return NULL; + } + + // parse the release json object + JsonParser *parser = json_parser_new(); + if(!json_parser_load_from_data(parser, response->str, response->len, NULL)) + { + g_object_unref(parser); + g_string_free(response, TRUE); + return NULL; + } + g_string_free(response, TRUE); + + JsonNode *root = json_parser_get_root(parser); + if(!root || !JSON_NODE_HOLDS_OBJECT(root)) + { + g_object_unref(parser); + return NULL; + } + + JsonObject *release = json_node_get_object(root); + if(!json_object_has_member(release, "assets")) + { + g_object_unref(parser); + return NULL; + } + + char *digest = NULL; + JsonArray *assets = json_object_get_array_member(release, "assets"); + guint len = json_array_get_length(assets); + for(guint i = 0; i < len; i++) + { + JsonNode *node = json_array_get_element(assets, i); + if(!JSON_NODE_HOLDS_OBJECT(node)) + continue; + JsonObject *asset_obj = json_node_get_object(node); + + if(!json_object_has_member(asset_obj, "name")) + continue; + const char *name = json_object_get_string_member(asset_obj, "name"); + if(g_strcmp0(name, asset_name) != 0) + continue; + + if(json_object_has_member(asset_obj, "digest")) + { + const char *d = json_object_get_string_member(asset_obj, "digest"); + if(d && g_str_has_prefix(d, "sha256:")) + { + digest = g_strdup(d); + dt_print(DT_DEBUG_AI, "[ai_models] asset %s digest: %s", asset_name, digest); + } + } + break; + } + + g_object_unref(parser); + + if(!digest) + dt_print(DT_DEBUG_AI, + "[ai_models] no digest found for asset %s in release %s", + asset_name, release_tag); + + return digest; +} +#endif // HAVE_AI_DOWNLOAD + +// core API + +dt_ai_registry_t *dt_ai_models_init(void) +{ + dt_ai_registry_t *registry = g_new0(dt_ai_registry_t, 1); + g_mutex_init(®istry->lock); + + // set up directories — models use user data dir so they are shared + // across --configdir instances (not tied to a specific config) + char cachedir[PATH_MAX] = {0}; + dt_loc_get_user_cache_dir(cachedir, sizeof(cachedir)); + + registry->models_dir = g_build_filename(g_get_user_data_dir(), "darktable", "models", NULL); + registry->cache_dir = g_build_filename(cachedir, "ai_downloads", NULL); + + // ensure directories exist + _ensure_directory(registry->models_dir); + _ensure_directory(registry->cache_dir); + + // load settings from config + registry->ai_enabled = dt_conf_get_bool(CONF_AI_ENABLED); + + char *provider_str = dt_conf_get_string(DT_AI_CONF_PROVIDER); + registry->provider = dt_ai_provider_from_string(provider_str); + g_free(provider_str); + + dt_print(DT_DEBUG_AI, + "[ai_models] initialized: models_dir=%s, cache_dir=%s", + registry->models_dir, registry->cache_dir); + + return registry; +} + +static dt_ai_model_t *_parse_model_json(JsonObject *obj) +{ + if(!json_object_has_member(obj, "id") || !json_object_has_member(obj, "name")) + return NULL; + + dt_ai_model_t *model = _model_new(); + model->id = g_strdup(json_object_get_string_member(obj, "id")); + model->name = g_strdup(json_object_get_string_member(obj, "name")); + + if(json_object_has_member(obj, "description")) + model->description = g_strdup(json_object_get_string_member(obj, "description")); + if(json_object_has_member(obj, "task")) + model->task = g_strdup(json_object_get_string_member(obj, "task")); + if(json_object_has_member(obj, "github_asset")) + model->github_asset = g_strdup(json_object_get_string_member(obj, "github_asset")); + if(json_object_has_member(obj, "checksum")) + model->checksum = g_strdup(json_object_get_string_member(obj, "checksum")); + if(json_object_has_member(obj, "default")) + model->is_default = json_object_get_boolean_member(obj, "default"); + + return model; +} + +gboolean dt_ai_models_load_registry(dt_ai_registry_t *registry) +{ + if(!registry) + return FALSE; + + // find the registry json file in the data directory + char datadir[PATH_MAX] = {0}; + dt_loc_get_datadir(datadir, sizeof(datadir)); + char *registry_path = g_build_filename(datadir, "ai_models.json", NULL); + + if(!g_file_test(registry_path, G_FILE_TEST_EXISTS)) + { + dt_print(DT_DEBUG_AI, "[ai_models] registry file not found: %s", registry_path); + g_free(registry_path); + return FALSE; + } + + GError *error = NULL; + JsonParser *parser = json_parser_new(); + + if(!json_parser_load_from_file(parser, registry_path, &error)) + { + dt_print(DT_DEBUG_AI, + "[ai_models] failed to parse registry: %s", + error ? error->message : "unknown error"); + if(error) + g_error_free(error); + g_object_unref(parser); + g_free(registry_path); + return FALSE; + } + + JsonNode *root = json_parser_get_root(parser); + if(!JSON_NODE_HOLDS_OBJECT(root)) + { + dt_print(DT_DEBUG_AI, "[ai_models] registry root is not an object"); + g_object_unref(parser); + g_free(registry_path); + return FALSE; + } + + JsonObject *root_obj = json_node_get_object(root); + + g_mutex_lock(®istry->lock); + + // clear existing models + g_list_free_full(registry->models, (GDestroyNotify)_model_free); + registry->models = NULL; + + // parse repository - config overrides json default + g_free(registry->repository); + registry->repository = NULL; + + if(dt_conf_key_exists(CONF_AI_REPOSITORY)) + registry->repository = dt_conf_get_string(CONF_AI_REPOSITORY); + + dt_print(DT_DEBUG_AI, + "[ai_models] using repository: %s", + registry->repository ? registry->repository : "(none)"); + + // parse models array + if(json_object_has_member(root_obj, "models")) + { + JsonArray *models_arr = json_object_get_array_member(root_obj, "models"); + guint len = json_array_get_length(models_arr); + + for(guint i = 0; i < len; i++) + { + JsonNode *node = json_array_get_element(models_arr, i); + if(!JSON_NODE_HOLDS_OBJECT(node)) + continue; + + dt_ai_model_t *model = _parse_model_json(json_node_get_object(node)); + if(model) + { + // load enabled state from user config + char *conf_key + = g_strdup_printf("%s%s/enabled", CONF_MODEL_ENABLED_PREFIX, model->id); + if(dt_conf_key_exists(conf_key)) + model->enabled = dt_conf_get_bool(conf_key); + g_free(conf_key); + + registry->models = g_list_prepend(registry->models, model); + dt_print(DT_DEBUG_AI, + "[ai_models] loaded model: %s (%s)", + model->name, model->id); + } + } + } + + // reverse to restore original json order (we used prepend for O(1) insertion) + registry->models = g_list_reverse(registry->models); + + const int model_count = g_list_length(registry->models); + + g_mutex_unlock(®istry->lock); + + dt_print(DT_DEBUG_AI, + "[ai_models] registry loaded: %d models from %s", + model_count, registry_path); + + g_object_unref(parser); + g_free(registry_path); + + // check which models are actually downloaded + dt_ai_models_refresh_status(registry); + + return TRUE; +} + +// validate that a model_id is a plain directory name with no path separators +// or ".." components that could escape the models directory +static gboolean _valid_model_id(const char *model_id); +static dt_ai_model_t *_find_model_unlocked(dt_ai_registry_t *registry, + const char *model_id); + +// parse a local model's config.json into a dt_ai_model_t +// uses directory name as fallback for id/name. no github_asset or checksum +static dt_ai_model_t *_parse_local_model_config(const char *config_path, + const char *dir_name) +{ + JsonParser *parser = json_parser_new(); + GError *error = NULL; + + if(!json_parser_load_from_file(parser, config_path, &error)) + { + dt_print(DT_DEBUG_AI, "[ai_models] failed to parse %s: %s", + config_path, error ? error->message : "unknown"); + if(error) g_error_free(error); + g_object_unref(parser); + return NULL; + } + + JsonNode *root = json_parser_get_root(parser); + if(!JSON_NODE_HOLDS_OBJECT(root)) + { + g_object_unref(parser); + return NULL; + } + + JsonObject *obj = json_node_get_object(root); + + const char *id = json_object_has_member(obj, "id") + ? json_object_get_string_member(obj, "id") : dir_name; + const char *name = json_object_has_member(obj, "name") + ? json_object_get_string_member(obj, "name") : dir_name; + + if(!id || !id[0]) + { + g_object_unref(parser); + return NULL; + } + + dt_ai_model_t *model = _model_new(); + model->id = g_strdup(id); + model->name = g_strdup(name); + + if(json_object_has_member(obj, "description")) + model->description = g_strdup(json_object_get_string_member(obj, "description")); + if(json_object_has_member(obj, "task")) + model->task = g_strdup(json_object_get_string_member(obj, "task")); + + // no github_asset, no checksum — local-only model + model->enabled = TRUE; + + g_object_unref(parser); + return model; +} + +static gboolean _valid_model_id(const char *model_id) +{ + if(!model_id || !model_id[0]) + return FALSE; + if(strchr(model_id, '/') || strchr(model_id, '\\')) + return FALSE; + if(strcmp(model_id, "..") == 0 || strcmp(model_id, ".") == 0) + return FALSE; + return TRUE; +} + +void dt_ai_models_refresh_status(dt_ai_registry_t *registry) +{ + if(!registry) + return; + + g_mutex_lock(®istry->lock); + + // remove previously-discovered local models (no github_asset) + // these will be re-discovered from disk below if still present + GList *l = registry->models; + while(l) + { + GList *next = g_list_next(l); + dt_ai_model_t *model = (dt_ai_model_t *)l->data; + if(!model->github_asset) + { + _model_free(model); + registry->models = g_list_delete_link(registry->models, l); + } + l = next; + } + + // pass 1: update status for registry models + for(GList *l2 = registry->models; l2; l2 = g_list_next(l2)) + { + dt_ai_model_t *model = (dt_ai_model_t *)l2->data; + + // skip models with invalid ids (path traversal protection) + if(!_valid_model_id(model->id)) + continue; + + // check if model directory exists and contains required files + char *model_dir = g_build_filename(registry->models_dir, model->id, NULL); + char *config_path = g_build_filename(model_dir, "config.json", NULL); + + if(g_file_test(model_dir, G_FILE_TEST_IS_DIR) + && g_file_test(config_path, G_FILE_TEST_EXISTS)) + { + model->status = DT_AI_MODEL_DOWNLOADED; + } + else + { + model->status = DT_AI_MODEL_NOT_DOWNLOADED; + } + + g_free(config_path); + g_free(model_dir); + } + + // pass 2: discover locally-installed models not in registry + if(registry->models_dir) + { + GDir *dir = g_dir_open(registry->models_dir, 0, NULL); + if(dir) + { + const char *entry_name; + while((entry_name = g_dir_read_name(dir))) + { + if(!_valid_model_id(entry_name)) + continue; + + // skip if already in registry (e.g. downloaded via ai_models.json) + if(_find_model_unlocked(registry, entry_name)) + continue; + + char *model_dir = g_build_filename(registry->models_dir, entry_name, NULL); + char *config_path = g_build_filename(model_dir, "config.json", NULL); + + if(g_file_test(model_dir, G_FILE_TEST_IS_DIR) + && g_file_test(config_path, G_FILE_TEST_EXISTS)) + { + dt_ai_model_t *model = _parse_local_model_config(config_path, entry_name); + if(model) + { + model->status = DT_AI_MODEL_DOWNLOADED; + registry->models = g_list_append(registry->models, model); + dt_print(DT_DEBUG_AI, + "[ai_models] discovered local model: %s (%s)", + model->name, model->id); + } + } + + g_free(config_path); + g_free(model_dir); + } + g_dir_close(dir); + } + } + + g_mutex_unlock(®istry->lock); +} + +void dt_ai_models_cleanup(dt_ai_registry_t *registry) +{ + if(!registry) + return; + + g_mutex_lock(®istry->lock); + g_list_free_full(registry->models, (GDestroyNotify)_model_free); + registry->models = NULL; + g_mutex_unlock(®istry->lock); + + g_mutex_clear(®istry->lock); + + g_free(registry->repository); + g_free(registry->models_dir); + g_free(registry->cache_dir); + g_free(registry); +} + +// internal: find model by id without locking. caller must hold registry->lock +// returns direct pointer to registry-owned model (not a copy) +static dt_ai_model_t * +_find_model_unlocked(dt_ai_registry_t *registry, const char *model_id) +{ + for(GList *l = registry->models; l; l = g_list_next(l)) + { + dt_ai_model_t *model = (dt_ai_model_t *)l->data; + if(g_strcmp0(model->id, model_id) == 0) + return model; + } + return NULL; +} + +int dt_ai_models_get_count(dt_ai_registry_t *registry) +{ + if(!registry) + return 0; + g_mutex_lock(®istry->lock); + int count = g_list_length(registry->models); + g_mutex_unlock(®istry->lock); + return count; +} + +dt_ai_model_t *dt_ai_models_get_by_index(dt_ai_registry_t *registry, int index) +{ + if(!registry || index < 0) + return NULL; + g_mutex_lock(®istry->lock); + dt_ai_model_t *model = g_list_nth_data(registry->models, index); + dt_ai_model_t *copy = _model_copy(model); + g_mutex_unlock(®istry->lock); + return copy; +} + +dt_ai_model_t *dt_ai_models_get_by_id(dt_ai_registry_t *registry, const char *model_id) +{ + if(!registry || !model_id) + return NULL; + g_mutex_lock(®istry->lock); + dt_ai_model_t *model = _find_model_unlocked(registry, model_id); + dt_ai_model_t *copy = _model_copy(model); + g_mutex_unlock(®istry->lock); + return copy; +} + +#ifdef HAVE_AI_DOWNLOAD + +typedef struct dt_ai_download_data_t +{ + dt_ai_registry_t *registry; + char *model_id; // owned copy of model id (safe to use without lock) + dt_ai_progress_callback callback; + gpointer user_data; + FILE *file; + const gboolean *cancel_flag; // optional: set to non-NULL to enable cancellation +} dt_ai_download_data_t; + +static size_t _curl_write_callback(void *ptr, size_t size, size_t nmemb, void *data) +{ + dt_ai_download_data_t *dl = (dt_ai_download_data_t *)data; + return fwrite(ptr, size, nmemb, dl->file); +} + +static int _curl_progress_callback( + void *clientp, + curl_off_t dltotal, + curl_off_t dlnow, + curl_off_t ultotal, + curl_off_t ulnow) +{ + dt_ai_download_data_t *dl = (dt_ai_download_data_t *)clientp; + + // check for cancellation + if(dl->cancel_flag && g_atomic_int_get(dl->cancel_flag)) + return 1; // non-zero aborts the transfer + + if(dltotal > 0) + { + double progress = (double)dlnow / (double)dltotal; + + // update model progress under lock + g_mutex_lock(&dl->registry->lock); + dt_ai_model_t *m = _find_model_unlocked(dl->registry, dl->model_id); + if(m) + m->download_progress = progress; + g_mutex_unlock(&dl->registry->lock); + + if(dl->callback) + dl->callback(dl->model_id, progress, dl->user_data); + } + return 0; +} + +static gboolean _verify_checksum(const char *filepath, const char *expected) +{ + if(!expected || !g_str_has_prefix(expected, "sha256:")) + { + dt_print(DT_DEBUG_AI, "[ai_models] no valid checksum provided - rejecting download"); + return FALSE; // reject files without a valid checksum + } + + const char *expected_hash = expected + 7; // skip "sha256:" + + GChecksum *checksum = g_checksum_new(G_CHECKSUM_SHA256); + if(!checksum) + return FALSE; + + // stream file in chunks to avoid loading entire file into memory + FILE *f = g_fopen(filepath, "rb"); + if(!f) + { + dt_print(DT_DEBUG_AI, "[ai_models] failed to open file for checksum: %s", filepath); + g_checksum_free(checksum); + return FALSE; + } + + guchar buf[65536]; + size_t n; + while((n = fread(buf, 1, sizeof(buf), f)) > 0) + g_checksum_update(checksum, buf, n); + fclose(f); + + const gchar *computed = g_checksum_get_string(checksum); + gboolean match = g_ascii_strcasecmp(computed, expected_hash) == 0; + + if(!match) + { + dt_print(DT_DEBUG_AI, + "[ai_models] checksum mismatch: expected %s, got %s", + expected_hash, computed); + } + + g_checksum_free(checksum); + return match; +} +#endif //HAVE_AI_DOWNLOAD + +static gboolean _extract_zip(const char *zippath, const char *destdir) +{ + struct archive *a = archive_read_new(); + struct archive *ext = archive_write_disk_new(); + struct archive_entry *entry; + int r; + gboolean success = TRUE; + + archive_read_support_format_zip(a); + archive_write_disk_set_options( + ext, + ARCHIVE_EXTRACT_TIME | ARCHIVE_EXTRACT_PERM | ARCHIVE_EXTRACT_SECURE_SYMLINKS + | ARCHIVE_EXTRACT_SECURE_NODOTDOT); + + if((r = archive_read_open_filename(a, zippath, 10240)) != ARCHIVE_OK) + { + dt_print(DT_DEBUG_AI, + "[ai_models] failed to open archive: %s", + archive_error_string(a)); + archive_read_free(a); + archive_write_free(ext); + return FALSE; + } + + _ensure_directory(destdir); + + // resolve destdir to a canonical path for path traversal validation + char *real_destdir = realpath(destdir, NULL); + if(!real_destdir) + { + dt_print(DT_DEBUG_AI, "[ai_models] failed to resolve destdir: %s", destdir); + archive_read_close(a); + archive_read_free(a); + archive_write_free(ext); + return FALSE; + } + + const size_t destdir_len = strlen(real_destdir); + + while(archive_read_next_header(a, &entry) == ARCHIVE_OK) + { + const char *entry_name = archive_entry_pathname(entry); + + // reject entries with path traversal components + if(g_strstr_len(entry_name, -1, "..") != NULL) + { + dt_print(DT_DEBUG_AI, + "[ai_models] skipping suspicious archive entry: %s", + entry_name); + continue; + } + + // build full path in destination + char *full_path = g_build_filename(real_destdir, entry_name, NULL); + + // verify the resolved path is within destdir + char *real_full_path = realpath(full_path, NULL); + // for new files, realpath returns NULL; check the parent directory instead + if(!real_full_path) + { + char *parent = g_path_get_dirname(full_path); + _ensure_directory(parent); + char *real_parent = realpath(parent, NULL); + g_free(parent); + if( + !real_parent || strncmp(real_parent, real_destdir, destdir_len) != 0 + || (real_parent[destdir_len] != '/' && real_parent[destdir_len] != '\\' + && real_parent[destdir_len] != '\0')) + { + dt_print(DT_DEBUG_AI, "[ai_models] path traversal blocked: %s", entry_name); + free(real_parent); + g_free(full_path); + continue; + } + free(real_parent); + } + else + { + if( + strncmp(real_full_path, real_destdir, destdir_len) != 0 + || (real_full_path[destdir_len] != '/' && real_full_path[destdir_len] != '\\' + && real_full_path[destdir_len] != '\0')) + { + dt_print(DT_DEBUG_AI, "[ai_models] path traversal blocked: %s", entry_name); + free(real_full_path); + g_free(full_path); + continue; + } + free(real_full_path); + } + + archive_entry_set_pathname(entry, full_path); + + r = archive_write_header(ext, entry); + if(r == ARCHIVE_OK) + { + const void *buff; + size_t size; + la_int64_t offset; + + while(archive_read_data_block(a, &buff, &size, &offset) == ARCHIVE_OK) + { + if(archive_write_data_block(ext, buff, size, offset) != ARCHIVE_OK) + { + dt_print(DT_DEBUG_AI, "[ai_models] write error: %s", archive_error_string(ext)); + success = FALSE; + break; + } + } + if(archive_write_finish_entry(ext) != ARCHIVE_OK) + success = FALSE; + } + else + { + dt_print(DT_DEBUG_AI, + "[ai_models] write header error: %s", + archive_error_string(ext)); + success = FALSE; + } + + g_free(full_path); + + if(!success) + break; + } + + free(real_destdir); + archive_read_close(a); + archive_read_free(a); + archive_write_close(ext); + archive_write_free(ext); + + return success; +} + +// install a local .dtmodel file (zip archive) into the models directory. +// returns error message (caller must free) or NULL on success. +char *dt_ai_models_install_local(dt_ai_registry_t *registry, const char *filepath) +{ + if(!registry || !filepath) + return g_strdup(_("invalid parameters")); + + if(!g_file_test(filepath, G_FILE_TEST_IS_REGULAR)) + return g_strdup_printf(_("file not found: %s"), filepath); + + if(!_extract_zip(filepath, registry->models_dir)) + return g_strdup(_("failed to extract model archive")); + + // rescan models directory to pick up newly installed model + dt_ai_models_refresh_status(registry); + + dt_print(DT_DEBUG_AI, "[ai_models] model installed from: %s", filepath); + + return NULL; // success +} + +#ifdef HAVE_AI_DOWNLOAD +// synchronous download - returns error message or NULL on success +char *dt_ai_models_download_sync( + dt_ai_registry_t *registry, + const char *model_id, + dt_ai_progress_callback callback, + gpointer user_data, + const gboolean *cancel_flag) +{ + dt_print(DT_DEBUG_AI, + "[ai_models] download requested for: %s", + model_id ? model_id : "(null)"); + + if(!registry || !model_id) + return g_strdup(_("invalid parameters")); + + // lock once to validate, copy immutable fields, and set status + g_mutex_lock(®istry->lock); + dt_ai_model_t *model = _find_model_unlocked(registry, model_id); + if(!model) + { + g_mutex_unlock(®istry->lock); + return g_strdup(_("model not found in registry")); + } + + if(!model->github_asset) + { + g_mutex_unlock(®istry->lock); + return g_strdup(_("model has no download asset defined")); + } + + // validate asset filename: reject path separators and query strings + if( + strchr(model->github_asset, '/') || strchr(model->github_asset, '\\') + || strchr(model->github_asset, '?') || strchr(model->github_asset, '#') + || strstr(model->github_asset, "..")) + { + g_mutex_unlock(®istry->lock); + return g_strdup(_("invalid asset filename")); + } + + if(model->status == DT_AI_MODEL_DOWNLOADING) + { + g_mutex_unlock(®istry->lock); + return g_strdup(_("model is already downloading")); + } + model->status = DT_AI_MODEL_DOWNLOADING; + model->download_progress = 0.0; + + // copy fields we need outside the lock (repository can be replaced by reload) + char *asset = g_strdup(model->github_asset); + char *checksum_copy = g_strdup(model->checksum); + char *repository = g_strdup(registry->repository); + g_mutex_unlock(®istry->lock); + +// helper macro: set model status under lock and return error +// uses _find_model_unlocked to avoid keeping a stale pointer +#define SET_STATUS_AND_RETURN(status_val, err_expr) \ + do \ + { \ + g_mutex_lock(®istry->lock); \ + dt_ai_model_t *_m = _find_model_unlocked(registry, model_id); \ + if(_m) \ + _m->status = (status_val); \ + g_mutex_unlock(®istry->lock); \ + g_free(asset); \ + g_free(checksum_copy); \ + g_free(repository); \ + return (err_expr); \ + } while(0) + + // validate repository format (must be "owner/repo" with safe characters) + if( + !repository + || !g_regex_match_simple("^[a-zA-Z0-9._-]+/[a-zA-Z0-9._-]+$", repository, 0, 0)) + { + SET_STATUS_AND_RETURN(DT_AI_MODEL_ERROR, g_strdup(_("invalid repository format"))); + } + + { + char *ver = _get_darktable_version_prefix(); + dt_print(DT_DEBUG_AI, "[ai_models] repository: %s", repository); + dt_print(DT_DEBUG_AI, + "[ai_models] darktable version: %s (full: %s)", + ver ? ver : "unknown", darktable_package_version); + g_free(ver); + } + + // find the latest compatible release for this darktable version + char *release_error = NULL; + char *release_tag = _find_latest_compatible_release(repository, &release_error); + if(!release_tag) + { + if(release_error) + { + SET_STATUS_AND_RETURN(DT_AI_MODEL_ERROR, release_error); + } + else + { + char *dt_ver = _get_darktable_version_prefix(); + char *err = g_strdup_printf( + _("no compatible ai model release found for darktable %s"), + dt_ver ? dt_ver : darktable_package_version); + g_free(dt_ver); + SET_STATUS_AND_RETURN(DT_AI_MODEL_ERROR, err); + } + } + + // fetch sha256 digest from github releases api if not already known + if(!checksum_copy || !g_str_has_prefix(checksum_copy, "sha256:")) + { + g_free(checksum_copy); + checksum_copy = _fetch_asset_digest(repository, release_tag, asset); + if(!checksum_copy) + { + g_free(release_tag); + SET_STATUS_AND_RETURN( + DT_AI_MODEL_ERROR, + g_strdup_printf(_("could not obtain checksum for %s — " + "refusing to download without integrity verification"), + asset)); + } + } + + // build github download url using local copies (not model pointer) + char *url = g_strdup_printf( + "https://github.com/%s/releases/download/%s/%s", + repository, + release_tag, + asset); + g_free(release_tag); + + if(!url) + { + SET_STATUS_AND_RETURN(DT_AI_MODEL_ERROR, g_strdup(_("failed to build download url"))); + } + + dt_print(DT_DEBUG_AI, "[ai_models] downloading: %s", url); + + // prepare download path using local copy + char *download_path = g_build_filename(registry->cache_dir, asset, NULL); + + FILE *file = g_fopen(download_path, "wb"); + if(!file) + { + char *err = g_strdup_printf(_("failed to create file: %s"), download_path); + g_free(download_path); + g_free(url); + SET_STATUS_AND_RETURN(DT_AI_MODEL_ERROR, err); + } + + // set up download data (uses model_id copy, not model pointer) + dt_ai_download_data_t dl = { + .registry = registry, + .model_id = (char *)model_id, + .callback = callback, + .user_data = user_data, + .file = file, + .cancel_flag = cancel_flag}; + + // initialize curl + CURL *curl = curl_easy_init(); + if(!curl) + { + fclose(file); + g_free(download_path); + g_free(url); + SET_STATUS_AND_RETURN( + DT_AI_MODEL_ERROR, + g_strdup(_("failed to initialize download"))); + } + dt_curl_init(curl, FALSE); + + curl_easy_setopt(curl, CURLOPT_URL, url); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, _curl_write_callback); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &dl); + curl_easy_setopt(curl, CURLOPT_XFERINFOFUNCTION, _curl_progress_callback); + curl_easy_setopt(curl, CURLOPT_XFERINFODATA, &dl); + curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L); + + CURLcode res = curl_easy_perform(curl); + + fclose(file); + + char *error = NULL; + + if(res != CURLE_OK) + { + if(res == CURLE_ABORTED_BY_CALLBACK) + error = g_strdup(_("download cancelled")); + else + error = g_strdup_printf(_("download failed: %s"), curl_easy_strerror(res)); + g_unlink(download_path); + } + else + { + // check http response code + long http_code = 0; + curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &http_code); + + if(http_code != 200) + { + error = g_strdup_printf(_("http error: %ld"), http_code); + g_unlink(download_path); + } + } + + curl_easy_cleanup(curl); + g_free(url); + + if(error) + { + g_free(download_path); + SET_STATUS_AND_RETURN(DT_AI_MODEL_ERROR, error); + } + + // verify checksum if available (fetched from github api or stored in registry) + if(checksum_copy && g_str_has_prefix(checksum_copy, "sha256:")) + { + if(!_verify_checksum(download_path, checksum_copy)) + { + g_unlink(download_path); + g_free(download_path); + SET_STATUS_AND_RETURN( + DT_AI_MODEL_ERROR, + g_strdup(_("checksum verification failed"))); + } + } + else + { + // should not reach here — checksum is now required before download + g_unlink(download_path); + g_free(download_path); + SET_STATUS_AND_RETURN( + DT_AI_MODEL_ERROR, + g_strdup_printf(_("no checksum available for %s — " + "refusing to install without integrity verification"), + asset)); + } + + // extract to models directory (zip already contains model_id folder) + if(!_extract_zip(download_path, registry->models_dir)) + { + g_unlink(download_path); + g_free(download_path); + SET_STATUS_AND_RETURN(DT_AI_MODEL_ERROR, g_strdup(_("failed to extract archive"))); + } + + // clean up downloaded zip + g_unlink(download_path); + g_free(download_path); + + // mark success + g_mutex_lock(®istry->lock); + dt_ai_model_t *m = _find_model_unlocked(registry, model_id); + if(m) + { + m->status = DT_AI_MODEL_DOWNLOADED; + m->download_progress = 1.0; + } + g_mutex_unlock(®istry->lock); + + dt_print(DT_DEBUG_AI, "[ai_models] download complete: %s", model_id); + + // final callback + if(callback) + callback(model_id, 1.0, user_data); + + g_free(asset); + g_free(checksum_copy); + g_free(repository); + +#undef SET_STATUS_AND_RETURN + + return NULL; // success +} + +// wrapper that returns boolean for compatibility +gboolean dt_ai_models_download( + dt_ai_registry_t *registry, + const char *model_id, + dt_ai_progress_callback callback, + gpointer user_data) +{ + char *error = dt_ai_models_download_sync(registry, model_id, callback, user_data, NULL); + if(error) + { + dt_print(DT_DEBUG_AI, "[ai_models] download error: %s", error); + g_free(error); + return FALSE; + } + return TRUE; +} + +gboolean dt_ai_models_download_default( + dt_ai_registry_t *registry, + dt_ai_progress_callback callback, + gpointer user_data) +{ + if(!registry) + return FALSE; + + // collect ids while holding lock, then download without lock + GList *ids = NULL; + g_mutex_lock(®istry->lock); + for(GList *l = registry->models; l; l = g_list_next(l)) + { + dt_ai_model_t *model = (dt_ai_model_t *)l->data; + if(model->is_default && model->status == DT_AI_MODEL_NOT_DOWNLOADED) + ids = g_list_prepend(ids, g_strdup(model->id)); + } + g_mutex_unlock(®istry->lock); + + gboolean any_started = FALSE; + for(GList *l = ids; l; l = g_list_next(l)) + { + if(dt_ai_models_download(registry, (const char *)l->data, callback, user_data)) + any_started = TRUE; + } + g_list_free_full(ids, g_free); + return any_started; +} + +gboolean dt_ai_models_download_all( + dt_ai_registry_t *registry, + dt_ai_progress_callback callback, + gpointer user_data) +{ + if(!registry) + return FALSE; + + // collect ids while holding lock, then download without lock + GList *ids = NULL; + g_mutex_lock(®istry->lock); + for(GList *l = registry->models; l; l = g_list_next(l)) + { + dt_ai_model_t *model = (dt_ai_model_t *)l->data; + if(model->status == DT_AI_MODEL_NOT_DOWNLOADED) + ids = g_list_prepend(ids, g_strdup(model->id)); + } + g_mutex_unlock(®istry->lock); + + gboolean any_started = FALSE; + for(GList *l = ids; l; l = g_list_next(l)) + { + if(dt_ai_models_download(registry, (const char *)l->data, callback, user_data)) + any_started = TRUE; + } + g_list_free_full(ids, g_free); + return any_started; +} +#endif // HAVE_AI_DOWNLOAD + +static gboolean _rmdir_recursive(const char *path) +{ + if(!g_file_test(path, G_FILE_TEST_IS_DIR)) + { + g_unlink(path); + return TRUE; + } + + GDir *dir = g_dir_open(path, 0, NULL); + if(!dir) + return FALSE; + + const gchar *name; + while((name = g_dir_read_name(dir))) + { + char *child = g_build_filename(path, name, NULL); + if(g_file_test(child, G_FILE_TEST_IS_SYMLINK)) + g_unlink(child); // remove the symlink itself, never follow + else if(g_file_test(child, G_FILE_TEST_IS_DIR)) + _rmdir_recursive(child); + else + g_unlink(child); + g_free(child); + } + g_dir_close(dir); + return g_rmdir(path) == 0; +} + +gboolean dt_ai_models_delete(dt_ai_registry_t *registry, const char *model_id) +{ + if(!registry || !_valid_model_id(model_id)) + return FALSE; + + // check model exists + g_mutex_lock(®istry->lock); + dt_ai_model_t *model = _find_model_unlocked(registry, model_id); + if(!model) + { + g_mutex_unlock(®istry->lock); + return FALSE; + } + g_mutex_unlock(®istry->lock); + + char *model_dir = g_build_filename(registry->models_dir, model_id, NULL); + _rmdir_recursive(model_dir); + g_free(model_dir); + + char *task_copy = NULL; + g_mutex_lock(®istry->lock); + model = _find_model_unlocked(registry, model_id); + if(model) + { + model->status = DT_AI_MODEL_NOT_DOWNLOADED; + model->download_progress = 0.0; + if(model->task) + task_copy = g_strdup(model->task); + } + g_mutex_unlock(®istry->lock); + + // clear active status if this was the active model for its task + if(task_copy) + { + char *active = dt_ai_models_get_active_for_task(task_copy); + if(active && strcmp(active, model_id) == 0) + dt_ai_models_set_active_for_task(task_copy, NULL); + g_free(active); + g_free(task_copy); + } + + return TRUE; +} + +// configuration + +void dt_ai_models_set_enabled( + dt_ai_registry_t *registry, + const char *model_id, + gboolean enabled) +{ + if(!registry || !model_id) + return; + + g_mutex_lock(®istry->lock); + dt_ai_model_t *model = _find_model_unlocked(registry, model_id); + if(model) + model->enabled = enabled; + g_mutex_unlock(®istry->lock); + + if(!model) + return; + + // persist to config + char *conf_key = g_strdup_printf("%s%s/enabled", CONF_MODEL_ENABLED_PREFIX, model_id); + dt_conf_set_bool(conf_key, enabled); + g_free(conf_key); +} + +char *dt_ai_models_get_active_for_task(const char *task) +{ + if(!task || !task[0]) + return NULL; + + // 1. check central config key + char *conf_key = g_strdup_printf("%s%s", CONF_ACTIVE_MODEL_PREFIX, task); + if(dt_conf_key_exists(conf_key)) + { + char *model_id = dt_conf_get_string(conf_key); + g_free(conf_key); + if(model_id && model_id[0]) + return model_id; + g_free(model_id); + return NULL; + } + g_free(conf_key); + + // 2. fall back to the default downloaded model for this task + if(darktable.ai_registry) + { + char *result = NULL; + g_mutex_lock(&darktable.ai_registry->lock); + for(GList *l = darktable.ai_registry->models; l; l = g_list_next(l)) + { + dt_ai_model_t *m = (dt_ai_model_t *)l->data; + if(m->task && strcmp(m->task, task) == 0 + && m->is_default && m->status == DT_AI_MODEL_DOWNLOADED) + { + result = g_strdup(m->id); + break; + } + } + g_mutex_unlock(&darktable.ai_registry->lock); + + if(result) + { + dt_ai_models_set_active_for_task(task, result); + return result; + } + } + + return NULL; +} + +void dt_ai_models_set_active_for_task(const char *task, const char *model_id) +{ + if(!task || !task[0]) + return; + + char *conf_key = g_strdup_printf("%s%s", CONF_ACTIVE_MODEL_PREFIX, task); + dt_conf_set_string(conf_key, model_id ? model_id : ""); + g_free(conf_key); +} + +char *dt_ai_models_get_path(dt_ai_registry_t *registry, const char *model_id) +{ + if(!registry || !_valid_model_id(model_id)) + return NULL; + + g_mutex_lock(®istry->lock); + dt_ai_model_t *model = _find_model_unlocked(registry, model_id); + gboolean downloaded = model && model->status == DT_AI_MODEL_DOWNLOADED; + g_mutex_unlock(®istry->lock); + + if(!downloaded) + return NULL; + + return g_build_filename(registry->models_dir, model_id, NULL); +} + +// clang-format off +// modelines: These editor modelines have been set for all relevant files by tools/update_modelines.py +// vim: shiftwidth=2 expandtab tabstop=2 cindent +// kate: tab-indents: off; indent-width 2; replace-tabs on; indent-mode cstyle; remove-trailing-spaces modified; +// clang-format on diff --git a/src/common/ai_models.h b/src/common/ai_models.h new file mode 100644 index 000000000000..7f282d6fdf97 --- /dev/null +++ b/src/common/ai_models.h @@ -0,0 +1,253 @@ +/* + This file is part of darktable, + Copyright (C) 2026 darktable developers. + + darktable is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + darktable is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with darktable. If not, see . +*/ + +#pragma once + +#include +#include "ai/backend.h" + +// Ensure PATH_MAX is defined on all platforms +#ifndef PATH_MAX +#ifdef _WIN32 +#define PATH_MAX _MAX_PATH +#else +#define PATH_MAX 4096 +#endif +#endif + +/** + * @brief Model download/availability status + */ +typedef enum dt_ai_model_status_t { + DT_AI_MODEL_NOT_DOWNLOADED = 0, + DT_AI_MODEL_DOWNLOADING, + DT_AI_MODEL_DOWNLOADED, + DT_AI_MODEL_ERROR, +} dt_ai_model_status_t; + +/** + * @brief Information about a single AI model + */ +typedef struct dt_ai_model_t { + char *id; // Unique identifier (e.g. "nafnet-sidd-width32") + char *name; // Display name + char *description; // Short description + char *task; // Task type: "denoise", "upscale", etc. + char *github_asset; // Asset filename in GitHub release + char *checksum; // SHA256 checksum (format: "sha256:...") + gboolean is_default; // TRUE if model is a default model for its task + gboolean enabled; // User preference (stored in darktablerc) + dt_ai_model_status_t status; + double download_progress; // 0.0 to 1.0 during download +} dt_ai_model_t; + +/** + * @brief Progress callback for download operations + * @param model_id The model being downloaded + * @param progress Download progress 0.0 to 1.0 + * @param user_data User-provided data + */ +typedef void (*dt_ai_progress_callback)(const char *model_id, double progress, + gpointer user_data); + +/** + * @brief AI Models Registry + * Central registry for managing AI models + */ +typedef struct dt_ai_registry_t { + GList *models; // List of dt_ai_model_t* + char *repository; // GitHub repository (e.g. "darktable-org/darktable-ai") + char *models_dir; // Path to user's models directory + char *cache_dir; // Path to download cache directory + gboolean ai_enabled; // Global AI enable/disable + dt_ai_provider_t provider; // Selected execution provider + GMutex lock; // Thread safety for registry access +} dt_ai_registry_t; + +// --- Core API --- + +/** + * @brief Initialize the AI models registry + * @return New registry instance, or NULL on error + */ +dt_ai_registry_t *dt_ai_models_init(void); + +/** + * @brief Load model registry from JSON file + * @param registry The registry to populate + * @return TRUE on success + */ +gboolean dt_ai_models_load_registry(dt_ai_registry_t *registry); + +/** + * @brief Scan models directory and update download status + * @param registry The registry to update + */ +void dt_ai_models_refresh_status(dt_ai_registry_t *registry); + +/** + * @brief Clean up and free the registry + * @param registry The registry to destroy + */ +void dt_ai_models_cleanup(dt_ai_registry_t *registry); + +// --- Model Access --- +// All get functions return a COPY of the model. Caller must free with dt_ai_model_free(). +// This ensures thread safety: no lock is exposed to callers, preventing deadlocks. + +/** + * @brief Free a model copy returned by get_by_index/get_by_id + * @param model The model to free (may be NULL) + */ +void dt_ai_model_free(dt_ai_model_t *model); + +/** + * @brief Get number of models in registry + * @param registry The registry + * @return Number of models + */ +int dt_ai_models_get_count(dt_ai_registry_t *registry); + +/** + * @brief Get model by index (returns a copy, caller must free with dt_ai_model_free) + * @param registry The registry + * @param index Index 0 to count-1 + * @return Model copy (caller owns), or NULL + */ +dt_ai_model_t *dt_ai_models_get_by_index(dt_ai_registry_t *registry, int index); + +/** + * @brief Get model by ID (returns a copy, caller must free with dt_ai_model_free) + * @param registry The registry + * @param model_id The unique model ID + * @return Model copy (caller owns), or NULL + */ +dt_ai_model_t *dt_ai_models_get_by_id(dt_ai_registry_t *registry, + const char *model_id); + +// --- Install Operations --- + +/** + * @brief Install a model from a local .dtmodel file + * @param registry The registry + * @param filepath Path to the .dtmodel file (zip archive) + * @return Error message (caller must free) or NULL on success + */ +char *dt_ai_models_install_local(dt_ai_registry_t *registry, const char *filepath); + +#ifdef HAVE_AI_DOWNLOAD +// --- Download Operations --- + +/** + * @brief Download a specific model synchronously + * @param registry The registry + * @param model_id The model to download + * @param callback Progress callback (may be NULL) + * @param user_data Data for callback + * @param cancel_flag Pointer to boolean checked for cancellation (may be NULL) + * @return Error message (caller must free) or NULL on success + */ +char *dt_ai_models_download_sync(dt_ai_registry_t *registry, const char *model_id, + dt_ai_progress_callback callback, + gpointer user_data, + const gboolean *cancel_flag); + +/** + * @brief Download a specific model (convenience wrapper) + * @param registry The registry + * @param model_id The model to download + * @param callback Progress callback (may be NULL) + * @param user_data Data for callback + * @return TRUE on success + */ +gboolean dt_ai_models_download(dt_ai_registry_t *registry, const char *model_id, + dt_ai_progress_callback callback, + gpointer user_data); + +/** + * @brief Download all default models (runs in background) + * @param registry The registry + * @param callback Progress callback (may be NULL) + * @param user_data Data for callback + * @return TRUE if downloads started successfully + */ +gboolean dt_ai_models_download_default(dt_ai_registry_t *registry, + dt_ai_progress_callback callback, + gpointer user_data); + +/** + * @brief Download all models (runs in background) + * @param registry The registry + * @param callback Progress callback (may be NULL) + * @param user_data Data for callback + * @return TRUE if downloads started successfully + */ +gboolean dt_ai_models_download_all(dt_ai_registry_t *registry, + dt_ai_progress_callback callback, + gpointer user_data); +#endif /* HAVE_AI_DOWNLOAD */ + +/** + * @brief Delete a downloaded model + * @param registry The registry + * @param model_id The model to delete + * @return TRUE on success + */ +gboolean dt_ai_models_delete(dt_ai_registry_t *registry, const char *model_id); + +// --- Configuration --- + +/** + * @brief Set model enabled state (persisted to config) + * @param registry The registry + * @param model_id The model ID + * @param enabled Whether the model should be enabled + */ +void dt_ai_models_set_enabled(dt_ai_registry_t *registry, const char *model_id, + gboolean enabled); + +/** + * @brief Get the active model ID for a task. + * + * Looks up the centralized config key `plugins/ai/models/active/{task}`. + * If not set, falls back to legacy consumer config keys, then to the + * default downloaded model for the task. + * + * @param task The task type (e.g. "mask", "denoise") + * @return Newly allocated model ID string (caller must free), or NULL if none active + */ +char *dt_ai_models_get_active_for_task(const char *task); + +/** + * @brief Set the active model for a task (exclusive — clears previous). + * + * Persists to `plugins/ai/models/active/{task}` in darktablerc. + * Pass model_id=NULL to clear (disable the task). + * + * @param task The task type (e.g. "mask", "denoise") + * @param model_id The model ID to activate, or NULL to clear + */ +void dt_ai_models_set_active_for_task(const char *task, const char *model_id); + +/** + * @brief Get the path to a downloaded model's directory + * @param registry The registry + * @param model_id The model ID + * @return Path string (caller must free), or NULL if not downloaded + */ +char *dt_ai_models_get_path(dt_ai_registry_t *registry, const char *model_id); diff --git a/src/common/darktable.c b/src/common/darktable.c index d1fae9e8e29b..15280c79ebec 100644 --- a/src/common/darktable.c +++ b/src/common/darktable.c @@ -53,6 +53,9 @@ #include "common/undo.h" #include "common/gimp.h" #include "common/pfm.h" +#ifdef HAVE_AI +#include "common/ai_models.h" +#endif #include "control/conf.h" #include "control/control.h" #include "control/crawler.h" @@ -1118,6 +1121,7 @@ int dt_init(int argc, !strcmp(darg, "pipe") ? DT_DEBUG_PIPE : !strcmp(darg, "expose") ? DT_DEBUG_EXPOSE : !strcmp(darg, "picker") ? DT_DEBUG_PICKER : + !strcmp(darg, "ai") ? DT_DEBUG_AI : // AI related stuff. 0; if(dadd) darktable.unmuted |= dadd; @@ -1629,6 +1633,13 @@ int dt_init(int argc, // get the list of color profiles darktable.color_profiles = dt_colorspaces_init(); +#ifdef HAVE_AI + // initialize AI models registry + darktable.ai_registry = dt_ai_models_init(); + if(darktable.ai_registry) + dt_ai_models_load_registry(darktable.ai_registry); +#endif + // initialize datetime data dt_datetime_init(); @@ -2173,6 +2184,10 @@ void dt_cleanup() dt_mipmap_cache_cleanup(); dt_colorspaces_cleanup(darktable.color_profiles); +#ifdef HAVE_AI + dt_ai_models_cleanup(darktable.ai_registry); + darktable.ai_registry = NULL; +#endif dt_conf_cleanup(darktable.conf); free(darktable.conf); darktable.conf = NULL; diff --git a/src/common/darktable.h b/src/common/darktable.h index 35ee37b9291b..a7914e78fcbe 100644 --- a/src/common/darktable.h +++ b/src/common/darktable.h @@ -281,6 +281,9 @@ struct dt_bauhaus_t; struct dt_undo_t; struct dt_colorspaces_t; struct dt_l10n_t; +#ifdef HAVE_AI +struct dt_ai_registry_t; +#endif typedef float dt_boundingbox_t[4]; //(x,y) of upperleft, then (x,y) of lowerright typedef float dt_pickerbox_t[8]; @@ -318,6 +321,7 @@ typedef enum dt_debug_thread_t DT_DEBUG_PIPE = 1 << 25, DT_DEBUG_EXPOSE = 1 << 26, DT_DEBUG_PICKER = 1 << 27, + DT_DEBUG_AI = 1 << 28, DT_DEBUG_ALL = 0xffffffff & ~DT_DEBUG_VERBOSE, DT_DEBUG_COMMON = DT_DEBUG_OPENCL | DT_DEBUG_DEV | DT_DEBUG_MASKS | DT_DEBUG_PARAMS | DT_DEBUG_IMAGEIO | DT_DEBUG_PIPE, DT_DEBUG_RESTRICT = DT_DEBUG_VERBOSE | DT_DEBUG_PERF, @@ -430,6 +434,9 @@ typedef struct darktable_t struct dt_backthumb_t backthumbs; struct dt_gimp_t gimp; struct dt_splash_t splash; +#ifdef HAVE_AI + struct dt_ai_registry_t *ai_registry; +#endif } darktable_t; typedef struct diff --git a/src/control/signal.c b/src/control/signal.c index 9d4e465cbb49..43068d806101 100644 --- a/src/control/signal.c +++ b/src/control/signal.c @@ -144,6 +144,10 @@ static dt_signal_description _signal_description[DT_SIGNAL_COUNT] = { [DT_SIGNAL_PRESET_APPLIED] = { "dt-preset-applied", NULL, NULL, G_TYPE_NONE, g_cclosure_marshal_VOID__VOID, 0, NULL, NULL, FALSE }, + /* AI models related signals */ + [DT_SIGNAL_AI_MODELS_CHANGED] = { "dt-ai-models-changed", + NULL, NULL, G_TYPE_NONE, g_cclosure_marshal_VOID__VOID, 0, NULL, NULL, FALSE }, + /* Develop related signals */ [DT_SIGNAL_DEVELOP_INITIALIZE] = { "dt-develop-initialized", NULL, NULL, G_TYPE_NONE, g_cclosure_marshal_VOID__VOID, 0, NULL, NULL, FALSE }, diff --git a/src/control/signal.h b/src/control/signal.h index e8330d5d40ae..8fc84b8005b8 100644 --- a/src/control/signal.h +++ b/src/control/signal.h @@ -285,6 +285,11 @@ typedef enum dt_signal_t /* \brief This signal is raised after a preset has been applied */ DT_SIGNAL_PRESET_APPLIED, + /** \brief This signal is raised when AI models have been downloaded/changed + no param, no returned value + */ + DT_SIGNAL_AI_MODELS_CHANGED, + /* do not touch !*/ DT_SIGNAL_COUNT } dt_signal_t; diff --git a/src/gui/preferences.c b/src/gui/preferences.c index 9047840cc4af..a1f1e12a1a0c 100644 --- a/src/gui/preferences.c +++ b/src/gui/preferences.c @@ -31,6 +31,9 @@ #include "gui/draw.h" #include "gui/gtk.h" #include "gui/preferences.h" +#ifdef HAVE_AI +#include "gui/preferences_ai.h" +#endif #include "gui/presets.h" #include "libs/lib.h" #include "preferences_gen.h" @@ -605,6 +608,9 @@ void dt_gui_preferences_show() init_tab_generated(_preferences_dialog, stack); init_tab_accels(stack); init_tab_presets(stack); +#ifdef HAVE_AI + init_tab_ai(_preferences_dialog, stack); +#endif #ifdef USE_LUA GtkGrid* lua_grid = init_tab_lua(_preferences_dialog, stack); #endif diff --git a/src/gui/preferences_ai.c b/src/gui/preferences_ai.c new file mode 100644 index 000000000000..437aebd5c3c4 --- /dev/null +++ b/src/gui/preferences_ai.c @@ -0,0 +1,1161 @@ +/* + This file is part of darktable, + Copyright (C) 2026 darktable developers. + + darktable is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + darktable is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with darktable. If not, see . +*/ + +#include "gui/preferences_ai.h" +#include "bauhaus/bauhaus.h" +#include "common/ai_models.h" +#include "common/darktable.h" +#include "control/conf.h" +#include "control/signal.h" +#include "gui/gtk.h" + +#include + +// non-default indicator +#define NON_DEF_CHAR "\xe2\x97\x8f" + +// update the non-default indicator dot for a boolean preference +static void _update_bool_indicator(GtkWidget *indicator, const char *confkey) +{ + const gboolean current = dt_conf_get_bool(confkey); + const gboolean def = dt_confgen_get_bool(confkey, DT_DEFAULT); + if(current == def) + { + gtk_label_set_text(GTK_LABEL(indicator), ""); + gtk_widget_set_tooltip_text(indicator, NULL); + } + else + { + gtk_label_set_text(GTK_LABEL(indicator), NON_DEF_CHAR); + gtk_widget_set_tooltip_text(indicator, _("this setting has been modified")); + } +} + +// update the non-default indicator dot for a string preference +static void _update_string_indicator(GtkWidget *indicator, const char *confkey) +{ + const gboolean is_default = dt_conf_is_default(confkey); + if(is_default) + { + gtk_label_set_text(GTK_LABEL(indicator), ""); + gtk_widget_set_tooltip_text(indicator, NULL); + } + else + { + gtk_label_set_text(GTK_LABEL(indicator), NON_DEF_CHAR); + gtk_widget_set_tooltip_text(indicator, _("this setting has been modified")); + } +} + +// create the indicator label widget +static GtkWidget *_create_indicator(const char *confkey) +{ + const gboolean is_default = dt_conf_is_default(confkey); + GtkWidget *label; + if(is_default) + label = gtk_label_new(""); + else + { + label = gtk_label_new(NON_DEF_CHAR); + gtk_widget_set_tooltip_text(label, _("this setting has been modified")); + } + gtk_widget_set_name(label, "preference_non_default"); + return label; +} + +// column indices for model list store +enum +{ + COL_SELECTED, + COL_NAME, + COL_TASK, + COL_DESCRIPTION, + COL_ENABLED, + COL_ENABLED_SENSITIVE, // whether the enabled checkbox is clickable + COL_STATUS, + COL_DEFAULT, + COL_ID, + NUM_COLS +}; + +typedef struct dt_prefs_ai_data_t +{ + GtkWidget *enable_toggle; + GtkWidget *provider_combo; + GtkWidget *provider_indicator; + GtkWidget *provider_status; + GtkWidget *model_list; + GtkListStore *model_store; +#ifdef HAVE_AI_DOWNLOAD + GtkWidget *download_selected_btn; + GtkWidget *download_default_btn; + GtkWidget *download_all_btn; +#endif + GtkWidget *install_btn; + GtkWidget *delete_selected_btn; + GtkWidget *parent_dialog; + GtkWidget *select_all_toggle; +} dt_prefs_ai_data_t; + +#ifdef HAVE_AI_DOWNLOAD +// download dialog data +typedef struct dt_download_dialog_t +{ + GtkWidget *dialog; + GtkWidget *progress_bar; + GtkWidget *status_label; + dt_prefs_ai_data_t *prefs_data; + char *model_id; + char *error; + double progress; + gboolean finished; + gboolean cancelled; + GMutex mutex; +} dt_download_dialog_t; +#endif + +// sort by task, then default (yes before no), then name +static gint _model_sort_func(GtkTreeModel *model, + GtkTreeIter *a, + GtkTreeIter *b, + gpointer user_data) +{ + gchar *task_a, *task_b, *default_a, *default_b, *name_a, *name_b; + gtk_tree_model_get(model, a, COL_TASK, &task_a, COL_DEFAULT, &default_a, + COL_NAME, &name_a, -1); + gtk_tree_model_get(model, b, COL_TASK, &task_b, COL_DEFAULT, &default_b, + COL_NAME, &name_b, -1); + + int cmp = g_strcmp0(task_a, task_b); + if(cmp == 0) + { + // "yes" sorts before "no" (reverse alphabetical) + cmp = g_strcmp0(default_b, default_a); + if(cmp == 0) + cmp = g_strcmp0(name_a, name_b); + } + + g_free(task_a); g_free(task_b); + g_free(default_a); g_free(default_b); + g_free(name_a); g_free(name_b); + return cmp; +} + +static const char *_status_to_string(dt_ai_model_status_t status) +{ + switch(status) + { + case DT_AI_MODEL_DOWNLOADED: + return _("downloaded"); + case DT_AI_MODEL_DOWNLOADING: + return _("downloading..."); + case DT_AI_MODEL_ERROR: + return _("error"); + default: + return _("not downloaded"); + } +} + +static void _refresh_model_list(dt_prefs_ai_data_t *data) +{ + if(!darktable.ai_registry) + { + dt_print(DT_DEBUG_AI, "[preferences_ai] registry is NULL"); + return; + } + + gtk_list_store_clear(data->model_store); + + dt_ai_models_refresh_status(darktable.ai_registry); + + const int count = dt_ai_models_get_count(darktable.ai_registry); + dt_print(DT_DEBUG_AI, "[preferences_ai] refreshing model list, count=%d", count); + for(int i = 0; i < count; i++) + { + dt_ai_model_t *model = dt_ai_models_get_by_index(darktable.ai_registry, i); + if(!model) + { + dt_print(DT_DEBUG_AI, "[preferences_ai] model at index %d is NULL", i); + continue; + } + dt_print( + DT_DEBUG_AI, + "[preferences_ai] adding model: %s", + model->id ? model->id : "(null)"); + + // check if this model is the active one for its task + const gboolean is_downloaded = (model->status == DT_AI_MODEL_DOWNLOADED); + gboolean is_active = FALSE; + if(model->task && model->task[0]) + { + char *active_id = dt_ai_models_get_active_for_task(model->task); + is_active = (active_id && g_strcmp0(active_id, model->id) == 0); + g_free(active_id); + } + + GtkTreeIter iter; + gtk_list_store_append(data->model_store, &iter); + gtk_list_store_set( + data->model_store, + &iter, + COL_SELECTED, + FALSE, + COL_ENABLED, + is_active, + COL_ENABLED_SENSITIVE, + is_downloaded, + COL_NAME, + model->name ? model->name : model->id, + COL_TASK, + model->task ? model->task : "", + COL_DESCRIPTION, + model->description ? model->description : "", + COL_STATUS, + _status_to_string(model->status), + COL_DEFAULT, + model->is_default ? _("yes") : _("no"), + COL_ID, + model->id, + -1); + dt_ai_model_free(model); + } + + // reset select-all toggle + if(data->select_all_toggle) + gtk_toggle_button_set_active(GTK_TOGGLE_BUTTON(data->select_all_toggle), FALSE); +} + +static void _on_enable_toggled(GtkWidget *widget, gpointer user_data) +{ + GtkWidget *indicator = GTK_WIDGET(user_data); + const gboolean enabled = gtk_toggle_button_get_active(GTK_TOGGLE_BUTTON(widget)); + dt_conf_set_bool("plugins/ai/enabled", enabled); + if(darktable.ai_registry) + { + g_mutex_lock(&darktable.ai_registry->lock); + darktable.ai_registry->ai_enabled = enabled; + g_mutex_unlock(&darktable.ai_registry->lock); + } + _update_bool_indicator(indicator, "plugins/ai/enabled"); +} + +// map combo box index to provider table index (skipping unavailable providers) +static int _combo_idx_to_provider(int combo_idx) +{ + int visible = -1; + for(int i = 0; i < DT_AI_PROVIDER_COUNT; i++) + { + if(!dt_ai_providers[i].available) continue; + if(++visible == combo_idx) + return i; + } + return 0; // fallback to AUTO +} + +// map provider enum value to combo box index +static int _provider_to_combo_idx(dt_ai_provider_t provider) +{ + int visible = -1; + for(int i = 0; i < DT_AI_PROVIDER_COUNT; i++) + { + if(!dt_ai_providers[i].available) continue; + visible++; + if(dt_ai_providers[i].value == provider) + return visible; + } + return 0; // fallback to first visible (AUTO) +} + +static void _update_provider_status(dt_prefs_ai_data_t *data, dt_ai_provider_t provider) +{ + if(!data->provider_status) return; + + if(provider == DT_AI_PROVIDER_AUTO || provider == DT_AI_PROVIDER_CPU + || dt_ai_probe_provider(provider)) + { + gtk_label_set_text(GTK_LABEL(data->provider_status), ""); + return; + } + + gtk_label_set_markup(GTK_LABEL(data->provider_status), + _("not available, will fall back to CPU")); +} + +static void _on_provider_changed(GtkWidget *widget, gpointer user_data) +{ + dt_prefs_ai_data_t *data = (dt_prefs_ai_data_t *)user_data; + const int combo_idx = dt_bauhaus_combobox_get(widget); + const int pi = _combo_idx_to_provider(combo_idx); + dt_conf_set_string(DT_AI_CONF_PROVIDER, dt_ai_providers[pi].config_string); + if(darktable.ai_registry) + { + g_mutex_lock(&darktable.ai_registry->lock); + darktable.ai_registry->provider = dt_ai_providers[pi].value; + g_mutex_unlock(&darktable.ai_registry->lock); + } + _update_string_indicator(data->provider_indicator, DT_AI_CONF_PROVIDER); + _update_provider_status(data, dt_ai_providers[pi].value); +} + +// double-click on label resets the enable toggle to default +static gboolean +_reset_enable_click(GtkWidget *label, GdkEventButton *event, GtkWidget *widget) +{ + if(event->type == GDK_2BUTTON_PRESS) + { + const gboolean def = dt_confgen_get_bool("plugins/ai/enabled", DT_DEFAULT); + gtk_toggle_button_set_active(GTK_TOGGLE_BUTTON(widget), def); + return TRUE; + } + return FALSE; +} + +// double-click on label resets the provider combo to default +static gboolean +_reset_provider_click(GtkWidget *label, GdkEventButton *event, GtkWidget *widget) +{ + if(event->type == GDK_2BUTTON_PRESS) + { + const char *def = dt_confgen_get(DT_AI_CONF_PROVIDER, DT_DEFAULT); + dt_ai_provider_t provider = dt_ai_provider_from_string(def); + dt_bauhaus_combobox_set(widget, _provider_to_combo_idx(provider)); + return TRUE; + } + return FALSE; +} + +static void _on_model_selection_toggled(GtkCellRendererToggle *cell, + gchar *path_string, + gpointer user_data) +{ + dt_prefs_ai_data_t *data = (dt_prefs_ai_data_t *)user_data; + + GtkTreeIter iter; + GtkTreePath *path = gtk_tree_path_new_from_string(path_string); + gtk_tree_model_get_iter(GTK_TREE_MODEL(data->model_store), &iter, path); + gtk_tree_path_free(path); + + gboolean selected; + gtk_tree_model_get( + GTK_TREE_MODEL(data->model_store), + &iter, + COL_SELECTED, + &selected, + -1); + + // toggle the value + gtk_list_store_set(data->model_store, &iter, COL_SELECTED, !selected, -1); +} + +static void _on_enabled_toggled(GtkCellRendererToggle *cell, + gchar *path_string, + gpointer user_data) +{ + dt_prefs_ai_data_t *data = (dt_prefs_ai_data_t *)user_data; + + GtkTreeIter iter; + GtkTreePath *path = gtk_tree_path_new_from_string(path_string); + gtk_tree_model_get_iter(GTK_TREE_MODEL(data->model_store), &iter, path); + gtk_tree_path_free(path); + + gboolean currently_enabled; + gchar *model_id = NULL; + gchar *task = NULL; + gtk_tree_model_get( + GTK_TREE_MODEL(data->model_store), + &iter, + COL_ENABLED, ¤tly_enabled, + COL_ID, &model_id, + COL_TASK, &task, + -1); + + if(!task || !task[0] || !model_id) + { + g_free(model_id); + g_free(task); + return; + } + + if(currently_enabled) + dt_ai_models_set_active_for_task(task, NULL); + else + dt_ai_models_set_active_for_task(task, model_id); + + g_free(model_id); + g_free(task); + + // refresh to update all checkboxes (previous active model unchecks) + _refresh_model_list(data); + + DT_CONTROL_SIGNAL_RAISE(DT_SIGNAL_AI_MODELS_CHANGED); +} + +static void _on_select_all_toggled(GtkToggleButton *toggle, gpointer user_data) +{ + dt_prefs_ai_data_t *data = (dt_prefs_ai_data_t *)user_data; + const gboolean select_all = gtk_toggle_button_get_active(toggle); + + GtkTreeIter iter; + gboolean valid + = gtk_tree_model_get_iter_first(GTK_TREE_MODEL(data->model_store), &iter); + while(valid) + { + gtk_list_store_set(data->model_store, &iter, COL_SELECTED, select_all, -1); + valid = gtk_tree_model_iter_next(GTK_TREE_MODEL(data->model_store), &iter); + } +} + +static void _on_select_all_header_clicked(GtkWidget *button, gpointer user_data) +{ + dt_prefs_ai_data_t *data = (dt_prefs_ai_data_t *)user_data; + // only toggle if the click wasn't already handled by the checkbox itself. + // block the toggled signal to prevent double-fire, then toggle manually + g_signal_handlers_block_by_func(data->select_all_toggle, _on_select_all_toggled, data); + gboolean active + = gtk_toggle_button_get_active(GTK_TOGGLE_BUTTON(data->select_all_toggle)); + gtk_toggle_button_set_active(GTK_TOGGLE_BUTTON(data->select_all_toggle), !active); + g_signal_handlers_unblock_by_func( + data->select_all_toggle, + _on_select_all_toggled, + data); + + // now manually apply the selection since we blocked the signal + _on_select_all_toggled(GTK_TOGGLE_BUTTON(data->select_all_toggle), data); +} + +// collect selected model IDs from the list store +static GList *_get_selected_model_ids(dt_prefs_ai_data_t *data) +{ + GList *ids = NULL; + GtkTreeIter iter; + gboolean valid + = gtk_tree_model_get_iter_first(GTK_TREE_MODEL(data->model_store), &iter); + while(valid) + { + gboolean selected; + gchar *model_id; + gtk_tree_model_get( + GTK_TREE_MODEL(data->model_store), + &iter, + COL_SELECTED, + &selected, + COL_ID, + &model_id, + -1); + if(selected && model_id) + ids = g_list_append(ids, model_id); + else + g_free(model_id); + valid = gtk_tree_model_iter_next(GTK_TREE_MODEL(data->model_store), &iter); + } + return ids; +} + +#ifdef HAVE_AI_DOWNLOAD +// progress callback called from download thread +static void +_download_progress_callback(const char *model_id, double progress, gpointer user_data) +{ + dt_download_dialog_t *dl = (dt_download_dialog_t *)user_data; + g_mutex_lock(&dl->mutex); + dl->progress = progress; + g_mutex_unlock(&dl->mutex); +} + +// idle callback to update progress bar from main thread +static gboolean _update_progress_idle(gpointer user_data) +{ + dt_download_dialog_t *dl = (dt_download_dialog_t *)user_data; + + g_mutex_lock(&dl->mutex); + double progress = dl->progress; + gboolean finished = dl->finished; + char *error = dl->error ? g_strdup(dl->error) : NULL; + g_mutex_unlock(&dl->mutex); + + if(dl->dialog && GTK_IS_WIDGET(dl->dialog)) + { + gtk_progress_bar_set_fraction(GTK_PROGRESS_BAR(dl->progress_bar), progress); + + char *text = g_strdup_printf("%.0f%%", progress * 100.0); + gtk_progress_bar_set_text(GTK_PROGRESS_BAR(dl->progress_bar), text); + g_free(text); + } + + if(finished) + { + if(error) + { + // show error in dialog + if(dl->dialog && GTK_IS_WIDGET(dl->dialog)) + { + gtk_label_set_text(GTK_LABEL(dl->status_label), error); + gtk_widget_show(dl->status_label); + } + g_free(error); + } + else + { + // success - close dialog + if(dl->dialog && GTK_IS_WIDGET(dl->dialog)) + gtk_dialog_response(GTK_DIALOG(dl->dialog), GTK_RESPONSE_OK); + } + return G_SOURCE_REMOVE; + } + + g_free(error); + return G_SOURCE_CONTINUE; +} + +// download thread function +static gpointer _download_thread_func(gpointer user_data) +{ + dt_download_dialog_t *dl = (dt_download_dialog_t *)user_data; + + char *error = dt_ai_models_download_sync( + darktable.ai_registry, + dl->model_id, + _download_progress_callback, + dl, + &dl->cancelled); + + g_mutex_lock(&dl->mutex); + dl->error = error; + dl->finished = TRUE; + g_mutex_unlock(&dl->mutex); + + return NULL; +} + +// show modal download dialog for a single model +static gboolean +_download_model_with_dialog(dt_prefs_ai_data_t *data, const char *model_id) +{ + dt_ai_model_t *model = dt_ai_models_get_by_id(darktable.ai_registry, model_id); + if(!model) + return FALSE; + + // create dialog + GtkWidget *dialog = gtk_dialog_new_with_buttons( + _("downloading AI model"), + GTK_WINDOW(data->parent_dialog), + GTK_DIALOG_MODAL | GTK_DIALOG_DESTROY_WITH_PARENT, + _("_cancel"), + GTK_RESPONSE_CANCEL, + NULL); + + gtk_window_set_default_size(GTK_WINDOW(dialog), 400, -1); + + GtkWidget *content = gtk_dialog_get_content_area(GTK_DIALOG(dialog)); + gtk_container_set_border_width(GTK_CONTAINER(content), 10); + gtk_box_set_spacing(GTK_BOX(content), 10); + + // model name label (use fields from copy, then free it) + char *title + = g_strdup_printf(_("downloading: %s"), model->name ? model->name : model->id); + dt_ai_model_free(model); + GtkWidget *title_label = gtk_label_new(title); + g_free(title); + dt_gui_box_add(content, title_label); + + // progress bar + GtkWidget *progress_bar = gtk_progress_bar_new(); + gtk_progress_bar_set_show_text(GTK_PROGRESS_BAR(progress_bar), TRUE); + gtk_progress_bar_set_text(GTK_PROGRESS_BAR(progress_bar), "0%"); + dt_gui_box_add(content, progress_bar); + + // status label (for errors) + GtkWidget *status_label = gtk_label_new(""); + gtk_widget_set_no_show_all(status_label, TRUE); + dt_gui_box_add(content, status_label); + + gtk_widget_show_all(dialog); + + // set up download data (heap-allocated for thread safety) + dt_download_dialog_t *dl = g_new0(dt_download_dialog_t, 1); + g_mutex_init(&dl->mutex); + dl->dialog = dialog; + dl->progress_bar = progress_bar; + dl->status_label = status_label; + dl->prefs_data = data; + dl->model_id = g_strdup(model_id); + dl->progress = 0.0; + dl->finished = FALSE; + dl->cancelled = FALSE; + dl->error = NULL; + + // start download thread + GThread *thread = g_thread_new("ai-download", _download_thread_func, dl); + + // start progress update timer + guint timer_id = g_timeout_add(100, _update_progress_idle, dl); + + // run dialog + gint response = gtk_dialog_run(GTK_DIALOG(dialog)); + + // handle cancellation (atomic so curl progress callback can read it safely) + if(response == GTK_RESPONSE_CANCEL) + g_atomic_int_set(&dl->cancelled, TRUE); + + // wait for thread to finish — after this, dl->finished is TRUE + g_thread_join(thread); + + // remove the timer. Any already-dispatched idle callback will see + // dl->finished == TRUE and return G_SOURCE_REMOVE harmlessly. + g_source_remove(timer_id); + + // destroy the dialog before freeing dl so no straggling callback + // can touch destroyed widgets. + gtk_widget_destroy(dialog); + dl->dialog = NULL; + + gboolean success = (dl->error == NULL); + + // notify modules that models have changed + if(success) + DT_CONTROL_SIGNAL_RAISE(DT_SIGNAL_AI_MODELS_CHANGED); + + // clean up + g_mutex_clear(&dl->mutex); + g_free(dl->model_id); + g_free(dl->error); + g_free(dl); + + return success; +} + +static void _on_download_selected(GtkButton *button, gpointer user_data) +{ + dt_prefs_ai_data_t *data = (dt_prefs_ai_data_t *)user_data; + + GList *ids = _get_selected_model_ids(data); + for(GList *l = ids; l; l = g_list_next(l)) + { + const char *id = (const char *)l->data; + dt_ai_model_t *model = dt_ai_models_get_by_id(darktable.ai_registry, id); + if(model) + { + gboolean need_download = (model->status == DT_AI_MODEL_NOT_DOWNLOADED); + dt_ai_model_free(model); + if(need_download && !_download_model_with_dialog(data, id)) + break; // stop on error or cancel + } + } + g_list_free_full(ids, g_free); + _refresh_model_list(data); +} + +static void _on_download_default(GtkButton *button, gpointer user_data) +{ + dt_prefs_ai_data_t *data = (dt_prefs_ai_data_t *)user_data; + + // download default models that need downloading + const int count = dt_ai_models_get_count(darktable.ai_registry); + for(int i = 0; i < count; i++) + { + dt_ai_model_t *model = dt_ai_models_get_by_index(darktable.ai_registry, i); + if(!model) + continue; + gboolean need_download + = (model->is_default && model->status == DT_AI_MODEL_NOT_DOWNLOADED); + char *id = need_download ? g_strdup(model->id) : NULL; + dt_ai_model_free(model); + if(need_download) + { + if(!_download_model_with_dialog(data, id)) + { + g_free(id); + break; // stop on error or cancel + } + g_free(id); + } + } + _refresh_model_list(data); +} + +static void _on_download_all(GtkButton *button, gpointer user_data) +{ + dt_prefs_ai_data_t *data = (dt_prefs_ai_data_t *)user_data; + + // download all models that need downloading + const int count = dt_ai_models_get_count(darktable.ai_registry); + for(int i = 0; i < count; i++) + { + dt_ai_model_t *model = dt_ai_models_get_by_index(darktable.ai_registry, i); + if(!model) + continue; + gboolean need_download = (model->status == DT_AI_MODEL_NOT_DOWNLOADED); + char *id = need_download ? g_strdup(model->id) : NULL; + dt_ai_model_free(model); + if(need_download) + { + if(!_download_model_with_dialog(data, id)) + { + g_free(id); + break; // stop on error or cancel + } + g_free(id); + } + } + _refresh_model_list(data); +} +#endif // HAVE_AI_DOWNLOAD + +static void _on_install_model(GtkButton *button, gpointer user_data) +{ + dt_prefs_ai_data_t *data = (dt_prefs_ai_data_t *)user_data; + + GtkWidget *dialog = gtk_file_chooser_dialog_new( + _("install AI model"), + GTK_WINDOW(data->parent_dialog), + GTK_FILE_CHOOSER_ACTION_OPEN, + _("_cancel"), GTK_RESPONSE_CANCEL, + _("_open"), GTK_RESPONSE_ACCEPT, + NULL); + + GtkFileFilter *filter = gtk_file_filter_new(); + gtk_file_filter_set_name(filter, _("AI model packages (*.dtmodel)")); + gtk_file_filter_add_pattern(filter, "*.dtmodel"); + gtk_file_chooser_add_filter(GTK_FILE_CHOOSER(dialog), filter); + + if(gtk_dialog_run(GTK_DIALOG(dialog)) == GTK_RESPONSE_ACCEPT) + { + char *filepath = gtk_file_chooser_get_filename(GTK_FILE_CHOOSER(dialog)); + gtk_widget_destroy(dialog); + + char *error = dt_ai_models_install_local(darktable.ai_registry, filepath); + if(error) + { + GtkWidget *err_dialog = gtk_message_dialog_new( + GTK_WINDOW(data->parent_dialog), + GTK_DIALOG_MODAL, + GTK_MESSAGE_ERROR, + GTK_BUTTONS_OK, + "%s", error); + gtk_dialog_run(GTK_DIALOG(err_dialog)); + gtk_widget_destroy(err_dialog); + g_free(error); + } + else + { + DT_CONTROL_SIGNAL_RAISE(DT_SIGNAL_AI_MODELS_CHANGED); + _refresh_model_list(data); + } + g_free(filepath); + } + else + { + gtk_widget_destroy(dialog); + } +} + +static void _on_delete_selected(GtkButton *button, gpointer user_data) +{ + dt_prefs_ai_data_t *data = (dt_prefs_ai_data_t *)user_data; + + // collect selected models that are downloaded + GList *ids = _get_selected_model_ids(data); + GList *to_delete = NULL; + int delete_count = 0; + + for(GList *l = ids; l; l = g_list_next(l)) + { + const char *id = (const char *)l->data; + dt_ai_model_t *model = dt_ai_models_get_by_id(darktable.ai_registry, id); + if(model) + { + if(model->status == DT_AI_MODEL_DOWNLOADED) + { + to_delete = g_list_append(to_delete, g_strdup(id)); + delete_count++; + } + dt_ai_model_free(model); + } + } + g_list_free_full(ids, g_free); + + if(delete_count == 0) + { + g_list_free_full(to_delete, g_free); + return; + } + + // confirm deletion + GtkWidget *confirm = gtk_message_dialog_new( + GTK_WINDOW(data->parent_dialog), + GTK_DIALOG_MODAL | GTK_DIALOG_DESTROY_WITH_PARENT, + GTK_MESSAGE_QUESTION, + GTK_BUTTONS_YES_NO, + ngettext("delete %d selected model?", "delete %d selected models?", delete_count), + delete_count); + + gint response = gtk_dialog_run(GTK_DIALOG(confirm)); + gtk_widget_destroy(confirm); + + if(response == GTK_RESPONSE_YES) + { + gboolean any_deleted = FALSE; + for(GList *l = to_delete; l; l = g_list_next(l)) + { + const char *model_id = (const char *)l->data; + if(dt_ai_models_delete(darktable.ai_registry, model_id)) + { + dt_print(DT_DEBUG_AI, "[preferences_ai] deleted model: %s", model_id); + any_deleted = TRUE; + } + } + + if(any_deleted) + DT_CONTROL_SIGNAL_RAISE(DT_SIGNAL_AI_MODELS_CHANGED); + + _refresh_model_list(data); + } + + g_list_free_full(to_delete, g_free); +} + +static void _on_refresh(GtkButton *button, gpointer user_data) +{ + dt_prefs_ai_data_t *data = (dt_prefs_ai_data_t *)user_data; + _refresh_model_list(data); +} + +void init_tab_ai(GtkWidget *dialog, GtkWidget *stack) +{ + dt_prefs_ai_data_t *data = g_new0(dt_prefs_ai_data_t, 1); + data->parent_dialog = dialog; + + // main vertical box holds two independent sections + GtkWidget *main_box = dt_gui_vbox(); + + // === "general" section with its own grid === + GtkWidget *general_grid = gtk_grid_new(); + gtk_grid_set_row_spacing(GTK_GRID(general_grid), DT_PIXEL_APPLY_DPI(3)); + gtk_grid_set_column_spacing(GTK_GRID(general_grid), DT_PIXEL_APPLY_DPI(5)); + + int row = 0; + + // "general" section header + { + GtkWidget *seclabel = gtk_label_new(_("general")); + GtkWidget *lbox = dt_gui_hbox(seclabel); + gtk_widget_set_name(lbox, "pref_section"); + gtk_grid_attach(GTK_GRID(general_grid), lbox, 0, row++, 3, 1); + } + + // enable AI toggle + GtkWidget *enable_label = gtk_label_new(_("enable AI features")); + gtk_widget_set_halign(enable_label, GTK_ALIGN_START); + GtkWidget *enable_labelev = gtk_event_box_new(); + gtk_widget_add_events(enable_labelev, GDK_BUTTON_PRESS_MASK); + gtk_container_add(GTK_CONTAINER(enable_labelev), enable_label); + gtk_event_box_set_visible_window(GTK_EVENT_BOX(enable_labelev), FALSE); + + GtkWidget *enable_indicator = _create_indicator("plugins/ai/enabled"); + data->enable_toggle = gtk_check_button_new(); + gtk_toggle_button_set_active( + GTK_TOGGLE_BUTTON(data->enable_toggle), + dt_conf_get_bool("plugins/ai/enabled")); + g_signal_connect( + data->enable_toggle, + "toggled", + G_CALLBACK(_on_enable_toggled), + enable_indicator); + g_signal_connect( + enable_labelev, + "button-press-event", + G_CALLBACK(_reset_enable_click), + data->enable_toggle); + gtk_grid_attach(GTK_GRID(general_grid), enable_labelev, 0, row, 1, 1); + gtk_grid_attach(GTK_GRID(general_grid), enable_indicator, 1, row, 1, 1); + gtk_grid_attach(GTK_GRID(general_grid), data->enable_toggle, 2, row++, 1, 1); + + // provider dropdown + GtkWidget *provider_label = gtk_label_new(_("execution provider")); + gtk_widget_set_halign(provider_label, GTK_ALIGN_START); + GtkWidget *provider_labelev = gtk_event_box_new(); + gtk_widget_add_events(provider_labelev, GDK_BUTTON_PRESS_MASK); + gtk_container_add(GTK_CONTAINER(provider_labelev), provider_label); + gtk_event_box_set_visible_window(GTK_EVENT_BOX(provider_labelev), FALSE); + + data->provider_indicator = _create_indicator(DT_AI_CONF_PROVIDER); + data->provider_combo = dt_bauhaus_combobox_new(NULL); + + // populate from central provider table, skipping unavailable providers + GString *tooltip = g_string_new(_("select hardware acceleration for AI inference:")); + for(int i = 0; i < DT_AI_PROVIDER_COUNT; i++) + { + if(!dt_ai_providers[i].available) continue; + if(dt_ai_providers[i].value == DT_AI_PROVIDER_AUTO) + dt_bauhaus_combobox_add(data->provider_combo, _("auto")); + else + dt_bauhaus_combobox_add(data->provider_combo, dt_ai_providers[i].display_name); + g_string_append_printf(tooltip, "\n- %s", dt_ai_providers[i].display_name); + } + + char *provider_str = dt_conf_get_string(DT_AI_CONF_PROVIDER); + dt_ai_provider_t provider = dt_ai_provider_from_string(provider_str); + g_free(provider_str); + dt_bauhaus_combobox_set(data->provider_combo, _provider_to_combo_idx(provider)); + + g_signal_connect( + data->provider_combo, + "value-changed", + G_CALLBACK(_on_provider_changed), + data); + g_signal_connect( + provider_labelev, + "button-press-event", + G_CALLBACK(_reset_provider_click), + data->provider_combo); + gtk_widget_set_tooltip_text(data->provider_combo, tooltip->str); + g_string_free(tooltip, TRUE); + data->provider_status = gtk_label_new(NULL); + gtk_label_set_use_markup(GTK_LABEL(data->provider_status), TRUE); + gtk_widget_set_halign(data->provider_status, GTK_ALIGN_START); + + gtk_grid_attach(GTK_GRID(general_grid), provider_labelev, 0, row, 1, 1); + gtk_grid_attach(GTK_GRID(general_grid), data->provider_indicator, 1, row, 1, 1); + gtk_grid_attach(GTK_GRID(general_grid), data->provider_combo, 2, row, 1, 1); + gtk_grid_attach(GTK_GRID(general_grid), data->provider_status, 3, row++, 1, 1); + + dt_gui_box_add(main_box, general_grid); + + // "models" section with its own grid + GtkWidget *models_grid = gtk_grid_new(); + gtk_grid_set_row_spacing(GTK_GRID(models_grid), DT_PIXEL_APPLY_DPI(3)); + gtk_grid_set_column_spacing(GTK_GRID(models_grid), DT_PIXEL_APPLY_DPI(5)); + + row = 0; + + // "models" section header + { + GtkWidget *seclabel = gtk_label_new(_("models")); + GtkWidget *lbox = dt_gui_hbox(seclabel); + gtk_widget_set_name(lbox, "pref_section"); + gtk_grid_attach(GTK_GRID(models_grid), lbox, 0, row++, 1, 1); + } + + // create model list store + data->model_store = gtk_list_store_new( + NUM_COLS, + G_TYPE_BOOLEAN, // selected + G_TYPE_STRING, // name + G_TYPE_STRING, // task + G_TYPE_STRING, // description + G_TYPE_BOOLEAN, // enabled + G_TYPE_BOOLEAN, // enabled_sensitive + G_TYPE_STRING, // status + G_TYPE_STRING, // default + G_TYPE_STRING); // id + + // sort by task, then default, then name + gtk_tree_sortable_set_default_sort_func( + GTK_TREE_SORTABLE(data->model_store), _model_sort_func, NULL, NULL); + gtk_tree_sortable_set_sort_column_id( + GTK_TREE_SORTABLE(data->model_store), + GTK_TREE_SORTABLE_DEFAULT_SORT_COLUMN_ID, GTK_SORT_ASCENDING); + + // create tree view + data->model_list = gtk_tree_view_new_with_model(GTK_TREE_MODEL(data->model_store)); + g_object_unref(data->model_store); // Tree view takes ownership + + // selection checkbox column (no title, with select-all checkbox in header) + GtkCellRenderer *toggle_renderer = gtk_cell_renderer_toggle_new(); + g_signal_connect( + toggle_renderer, + "toggled", + G_CALLBACK(_on_model_selection_toggled), + data); + GtkTreeViewColumn *select_col = gtk_tree_view_column_new_with_attributes( + "", + toggle_renderer, + "active", + COL_SELECTED, + NULL); + + // add select-all checkbox as column header widget + data->select_all_toggle = gtk_check_button_new(); + gtk_widget_set_tooltip_text(data->select_all_toggle, _("select/deselect all")); + g_signal_connect( + data->select_all_toggle, + "toggled", + G_CALLBACK(_on_select_all_toggled), + data); + gtk_widget_show(data->select_all_toggle); + gtk_tree_view_column_set_widget(select_col, data->select_all_toggle); + gtk_tree_view_column_set_clickable(select_col, TRUE); + + gtk_tree_view_append_column(GTK_TREE_VIEW(data->model_list), select_col); + + // connect to the header button's clicked signal so the checkbox toggles + // when clicking anywhere in the header area + GtkWidget *select_col_button = gtk_tree_view_column_get_button(select_col); + g_signal_connect( + select_col_button, + "clicked", + G_CALLBACK(_on_select_all_header_clicked), + data); + + // name column + GtkCellRenderer *text_renderer = gtk_cell_renderer_text_new(); + GtkTreeViewColumn *name_col = gtk_tree_view_column_new_with_attributes( + _("name"), + text_renderer, + "text", + COL_NAME, + NULL); + gtk_tree_view_column_set_expand(name_col, FALSE); + gtk_tree_view_append_column(GTK_TREE_VIEW(data->model_list), name_col); + + // task column + GtkTreeViewColumn *task_col = gtk_tree_view_column_new_with_attributes( + _("task"), + text_renderer, + "text", + COL_TASK, + NULL); + gtk_tree_view_append_column(GTK_TREE_VIEW(data->model_list), task_col); + + // description column + GtkTreeViewColumn *desc_col = gtk_tree_view_column_new_with_attributes( + _("description"), + text_renderer, + "text", + COL_DESCRIPTION, + NULL); + gtk_tree_view_column_set_expand(desc_col, TRUE); + gtk_tree_view_append_column(GTK_TREE_VIEW(data->model_list), desc_col); + + // enabled checkbox column (radio-button behavior per task) + GtkCellRenderer *enabled_renderer = gtk_cell_renderer_toggle_new(); + g_signal_connect( + enabled_renderer, + "toggled", + G_CALLBACK(_on_enabled_toggled), + data); + GtkTreeViewColumn *enabled_col = gtk_tree_view_column_new_with_attributes( + _("enabled"), + enabled_renderer, + "active", COL_ENABLED, + "sensitive", COL_ENABLED_SENSITIVE, + "activatable", COL_ENABLED_SENSITIVE, + NULL); + gtk_tree_view_append_column(GTK_TREE_VIEW(data->model_list), enabled_col); + + // status column + GtkTreeViewColumn *status_col = gtk_tree_view_column_new_with_attributes( + _("status"), + text_renderer, + "text", + COL_STATUS, + NULL); + gtk_tree_view_append_column(GTK_TREE_VIEW(data->model_list), status_col); + + // default column + GtkTreeViewColumn *default_col = gtk_tree_view_column_new_with_attributes( + _("default"), + text_renderer, + "text", + COL_DEFAULT, + NULL); + gtk_tree_view_append_column(GTK_TREE_VIEW(data->model_list), default_col); + + // scrolled window for the list + GtkWidget *scroll = gtk_scrolled_window_new(NULL, NULL); + gtk_scrolled_window_set_policy( + GTK_SCROLLED_WINDOW(scroll), + GTK_POLICY_AUTOMATIC, + GTK_POLICY_AUTOMATIC); + gtk_scrolled_window_set_min_content_height( + GTK_SCROLLED_WINDOW(scroll), + DT_PIXEL_APPLY_DPI(200)); + gtk_widget_set_hexpand(scroll, TRUE); + gtk_widget_set_vexpand(scroll, TRUE); + gtk_container_add(GTK_CONTAINER(scroll), data->model_list); + gtk_grid_attach(GTK_GRID(models_grid), scroll, 0, row++, 1, 1); + + // button box + GtkWidget *button_box = dt_gui_hbox(); + gtk_grid_attach(GTK_GRID(models_grid), button_box, 0, row++, 1, 1); + +#ifdef HAVE_AI_DOWNLOAD + // download selected button + data->download_selected_btn = gtk_button_new_with_label(_("download selected")); + g_signal_connect( + data->download_selected_btn, + "clicked", + G_CALLBACK(_on_download_selected), + data); + dt_gui_box_add(button_box, data->download_selected_btn); + + // download default button + data->download_default_btn = gtk_button_new_with_label(_("download default")); + g_signal_connect( + data->download_default_btn, + "clicked", + G_CALLBACK(_on_download_default), + data); + dt_gui_box_add(button_box, data->download_default_btn); + + // download all button + data->download_all_btn = gtk_button_new_with_label(_("download all")); + g_signal_connect(data->download_all_btn, "clicked", G_CALLBACK(_on_download_all), data); + dt_gui_box_add(button_box, data->download_all_btn); +#endif // HAVE_AI_DOWNLOAD + + // install model button + data->install_btn = gtk_button_new_with_label(_("install model")); + g_signal_connect(data->install_btn, "clicked", G_CALLBACK(_on_install_model), data); + dt_gui_box_add(button_box, data->install_btn); + + // delete selected button + data->delete_selected_btn = gtk_button_new_with_label(_("delete selected")); + g_signal_connect( + data->delete_selected_btn, + "clicked", + G_CALLBACK(_on_delete_selected), + data); + dt_gui_box_add(button_box, data->delete_selected_btn); + + // refresh button + GtkWidget *refresh_btn = gtk_button_new_with_label(_("refresh")); + g_signal_connect(refresh_btn, "clicked", G_CALLBACK(_on_refresh), data); + dt_gui_box_add(button_box, refresh_btn); + + dt_gui_box_add(main_box, models_grid); + + // wrap in a scrolled container like other tabs + GtkWidget *main_scroll = dt_gui_scroll_wrap(main_box); + GtkWidget *tab_box = dt_gui_vbox(main_scroll); + + // add to stack + gtk_stack_add_titled(GTK_STACK(stack), tab_box, "AI", _("AI")); + + // populate model list + _refresh_model_list(data); + + // store data pointer for cleanup (attach to container) + g_object_set_data_full(G_OBJECT(tab_box), "prefs-ai-data", data, g_free); +} + +// clang-format off +// modelines: These editor modelines have been set for all relevant files by tools/update_modelines.py +// vim: shiftwidth=2 expandtab tabstop=2 cindent +// kate: tab-indents: off; indent-width 2; replace-tabs on; indent-mode cstyle; remove-trailing-spaces modified; +// clang-format on diff --git a/src/gui/preferences_ai.h b/src/gui/preferences_ai.h new file mode 100644 index 000000000000..415629fa7014 --- /dev/null +++ b/src/gui/preferences_ai.h @@ -0,0 +1,28 @@ +/* + This file is part of darktable, + Copyright (C) 2026 darktable developers. + + darktable is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + darktable is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with darktable. If not, see . +*/ + +#pragma once + +#include + +/** + * @brief Initialize the AI preferences tab + * @param dialog The preferences dialog + * @param stack The GtkStack to add the tab to + */ +void init_tab_ai(GtkWidget *dialog, GtkWidget *stack); diff --git a/src/tests/unittests/CMakeLists.txt b/src/tests/unittests/CMakeLists.txt index affb343924cc..278d5669e0d5 100644 --- a/src/tests/unittests/CMakeLists.txt +++ b/src/tests/unittests/CMakeLists.txt @@ -1,5 +1,9 @@ add_subdirectory(iop) +if(BUILD_AI) + add_subdirectory(ai) +endif(BUILD_AI) + add_cmocka_test(test_sample SOURCES test_sample.c LINK_LIBRARIES cmocka) diff --git a/src/tests/unittests/ai/CMakeLists.txt b/src/tests/unittests/ai/CMakeLists.txt new file mode 100644 index 000000000000..b1e55f8b7e2c --- /dev/null +++ b/src/tests/unittests/ai/CMakeLists.txt @@ -0,0 +1,6 @@ +add_cmocka_test(test_ai_backend + SOURCES test_ai_backend.c dt_stubs.c + LINK_LIBRARIES darktable_ai cmocka::cmocka) + +target_compile_definitions(test_ai_backend PRIVATE + TEST_MODEL_DIR="${CMAKE_CURRENT_SOURCE_DIR}/models") diff --git a/src/tests/unittests/ai/dt_stubs.c b/src/tests/unittests/ai/dt_stubs.c new file mode 100644 index 000000000000..4266c2466057 --- /dev/null +++ b/src/tests/unittests/ai/dt_stubs.c @@ -0,0 +1,45 @@ +/* + Minimal stubs for darktable symbols required by libdarktable_ai. + Provides just enough for the AI backend to link and run without + the full darktable runtime. +*/ + +#include +#include +#include +#include +#include + +/* darktable global — the AI backend accesses darktable.unmuted + via the dt_debug_if macro. Set all bits so debug output is enabled. */ +char darktable[8192] __attribute__((aligned(16))); + +/* Enable all debug output: set unmuted to 0xFFFFFFFF. + darktable_t layout: dt_codepath_t (4 bytes), int32_t num_openmp_threads (4 bytes), + int32_t unmuted (offset 8). */ +__attribute__((constructor)) +static void _init_darktable_stub(void) +{ + /* Set unmuted field at offset 8 to all-bits-on */ + int32_t *unmuted = (int32_t *)(darktable + 8); + *unmuted = 0x7FFFFFFF; +} + +void dt_print_ext(const char *msg, ...) +{ + va_list ap; + va_start(ap, msg); + vfprintf(stderr, msg, ap); + va_end(ap); + fputc('\n', stderr); +} + +gchar *dt_conf_get_string(const char *name) +{ + return g_strdup("cpu"); +} + +void dt_loc_get_user_config_dir(char *configdir, size_t bufsize) +{ + if(configdir && bufsize > 0) configdir[0] = '\0'; +} diff --git a/src/tests/unittests/ai/generate_test_model.py b/src/tests/unittests/ai/generate_test_model.py new file mode 100644 index 000000000000..68176801ddc5 --- /dev/null +++ b/src/tests/unittests/ai/generate_test_model.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 +"""Generate a tiny ONNX test model for AI backend integration tests. + +Model: y = x * 2 (element-wise multiply) +Input: 'x' float32 [1, 3, 4, 4] +Output: 'y' float32 [1, 3, 4, 4] + +Usage: python3 generate_test_model.py [output_dir] + Default output_dir: ./models +""" + +import json +import os +import sys + +import numpy as np +import onnx +from onnx import TensorProto, helper, numpy_helper + + +def main(): + out_dir = os.path.join(sys.argv[1] if len(sys.argv) > 1 else "models", + "test-multiply") + os.makedirs(out_dir, exist_ok=True) + + # Create model: y = x * 2 + X = helper.make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 4, 4]) + Y = helper.make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 4, 4]) + const_2 = numpy_helper.from_array( + np.array([2.0], dtype=np.float32), name="const_2" + ) + mul_node = helper.make_node("Mul", ["x", "const_2"], ["y"]) + graph = helper.make_graph([mul_node], "test_multiply", [X], [Y], [const_2]) + model = helper.make_model( + graph, opset_imports=[helper.make_opsetid("", 13)] + ) + # Force IR version 8 for compatibility with ONNX Runtime 1.x + model.ir_version = 8 + onnx.checker.check_model(model) + onnx.save(model, os.path.join(out_dir, "model.onnx")) + + # Create config.json + config = { + "id": "test-multiply", + "name": "Test Multiply", + "description": "Test model: y = x * 2", + "task": "test", + "backend": "onnx", + "num_inputs": 1, + } + with open(os.path.join(out_dir, "config.json"), "w") as f: + json.dump(config, f, indent=2) + f.write("\n") + + print(f"Generated test model in {out_dir}") + + +if __name__ == "__main__": + main() diff --git a/src/tests/unittests/ai/models/test-multiply/config.json b/src/tests/unittests/ai/models/test-multiply/config.json new file mode 100644 index 000000000000..260b8ad51759 --- /dev/null +++ b/src/tests/unittests/ai/models/test-multiply/config.json @@ -0,0 +1,8 @@ +{ + "id": "test-multiply", + "name": "Test Multiply", + "description": "Test model: y = x * 2", + "task": "test", + "backend": "onnx", + "num_inputs": 1 +} diff --git a/src/tests/unittests/ai/models/test-multiply/model.onnx b/src/tests/unittests/ai/models/test-multiply/model.onnx new file mode 100644 index 000000000000..b81cf6b0dca2 Binary files /dev/null and b/src/tests/unittests/ai/models/test-multiply/model.onnx differ diff --git a/src/tests/unittests/ai/test_ai_backend.c b/src/tests/unittests/ai/test_ai_backend.c new file mode 100644 index 000000000000..32e6089fbec7 --- /dev/null +++ b/src/tests/unittests/ai/test_ai_backend.c @@ -0,0 +1,400 @@ +/* + This file is part of darktable, + Copyright (C) 2026 darktable developers. + + darktable is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + darktable is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with darktable. If not, see . +*/ + +/* + * how to run these tests: + * + * cmake -B build -DBUILD_AI=ON + * cmake --build build --target test_ai_backend + * cd build && ctest -R test_ai_backend -V + * + * TEST_MODEL_DIR is set automatically by CMake to point at the test + * fixtures directory containing sample ONNX models + */ + +#include +#include +#include +#include +#include +#include + +#include + +#include "ai/backend.h" + +#ifndef TEST_MODEL_DIR +#error "TEST_MODEL_DIR must be defined by CMake" +#endif + +// shared environment for all tests +static dt_ai_environment_t *env; + +// setup / teardown + +static int group_setup(void **state) +{ + env = dt_ai_env_init(TEST_MODEL_DIR); + *state = env; + return env ? 0 : -1; +} + +static int group_teardown(void **state) +{ + dt_ai_env_destroy(env); + env = NULL; + return 0; +} + +// test: environment init + +static void test_env_init(void **state) +{ + assert_non_null(env); +} + +// test: model discovery + +static void test_model_discovery(void **state) +{ + const int count = dt_ai_get_model_count(env); + assert_int_equal(count, 1); + + const dt_ai_model_info_t *info = dt_ai_get_model_info_by_index(env, 0); + assert_non_null(info); + assert_string_equal(info->id, "test-multiply"); + assert_string_equal(info->name, "Test Multiply"); + assert_string_equal(info->task_type, "test"); + assert_string_equal(info->backend, "onnx"); + assert_int_equal(info->num_inputs, 1); +} + +// test: model lookup by ID + +static void test_model_lookup(void **state) +{ + const dt_ai_model_info_t *info + = dt_ai_get_model_info_by_id(env, "test-multiply"); + assert_non_null(info); + assert_string_equal(info->id, "test-multiply"); + + // non-existent model + const dt_ai_model_info_t *none + = dt_ai_get_model_info_by_id(env, "does-not-exist"); + assert_null(none); +} + +// test: model load + +static void test_model_load(void **state) +{ + dt_ai_context_t *ctx + = dt_ai_load_model(env, "test-multiply", NULL, DT_AI_PROVIDER_CPU); + assert_non_null(ctx); + dt_ai_unload_model(ctx); +} + +// test: I/O introspection + +static void test_introspection(void **state) +{ + dt_ai_context_t *ctx + = dt_ai_load_model(env, "test-multiply", NULL, DT_AI_PROVIDER_CPU); + assert_non_null(ctx); + + assert_int_equal(dt_ai_get_input_count(ctx), 1); + assert_int_equal(dt_ai_get_output_count(ctx), 1); + + assert_string_equal(dt_ai_get_input_name(ctx, 0), "x"); + assert_string_equal(dt_ai_get_output_name(ctx, 0), "y"); + + assert_int_equal(dt_ai_get_input_type(ctx, 0), DT_AI_FLOAT); + assert_int_equal(dt_ai_get_output_type(ctx, 0), DT_AI_FLOAT); + + int64_t shape[8]; + const int ndim = dt_ai_get_output_shape(ctx, 0, shape, 8); + assert_int_equal(ndim, 4); + assert_int_equal(shape[0], 1); + assert_int_equal(shape[1], 3); + assert_int_equal(shape[2], 4); + assert_int_equal(shape[3], 4); + + dt_ai_unload_model(ctx); +} + +// test: inference + +static void test_inference(void **state) +{ + dt_ai_context_t *ctx + = dt_ai_load_model(env, "test-multiply", NULL, DT_AI_PROVIDER_CPU); + assert_non_null(ctx); + + // input: all 1.0 + const int n = 1 * 3 * 4 * 4; + float input_data[48]; /* 1*3*4*4 = 48 */ + for(int i = 0; i < n; i++) input_data[i] = 1.0f; + + int64_t in_shape[] = { 1, 3, 4, 4 }; + dt_ai_tensor_t input = { + .data = input_data, + .type = DT_AI_FLOAT, + .shape = in_shape, + .ndim = 4 + }; + + // output buffer + float output_data[48]; + memset(output_data, 0, sizeof(output_data)); + + int64_t out_shape[] = { 1, 3, 4, 4 }; + dt_ai_tensor_t output = { + .data = output_data, + .type = DT_AI_FLOAT, + .shape = out_shape, + .ndim = 4 + }; + + const int ret = dt_ai_run(ctx, &input, 1, &output, 1); + assert_int_equal(ret, 0); + + // y = x * 2 → all outputs should be 2.0 + for(int i = 0; i < n; i++) + { + assert_float_equal(output_data[i], 2.0f, 1e-6f); + } + + dt_ai_unload_model(ctx); +} + +// test: provider setting + +static void test_provider_change(void **state) +{ + // stub returns "cpu" → init should have set CPU + assert_int_equal(dt_ai_env_get_provider(env), DT_AI_PROVIDER_CPU); + + // change to CoreML + dt_ai_env_set_provider(env, DT_AI_PROVIDER_COREML); + assert_int_equal(dt_ai_env_get_provider(env), DT_AI_PROVIDER_COREML); + + // change to AUTO + dt_ai_env_set_provider(env, DT_AI_PROVIDER_AUTO); + assert_int_equal(dt_ai_env_get_provider(env), DT_AI_PROVIDER_AUTO); + + // restore CPU for remaining tests + dt_ai_env_set_provider(env, DT_AI_PROVIDER_CPU); + assert_int_equal(dt_ai_env_get_provider(env), DT_AI_PROVIDER_CPU); +} + +// test: unload + cleanup + +static void test_cleanup(void **state) +{ + dt_ai_context_t *ctx + = dt_ai_load_model(env, "test-multiply", NULL, DT_AI_PROVIDER_CPU); + assert_non_null(ctx); + + // unload should not crash + dt_ai_unload_model(ctx); + + // double-unload NULL should be safe + dt_ai_unload_model(NULL); +} + +// test: error paths — NULL and invalid arguments + +static void test_error_null_env(void **state) +{ + // NULL env should return NULL / 0, not crash + assert_null(dt_ai_load_model(NULL, "test-multiply", NULL, DT_AI_PROVIDER_CPU)); + assert_int_equal(dt_ai_get_model_count(NULL), 0); + assert_null(dt_ai_get_model_info_by_index(NULL, 0)); + assert_null(dt_ai_get_model_info_by_id(NULL, "test-multiply")); + assert_null(dt_ai_get_model_info_by_id(env, NULL)); +} + +static void test_error_bad_model_id(void **state) +{ + // non-existent model ID + dt_ai_context_t *ctx + = dt_ai_load_model(env, "no-such-model", NULL, DT_AI_PROVIDER_CPU); + assert_null(ctx); +} + +static void test_error_bad_model_file(void **state) +{ + // existing model ID but non-existent .onnx file + dt_ai_context_t *ctx + = dt_ai_load_model(env, "test-multiply", "nonexistent.onnx", DT_AI_PROVIDER_CPU); + assert_null(ctx); +} + +static void test_error_introspection_bounds(void **state) +{ + dt_ai_context_t *ctx + = dt_ai_load_model(env, "test-multiply", NULL, DT_AI_PROVIDER_CPU); + assert_non_null(ctx); + + // NULL context + assert_int_equal(dt_ai_get_input_count(NULL), 0); + assert_int_equal(dt_ai_get_output_count(NULL), 0); + assert_null(dt_ai_get_input_name(NULL, 0)); + assert_null(dt_ai_get_output_name(NULL, 0)); + + // out-of-range index + assert_null(dt_ai_get_input_name(ctx, 99)); + assert_null(dt_ai_get_output_name(ctx, -1)); + + // output shape with NULL shape array + assert_int_equal(dt_ai_get_output_shape(ctx, 0, NULL, 0), -1); + + // output shape with too-small buffer + int64_t shape[2]; + const int ndim = dt_ai_get_output_shape(ctx, 0, shape, 2); + /* should return actual ndim (4) but only write 2 elements */ + assert_int_equal(ndim, 4); + + dt_ai_unload_model(ctx); +} + +static void test_error_run_bad_args(void **state) +{ + // dt_ai_run with NULL context + float dummy[48]; + int64_t shape[] = { 1, 3, 4, 4 }; + dt_ai_tensor_t t = { .data = dummy, .type = DT_AI_FLOAT, .shape = shape, .ndim = 4 }; + assert_int_not_equal(dt_ai_run(NULL, &t, 1, &t, 1), 0); +} + +// test: provider string conversion + +static void test_provider_strings(void **state) +{ + // round-trip all known providers + for(int i = 0; i < DT_AI_PROVIDER_COUNT; i++) + { + const char *str = dt_ai_providers[i].config_string; + dt_ai_provider_t parsed = dt_ai_provider_from_string(str); + assert_int_equal(parsed, dt_ai_providers[i].value); + } + + // display name lookup + const char *cpu_name = dt_ai_provider_to_string(DT_AI_PROVIDER_CPU); + assert_non_null(cpu_name); + assert_string_equal(cpu_name, "CPU"); + + // unknown string falls back to AUTO + assert_int_equal(dt_ai_provider_from_string("bogus"), DT_AI_PROVIDER_AUTO); + assert_int_equal(dt_ai_provider_from_string(NULL), DT_AI_PROVIDER_AUTO); + assert_int_equal(dt_ai_provider_from_string(""), DT_AI_PROVIDER_AUTO); + + // provider table completeness + assert_int_equal(dt_ai_providers[0].value, DT_AI_PROVIDER_AUTO); + assert_int_equal(dt_ai_providers[DT_AI_PROVIDER_COUNT - 1].value, DT_AI_PROVIDER_DIRECTML); +} + +// test: env_refresh preserves discovered models ---- + +static void test_env_refresh(void **state) +{ + const int before = dt_ai_get_model_count(env); + dt_ai_env_refresh(env); + const int after = dt_ai_get_model_count(env); + assert_int_equal(before, after); + + // model is still findable after refresh + const dt_ai_model_info_t *info + = dt_ai_get_model_info_by_id(env, "test-multiply"); + assert_non_null(info); + assert_string_equal(info->id, "test-multiply"); +} + +// test: load with optimization levels + +static void test_load_opt_levels(void **state) +{ + // DT_AI_OPT_BASIC + dt_ai_context_t *ctx_basic + = dt_ai_load_model_ext(env, "test-multiply", NULL, + DT_AI_PROVIDER_CPU, DT_AI_OPT_BASIC, NULL, 0); + assert_non_null(ctx_basic); + + // verify inference still works with basic optimization + float in[48], out[48]; + for(int i = 0; i < 48; i++) in[i] = 3.0f; + int64_t shape[] = { 1, 3, 4, 4 }; + dt_ai_tensor_t inp = { .data = in, .type = DT_AI_FLOAT, .shape = shape, .ndim = 4 }; + dt_ai_tensor_t outp = { .data = out, .type = DT_AI_FLOAT, .shape = shape, .ndim = 4 }; + assert_int_equal(dt_ai_run(ctx_basic, &inp, 1, &outp, 1), 0); + assert_float_equal(out[0], 6.0f, 1e-6f); + dt_ai_unload_model(ctx_basic); + + // DT_AI_OPT_DISABLED + dt_ai_context_t *ctx_none + = dt_ai_load_model_ext(env, "test-multiply", NULL, + DT_AI_PROVIDER_CPU, DT_AI_OPT_DISABLED, NULL, 0); + assert_non_null(ctx_none); + dt_ai_unload_model(ctx_none); +} + +// test: env_init with empty/invalid path + +static void test_env_init_empty(void **state) +{ + // non-existent path: should succeed with 0 models + dt_ai_environment_t *e = dt_ai_env_init("/no/such/path/xyz"); + assert_non_null(e); + assert_int_equal(dt_ai_get_model_count(e), 0); + dt_ai_env_destroy(e); + + // NULL path: still creates env (scans default dirs only) + dt_ai_environment_t *e2 = dt_ai_env_init(NULL); + assert_non_null(e2); + dt_ai_env_destroy(e2); +} + +int main(int argc, char *argv[]) +{ + const struct CMUnitTest tests[] = { + cmocka_unit_test(test_env_init), + cmocka_unit_test(test_model_discovery), + cmocka_unit_test(test_model_lookup), + cmocka_unit_test(test_model_load), + cmocka_unit_test(test_introspection), + cmocka_unit_test(test_inference), + cmocka_unit_test(test_provider_change), + cmocka_unit_test(test_cleanup), + cmocka_unit_test(test_error_null_env), + cmocka_unit_test(test_error_bad_model_id), + cmocka_unit_test(test_error_bad_model_file), + cmocka_unit_test(test_error_introspection_bounds), + cmocka_unit_test(test_error_run_bad_args), + cmocka_unit_test(test_provider_strings), + cmocka_unit_test(test_env_refresh), + cmocka_unit_test(test_load_opt_levels), + cmocka_unit_test(test_env_init_empty), + }; + + return cmocka_run_group_tests(tests, group_setup, group_teardown); +} +// clang-format off +// modelines: These editor modelines have been set for all relevant files by tools/update_modelines.py +// vim: shiftwidth=2 expandtab tabstop=2 cindent +// kate: tab-indents: off; indent-width 2; replace-tabs on; indent-mode cstyle; remove-trailing-spaces modified; +// clang-format on