diff --git a/docs/contributing/EXTERNAL_PROVIDERS.md b/docs/contributing/EXTERNAL_PROVIDERS.md new file mode 100644 index 00000000000..989698c4ba7 --- /dev/null +++ b/docs/contributing/EXTERNAL_PROVIDERS.md @@ -0,0 +1,129 @@ +# External Provider Integration + +This guide covers: + +1. Adding a new **external model** (most common; existing provider). +2. Adding a brand-new **external provider** (adapter + config + UI wiring). + +## 1) Add a New External Model (Existing Provider) + +For provider-backed models (for example, OpenAI or Gemini), the source of truth is +`invokeai/backend/model_manager/starter_models.py`. + +### Required model fields + +Define a `StarterModel` with: + +- `base=BaseModelType.External` +- `type=ModelType.ExternalImageGenerator` +- `format=ModelFormat.ExternalApi` +- `source="external:///"` +- `name`, `description` +- `capabilities=ExternalModelCapabilities(...)` +- optional `default_settings=ExternalApiModelDefaultSettings(...)` + +Example: + +```python +new_external_model = StarterModel( + name="Provider Model Name", + base=BaseModelType.External, + source="external://openai/my-model-id", + description=( + "Provider model (external API). " + "Requires a configured OpenAI API key and may incur provider usage costs." + ), + type=ModelType.ExternalImageGenerator, + format=ModelFormat.ExternalApi, + capabilities=ExternalModelCapabilities( + modes=["txt2img", "img2img", "inpaint"], + supports_negative_prompt=False, + supports_seed=False, + supports_guidance=False, + supports_steps=False, + supports_reference_images=True, + max_images_per_request=4, + ), + default_settings=ExternalApiModelDefaultSettings( + width=1024, + height=1024, + num_images=1, + ), +) +``` + +Then append it to `STARTER_MODELS`. + +### Required description text + +External starter model descriptions must clearly state: + +- an API key is required +- usage may incur provider-side costs + +### Capabilities must be accurate + +These flags directly control UI visibility and request payload fields: + +- `supports_negative_prompt` +- `supports_seed` +- `supports_guidance` +- `supports_steps` +- `supports_reference_images` + +`supports_steps` is especially important: if `False`, steps are hidden for that model and `steps` is sent as `null`. + +### Source string stability + +Starter overrides are matched by `source` (`external://provider/model-id`). Keep this stable: + +- runtime capability/default overrides depend on it +- installation detection in starter-model APIs depends on it + +`STARTER_MODELS` enforces unique `source` values with an assertion. + +### Install behavior notes + +- External starter models are managed in **External Providers** setup (not the regular Starter Models tab). +- External starter models auto-install when a provider is configured. +- Removing a provider API key removes installed external models for that provider. + +## 2) Credentials and Config + +External provider API keys are stored separately from `invokeai.yaml`: + +- default file: `~/invokeai/api_keys.yaml` +- resolved path: `/api_keys.yaml` + +Non-secret provider settings (for example base URL overrides) stay in `invokeai.yaml`. + +Environment variables are still supported, e.g.: + +- `INVOKEAI_EXTERNAL_GEMINI_API_KEY` +- `INVOKEAI_EXTERNAL_OPENAI_API_KEY` + +## 3) Add a New Provider (Only If Needed) + +If your model uses a provider that is not already integrated: + +1. Add config fields in `invokeai/app/services/config/config_default.py` + `external__api_key` and optional `external__base_url`. +2. Add provider field mapping in `invokeai/app/api/routers/app_info.py` + (`EXTERNAL_PROVIDER_FIELDS`). +3. Implement provider adapter in `invokeai/app/services/external_generation/providers/` + by subclassing `ExternalProvider`. +4. Register the provider in `invokeai/app/api/dependencies.py` when building + `ExternalGenerationService`. +5. Add starter model entries using `source="external:///"`. +6. Optional UI ordering tweak: + `invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/ExternalProviders/ExternalProvidersForm.tsx` + (`PROVIDER_SORT_ORDER`). + +## 4) Optional Manual Installation + +You can also install external models directly via: + +`POST /api/v2/models/install?source=external:///` + +If omitted, `path`, `source`, and `hash` are auto-populated for external model configs. +Set capabilities conservatively; the external generation service enforces capability checks at runtime. diff --git a/docs/contributing/index.md b/docs/contributing/index.md index 79c1082746d..b8002a18024 100644 --- a/docs/contributing/index.md +++ b/docs/contributing/index.md @@ -8,6 +8,10 @@ We welcome contributions, whether features, bug fixes, code cleanup, testing, co If you’d like to help with development, please see our [development guide](contribution_guides/development.md). +## External Providers + +If you are adding external image generation providers or configs, see our [external provider integration guide](EXTERNAL_PROVIDERS.md). + **New Contributors:** If you’re unfamiliar with contributing to open source projects, take a look at our [new contributor guide](contribution_guides/newContributorChecklist.md). ## Nodes diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 339a0ceadb4..ce7333bc446 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -16,6 +16,8 @@ from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.download.download_default import DownloadQueueService from invokeai.app.services.events.events_fastapievents import FastAPIEventService +from invokeai.app.services.external_generation.external_generation_default import ExternalGenerationService +from invokeai.app.services.external_generation.providers import GeminiProvider, OpenAIProvider from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage from invokeai.app.services.image_records.image_records_sqlite import SqliteImageRecordStorage from invokeai.app.services.images.images_default import ImageService @@ -145,13 +147,22 @@ def initialize( ), ) download_queue_service = DownloadQueueService(app_config=configuration, event_bus=events) - model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images") + model_record_service = ModelRecordServiceSQL(db=db, logger=logger) model_manager = ModelManagerService.build_model_manager( app_config=configuration, - model_record_service=ModelRecordServiceSQL(db=db, logger=logger), + model_record_service=model_record_service, download_queue=download_queue_service, events=events, ) + external_generation = ExternalGenerationService( + providers={ + GeminiProvider.provider_id: GeminiProvider(app_config=configuration, logger=logger), + OpenAIProvider.provider_id: OpenAIProvider(app_config=configuration, logger=logger), + }, + logger=logger, + record_store=model_record_service, + ) + model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images") model_relationships = ModelRelationshipsService() model_relationship_records = SqliteModelRelationshipRecordStorage(db=db) names = SimpleNameService() @@ -184,6 +195,7 @@ def initialize( model_relationships=model_relationships, model_relationship_records=model_relationship_records, download_queue=download_queue_service, + external_generation=external_generation, names=names, performance_statistics=performance_statistics, session_processor=session_processor, diff --git a/invokeai/app/api/routers/app_info.py b/invokeai/app/api/routers/app_info.py index d8f3bb2f807..c0c002ac877 100644 --- a/invokeai/app/api/routers/app_info.py +++ b/invokeai/app/api/routers/app_info.py @@ -1,15 +1,29 @@ +import locale from enum import Enum from importlib.metadata import distributions +from pathlib import Path as FilePath +from threading import Lock import torch -from fastapi import Body +import yaml +from fastapi import Body, HTTPException, Path from fastapi.routing import APIRouter from pydantic import BaseModel, Field from invokeai.app.api.dependencies import ApiDependencies -from invokeai.app.services.config.config_default import InvokeAIAppConfig, get_config +from invokeai.app.services.config.config_default import ( + DefaultInvokeAIAppConfig, + EXTERNAL_API_KEY_FIELDS, + InvokeAIAppConfig, + get_config, + load_external_api_keys, + load_and_migrate_config, +) +from invokeai.app.services.external_generation.external_generation_common import ExternalProviderStatus from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus +from invokeai.app.services.model_records.model_records_base import UnknownModelException from invokeai.backend.image_util.infill_methods.patchmatch import PatchMatch +from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType from invokeai.backend.util.logging import logging from invokeai.version import __version__ @@ -41,7 +55,7 @@ async def get_version() -> AppVersion: async def get_app_deps() -> dict[str, str]: deps: dict[str, str] = {dist.metadata["Name"]: dist.version for dist in distributions()} try: - cuda = torch.version.cuda or "N/A" + cuda = getattr(getattr(torch, "version", None), "cuda", None) or "N/A" # pyright: ignore[reportAttributeAccessIssue] except Exception: cuda = "N/A" @@ -64,6 +78,30 @@ class InvokeAIAppConfigWithSetFields(BaseModel): config: InvokeAIAppConfig = Field(description="The InvokeAI App Config") +class ExternalProviderStatusModel(BaseModel): + provider_id: str = Field(description="The external provider identifier") + configured: bool = Field(description="Whether credentials are configured for the provider") + message: str | None = Field(default=None, description="Optional provider status detail") + + +class ExternalProviderConfigUpdate(BaseModel): + api_key: str | None = Field(default=None, description="API key for the external provider") + base_url: str | None = Field(default=None, description="Optional base URL override for the provider") + + +class ExternalProviderConfigModel(BaseModel): + provider_id: str = Field(description="The external provider identifier") + api_key_configured: bool = Field(description="Whether an API key is configured") + base_url: str | None = Field(default=None, description="Optional base URL override") + + +EXTERNAL_PROVIDER_FIELDS: dict[str, tuple[str, str]] = { + "gemini": ("external_gemini_api_key", "external_gemini_base_url"), + "openai": ("external_openai_api_key", "external_openai_base_url"), +} +_EXTERNAL_PROVIDER_CONFIG_LOCK = Lock() + + @app_router.get( "/runtime_config", operation_id="get_runtime_config", status_code=200, response_model=InvokeAIAppConfigWithSetFields ) @@ -72,6 +110,166 @@ async def get_runtime_config() -> InvokeAIAppConfigWithSetFields: return InvokeAIAppConfigWithSetFields(set_fields=config.model_fields_set, config=config) +@app_router.get( + "/external_providers/status", + operation_id="get_external_provider_statuses", + status_code=200, + response_model=list[ExternalProviderStatusModel], +) +async def get_external_provider_statuses() -> list[ExternalProviderStatusModel]: + statuses = ApiDependencies.invoker.services.external_generation.get_provider_statuses() + return [status_to_model(status) for status in statuses.values()] + + +@app_router.get( + "/external_providers/config", + operation_id="get_external_provider_configs", + status_code=200, + response_model=list[ExternalProviderConfigModel], +) +async def get_external_provider_configs() -> list[ExternalProviderConfigModel]: + config = get_config() + return [_build_external_provider_config(provider_id, config) for provider_id in EXTERNAL_PROVIDER_FIELDS] + + +@app_router.post( + "/external_providers/config/{provider_id}", + operation_id="set_external_provider_config", + status_code=200, + response_model=ExternalProviderConfigModel, +) +async def set_external_provider_config( + provider_id: str = Path(description="The external provider identifier"), + update: ExternalProviderConfigUpdate = Body(description="External provider configuration settings"), +) -> ExternalProviderConfigModel: + api_key_field, base_url_field = _get_external_provider_fields(provider_id) + updates: dict[str, str | None] = {} + + if update.api_key is not None: + api_key = update.api_key.strip() + updates[api_key_field] = api_key or None + if update.base_url is not None: + base_url = update.base_url.strip() + updates[base_url_field] = base_url or None + + if not updates: + raise HTTPException(status_code=400, detail="No external provider config fields provided") + + api_key_removed = update.api_key is not None and updates.get(api_key_field) is None + _apply_external_provider_update(updates) + if api_key_removed: + _remove_external_models_for_provider(provider_id) + return _build_external_provider_config(provider_id, get_config()) + + +@app_router.delete( + "/external_providers/config/{provider_id}", + operation_id="reset_external_provider_config", + status_code=200, + response_model=ExternalProviderConfigModel, +) +async def reset_external_provider_config( + provider_id: str = Path(description="The external provider identifier"), +) -> ExternalProviderConfigModel: + api_key_field, base_url_field = _get_external_provider_fields(provider_id) + _apply_external_provider_update({api_key_field: None, base_url_field: None}) + _remove_external_models_for_provider(provider_id) + return _build_external_provider_config(provider_id, get_config()) + + +def status_to_model(status: ExternalProviderStatus) -> ExternalProviderStatusModel: + return ExternalProviderStatusModel( + provider_id=status.provider_id, + configured=status.configured, + message=status.message, + ) + + +def _get_external_provider_fields(provider_id: str) -> tuple[str, str]: + if provider_id not in EXTERNAL_PROVIDER_FIELDS: + raise HTTPException(status_code=404, detail=f"Unknown external provider '{provider_id}'") + return EXTERNAL_PROVIDER_FIELDS[provider_id] + + +def _write_external_api_keys_file(api_keys_file_path: FilePath, api_keys: dict[str, str]) -> None: + if not api_keys: + if api_keys_file_path.exists(): + api_keys_file_path.unlink() + return + + api_keys_file_path.parent.mkdir(parents=True, exist_ok=True) + with open(api_keys_file_path, "w", encoding=locale.getpreferredencoding()) as api_keys_file: + yaml.safe_dump(api_keys, api_keys_file, sort_keys=False) + + +def _apply_external_provider_update(updates: dict[str, str | None]) -> None: + with _EXTERNAL_PROVIDER_CONFIG_LOCK: + runtime_config = get_config() + config_path = runtime_config.config_file_path + api_keys_file_path = runtime_config.api_keys_file_path + if config_path.exists(): + file_config = load_and_migrate_config(config_path) + else: + file_config = DefaultInvokeAIAppConfig() + + runtime_config.update_config(updates) + key_fields = set(EXTERNAL_API_KEY_FIELDS) + key_updates = {field: value for field, value in updates.items() if field in key_fields} + non_key_updates = {field: value for field, value in updates.items() if field not in key_fields} + + if non_key_updates: + file_config.update_config(non_key_updates) + + persisted_api_keys = load_external_api_keys(api_keys_file_path) + for field_name in EXTERNAL_API_KEY_FIELDS: + file_value = getattr(file_config, field_name, None) + if field_name not in persisted_api_keys and isinstance(file_value, str) and file_value.strip(): + persisted_api_keys[field_name] = file_value + + for field_name, value in key_updates.items(): + if value is None: + persisted_api_keys.pop(field_name, None) + else: + persisted_api_keys[field_name] = value + + _write_external_api_keys_file(api_keys_file_path, persisted_api_keys) + + for field_name in EXTERNAL_API_KEY_FIELDS: + setattr(file_config, field_name, None) + + file_config_to_write = type(file_config).model_validate( + file_config.model_dump(exclude_unset=True, exclude_none=True) + ) + file_config_to_write.write_file(config_path, as_example=False) + + +def _build_external_provider_config(provider_id: str, config: InvokeAIAppConfig) -> ExternalProviderConfigModel: + api_key_field, base_url_field = _get_external_provider_fields(provider_id) + return ExternalProviderConfigModel( + provider_id=provider_id, + api_key_configured=bool(getattr(config, api_key_field)), + base_url=getattr(config, base_url_field), + ) + + +def _remove_external_models_for_provider(provider_id: str) -> None: + model_manager = ApiDependencies.invoker.services.model_manager + external_models = model_manager.store.search_by_attr( + base_model=BaseModelType.External, + model_type=ModelType.ExternalImageGenerator, + ) + + for model in external_models: + if getattr(model, "provider_id", None) != provider_id: + continue + try: + model_manager.install.delete(model.key) + except UnknownModelException: + logging.warning(f"External model key '{model.key}' was already removed while resetting '{provider_id}'") + except Exception as error: + logging.warning(f"Failed removing external model key '{model.key}' for '{provider_id}': {error}") + + @app_router.get( "/logging", operation_id="get_log_level", diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py index a1f6b3a744a..351948d9001 100644 --- a/invokeai/app/api/routers/model_manager.py +++ b/invokeai/app/api/routers/model_manager.py @@ -30,6 +30,7 @@ ) from invokeai.app.services.orphaned_models import OrphanedModelInfo from invokeai.app.util.suppress_output import SuppressOutput +from invokeai.backend.model_manager.configs.external_api import ExternalApiModelConfig from invokeai.backend.model_manager.configs.factory import AnyModelConfig, ModelConfigFactory from invokeai.backend.model_manager.configs.main import ( Main_Checkpoint_SD1_Config, @@ -75,8 +76,7 @@ class CacheType(str, Enum): def add_cover_image_to_model_config(config: AnyModelConfig, dependencies: Type[ApiDependencies]) -> AnyModelConfig: """Add a cover image URL to a model configuration.""" cover_image = dependencies.invoker.services.model_images.get_url(config.key) - config.cover_image = cover_image - return config + return config.model_copy(update={"cover_image": cover_image}) ############################################################################## @@ -145,8 +145,19 @@ async def list_model_records( found_models.extend( record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format) ) - for model in found_models: + for index, model in enumerate(found_models): model = add_cover_image_to_model_config(model, ApiDependencies) + if isinstance(model, ExternalApiModelConfig): + starter_match = next((starter for starter in STARTER_MODELS if starter.source == model.source), None) + if starter_match is not None: + model_updates: dict[str, object] = {} + if starter_match.capabilities is not None: + model_updates["capabilities"] = starter_match.capabilities + if starter_match.default_settings is not None: + model_updates["default_settings"] = starter_match.default_settings + if model_updates: + model = model.model_copy(update=model_updates) + found_models[index] = model return ModelsList(models=found_models) @@ -166,6 +177,8 @@ async def list_missing_models() -> ModelsList: missing_models: list[AnyModelConfig] = [] for model_config in record_store.all_models(): + if model_config.base == BaseModelType.External or model_config.format == ModelFormat.ExternalApi: + continue if not (models_path / model_config.path).resolve().exists(): missing_models.append(model_config) @@ -250,7 +263,8 @@ async def reidentify_model( result.config.name = config.name result.config.description = config.description result.config.cover_image = config.cover_image - result.config.trigger_phrases = config.trigger_phrases + if hasattr(result.config, "trigger_phrases") and hasattr(config, "trigger_phrases"): + result.config.trigger_phrases = config.trigger_phrases result.config.source = config.source result.config.source_type = config.source_type diff --git a/invokeai/app/invocations/external_image_generation.py b/invokeai/app/invocations/external_image_generation.py new file mode 100644 index 00000000000..c70ecb40795 --- /dev/null +++ b/invokeai/app/invocations/external_image_generation.py @@ -0,0 +1,148 @@ +from typing import Any + +from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation +from invokeai.app.invocations.fields import ( + FieldDescriptions, + ImageField, + InputField, + MetadataField, + WithBoard, + WithMetadata, +) +from invokeai.app.invocations.model import ModelIdentifierField +from invokeai.app.invocations.primitives import ImageCollectionOutput +from invokeai.app.services.external_generation.external_generation_common import ( + ExternalGenerationRequest, + ExternalGenerationResult, + ExternalReferenceImage, +) +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.model_manager.configs.external_api import ExternalApiModelConfig, ExternalGenerationMode +from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType + + +@invocation( + "external_image_generation", + title="External Image Generation", + tags=["external", "generation"], + category="image", + version="1.0.0", +) +class ExternalImageGenerationInvocation(BaseInvocation, WithMetadata, WithBoard): + """Generate images using an external provider.""" + + model: ModelIdentifierField = InputField( + description=FieldDescriptions.main_model, + ui_model_base=[BaseModelType.External], + ui_model_type=[ModelType.ExternalImageGenerator], + ui_model_format=[ModelFormat.ExternalApi], + ) + mode: ExternalGenerationMode = InputField(default="txt2img", description="Generation mode") + prompt: str = InputField(description="Prompt") + negative_prompt: str | None = InputField(default=None, description="Negative prompt") + seed: int | None = InputField(default=None, description=FieldDescriptions.seed) + num_images: int = InputField(default=1, gt=0, description="Number of images to generate") + width: int = InputField(default=1024, gt=0, description=FieldDescriptions.width) + height: int = InputField(default=1024, gt=0, description=FieldDescriptions.height) + steps: int | None = InputField(default=None, gt=0, description=FieldDescriptions.steps) + guidance: float | None = InputField(default=None, ge=0, description="Guidance strength") + init_image: ImageField | None = InputField(default=None, description="Init image for img2img/inpaint") + mask_image: ImageField | None = InputField(default=None, description="Mask image for inpaint") + reference_images: list[ImageField] = InputField(default=[], description="Reference images") + reference_image_weights: list[float] | None = InputField(default=None, description="Reference image weights") + reference_image_modes: list[str] | None = InputField(default=None, description="Reference image modes") + + def invoke(self, context: InvocationContext) -> ImageCollectionOutput: + model_config = context.models.get_config(self.model) + if not isinstance(model_config, ExternalApiModelConfig): + raise ValueError("Selected model is not an external API model") + + init_image = None + if self.init_image is not None: + init_image = context.images.get_pil(self.init_image.image_name, mode="RGB") + + mask_image = None + if self.mask_image is not None: + mask_image = context.images.get_pil(self.mask_image.image_name, mode="L") + + if self.reference_image_weights is not None and len(self.reference_image_weights) != len(self.reference_images): + raise ValueError("reference_image_weights must match reference_images length") + + if self.reference_image_modes is not None and len(self.reference_image_modes) != len(self.reference_images): + raise ValueError("reference_image_modes must match reference_images length") + + reference_images: list[ExternalReferenceImage] = [] + for index, image_field in enumerate(self.reference_images): + reference_image = context.images.get_pil(image_field.image_name, mode="RGB") + weight = None + mode = None + if self.reference_image_weights is not None: + weight = self.reference_image_weights[index] + if self.reference_image_modes is not None: + mode = self.reference_image_modes[index] + reference_images.append(ExternalReferenceImage(image=reference_image, weight=weight, mode=mode)) + + request = ExternalGenerationRequest( + model=model_config, + mode=self.mode, + prompt=self.prompt, + negative_prompt=self.negative_prompt, + seed=self.seed, + num_images=self.num_images, + width=self.width, + height=self.height, + steps=self.steps, + guidance=self.guidance, + init_image=init_image, + mask_image=mask_image, + reference_images=reference_images, + metadata=self._build_request_metadata(), + ) + + result = context._services.external_generation.generate(request) + + outputs: list[ImageField] = [] + for generated in result.images: + metadata = self._build_output_metadata(model_config, result, generated.seed) + image_dto = context.images.save(image=generated.image, metadata=metadata) + outputs.append(ImageField(image_name=image_dto.image_name)) + + return ImageCollectionOutput(collection=outputs) + + def _build_request_metadata(self) -> dict[str, Any] | None: + if self.metadata is None: + return None + return self.metadata.root + + def _build_output_metadata( + self, + model_config: ExternalApiModelConfig, + result: ExternalGenerationResult, + image_seed: int | None, + ) -> MetadataField | None: + metadata: dict[str, Any] = {} + + if self.metadata is not None: + metadata.update(self.metadata.root) + + metadata.update( + { + "external_provider": model_config.provider_id, + "external_model_id": model_config.provider_model_id, + } + ) + + provider_request_id = getattr(result, "provider_request_id", None) + if provider_request_id: + metadata["external_request_id"] = provider_request_id + + provider_metadata = getattr(result, "provider_metadata", None) + if provider_metadata: + metadata["external_provider_metadata"] = provider_metadata + + if image_seed is not None: + metadata["external_seed"] = image_seed + + if not metadata: + return None + return MetadataField(root=metadata) diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index 2cc2aaf273c..b2e27d30ccf 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -22,6 +22,7 @@ from invokeai.frontend.cli.arg_parser import InvokeAIArgs INIT_FILE = Path("invokeai.yaml") +API_KEYS_FILE = Path("api_keys.yaml") DB_FILE = Path("invokeai.db") LEGACY_INIT_FILE = Path("invokeai.init") PRECISION = Literal["auto", "float16", "bfloat16", "float32"] @@ -30,6 +31,7 @@ LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"] LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"] CONFIG_SCHEMA_VERSION = "4.0.2" +EXTERNAL_API_KEY_FIELDS = ("external_gemini_api_key", "external_openai_api_key") class URLRegexTokenPair(BaseModel): @@ -111,6 +113,10 @@ class InvokeAIAppConfig(BaseSettings): unsafe_disable_picklescan: UNSAFE. Disable the picklescan security check during model installation. Recommended only for development and testing purposes. This will allow arbitrary code execution during model installation, so should never be used in production. allow_unknown_models: Allow installation of models that we are unable to identify. If enabled, models will be marked as `unknown` in the database, and will not have any metadata associated with them. If disabled, unknown models will be rejected during installation. multiuser: Enable multiuser support. When disabled, the application runs in single-user mode using a default system account with administrator privileges. When enabled, requires user authentication and authorization. + external_gemini_api_key: API key for Gemini image generation. + external_openai_api_key: API key for OpenAI image generation. + external_gemini_base_url: Base URL override for Gemini image generation. + external_openai_base_url: Base URL override for OpenAI image generation. """ _root: Optional[Path] = PrivateAttr(default=None) @@ -207,6 +213,16 @@ class InvokeAIAppConfig(BaseSettings): # MULTIUSER multiuser: bool = Field(default=False, description="Enable multiuser support. When disabled, the application runs in single-user mode using a default system account with administrator privileges. When enabled, requires user authentication and authorization.") + # EXTERNAL PROVIDERS + external_gemini_api_key: Optional[str] = Field(default=None, description="API key for Gemini image generation.") + external_openai_api_key: Optional[str] = Field(default=None, description="API key for OpenAI image generation.") + external_gemini_base_url: Optional[str] = Field( + default=None, description="Base URL override for Gemini image generation." + ) + external_openai_base_url: Optional[str] = Field( + default=None, description="Base URL override for OpenAI image generation." + ) + # fmt: on model_config = SettingsConfigDict(env_prefix="INVOKEAI_", env_ignore_empty=True) @@ -288,6 +304,13 @@ def config_file_path(self) -> Path: assert resolved_path is not None return resolved_path + @property + def api_keys_file_path(self) -> Path: + """Path to api_keys.yaml, resolved to an absolute path..""" + resolved_path = self._resolve(API_KEYS_FILE) + assert resolved_path is not None + return resolved_path + @property def outputs_path(self) -> Optional[Path]: """Path to the outputs directory, resolved to an absolute path..""" @@ -500,6 +523,36 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig: raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e +def load_external_api_keys(api_keys_file_path: Path) -> dict[str, str]: + """Load external provider API keys from a dedicated YAML file.""" + if not api_keys_file_path.exists(): + return {} + + with open(api_keys_file_path, "rt", encoding=locale.getpreferredencoding()) as file: + loaded_api_keys: Any = yaml.safe_load(file) + + if loaded_api_keys is None: + return {} + + if not isinstance(loaded_api_keys, dict): + raise RuntimeError(f"Failed to load api keys file {api_keys_file_path}: expected a mapping") + + parsed_api_keys: dict[str, str] = {} + for field_name in EXTERNAL_API_KEY_FIELDS: + value = loaded_api_keys.get(field_name) + if value is None: + continue + if not isinstance(value, str): + raise RuntimeError( + f"Failed to load api keys file {api_keys_file_path}: value for '{field_name}' must be a string" + ) + stripped_value = value.strip() + if stripped_value: + parsed_api_keys[field_name] = stripped_value + + return parsed_api_keys + + @lru_cache(maxsize=1) def get_config() -> InvokeAIAppConfig: """Get the global singleton app config. @@ -516,6 +569,7 @@ def get_config() -> InvokeAIAppConfig: """ # This object includes environment variables, as parsed by pydantic-settings config = InvokeAIAppConfig() + env_fields_set = set(config.model_fields_set) args = InvokeAIArgs.args @@ -577,4 +631,11 @@ def get_config() -> InvokeAIAppConfig: default_config = DefaultInvokeAIAppConfig() default_config.write_file(config.config_file_path, as_example=False) + api_keys_from_file = load_external_api_keys(config.api_keys_file_path) + if api_keys_from_file: + # API keys file should take precedence over invokeai.yaml, but not over environment variables. + api_keys_to_apply = {key: value for key, value in api_keys_from_file.items() if key not in env_fields_set} + if api_keys_to_apply: + config.update_config(api_keys_to_apply, clobber=True) + return config diff --git a/invokeai/app/services/external_generation/__init__.py b/invokeai/app/services/external_generation/__init__.py new file mode 100644 index 00000000000..692da64643a --- /dev/null +++ b/invokeai/app/services/external_generation/__init__.py @@ -0,0 +1,23 @@ +from invokeai.app.services.external_generation.external_generation_base import ( + ExternalGenerationServiceBase, + ExternalProvider, +) +from invokeai.app.services.external_generation.external_generation_common import ( + ExternalGenerationRequest, + ExternalGenerationResult, + ExternalGeneratedImage, + ExternalProviderStatus, + ExternalReferenceImage, +) +from invokeai.app.services.external_generation.external_generation_default import ExternalGenerationService + +__all__ = [ + "ExternalGenerationRequest", + "ExternalGenerationResult", + "ExternalGeneratedImage", + "ExternalGenerationService", + "ExternalGenerationServiceBase", + "ExternalProvider", + "ExternalProviderStatus", + "ExternalReferenceImage", +] diff --git a/invokeai/app/services/external_generation/errors.py b/invokeai/app/services/external_generation/errors.py new file mode 100644 index 00000000000..9980b39bc43 --- /dev/null +++ b/invokeai/app/services/external_generation/errors.py @@ -0,0 +1,18 @@ +class ExternalGenerationError(Exception): + """Base error for external generation.""" + + +class ExternalProviderNotFoundError(ExternalGenerationError): + """Raised when no provider is registered for a model.""" + + +class ExternalProviderNotConfiguredError(ExternalGenerationError): + """Raised when a provider is missing required credentials.""" + + +class ExternalProviderCapabilityError(ExternalGenerationError): + """Raised when a request is not supported by provider capabilities.""" + + +class ExternalProviderRequestError(ExternalGenerationError): + """Raised when a provider rejects the request or returns an error.""" diff --git a/invokeai/app/services/external_generation/external_generation_base.py b/invokeai/app/services/external_generation/external_generation_base.py new file mode 100644 index 00000000000..2145ff5ca42 --- /dev/null +++ b/invokeai/app/services/external_generation/external_generation_base.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from logging import Logger + +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.app.services.external_generation.external_generation_common import ( + ExternalGenerationRequest, + ExternalGenerationResult, + ExternalProviderStatus, +) + + +class ExternalProvider(ABC): + provider_id: str + + def __init__(self, app_config: InvokeAIAppConfig, logger: Logger) -> None: + self._app_config = app_config + self._logger = logger + + @abstractmethod + def is_configured(self) -> bool: + raise NotImplementedError + + @abstractmethod + def generate(self, request: ExternalGenerationRequest) -> ExternalGenerationResult: + raise NotImplementedError + + def get_status(self) -> ExternalProviderStatus: + return ExternalProviderStatus(provider_id=self.provider_id, configured=self.is_configured()) + + +class ExternalGenerationServiceBase(ABC): + @abstractmethod + def generate(self, request: ExternalGenerationRequest) -> ExternalGenerationResult: + raise NotImplementedError + + @abstractmethod + def get_provider_statuses(self) -> dict[str, ExternalProviderStatus]: + raise NotImplementedError diff --git a/invokeai/app/services/external_generation/external_generation_common.py b/invokeai/app/services/external_generation/external_generation_common.py new file mode 100644 index 00000000000..c1e2f4706f5 --- /dev/null +++ b/invokeai/app/services/external_generation/external_generation_common.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +from PIL.Image import Image as PILImageType + +from invokeai.backend.model_manager.configs.external_api import ExternalApiModelConfig, ExternalGenerationMode + + +@dataclass(frozen=True) +class ExternalReferenceImage: + image: PILImageType + weight: float | None = None + mode: str | None = None + + +@dataclass(frozen=True) +class ExternalGenerationRequest: + model: ExternalApiModelConfig + mode: ExternalGenerationMode + prompt: str + negative_prompt: str | None + seed: int | None + num_images: int + width: int + height: int + steps: int | None + guidance: float | None + init_image: PILImageType | None + mask_image: PILImageType | None + reference_images: list[ExternalReferenceImage] + metadata: dict[str, Any] | None + + +@dataclass(frozen=True) +class ExternalGeneratedImage: + image: PILImageType + seed: int | None = None + + +@dataclass(frozen=True) +class ExternalGenerationResult: + images: list[ExternalGeneratedImage] + seed_used: int | None = None + provider_request_id: str | None = None + provider_metadata: dict[str, Any] | None = None + content_filters: dict[str, str] | None = None + + +@dataclass(frozen=True) +class ExternalProviderStatus: + provider_id: str + configured: bool + message: str | None = None diff --git a/invokeai/app/services/external_generation/external_generation_default.py b/invokeai/app/services/external_generation/external_generation_default.py new file mode 100644 index 00000000000..ff54d714762 --- /dev/null +++ b/invokeai/app/services/external_generation/external_generation_default.py @@ -0,0 +1,335 @@ +from __future__ import annotations + +from logging import Logger +from typing import TYPE_CHECKING + +from PIL import Image +from PIL.Image import Image as PILImageType + +from invokeai.app.services.external_generation.errors import ( + ExternalProviderCapabilityError, + ExternalProviderNotConfiguredError, + ExternalProviderNotFoundError, +) +from invokeai.app.services.external_generation.external_generation_base import ( + ExternalGenerationServiceBase, + ExternalProvider, +) +from invokeai.app.services.external_generation.external_generation_common import ( + ExternalGeneratedImage, + ExternalGenerationRequest, + ExternalGenerationResult, + ExternalProviderStatus, +) +from invokeai.backend.model_manager.configs.external_api import ExternalApiModelConfig, ExternalImageSize +from invokeai.backend.model_manager.starter_models import STARTER_MODELS + +if TYPE_CHECKING: + from invokeai.app.services.model_records import ModelRecordServiceBase + + +class ExternalGenerationService(ExternalGenerationServiceBase): + def __init__( + self, + providers: dict[str, ExternalProvider], + logger: Logger, + record_store: ModelRecordServiceBase | None = None, + ) -> None: + self._providers = providers + self._logger = logger + self._record_store = record_store + + def generate(self, request: ExternalGenerationRequest) -> ExternalGenerationResult: + provider = self._providers.get(request.model.provider_id) + if provider is None: + raise ExternalProviderNotFoundError(f"No external provider registered for '{request.model.provider_id}'") + + if not provider.is_configured(): + raise ExternalProviderNotConfiguredError(f"Provider '{request.model.provider_id}' is missing credentials") + + request = self._refresh_model_capabilities(request) + resize_to_original_inpaint_size = _get_resize_target_for_inpaint(request) + request = self._bucket_request(request) + + self._validate_request(request) + result = provider.generate(request) + + if resize_to_original_inpaint_size is None: + return result + + width, height = resize_to_original_inpaint_size + return _resize_result_images(result, width, height) + + def get_provider_statuses(self) -> dict[str, ExternalProviderStatus]: + return {provider_id: provider.get_status() for provider_id, provider in self._providers.items()} + + def _validate_request(self, request: ExternalGenerationRequest) -> None: + capabilities = request.model.capabilities + + self._logger.debug( + "Validating external request provider=%s model=%s mode=%s supported=%s", + request.model.provider_id, + request.model.provider_model_id, + request.mode, + capabilities.modes, + ) + + if request.mode not in capabilities.modes: + raise ExternalProviderCapabilityError(f"Mode '{request.mode}' is not supported by {request.model.name}") + + if request.negative_prompt and not capabilities.supports_negative_prompt: + raise ExternalProviderCapabilityError(f"Negative prompts are not supported by {request.model.name}") + + if request.seed is not None and not capabilities.supports_seed: + raise ExternalProviderCapabilityError(f"Seed control is not supported by {request.model.name}") + + if request.guidance is not None and not capabilities.supports_guidance: + raise ExternalProviderCapabilityError(f"Guidance is not supported by {request.model.name}") + + if request.reference_images and not capabilities.supports_reference_images: + raise ExternalProviderCapabilityError(f"Reference images are not supported by {request.model.name}") + + if capabilities.max_reference_images is not None: + if len(request.reference_images) > capabilities.max_reference_images: + raise ExternalProviderCapabilityError( + f"{request.model.name} supports at most {capabilities.max_reference_images} reference images" + ) + + if capabilities.max_images_per_request is not None and request.num_images > capabilities.max_images_per_request: + raise ExternalProviderCapabilityError( + f"{request.model.name} supports at most {capabilities.max_images_per_request} images per request" + ) + + if capabilities.max_image_size is not None: + if request.width > capabilities.max_image_size.width or request.height > capabilities.max_image_size.height: + raise ExternalProviderCapabilityError( + f"{request.model.name} supports a maximum size of {capabilities.max_image_size.width}x{capabilities.max_image_size.height}" + ) + + if capabilities.allowed_aspect_ratios: + aspect_ratio = _format_aspect_ratio(request.width, request.height) + if aspect_ratio not in capabilities.allowed_aspect_ratios: + size_ratio = None + if capabilities.aspect_ratio_sizes: + size_ratio = _ratio_for_size(request.width, request.height, capabilities.aspect_ratio_sizes) + if size_ratio is None or size_ratio not in capabilities.allowed_aspect_ratios: + ratio_label = size_ratio or aspect_ratio + raise ExternalProviderCapabilityError( + f"{request.model.name} does not support aspect ratio {ratio_label}" + ) + + required_modes = capabilities.input_image_required_for or ["img2img", "inpaint"] + if request.mode in required_modes and request.init_image is None: + raise ExternalProviderCapabilityError( + f"Mode '{request.mode}' requires an init image for {request.model.name}" + ) + + if request.mode == "inpaint" and request.mask_image is None: + raise ExternalProviderCapabilityError( + f"Mode '{request.mode}' requires a mask image for {request.model.name}" + ) + + def _refresh_model_capabilities(self, request: ExternalGenerationRequest) -> ExternalGenerationRequest: + if self._record_store is None: + return request + + try: + record = self._record_store.get_model(request.model.key) + except Exception: + record = None + + if not isinstance(record, ExternalApiModelConfig): + return request + + if record.key != request.model.key: + return request + + if record.provider_id != request.model.provider_id: + return request + + if record.provider_model_id != request.model.provider_model_id: + return request + + record = _apply_starter_overrides(record) + + if record == request.model: + return request + + return ExternalGenerationRequest( + model=record, + mode=request.mode, + prompt=request.prompt, + negative_prompt=request.negative_prompt, + seed=request.seed, + num_images=request.num_images, + width=request.width, + height=request.height, + steps=request.steps, + guidance=request.guidance, + init_image=request.init_image, + mask_image=request.mask_image, + reference_images=request.reference_images, + metadata=request.metadata, + ) + + def _bucket_request(self, request: ExternalGenerationRequest) -> ExternalGenerationRequest: + capabilities = request.model.capabilities + if not capabilities.allowed_aspect_ratios: + return request + + aspect_ratio = _format_aspect_ratio(request.width, request.height) + size = None + if capabilities.aspect_ratio_sizes: + size = capabilities.aspect_ratio_sizes.get(aspect_ratio) + + if size is not None: + if request.width == size.width and request.height == size.height: + return request + return self._bucket_to_size(request, size.width, size.height, aspect_ratio) + + if aspect_ratio in capabilities.allowed_aspect_ratios: + return request + + if not capabilities.aspect_ratio_sizes: + return request + + closest = _select_closest_ratio( + request.width, + request.height, + capabilities.allowed_aspect_ratios, + ) + if closest is None: + return request + + size = capabilities.aspect_ratio_sizes.get(closest) + if size is None: + return request + + return self._bucket_to_size(request, size.width, size.height, closest) + + def _bucket_to_size( + self, + request: ExternalGenerationRequest, + width: int, + height: int, + ratio: str, + ) -> ExternalGenerationRequest: + self._logger.info( + "Bucketing external request provider=%s model=%s %sx%s -> %sx%s (ratio %s)", + request.model.provider_id, + request.model.provider_model_id, + request.width, + request.height, + width, + height, + ratio, + ) + + return ExternalGenerationRequest( + model=request.model, + mode=request.mode, + prompt=request.prompt, + negative_prompt=request.negative_prompt, + seed=request.seed, + num_images=request.num_images, + width=width, + height=height, + steps=request.steps, + guidance=request.guidance, + init_image=_resize_image(request.init_image, width, height, "RGB"), + mask_image=_resize_image(request.mask_image, width, height, "L"), + reference_images=request.reference_images, + metadata=request.metadata, + ) + + +def _format_aspect_ratio(width: int, height: int) -> str: + divisor = _gcd(width, height) + return f"{width // divisor}:{height // divisor}" + + +def _select_closest_ratio(width: int, height: int, ratios: list[str]) -> str | None: + ratio = width / height + parsed: list[tuple[str, float]] = [] + for value in ratios: + parsed_ratio = _parse_ratio(value) + if parsed_ratio is not None: + parsed.append((value, parsed_ratio)) + if not parsed: + return None + return min(parsed, key=lambda item: abs(item[1] - ratio))[0] + + +def _ratio_for_size(width: int, height: int, sizes: dict[str, ExternalImageSize]) -> str | None: + for ratio, size in sizes.items(): + if size.width == width and size.height == height: + return ratio + return None + + +def _parse_ratio(value: str) -> float | None: + if ":" not in value: + return None + left, right = value.split(":", 1) + try: + numerator = float(left) + denominator = float(right) + except ValueError: + return None + if denominator == 0: + return None + return numerator / denominator + + +def _gcd(a: int, b: int) -> int: + while b: + a, b = b, a % b + return a + + +def _resize_image(image: PILImageType | None, width: int, height: int, mode: str) -> PILImageType | None: + if image is None: + return None + if image.width == width and image.height == height: + return image + return image.convert(mode).resize((width, height), Image.Resampling.LANCZOS) + + +def _get_resize_target_for_inpaint(request: ExternalGenerationRequest) -> tuple[int, int] | None: + if request.mode != "inpaint" or request.init_image is None: + return None + return request.init_image.width, request.init_image.height + + +def _resize_result_images(result: ExternalGenerationResult, width: int, height: int) -> ExternalGenerationResult: + resized_images = [ + ExternalGeneratedImage( + image=generated.image + if generated.image.width == width and generated.image.height == height + else generated.image.resize((width, height), Image.Resampling.LANCZOS), + seed=generated.seed, + ) + for generated in result.images + ] + return ExternalGenerationResult( + images=resized_images, + seed_used=result.seed_used, + provider_request_id=result.provider_request_id, + provider_metadata=result.provider_metadata, + content_filters=result.content_filters, + ) + + +def _apply_starter_overrides(model: ExternalApiModelConfig) -> ExternalApiModelConfig: + source = model.source or f"external://{model.provider_id}/{model.provider_model_id}" + starter_match = next((starter for starter in STARTER_MODELS if starter.source == source), None) + if starter_match is None: + return model + updates: dict[str, object] = {} + if starter_match.capabilities is not None: + updates["capabilities"] = starter_match.capabilities + if starter_match.default_settings is not None: + updates["default_settings"] = starter_match.default_settings + if not updates: + return model + return model.model_copy(update=updates) diff --git a/invokeai/app/services/external_generation/image_utils.py b/invokeai/app/services/external_generation/image_utils.py new file mode 100644 index 00000000000..a23c1f11d66 --- /dev/null +++ b/invokeai/app/services/external_generation/image_utils.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +import base64 +import io + +from PIL import Image +from PIL.Image import Image as PILImageType + + +def encode_image_base64(image: PILImageType, format: str = "PNG") -> str: + buffer = io.BytesIO() + image.save(buffer, format=format) + return base64.b64encode(buffer.getvalue()).decode("ascii") + + +def decode_image_base64(encoded: str) -> PILImageType: + data = base64.b64decode(encoded) + image = Image.open(io.BytesIO(data)) + return image.convert("RGB") diff --git a/invokeai/app/services/external_generation/providers/__init__.py b/invokeai/app/services/external_generation/providers/__init__.py new file mode 100644 index 00000000000..9e380fca1e1 --- /dev/null +++ b/invokeai/app/services/external_generation/providers/__init__.py @@ -0,0 +1,4 @@ +from invokeai.app.services.external_generation.providers.gemini import GeminiProvider +from invokeai.app.services.external_generation.providers.openai import OpenAIProvider + +__all__ = ["GeminiProvider", "OpenAIProvider"] diff --git a/invokeai/app/services/external_generation/providers/gemini.py b/invokeai/app/services/external_generation/providers/gemini.py new file mode 100644 index 00000000000..4d43431a14a --- /dev/null +++ b/invokeai/app/services/external_generation/providers/gemini.py @@ -0,0 +1,249 @@ +from __future__ import annotations + +import json +import uuid + +import requests +from PIL.Image import Image as PILImageType + +from invokeai.app.services.external_generation.errors import ExternalProviderRequestError +from invokeai.app.services.external_generation.external_generation_base import ExternalProvider +from invokeai.app.services.external_generation.external_generation_common import ( + ExternalGeneratedImage, + ExternalGenerationRequest, + ExternalGenerationResult, +) +from invokeai.app.services.external_generation.image_utils import decode_image_base64, encode_image_base64 + + +class GeminiProvider(ExternalProvider): + provider_id = "gemini" + _SYSTEM_INSTRUCTION = ( + "You are an image generation model. Always respond with an image based on the user's prompt. " + "Do not return text-only responses. If the user input is not an edit instruction, " + "interpret it as a request to create a new image." + ) + + def is_configured(self) -> bool: + return bool(self._app_config.external_gemini_api_key) + + def generate(self, request: ExternalGenerationRequest) -> ExternalGenerationResult: + api_key = self._app_config.external_gemini_api_key + if not api_key: + raise ExternalProviderRequestError("Gemini API key is not configured") + + base_url = (self._app_config.external_gemini_base_url or "https://generativelanguage.googleapis.com").rstrip( + "/" + ) + if not base_url.endswith("/v1") and not base_url.endswith("/v1beta"): + base_url = f"{base_url}/v1beta" + model_id = request.model.provider_model_id.removeprefix("models/") + endpoint = f"{base_url}/models/{model_id}:generateContent" + + request_parts: list[dict[str, object]] = [] + + if request.init_image is not None: + request_parts.append( + { + "inlineData": { + "mimeType": "image/png", + "data": encode_image_base64(request.init_image), + } + } + ) + + request_parts.append({"text": request.prompt}) + + for reference in request.reference_images: + request_parts.append( + { + "inlineData": { + "mimeType": "image/png", + "data": encode_image_base64(reference.image), + } + } + ) + + generation_config: dict[str, object] = { + "candidateCount": request.num_images, + "responseModalities": ["IMAGE"], + } + aspect_ratio = _select_aspect_ratio( + request.width, + request.height, + request.model.capabilities.allowed_aspect_ratios, + ) + system_instruction = self._SYSTEM_INSTRUCTION + if request.init_image is not None: + system_instruction = ( + f"{system_instruction} An input image is provided. " + "Treat the prompt as an edit instruction and modify the image accordingly. " + "Do not return the original image unchanged." + ) + if aspect_ratio is not None: + system_instruction = f"{system_instruction} Use an aspect ratio of {aspect_ratio}." + + payload: dict[str, object] = { + "systemInstruction": {"parts": [{"text": system_instruction}]}, + "contents": [{"role": "user", "parts": request_parts}], + "generationConfig": generation_config, + } + + self._dump_debug_payload("request", payload) + + response = requests.post( + endpoint, + params={"key": api_key}, + json=payload, + timeout=120, + ) + + if not response.ok: + raise ExternalProviderRequestError( + f"Gemini request failed with status {response.status_code} for model '{model_id}': {response.text}" + ) + + data = response.json() + self._dump_debug_payload("response", data) + if not isinstance(data, dict): + raise ExternalProviderRequestError("Gemini response payload was not a JSON object") + images: list[ExternalGeneratedImage] = [] + text_parts: list[str] = [] + finish_messages: list[str] = [] + candidates = data.get("candidates") + if not isinstance(candidates, list): + raise ExternalProviderRequestError("Gemini response payload missing candidates") + for candidate in candidates: + if not isinstance(candidate, dict): + continue + finish_message = candidate.get("finishMessage") + finish_reason = candidate.get("finishReason") + if isinstance(finish_message, str): + finish_messages.append(finish_message) + elif isinstance(finish_reason, str): + finish_messages.append(f"Finish reason: {finish_reason}") + for part in _iter_response_parts(candidate): + inline_data = part.get("inline_data") or part.get("inlineData") + if isinstance(inline_data, dict): + encoded = inline_data.get("data") + if encoded: + image = decode_image_base64(encoded) + images.append(ExternalGeneratedImage(image=image, seed=request.seed)) + self._dump_debug_image(image) + continue + file_data = part.get("fileData") or part.get("file_data") + if isinstance(file_data, dict): + file_uri = file_data.get("fileUri") or file_data.get("file_uri") + if isinstance(file_uri, str) and file_uri: + raise ExternalProviderRequestError( + f"Gemini returned fileUri instead of inline image data: {file_uri}" + ) + text = part.get("text") + if isinstance(text, str): + text_parts.append(text) + + if not images: + self._logger.error("Gemini response contained no images: %s", data) + detail = "" + if finish_messages: + combined = " ".join(message.strip() for message in finish_messages if message.strip()) + if combined: + detail = f" Response status: {combined[:500]}" + elif text_parts: + combined = " ".join(text_parts).strip() + if combined: + detail = f" Response text: {combined[:500]}" + raise ExternalProviderRequestError(f"Gemini response contained no images.{detail}") + + return ExternalGenerationResult( + images=images, + seed_used=request.seed, + provider_metadata={"model": request.model.provider_model_id}, + ) + + def _dump_debug_payload(self, label: str, payload: object) -> None: + """TODO: remove debug payload dump once Gemini is stable.""" + try: + outputs_path = self._app_config.outputs_path + if outputs_path is None: + return + debug_dir = outputs_path / "external_debug" / "gemini" + debug_dir.mkdir(parents=True, exist_ok=True) + path = debug_dir / f"{label}_{uuid.uuid4().hex}.json" + path.write_text(json.dumps(payload, indent=2, default=str), encoding="utf-8") + except Exception as exc: + self._logger.debug("Failed to write Gemini debug payload: %s", exc) + + def _dump_debug_image(self, image: "PILImageType") -> None: + """TODO: remove debug image dump once Gemini is stable.""" + try: + outputs_path = self._app_config.outputs_path + if outputs_path is None: + return + debug_dir = outputs_path / "external_debug" / "gemini" + debug_dir.mkdir(parents=True, exist_ok=True) + path = debug_dir / f"decoded_{uuid.uuid4().hex}.png" + image.save(path, format="PNG") + except Exception as exc: + self._logger.debug("Failed to write Gemini debug image: %s", exc) + + +def _iter_response_parts(candidate: dict[str, object]) -> list[dict[str, object]]: + content = candidate.get("content") + if isinstance(content, dict): + content_parts = content.get("parts") + if isinstance(content_parts, list): + return [part for part in content_parts if isinstance(part, dict)] + contents = candidate.get("contents") + if isinstance(contents, list): + parts: list[dict[str, object]] = [] + for item in contents: + if not isinstance(item, dict): + continue + item_parts = item.get("parts") + if isinstance(item_parts, list): + parts.extend([part for part in item_parts if isinstance(part, dict)]) + if parts: + return parts + return [] + + +def _select_aspect_ratio(width: int, height: int, allowed: list[str] | None) -> str | None: + if width <= 0 or height <= 0: + return None + ratio = width / height + default_ratio = _format_aspect_ratio(width, height) + if not allowed: + return default_ratio + parsed = [(value, _parse_ratio(value)) for value in allowed] + filtered = [(value, parsed_ratio) for value, parsed_ratio in parsed if parsed_ratio is not None] + if not filtered: + return default_ratio + return min(filtered, key=lambda item: abs(item[1] - ratio))[0] + + +def _format_aspect_ratio(width: int, height: int) -> str | None: + if width <= 0 or height <= 0: + return None + divisor = _gcd(width, height) + return f"{width // divisor}:{height // divisor}" + + +def _parse_ratio(value: str) -> float | None: + if ":" not in value: + return None + left, right = value.split(":", 1) + try: + numerator = float(left) + denominator = float(right) + except ValueError: + return None + if denominator == 0: + return None + return numerator / denominator + + +def _gcd(a: int, b: int) -> int: + while b: + a, b = b, a % b + return a diff --git a/invokeai/app/services/external_generation/providers/openai.py b/invokeai/app/services/external_generation/providers/openai.py new file mode 100644 index 00000000000..f06491a225b --- /dev/null +++ b/invokeai/app/services/external_generation/providers/openai.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +import io + +import requests +from PIL.Image import Image as PILImageType + +from invokeai.app.services.external_generation.errors import ExternalProviderRequestError +from invokeai.app.services.external_generation.external_generation_base import ExternalProvider +from invokeai.app.services.external_generation.external_generation_common import ( + ExternalGeneratedImage, + ExternalGenerationRequest, + ExternalGenerationResult, +) +from invokeai.app.services.external_generation.image_utils import decode_image_base64 + + +class OpenAIProvider(ExternalProvider): + provider_id = "openai" + + def is_configured(self) -> bool: + return bool(self._app_config.external_openai_api_key) + + def generate(self, request: ExternalGenerationRequest) -> ExternalGenerationResult: + api_key = self._app_config.external_openai_api_key + if not api_key: + raise ExternalProviderRequestError("OpenAI API key is not configured") + + size = f"{request.width}x{request.height}" + base_url = (self._app_config.external_openai_base_url or "https://api.openai.com").rstrip("/") + headers = {"Authorization": f"Bearer {api_key}"} + + use_edits_endpoint = request.mode != "txt2img" or bool(request.reference_images) + + if not use_edits_endpoint: + payload: dict[str, object] = { + "prompt": request.prompt, + "n": request.num_images, + "size": size, + "response_format": "b64_json", + } + if request.seed is not None: + payload["seed"] = request.seed + response = requests.post( + f"{base_url}/v1/images/generations", + headers=headers, + json=payload, + timeout=120, + ) + else: + images: list[PILImageType] = [] + if request.init_image is not None: + images.append(request.init_image) + images.extend(reference.image for reference in request.reference_images) + if not images: + raise ExternalProviderRequestError( + "OpenAI image edits require at least one image (init image or reference image)" + ) + + files: list[tuple[str, tuple[str, io.BytesIO, str]]] = [] + image_field_name = "image" if len(images) == 1 else "image[]" + for index, image in enumerate(images): + image_buffer = io.BytesIO() + image.save(image_buffer, format="PNG") + image_buffer.seek(0) + files.append((image_field_name, (f"image_{index}.png", image_buffer, "image/png"))) + + if request.mask_image is not None: + mask_buffer = io.BytesIO() + request.mask_image.save(mask_buffer, format="PNG") + mask_buffer.seek(0) + files.append(("mask", ("mask.png", mask_buffer, "image/png"))) + + data: dict[str, object] = { + "prompt": request.prompt, + "n": request.num_images, + "size": size, + "response_format": "b64_json", + } + response = requests.post( + f"{base_url}/v1/images/edits", + headers=headers, + data=data, + files=files, + timeout=120, + ) + + if not response.ok: + raise ExternalProviderRequestError( + f"OpenAI request failed with status {response.status_code}: {response.text}" + ) + + payload = response.json() + if not isinstance(payload, dict): + raise ExternalProviderRequestError("OpenAI response payload was not a JSON object") + images: list[ExternalGeneratedImage] = [] + data_items = payload.get("data") + if not isinstance(data_items, list): + raise ExternalProviderRequestError("OpenAI response payload missing image data") + for item in data_items: + if not isinstance(item, dict): + continue + encoded = item.get("b64_json") + if not encoded: + continue + images.append(ExternalGeneratedImage(image=decode_image_base64(encoded), seed=request.seed)) + + if not images: + raise ExternalProviderRequestError("OpenAI response contained no images") + + return ExternalGenerationResult( + images=images, + seed_used=request.seed, + provider_request_id=response.headers.get("x-request-id"), + provider_metadata={"model": request.model.provider_model_id}, + ) diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index 7a33f49940c..2c95f87b41d 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -21,6 +21,7 @@ from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.download import DownloadQueueServiceBase from invokeai.app.services.events.events_base import EventServiceBase + from invokeai.app.services.external_generation.external_generation_base import ExternalGenerationServiceBase from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase from invokeai.app.services.image_records.image_records_base import ImageRecordStorageBase from invokeai.app.services.images.images_base import ImageServiceABC @@ -63,6 +64,7 @@ def __init__( model_relationships: "ModelRelationshipsServiceABC", model_relationship_records: "ModelRelationshipRecordStorageBase", download_queue: "DownloadQueueServiceBase", + external_generation: "ExternalGenerationServiceBase", performance_statistics: "InvocationStatsServiceBase", session_queue: "SessionQueueBase", session_processor: "SessionProcessorBase", @@ -94,6 +96,7 @@ def __init__( self.model_relationships = model_relationships self.model_relationship_records = model_relationship_records self.download_queue = download_queue + self.external_generation = external_generation self.performance_statistics = performance_statistics self.session_queue = session_queue self.session_processor = session_processor diff --git a/invokeai/app/services/model_install/model_install_common.py b/invokeai/app/services/model_install/model_install_common.py index 1006135a95e..f223c4698c2 100644 --- a/invokeai/app/services/model_install/model_install_common.py +++ b/invokeai/app/services/model_install/model_install_common.py @@ -139,12 +139,27 @@ def __str__(self) -> str: return str(self.url) -ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Field(discriminator="type")] +class ExternalModelSource(StringLikeSource): + """An external provider model identifier.""" + + provider_id: str + provider_model_id: str + type: Literal["external"] = "external" + + def __str__(self) -> str: + return f"external://{self.provider_id}/{self.provider_model_id}" + + +ModelSource = Annotated[ + Union[LocalModelSource, HFModelSource, URLModelSource, ExternalModelSource], + Field(discriminator="type"), +] MODEL_SOURCE_TO_TYPE_MAP = { URLModelSource: ModelSourceType.Url, HFModelSource: ModelSourceType.HFRepoID, LocalModelSource: ModelSourceType.Path, + ExternalModelSource: ModelSourceType.External, } diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index c47267eab5f..24ba2ebed0b 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -28,6 +28,7 @@ from invokeai.app.services.model_install.model_install_base import ModelInstallServiceBase from invokeai.app.services.model_install.model_install_common import ( MODEL_SOURCE_TO_TYPE_MAP, + ExternalModelSource, HFModelSource, InstallStatus, InvalidModelConfigException, @@ -37,10 +38,15 @@ StringLikeSource, URLModelSource, ) -from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase +from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase, UnknownModelException from invokeai.app.services.model_records.model_records_base import ModelRecordChanges from invokeai.app.util.misc import get_iso_timestamp from invokeai.backend.model_manager.configs.base import Checkpoint_Config_Base +from invokeai.backend.model_manager.configs.external_api import ( + ExternalApiModelConfig, + ExternalApiModelDefaultSettings, + ExternalModelCapabilities, +) from invokeai.backend.model_manager.configs.factory import ( AnyModelConfig, ModelConfigFactory, @@ -55,7 +61,13 @@ ) from invokeai.backend.model_manager.metadata.metadata_base import HuggingFaceMetadata from invokeai.backend.model_manager.search import ModelSearch -from invokeai.backend.model_manager.taxonomy import ModelRepoVariant, ModelSourceType +from invokeai.backend.model_manager.taxonomy import ( + BaseModelType, + ModelFormat, + ModelRepoVariant, + ModelSourceType, + ModelType, +) from invokeai.backend.model_manager.util.lora_metadata_extractor import apply_lora_metadata from invokeai.backend.util import InvokeAILogger from invokeai.backend.util.catch_sigint import catch_sigint @@ -451,6 +463,9 @@ def import_model(self, source: ModelSource, config: Optional[ModelRecordChanges] install_job = self._import_from_hf(source, config) elif isinstance(source, URLModelSource): install_job = self._import_from_url(source, config) + elif isinstance(source, ExternalModelSource): + install_job = self._import_external_model(source, config) + self._put_in_queue(install_job) else: raise ValueError(f"Unsupported model source: '{type(source)}'") @@ -748,7 +763,13 @@ def _guess_source(self, source: str) -> ModelSource: source_obj: Optional[StringLikeSource] = None source_stripped = source.strip('"') - if Path(source_stripped).exists(): # A local file or directory + if source_stripped.startswith("external://"): + external_id = source_stripped.removeprefix("external://") + provider_id, _, provider_model_id = external_id.partition("/") + if not provider_id or not provider_model_id: + raise ValueError(f"Invalid external model source: '{source_stripped}'") + source_obj = ExternalModelSource(provider_id=provider_id, provider_model_id=provider_model_id) + elif Path(source_stripped).exists(): # A local file or directory source_obj = LocalModelSource(path=Path(source_stripped)) elif match := re.match(hf_repoid_re, source): source_obj = HFModelSource( @@ -840,6 +861,9 @@ def _install_next_item(self) -> None: self._logger.info(f"Installer thread {threading.get_ident()} exiting") def _register_or_install(self, job: ModelInstallJob) -> None: + if isinstance(job.source, ExternalModelSource): + self._register_external_model(job) + return # local jobs will be in waiting state, remote jobs will be downloading state job.total_bytes = self._stat_size(job.local_path) job.bytes = job.total_bytes @@ -860,6 +884,71 @@ def _register_or_install(self, job: ModelInstallJob) -> None: job.config_out = self.record_store.get_model(key) self._signal_job_completed(job) + def _register_external_model(self, job: ModelInstallJob) -> None: + job.total_bytes = 0 + job.bytes = 0 + self._signal_job_running(job) + job.config_in.source = str(job.source) + job.config_in.source_type = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__] + + provider_id = job.source.provider_id + provider_model_id = job.source.provider_model_id + capabilities = job.config_in.capabilities or ExternalModelCapabilities() + default_settings = ( + job.config_in.default_settings + if isinstance(job.config_in.default_settings, ExternalApiModelDefaultSettings) + else None + ) + name = job.config_in.name or f"{provider_id} {provider_model_id}" + key = job.config_in.key or slugify(f"{provider_id}-{provider_model_id}") + + existing_external = next( + ( + model + for model in self.record_store.search_by_attr( + base_model=BaseModelType.External, model_type=ModelType.ExternalImageGenerator + ) + if isinstance(model, ExternalApiModelConfig) + and model.provider_id == provider_id + and model.provider_model_id == provider_model_id + ), + None, + ) + + if existing_external is not None: + key = existing_external.key + else: + try: + self.record_store.get_model(key) + raise DuplicateModelException( + f"Model key '{key}' already exists. Provide a different key to install this external model." + ) + except UnknownModelException: + pass + + config = ExternalApiModelConfig( + key=key, + name=name, + description=job.config_in.description, + provider_id=provider_id, + provider_model_id=provider_model_id, + capabilities=capabilities, + default_settings=default_settings, + source=str(job.source), + source_type=MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__], + path="", + hash="", + file_size=0, + ) + + if existing_external is not None: + self.record_store.replace_model(existing_external.key, config) + else: + self.record_store.add_model(config) + + job.config_out = self.record_store.get_model(config.key) + self._signal_job_completed(job) + def _set_error(self, install_job: ModelInstallJob, excp: Exception) -> None: multifile_download_job = install_job._multifile_job if multifile_download_job and any( @@ -895,6 +984,8 @@ def _scan_for_missing_models(self) -> list[AnyModelConfig]: """Scan the models directory for missing models and return a list of them.""" missing_models: list[AnyModelConfig] = [] for model_config in self.record_store.all_models(): + if model_config.base == BaseModelType.External or model_config.format == ModelFormat.ExternalApi: + continue if not (self.app_config.models_path / model_config.path).resolve().exists(): missing_models.append(model_config) return missing_models @@ -1036,6 +1127,19 @@ def _import_from_url( remote_files=remote_files, ) + def _import_external_model( + self, + source: ExternalModelSource, + config: Optional[ModelRecordChanges] = None, + ) -> ModelInstallJob: + return ModelInstallJob( + id=self._next_id(), + source=source, + config_in=config or ModelRecordChanges(), + local_path=self._app_config.models_path, + inplace=True, + ) + def _import_remote_model( self, source: HFModelSource | URLModelSource, diff --git a/invokeai/app/services/model_records/model_records_base.py b/invokeai/app/services/model_records/model_records_base.py index 96e12d3b0a3..318ebb000e6 100644 --- a/invokeai/app/services/model_records/model_records_base.py +++ b/invokeai/app/services/model_records/model_records_base.py @@ -13,6 +13,10 @@ from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.app.util.model_exclude_null import BaseModelExcludeNull from invokeai.backend.model_manager.configs.controlnet import ControlAdapterDefaultSettings +from invokeai.backend.model_manager.configs.external_api import ( + ExternalApiModelDefaultSettings, + ExternalModelCapabilities, +) from invokeai.backend.model_manager.configs.factory import AnyModelConfig from invokeai.backend.model_manager.configs.lora import LoraModelDefaultSettings from invokeai.backend.model_manager.configs.main import MainModelDefaultSettings @@ -86,8 +90,19 @@ class ModelRecordChanges(BaseModelExcludeNull): file_size: Optional[int] = Field(description="Size of model file", default=None) format: Optional[str] = Field(description="format of model file", default=None) trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None) - default_settings: Optional[MainModelDefaultSettings | LoraModelDefaultSettings | ControlAdapterDefaultSettings] = ( - Field(description="Default settings for this model", default=None) + default_settings: Optional[ + MainModelDefaultSettings + | LoraModelDefaultSettings + | ControlAdapterDefaultSettings + | ExternalApiModelDefaultSettings + ] = Field(description="Default settings for this model", default=None) + + # External API model changes + provider_id: Optional[str] = Field(description="External provider identifier", default=None) + provider_model_id: Optional[str] = Field(description="External provider model identifier", default=None) + capabilities: Optional[ExternalModelCapabilities] = Field( + description="External model capabilities", + default=None, ) cpu_only: Optional[bool] = Field(description="Whether this model should run on CPU only", default=None) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 67e3c99f1ad..e38766d5ba2 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -388,6 +388,8 @@ def load( submodel_type = submodel_type or identifier.submodel_type model = self._services.model_manager.store.get_model(identifier.key) + self._raise_if_external(model) + message = f"Loading model {model.name}" if submodel_type: message += f" ({submodel_type.value})" @@ -417,12 +419,18 @@ def load_by_attrs( if len(configs) > 1: raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}") + self._raise_if_external(configs[0]) message = f"Loading model {name}" if submodel_type: message += f" ({submodel_type.value})" self._util.signal_progress(message) return self._services.model_manager.load.load_model(configs[0], submodel_type) + @staticmethod + def _raise_if_external(model: AnyModelConfig) -> None: + if model.base == BaseModelType.External or model.format == ModelFormat.ExternalApi: + raise ValueError("External API models cannot be loaded from disk") + def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig: """Get a model's config. diff --git a/invokeai/backend/model_manager/configs/external_api.py b/invokeai/backend/model_manager/configs/external_api.py index e69de29bb2d..5720a51e4e0 100644 --- a/invokeai/backend/model_manager/configs/external_api.py +++ b/invokeai/backend/model_manager/configs/external_api.py @@ -0,0 +1,81 @@ +from typing import Literal, Self + +from pydantic import BaseModel, ConfigDict, Field, model_validator + +from invokeai.backend.model_manager.configs.base import Config_Base +from invokeai.backend.model_manager.configs.identification_utils import NotAMatchError +from invokeai.backend.model_manager.model_on_disk import ModelOnDisk +from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelSourceType, ModelType + +ExternalGenerationMode = Literal["txt2img", "img2img", "inpaint"] +ExternalMaskFormat = Literal["alpha", "binary", "none"] + + +class ExternalImageSize(BaseModel): + width: int = Field(gt=0) + height: int = Field(gt=0) + + model_config = ConfigDict(extra="forbid") + + +class ExternalModelCapabilities(BaseModel): + modes: list[ExternalGenerationMode] = Field(default_factory=lambda: ["txt2img"]) + supports_reference_images: bool = Field(default=False) + supports_negative_prompt: bool = Field(default=True) + supports_seed: bool = Field(default=False) + supports_guidance: bool = Field(default=False) + supports_steps: bool = Field(default=False) + max_images_per_request: int | None = Field(default=None, gt=0) + max_image_size: ExternalImageSize | None = Field(default=None) + allowed_aspect_ratios: list[str] | None = Field(default=None) + aspect_ratio_sizes: dict[str, ExternalImageSize] | None = Field(default=None) + max_reference_images: int | None = Field(default=None, gt=0) + mask_format: ExternalMaskFormat = Field(default="none") + input_image_required_for: list[ExternalGenerationMode] | None = Field(default=None) + + model_config = ConfigDict(extra="forbid") + + +class ExternalApiModelDefaultSettings(BaseModel): + width: int | None = Field(default=None, gt=0) + height: int | None = Field(default=None, gt=0) + steps: int | None = Field(default=None, gt=0) + guidance: float | None = Field(default=None, gt=0) + num_images: int | None = Field(default=None, gt=0) + + model_config = ConfigDict(extra="forbid") + + +class ExternalApiModelConfig(Config_Base): + base: Literal[BaseModelType.External] = Field(default=BaseModelType.External) + type: Literal[ModelType.ExternalImageGenerator] = Field(default=ModelType.ExternalImageGenerator) + format: Literal[ModelFormat.ExternalApi] = Field(default=ModelFormat.ExternalApi) + + provider_id: str = Field(min_length=1, description="External provider ID") + provider_model_id: str = Field(min_length=1, description="Provider-specific model ID") + capabilities: ExternalModelCapabilities = Field(description="Provider capability matrix") + default_settings: ExternalApiModelDefaultSettings | None = Field(default=None) + tags: list[str] | None = Field(default=None) + is_default: bool = Field(default=False) + + source_type: ModelSourceType = Field(default=ModelSourceType.External) + path: str = Field(default="") + source: str = Field(default="") + hash: str = Field(default="") + file_size: int = Field(default=0, ge=0) + + model_config = ConfigDict(extra="forbid") + + @model_validator(mode="after") + def _populate_external_fields(self) -> "ExternalApiModelConfig": + if not self.path: + self.path = f"external://{self.provider_id}/{self.provider_model_id}" + if not self.source: + self.source = self.path + if not self.hash: + self.hash = f"external:{self.provider_id}:{self.provider_model_id}" + return self + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, object]) -> Self: + raise NotAMatchError("external API models are not probed from disk") diff --git a/invokeai/backend/model_manager/configs/factory.py b/invokeai/backend/model_manager/configs/factory.py index 7702d4a5d9b..81464a1a971 100644 --- a/invokeai/backend/model_manager/configs/factory.py +++ b/invokeai/backend/model_manager/configs/factory.py @@ -26,6 +26,7 @@ ControlNet_Diffusers_SD2_Config, ControlNet_Diffusers_SDXL_Config, ) +from invokeai.backend.model_manager.configs.external_api import ExternalApiModelConfig from invokeai.backend.model_manager.configs.flux_redux import FLUXRedux_Checkpoint_Config from invokeai.backend.model_manager.configs.identification_utils import NotAMatchError from invokeai.backend.model_manager.configs.ip_adapter import ( @@ -256,6 +257,7 @@ Annotated[SigLIP_Diffusers_Config, SigLIP_Diffusers_Config.get_tag()], Annotated[FLUXRedux_Checkpoint_Config, FLUXRedux_Checkpoint_Config.get_tag()], Annotated[LlavaOnevision_Diffusers_Config, LlavaOnevision_Diffusers_Config.get_tag()], + Annotated[ExternalApiModelConfig, ExternalApiModelConfig.get_tag()], # Unknown model (fallback) Annotated[Unknown_Config, Unknown_Config.get_tag()], ], diff --git a/invokeai/backend/model_manager/starter_models.py b/invokeai/backend/model_manager/starter_models.py index ef7cd80cd29..c0b877c9e35 100644 --- a/invokeai/backend/model_manager/starter_models.py +++ b/invokeai/backend/model_manager/starter_models.py @@ -2,6 +2,11 @@ from pydantic import BaseModel +from invokeai.backend.model_manager.configs.external_api import ( + ExternalApiModelDefaultSettings, + ExternalImageSize, + ExternalModelCapabilities, +) from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType @@ -13,6 +18,8 @@ class StarterModelWithoutDependencies(BaseModel): type: ModelType format: Optional[ModelFormat] = None is_installed: bool = False + capabilities: ExternalModelCapabilities | None = None + default_settings: ExternalApiModelDefaultSettings | None = None # allows us to track what models a user has installed across name changes within starter models # if you update a starter model name, please add the old one to this list for that starter model previous_names: list[str] = [] @@ -862,6 +869,108 @@ class StarterModelBundle(BaseModel): ) # endregion +# region External API +gemini_flash_image = StarterModel( + name="Gemini 2.5 Flash Image", + base=BaseModelType.External, + source="external://gemini/gemini-2.5-flash-image", + description="Google Gemini 2.5 Flash image generation model (external API). Requires a configured Gemini API key and may incur provider usage costs.", + type=ModelType.ExternalImageGenerator, + format=ModelFormat.ExternalApi, + capabilities=ExternalModelCapabilities( + modes=["txt2img", "img2img", "inpaint"], + supports_negative_prompt=True, + supports_seed=True, + supports_guidance=True, + supports_reference_images=True, + max_images_per_request=1, + allowed_aspect_ratios=[ + "1:1", + "2:3", + "3:2", + "3:4", + "4:3", + "4:5", + "5:4", + "9:16", + "16:9", + "21:9", + ], + aspect_ratio_sizes={ + "1:1": ExternalImageSize(width=1024, height=1024), + "2:3": ExternalImageSize(width=832, height=1248), + "3:2": ExternalImageSize(width=1248, height=832), + "3:4": ExternalImageSize(width=864, height=1184), + "4:3": ExternalImageSize(width=1184, height=864), + "4:5": ExternalImageSize(width=896, height=1152), + "5:4": ExternalImageSize(width=1152, height=896), + "9:16": ExternalImageSize(width=768, height=1344), + "16:9": ExternalImageSize(width=1344, height=768), + "21:9": ExternalImageSize(width=1536, height=672), + }, + ), + default_settings=ExternalApiModelDefaultSettings(width=1024, height=1024, num_images=1), +) +gemini_pro_image_preview = StarterModel( + name="Gemini 3 Pro Image Preview", + base=BaseModelType.External, + source="external://gemini/gemini-3-pro-image-preview", + description="Google Gemini 3 Pro image generation preview model (external API). Requires a configured Gemini API key and may incur provider usage costs.", + type=ModelType.ExternalImageGenerator, + format=ModelFormat.ExternalApi, + capabilities=ExternalModelCapabilities( + modes=["txt2img", "img2img", "inpaint"], + supports_negative_prompt=True, + supports_seed=True, + supports_guidance=True, + supports_reference_images=True, + max_images_per_request=1, + allowed_aspect_ratios=[ + "1:1", + "2:3", + "3:2", + "3:4", + "4:3", + "4:5", + "5:4", + "9:16", + "16:9", + "21:9", + ], + aspect_ratio_sizes={ + "1:1": ExternalImageSize(width=1024, height=1024), + "2:3": ExternalImageSize(width=832, height=1248), + "3:2": ExternalImageSize(width=1248, height=832), + "3:4": ExternalImageSize(width=864, height=1184), + "4:3": ExternalImageSize(width=1184, height=864), + "4:5": ExternalImageSize(width=896, height=1152), + "5:4": ExternalImageSize(width=1152, height=896), + "9:16": ExternalImageSize(width=768, height=1344), + "16:9": ExternalImageSize(width=1344, height=768), + "21:9": ExternalImageSize(width=1536, height=672), + }, + ), + default_settings=ExternalApiModelDefaultSettings(width=1024, height=1024, num_images=1), +) +openai_gpt_image_1 = StarterModel( + name="ChatGPT Image", + base=BaseModelType.External, + source="external://openai/gpt-image-1", + description="OpenAI GPT-Image-1 image generation model (external API). Requires a configured OpenAI API key and may incur provider usage costs.", + type=ModelType.ExternalImageGenerator, + format=ModelFormat.ExternalApi, + capabilities=ExternalModelCapabilities( + modes=["txt2img", "img2img", "inpaint"], + supports_negative_prompt=True, + supports_seed=True, + supports_guidance=True, + supports_reference_images=True, + max_images_per_request=1, + ), + default_settings=ExternalApiModelDefaultSettings(width=1024, height=1024, num_images=1), +) +# endregion + # List of starter models, displayed on the frontend. # The order/sort of this list is not changed by the frontend - set it how you want it here. STARTER_MODELS: list[StarterModel] = [ @@ -957,6 +1066,9 @@ class StarterModelBundle(BaseModel): z_image_qwen3_encoder_quantized, z_image_controlnet_union, z_image_controlnet_tile, + gemini_flash_image, + gemini_pro_image_preview, + openai_gpt_image_1, ] sd1_bundle: list[StarterModel] = [ diff --git a/invokeai/backend/model_manager/taxonomy.py b/invokeai/backend/model_manager/taxonomy.py index c002418a6bd..4bf3461a8b9 100644 --- a/invokeai/backend/model_manager/taxonomy.py +++ b/invokeai/backend/model_manager/taxonomy.py @@ -52,6 +52,8 @@ class BaseModelType(str, Enum): """Indicates the model is associated with CogView 4 model architecture.""" ZImage = "z-image" """Indicates the model is associated with Z-Image model architecture, including Z-Image-Turbo.""" + External = "external" + """Indicates the model is hosted by an external provider.""" Unknown = "unknown" """Indicates the model's base architecture is unknown.""" @@ -76,6 +78,7 @@ class ModelType(str, Enum): SigLIP = "siglip" FluxRedux = "flux_redux" LlavaOnevision = "llava_onevision" + ExternalImageGenerator = "external_image_generator" Unknown = "unknown" @@ -170,6 +173,7 @@ class ModelFormat(str, Enum): BnbQuantizedLlmInt8b = "bnb_quantized_int8b" BnbQuantizednf4b = "bnb_quantized_nf4b" GGUFQuantized = "gguf_quantized" + ExternalApi = "external_api" Unknown = "unknown" @@ -198,6 +202,7 @@ class ModelSourceType(str, Enum): Path = "path" Url = "url" HFRepoID = "hf_repo_id" + External = "external" class FluxLoRAFormat(str, Enum): diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index c28df6ee383..218244d204d 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -993,6 +993,22 @@ "fileSize": "File Size", "filterModels": "Filter models", "fluxRedux": "FLUX Redux", + "externalImageGenerator": "External Image Generator", + "externalProviders": "External Providers", + "externalSetupTitle": "External Providers Setup", + "externalSetupDescription": "Connect an API key to enable external image generation. External starter models auto-install when a provider is configured.", + "externalInstallDefaults": "Auto-install starter models", + "externalProvidersUnavailable": "External providers are not available in this build.", + "externalSetupFooter": "An API key is required. External providers use remote APIs; usage may incur provider-side costs.", + "externalProviderCardDescription": "Configure {{providerId}} credentials for external image generation.", + "externalApiKey": "API Key", + "externalApiKeyPlaceholder": "Paste your API key", + "externalApiKeyPlaceholderSet": "API key configured", + "externalApiKeyHelper": "Stored in your InvokeAI config file.", + "externalBaseUrl": "Base URL (optional)", + "externalBaseUrlPlaceholder": "https://...", + "externalBaseUrlHelper": "Override the default API base URL if needed.", + "externalResetHelper": "Clear API key and base URL.", "height": "Height", "huggingFace": "HuggingFace", "huggingFacePlaceholder": "owner/model-name", @@ -1060,6 +1076,21 @@ "modelUpdated": "Model Updated", "modelUpdateFailed": "Model Update Failed", "name": "Name", + "externalProvider": "External Provider", + "externalCapabilities": "External Capabilities", + "externalDefaults": "External Defaults", + "providerId": "Provider ID", + "providerModelId": "Provider Model ID", + "supportedModes": "Supported Modes", + "supportsNegativePrompt": "Supports Negative Prompt", + "supportsReferenceImages": "Supports Reference Images", + "supportsSeed": "Supports Seed", + "supportsGuidance": "Supports Guidance", + "maxImagesPerRequest": "Max Images Per Request", + "maxReferenceImages": "Max Reference Images", + "maxImageWidth": "Max Image Width", + "maxImageHeight": "Max Image Height", + "numImages": "Num Images", "modelPickerFallbackNoModelsInstalled": "No models installed.", "modelPickerFallbackNoModelsInstalled2": "Visit the Model Manager to install models.", "noModelsInstalledDesc1": "Install models with the", @@ -1102,6 +1133,7 @@ "urlDescription": "Install models from a URL or local file path. Perfect for specific models you want to add.", "huggingFaceDescription": "Browse and install models directly from HuggingFace repositories.", "scanFolderDescription": "Scan a local folder to automatically detect and install models.", + "externalDescription": "Connect a Gemini or OpenAI API key to enable external generation. Usage may incur provider-side costs.", "recommendedModels": "Recommended Models", "exploreStarter": "Or browse all available starter models", "quickStart": "Quick Start Bundles", @@ -1575,7 +1607,11 @@ "intermediatesCleared_one": "Cleared {{count}} Intermediate", "intermediatesCleared_other": "Cleared {{count}} Intermediates", "intermediatesClearedFailed": "Problem Clearing Intermediates", - "reloadingIn": "Reloading in" + "reloadingIn": "Reloading in", + "externalProviders": "External Providers", + "externalProviderConfigured": "Configured", + "externalProviderNotConfigured": "API Key Required", + "externalProviderNotConfiguredHint": "Add your API key in Model Manager or the server config to enable this provider." }, "toast": { "addedToBoard": "Added to board {{name}}'s assets", diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts index 20057472ca8..ed2c67d5292 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts @@ -168,7 +168,7 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) = } } - if (SUPPORTS_REF_IMAGES_BASE_MODELS.includes(newModel.base)) { + if (newModel.base !== 'external' && SUPPORTS_REF_IMAGES_BASE_MODELS.includes(newModel.base)) { // Handle incompatible reference image models - switch to first compatible model, with some smart logic // to choose the best available model based on the new main model. const allRefImageModels = selectGlobalRefImageModels(state).filter(({ base }) => base === newBase); diff --git a/invokeai/frontend/web/src/features/controlLayers/components/RefImage/RefImagePreview.tsx b/invokeai/frontend/web/src/features/controlLayers/components/RefImage/RefImagePreview.tsx index 84c1b2fc37b..ddbdb8b131c 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/RefImage/RefImagePreview.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/RefImage/RefImagePreview.tsx @@ -15,6 +15,7 @@ import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/va import { memo, useCallback, useEffect, useMemo, useState } from 'react'; import { PiExclamationMarkBold, PiEyeSlashBold, PiImageBold } from 'react-icons/pi'; import { useImageDTOFromCroppableImage } from 'services/api/endpoints/images'; +import { isExternalApiModelConfig } from 'services/api/types'; import { RefImageWarningTooltipContent } from './RefImageWarningTooltipContent'; @@ -71,18 +72,19 @@ export const RefImagePreview = memo(() => { const selectedEntityId = useAppSelector(selectSelectedRefEntityId); const isPanelOpen = useAppSelector(selectIsRefImagePanelOpen); const [showWeightDisplay, setShowWeightDisplay] = useState(false); + const isExternalModel = !!mainModelConfig && isExternalApiModelConfig(mainModelConfig); const imageDTO = useImageDTOFromCroppableImage(entity.config.image); const sx = useMemo(() => { - if (!isIPAdapterConfig(entity.config)) { + if (!isIPAdapterConfig(entity.config) || isExternalModel) { return baseSx; } return getImageSxWithWeight(entity.config.weight); - }, [entity.config]); + }, [entity.config, isExternalModel]); useEffect(() => { - if (!isIPAdapterConfig(entity.config)) { + if (!isIPAdapterConfig(entity.config) || isExternalModel) { return; } setShowWeightDisplay(true); @@ -92,7 +94,7 @@ export const RefImagePreview = memo(() => { return () => { window.clearTimeout(timeout); }; - }, [entity.config]); + }, [entity.config, isExternalModel]); const warnings = useMemo(() => { return getGlobalReferenceImageWarnings(entity, mainModelConfig); @@ -154,7 +156,7 @@ export const RefImagePreview = memo(() => { ) : ( )} - {isIPAdapterConfig(entity.config) && ( + {isIPAdapterConfig(entity.config) && !isExternalModel && ( { const selectConfig = useMemo(() => buildSelectConfig(id), [id]); const config = useAppSelector(selectConfig); const tab = useAppSelector(selectActiveTab); + const mainModelConfig = useAppSelector(selectMainModelConfig); const onChangeBeginEndStepPct = useCallback( (beginEndStepPct: [number, number]) => { @@ -120,9 +122,10 @@ const RefImageSettingsContent = memo(() => { ); const isFLUX = useAppSelector(selectIsFLUX); + const isExternalModel = !!mainModelConfig && isExternalApiModelConfig(mainModelConfig); - // FLUX.2 Klein has built-in reference image support - no model selector needed - const showModelSelector = !isFlux2ReferenceImageConfig(config); + // FLUX.2 Klein and external API models do not require a ref image model selection. + const showModelSelector = !isFlux2ReferenceImageConfig(config) && !isExternalModel; return ( @@ -150,14 +153,14 @@ const RefImageSettingsContent = memo(() => { )} - {isIPAdapterConfig(config) && ( + {isIPAdapterConfig(config) && !isExternalModel && ( {!isFLUX && } )} - {isFLUXReduxConfig(config) && ( + {isFLUXReduxConfig(config) && !isExternalModel && ( { const bboxRect = selectBboxRect(store.getState()); const { x, y } = bboxRect; - const imageObject = imageDTOToImageObject(imageDTO); + const imageObject = imageDTOToImageObject(imageDTO, { usePixelBbox: false }); + const scale = Math.min(bboxRect.width / imageDTO.width, bboxRect.height / imageDTO.height); + const scaledWidth = Math.round(imageDTO.width * scale); + const scaledHeight = Math.round(imageDTO.height * scale); + const position = { + x: x + Math.round((bboxRect.width - scaledWidth) / 2), + y: y + Math.round((bboxRect.height - scaledHeight) / 2), + }; const selectedEntityIdentifier = selectSelectedEntityIdentifier(store.getState()); const overrides: Partial = { - position: { x, y }, + position, objects: [imageObject], }; diff --git a/invokeai/frontend/web/src/features/controlLayers/components/StagingArea/shared.test.ts b/invokeai/frontend/web/src/features/controlLayers/components/StagingArea/shared.test.ts index f16b9023164..9268fc7570f 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/StagingArea/shared.test.ts +++ b/invokeai/frontend/web/src/features/controlLayers/components/StagingArea/shared.test.ts @@ -181,6 +181,33 @@ describe('StagingAreaApi Utility Functions', () => { expect(result).toBe('first-image.png'); }); + it('should return first image from image collections', () => { + const queueItem: S['SessionQueueItem'] = { + item_id: 1, + status: 'completed', + priority: 0, + destination: 'test-session', + created_at: '2024-01-01T00:00:00Z', + updated_at: '2024-01-01T00:00:00Z', + started_at: '2024-01-01T00:00:01Z', + completed_at: '2024-01-01T00:01:00Z', + error: null, + session: { + id: 'test-session', + source_prepared_mapping: { + canvas_output: ['output-node-id'], + }, + results: { + 'output-node-id': { + images: [{ image_name: 'first.png' }, { image_name: 'second.png' }], + }, + }, + }, + } as unknown as S['SessionQueueItem']; + + expect(getOutputImageName(queueItem)).toBe('first.png'); + }); + it('should handle empty session mapping', () => { const queueItem: S['SessionQueueItem'] = { item_id: 1, diff --git a/invokeai/frontend/web/src/features/controlLayers/components/StagingArea/shared.ts b/invokeai/frontend/web/src/features/controlLayers/components/StagingArea/shared.ts index fe98408df58..1fe461e9993 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/StagingArea/shared.ts +++ b/invokeai/frontend/web/src/features/controlLayers/components/StagingArea/shared.ts @@ -1,4 +1,4 @@ -import { isImageField } from 'features/nodes/types/common'; +import { isImageField, isImageFieldCollection } from 'features/nodes/types/common'; import { isCanvasOutputNodeId } from 'features/nodes/util/graph/graphBuilderUtils'; import type { S } from 'services/api/types'; import { formatProgressMessage } from 'services/events/stores'; @@ -29,6 +29,9 @@ export const getOutputImageName = (item: S['SessionQueueItem']) => { if (isImageField(value)) { return value.image_name; } + if (isImageFieldCollection(value)) { + return value[0]?.image_name ?? null; + } } return null; diff --git a/invokeai/frontend/web/src/features/controlLayers/hooks/saveCanvasHooks.ts b/invokeai/frontend/web/src/features/controlLayers/hooks/saveCanvasHooks.ts index 6b089c4592b..1b95aef000b 100644 --- a/invokeai/frontend/web/src/features/controlLayers/hooks/saveCanvasHooks.ts +++ b/invokeai/frontend/web/src/features/controlLayers/hooks/saveCanvasHooks.ts @@ -45,7 +45,7 @@ import { toast } from 'features/toast/toast'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { serializeError } from 'serialize-error'; -import type { ImageDTO } from 'services/api/types'; +import { type ImageDTO, isExternalApiModelConfig } from 'services/api/types'; import type { JsonObject } from 'type-fest'; const log = logger('canvas'); @@ -90,7 +90,7 @@ const useSaveCanvas = ({ region, saveToGallery, toastOk, toastError, onSave, wit metadata.negative_prompt = selectNegativePrompt(state); metadata.seed = selectSeed(state); const model = selectMainModelConfig(state); - if (model) { + if (model && !isExternalApiModelConfig(model)) { metadata.model = Graph.getModelMetadataField(model); } } diff --git a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.test.ts b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.test.ts new file mode 100644 index 00000000000..66d17cd3608 --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.test.ts @@ -0,0 +1,103 @@ +import { zModelIdentifierField } from 'features/nodes/types/common'; +import type { + ExternalApiModelConfig, + ExternalApiModelDefaultSettings, + ExternalImageSize, + ExternalModelCapabilities, +} from 'services/api/types'; +import { describe, expect, it } from 'vitest'; + +import { + selectModelSupportsGuidance, + selectModelSupportsNegativePrompt, + selectModelSupportsRefImages, + selectModelSupportsSeed, + selectModelSupportsSteps, +} from './paramsSlice'; + +const createExternalConfig = (capabilities: ExternalModelCapabilities): ExternalApiModelConfig => { + const maxImageSize: ExternalImageSize = { width: 1024, height: 1024 }; + const defaultSettings: ExternalApiModelDefaultSettings = { width: 1024, height: 1024, steps: 30 }; + + return { + key: 'external-test', + hash: 'external:openai:gpt-image-1', + path: 'external://openai/gpt-image-1', + file_size: 0, + name: 'External Test', + description: null, + source: 'external://openai/gpt-image-1', + source_type: 'url', + source_api_response: null, + cover_image: null, + base: 'external', + type: 'external_image_generator', + format: 'external_api', + provider_id: 'openai', + provider_model_id: 'gpt-image-1', + capabilities: { ...capabilities, max_image_size: maxImageSize }, + default_settings: defaultSettings, + tags: ['external'], + is_default: false, + }; +}; + +describe('paramsSlice selectors for external models', () => { + it('uses external capabilities for negative prompt support', () => { + const config = createExternalConfig({ + modes: ['txt2img'], + supports_negative_prompt: true, + supports_reference_images: false, + }); + const model = zModelIdentifierField.parse(config); + + expect(selectModelSupportsNegativePrompt.resultFunc(model, config)).toBe(true); + }); + + it('uses external capabilities for ref image support', () => { + const config = createExternalConfig({ + modes: ['txt2img'], + supports_negative_prompt: false, + supports_reference_images: false, + }); + const model = zModelIdentifierField.parse(config); + + expect(selectModelSupportsRefImages.resultFunc(model, config)).toBe(false); + }); + + it('uses external capabilities for guidance support', () => { + const config = createExternalConfig({ + modes: ['txt2img'], + supports_negative_prompt: true, + supports_reference_images: false, + supports_guidance: true, + }); + const model = zModelIdentifierField.parse(config); + + expect(selectModelSupportsGuidance.resultFunc(model, config)).toBe(true); + }); + + it('uses external capabilities for seed support', () => { + const config = createExternalConfig({ + modes: ['txt2img'], + supports_negative_prompt: true, + supports_reference_images: false, + supports_seed: false, + }); + const model = zModelIdentifierField.parse(config); + + expect(selectModelSupportsSeed.resultFunc(model, config)).toBe(false); + }); + + it('uses external capabilities for steps support', () => { + const config = createExternalConfig({ + modes: ['txt2img'], + supports_negative_prompt: true, + supports_reference_images: false, + supports_steps: false, + }); + const model = zModelIdentifierField.parse(config); + + expect(selectModelSupportsSteps.resultFunc(model, config)).toBe(false); + }); +}); diff --git a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts index 8dcd93cc5de..a01e1e424cb 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts @@ -21,6 +21,7 @@ import { SUPPORTS_OPTIMIZED_DENOISING_BASE_MODELS, SUPPORTS_REF_IMAGES_BASE_MODELS, } from 'features/modelManagerV2/models'; +import type { BaseModelType } from 'features/nodes/types/common'; import { CLIP_SKIP_MAP } from 'features/parameters/types/constants'; import type { ParameterCanvasCoherenceMode, @@ -43,7 +44,8 @@ import type { } from 'features/parameters/types/parameterSchemas'; import { getGridSize, getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension'; import { modelConfigsAdapterSelectors, selectModelConfigsQuery } from 'services/api/endpoints/models'; -import { isNonRefinerMainModelConfig } from 'services/api/types'; +import type { AnyModelConfigWithExternal } from 'services/api/types'; +import { isExternalApiModelConfig, isNonRefinerMainModelConfig } from 'services/api/types'; import { assert } from 'tsafe'; const slice = createSlice({ @@ -317,7 +319,7 @@ const slice = createSlice({ //#region Dimensions sizeRecalled: (state, action: PayloadAction<{ width: number; height: number }>) => { const { width, height } = action.payload; - const gridSize = getGridSize(state.model?.base); + const gridSize = getGridSize(state.model?.base as BaseModelType | undefined); state.dimensions.width = Math.max(roundDownToMultiple(width, gridSize), 64); state.dimensions.height = Math.max(roundDownToMultiple(height, gridSize), 64); state.dimensions.aspectRatio.value = state.dimensions.width / state.dimensions.height; @@ -326,7 +328,7 @@ const slice = createSlice({ }, widthChanged: (state, action: PayloadAction<{ width: number; updateAspectRatio?: boolean; clamp?: boolean }>) => { const { width, updateAspectRatio, clamp } = action.payload; - const gridSize = getGridSize(state.model?.base); + const gridSize = getGridSize(state.model?.base as BaseModelType | undefined); state.dimensions.width = clamp ? Math.max(roundDownToMultiple(width, gridSize), 64) : width; if (state.dimensions.aspectRatio.isLocked) { @@ -344,7 +346,7 @@ const slice = createSlice({ }, heightChanged: (state, action: PayloadAction<{ height: number; updateAspectRatio?: boolean; clamp?: boolean }>) => { const { height, updateAspectRatio, clamp } = action.payload; - const gridSize = getGridSize(state.model?.base); + const gridSize = getGridSize(state.model?.base as BaseModelType | undefined); state.dimensions.height = clamp ? Math.max(roundDownToMultiple(height, gridSize), 64) : height; if (state.dimensions.aspectRatio.isLocked) { @@ -374,7 +376,7 @@ const slice = createSlice({ const { width, height } = calculateNewSize( state.dimensions.aspectRatio.value, state.dimensions.width * state.dimensions.height, - state.model?.base + state.model?.base as BaseModelType | undefined ); state.dimensions.width = width; state.dimensions.height = height; @@ -391,7 +393,7 @@ const slice = createSlice({ const { width, height } = calculateNewSize( state.dimensions.aspectRatio.value, state.dimensions.width * state.dimensions.height, - state.model?.base + state.model?.base as BaseModelType | undefined ); state.dimensions.width = width; state.dimensions.height = height; @@ -399,12 +401,12 @@ const slice = createSlice({ } }, sizeOptimized: (state) => { - const optimalDimension = getOptimalDimension(state.model?.base); + const optimalDimension = getOptimalDimension(state.model?.base as BaseModelType | undefined); if (state.dimensions.aspectRatio.isLocked) { const { width, height } = calculateNewSize( state.dimensions.aspectRatio.value, optimalDimension * optimalDimension, - state.model?.base + state.model?.base as BaseModelType | undefined ); state.dimensions.width = width; state.dimensions.height = height; @@ -415,13 +417,19 @@ const slice = createSlice({ } }, syncedToOptimalDimension: (state) => { - const optimalDimension = getOptimalDimension(state.model?.base); - - if (!getIsSizeOptimal(state.dimensions.width, state.dimensions.height, state.model?.base)) { + const optimalDimension = getOptimalDimension(state.model?.base as BaseModelType | undefined); + + if ( + !getIsSizeOptimal( + state.dimensions.width, + state.dimensions.height, + state.model?.base as BaseModelType | undefined + ) + ) { const bboxDims = calculateNewSize( state.dimensions.aspectRatio.value, optimalDimension * optimalDimension, - state.model?.base + state.model?.base as BaseModelType | undefined ); state.dimensions.width = bboxDims.width; state.dimensions.height = bboxDims.height; @@ -456,6 +464,9 @@ const hasModelClipSkip = (model: ParameterModel | null) => { }; const getModelMaxClipSkip = (model: ParameterModel) => { + if (model.base === 'external') { + return undefined; + } if (model.base === 'sdxl') { // We don't support user-defined CLIP skip for SDXL because it doesn't do anything useful return 0; @@ -594,6 +605,7 @@ export const selectIsSD3 = createParamsSelector((params) => params.model?.base = export const selectIsCogView4 = createParamsSelector((params) => params.model?.base === 'cogview4'); export const selectIsZImage = createParamsSelector((params) => params.model?.base === 'z-image'); export const selectIsFlux2 = createParamsSelector((params) => params.model?.base === 'flux2'); +export const selectIsExternal = createParamsSelector((params) => params.model?.base === 'external'); export const selectIsFluxKontext = createParamsSelector((params) => { if (params.model?.base === 'flux' && params.model?.name.toLowerCase().includes('kontext')) { return true; @@ -638,19 +650,83 @@ export const selectOptimizedDenoisingEnabled = createParamsSelector((params) => export const selectPositivePrompt = createParamsSelector((params) => params.positivePrompt); export const selectNegativePrompt = createParamsSelector((params) => params.negativePrompt); export const selectNegativePromptWithFallback = createParamsSelector((params) => params.negativePrompt ?? ''); +export const selectModelConfig = createSelector( + selectModelConfigsQuery, + selectParamsSlice, + (modelConfigs, { model }) => { + if (!modelConfigs.data) { + return null; + } + if (!model) { + return null; + } + return ( + (modelConfigsAdapterSelectors.selectById(modelConfigs.data, model.key) as + | AnyModelConfigWithExternal + | undefined) ?? null + ); + } +); export const selectHasNegativePrompt = createParamsSelector((params) => params.negativePrompt !== null); export const selectModelSupportsNegativePrompt = createSelector( selectModel, - (model) => !!model && SUPPORTS_NEGATIVE_PROMPT_BASE_MODELS.includes(model.base) -); -export const selectModelSupportsRefImages = createSelector( - selectModel, - (model) => !!model && SUPPORTS_REF_IMAGES_BASE_MODELS.includes(model.base) + selectModelConfig, + (model, modelConfig) => { + if (!model) { + return false; + } + if (modelConfig && isExternalApiModelConfig(modelConfig)) { + return modelConfig.capabilities.supports_negative_prompt ?? false; + } + if (model.base === 'external') { + return false; + } + return SUPPORTS_NEGATIVE_PROMPT_BASE_MODELS.includes(model.base); + } ); +export const selectModelSupportsRefImages = createSelector(selectModel, selectModelConfig, (model, modelConfig) => { + if (!model) { + return false; + } + if (modelConfig && isExternalApiModelConfig(modelConfig)) { + return modelConfig.capabilities.supports_reference_images ?? false; + } + if (model.base === 'external') { + return false; + } + return SUPPORTS_REF_IMAGES_BASE_MODELS.includes(model.base); +}); export const selectModelSupportsOptimizedDenoising = createSelector( selectModel, - (model) => !!model && SUPPORTS_OPTIMIZED_DENOISING_BASE_MODELS.includes(model.base) + (model) => !!model && model.base !== 'external' && SUPPORTS_OPTIMIZED_DENOISING_BASE_MODELS.includes(model.base) ); +export const selectModelSupportsGuidance = createSelector(selectModel, selectModelConfig, (model, modelConfig) => { + if (!model) { + return false; + } + if (modelConfig && isExternalApiModelConfig(modelConfig)) { + return modelConfig.capabilities.supports_guidance ?? false; + } + return true; +}); +export const selectModelSupportsSeed = createSelector(selectModel, selectModelConfig, (model, modelConfig) => { + if (!model) { + return false; + } + if (modelConfig && isExternalApiModelConfig(modelConfig)) { + return modelConfig.capabilities.supports_seed ?? false; + } + return true; +}); +export const selectModelSupportsSteps = createSelector(selectModel, selectModelConfig, (model, modelConfig) => { + if (!model) { + return false; + } + if (modelConfig && isExternalApiModelConfig(modelConfig)) { + return modelConfig.capabilities.supports_steps ?? false; + } + return true; +}); export const selectScheduler = createParamsSelector((params) => params.scheduler); export const selectFluxScheduler = createParamsSelector((params) => params.fluxScheduler); export const selectFluxDypePreset = createParamsSelector((params) => params.fluxDypePreset); @@ -693,24 +769,23 @@ export const selectHeight = createParamsSelector((params) => params.dimensions.h export const selectAspectRatioID = createParamsSelector((params) => params.dimensions.aspectRatio.id); export const selectAspectRatioValue = createParamsSelector((params) => params.dimensions.aspectRatio.value); export const selectAspectRatioIsLocked = createParamsSelector((params) => params.dimensions.aspectRatio.isLocked); +export const selectAllowedAspectRatioIDs = createSelector(selectModelConfig, (modelConfig) => { + if (!modelConfig || !isExternalApiModelConfig(modelConfig)) { + return null; + } + const allowed = modelConfig.capabilities.allowed_aspect_ratios; + return allowed?.length ? allowed : null; +}); -export const selectMainModelConfig = createSelector( - selectModelConfigsQuery, - selectParamsSlice, - (modelConfigs, { model }) => { - if (!modelConfigs.data) { - return null; - } - if (!model) { - return null; - } - const modelConfig = modelConfigsAdapterSelectors.selectById(modelConfigs.data, model.key); - if (!modelConfig) { - return null; - } - if (!isNonRefinerMainModelConfig(modelConfig)) { - return null; - } +export const selectMainModelConfig = createSelector(selectModelConfig, (modelConfig) => { + if (!modelConfig) { + return null; + } + if (isExternalApiModelConfig(modelConfig)) { return modelConfig; } -); + if (!isNonRefinerMainModelConfig(modelConfig)) { + return null; + } + return modelConfig; +}); diff --git a/invokeai/frontend/web/src/features/controlLayers/store/selectors.ts b/invokeai/frontend/web/src/features/controlLayers/store/selectors.ts index 5c0abfdb892..cec75a394bc 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/selectors.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/selectors.ts @@ -13,6 +13,7 @@ import type { CanvasRegionalGuidanceState, CanvasState, } from 'features/controlLayers/store/types'; +import type { BaseModelType } from 'features/nodes/types/common'; import { getGridSize, getOptimalDimension } from 'features/parameters/util/optimalDimension'; import type { Equals } from 'tsafe'; import { assert } from 'tsafe'; @@ -74,7 +75,7 @@ export const selectHasEntities = createSelector(selectEntityCountAll, (count) => * Selects the optimal dimension for the canvas based on the currently-selected model */ export const selectOptimalDimension = createSelector(selectParamsSlice, (params): number => { - const modelBase = params.model?.base; + const modelBase = params.model?.base as BaseModelType | undefined; return getOptimalDimension(modelBase ?? null); }); @@ -82,7 +83,7 @@ export const selectOptimalDimension = createSelector(selectParamsSlice, (params) * Selects the grid size for the canvas based on the currently-selected model */ export const selectGridSize = createSelector(selectParamsSlice, (params): number => { - const modelBase = params.model?.base; + const modelBase = params.model?.base as BaseModelType | undefined; return getGridSize(modelBase ?? null); }); diff --git a/invokeai/frontend/web/src/features/controlLayers/store/validators.ts b/invokeai/frontend/web/src/features/controlLayers/store/validators.ts index 3406e9e7ee6..15a986c8018 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/validators.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/validators.ts @@ -6,7 +6,11 @@ import type { RefImageState, } from 'features/controlLayers/store/types'; import type { ModelIdentifierField } from 'features/nodes/types/common'; -import type { AnyModelConfig, MainModelConfig } from 'services/api/types'; +import { + type AnyModelConfigWithExternal, + isExternalApiModelConfig, + type MainOrExternalModelConfig, +} from 'services/api/types'; const WARNINGS = { UNSUPPORTED_MODEL: 'controlLayers.warnings.unsupportedModel', @@ -28,7 +32,7 @@ type WarningTKey = (typeof WARNINGS)[keyof typeof WARNINGS]; export const getRegionalGuidanceWarnings = ( entity: CanvasRegionalGuidanceState, - model: MainModelConfig | null | undefined + model: MainOrExternalModelConfig | null | undefined ): WarningTKey[] => { const warnings: WarningTKey[] = []; @@ -90,8 +94,8 @@ export const getRegionalGuidanceWarnings = ( }; export const areBasesCompatibleForRefImage = ( - first?: ModelIdentifierField | AnyModelConfig | null, - second?: ModelIdentifierField | AnyModelConfig | null + first?: ModelIdentifierField | AnyModelConfigWithExternal | null, + second?: ModelIdentifierField | AnyModelConfigWithExternal | null ): boolean => { if (!first || !second) { return false; @@ -112,11 +116,19 @@ export const areBasesCompatibleForRefImage = ( export const getGlobalReferenceImageWarnings = ( entity: RefImageState, - model: MainModelConfig | null | undefined + model: MainOrExternalModelConfig | null | undefined ): WarningTKey[] => { const warnings: WarningTKey[] = []; if (model) { + if (isExternalApiModelConfig(model)) { + if (!entity.config.image) { + // No image selected + warnings.push(WARNINGS.IP_ADAPTER_NO_IMAGE_SELECTED); + } + return warnings; + } + if (model.base === 'sd-3' || model.base === 'sd-2') { // Unsupported model architecture warnings.push(WARNINGS.UNSUPPORTED_MODEL); @@ -147,7 +159,7 @@ export const getGlobalReferenceImageWarnings = ( export const getControlLayerWarnings = ( entity: CanvasControlLayerState, - model: MainModelConfig | null | undefined + model: MainOrExternalModelConfig | null | undefined ): WarningTKey[] => { const warnings: WarningTKey[] = []; @@ -181,7 +193,7 @@ export const getControlLayerWarnings = ( export const getRasterLayerWarnings = ( _entity: CanvasRasterLayerState, - _model: MainModelConfig | null | undefined + _model: MainOrExternalModelConfig | null | undefined ): WarningTKey[] => { const warnings: WarningTKey[] = []; @@ -192,7 +204,7 @@ export const getRasterLayerWarnings = ( export const getInpaintMaskWarnings = ( _entity: CanvasInpaintMaskState, - _model: MainModelConfig | null | undefined + _model: MainOrExternalModelConfig | null | undefined ): WarningTKey[] => { const warnings: WarningTKey[] = []; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/models.ts b/invokeai/frontend/web/src/features/modelManagerV2/models.ts index 7b5a08adfe2..4cd3c7eaad0 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/models.ts +++ b/invokeai/frontend/web/src/features/modelManagerV2/models.ts @@ -1,10 +1,11 @@ import type { AnyModelVariant, BaseModelType, ModelFormat, ModelType } from 'features/nodes/types/common'; -import type { AnyModelConfig } from 'services/api/types'; import { + type AnyModelConfig, isCLIPEmbedModelConfig, isCLIPVisionModelConfig, isControlLoRAModelConfig, isControlNetModelConfig, + isExternalApiModelConfig, isFluxReduxModelConfig, isIPAdapterModelConfig, isLLaVAModelConfig, @@ -121,6 +122,11 @@ export const MODEL_CATEGORIES: Record = { i18nKey: 'modelManager.llavaOnevision', filter: isLLaVAModelConfig, }, + external_image_generator: { + category: 'external_image_generator', + i18nKey: 'modelManager.externalImageGenerator', + filter: isExternalApiModelConfig, + }, }; export const MODEL_CATEGORIES_AS_LIST = objectEntries(MODEL_CATEGORIES).map(([category, { i18nKey, filter }]) => ({ @@ -132,7 +138,7 @@ export const MODEL_CATEGORIES_AS_LIST = objectEntries(MODEL_CATEGORIES).map(([ca /** * Mapping of model base to its color */ -export const MODEL_BASE_TO_COLOR: Record = { +export const MODEL_BASE_TO_COLOR: Record = { any: 'base', 'sd-1': 'green', 'sd-2': 'teal', @@ -143,13 +149,14 @@ export const MODEL_BASE_TO_COLOR: Record = { flux2: 'gold', cogview4: 'red', 'z-image': 'cyan', + external: 'orange', unknown: 'red', }; /** * Mapping of model type to human readable name */ -export const MODEL_TYPE_TO_LONG_NAME: Record = { +export const MODEL_TYPE_TO_LONG_NAME: Record = { main: 'Main', vae: 'VAE', lora: 'LoRA', @@ -167,13 +174,14 @@ export const MODEL_TYPE_TO_LONG_NAME: Record = { clip_embed: 'CLIP Embed', siglip: 'SigLIP', flux_redux: 'FLUX Redux', + external_image_generator: 'External Image Generator', unknown: 'Unknown', }; /** * Mapping of model base to human readable name */ -export const MODEL_BASE_TO_LONG_NAME: Record = { +export const MODEL_BASE_TO_LONG_NAME: Record = { any: 'Any', 'sd-1': 'Stable Diffusion 1.x', 'sd-2': 'Stable Diffusion 2.x', @@ -184,13 +192,14 @@ export const MODEL_BASE_TO_LONG_NAME: Record = { flux2: 'FLUX.2', cogview4: 'CogView4', 'z-image': 'Z-Image', + external: 'External', unknown: 'Unknown', }; /** * Mapping of model base to short human readable name */ -export const MODEL_BASE_TO_SHORT_NAME: Record = { +export const MODEL_BASE_TO_SHORT_NAME: Record = { any: 'Any', 'sd-1': 'SD1.X', 'sd-2': 'SD2.X', @@ -201,6 +210,7 @@ export const MODEL_BASE_TO_SHORT_NAME: Record = { flux2: 'FLUX.2', cogview4: 'CogView4', 'z-image': 'Z-Image', + external: 'External', unknown: 'Unknown', }; @@ -222,12 +232,13 @@ export const MODEL_VARIANT_TO_LONG_NAME: Record = { qwen3_8b: 'Qwen3 8B', }; -export const MODEL_FORMAT_TO_LONG_NAME: Record = { +export const MODEL_FORMAT_TO_LONG_NAME: Record = { omi: 'OMI', diffusers: 'Diffusers', checkpoint: 'Checkpoint', lycoris: 'LyCORIS', onnx: 'ONNX', + external_api: 'External API', olive: 'Olive', embedding_file: 'Embedding (file)', embedding_folder: 'Embedding (folder)', diff --git a/invokeai/frontend/web/src/features/modelManagerV2/store/installModelsStore.ts b/invokeai/frontend/web/src/features/modelManagerV2/store/installModelsStore.ts index b99a1212fec..79b7bfe31a7 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/store/installModelsStore.ts +++ b/invokeai/frontend/web/src/features/modelManagerV2/store/installModelsStore.ts @@ -1,13 +1,14 @@ import { atom } from 'nanostores'; -type InstallModelsTabName = 'launchpad' | 'urlOrLocal' | 'huggingface' | 'scanFolder' | 'starterModels'; +type InstallModelsTabName = 'launchpad' | 'urlOrLocal' | 'huggingface' | 'external' | 'scanFolder' | 'starterModels'; const TAB_TO_INDEX_MAP: Record = { launchpad: 0, urlOrLocal: 1, huggingface: 2, - scanFolder: 3, - starterModels: 4, + external: 3, + scanFolder: 4, + starterModels: 5, }; export const setInstallModelsTabByName = (tab: InstallModelsTabName) => { diff --git a/invokeai/frontend/web/src/features/modelManagerV2/store/modelManagerV2Slice.ts b/invokeai/frontend/web/src/features/modelManagerV2/store/modelManagerV2Slice.ts index 092998d0c31..44df38d9112 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/store/modelManagerV2Slice.ts +++ b/invokeai/frontend/web/src/features/modelManagerV2/store/modelManagerV2Slice.ts @@ -7,7 +7,10 @@ import { zModelType } from 'features/nodes/types/common'; import { assert } from 'tsafe'; import z from 'zod'; -const zModelCategoryType = zModelType.exclude(['onnx']).or(z.literal('refiner')); +const zModelCategoryType = zModelType + .exclude(['onnx']) + .or(z.literal('refiner')) + .or(z.literal('external_image_generator')); export type ModelCategoryType = z.infer; const zFilterableModelType = zModelCategoryType.or(z.literal('missing')); diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/ExternalProviders/ExternalProvidersForm.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/ExternalProviders/ExternalProvidersForm.tsx new file mode 100644 index 00000000000..59fb868a50d --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/ExternalProviders/ExternalProvidersForm.tsx @@ -0,0 +1,266 @@ +import { + Badge, + Button, + Card, + Flex, + FormControl, + FormHelperText, + FormLabel, + Heading, + Input, + Text, + Tooltip, +} from '@invoke-ai/ui-library'; +import { useStore } from '@nanostores/react'; +import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent'; +import { useBuildModelInstallArg } from 'features/modelManagerV2/hooks/useBuildModelsToInstall'; +import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel'; +import { $installModelsTabIndex } from 'features/modelManagerV2/store/installModelsStore'; +import type { ChangeEvent } from 'react'; +import { memo, useCallback, useEffect, useMemo, useState } from 'react'; +import { useTranslation } from 'react-i18next'; +import { PiCheckBold, PiWarningBold } from 'react-icons/pi'; +import { + useGetExternalProviderConfigsQuery, + useResetExternalProviderConfigMutation, + useSetExternalProviderConfigMutation, +} from 'services/api/endpoints/appInfo'; +import { useGetStarterModelsQuery } from 'services/api/endpoints/models'; +import type { ExternalProviderConfig, StarterModel } from 'services/api/types'; + +const PROVIDER_SORT_ORDER = ['gemini', 'openai']; + +type ProviderCardProps = { + provider: ExternalProviderConfig; + onInstallModels: (providerId: string) => void; +}; + +type UpdatePayload = { + provider_id: string; + api_key?: string; + base_url?: string; +}; + +export const ExternalProvidersForm = memo(() => { + const { t } = useTranslation(); + const { data, isLoading } = useGetExternalProviderConfigsQuery(); + const { data: starterModels } = useGetStarterModelsQuery(); + const [installModel] = useInstallModel(); + const { getIsInstalled, buildModelInstallArg } = useBuildModelInstallArg(); + const tabIndex = useStore($installModelsTabIndex); + + const externalModelsByProvider = useMemo(() => { + const groups = new Map(); + for (const model of starterModels?.starter_models ?? []) { + if (!model.source.startsWith('external://')) { + continue; + } + const providerId = model.source.replace('external://', '').split('/')[0]; + if (!providerId) { + continue; + } + const models = groups.get(providerId) ?? []; + models.push(model); + groups.set(providerId, models); + } + + for (const [providerId, models] of groups.entries()) { + models.sort((a, b) => a.name.localeCompare(b.name)); + groups.set(providerId, models); + } + + return groups; + }, [starterModels]); + + const handleInstallProviderModels = useCallback( + (providerId: string) => { + const models = externalModelsByProvider.get(providerId); + if (!models?.length) { + return; + } + const modelsToInstall = models.filter((model) => !getIsInstalled(model)).map(buildModelInstallArg); + modelsToInstall.forEach((model) => installModel(model)); + }, + [buildModelInstallArg, externalModelsByProvider, getIsInstalled, installModel] + ); + + const sortedProviders = useMemo(() => { + if (!data) { + return []; + } + return [...data].sort((a, b) => { + const aIndex = PROVIDER_SORT_ORDER.indexOf(a.provider_id); + const bIndex = PROVIDER_SORT_ORDER.indexOf(b.provider_id); + if (aIndex === -1 && bIndex === -1) { + return a.provider_id.localeCompare(b.provider_id); + } + if (aIndex === -1) { + return 1; + } + if (bIndex === -1) { + return -1; + } + return aIndex - bIndex; + }); + }, [data]); + + return ( + + + {t('modelManager.externalSetupTitle')} + {t('modelManager.externalSetupDescription')} + + + + {isLoading && {t('common.loading')}} + {!isLoading && sortedProviders.length === 0 && ( + {t('modelManager.externalProvidersUnavailable')} + )} + {sortedProviders.map((provider) => ( + + ))} + + + {tabIndex === 3 && ( + + {t('modelManager.externalSetupFooter')} + + )} + + ); +}); + +ExternalProvidersForm.displayName = 'ExternalProvidersForm'; + +const ProviderCard = memo(({ provider, onInstallModels }: ProviderCardProps) => { + const { t } = useTranslation(); + const [apiKey, setApiKey] = useState(''); + const [baseUrl, setBaseUrl] = useState(provider.base_url ?? ''); + const [saveConfig, { isLoading }] = useSetExternalProviderConfigMutation(); + const [resetConfig, { isLoading: isResetting }] = useResetExternalProviderConfigMutation(); + + useEffect(() => { + setBaseUrl(provider.base_url ?? ''); + }, [provider.base_url]); + + const handleSave = useCallback(() => { + const trimmedApiKey = apiKey.trim(); + const trimmedBaseUrl = baseUrl.trim(); + const updatePayload: UpdatePayload = { + provider_id: provider.provider_id, + }; + if (trimmedApiKey) { + updatePayload.api_key = trimmedApiKey; + } + if (trimmedBaseUrl !== (provider.base_url ?? '')) { + updatePayload.base_url = trimmedBaseUrl; + } + + if (!updatePayload.api_key && updatePayload.base_url === undefined) { + return; + } + + saveConfig(updatePayload) + .unwrap() + .then((result) => { + if (result.api_key_configured) { + setApiKey(''); + onInstallModels(provider.provider_id); + } + if (result.base_url !== undefined) { + setBaseUrl(result.base_url ?? ''); + } + }); + }, [apiKey, baseUrl, onInstallModels, provider.base_url, provider.provider_id, saveConfig]); + + const handleReset = useCallback(() => { + resetConfig(provider.provider_id) + .unwrap() + .then((result) => { + setApiKey(''); + setBaseUrl(result.base_url ?? ''); + }); + }, [provider.provider_id, resetConfig]); + + const handleApiKeyChange = useCallback((event: ChangeEvent) => { + setApiKey(event.target.value); + }, []); + + const handleBaseUrlChange = useCallback((event: ChangeEvent) => { + setBaseUrl(event.target.value); + }, []); + + const statusBadge = provider.api_key_configured ? ( + + + {t('settings.externalProviderConfigured')} + + ) : ( + + + {t('settings.externalProviderNotConfigured')} + + ); + + return ( + + + + + {provider.provider_id} + + + {t('modelManager.externalProviderCardDescription', { providerId: provider.provider_id })} + + + {statusBadge} + + + + {t('modelManager.externalApiKey')} + + {t('modelManager.externalApiKeyHelper')} + + + {t('modelManager.externalBaseUrl')} + + {t('modelManager.externalBaseUrlHelper')} + + + + + + + + + + ); +}); + +ProviderCard.displayName = 'ProviderCard'; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/LaunchpadForm/LaunchpadForm.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/LaunchpadForm/LaunchpadForm.tsx index fc99bcec7bf..591c61a4b23 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/LaunchpadForm/LaunchpadForm.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/LaunchpadForm/LaunchpadForm.tsx @@ -6,7 +6,7 @@ import { StarterBundleButton } from 'features/modelManagerV2/subpanels/AddModelP import { StarterBundleTooltipContentCompact } from 'features/modelManagerV2/subpanels/AddModelPanel/StarterModels/StarterBundleTooltipContentCompact'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; -import { PiFolderOpenBold, PiLinkBold, PiStarBold } from 'react-icons/pi'; +import { PiFolderOpenBold, PiLinkBold, PiPlugBold, PiStarBold } from 'react-icons/pi'; import { SiHuggingface } from 'react-icons/si'; import { useGetStarterModelsQuery } from 'services/api/endpoints/models'; @@ -28,6 +28,10 @@ export const LaunchpadForm = memo(() => { setInstallModelsTabByName('scanFolder'); }, []); + const navigateToExternalTab = useCallback(() => { + setInstallModelsTabByName('external'); + }, []); + const navigateToStarterModelsTab = useCallback(() => { setInstallModelsTabByName('starterModels'); }, []); @@ -63,6 +67,12 @@ export const LaunchpadForm = memo(() => { title={t('modelManager.scanFolder')} description={t('modelManager.launchpad.scanFolderDescription')} /> + {/* Recommended Section */} diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/ModelInstallQueue/ModelInstallQueueItem.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/ModelInstallQueue/ModelInstallQueueItem.tsx index bd39030cdd5..e4ae5155276 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/ModelInstallQueue/ModelInstallQueueItem.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/ModelInstallQueue/ModelInstallQueueItem.tsx @@ -172,11 +172,15 @@ export const ModelInstallQueueItem = memo((props: ModelListItemProps) => { return installJob.source.url; case 'local': return installJob.source.path; + case 'external': + return `external://${installJob.source.provider_id}/${installJob.source.provider_model_id}`; default: return t('common.unknown'); } }, [installJob.source]); + const configuredName = installJob.config_in?.name; + const modelName = useMemo(() => { switch (installJob.source.type) { case 'hf': { @@ -187,13 +191,15 @@ export const ModelInstallQueueItem = memo((props: ModelListItemProps) => { return repo_id; } case 'url': - return installJob.source.url.split('/').slice(-1)[0] ?? t('common.unknown'); + return configuredName ?? installJob.source.url.split('/').slice(-1)[0] ?? t('common.unknown'); case 'local': - return installJob.source.path.split('\\').slice(-1)[0] ?? t('common.unknown'); + return configuredName ?? installJob.source.path.split('\\').slice(-1)[0] ?? t('common.unknown'); + case 'external': + return configuredName ?? `${installJob.source.provider_id}/${installJob.source.provider_model_id}`; default: - return t('common.unknown'); + return configuredName ?? t('common.unknown'); } - }, [installJob.source]); + }, [configuredName, installJob.source]); const progressValue = useMemo(() => { if (installJob.status === 'completed' || installJob.status === 'error' || installJob.status === 'cancelled') { diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/StarterModels/StarterModelsResults.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/StarterModels/StarterModelsResults.tsx index 86350c54c42..20b52cbf940 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/StarterModels/StarterModelsResults.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/StarterModels/StarterModelsResults.tsx @@ -21,6 +21,9 @@ export const StarterModelsResults = memo(({ results }: StarterModelsResultsProps const filteredResults = useMemo(() => { return results.starter_models.filter((result) => { + if (result.source.startsWith('external://')) { + return false; + } const trimmedSearchTerm = searchTerm.trim().toLowerCase(); const matchStrings = [ result.name.toLowerCase(), diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/InstallModels.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/InstallModels.tsx index 9039c0f85f4..5bc4c9713fc 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/InstallModels.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/InstallModels.tsx @@ -2,18 +2,18 @@ import type { SystemStyleObject } from '@invoke-ai/ui-library'; import { Box, Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library'; import { useStore } from '@nanostores/react'; import { $installModelsTabIndex } from 'features/modelManagerV2/store/installModelsStore'; +import { ExternalProvidersForm } from 'features/modelManagerV2/subpanels/AddModelPanel/ExternalProviders/ExternalProvidersForm'; +import { HuggingFaceForm } from 'features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HuggingFaceForm'; +import { InstallModelForm } from 'features/modelManagerV2/subpanels/AddModelPanel/InstallModelForm'; +import { LaunchpadForm } from 'features/modelManagerV2/subpanels/AddModelPanel/LaunchpadForm/LaunchpadForm'; +import { ModelInstallQueue } from 'features/modelManagerV2/subpanels/AddModelPanel/ModelInstallQueue/ModelInstallQueue'; +import { ScanModelsForm } from 'features/modelManagerV2/subpanels/AddModelPanel/ScanFolder/ScanFolderForm'; import { StarterModelsForm } from 'features/modelManagerV2/subpanels/AddModelPanel/StarterModels/StarterModelsForm'; import { memo } from 'react'; import { useTranslation } from 'react-i18next'; -import { PiCubeBold, PiFolderOpenBold, PiLinkSimpleBold, PiShootingStarBold } from 'react-icons/pi'; +import { PiCubeBold, PiFolderOpenBold, PiLinkSimpleBold, PiPlugBold, PiShootingStarBold } from 'react-icons/pi'; import { SiHuggingface } from 'react-icons/si'; -import { HuggingFaceForm } from './AddModelPanel/HuggingFaceFolder/HuggingFaceForm'; -import { InstallModelForm } from './AddModelPanel/InstallModelForm'; -import { LaunchpadForm } from './AddModelPanel/LaunchpadForm/LaunchpadForm'; -import { ModelInstallQueue } from './AddModelPanel/ModelInstallQueue/ModelInstallQueue'; -import { ScanModelsForm } from './AddModelPanel/ScanFolder/ScanFolderForm'; - const installModelsTabSx: SystemStyleObject = { display: 'flex', gap: 2, @@ -61,6 +61,10 @@ export const InstallModels = memo(() => { {t('modelManager.huggingFace')} + + + {t('modelManager.externalProviders')} + {t('modelManager.scanFolder')} @@ -80,6 +84,9 @@ export const InstallModels = memo(() => { + + + diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelFormatBadge.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelFormatBadge.tsx index 7d44ee54637..9728802c8bd 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelFormatBadge.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelFormatBadge.tsx @@ -3,10 +3,10 @@ import type { ModelFormat } from 'features/nodes/types/common'; import { memo } from 'react'; type Props = { - format: ModelFormat; + format: ModelFormat | 'external_api'; }; -const FORMAT_NAME_MAP: Record = { +const FORMAT_NAME_MAP: Record = { diffusers: 'diffusers', lycoris: 'lycoris', checkpoint: 'checkpoint', @@ -19,12 +19,13 @@ const FORMAT_NAME_MAP: Record = { bnb_quantized_nf4b: 'quantized', gguf_quantized: 'gguf', omi: 'omi', + external_api: 'external_api', unknown: 'unknown', olive: 'olive', onnx: 'onnx', }; -const FORMAT_COLOR_MAP: Record = { +const FORMAT_COLOR_MAP: Record = { diffusers: 'base', omi: 'base', lycoris: 'base', @@ -40,6 +41,7 @@ const FORMAT_COLOR_MAP: Record = { unknown: 'red', olive: 'base', onnx: 'base', + external_api: 'base', }; const ModelFormatBadge = ({ format }: Props) => { diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/BaseModelSelect.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/BaseModelSelect.tsx index 8235d26efef..e4c8752e569 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/BaseModelSelect.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/BaseModelSelect.tsx @@ -5,7 +5,7 @@ import { MODEL_BASE_TO_LONG_NAME } from 'features/modelManagerV2/models'; import { useCallback, useMemo } from 'react'; import type { Control } from 'react-hook-form'; import { useController } from 'react-hook-form'; -import type { UpdateModelArg } from 'services/api/endpoints/models'; +import type { UpdateModelBody } from 'services/api/types'; import { objectEntries } from 'tsafe'; const options: ComboboxOption[] = objectEntries(MODEL_BASE_TO_LONG_NAME).map(([value, label]) => ({ @@ -14,7 +14,7 @@ const options: ComboboxOption[] = objectEntries(MODEL_BASE_TO_LONG_NAME).map(([v })); type Props = { - control: Control; + control: Control; }; const BaseModelSelect = ({ control }: Props) => { diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/ModelFormatSelect.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/ModelFormatSelect.tsx index 1057ab7784c..2bd3eb954e5 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/ModelFormatSelect.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/ModelFormatSelect.tsx @@ -5,7 +5,7 @@ import { MODEL_FORMAT_TO_LONG_NAME } from 'features/modelManagerV2/models'; import { useCallback, useMemo } from 'react'; import type { Control } from 'react-hook-form'; import { useController } from 'react-hook-form'; -import type { UpdateModelArg } from 'services/api/endpoints/models'; +import type { UpdateModelBody } from 'services/api/types'; import { objectEntries } from 'tsafe'; const options: ComboboxOption[] = objectEntries(MODEL_FORMAT_TO_LONG_NAME).map(([value, label]) => ({ @@ -14,7 +14,7 @@ const options: ComboboxOption[] = objectEntries(MODEL_FORMAT_TO_LONG_NAME).map(( })); type Props = { - control: Control; + control: Control; }; const ModelFormatSelect = ({ control }: Props) => { diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/ModelTypeSelect.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/ModelTypeSelect.tsx index 44b41f01518..b35ce7f96df 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/ModelTypeSelect.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/ModelTypeSelect.tsx @@ -5,7 +5,7 @@ import { MODEL_TYPE_TO_LONG_NAME } from 'features/modelManagerV2/models'; import { useCallback, useMemo } from 'react'; import type { Control } from 'react-hook-form'; import { useController } from 'react-hook-form'; -import type { UpdateModelArg } from 'services/api/endpoints/models'; +import type { UpdateModelBody } from 'services/api/types'; import { objectEntries } from 'tsafe'; const options: ComboboxOption[] = objectEntries(MODEL_TYPE_TO_LONG_NAME).map(([value, label]) => ({ @@ -14,7 +14,7 @@ const options: ComboboxOption[] = objectEntries(MODEL_TYPE_TO_LONG_NAME).map(([v })); type Props = { - control: Control; + control: Control; }; const ModelTypeSelect = ({ control }: Props) => { diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/ModelVariantSelect.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/ModelVariantSelect.tsx index 52eb2a4749d..d8e8c6a5b8a 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/ModelVariantSelect.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/ModelVariantSelect.tsx @@ -5,13 +5,13 @@ import { MODEL_VARIANT_TO_LONG_NAME } from 'features/modelManagerV2/models'; import { useCallback, useMemo } from 'react'; import type { Control } from 'react-hook-form'; import { useController } from 'react-hook-form'; -import type { UpdateModelArg } from 'services/api/endpoints/models'; +import type { UpdateModelBody } from 'services/api/types'; import { objectEntries } from 'tsafe'; const options: ComboboxOption[] = objectEntries(MODEL_VARIANT_TO_LONG_NAME).map(([value, label]) => ({ label, value })); type Props = { - control: Control; + control: Control; }; const ModelVariantSelect = ({ control }: Props) => { diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/PredictionTypeSelect.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/PredictionTypeSelect.tsx index dcef95b4243..593bc4c4136 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/PredictionTypeSelect.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/PredictionTypeSelect.tsx @@ -4,7 +4,7 @@ import { typedMemo } from 'common/util/typedMemo'; import { useCallback, useMemo } from 'react'; import type { Control } from 'react-hook-form'; import { useController } from 'react-hook-form'; -import type { UpdateModelArg } from 'services/api/endpoints/models'; +import type { UpdateModelBody } from 'services/api/types'; const options: ComboboxOption[] = [ { value: 'none', label: '-' }, @@ -14,7 +14,7 @@ const options: ComboboxOption[] = [ ]; type Props = { - control: Control; + control: Control; }; const PredictionTypeSelect = ({ control }: Props) => { diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Model.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Model.tsx index 0393e322335..4a35a140e18 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Model.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Model.tsx @@ -5,6 +5,7 @@ import { memo, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { PiExclamationMarkBold } from 'react-icons/pi'; import { modelConfigsAdapterSelectors, useGetModelConfigsQuery } from 'services/api/endpoints/models'; +import type { AnyModelConfigWithExternal } from 'services/api/types'; import { ModelEdit } from './ModelEdit'; import { ModelView } from './ModelView'; @@ -21,7 +22,9 @@ export const Model = memo(() => { if (selectedModelKey === null) { return null; } - const modelConfig = modelConfigsAdapterSelectors.selectById(modelConfigs, selectedModelKey); + const modelConfig = modelConfigsAdapterSelectors.selectById(modelConfigs, selectedModelKey) as + | AnyModelConfigWithExternal + | undefined; if (!modelConfig) { return null; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelDeleteButton.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelDeleteButton.tsx index 92096e7454e..db6ba93cfe9 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelDeleteButton.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelDeleteButton.tsx @@ -7,11 +7,11 @@ import { memo, type MouseEvent, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { PiTrashSimpleBold } from 'react-icons/pi'; import { useDeleteModelsMutation } from 'services/api/endpoints/models'; -import type { AnyModelConfig } from 'services/api/types'; +import type { AnyModelConfigWithExternal } from 'services/api/types'; type Props = { showLabel?: boolean; - modelConfig: AnyModelConfig; + modelConfig: AnyModelConfigWithExternal; }; export const ModelDeleteButton = memo(({ showLabel = true, modelConfig }: Props) => { diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelEdit.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelEdit.tsx index d845eca3eec..272642c8867 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelEdit.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelEdit.tsx @@ -15,12 +15,18 @@ import { useAppDispatch } from 'app/store/storeHooks'; import { setSelectedModelMode } from 'features/modelManagerV2/store/modelManagerV2Slice'; import { ModelHeader } from 'features/modelManagerV2/subpanels/ModelPanel/ModelHeader'; import { toast } from 'features/toast/toast'; -import { memo, useCallback } from 'react'; -import { type SubmitHandler, useForm } from 'react-hook-form'; +import { memo, useCallback, useMemo } from 'react'; +import { type SubmitHandler, useForm, useWatch } from 'react-hook-form'; import { useTranslation } from 'react-i18next'; import { PiCheckBold, PiXBold } from 'react-icons/pi'; import { type UpdateModelArg, useUpdateModelMutation } from 'services/api/endpoints/models'; -import type { AnyModelConfig } from 'services/api/types'; +import { + type AnyModelConfigWithExternal, + type ExternalApiModelDefaultSettings, + type ExternalModelCapabilities, + isExternalApiModelConfig, + type UpdateModelBody, +} from 'services/api/types'; import BaseModelSelect from './Fields/BaseModelSelect'; import ModelFormatSelect from './Fields/ModelFormatSelect'; @@ -30,7 +36,14 @@ import PredictionTypeSelect from './Fields/PredictionTypeSelect'; import { ModelFooter } from './ModelFooter'; type Props = { - modelConfig: AnyModelConfig; + modelConfig: AnyModelConfigWithExternal; +}; + +type ModelEditFormValues = UpdateModelBody & { + capabilities?: ExternalModelCapabilities; + provider_id?: string; + provider_model_id?: string; + default_settings?: ExternalApiModelDefaultSettings | null; }; const stringFieldOptions = { @@ -41,19 +54,54 @@ export const ModelEdit = memo(({ modelConfig }: Props) => { const { t } = useTranslation(); const [updateModel, { isLoading: isSubmitting }] = useUpdateModelMutation(); const dispatch = useAppDispatch(); + const isExternal = useMemo(() => isExternalApiModelConfig(modelConfig), [modelConfig]); - const form = useForm({ - defaultValues: modelConfig, + const form = useForm({ + defaultValues: modelConfig as unknown as ModelEditFormValues, mode: 'onChange', }); - const onSubmit = useCallback>( + const externalModes = useWatch({ + control: form.control, + name: 'capabilities.modes', + }) as ExternalModelCapabilities['modes'] | undefined; + + const modeSet = useMemo(() => new Set(externalModes ?? []), [externalModes]); + + const toggleMode = useCallback( + (mode: ExternalModelCapabilities['modes'][number]) => { + const nextModes = modeSet.has(mode) + ? externalModes?.filter((value) => value !== mode) + : [...(externalModes ?? []), mode]; + form.setValue('capabilities.modes', nextModes ?? [], { shouldDirty: true, shouldValidate: true }); + }, + [externalModes, form, modeSet] + ); + + const handleToggleTxt2Img = useCallback(() => toggleMode('txt2img'), [toggleMode]); + const handleToggleImg2Img = useCallback(() => toggleMode('img2img'), [toggleMode]); + const handleToggleInpaint = useCallback(() => toggleMode('inpaint'), [toggleMode]); + + const parseOptionalNumber = useCallback((value: string | null | undefined) => { + if (value === null || value === undefined || value === '') { + return null; + } + if (typeof value !== 'string') { + return Number.isNaN(Number(value)) ? null : Number(value); + } + if (value.trim() === '') { + return null; + } + const parsed = Number(value); + return Number.isNaN(parsed) ? null : parsed; + }, []); + + const onSubmit = useCallback>( (values) => { const responseBody: UpdateModelArg = { key: modelConfig.key, - body: values, + body: values as UpdateModelBody, }; - updateModel(responseBody) .unwrap() .then((payload) => { @@ -160,6 +208,144 @@ export const ModelEdit = memo(({ modelConfig }: Props) => { + {isExternal && ( + <> + + {t('modelManager.externalProvider')} + + + + {t('modelManager.providerId')} + + + + {t('modelManager.providerModelId')} + + + + + {t('modelManager.externalCapabilities')} + + + + {t('modelManager.supportedModes')} + + + txt2img + + + img2img + + + inpaint + + + + + {t('modelManager.supportsNegativePrompt')} + + + + {t('modelManager.supportsReferenceImages')} + + + + {t('modelManager.supportsSeed')} + + + + {t('modelManager.supportsGuidance')} + + + + {t('modelManager.maxImagesPerRequest')} + + + + {t('modelManager.maxReferenceImages')} + + + + {t('modelManager.maxImageWidth')} + + + + {t('modelManager.maxImageHeight')} + + + + + {t('modelManager.externalDefaults')} + + + + {t('modelManager.width')} + + + + {t('modelManager.height')} + + + + {t('parameters.steps')} + + + + {t('parameters.guidance')} + + + + {t('modelManager.numImages')} + + + + + )} diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelFooter.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelFooter.tsx index f31609c4017..38135972e24 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelFooter.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelFooter.tsx @@ -1,7 +1,7 @@ import { Flex, Heading, type SystemStyleObject } from '@invoke-ai/ui-library'; import { memo } from 'react'; import { useTranslation } from 'react-i18next'; -import type { AnyModelConfig } from 'services/api/types'; +import type { AnyModelConfigWithExternal } from 'services/api/types'; import { ModelConvertButton } from './ModelConvertButton'; import { ModelDeleteButton } from './ModelDeleteButton'; @@ -20,7 +20,7 @@ const footerRowSx: SystemStyleObject = { }; type Props = { - modelConfig: AnyModelConfig; + modelConfig: AnyModelConfigWithExternal; isEditing: boolean; }; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelHeader.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelHeader.tsx index 1c1a05dbd02..2d6827ee6cf 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelHeader.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelHeader.tsx @@ -4,10 +4,10 @@ import ModelImageUpload from 'features/modelManagerV2/subpanels/ModelPanel/Field import type { PropsWithChildren } from 'react'; import { memo } from 'react'; import { useTranslation } from 'react-i18next'; -import type { AnyModelConfig } from 'services/api/types'; +import type { AnyModelConfigWithExternal } from 'services/api/types'; type Props = PropsWithChildren<{ - modelConfig: AnyModelConfig; + modelConfig: AnyModelConfigWithExternal; }>; export const ModelHeader = memo(({ modelConfig, children }: Props) => { diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelReidentifyButton.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelReidentifyButton.tsx index 31334c0510d..6136c71f985 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelReidentifyButton.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelReidentifyButton.tsx @@ -4,15 +4,18 @@ import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { PiSparkleFill } from 'react-icons/pi'; import { useReidentifyModelMutation } from 'services/api/endpoints/models'; -import type { AnyModelConfig } from 'services/api/types'; +import { type AnyModelConfigWithExternal, isExternalApiModelConfig } from 'services/api/types'; + +import { isExternalModel } from './isExternalModel'; interface Props { - modelConfig: AnyModelConfig; + modelConfig: AnyModelConfigWithExternal; } export const ModelReidentifyButton = memo(({ modelConfig }: Props) => { const { t } = useTranslation(); const [reidentifyModel, { isLoading }] = useReidentifyModelMutation(); + const isExternal = isExternalApiModelConfig(modelConfig) || isExternalModel(modelConfig.path); const onClick = useCallback(() => { reidentifyModel({ key: modelConfig.key }) @@ -40,6 +43,10 @@ export const ModelReidentifyButton = memo(({ modelConfig }: Props) => { }); }, [modelConfig.key, reidentifyModel, t]); + if (isExternal) { + return null; + } + return (