Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support stable cascade; feat: support API model changes to enum values in the future #75

Merged
merged 3 commits into from
Feb 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/psf/black
rev: 24.1.1
rev: 24.2.0
hooks:
- id: black
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.2.1
rev: v0.2.2
hooks:
- id: ruff
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Some key takeaways for the new `stable_diffusion.json`:
- `stable_diffusion_2_768`
- `stable_diffusion_2_512`
- `stable_diffusion_xl`
- `stable_cascade`
- An MD5 sum is no longer included. All models (of all types) will have an SHA included from now on.
- `download` entries optionally contain a new key, `known_slow_download`, which indicates this download host is known to be slow at times.

Expand Down
6 changes: 6 additions & 0 deletions horde_model_reference/legacy/classes/legacy_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,10 @@ def config_record_pre_parse(
error = f"{model_record_key} has a download record without a sha256sum."
self.add_validation_error_to_log(model_record_key=model_record_key, error=error)
parsed_download_record.sha256sum = "FIXME"

if download.get("file_type") and download["file_type"] == "ckpt":
parsed_download_record.file_type = download["file_type"]

parsed_record_config_download_list.append(parsed_download_record)

return parsed_record_config_download_list
Expand Down Expand Up @@ -581,6 +585,8 @@ def convert_legacy_baseline(self, baseline: str):
baseline = "stable_diffusion_2_512"
elif baseline == "stable_diffusion_xl":
baseline = "stable_diffusion_xl"
elif baseline == "stable_cascade":
baseline = "stable_cascade"
return baseline

def create_showcase_folder(self, showcase_foldername: str) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class RawLegacy_DownloadRecord(BaseModel):
file_name: str
file_path: str
file_url: str
file_type: str | None = None


class RawLegacy_FileRecord(BaseModel):
Expand All @@ -23,6 +24,7 @@ class RawLegacy_FileRecord(BaseModel):
path: str
md5sum: str | None = None
sha256sum: str | None = None
file_type: str | None = None


class FEATURE_SUPPORTED(StrEnum):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@

# These classes will persist until the legacy model reference is fully deprecated.

from __future__ import annotations

from collections.abc import Mapping

from loguru import logger
from pydantic import BaseModel, ConfigDict, model_validator

from horde_model_reference.model_reference_records import MODEL_PURPOSE
Expand Down Expand Up @@ -42,6 +45,7 @@ class StagingLegacy_Config_DownloadRecord(BaseModel):

file_name: str
file_path: str = ""
file_type: str | None = None
file_url: str
sha256sum: str | None = None
known_slow_download: bool | None = False
Expand All @@ -62,9 +66,17 @@ class StagingLegacy_Generic_ModelRecord(BaseModel):
config: dict[str, list[StagingLegacy_Config_FileRecord | StagingLegacy_Config_DownloadRecord]]
available: bool | None = None

purpose: MODEL_PURPOSE | None = None
purpose: MODEL_PURPOSE | str | None = None
features_not_supported: list[str] | None = None

@model_validator(mode="after")
def validator_known_purpose(self) -> StagingLegacy_Generic_ModelRecord:
"""Check if the purpose is known."""
if self.purpose is not None and str(self.purpose) not in MODEL_PURPOSE.__members__:
logger.warning(f"Unknown purpose {self.purpose} for model {self.name}")

return self


class Legacy_CLIP_ModelRecord(StagingLegacy_Generic_ModelRecord):
"""A model entry in the legacy model reference."""
Expand Down
1 change: 1 addition & 0 deletions horde_model_reference/meta_consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class STABLE_DIFFUSION_BASELINE_CATEGORY(StrEnum):
stable_diffusion_2_768 = auto()
stable_diffusion_2_512 = auto()
stable_diffusion_xl = auto()
stable_cascade = auto()


MODEL_PURPOSE_LOOKUP: dict[MODEL_REFERENCE_CATEGORY, MODEL_PURPOSE] = {
Expand Down
39 changes: 34 additions & 5 deletions horde_model_reference/model_reference_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import urllib.parse
from collections.abc import Mapping

from loguru import logger
from pydantic import (
BaseModel,
ConfigDict,
Expand All @@ -31,6 +32,7 @@ class DownloadRecord(BaseModel): # TODO Rename? (record to subrecord?)
"""The fully qualified URL to download the file from."""
sha256sum: str
"""The sha256sum of the file."""
file_type: str | None = None
known_slow_download: bool | None = None
"""Whether the download is known to be slow or not."""

Expand All @@ -47,11 +49,19 @@ class Generic_ModelRecord(BaseModel):
config: dict[str, list[DownloadRecord]]
"""A dictionary of any configuration files and information on where to download the model file(s)."""

purpose: MODEL_PURPOSE
purpose: MODEL_PURPOSE | str
"""The purpose of the model."""

features_not_supported: list[str] | None = None

@model_validator(mode="after")
def validator_known_purpose(self) -> Generic_ModelRecord:
"""Check if the purpose is known."""
if str(self.purpose) not in MODEL_PURPOSE.__members__:
logger.warning(f"Unknown purpose {self.purpose} for model {self.name}")

return self


class StableDiffusion_ModelRecord(Generic_ModelRecord):
"""A model entry in the model reference."""
Expand All @@ -60,7 +70,7 @@ class StableDiffusion_ModelRecord(Generic_ModelRecord):

inpainting: bool | None = False
"""If this is an inpainting model or not."""
baseline: STABLE_DIFFUSION_BASELINE_CATEGORY
baseline: STABLE_DIFFUSION_BASELINE_CATEGORY | str
"""The model on which this model is based."""
tags: list[str] | None = []
"""Any tags associated with the model which may be useful for searching."""
Expand All @@ -75,12 +85,12 @@ class StableDiffusion_ModelRecord(Generic_ModelRecord):
nsfw: bool
"""Whether the model is NSFW or not."""

style: MODEL_STYLE | None = None
style: MODEL_STYLE | str | None = None
"""The style of the model."""

size_on_disk_bytes: int | None = None

@model_validator(mode="after") # type: ignore # FIXME
@model_validator(mode="after")
def validator_set_arrays_to_empty_if_none(self) -> StableDiffusion_ModelRecord:
"""Set any `None` values to empty lists."""
if self.tags is None:
Expand All @@ -91,14 +101,33 @@ def validator_set_arrays_to_empty_if_none(self) -> StableDiffusion_ModelRecord:
self.trigger = []
return self

@model_validator(mode="after")
def validator_is_baseline_and_style_known(self) -> StableDiffusion_ModelRecord:
"""Check if the baseline is known."""
if str(self.baseline) not in STABLE_DIFFUSION_BASELINE_CATEGORY.__members__:
logger.warning(f"Unknown baseline {self.baseline} for model {self.name}")

if self.style is not None and str(self.style) not in MODEL_STYLE.__members__:
logger.warning(f"Unknown style {self.style} for model {self.name}")

return self


class CLIP_ModelRecord(Generic_ModelRecord):
pretrained_name: str | None = None
# TODO docstring


class ControlNet_ModelRecord(Generic_ModelRecord):
style: CONTROLNET_STYLE | None = None
style: CONTROLNET_STYLE | str | None = None

@model_validator(mode="after")
def validator_is_style_known(self) -> ControlNet_ModelRecord:
"""Check if the style is known."""
if self.style is not None and str(self.style) not in CONTROLNET_STYLE.__members__:
logger.warning(f"Unknown style {self.style} for model {self.name}")

return self


class Generic_ModelReference(RootModel[Mapping[str, Generic_ModelRecord]]):
Expand Down
24 changes: 24 additions & 0 deletions legacy_stable_diffusion.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,18 @@
"file_url": {
"title": "File Url",
"type": "string"
},
"file_type": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null,
"title": "File Type"
}
},
"required": [
Expand Down Expand Up @@ -65,6 +77,18 @@
],
"default": null,
"title": "Sha256Sum"
},
"file_type": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null,
"title": "File Type"
}
},
"required": [
Expand Down
2 changes: 2 additions & 0 deletions stable_diffusion.example.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"file_name": "example_general_model.ckpt",
"file_url": "https://www.some_website.com/a_different_name_on_the_website.ckpt",
"sha256sum": "DEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEF",
"file_type": null,
"known_slow_download": null
}
]
Expand Down Expand Up @@ -44,6 +45,7 @@
"file_name": "example_general_model.ckpt",
"file_url": "https://www.some_website.com/a_different_name_on_the_website.ckpt",
"sha256sum": "DEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEF",
"file_type": null,
"known_slow_download": null
}
]
Expand Down
41 changes: 37 additions & 4 deletions stable_diffusion.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,18 @@
"title": "Sha256Sum",
"type": "string"
},
"file_type": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null,
"title": "File Type"
},
"known_slow_download": {
"anyOf": [
{
Expand Down Expand Up @@ -66,7 +78,8 @@
"stable_diffusion_1",
"stable_diffusion_2_768",
"stable_diffusion_2_512",
"stable_diffusion_xl"
"stable_diffusion_xl",
"stable_cascade"
],
"title": "STABLE_DIFFUSION_BASELINE_CATEGORY",
"type": "string"
Expand Down Expand Up @@ -113,7 +126,15 @@
"type": "object"
},
"purpose": {
"$ref": "#/$defs/MODEL_PURPOSE"
"anyOf": [
{
"$ref": "#/$defs/MODEL_PURPOSE"
},
{
"type": "string"
}
],
"title": "Purpose"
},
"features_not_supported": {
"anyOf": [
Expand Down Expand Up @@ -143,7 +164,15 @@
"title": "Inpainting"
},
"baseline": {
"$ref": "#/$defs/STABLE_DIFFUSION_BASELINE_CATEGORY"
"anyOf": [
{
"$ref": "#/$defs/STABLE_DIFFUSION_BASELINE_CATEGORY"
},
{
"type": "string"
}
],
"title": "Baseline"
},
"tags": {
"anyOf": [
Expand Down Expand Up @@ -223,11 +252,15 @@
{
"$ref": "#/$defs/MODEL_STYLE"
},
{
"type": "string"
},
{
"type": "null"
}
],
"default": null
"default": null,
"title": "Style"
},
"size_on_disk_bytes": {
"anyOf": [
Expand Down
22 changes: 22 additions & 0 deletions tests/test_records.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from horde_model_reference.model_reference_records import DownloadRecord, StableDiffusion_ModelRecord


def test_stable_diffusion_model_record():
"""Tests the StableDiffusion_ModelRecord class."""
# Create a record
StableDiffusion_ModelRecord(
name="test_name",
description="test_description",
version="test_version",
style="test_style",
purpose="test_purpose",
inpainting=False,
baseline="test_baseline",
tags=["test_tag"],
nsfw=False,
config={
"test_config": [
DownloadRecord(file_name="test_file_name", file_url="test_file_url", sha256sum="test_sha256sum"),
],
},
)
Loading