From 373675190b98d57d57481901830ecc33711b4cc5 Mon Sep 17 00:00:00 2001 From: db0 Date: Thu, 22 Feb 2024 18:25:29 +0100 Subject: [PATCH 1/3] feat: support stable cascade --- README.md | 1 + horde_model_reference/legacy/classes/legacy_converters.py | 2 ++ horde_model_reference/meta_consts.py | 1 + stable_diffusion.schema.json | 3 ++- 4 files changed, 6 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 29534d6..abd1010 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,7 @@ Some key takeaways for the new `stable_diffusion.json`: - `stable_diffusion_2_768` - `stable_diffusion_2_512` - `stable_diffusion_xl` + - `stable_cascade` - An MD5 sum is no longer included. All models (of all types) will have an SHA included from now on. - `download` entries optionally contain a new key, `known_slow_download`, which indicates this download host is known to be slow at times. diff --git a/horde_model_reference/legacy/classes/legacy_converters.py b/horde_model_reference/legacy/classes/legacy_converters.py index c8f9221..670ef41 100644 --- a/horde_model_reference/legacy/classes/legacy_converters.py +++ b/horde_model_reference/legacy/classes/legacy_converters.py @@ -581,6 +581,8 @@ def convert_legacy_baseline(self, baseline: str): baseline = "stable_diffusion_2_512" elif baseline == "stable_diffusion_xl": baseline = "stable_diffusion_xl" + elif baseline == "stable_cascade": + baseline = "stable_cascade" return baseline def create_showcase_folder(self, showcase_foldername: str) -> None: diff --git a/horde_model_reference/meta_consts.py b/horde_model_reference/meta_consts.py index 2f66199..51c2f98 100644 --- a/horde_model_reference/meta_consts.py +++ b/horde_model_reference/meta_consts.py @@ -86,6 +86,7 @@ class STABLE_DIFFUSION_BASELINE_CATEGORY(StrEnum): stable_diffusion_2_768 = auto() stable_diffusion_2_512 = auto() stable_diffusion_xl = auto() + stable_cascade = auto() MODEL_PURPOSE_LOOKUP: dict[MODEL_REFERENCE_CATEGORY, MODEL_PURPOSE] = { diff --git a/stable_diffusion.schema.json b/stable_diffusion.schema.json index 9980654..6332f2d 100644 --- a/stable_diffusion.schema.json +++ b/stable_diffusion.schema.json @@ -66,7 +66,8 @@ "stable_diffusion_1", "stable_diffusion_2_768", "stable_diffusion_2_512", - "stable_diffusion_xl" + "stable_diffusion_xl", + "stable_cascade" ], "title": "STABLE_DIFFUSION_BASELINE_CATEGORY", "type": "string" From 5a843cabb7df02a9cc8f864ac0d98666769dd4d7 Mon Sep 17 00:00:00 2001 From: tazlin Date: Thu, 22 Feb 2024 15:09:00 -0500 Subject: [PATCH 2/3] fix: allow unknown enum values; explicitly support `file_type` --- .pre-commit-config.yaml | 4 +- .../legacy/classes/legacy_converters.py | 4 ++ .../raw_legacy_model_database_records.py | 2 + .../classes/staging_model_database_records.py | 14 ++++++- .../model_reference_records.py | 37 ++++++++++++++++-- legacy_stable_diffusion.schema.json | 24 ++++++++++++ stable_diffusion.example.json | 2 + stable_diffusion.schema.json | 38 +++++++++++++++++-- tests/test_records.py | 22 +++++++++++ 9 files changed, 137 insertions(+), 10 deletions(-) create mode 100644 tests/test_records.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6189d5d..c419fde 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,10 +6,10 @@ repos: - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/psf/black - rev: 24.1.1 + rev: 24.2.0 hooks: - id: black - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.2.1 + rev: v0.2.2 hooks: - id: ruff diff --git a/horde_model_reference/legacy/classes/legacy_converters.py b/horde_model_reference/legacy/classes/legacy_converters.py index 670ef41..6ef476c 100644 --- a/horde_model_reference/legacy/classes/legacy_converters.py +++ b/horde_model_reference/legacy/classes/legacy_converters.py @@ -219,6 +219,10 @@ def config_record_pre_parse( error = f"{model_record_key} has a download record without a sha256sum." self.add_validation_error_to_log(model_record_key=model_record_key, error=error) parsed_download_record.sha256sum = "FIXME" + + if download.get("file_type") and download["file_type"] == "ckpt": + parsed_download_record.file_type = download["file_type"] + parsed_record_config_download_list.append(parsed_download_record) return parsed_record_config_download_list diff --git a/horde_model_reference/legacy/classes/raw_legacy_model_database_records.py b/horde_model_reference/legacy/classes/raw_legacy_model_database_records.py index 916aefc..10946aa 100644 --- a/horde_model_reference/legacy/classes/raw_legacy_model_database_records.py +++ b/horde_model_reference/legacy/classes/raw_legacy_model_database_records.py @@ -15,6 +15,7 @@ class RawLegacy_DownloadRecord(BaseModel): file_name: str file_path: str file_url: str + file_type: str | None = None class RawLegacy_FileRecord(BaseModel): @@ -23,6 +24,7 @@ class RawLegacy_FileRecord(BaseModel): path: str md5sum: str | None = None sha256sum: str | None = None + file_type: str | None = None class FEATURE_SUPPORTED(StrEnum): diff --git a/horde_model_reference/legacy/classes/staging_model_database_records.py b/horde_model_reference/legacy/classes/staging_model_database_records.py index 073af6a..d8d8b9b 100644 --- a/horde_model_reference/legacy/classes/staging_model_database_records.py +++ b/horde_model_reference/legacy/classes/staging_model_database_records.py @@ -6,8 +6,11 @@ # These classes will persist until the legacy model reference is fully deprecated. +from __future__ import annotations + from collections.abc import Mapping +from loguru import logger from pydantic import BaseModel, ConfigDict, model_validator from horde_model_reference.model_reference_records import MODEL_PURPOSE @@ -42,6 +45,7 @@ class StagingLegacy_Config_DownloadRecord(BaseModel): file_name: str file_path: str = "" + file_type: str | None = None file_url: str sha256sum: str | None = None known_slow_download: bool | None = False @@ -62,9 +66,17 @@ class StagingLegacy_Generic_ModelRecord(BaseModel): config: dict[str, list[StagingLegacy_Config_FileRecord | StagingLegacy_Config_DownloadRecord]] available: bool | None = None - purpose: MODEL_PURPOSE | None = None + purpose: MODEL_PURPOSE | str | None = None features_not_supported: list[str] | None = None + @model_validator(mode="after") + def validator_known_purpose(self) -> StagingLegacy_Generic_ModelRecord: + """Check if the purpose is known.""" + if self.purpose is not None and str(self.purpose) not in MODEL_PURPOSE.__members__: + logger.warning(f"Unknown purpose {self.purpose} for model {self.name}") + + return self + class Legacy_CLIP_ModelRecord(StagingLegacy_Generic_ModelRecord): """A model entry in the legacy model reference.""" diff --git a/horde_model_reference/model_reference_records.py b/horde_model_reference/model_reference_records.py index 72a009e..275f4e5 100644 --- a/horde_model_reference/model_reference_records.py +++ b/horde_model_reference/model_reference_records.py @@ -5,6 +5,7 @@ import urllib.parse from collections.abc import Mapping +from loguru import logger from pydantic import ( BaseModel, ConfigDict, @@ -31,6 +32,7 @@ class DownloadRecord(BaseModel): # TODO Rename? (record to subrecord?) """The fully qualified URL to download the file from.""" sha256sum: str """The sha256sum of the file.""" + file_type: str | None = None known_slow_download: bool | None = None """Whether the download is known to be slow or not.""" @@ -47,11 +49,19 @@ class Generic_ModelRecord(BaseModel): config: dict[str, list[DownloadRecord]] """A dictionary of any configuration files and information on where to download the model file(s).""" - purpose: MODEL_PURPOSE + purpose: MODEL_PURPOSE | str """The purpose of the model.""" features_not_supported: list[str] | None = None + @model_validator(mode="after") + def validator_known_purpose(self) -> Generic_ModelRecord: + """Check if the purpose is known.""" + if str(self.purpose) not in MODEL_PURPOSE.__members__: + logger.warning(f"Unknown purpose {self.purpose} for model {self.name}") + + return self + class StableDiffusion_ModelRecord(Generic_ModelRecord): """A model entry in the model reference.""" @@ -60,7 +70,7 @@ class StableDiffusion_ModelRecord(Generic_ModelRecord): inpainting: bool | None = False """If this is an inpainting model or not.""" - baseline: STABLE_DIFFUSION_BASELINE_CATEGORY + baseline: STABLE_DIFFUSION_BASELINE_CATEGORY | str """The model on which this model is based.""" tags: list[str] | None = [] """Any tags associated with the model which may be useful for searching.""" @@ -75,7 +85,7 @@ class StableDiffusion_ModelRecord(Generic_ModelRecord): nsfw: bool """Whether the model is NSFW or not.""" - style: MODEL_STYLE | None = None + style: MODEL_STYLE | str | None = None """The style of the model.""" size_on_disk_bytes: int | None = None @@ -91,6 +101,17 @@ def validator_set_arrays_to_empty_if_none(self) -> StableDiffusion_ModelRecord: self.trigger = [] return self + @model_validator(mode="after") + def validator_is_baseline_and_style_known(self) -> StableDiffusion_ModelRecord: + """Check if the baseline is known.""" + if str(self.baseline) not in STABLE_DIFFUSION_BASELINE_CATEGORY.__members__: + logger.warning(f"Unknown baseline {self.baseline} for model {self.name}") + + if self.style is not None and str(self.style) not in MODEL_STYLE.__members__: # type: ignore # FIXME + logger.warning(f"Unknown style {self.style} for model {self.name}") + + return self + class CLIP_ModelRecord(Generic_ModelRecord): pretrained_name: str | None = None @@ -98,7 +119,15 @@ class CLIP_ModelRecord(Generic_ModelRecord): class ControlNet_ModelRecord(Generic_ModelRecord): - style: CONTROLNET_STYLE | None = None + style: CONTROLNET_STYLE | str | None = None + + @model_validator(mode="after") + def validator_is_style_known(self) -> ControlNet_ModelRecord: + """Check if the style is known.""" + if self.style is not None and str(self.style) not in CONTROLNET_STYLE.__members__: + logger.warning(f"Unknown style {self.style} for model {self.name}") + + return self class Generic_ModelReference(RootModel[Mapping[str, Generic_ModelRecord]]): diff --git a/legacy_stable_diffusion.schema.json b/legacy_stable_diffusion.schema.json index f76d20a..ead3eab 100644 --- a/legacy_stable_diffusion.schema.json +++ b/legacy_stable_diffusion.schema.json @@ -25,6 +25,18 @@ "file_url": { "title": "File Url", "type": "string" + }, + "file_type": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "File Type" } }, "required": [ @@ -65,6 +77,18 @@ ], "default": null, "title": "Sha256Sum" + }, + "file_type": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "File Type" } }, "required": [ diff --git a/stable_diffusion.example.json b/stable_diffusion.example.json index 720871a..870b459 100644 --- a/stable_diffusion.example.json +++ b/stable_diffusion.example.json @@ -9,6 +9,7 @@ "file_name": "example_general_model.ckpt", "file_url": "https://www.some_website.com/a_different_name_on_the_website.ckpt", "sha256sum": "DEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEF", + "file_type": null, "known_slow_download": null } ] @@ -44,6 +45,7 @@ "file_name": "example_general_model.ckpt", "file_url": "https://www.some_website.com/a_different_name_on_the_website.ckpt", "sha256sum": "DEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEF", + "file_type": null, "known_slow_download": null } ] diff --git a/stable_diffusion.schema.json b/stable_diffusion.schema.json index 6332f2d..f69df4e 100644 --- a/stable_diffusion.schema.json +++ b/stable_diffusion.schema.json @@ -15,6 +15,18 @@ "title": "Sha256Sum", "type": "string" }, + "file_type": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "File Type" + }, "known_slow_download": { "anyOf": [ { @@ -114,7 +126,15 @@ "type": "object" }, "purpose": { - "$ref": "#/$defs/MODEL_PURPOSE" + "anyOf": [ + { + "$ref": "#/$defs/MODEL_PURPOSE" + }, + { + "type": "string" + } + ], + "title": "Purpose" }, "features_not_supported": { "anyOf": [ @@ -144,7 +164,15 @@ "title": "Inpainting" }, "baseline": { - "$ref": "#/$defs/STABLE_DIFFUSION_BASELINE_CATEGORY" + "anyOf": [ + { + "$ref": "#/$defs/STABLE_DIFFUSION_BASELINE_CATEGORY" + }, + { + "type": "string" + } + ], + "title": "Baseline" }, "tags": { "anyOf": [ @@ -224,11 +252,15 @@ { "$ref": "#/$defs/MODEL_STYLE" }, + { + "type": "string" + }, { "type": "null" } ], - "default": null + "default": null, + "title": "Style" }, "size_on_disk_bytes": { "anyOf": [ diff --git a/tests/test_records.py b/tests/test_records.py new file mode 100644 index 0000000..6901a31 --- /dev/null +++ b/tests/test_records.py @@ -0,0 +1,22 @@ +from horde_model_reference.model_reference_records import DownloadRecord, StableDiffusion_ModelRecord + + +def test_stable_diffusion_model_record(): + """Tests the StableDiffusion_ModelRecord class.""" + # Create a record + StableDiffusion_ModelRecord( + name="test_name", + description="test_description", + version="test_version", + style="test_style", + purpose="test_purpose", + inpainting=False, + baseline="test_baseline", + tags=["test_tag"], + nsfw=False, + config={ + "test_config": [ + DownloadRecord(file_name="test_file_name", file_url="test_file_url", sha256sum="test_sha256sum"), + ], + }, + ) From ea6fa66b172496653052b40382e9c4efbfe686e0 Mon Sep 17 00:00:00 2001 From: tazlin Date: Sat, 24 Feb 2024 08:26:08 -0500 Subject: [PATCH 3/3] chore: remove now obsolete mypy directives The issue that these `# type: ignore` directives were silencing appears to have been long ago fixed --- horde_model_reference/model_reference_records.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/horde_model_reference/model_reference_records.py b/horde_model_reference/model_reference_records.py index 275f4e5..fcc1aa7 100644 --- a/horde_model_reference/model_reference_records.py +++ b/horde_model_reference/model_reference_records.py @@ -90,7 +90,7 @@ class StableDiffusion_ModelRecord(Generic_ModelRecord): size_on_disk_bytes: int | None = None - @model_validator(mode="after") # type: ignore # FIXME + @model_validator(mode="after") def validator_set_arrays_to_empty_if_none(self) -> StableDiffusion_ModelRecord: """Set any `None` values to empty lists.""" if self.tags is None: @@ -107,7 +107,7 @@ def validator_is_baseline_and_style_known(self) -> StableDiffusion_ModelRecord: if str(self.baseline) not in STABLE_DIFFUSION_BASELINE_CATEGORY.__members__: logger.warning(f"Unknown baseline {self.baseline} for model {self.name}") - if self.style is not None and str(self.style) not in MODEL_STYLE.__members__: # type: ignore # FIXME + if self.style is not None and str(self.style) not in MODEL_STYLE.__members__: logger.warning(f"Unknown style {self.style} for model {self.name}") return self