Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support consts/stubs for TI and LoRa #168

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions create_example_json.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json

from horde_model_reference import MODEL_PURPOSE, MODEL_STYLE, STABLE_DIFFUSION_BASELINE_CATEGORY
from horde_model_reference import IMAGE_GENERATION_BASELINE, MODEL_PURPOSE, MODEL_STYLE
from horde_model_reference.legacy.classes.raw_legacy_model_database_records import (
RawLegacy_StableDiffusion_ModelReference,
)
Expand Down Expand Up @@ -33,7 +33,7 @@ def main():
style=MODEL_STYLE.generalist,
config={"download": [example_download_record]},
purpose=MODEL_PURPOSE.image_generation,
baseline=STABLE_DIFFUSION_BASELINE_CATEGORY.stable_diffusion_1,
baseline=IMAGE_GENERATION_BASELINE.stable_diffusion_1,
tags=["anime", "faces"],
showcases=[
"https://raw.githubusercontent.com/db0/AI-Horde-image-model-reference/main/showcase/test/test_general_01.png",
Expand All @@ -59,7 +59,7 @@ def main():
style=MODEL_STYLE.anime,
config={"download": [example_download_record_2]},
purpose=MODEL_PURPOSE.image_generation,
baseline=STABLE_DIFFUSION_BASELINE_CATEGORY.stable_diffusion_1,
baseline=IMAGE_GENERATION_BASELINE.stable_diffusion_1,
tags=["anime", "faces"],
showcases=[
"https://raw.githubusercontent.com/db0/AI-Horde-image-model-reference/main/showcase/test/anime_01.png",
Expand Down
4 changes: 2 additions & 2 deletions horde_model_reference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@


from .meta_consts import ( # noqa: E402
IMAGE_GENERATION_BASELINE,
KNOWN_TAGS,
MODEL_PURPOSE,
MODEL_PURPOSE_LOOKUP,
MODEL_REFERENCE_CATEGORY,
MODEL_STYLE,
STABLE_DIFFUSION_BASELINE_CATEGORY,
)
from .path_consts import ( # noqa: E402
BASE_PATH,
Expand All @@ -26,7 +26,7 @@
"MODEL_PURPOSE",
"MODEL_PURPOSE_LOOKUP",
"MODEL_STYLE",
"STABLE_DIFFUSION_BASELINE_CATEGORY",
"IMAGE_GENERATION_BASELINE",
"BASE_PATH",
"DEFAULT_SHOWCASE_FOLDER_NAME",
"LEGACY_REFERENCE_FOLDER",
Expand Down
2 changes: 2 additions & 0 deletions horde_model_reference/legacy/convert_all_legacy_dbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def convert_all_legacy_model_references(
non_generic_converter_categories = [
MODEL_REFERENCE_CATEGORY.stable_diffusion,
MODEL_REFERENCE_CATEGORY.clip,
MODEL_REFERENCE_CATEGORY.lora,
MODEL_REFERENCE_CATEGORY.ti,
]
generic_converted_categories = [x for x in MODEL_REFERENCE_CATEGORY if x not in non_generic_converter_categories]
for model_category in generic_converted_categories:
Expand Down
10 changes: 7 additions & 3 deletions horde_model_reference/legacy/legacy_download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from loguru import logger

from horde_model_reference.legacy.convert_all_legacy_dbs import convert_all_legacy_model_references
from horde_model_reference.meta_consts import MODEL_REFERENCE_CATEGORY
from horde_model_reference.meta_consts import LOCAL_MODEL_REFERENCE_CATEGORIES, MODEL_REFERENCE_CATEGORY
from horde_model_reference.path_consts import (
BASE_PATH,
HORDE_PROXY_URL_BASE,
Expand Down Expand Up @@ -43,6 +43,12 @@ def download_legacy_model_reference(
model_category_name: MODEL_REFERENCE_CATEGORY,
override_existing: bool = False,
) -> pathlib.Path | None:
target_file_path = get_model_reference_file_path(model_category_name, base_path=self.legacy_path)

if model_category_name in LOCAL_MODEL_REFERENCE_CATEGORIES:
logger.debug(f"Skipping download of {model_category_name} reference file, as it is local.")
return target_file_path

response = requests.get(self.proxy_url + LEGACY_MODEL_GITHUB_URLS[model_category_name])
if response.status_code != 200:
logger.error(f"Failed to download {model_category_name} reference file.")
Expand All @@ -54,8 +60,6 @@ def download_legacy_model_reference(
logger.error(f"Failed to parse {model_category_name} reference file as JSON.")
return None

target_file_path = get_model_reference_file_path(model_category_name, base_path=self.legacy_path)

if target_file_path.exists() and not override_existing:
logger.debug(f"File {target_file_path} already exists, skipping download.")
return None
Expand Down
28 changes: 18 additions & 10 deletions horde_model_reference/meta_consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,14 @@ class MODEL_REFERENCE_CATEGORY(StrEnum):
safety_checker = auto()
stable_diffusion = auto()
miscellaneous = auto()
ti = auto()
lora = auto()


LOCAL_MODEL_REFERENCE_CATEGORIES = [
MODEL_REFERENCE_CATEGORY.ti,
MODEL_REFERENCE_CATEGORY.lora,
]


class MODEL_PURPOSE(StrEnum):
Expand All @@ -82,15 +90,15 @@ class MODEL_PURPOSE(StrEnum):
miscellaneous = auto()


class STABLE_DIFFUSION_BASELINE_CATEGORY(StrEnum):
class IMAGE_GENERATION_BASELINE(StrEnum):
"""An enum of all the image generation baselines."""

stable_diffusion_1 = auto()
stable_diffusion_2_768 = auto()
stable_diffusion_2_512 = auto()
stable_diffusion_xl = auto()
stable_cascade = auto()
flux_1 = auto() # TODO: Extract flux and create "IMAGE_GENERATION_BASELINE_CATEGORY" due to name inconsistency
flux_1 = auto()


MODEL_PURPOSE_LOOKUP: dict[MODEL_REFERENCE_CATEGORY, MODEL_PURPOSE] = {
Expand All @@ -105,18 +113,18 @@ class STABLE_DIFFUSION_BASELINE_CATEGORY(StrEnum):
MODEL_REFERENCE_CATEGORY.miscellaneous: MODEL_PURPOSE.miscellaneous,
}

STABLE_DIFFUSION_BASELINE_NATIVE_RESOLUTION_LOOKUP: dict[STABLE_DIFFUSION_BASELINE_CATEGORY, int] = {
STABLE_DIFFUSION_BASELINE_CATEGORY.stable_diffusion_1: 512,
STABLE_DIFFUSION_BASELINE_CATEGORY.stable_diffusion_2_768: 768,
STABLE_DIFFUSION_BASELINE_CATEGORY.stable_diffusion_2_512: 512,
STABLE_DIFFUSION_BASELINE_CATEGORY.stable_diffusion_xl: 1024,
STABLE_DIFFUSION_BASELINE_CATEGORY.stable_cascade: 1024,
STABLE_DIFFUSION_BASELINE_CATEGORY.flux_1: 1024,
STABLE_DIFFUSION_BASELINE_NATIVE_RESOLUTION_LOOKUP: dict[IMAGE_GENERATION_BASELINE, int] = {
IMAGE_GENERATION_BASELINE.stable_diffusion_1: 512,
IMAGE_GENERATION_BASELINE.stable_diffusion_2_768: 768,
IMAGE_GENERATION_BASELINE.stable_diffusion_2_512: 512,
IMAGE_GENERATION_BASELINE.stable_diffusion_xl: 1024,
IMAGE_GENERATION_BASELINE.stable_cascade: 1024,
IMAGE_GENERATION_BASELINE.flux_1: 1024,
}
"""The single-side preferred resolution for each known stable diffusion baseline."""


def get_baseline_native_resolution(baseline: STABLE_DIFFUSION_BASELINE_CATEGORY) -> int:
def get_baseline_native_resolution(baseline: IMAGE_GENERATION_BASELINE) -> int:
"""
Get the native resolution of a stable diffusion baseline.

Expand Down
12 changes: 6 additions & 6 deletions horde_model_reference/model_reference_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
)

from horde_model_reference import (
IMAGE_GENERATION_BASELINE,
MODEL_PURPOSE,
MODEL_REFERENCE_CATEGORY,
MODEL_STYLE,
STABLE_DIFFUSION_BASELINE_CATEGORY,
)
from horde_model_reference.meta_consts import CONTROLNET_STYLE

Expand Down Expand Up @@ -70,7 +70,7 @@ class StableDiffusion_ModelRecord(Generic_ModelRecord):

inpainting: bool | None = False
"""If this is an inpainting model or not."""
baseline: STABLE_DIFFUSION_BASELINE_CATEGORY | str
baseline: IMAGE_GENERATION_BASELINE | str
"""The model on which this model is based."""
optimization: str | None = None
"""The optimization type of the model."""
Expand Down Expand Up @@ -108,7 +108,7 @@ def validator_set_arrays_to_empty_if_none(self) -> StableDiffusion_ModelRecord:
@model_validator(mode="after")
def validator_is_baseline_and_style_known(self) -> StableDiffusion_ModelRecord:
"""Check if the baseline is known."""
if str(self.baseline) not in STABLE_DIFFUSION_BASELINE_CATEGORY.__members__:
if str(self.baseline) not in IMAGE_GENERATION_BASELINE.__members__:
logger.warning(f"Unknown baseline {self.baseline} for model {self.name}")

if self.style is not None and str(self.style) not in MODEL_STYLE.__members__:
Expand Down Expand Up @@ -142,7 +142,7 @@ class Generic_ModelReference(RootModel[Mapping[str, Generic_ModelRecord]]):
class StableDiffusion_ModelReference(Generic_ModelReference):
"""The combined metadata and model list."""

_baseline: dict[STABLE_DIFFUSION_BASELINE_CATEGORY | str, int] = PrivateAttr(default_factory=dict)
_baseline: dict[IMAGE_GENERATION_BASELINE | str, int] = PrivateAttr(default_factory=dict)
"""A dictionary of all the baseline types and how many models use them."""
_styles: dict[MODEL_STYLE | str, int] = PrivateAttr(default_factory=dict)
"""A dictionary of all the styles and how many models use them."""
Expand Down Expand Up @@ -204,7 +204,7 @@ def rebuild_metadata(self) -> None:
self._models_hosts[host] = self._models_hosts.get(host, 0) + 1

@property
def baseline(self) -> dict[STABLE_DIFFUSION_BASELINE_CATEGORY | str, int]:
def baseline(self) -> dict[IMAGE_GENERATION_BASELINE | str, int]:
"""Return a dictionary of all the baseline types and how many models use them."""
self.check_was_models_modified()
return self._baseline
Expand Down Expand Up @@ -232,7 +232,7 @@ def models_names(self) -> set[str]:
"""Return a list of all the model names."""
return set(self.root.keys())

def get_model_baseline(self, model_name: str) -> STABLE_DIFFUSION_BASELINE_CATEGORY | str | None:
def get_model_baseline(self, model_name: str) -> IMAGE_GENERATION_BASELINE | str | None:
"""Return the baseline for a given model name."""
model: StableDiffusion_ModelRecord | None = self.root.get(model_name)

Expand Down
28 changes: 14 additions & 14 deletions stable_diffusion.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,19 @@
"title": "DownloadRecord",
"type": "object"
},
"IMAGE_GENERATION_BASELINE": {
"description": "An enum of all the image generation baselines.",
"enum": [
"stable_diffusion_1",
"stable_diffusion_2_768",
"stable_diffusion_2_512",
"stable_diffusion_xl",
"stable_cascade",
"flux_1"
],
"title": "IMAGE_GENERATION_BASELINE",
"type": "string"
},
"MODEL_PURPOSE": {
"enum": [
"image_generation",
Expand All @@ -73,19 +86,6 @@
"title": "MODEL_STYLE",
"type": "string"
},
"STABLE_DIFFUSION_BASELINE_CATEGORY": {
"description": "An enum of all the image generation baselines.",
"enum": [
"stable_diffusion_1",
"stable_diffusion_2_768",
"stable_diffusion_2_512",
"stable_diffusion_xl",
"stable_cascade",
"flux_1"
],
"title": "STABLE_DIFFUSION_BASELINE_CATEGORY",
"type": "string"
},
"StableDiffusion_ModelRecord": {
"description": "A model entry in the model reference.",
"properties": {
Expand Down Expand Up @@ -168,7 +168,7 @@
"baseline": {
"anyOf": [
{
"$ref": "#/$defs/STABLE_DIFFUSION_BASELINE_CATEGORY"
"$ref": "#/$defs/IMAGE_GENERATION_BASELINE"
},
{
"type": "string"
Expand Down
Loading