From f8c63c1a444126525bcb94cfc10ea7cf073c04d4 Mon Sep 17 00:00:00 2001 From: Lucian Hardy Date: Sun, 27 Apr 2025 01:24:30 +1000 Subject: [PATCH] feat(ui): model relationship management Adds full support for managing model-to-model relationships in the UI and backend. Introduces RelatedModels subpanel for linking and unlinking models in model management. - Adds REST API routes for adding, removing, and retrieving model relationships. - New database migration: creates model_relationships table for bidirectional links. - New service layer (model_relationships) for relationship management. - Updated frontend: Related models float to top of LoRA/Main grouped model comboboxes for quick access. - Added 'Show Only Related' toggle badge to MainModelPicker filter bar --- invokeai/app/api/dependencies.py | 6 + .../app/api/routers/model_relationships.py | 196 ++++++++++++ invokeai/app/api_app.py | 2 + invokeai/app/services/invocation_services.py | 6 + .../model_relationship_records_base.py | 57 ++++ .../model_relationship_records_sqlite.py | 89 ++++++ .../model_relationships_base.py | 42 +++ .../model_relationships_common.py | 8 + .../model_relationships_default.py | 30 ++ .../app/services/shared/sqlite/sqlite_util.py | 2 + .../migrations/migration_20.py | 37 +++ invokeai/frontend/web/public/locales/en.json | 2 + .../hooks/useRelatedGroupedModelCombobox.ts | 92 ++++++ .../src/common/hooks/useRelatedModelKeys.ts | 14 + .../src/common/hooks/useSelectedModelKeys.ts | 34 ++ .../features/lora/components/LoRASelect.tsx | 4 +- .../subpanels/ModelPanel/ModelView.tsx | 4 + .../subpanels/ModelPanel/RelatedModels.tsx | 300 ++++++++++++++++++ .../MainModel/ParamMainModelSelect.tsx | 4 +- .../MainModelPicker.tsx | 33 +- .../api/endpoints/modelRelationships.ts | 67 ++++ .../frontend/web/src/services/api/index.ts | 1 + .../frontend/web/src/services/api/schema.ts | 264 ++++++++++++++- 23 files changed, 1285 insertions(+), 9 deletions(-) create mode 100644 invokeai/app/api/routers/model_relationships.py create mode 100644 invokeai/app/services/model_relationship_records/model_relationship_records_base.py create mode 100644 invokeai/app/services/model_relationship_records/model_relationship_records_sqlite.py create mode 100644 invokeai/app/services/model_relationships/model_relationships_base.py create mode 100644 invokeai/app/services/model_relationships/model_relationships_common.py create mode 100644 invokeai/app/services/model_relationships/model_relationships_default.py create mode 100644 invokeai/app/services/shared/sqlite_migrator/migrations/migration_20.py create mode 100644 invokeai/frontend/web/src/common/hooks/useRelatedGroupedModelCombobox.ts create mode 100644 invokeai/frontend/web/src/common/hooks/useRelatedModelKeys.ts create mode 100644 invokeai/frontend/web/src/common/hooks/useSelectedModelKeys.ts create mode 100644 invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/RelatedModels.tsx create mode 100644 invokeai/frontend/web/src/services/api/endpoints/modelRelationships.ts diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index a5a7dbef9c7..83b8bb219d8 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -23,6 +23,8 @@ from invokeai.app.services.model_images.model_images_default import ModelImageFileStorageDisk from invokeai.app.services.model_manager.model_manager_default import ModelManagerService from invokeai.app.services.model_records.model_records_sql import ModelRecordServiceSQL +from invokeai.app.services.model_relationships.model_relationships_default import ModelRelationshipsService +from invokeai.app.services.model_relationship_records.model_relationship_records_sqlite import SqliteModelRelationshipRecordStorage from invokeai.app.services.names.names_default import SimpleNameService from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache @@ -136,6 +138,8 @@ def initialize( download_queue=download_queue_service, events=events, ) + model_relationships = ModelRelationshipsService() + model_relationship_records = SqliteModelRelationshipRecordStorage(db=db) names = SimpleNameService() performance_statistics = InvocationStatsService() session_processor = DefaultSessionProcessor(session_runner=DefaultSessionRunner()) @@ -161,6 +165,8 @@ def initialize( logger=logger, model_images=model_images_service, model_manager=model_manager, + model_relationships=model_relationships, + model_relationship_records=model_relationship_records, download_queue=download_queue_service, names=names, performance_statistics=performance_statistics, diff --git a/invokeai/app/api/routers/model_relationships.py b/invokeai/app/api/routers/model_relationships.py new file mode 100644 index 00000000000..d550cdcb787 --- /dev/null +++ b/invokeai/app/api/routers/model_relationships.py @@ -0,0 +1,196 @@ +"""FastAPI route for model relationship records.""" + +from fastapi import HTTPException, APIRouter, Path, Body, status +from pydantic import BaseModel, Field +from typing import List +from invokeai.app.api.dependencies import ApiDependencies + +model_relationships_router = APIRouter( + prefix="/v1/model_relationships", + tags=["model_relationships"] +) + +# === Schemas === + +class ModelRelationshipCreateRequest(BaseModel): + model_key_1: str = Field(..., description="The key of the first model in the relationship", examples=[ + "aa3b247f-90c9-4416-bfcd-aeaa57a5339e", + "ac32b914-10ab-496e-a24a-3068724b9c35", + "d944abfd-c7c3-42e2-a4ff-da640b29b8b4", + "b1c2d3e4-f5a6-7890-abcd-ef1234567890", + "12345678-90ab-cdef-1234-567890abcdef", + "fedcba98-7654-3210-fedc-ba9876543210" + ]) + model_key_2: str = Field(..., description="The key of the second model in the relationship", examples=[ + "3bb7c0eb-b6c8-469c-ad8c-4d69c06075e4", + "f0c3da4e-d9ff-42b5-a45c-23be75c887c9", + "38170dd8-f1e5-431e-866c-2c81f1277fcc", + "c57fea2d-7646-424c-b9ad-c0ba60fc68be", + "10f7807b-ab54-46a9-ab03-600e88c630a1", + "f6c1d267-cf87-4ee0-bee0-37e791eacab7" + ]) + +class ModelRelationshipBatchRequest(BaseModel): + model_keys: List[str] = Field(..., description="List of model keys to fetch related models for", examples= + [[ + "aa3b247f-90c9-4416-bfcd-aeaa57a5339e", + "ac32b914-10ab-496e-a24a-3068724b9c35", + ],[ + "b1c2d3e4-f5a6-7890-abcd-ef1234567890", + "12345678-90ab-cdef-1234-567890abcdef", + "fedcba98-7654-3210-fedc-ba9876543210" + ],[ + "3bb7c0eb-b6c8-469c-ad8c-4d69c06075e4", + ]]) + +# === Routes === + +@model_relationships_router.get( + "/i/{model_key}", + operation_id="get_related_models", + response_model=list[str], + responses={ + 200: { + "description": "A list of related model keys was retrieved successfully", + "content": { + "application/json": { + "example": [ + "15e9eb28-8cfe-47c9-b610-37907a79fc3c", + "71272e82-0e5f-46d5-bca9-9a61f4bd8a82", + "a5d7cd49-1b98-4534-a475-aeee4ccf5fa2" + ] + } + }, + }, + 404: {"description": "The specified model could not be found"}, + 422: {"description": "Validation error"}, + }, +) +async def get_related_models( + model_key: str = Path(..., description="The key of the model to get relationships for") + ) -> list[str]: + """ + Get a list of model keys related to a given model. + """ + try: + return ApiDependencies.invoker.services.model_relationships.get_related_model_keys(model_key) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@model_relationships_router.post( + "/", + status_code=status.HTTP_204_NO_CONTENT, + responses={ + 204: {"description": "The relationship was successfully created"}, + 400: {"description": "Invalid model keys or self-referential relationship"}, + 409: {"description": "The relationship already exists"}, + 422: {"description": "Validation error"}, + 500: {"description": "Internal server error"}, + }, + summary="Add Model Relationship", + description="Creates a **bidirectional** relationship between two models, allowing each to reference the other as related.", +) +async def add_model_relationship( + req: ModelRelationshipCreateRequest = Body(..., description="The model keys to relate") +) -> None: + """ + Add a relationship between two models. + + Relationships are bidirectional and will be accessible from both models. + + - Raises 400 if keys are invalid or identical. + - Raises 409 if the relationship already exists. + """ + try: + if req.model_key_1 == req.model_key_2: + raise HTTPException(status_code=400, detail="Cannot relate a model to itself.") + + ApiDependencies.invoker.services.model_relationships.add_model_relationship( + req.model_key_1, + req.model_key_2, + ) + except ValueError as e: + raise HTTPException(status_code=409, detail=str(e)) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@model_relationships_router.delete( + "/", + status_code=status.HTTP_204_NO_CONTENT, + responses={ + 204: {"description": "The relationship was successfully removed"}, + 400: {"description": "Invalid model keys or self-referential relationship"}, + 404: {"description": "The relationship does not exist"}, + 422: {"description": "Validation error"}, + 500: {"description": "Internal server error"}, + }, + summary="Remove Model Relationship", + description="Removes a **bidirectional** relationship between two models. The relationship must already exist." +) +async def remove_model_relationship( + req: ModelRelationshipCreateRequest = Body(..., description="The model keys to disconnect") +) -> None: + """ + Removes a bidirectional relationship between two model keys. + + - Raises 400 if attempting to unlink a model from itself. + - Raises 404 if the relationship was not found. + """ + try: + if req.model_key_1 == req.model_key_2: + raise HTTPException(status_code=400, detail="Cannot unlink a model from itself.") + + ApiDependencies.invoker.services.model_relationships.remove_model_relationship( + req.model_key_1, + req.model_key_2, + ) + except ValueError as e: + raise HTTPException(status_code=404, detail=str(e)) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + +@model_relationships_router.post( + "/batch", + operation_id="get_related_models_batch", + response_model=List[str], + responses={ + 200: { + "description": "Related model keys retrieved successfully", + "content": { + "application/json": { + "example": [ + "ca562b14-995e-4a42-90c1-9528f1a5921d", + "cc0c2b8a-c62e-41d6-878e-cc74dde5ca8f", + "18ca7649-6a9e-47d5-bc17-41ab1e8cec81", + "7c12d1b2-0ef9-4bec-ba55-797b2d8f2ee1", + "c382eaa3-0e28-4ab0-9446-408667699aeb", + "71272e82-0e5f-46d5-bca9-9a61f4bd8a82", + "a5d7cd49-1b98-4534-a475-aeee4ccf5fa2" + ] + } + } + }, + 422: {"description": "Validation error"}, + 500: {"description": "Internal server error"}, + }, + summary="Get Related Model Keys (Batch)", + description="Retrieves all **unique related model keys** for a list of given models. This is useful for contextual suggestions or filtering." +) +async def get_related_models_batch( + req: ModelRelationshipBatchRequest = Body(..., description="Model keys to check for related connections") + ) -> list[str]: + """ + Accepts multiple model keys and returns a flat list of all unique related keys. + + Useful when working with multiple selections in the UI or cross-model comparisons. + """ + try: + all_related: set[str] = set() + for key in req.model_keys: + related = ApiDependencies.invoker.services.model_relationships.get_related_model_keys(key) + all_related.update(related) + return list(all_related) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) \ No newline at end of file diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index fda232496e7..22b77748cf2 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -22,6 +22,7 @@ download_queue, images, model_manager, + model_relationships, session_queue, style_presets, utilities, @@ -125,6 +126,7 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint): app.include_router(images.images_router, prefix="/api") app.include_router(boards.boards_router, prefix="/api") app.include_router(board_images.board_images_router, prefix="/api") +app.include_router(model_relationships.model_relationships_router, prefix="/api") app.include_router(app_info.app_router, prefix="/api") app.include_router(session_queue.session_queue_router, prefix="/api") app.include_router(workflows.workflows_router, prefix="/api") diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index 933c57b4a08..3dbb2686adf 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -27,6 +27,8 @@ from invokeai.app.services.invocation_stats.invocation_stats_base import InvocationStatsServiceBase from invokeai.app.services.model_images.model_images_base import ModelImageFileStorageBase from invokeai.app.services.model_manager.model_manager_base import ModelManagerServiceBase + from invokeai.app.services.model_relationship_records.model_relationship_records_base import ModelRelationshipRecordStorageBase + from invokeai.app.services.model_relationships.model_relationships_base import ModelRelationshipsServiceABC from invokeai.app.services.names.names_base import NameServiceBase from invokeai.app.services.session_processor.session_processor_base import SessionProcessorBase from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase @@ -54,6 +56,8 @@ def __init__( logger: "Logger", model_images: "ModelImageFileStorageBase", model_manager: "ModelManagerServiceBase", + model_relationships: "ModelRelationshipsServiceABC", + model_relationship_records: "ModelRelationshipRecordStorageBase", download_queue: "DownloadQueueServiceBase", performance_statistics: "InvocationStatsServiceBase", session_queue: "SessionQueueBase", @@ -81,6 +85,8 @@ def __init__( self.logger = logger self.model_images = model_images self.model_manager = model_manager + self.model_relationships = model_relationships + self.model_relationship_records = model_relationship_records self.download_queue = download_queue self.performance_statistics = performance_statistics self.session_queue = session_queue diff --git a/invokeai/app/services/model_relationship_records/model_relationship_records_base.py b/invokeai/app/services/model_relationship_records/model_relationship_records_base.py new file mode 100644 index 00000000000..7921523db46 --- /dev/null +++ b/invokeai/app/services/model_relationship_records/model_relationship_records_base.py @@ -0,0 +1,57 @@ +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from invokeai.backend.model_manager.config import AnyModelConfig + +class ModelRelationshipRecordStorageBase(ABC): + """Abstract base class for model-to-model relationship record storage.""" + + @abstractmethod + def add_model_relationship(self, model_key_1: str, model_key_2: str) -> None: + """Creates a relationship between two models by keys.""" + pass + + @abstractmethod + def remove_model_relationship(self, model_key_1: str, model_key_2: str) -> None: + """Removes a relationship between two models by keys.""" + pass + + @abstractmethod + def get_related_model_keys(self, model_key: str) -> list[str]: + """Gets all models keys related to a given model key.""" + pass + + @abstractmethod + def get_related_model_keys_batch(self, model_keys: list[str]) -> list[str]: + """Get related model keys for multiple models given a list of keys.""" + pass + + @abstractmethod + def get_related_model_key_count(self, model_key: str) -> int: + """Gets the number of relations for a given model key.""" + pass + + """ Below are methods that use ModelConfigs instead of model keys, as convenience methods. + These methods are not required to be implemented, but they are potentially useful for later development. + They are not used in the current codebase.""" + + @abstractmethod + def add_relationship_from_models(self, model_1: "AnyModelConfig", model_2: "AnyModelConfig") -> None: + """Creates a relationship between two models using ModelConfigs.""" + pass + + @abstractmethod + def remove_relationship_from_models(self, model_1: "AnyModelConfig", model_2: "AnyModelConfig") -> None: + """Removes a relationship between two models using ModelConfigs.""" + pass + + @abstractmethod + def get_related_keys_from_model(self, model: "AnyModelConfig") -> list[str]: + """Gets all model keys related to a given model using it's config.""" + pass + + @abstractmethod + def get_related_model_key_count_from_model(self, model: "AnyModelConfig") -> int: + """Gets the number of relations for a given model config.""" + pass \ No newline at end of file diff --git a/invokeai/app/services/model_relationship_records/model_relationship_records_sqlite.py b/invokeai/app/services/model_relationship_records/model_relationship_records_sqlite.py new file mode 100644 index 00000000000..4f87f5ef4c2 --- /dev/null +++ b/invokeai/app/services/model_relationship_records/model_relationship_records_sqlite.py @@ -0,0 +1,89 @@ +from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase +import sqlite3 +from typing import cast, TYPE_CHECKING +from invokeai.app.services.model_relationship_records.model_relationship_records_base import ModelRelationshipRecordStorageBase +if TYPE_CHECKING: + from invokeai.backend.model_manager.config import AnyModelConfig + +class SqliteModelRelationshipRecordStorage(ModelRelationshipRecordStorageBase): + def __init__(self, db: SqliteDatabase) -> None: + super().__init__() + self._conn = db.conn + + def add_model_relationship(self, model_key_1: str, model_key_2: str) -> None: + if model_key_1 == model_key_2: + raise ValueError("Cannot relate a model to itself.") + a, b = sorted([model_key_1, model_key_2]) + try: + cursor = self._conn.cursor() + cursor.execute( + "INSERT OR IGNORE INTO model_relationships (model_key_1, model_key_2) VALUES (?, ?)", + (a, b), + ) + self._conn.commit() + except sqlite3.Error as e: + self._conn.rollback() + raise e + + def remove_model_relationship(self, model_key_1: str, model_key_2: str) -> None: + a, b = sorted([model_key_1, model_key_2]) + try: + cursor = self._conn.cursor() + cursor.execute( + "DELETE FROM model_relationships WHERE model_key_1 = ? AND model_key_2 = ?", + (a, b), + ) + self._conn.commit() + except sqlite3.Error as e: + self._conn.rollback() + raise e + + def get_related_model_keys(self, model_key: str) -> list[str]: + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT model_key_2 FROM model_relationships WHERE model_key_1 = ? + UNION + SELECT model_key_1 FROM model_relationships WHERE model_key_2 = ? + """, + (model_key, model_key), + ) + return [row[0] for row in cursor.fetchall()] + + def get_related_model_keys_batch(self, model_keys: list[str]) -> list[str]: + cursor = self._conn.cursor() + + key_list = ','.join('?' for _ in model_keys) + cursor.execute(f""" + SELECT model_key_2 FROM model_relationships WHERE model_key_1 IN ({key_list}) + UNION + SELECT model_key_1 FROM model_relationships WHERE model_key_2 IN ({key_list}) + """, + model_keys + model_keys + ) + return [row[0] for row in cursor.fetchall()] + + def get_related_model_key_count(self, model_key: str) -> int: + cursor = self._conn.execute( + """ + SELECT COUNT(*) FROM ( + SELECT model_key_2 FROM model_relationships WHERE model_key_1 = ? + UNION + SELECT model_key_1 FROM model_relationships WHERE model_key_2 = ? + ) + """, + (model_key, model_key), + ) + return cast(int, cursor.fetchone()[0]) + + def add_relationship_from_models(self, model_1: "AnyModelConfig", model_2: "AnyModelConfig") -> None: + self.add_model_relationship(model_1.key, model_2.key) + + def remove_relationship_from_models(self, model_1: "AnyModelConfig", model_2: "AnyModelConfig") -> None: + self.remove_model_relationship(model_1.key, model_2.key) + + def get_related_keys_from_model(self, model: "AnyModelConfig") -> list[str]: + return self.get_related_model_keys(model.key) + + def get_related_model_key_count_from_model(self, model: "AnyModelConfig") -> int: + return self.get_related_model_key_count(model.key) \ No newline at end of file diff --git a/invokeai/app/services/model_relationships/model_relationships_base.py b/invokeai/app/services/model_relationships/model_relationships_base.py new file mode 100644 index 00000000000..b60404d5710 --- /dev/null +++ b/invokeai/app/services/model_relationships/model_relationships_base.py @@ -0,0 +1,42 @@ +from abc import ABC, abstractmethod + +from invokeai.backend.model_manager.config import AnyModelConfig + + +class ModelRelationshipsServiceABC(ABC): + """High-level service for managing model-to-model relationships.""" + + @abstractmethod + def add_model_relationship(self, model_key_1: str, model_key_2: str) -> None: + """Creates a relationship between two models keys.""" + pass + + @abstractmethod + def remove_model_relationship(self, model_key_1: str, model_key_2: str) -> None: + """Removes a relationship between two models keys.""" + pass + + @abstractmethod + def get_related_model_keys(self, model_key: str) -> list[str]: + """Gets all models keys related to a given model key.""" + pass + + @abstractmethod + def get_related_model_keys_batch(self, model_keys: list[str]) -> list[str]: + """Get related model keys for multiple models.""" + pass + + @abstractmethod + def add_relationship_from_models(self, model_1: AnyModelConfig, model_2: AnyModelConfig) -> None: + """Creates a relationship from model objects.""" + pass + + @abstractmethod + def remove_relationship_from_models(self, model_1: AnyModelConfig, model_2: AnyModelConfig) -> None: + """Removes a relationship from model objects.""" + pass + + @abstractmethod + def get_related_keys_from_model(self, model: AnyModelConfig) -> list[str]: + """Gets all model keys related to a given model object.""" + pass \ No newline at end of file diff --git a/invokeai/app/services/model_relationships/model_relationships_common.py b/invokeai/app/services/model_relationships/model_relationships_common.py new file mode 100644 index 00000000000..6170b549d92 --- /dev/null +++ b/invokeai/app/services/model_relationships/model_relationships_common.py @@ -0,0 +1,8 @@ +from invokeai.app.util.model_exclude_null import BaseModelExcludeNull +from datetime import datetime + + +class ModelRelationship(BaseModelExcludeNull): + model_key_1: str + model_key_2: str + created_at: datetime \ No newline at end of file diff --git a/invokeai/app/services/model_relationships/model_relationships_default.py b/invokeai/app/services/model_relationships/model_relationships_default.py new file mode 100644 index 00000000000..1e6f338661e --- /dev/null +++ b/invokeai/app/services/model_relationships/model_relationships_default.py @@ -0,0 +1,30 @@ +from invokeai.backend.model_manager.config import AnyModelConfig +from .model_relationships_base import ModelRelationshipsServiceABC +from invokeai.app.services.invoker import Invoker + +class ModelRelationshipsService(ModelRelationshipsServiceABC): + __invoker: Invoker + + def start(self, invoker: Invoker) -> None: + self.__invoker = invoker + + def add_model_relationship(self, model_key_1: str, model_key_2: str) -> None: + self.__invoker.services.model_relationship_records.add_model_relationship(model_key_1, model_key_2) + + def remove_model_relationship(self, model_key_1: str, model_key_2: str) -> None: + self.__invoker.services.model_relationship_records.remove_model_relationship(model_key_1, model_key_2) + + def get_related_model_keys(self, model_key: str) -> list[str]: + return self.__invoker.services.model_relationship_records.get_related_model_keys(model_key) + + def add_relationship_from_models(self, model_1: AnyModelConfig, model_2: AnyModelConfig) -> None: + self.add_model_relationship(model_1.key, model_2.key) + + def remove_relationship_from_models(self, model_1: AnyModelConfig, model_2: AnyModelConfig) -> None: + self.remove_model_relationship(model_1.key, model_2.key) + + def get_related_keys_from_model(self, model: AnyModelConfig) -> list[str]: + return self.get_related_model_keys(model.key) + + def get_related_model_keys_batch(self, model_keys: list[str]) -> list[str]: + return self.__invoker.services.model_relationship_records.get_related_model_keys_batch(model_keys) \ No newline at end of file diff --git a/invokeai/app/services/shared/sqlite/sqlite_util.py b/invokeai/app/services/shared/sqlite/sqlite_util.py index 233bb72cda2..7c825616c16 100644 --- a/invokeai/app/services/shared/sqlite/sqlite_util.py +++ b/invokeai/app/services/shared/sqlite/sqlite_util.py @@ -22,6 +22,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_17 import build_migration_17 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_18 import build_migration_18 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_19 import build_migration_19 +from invokeai.app.services.shared.sqlite_migrator.migrations.migration_20 import build_migration_20 from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator @@ -61,6 +62,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto migrator.register_migration(build_migration_17()) migrator.register_migration(build_migration_18()) migrator.register_migration(build_migration_19(app_config=config)) + migrator.register_migration(build_migration_20()) migrator.run_migrations() return db diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_20.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_20.py new file mode 100644 index 00000000000..6b2050fac0d --- /dev/null +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_20.py @@ -0,0 +1,37 @@ +import sqlite3 + +from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration + +class Migration20Callback: + + def __call__(self, cursor: sqlite3.Cursor) -> None: + cursor.execute( + """ + -- many-to-many relationship table for models + CREATE TABLE IF NOT EXISTS model_relationships ( + -- model_key_1 and model_key_2 are the same as the key(primary key) in the models table + model_key_1 TEXT NOT NULL, + model_key_2 TEXT NOT NULL, + created_at TEXT DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), + PRIMARY KEY (model_key_1, model_key_2), + -- model_key_1 < model_key_2, to ensure uniqueness and prevent duplicates + FOREIGN KEY (model_key_1) REFERENCES models(id) ON DELETE CASCADE, + FOREIGN KEY (model_key_2) REFERENCES models(id) ON DELETE CASCADE + ); + """ + ) + cursor.execute( + """ + -- Creates an index to keep performance equal when searching for model_key_1 or model_key_2 + CREATE INDEX IF NOT EXISTS keyx_model_relationships_model_key_2 + ON model_relationships(model_key_2) + """ + ) + + +def build_migration_20() -> Migration: + return Migration( + from_version=19, + to_version=20, + callback=Migration20Callback(), + ) \ No newline at end of file diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index b885df54528..f6d07a4831d 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -840,6 +840,8 @@ "predictionType": "Prediction Type", "prune": "Prune", "pruneTooltip": "Prune finished imports from queue", + "relatedModels": "Related Models", + "showOnlyRelatedModels": "Related", "repo_id": "Repo ID", "repoVariant": "Repo Variant", "scanFolder": "Scan Folder", diff --git a/invokeai/frontend/web/src/common/hooks/useRelatedGroupedModelCombobox.ts b/invokeai/frontend/web/src/common/hooks/useRelatedGroupedModelCombobox.ts new file mode 100644 index 00000000000..af14f5460e6 --- /dev/null +++ b/invokeai/frontend/web/src/common/hooks/useRelatedGroupedModelCombobox.ts @@ -0,0 +1,92 @@ +import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library'; +import type { GroupBase } from 'chakra-react-select'; +import type { ModelIdentifierField } from 'features/nodes/types/common'; +import { useTranslation } from 'react-i18next'; +import type { AnyModelConfig } from 'services/api/types'; + +import { useGroupedModelCombobox } from './useGroupedModelCombobox'; +import { useRelatedModelKeys } from './useRelatedModelKeys'; +import { useSelectedModelKeys } from './useSelectedModelKeys'; + +type UseRelatedGroupedModelComboboxArg = { + modelConfigs: T[]; + selectedModel?: ModelIdentifierField | null; + onChange: (value: T | null) => void; + getIsDisabled?: (model: T) => boolean; + isLoading?: boolean; + groupByType?: boolean; +}; + +// Custom hook to overlay the grouped model combobox with related models on top! +// Cleaner than hooking into useGroupedModelCombobox with a flag to enable/disable the related models +// Also allows for related models to be shown conditionally with some pretty simple logic if it ends up as a config flag. + +type UseRelatedGroupedModelComboboxReturn = { + value: ComboboxOption | undefined | null; + options: GroupBase[]; + onChange: ComboboxOnChange; + placeholder: string; + noOptionsMessage: () => string; +}; + +export function useRelatedGroupedModelCombobox({ + modelConfigs, + selectedModel, + onChange, + isLoading = false, + getIsDisabled, + groupByType, +}: UseRelatedGroupedModelComboboxArg): UseRelatedGroupedModelComboboxReturn { + const { t } = useTranslation(); + + const selectedKeys = useSelectedModelKeys(); + + const relatedKeys = useRelatedModelKeys(selectedKeys); + + // Base grouped options + const base = useGroupedModelCombobox({ + modelConfigs, + selectedModel, + onChange, + getIsDisabled, + isLoading, + groupByType, + }); + + // If no related models selected, just return base + if (relatedKeys.size === 0) { + return base; + } + + const relatedOptions: ComboboxOption[] = []; + const updatedGroups: GroupBase[] = []; + + for (const group of base.options) { + const remainingOptions: ComboboxOption[] = []; + + for (const option of group.options) { + if (relatedKeys.has(option.value)) { + relatedOptions.push({ ...option, label: `* ${option.label}` }); + } else { + remainingOptions.push(option); + } + } + + if (remainingOptions.length > 0) { + updatedGroups.push({ + label: group.label, + options: remainingOptions, + }); + } + } + + const finalOptions: GroupBase[] = + relatedOptions.length > 0 + ? [{ label: t('modelManager.relatedModels'), options: relatedOptions }, ...updatedGroups] + : updatedGroups; + + return { + ...base, + options: finalOptions, + }; +} diff --git a/invokeai/frontend/web/src/common/hooks/useRelatedModelKeys.ts b/invokeai/frontend/web/src/common/hooks/useRelatedModelKeys.ts new file mode 100644 index 00000000000..fc0711b969e --- /dev/null +++ b/invokeai/frontend/web/src/common/hooks/useRelatedModelKeys.ts @@ -0,0 +1,14 @@ +import { useMemo } from 'react'; +import { useGetRelatedModelIdsBatchQuery } from 'services/api/endpoints/modelRelationships'; + +/** + * Fetches related model keys for a given set of selected model keys. + * Returns a Set for fast lookup. + */ +export const useRelatedModelKeys = (selectedKeys: Set) => { + const { data: related = [] } = useGetRelatedModelIdsBatchQuery([...selectedKeys], { + skip: selectedKeys.size === 0, + }); + + return useMemo(() => new Set(related), [related]); +}; diff --git a/invokeai/frontend/web/src/common/hooks/useSelectedModelKeys.ts b/invokeai/frontend/web/src/common/hooks/useSelectedModelKeys.ts new file mode 100644 index 00000000000..83e1d3ac597 --- /dev/null +++ b/invokeai/frontend/web/src/common/hooks/useSelectedModelKeys.ts @@ -0,0 +1,34 @@ +import { useAppSelector } from 'app/store/storeHooks'; + +/** + * Gathers all currently selected model keys from parameters and loras. + * This includes the main model, VAE, refiner model, controlnet, and loras. + */ +export const useSelectedModelKeys = () => { + return useAppSelector((state) => { + const keys = new Set(); + const main = state.params.model; + const vae = state.params.vae; + const refiner = state.params.refinerModel; + const controlnet = state.params.controlLora; + const loras = state.loras.loras.map((l) => l.model); + + if (main) { + keys.add(main.key); + } + if (vae) { + keys.add(vae.key); + } + if (refiner) { + keys.add(refiner.key); + } + if (controlnet) { + keys.add(controlnet.key); + } + for (const lora of loras) { + keys.add(lora.key); + } + + return keys; + }); +}; diff --git a/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx b/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx index 90dacac0fa4..c6e8091c824 100644 --- a/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx +++ b/invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx @@ -3,7 +3,7 @@ import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library'; import { createSelector } from '@reduxjs/toolkit'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover'; -import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; +import { useRelatedGroupedModelCombobox } from 'common/hooks/useRelatedGroupedModelCombobox'; import { loraAdded, selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice'; import { selectBase } from 'features/controlLayers/store/paramsSlice'; import { NavigateToModelManagerButton } from 'features/parameters/components/MainModel/NavigateToModelManagerButton'; @@ -38,7 +38,7 @@ const LoRASelect = () => { [dispatch] ); - const { options, onChange } = useGroupedModelCombobox({ + const { options, onChange } = useRelatedGroupedModelCombobox({ modelConfigs, getIsDisabled, onChange: _onChange, diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelView.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelView.tsx index 42e0689733e..cb17dcf00c0 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelView.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelView.tsx @@ -11,6 +11,7 @@ import type { AnyModelConfig } from 'services/api/types'; import { MainModelDefaultSettings } from './MainModelDefaultSettings/MainModelDefaultSettings'; import { ModelAttrView } from './ModelAttrView'; +import { RelatedModels } from './RelatedModels'; type Props = { modelConfig: AnyModelConfig; @@ -83,6 +84,9 @@ export const ModelView = memo(({ modelConfig }: Props) => { )} )} + + + ); diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/RelatedModels.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/RelatedModels.tsx new file mode 100644 index 00000000000..f1846c57d76 --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/RelatedModels.tsx @@ -0,0 +1,300 @@ +/** + * RelatedModels.tsx + * + * Panel for managing and displaying model-to-model relationships. + * + * Allows adding/removing bidirectional links between models, organized visually + * with color-coded tags, dividers between types, and sorted dropdown selection. + */ + +import { + Box, + Button, + Combobox, + Divider, + Flex, + FormControl, + FormErrorMessage, + FormLabel, + Tag, + TagCloseButton, + TagLabel, + Tooltip, +} from '@invoke-ai/ui-library'; +import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; +import { memo, useCallback, useMemo, useState } from 'react'; +import { useTranslation } from 'react-i18next'; +import { PiPlusBold } from 'react-icons/pi'; +import { + useAddModelRelationshipMutation, + useGetRelatedModelIdsQuery, + useRemoveModelRelationshipMutation, +} from 'services/api/endpoints/modelRelationships'; +import { useGetModelConfigsQuery } from 'services/api/endpoints/models'; +import type { AnyModelConfig } from 'services/api/types'; + +type Props = { + modelConfig: AnyModelConfig; +}; + +// Determines if two models are compatible for relationship linking based on their base type. +// +// Models with a base of 'any' are considered universally compatible. +// This is a known flaw: 'any'-based links may allow relationships that are +// meaningless in practice and could bloat the database over time. +// +// TODO: In the future, refine this logic to more strictly validate +// relationships based on model types or actual usage patterns. +const isBaseCompatible = (a: AnyModelConfig, b: AnyModelConfig): boolean => { + if (a.base === 'any' || b.base === 'any') { + return true; + } + return a.base === b.base; +}; + +export const RelatedModels = memo(({ modelConfig }: Props) => { + const { t } = useTranslation(); + const [addModelRelationship, { isLoading: isAdding }] = useAddModelRelationshipMutation(); + const [removeModelRelationship, { isLoading: isRemoving }] = useRemoveModelRelationshipMutation(); + const isLoading = isAdding || isRemoving; + const [selectedKey, setSelectedKey] = useState(''); + const { data: modelConfigs } = useGetModelConfigsQuery(); + const { data: relatedModels = [] } = useGetRelatedModelIdsQuery(modelConfig.key); + const relatedIDs = useMemo(() => new Set(relatedModels), [relatedModels]); + // Used to prioritize certain model types in UI sorting + const MODEL_TYPE_PRIORITY = useMemo(() => ['main', 'lora'], []); + + //Get all modelConfigs that are not already related to the current model. + const availableModels = useMemo(() => { + if (!modelConfigs) { + return []; + } + + return Object.values(modelConfigs.entities).filter( + (m): m is AnyModelConfig => + !!m && + m.key !== modelConfig.key && + !relatedIDs.has(m.key) && + isBaseCompatible(modelConfig, m) && + !(modelConfig.type === 'main' && m.type === 'main') // still block main↔main + ); + }, [modelConfigs, modelConfig, relatedIDs]); + + // Tracks validation errors for current input (e.g., duplicate key or no selection). + const errors = useMemo(() => { + const errs: string[] = []; + if (!selectedKey) { + return errs; + } + if (relatedIDs.has(selectedKey)) { + errs.push('Item already promoted'); + } + return errs; + }, [selectedKey, relatedIDs]); + + // Handles linking a selected model to the current one via API. + const handleAdd = useCallback(async () => { + const target = availableModels.find((m) => m.key === selectedKey); + if (!target) { + return; + } + + setSelectedKey(''); + await Promise.all([addModelRelationship({ model_key_1: modelConfig.key, model_key_2: target.key })]); + }, [modelConfig, availableModels, addModelRelationship, selectedKey]); + + const { + options, + onChange: comboboxOnChange, + placeholder, + noOptionsMessage, + } = useGroupedModelCombobox({ + modelConfigs: availableModels, + selectedModel: null, + onChange: (model) => { + if (!model) { + return; + } + setSelectedKey(model.key); + }, + groupByType: true, + }); + + // Unlinks an existing related model via API. + const handleRemove = useCallback( + async (id: string) => { + const target = modelConfigs?.entities[id]; + if (!target) { + return; + } + + await Promise.all([removeModelRelationship({ model_key_1: modelConfig.key, model_key_2: target.key })]); + }, + [modelConfig, modelConfigs, removeModelRelationship] + ); + + // Finds the selected model's combobox option to control current dropdown state. + const selectedOption = useMemo(() => { + return options.flatMap((group) => group.options).find((o) => o.value === selectedKey) ?? null; + }, [selectedKey, options]); + + const makeRemoveHandler = useCallback((id: string) => () => handleRemove(id), [handleRemove]); + + // Defines custom tag colors for model types in the UI. + // + // The default UI color scheme (mostly grey and orange) felt too flat, + // so this mapping provides a slightly more expressive color flow. + // + // Note: This is purely aesthetic. Safe to remove if project preferences change. + const getModelTagColor = (type: string): string => { + switch (type) { + case 'main': + case 'checkpoint': + return 'orange'; + case 'lora': + case 'lycoris': + return 'purple'; + case 'embedding': + case 'embedding_file': + return 'teal'; + case 'vae': + return 'blue'; + case 'controlnet': + case 'ip_adapter': + case 't2i_adapter': + return 'cyan'; + case 'onnx': + case 'bnb_quantized_int8b': + case 'bnb_quantized_nf4b': + case 'gguf_quantized': + return 'pink'; + case 't5_encoder': + case 'clip_embed': + case 'clip_vision': + case 'siglip': + return 'green'; + default: + return 'base'; + } + }; + + // Force group priority order: Main first, then LoRA + const getTypeFromLabel = (label: string): string => label.split('/')[1]?.trim().toLowerCase() || ''; + + const sortedOptions = useMemo(() => { + return [...options].sort((a, b) => { + const aType = getTypeFromLabel(a.label ?? ''); + const bType = getTypeFromLabel(b.label ?? ''); + + const aIndex = MODEL_TYPE_PRIORITY.indexOf(aType); + const bIndex = MODEL_TYPE_PRIORITY.indexOf(bType); + + const aScore = aIndex === -1 ? 99 : aIndex; + const bScore = bIndex === -1 ? 99 : bIndex; + + return aScore - bScore; + }); + }, [options, MODEL_TYPE_PRIORITY]); + + return ( + + {t('modelManager.relatedModels')} + 0}> + + + + + {errors.map((error) => ( + {error} + ))} + + + + { + // Render the related model tags as styled components. + // + // Models are grouped visually by type, sorted with 'main' and 'lora' types at the front. + // A vertical Divider is inserted when the type changes between adjacent models. + // Tags include: + // - Colored background based on model type (via getModelTagColor) + // - Tooltip showing ": " + // - Ellipsis-truncated tag name for compact layout + // - A close button to remove the relationship + [...relatedModels] + .sort((aKey, bKey) => { + const a = modelConfigs?.entities[aKey]; + const b = modelConfigs?.entities[bKey]; + if (!a || !b) { + return 0; + } + + // Floats Mains and LoRAs to the front + const aPriority = MODEL_TYPE_PRIORITY.indexOf(a.type); + const bPriority = MODEL_TYPE_PRIORITY.indexOf(b.type); + + const aScore = aPriority === -1 ? 99 : aPriority; + const bScore = bPriority === -1 ? 99 : bPriority; + + return aScore - bScore || a.type.localeCompare(b.type) || a.name.localeCompare(b.name); + }) + .reduce((acc, id, index, arr) => { + const model = modelConfigs?.entities[id]; + if (!model) { + return acc; + } + + const modelName = model.name ?? id; + const modelType = model.type ?? 'unknown'; + const modelTypeLabel = modelType.replace(/_/g, ' ').replace(/\b\w/g, (c) => c.toUpperCase()); + + // Create a divider if the previous model is of a different type. Just a small dash of visual flair. + const prevId = index > 0 ? arr[index - 1] : undefined; + const prevModel = prevId ? modelConfigs?.entities[prevId] : null; + const needsDivider = prevModel && prevModel.type !== model.type; + + if (needsDivider) { + acc.push(); + } + + acc.push( + + + + {modelName} + + + + + ); + + return acc; + }, []) + } + + + + ); +}); + +RelatedModels.displayName = 'RelatedModels'; diff --git a/invokeai/frontend/web/src/features/parameters/components/MainModel/ParamMainModelSelect.tsx b/invokeai/frontend/web/src/features/parameters/components/MainModel/ParamMainModelSelect.tsx index ecebcb63a2a..9261cfb5880 100644 --- a/invokeai/frontend/web/src/features/parameters/components/MainModel/ParamMainModelSelect.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/MainModel/ParamMainModelSelect.tsx @@ -1,7 +1,7 @@ import { Box, Combobox, Flex, FormControl, FormLabel, Icon, Tooltip } from '@invoke-ai/ui-library'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover'; -import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; +import { useRelatedGroupedModelCombobox } from 'common/hooks/useRelatedGroupedModelCombobox'; import { selectModelKey } from 'features/controlLayers/store/paramsSlice'; import { zModelIdentifierField } from 'features/nodes/types/common'; import { NavigateToModelManagerButton } from 'features/parameters/components/MainModel/NavigateToModelManagerButton'; @@ -66,7 +66,7 @@ const ParamMainModelSelect = () => { [activeTabName] ); - const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({ + const { options, value, onChange, placeholder, noOptionsMessage } = useRelatedGroupedModelCombobox({ modelConfigs, selectedModel, onChange: _onChange, diff --git a/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/MainModelPicker.tsx b/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/MainModelPicker.tsx index f681fc26208..d71c7ec8225 100644 --- a/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/MainModelPicker.tsx +++ b/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/MainModelPicker.tsx @@ -22,6 +22,8 @@ import { InformationalPopover } from 'common/components/InformationalPopover/Inf import type { Group, ImperativeModelPickerHandle } from 'common/components/Picker/Picker'; import { getRegex, Picker, usePickerContext } from 'common/components/Picker/Picker'; import { useDisclosure } from 'common/hooks/useBoolean'; +import { useRelatedModelKeys } from 'common/hooks/useRelatedModelKeys'; +import { useSelectedModelKeys } from 'common/hooks/useSelectedModelKeys'; import { fixedForwardRef } from 'common/util/fixedForwardRef'; import { typedMemo } from 'common/util/typedMemo'; import { $installModelsTab } from 'features/modelManagerV2/subpanels/InstallModels'; @@ -55,6 +57,8 @@ type PickerExtraContext = { toggleBaseModelTypeFilter: (baseModelType: BaseModelType) => void; basesWithModels: BaseModelType[]; baseModelTypeFilters: BaseModelTypeFilters; + showOnlyRelated: boolean; + setShowOnlyRelated: (show: boolean) => void; }; const ModelManagerLink = memo((props: ButtonProps) => { @@ -86,6 +90,10 @@ NoOptionsFallback.displayName = 'NoOptionsFallback'; export const MainModelPicker = memo(() => { const { t } = useTranslation(); + const [showOnlyRelated, setShowOnlyRelated] = useState(false); + const selectedKeys = useSelectedModelKeys(); + const allRelatedModels = useRelatedModelKeys(selectedKeys); + const relatedKeys = useMemo(() => new Set(allRelatedModels), [allRelatedModels]); const [modelConfigs] = useMainModels(); const basesWithModels = useMemo(() => { const bases: BaseModelType[] = []; @@ -124,8 +132,8 @@ export const MainModelPicker = memo(() => { [basesWithModels] ); const extra = useMemo( - () => ({ toggleBaseModelTypeFilter, basesWithModels, baseModelTypeFilters }), - [toggleBaseModelTypeFilter, basesWithModels, baseModelTypeFilters] + () => ({ toggleBaseModelTypeFilter, basesWithModels, baseModelTypeFilters, showOnlyRelated, setShowOnlyRelated, }), + [toggleBaseModelTypeFilter, basesWithModels, baseModelTypeFilters, showOnlyRelated, setShowOnlyRelated] ); const grouped = useMemo[]>(() => { // When all groups are disabled, we show all models @@ -135,6 +143,9 @@ export const MainModelPicker = memo(() => { } = {}; for (const modelConfig of modelConfigs) { + if (showOnlyRelated && !relatedKeys.has(modelConfig.key)) { + continue; + } let group = groups[modelConfig.base]; if (!group && (baseModelTypeFilters[modelConfig.base] || areAllGroupsDisabled)) { group = { @@ -174,7 +185,7 @@ export const MainModelPicker = memo(() => { sortedGroups.push(...Object.values(groups)); return sortedGroups; - }, [baseModelTypeFilters, modelConfigs]); + }, [baseModelTypeFilters, modelConfigs, showOnlyRelated, relatedKeys]); const modelConfig = useSelectedModelConfig(); const popover = useDisclosure(false); const pickerRef = useRef(null); @@ -251,6 +262,9 @@ const SearchBarComponent = typedMemo( const onClearSearchTerm = useCallback(() => { setSearchTerm(''); }, [setSearchTerm]); + const onToggleShowOnlyRelated = useCallback(() => { + extra.setShowOnlyRelated(!extra.showOnlyRelated); + }, [extra]); return ( @@ -280,6 +294,19 @@ const SearchBarComponent = typedMemo( /> + + {t('modelManager.showOnlyRelatedModels')} + {extra.basesWithModels.map((base) => ( ))} diff --git a/invokeai/frontend/web/src/services/api/endpoints/modelRelationships.ts b/invokeai/frontend/web/src/services/api/endpoints/modelRelationships.ts new file mode 100644 index 00000000000..ffb815aac25 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/endpoints/modelRelationships.ts @@ -0,0 +1,67 @@ +/** + * modelRelationships.ts + * + * RTK Query API slice for managing model-to-model relationships. + * + * Endpoints provided: + * - Fetch related models for a single model + * - Add a relationship between two models + * - Remove a relationship between two models + * - Fetch related models for multiple models in batch + * + * Provides and invalidates cache tags for seamless UI updates after add/remove operations. + */ + +import { api } from '..'; + +const REL_TAG = 'ModelRelationships'; // Needed for UI updates on relationship changes. + +const modelRelationshipsApi = api.injectEndpoints({ + endpoints: (build) => ({ + getRelatedModelIds: build.query({ + query: (model_key) => `/api/v1/model_relationships/i/${model_key}`, + providesTags: (result, error, model_key) => [{ type: REL_TAG, id: model_key }], + }), + + addModelRelationship: build.mutation({ + query: (payload) => ({ + url: `/api/v1/model_relationships/`, + method: 'POST', + body: payload, + }), + invalidatesTags: (result, error, { model_key_1, model_key_2 }) => [ + { type: REL_TAG, id: model_key_1 }, + { type: REL_TAG, id: model_key_2 }, + ], + }), + + removeModelRelationship: build.mutation({ + query: (payload) => ({ + url: `/api/v1/model_relationships/`, + method: 'DELETE', + body: payload, + }), + invalidatesTags: (result, error, { model_key_1, model_key_2 }) => [ + { type: REL_TAG, id: model_key_1 }, + { type: REL_TAG, id: model_key_2 }, + ], + }), + + getRelatedModelIdsBatch: build.query({ + query: (model_keys) => ({ + url: `/api/v1/model_relationships/batch`, + method: 'POST', + body: { model_keys }, + }), + providesTags: (result, error, model_keys) => model_keys.map((key) => ({ type: 'ModelRelationships', id: key })), + }), + }), + overrideExisting: false, +}); + +export const { + useGetRelatedModelIdsQuery, + useAddModelRelationshipMutation, + useRemoveModelRelationshipMutation, + useGetRelatedModelIdsBatchQuery, +} = modelRelationshipsApi; diff --git a/invokeai/frontend/web/src/services/api/index.ts b/invokeai/frontend/web/src/services/api/index.ts index 8740e465b6f..7bc59202a46 100644 --- a/invokeai/frontend/web/src/services/api/index.ts +++ b/invokeai/frontend/web/src/services/api/index.ts @@ -34,6 +34,7 @@ const tagTypes = [ 'InvocationCacheStatus', 'ModelConfig', 'ModelInstalls', + 'ModelRelationships', 'ModelScanFolderResults', 'T2IAdapterModel', 'MainModel', diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 7eb6a9e4689..ba2c63f73d3 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -867,6 +867,70 @@ export type paths = { patch?: never; trace?: never; }; + "/api/v1/model_relationships/i/{model_key}": { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + /** + * Get Related Models + * @description Get a list of model keys related to a given model. + */ + get: operations["get_related_models"]; + put?: never; + post?: never; + delete?: never; + options?: never; + head?: never; + patch?: never; + trace?: never; + }; + "/api/v1/model_relationships/": { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + get?: never; + put?: never; + /** + * Add Model Relationship + * @description Creates a **bidirectional** relationship between two models, allowing each to reference the other as related. + */ + post: operations["add_model_relationship_api_v1_model_relationships__post"]; + /** + * Remove Model Relationship + * @description Removes a **bidirectional** relationship between two models. The relationship must already exist. + */ + delete: operations["remove_model_relationship_api_v1_model_relationships__delete"]; + options?: never; + head?: never; + patch?: never; + trace?: never; + }; + "/api/v1/model_relationships/batch": { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + get?: never; + put?: never; + /** + * Get Related Model Keys (Batch) + * @description Retrieves all **unique related model keys** for a list of given models. This is useful for contextual suggestions or filtering. + */ + post: operations["get_related_models_batch"]; + delete?: never; + options?: never; + head?: never; + patch?: never; + trace?: never; + }; "/api/v1/app/version": { parameters: { query?: never; @@ -11922,14 +11986,14 @@ export type components = { * Convert Cache Dir * Format: path * @description Path to the converted models cache directory (DEPRECATED, but do not delete because it is needed for migration from previous versions). - * @default models/.convert_cache + * @default models\.convert_cache */ convert_cache_dir?: string; /** * Download Cache Dir * Format: path * @description Path to the directory that contains dynamically downloaded models. - * @default models/.download_cache + * @default models\.download_cache */ download_cache_dir?: string; /** @@ -16442,6 +16506,27 @@ export type components = { */ config_path?: string | null; }; + /** ModelRelationshipBatchRequest */ + ModelRelationshipBatchRequest: { + /** + * Model Keys + * @description List of model keys to fetch related models for + */ + model_keys: string[]; + }; + /** ModelRelationshipCreateRequest */ + ModelRelationshipCreateRequest: { + /** + * Model Key 1 + * @description The key of the first model in the relationship + */ + model_key_1: string; + /** + * Model Key 2 + * @description The key of the second model in the relationship + */ + model_key_2: string; + }; /** * ModelRepoVariant * @description Various hugging face variants on the diffusers format. @@ -23618,6 +23703,181 @@ export interface operations { }; }; }; + get_related_models: { + parameters: { + query?: never; + header?: never; + path: { + /** @description The key of the model to get relationships for */ + model_key: string; + }; + cookie?: never; + }; + requestBody?: never; + responses: { + /** @description A list of related model keys was retrieved successfully */ + 200: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": string[]; + }; + }; + /** @description The specified model could not be found */ + 404: { + headers: { + [name: string]: unknown; + }; + content?: never; + }; + /** @description Validation error */ + 422: { + headers: { + [name: string]: unknown; + }; + content?: never; + }; + }; + }; + add_model_relationship_api_v1_model_relationships__post: { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + requestBody: { + content: { + "application/json": components["schemas"]["ModelRelationshipCreateRequest"]; + }; + }; + responses: { + /** @description The relationship was successfully created */ + 204: { + headers: { + [name: string]: unknown; + }; + content?: never; + }; + /** @description Invalid model keys or self-referential relationship */ + 400: { + headers: { + [name: string]: unknown; + }; + content?: never; + }; + /** @description The relationship already exists */ + 409: { + headers: { + [name: string]: unknown; + }; + content?: never; + }; + /** @description Validation error */ + 422: { + headers: { + [name: string]: unknown; + }; + content?: never; + }; + /** @description Internal server error */ + 500: { + headers: { + [name: string]: unknown; + }; + content?: never; + }; + }; + }; + remove_model_relationship_api_v1_model_relationships__delete: { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + requestBody: { + content: { + "application/json": components["schemas"]["ModelRelationshipCreateRequest"]; + }; + }; + responses: { + /** @description The relationship was successfully removed */ + 204: { + headers: { + [name: string]: unknown; + }; + content?: never; + }; + /** @description Invalid model keys or self-referential relationship */ + 400: { + headers: { + [name: string]: unknown; + }; + content?: never; + }; + /** @description The relationship does not exist */ + 404: { + headers: { + [name: string]: unknown; + }; + content?: never; + }; + /** @description Validation error */ + 422: { + headers: { + [name: string]: unknown; + }; + content?: never; + }; + /** @description Internal server error */ + 500: { + headers: { + [name: string]: unknown; + }; + content?: never; + }; + }; + }; + get_related_models_batch: { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + requestBody: { + content: { + "application/json": components["schemas"]["ModelRelationshipBatchRequest"]; + }; + }; + responses: { + /** @description Related model keys retrieved successfully */ + 200: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": string[]; + }; + }; + /** @description Validation error */ + 422: { + headers: { + [name: string]: unknown; + }; + content?: never; + }; + /** @description Internal server error */ + 500: { + headers: { + [name: string]: unknown; + }; + content?: never; + }; + }; + }; app_version: { parameters: { query?: never;