Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/ai/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,11 @@ typedef struct dt_ai_context_t dt_ai_context_t;
* @brief Model Metadata (ReadOnly)
*/
typedef struct dt_ai_model_info_t {
const char *id; ///< Unique ID (e.g. "nafnet-sidd")
const char *id; ///< Unique ID (e.g. "mask-object-segnext-b2hq")
const char *name; ///< Display name
const char *description; ///< Short description
const char *task_type; ///< e.g. "denoise", "inpainting"
const char *task_type; ///< e.g. "mask", "denoise"
const char *arch; ///< e.g. "sam2", "segnext"
const char *backend; ///< Backend type (e.g. "onnx")
int num_inputs; ///< Number of model inputs (default 1)
} dt_ai_model_info_t;
Expand Down
5 changes: 5 additions & 0 deletions src/ai/backend_common.c
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,11 @@ static void _scan_directory(dt_ai_environment_t *env, const char *root_path)
_store_string(env, name, &info->name);
_store_string(env, desc, &info->description);
_store_string(env, task, &info->task_type);

const char *arch = json_object_has_member(obj, "arch")
? json_object_get_string_member(obj, "arch")
: "";
_store_string(env, arch, &info->arch);
_store_string(env, backend, &info->backend);
info->num_inputs = json_object_has_member(obj, "num_inputs")
? (int)json_object_get_int_member(obj, "num_inputs")
Expand Down
16 changes: 8 additions & 8 deletions src/ai/segmentation.c
Original file line number Diff line number Diff line change
Expand Up @@ -300,20 +300,20 @@ dt_seg_context_t *dt_seg_load(dt_ai_environment_t *env, const char *model_id)
ctx->enc_order[0], ctx->enc_order[1], ctx->enc_order[2], ctx->enc_order[3],
ctx->n_enc_outputs);

// detect model type from decoder output count
// SAM: 3+ outputs (masks, iou_predictions, low_res_masks)
// SegNext: 1 output (mask)
const int n_dec_outputs = dt_ai_get_output_count(decoder);
// detect model type from arch field in model registry
const dt_ai_model_info_t *minfo
= dt_ai_get_model_info_by_id(env, model_id);
const char *arch = minfo ? minfo->arch : "";

if(n_dec_outputs >= 3)
if(strcmp(arch, "sam2") == 0)
ctx->model_type = DT_SEG_MODEL_SAM;
else if(n_dec_outputs == 1)
else if(strcmp(arch, "segnext") == 0)
ctx->model_type = DT_SEG_MODEL_SEGNEXT;
else
{
dt_print(DT_DEBUG_AI,
"[segmentation] decoder has %d outputs, unsupported for %s",
n_dec_outputs, model_id);
"[segmentation] unknown arch '%s' for %s",
arch, model_id);
dt_seg_free(ctx);
return NULL;
}
Expand Down
2 changes: 1 addition & 1 deletion src/common/ai_models.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ typedef enum dt_ai_model_status_t
*/
typedef struct dt_ai_model_t
{
char *id; // Unique identifier (e.g. "nafnet-sidd-width32")
char *id; // Unique identifier (e.g. "mask-object-segnext-b2hq")
char *name; // Display name
char *description; // Short description
char *task; // Task type: "denoise", "upscale", etc.
Expand Down
Loading