diff --git a/src/ai/backend.h b/src/ai/backend.h index c34878a1043d..3e24922e5231 100644 --- a/src/ai/backend.h +++ b/src/ai/backend.h @@ -93,10 +93,11 @@ typedef struct dt_ai_context_t dt_ai_context_t; * @brief Model Metadata (ReadOnly) */ typedef struct dt_ai_model_info_t { - const char *id; ///< Unique ID (e.g. "nafnet-sidd") + const char *id; ///< Unique ID (e.g. "mask-object-segnext-b2hq") const char *name; ///< Display name const char *description; ///< Short description - const char *task_type; ///< e.g. "denoise", "inpainting" + const char *task_type; ///< e.g. "mask", "denoise" + const char *arch; ///< e.g. "sam2", "segnext" const char *backend; ///< Backend type (e.g. "onnx") int num_inputs; ///< Number of model inputs (default 1) } dt_ai_model_info_t; diff --git a/src/ai/backend_common.c b/src/ai/backend_common.c index ee72580620ab..679efb007e82 100644 --- a/src/ai/backend_common.c +++ b/src/ai/backend_common.c @@ -143,6 +143,11 @@ static void _scan_directory(dt_ai_environment_t *env, const char *root_path) _store_string(env, name, &info->name); _store_string(env, desc, &info->description); _store_string(env, task, &info->task_type); + + const char *arch = json_object_has_member(obj, "arch") + ? json_object_get_string_member(obj, "arch") + : ""; + _store_string(env, arch, &info->arch); _store_string(env, backend, &info->backend); info->num_inputs = json_object_has_member(obj, "num_inputs") ? (int)json_object_get_int_member(obj, "num_inputs") diff --git a/src/ai/segmentation.c b/src/ai/segmentation.c index 6d5766889eda..b20ed2524640 100644 --- a/src/ai/segmentation.c +++ b/src/ai/segmentation.c @@ -300,20 +300,20 @@ dt_seg_context_t *dt_seg_load(dt_ai_environment_t *env, const char *model_id) ctx->enc_order[0], ctx->enc_order[1], ctx->enc_order[2], ctx->enc_order[3], ctx->n_enc_outputs); - // detect model type from decoder output count - // SAM: 3+ outputs (masks, iou_predictions, low_res_masks) - // SegNext: 1 output (mask) - const int n_dec_outputs = dt_ai_get_output_count(decoder); + // detect model type from arch field in model registry + const dt_ai_model_info_t *minfo + = dt_ai_get_model_info_by_id(env, model_id); + const char *arch = minfo ? minfo->arch : ""; - if(n_dec_outputs >= 3) + if(strcmp(arch, "sam2") == 0) ctx->model_type = DT_SEG_MODEL_SAM; - else if(n_dec_outputs == 1) + else if(strcmp(arch, "segnext") == 0) ctx->model_type = DT_SEG_MODEL_SEGNEXT; else { dt_print(DT_DEBUG_AI, - "[segmentation] decoder has %d outputs, unsupported for %s", - n_dec_outputs, model_id); + "[segmentation] unknown arch '%s' for %s", + arch, model_id); dt_seg_free(ctx); return NULL; } diff --git a/src/common/ai_models.h b/src/common/ai_models.h index 59eaee32f89f..79fb87092631 100644 --- a/src/common/ai_models.h +++ b/src/common/ai_models.h @@ -46,7 +46,7 @@ typedef enum dt_ai_model_status_t */ typedef struct dt_ai_model_t { - char *id; // Unique identifier (e.g. "nafnet-sidd-width32") + char *id; // Unique identifier (e.g. "mask-object-segnext-b2hq") char *name; // Display name char *description; // Short description char *task; // Task type: "denoise", "upscale", etc.