-
Notifications
You must be signed in to change notification settings - Fork 2.5k
feat(ui): model relationship management #7963
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
xiaden
wants to merge
2
commits into
invoke-ai:main
Choose a base branch
from
xiaden:feat(ui)-model-relationship-management
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+1,285
−9
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,196 @@ | ||
"""FastAPI route for model relationship records.""" | ||
|
||
from fastapi import HTTPException, APIRouter, Path, Body, status | ||
from pydantic import BaseModel, Field | ||
from typing import List | ||
from invokeai.app.api.dependencies import ApiDependencies | ||
|
||
model_relationships_router = APIRouter( | ||
prefix="/v1/model_relationships", | ||
tags=["model_relationships"] | ||
) | ||
|
||
# === Schemas === | ||
|
||
class ModelRelationshipCreateRequest(BaseModel): | ||
model_key_1: str = Field(..., description="The key of the first model in the relationship", examples=[ | ||
"aa3b247f-90c9-4416-bfcd-aeaa57a5339e", | ||
"ac32b914-10ab-496e-a24a-3068724b9c35", | ||
"d944abfd-c7c3-42e2-a4ff-da640b29b8b4", | ||
"b1c2d3e4-f5a6-7890-abcd-ef1234567890", | ||
"12345678-90ab-cdef-1234-567890abcdef", | ||
"fedcba98-7654-3210-fedc-ba9876543210" | ||
]) | ||
model_key_2: str = Field(..., description="The key of the second model in the relationship", examples=[ | ||
"3bb7c0eb-b6c8-469c-ad8c-4d69c06075e4", | ||
"f0c3da4e-d9ff-42b5-a45c-23be75c887c9", | ||
"38170dd8-f1e5-431e-866c-2c81f1277fcc", | ||
"c57fea2d-7646-424c-b9ad-c0ba60fc68be", | ||
"10f7807b-ab54-46a9-ab03-600e88c630a1", | ||
"f6c1d267-cf87-4ee0-bee0-37e791eacab7" | ||
]) | ||
|
||
class ModelRelationshipBatchRequest(BaseModel): | ||
model_keys: List[str] = Field(..., description="List of model keys to fetch related models for", examples= | ||
[[ | ||
"aa3b247f-90c9-4416-bfcd-aeaa57a5339e", | ||
"ac32b914-10ab-496e-a24a-3068724b9c35", | ||
],[ | ||
"b1c2d3e4-f5a6-7890-abcd-ef1234567890", | ||
"12345678-90ab-cdef-1234-567890abcdef", | ||
"fedcba98-7654-3210-fedc-ba9876543210" | ||
],[ | ||
"3bb7c0eb-b6c8-469c-ad8c-4d69c06075e4", | ||
]]) | ||
|
||
# === Routes === | ||
|
||
@model_relationships_router.get( | ||
"/i/{model_key}", | ||
operation_id="get_related_models", | ||
response_model=list[str], | ||
responses={ | ||
200: { | ||
"description": "A list of related model keys was retrieved successfully", | ||
"content": { | ||
"application/json": { | ||
"example": [ | ||
"15e9eb28-8cfe-47c9-b610-37907a79fc3c", | ||
"71272e82-0e5f-46d5-bca9-9a61f4bd8a82", | ||
"a5d7cd49-1b98-4534-a475-aeee4ccf5fa2" | ||
] | ||
} | ||
}, | ||
}, | ||
404: {"description": "The specified model could not be found"}, | ||
422: {"description": "Validation error"}, | ||
}, | ||
) | ||
async def get_related_models( | ||
model_key: str = Path(..., description="The key of the model to get relationships for") | ||
) -> list[str]: | ||
""" | ||
Get a list of model keys related to a given model. | ||
""" | ||
try: | ||
return ApiDependencies.invoker.services.model_relationships.get_related_model_keys(model_key) | ||
except Exception as e: | ||
raise HTTPException(status_code=500, detail=str(e)) | ||
|
||
|
||
@model_relationships_router.post( | ||
"/", | ||
status_code=status.HTTP_204_NO_CONTENT, | ||
responses={ | ||
204: {"description": "The relationship was successfully created"}, | ||
400: {"description": "Invalid model keys or self-referential relationship"}, | ||
409: {"description": "The relationship already exists"}, | ||
422: {"description": "Validation error"}, | ||
500: {"description": "Internal server error"}, | ||
}, | ||
summary="Add Model Relationship", | ||
description="Creates a **bidirectional** relationship between two models, allowing each to reference the other as related.", | ||
) | ||
async def add_model_relationship( | ||
req: ModelRelationshipCreateRequest = Body(..., description="The model keys to relate") | ||
) -> None: | ||
""" | ||
Add a relationship between two models. | ||
Relationships are bidirectional and will be accessible from both models. | ||
- Raises 400 if keys are invalid or identical. | ||
- Raises 409 if the relationship already exists. | ||
""" | ||
try: | ||
if req.model_key_1 == req.model_key_2: | ||
raise HTTPException(status_code=400, detail="Cannot relate a model to itself.") | ||
|
||
ApiDependencies.invoker.services.model_relationships.add_model_relationship( | ||
req.model_key_1, | ||
req.model_key_2, | ||
) | ||
except ValueError as e: | ||
raise HTTPException(status_code=409, detail=str(e)) | ||
except Exception as e: | ||
raise HTTPException(status_code=500, detail=str(e)) | ||
|
||
|
||
@model_relationships_router.delete( | ||
"/", | ||
status_code=status.HTTP_204_NO_CONTENT, | ||
responses={ | ||
204: {"description": "The relationship was successfully removed"}, | ||
400: {"description": "Invalid model keys or self-referential relationship"}, | ||
404: {"description": "The relationship does not exist"}, | ||
422: {"description": "Validation error"}, | ||
500: {"description": "Internal server error"}, | ||
}, | ||
summary="Remove Model Relationship", | ||
description="Removes a **bidirectional** relationship between two models. The relationship must already exist." | ||
) | ||
async def remove_model_relationship( | ||
req: ModelRelationshipCreateRequest = Body(..., description="The model keys to disconnect") | ||
) -> None: | ||
""" | ||
Removes a bidirectional relationship between two model keys. | ||
- Raises 400 if attempting to unlink a model from itself. | ||
- Raises 404 if the relationship was not found. | ||
""" | ||
try: | ||
if req.model_key_1 == req.model_key_2: | ||
raise HTTPException(status_code=400, detail="Cannot unlink a model from itself.") | ||
|
||
ApiDependencies.invoker.services.model_relationships.remove_model_relationship( | ||
req.model_key_1, | ||
req.model_key_2, | ||
) | ||
except ValueError as e: | ||
raise HTTPException(status_code=404, detail=str(e)) | ||
except Exception as e: | ||
raise HTTPException(status_code=500, detail=str(e)) | ||
|
||
@model_relationships_router.post( | ||
"/batch", | ||
operation_id="get_related_models_batch", | ||
response_model=List[str], | ||
responses={ | ||
200: { | ||
"description": "Related model keys retrieved successfully", | ||
"content": { | ||
"application/json": { | ||
"example": [ | ||
"ca562b14-995e-4a42-90c1-9528f1a5921d", | ||
"cc0c2b8a-c62e-41d6-878e-cc74dde5ca8f", | ||
"18ca7649-6a9e-47d5-bc17-41ab1e8cec81", | ||
"7c12d1b2-0ef9-4bec-ba55-797b2d8f2ee1", | ||
"c382eaa3-0e28-4ab0-9446-408667699aeb", | ||
"71272e82-0e5f-46d5-bca9-9a61f4bd8a82", | ||
"a5d7cd49-1b98-4534-a475-aeee4ccf5fa2" | ||
] | ||
} | ||
} | ||
}, | ||
422: {"description": "Validation error"}, | ||
500: {"description": "Internal server error"}, | ||
}, | ||
summary="Get Related Model Keys (Batch)", | ||
description="Retrieves all **unique related model keys** for a list of given models. This is useful for contextual suggestions or filtering." | ||
) | ||
async def get_related_models_batch( | ||
req: ModelRelationshipBatchRequest = Body(..., description="Model keys to check for related connections") | ||
) -> list[str]: | ||
""" | ||
Accepts multiple model keys and returns a flat list of all unique related keys. | ||
Useful when working with multiple selections in the UI or cross-model comparisons. | ||
""" | ||
try: | ||
all_related: set[str] = set() | ||
for key in req.model_keys: | ||
related = ApiDependencies.invoker.services.model_relationships.get_related_model_keys(key) | ||
all_related.update(related) | ||
return list(all_related) | ||
except Exception as e: | ||
raise HTTPException(status_code=500, detail=str(e)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
57 changes: 57 additions & 0 deletions
57
invokeai/app/services/model_relationship_records/model_relationship_records_base.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import TYPE_CHECKING | ||
|
||
if TYPE_CHECKING: | ||
from invokeai.backend.model_manager.config import AnyModelConfig | ||
|
||
class ModelRelationshipRecordStorageBase(ABC): | ||
"""Abstract base class for model-to-model relationship record storage.""" | ||
|
||
@abstractmethod | ||
def add_model_relationship(self, model_key_1: str, model_key_2: str) -> None: | ||
"""Creates a relationship between two models by keys.""" | ||
pass | ||
|
||
@abstractmethod | ||
def remove_model_relationship(self, model_key_1: str, model_key_2: str) -> None: | ||
"""Removes a relationship between two models by keys.""" | ||
pass | ||
|
||
@abstractmethod | ||
def get_related_model_keys(self, model_key: str) -> list[str]: | ||
"""Gets all models keys related to a given model key.""" | ||
pass | ||
|
||
@abstractmethod | ||
def get_related_model_keys_batch(self, model_keys: list[str]) -> list[str]: | ||
"""Get related model keys for multiple models given a list of keys.""" | ||
pass | ||
|
||
@abstractmethod | ||
def get_related_model_key_count(self, model_key: str) -> int: | ||
"""Gets the number of relations for a given model key.""" | ||
pass | ||
|
||
""" Below are methods that use ModelConfigs instead of model keys, as convenience methods. | ||
These methods are not required to be implemented, but they are potentially useful for later development. | ||
They are not used in the current codebase.""" | ||
|
||
@abstractmethod | ||
def add_relationship_from_models(self, model_1: "AnyModelConfig", model_2: "AnyModelConfig") -> None: | ||
"""Creates a relationship between two models using ModelConfigs.""" | ||
pass | ||
|
||
@abstractmethod | ||
def remove_relationship_from_models(self, model_1: "AnyModelConfig", model_2: "AnyModelConfig") -> None: | ||
"""Removes a relationship between two models using ModelConfigs.""" | ||
pass | ||
|
||
@abstractmethod | ||
def get_related_keys_from_model(self, model: "AnyModelConfig") -> list[str]: | ||
"""Gets all model keys related to a given model using it's config.""" | ||
pass | ||
|
||
@abstractmethod | ||
def get_related_model_key_count_from_model(self, model: "AnyModelConfig") -> int: | ||
"""Gets the number of relations for a given model config.""" | ||
pass |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we have a model config, we also definitely have the key - I don't think we will ever use these other methods. Lets omit the unused code in the ABC and impl.