Skip to content

feat(ui): model relationship management #7963

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions invokeai/app/api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from invokeai.app.services.model_images.model_images_default import ModelImageFileStorageDisk
from invokeai.app.services.model_manager.model_manager_default import ModelManagerService
from invokeai.app.services.model_records.model_records_sql import ModelRecordServiceSQL
from invokeai.app.services.model_relationships.model_relationships_default import ModelRelationshipsService
from invokeai.app.services.model_relationship_records.model_relationship_records_sqlite import SqliteModelRelationshipRecordStorage
from invokeai.app.services.names.names_default import SimpleNameService
from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk
from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache
Expand Down Expand Up @@ -136,6 +138,8 @@ def initialize(
download_queue=download_queue_service,
events=events,
)
model_relationships = ModelRelationshipsService()
model_relationship_records = SqliteModelRelationshipRecordStorage(db=db)
names = SimpleNameService()
performance_statistics = InvocationStatsService()
session_processor = DefaultSessionProcessor(session_runner=DefaultSessionRunner())
Expand All @@ -161,6 +165,8 @@ def initialize(
logger=logger,
model_images=model_images_service,
model_manager=model_manager,
model_relationships=model_relationships,
model_relationship_records=model_relationship_records,
download_queue=download_queue_service,
names=names,
performance_statistics=performance_statistics,
Expand Down
196 changes: 196 additions & 0 deletions invokeai/app/api/routers/model_relationships.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
"""FastAPI route for model relationship records."""

from fastapi import HTTPException, APIRouter, Path, Body, status
from pydantic import BaseModel, Field
from typing import List
from invokeai.app.api.dependencies import ApiDependencies

model_relationships_router = APIRouter(
prefix="/v1/model_relationships",
tags=["model_relationships"]
)

# === Schemas ===

class ModelRelationshipCreateRequest(BaseModel):
model_key_1: str = Field(..., description="The key of the first model in the relationship", examples=[
"aa3b247f-90c9-4416-bfcd-aeaa57a5339e",
"ac32b914-10ab-496e-a24a-3068724b9c35",
"d944abfd-c7c3-42e2-a4ff-da640b29b8b4",
"b1c2d3e4-f5a6-7890-abcd-ef1234567890",
"12345678-90ab-cdef-1234-567890abcdef",
"fedcba98-7654-3210-fedc-ba9876543210"
])
model_key_2: str = Field(..., description="The key of the second model in the relationship", examples=[
"3bb7c0eb-b6c8-469c-ad8c-4d69c06075e4",
"f0c3da4e-d9ff-42b5-a45c-23be75c887c9",
"38170dd8-f1e5-431e-866c-2c81f1277fcc",
"c57fea2d-7646-424c-b9ad-c0ba60fc68be",
"10f7807b-ab54-46a9-ab03-600e88c630a1",
"f6c1d267-cf87-4ee0-bee0-37e791eacab7"
])

class ModelRelationshipBatchRequest(BaseModel):
model_keys: List[str] = Field(..., description="List of model keys to fetch related models for", examples=
[[
"aa3b247f-90c9-4416-bfcd-aeaa57a5339e",
"ac32b914-10ab-496e-a24a-3068724b9c35",
],[
"b1c2d3e4-f5a6-7890-abcd-ef1234567890",
"12345678-90ab-cdef-1234-567890abcdef",
"fedcba98-7654-3210-fedc-ba9876543210"
],[
"3bb7c0eb-b6c8-469c-ad8c-4d69c06075e4",
]])

# === Routes ===

@model_relationships_router.get(
"/i/{model_key}",
operation_id="get_related_models",
response_model=list[str],
responses={
200: {
"description": "A list of related model keys was retrieved successfully",
"content": {
"application/json": {
"example": [
"15e9eb28-8cfe-47c9-b610-37907a79fc3c",
"71272e82-0e5f-46d5-bca9-9a61f4bd8a82",
"a5d7cd49-1b98-4534-a475-aeee4ccf5fa2"
]
}
},
},
404: {"description": "The specified model could not be found"},
422: {"description": "Validation error"},
},
)
async def get_related_models(
model_key: str = Path(..., description="The key of the model to get relationships for")
) -> list[str]:
"""
Get a list of model keys related to a given model.
"""
try:
return ApiDependencies.invoker.services.model_relationships.get_related_model_keys(model_key)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))


@model_relationships_router.post(
"/",
status_code=status.HTTP_204_NO_CONTENT,
responses={
204: {"description": "The relationship was successfully created"},
400: {"description": "Invalid model keys or self-referential relationship"},
409: {"description": "The relationship already exists"},
422: {"description": "Validation error"},
500: {"description": "Internal server error"},
},
summary="Add Model Relationship",
description="Creates a **bidirectional** relationship between two models, allowing each to reference the other as related.",
)
async def add_model_relationship(
req: ModelRelationshipCreateRequest = Body(..., description="The model keys to relate")
) -> None:
"""
Add a relationship between two models.
Relationships are bidirectional and will be accessible from both models.
- Raises 400 if keys are invalid or identical.
- Raises 409 if the relationship already exists.
"""
try:
if req.model_key_1 == req.model_key_2:
raise HTTPException(status_code=400, detail="Cannot relate a model to itself.")

ApiDependencies.invoker.services.model_relationships.add_model_relationship(
req.model_key_1,
req.model_key_2,
)
except ValueError as e:
raise HTTPException(status_code=409, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))


@model_relationships_router.delete(
"/",
status_code=status.HTTP_204_NO_CONTENT,
responses={
204: {"description": "The relationship was successfully removed"},
400: {"description": "Invalid model keys or self-referential relationship"},
404: {"description": "The relationship does not exist"},
422: {"description": "Validation error"},
500: {"description": "Internal server error"},
},
summary="Remove Model Relationship",
description="Removes a **bidirectional** relationship between two models. The relationship must already exist."
)
async def remove_model_relationship(
req: ModelRelationshipCreateRequest = Body(..., description="The model keys to disconnect")
) -> None:
"""
Removes a bidirectional relationship between two model keys.
- Raises 400 if attempting to unlink a model from itself.
- Raises 404 if the relationship was not found.
"""
try:
if req.model_key_1 == req.model_key_2:
raise HTTPException(status_code=400, detail="Cannot unlink a model from itself.")

ApiDependencies.invoker.services.model_relationships.remove_model_relationship(
req.model_key_1,
req.model_key_2,
)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

@model_relationships_router.post(
"/batch",
operation_id="get_related_models_batch",
response_model=List[str],
responses={
200: {
"description": "Related model keys retrieved successfully",
"content": {
"application/json": {
"example": [
"ca562b14-995e-4a42-90c1-9528f1a5921d",
"cc0c2b8a-c62e-41d6-878e-cc74dde5ca8f",
"18ca7649-6a9e-47d5-bc17-41ab1e8cec81",
"7c12d1b2-0ef9-4bec-ba55-797b2d8f2ee1",
"c382eaa3-0e28-4ab0-9446-408667699aeb",
"71272e82-0e5f-46d5-bca9-9a61f4bd8a82",
"a5d7cd49-1b98-4534-a475-aeee4ccf5fa2"
]
}
}
},
422: {"description": "Validation error"},
500: {"description": "Internal server error"},
},
summary="Get Related Model Keys (Batch)",
description="Retrieves all **unique related model keys** for a list of given models. This is useful for contextual suggestions or filtering."
)
async def get_related_models_batch(
req: ModelRelationshipBatchRequest = Body(..., description="Model keys to check for related connections")
) -> list[str]:
"""
Accepts multiple model keys and returns a flat list of all unique related keys.
Useful when working with multiple selections in the UI or cross-model comparisons.
"""
try:
all_related: set[str] = set()
for key in req.model_keys:
related = ApiDependencies.invoker.services.model_relationships.get_related_model_keys(key)
all_related.update(related)
return list(all_related)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
2 changes: 2 additions & 0 deletions invokeai/app/api_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
download_queue,
images,
model_manager,
model_relationships,
session_queue,
style_presets,
utilities,
Expand Down Expand Up @@ -125,6 +126,7 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
app.include_router(images.images_router, prefix="/api")
app.include_router(boards.boards_router, prefix="/api")
app.include_router(board_images.board_images_router, prefix="/api")
app.include_router(model_relationships.model_relationships_router, prefix="/api")
app.include_router(app_info.app_router, prefix="/api")
app.include_router(session_queue.session_queue_router, prefix="/api")
app.include_router(workflows.workflows_router, prefix="/api")
Expand Down
6 changes: 6 additions & 0 deletions invokeai/app/services/invocation_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from invokeai.app.services.invocation_stats.invocation_stats_base import InvocationStatsServiceBase
from invokeai.app.services.model_images.model_images_base import ModelImageFileStorageBase
from invokeai.app.services.model_manager.model_manager_base import ModelManagerServiceBase
from invokeai.app.services.model_relationship_records.model_relationship_records_base import ModelRelationshipRecordStorageBase
from invokeai.app.services.model_relationships.model_relationships_base import ModelRelationshipsServiceABC
from invokeai.app.services.names.names_base import NameServiceBase
from invokeai.app.services.session_processor.session_processor_base import SessionProcessorBase
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
Expand Down Expand Up @@ -54,6 +56,8 @@ def __init__(
logger: "Logger",
model_images: "ModelImageFileStorageBase",
model_manager: "ModelManagerServiceBase",
model_relationships: "ModelRelationshipsServiceABC",
model_relationship_records: "ModelRelationshipRecordStorageBase",
download_queue: "DownloadQueueServiceBase",
performance_statistics: "InvocationStatsServiceBase",
session_queue: "SessionQueueBase",
Expand Down Expand Up @@ -81,6 +85,8 @@ def __init__(
self.logger = logger
self.model_images = model_images
self.model_manager = model_manager
self.model_relationships = model_relationships
self.model_relationship_records = model_relationship_records
self.download_queue = download_queue
self.performance_statistics = performance_statistics
self.session_queue = session_queue
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from invokeai.backend.model_manager.config import AnyModelConfig

class ModelRelationshipRecordStorageBase(ABC):
"""Abstract base class for model-to-model relationship record storage."""

@abstractmethod
def add_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
"""Creates a relationship between two models by keys."""
pass

@abstractmethod
def remove_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
"""Removes a relationship between two models by keys."""
pass

@abstractmethod
def get_related_model_keys(self, model_key: str) -> list[str]:
"""Gets all models keys related to a given model key."""
pass

@abstractmethod
def get_related_model_keys_batch(self, model_keys: list[str]) -> list[str]:
"""Get related model keys for multiple models given a list of keys."""
pass

@abstractmethod
def get_related_model_key_count(self, model_key: str) -> int:
"""Gets the number of relations for a given model key."""
pass

""" Below are methods that use ModelConfigs instead of model keys, as convenience methods.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we have a model config, we also definitely have the key - I don't think we will ever use these other methods. Lets omit the unused code in the ABC and impl.

These methods are not required to be implemented, but they are potentially useful for later development.
They are not used in the current codebase."""

@abstractmethod
def add_relationship_from_models(self, model_1: "AnyModelConfig", model_2: "AnyModelConfig") -> None:
"""Creates a relationship between two models using ModelConfigs."""
pass

@abstractmethod
def remove_relationship_from_models(self, model_1: "AnyModelConfig", model_2: "AnyModelConfig") -> None:
"""Removes a relationship between two models using ModelConfigs."""
pass

@abstractmethod
def get_related_keys_from_model(self, model: "AnyModelConfig") -> list[str]:
"""Gets all model keys related to a given model using it's config."""
pass

@abstractmethod
def get_related_model_key_count_from_model(self, model: "AnyModelConfig") -> int:
"""Gets the number of relations for a given model config."""
pass
Loading