diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6189d5d..c419fde 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,10 +6,10 @@ repos: - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/psf/black - rev: 24.1.1 + rev: 24.2.0 hooks: - id: black - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.2.1 + rev: v0.2.2 hooks: - id: ruff diff --git a/README.md b/README.md index 29534d6..abd1010 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,7 @@ Some key takeaways for the new `stable_diffusion.json`: - `stable_diffusion_2_768` - `stable_diffusion_2_512` - `stable_diffusion_xl` + - `stable_cascade` - An MD5 sum is no longer included. All models (of all types) will have an SHA included from now on. - `download` entries optionally contain a new key, `known_slow_download`, which indicates this download host is known to be slow at times. diff --git a/horde_model_reference/legacy/classes/legacy_converters.py b/horde_model_reference/legacy/classes/legacy_converters.py index c8f9221..6ef476c 100644 --- a/horde_model_reference/legacy/classes/legacy_converters.py +++ b/horde_model_reference/legacy/classes/legacy_converters.py @@ -219,6 +219,10 @@ def config_record_pre_parse( error = f"{model_record_key} has a download record without a sha256sum." self.add_validation_error_to_log(model_record_key=model_record_key, error=error) parsed_download_record.sha256sum = "FIXME" + + if download.get("file_type") and download["file_type"] == "ckpt": + parsed_download_record.file_type = download["file_type"] + parsed_record_config_download_list.append(parsed_download_record) return parsed_record_config_download_list @@ -581,6 +585,8 @@ def convert_legacy_baseline(self, baseline: str): baseline = "stable_diffusion_2_512" elif baseline == "stable_diffusion_xl": baseline = "stable_diffusion_xl" + elif baseline == "stable_cascade": + baseline = "stable_cascade" return baseline def create_showcase_folder(self, showcase_foldername: str) -> None: diff --git a/horde_model_reference/legacy/classes/raw_legacy_model_database_records.py b/horde_model_reference/legacy/classes/raw_legacy_model_database_records.py index 916aefc..10946aa 100644 --- a/horde_model_reference/legacy/classes/raw_legacy_model_database_records.py +++ b/horde_model_reference/legacy/classes/raw_legacy_model_database_records.py @@ -15,6 +15,7 @@ class RawLegacy_DownloadRecord(BaseModel): file_name: str file_path: str file_url: str + file_type: str | None = None class RawLegacy_FileRecord(BaseModel): @@ -23,6 +24,7 @@ class RawLegacy_FileRecord(BaseModel): path: str md5sum: str | None = None sha256sum: str | None = None + file_type: str | None = None class FEATURE_SUPPORTED(StrEnum): diff --git a/horde_model_reference/legacy/classes/staging_model_database_records.py b/horde_model_reference/legacy/classes/staging_model_database_records.py index 073af6a..d8d8b9b 100644 --- a/horde_model_reference/legacy/classes/staging_model_database_records.py +++ b/horde_model_reference/legacy/classes/staging_model_database_records.py @@ -6,8 +6,11 @@ # These classes will persist until the legacy model reference is fully deprecated. +from __future__ import annotations + from collections.abc import Mapping +from loguru import logger from pydantic import BaseModel, ConfigDict, model_validator from horde_model_reference.model_reference_records import MODEL_PURPOSE @@ -42,6 +45,7 @@ class StagingLegacy_Config_DownloadRecord(BaseModel): file_name: str file_path: str = "" + file_type: str | None = None file_url: str sha256sum: str | None = None known_slow_download: bool | None = False @@ -62,9 +66,17 @@ class StagingLegacy_Generic_ModelRecord(BaseModel): config: dict[str, list[StagingLegacy_Config_FileRecord | StagingLegacy_Config_DownloadRecord]] available: bool | None = None - purpose: MODEL_PURPOSE | None = None + purpose: MODEL_PURPOSE | str | None = None features_not_supported: list[str] | None = None + @model_validator(mode="after") + def validator_known_purpose(self) -> StagingLegacy_Generic_ModelRecord: + """Check if the purpose is known.""" + if self.purpose is not None and str(self.purpose) not in MODEL_PURPOSE.__members__: + logger.warning(f"Unknown purpose {self.purpose} for model {self.name}") + + return self + class Legacy_CLIP_ModelRecord(StagingLegacy_Generic_ModelRecord): """A model entry in the legacy model reference.""" diff --git a/horde_model_reference/meta_consts.py b/horde_model_reference/meta_consts.py index 2f66199..51c2f98 100644 --- a/horde_model_reference/meta_consts.py +++ b/horde_model_reference/meta_consts.py @@ -86,6 +86,7 @@ class STABLE_DIFFUSION_BASELINE_CATEGORY(StrEnum): stable_diffusion_2_768 = auto() stable_diffusion_2_512 = auto() stable_diffusion_xl = auto() + stable_cascade = auto() MODEL_PURPOSE_LOOKUP: dict[MODEL_REFERENCE_CATEGORY, MODEL_PURPOSE] = { diff --git a/horde_model_reference/model_reference_records.py b/horde_model_reference/model_reference_records.py index 72a009e..fcc1aa7 100644 --- a/horde_model_reference/model_reference_records.py +++ b/horde_model_reference/model_reference_records.py @@ -5,6 +5,7 @@ import urllib.parse from collections.abc import Mapping +from loguru import logger from pydantic import ( BaseModel, ConfigDict, @@ -31,6 +32,7 @@ class DownloadRecord(BaseModel): # TODO Rename? (record to subrecord?) """The fully qualified URL to download the file from.""" sha256sum: str """The sha256sum of the file.""" + file_type: str | None = None known_slow_download: bool | None = None """Whether the download is known to be slow or not.""" @@ -47,11 +49,19 @@ class Generic_ModelRecord(BaseModel): config: dict[str, list[DownloadRecord]] """A dictionary of any configuration files and information on where to download the model file(s).""" - purpose: MODEL_PURPOSE + purpose: MODEL_PURPOSE | str """The purpose of the model.""" features_not_supported: list[str] | None = None + @model_validator(mode="after") + def validator_known_purpose(self) -> Generic_ModelRecord: + """Check if the purpose is known.""" + if str(self.purpose) not in MODEL_PURPOSE.__members__: + logger.warning(f"Unknown purpose {self.purpose} for model {self.name}") + + return self + class StableDiffusion_ModelRecord(Generic_ModelRecord): """A model entry in the model reference.""" @@ -60,7 +70,7 @@ class StableDiffusion_ModelRecord(Generic_ModelRecord): inpainting: bool | None = False """If this is an inpainting model or not.""" - baseline: STABLE_DIFFUSION_BASELINE_CATEGORY + baseline: STABLE_DIFFUSION_BASELINE_CATEGORY | str """The model on which this model is based.""" tags: list[str] | None = [] """Any tags associated with the model which may be useful for searching.""" @@ -75,12 +85,12 @@ class StableDiffusion_ModelRecord(Generic_ModelRecord): nsfw: bool """Whether the model is NSFW or not.""" - style: MODEL_STYLE | None = None + style: MODEL_STYLE | str | None = None """The style of the model.""" size_on_disk_bytes: int | None = None - @model_validator(mode="after") # type: ignore # FIXME + @model_validator(mode="after") def validator_set_arrays_to_empty_if_none(self) -> StableDiffusion_ModelRecord: """Set any `None` values to empty lists.""" if self.tags is None: @@ -91,6 +101,17 @@ def validator_set_arrays_to_empty_if_none(self) -> StableDiffusion_ModelRecord: self.trigger = [] return self + @model_validator(mode="after") + def validator_is_baseline_and_style_known(self) -> StableDiffusion_ModelRecord: + """Check if the baseline is known.""" + if str(self.baseline) not in STABLE_DIFFUSION_BASELINE_CATEGORY.__members__: + logger.warning(f"Unknown baseline {self.baseline} for model {self.name}") + + if self.style is not None and str(self.style) not in MODEL_STYLE.__members__: + logger.warning(f"Unknown style {self.style} for model {self.name}") + + return self + class CLIP_ModelRecord(Generic_ModelRecord): pretrained_name: str | None = None @@ -98,7 +119,15 @@ class CLIP_ModelRecord(Generic_ModelRecord): class ControlNet_ModelRecord(Generic_ModelRecord): - style: CONTROLNET_STYLE | None = None + style: CONTROLNET_STYLE | str | None = None + + @model_validator(mode="after") + def validator_is_style_known(self) -> ControlNet_ModelRecord: + """Check if the style is known.""" + if self.style is not None and str(self.style) not in CONTROLNET_STYLE.__members__: + logger.warning(f"Unknown style {self.style} for model {self.name}") + + return self class Generic_ModelReference(RootModel[Mapping[str, Generic_ModelRecord]]): diff --git a/legacy_stable_diffusion.schema.json b/legacy_stable_diffusion.schema.json index f76d20a..ead3eab 100644 --- a/legacy_stable_diffusion.schema.json +++ b/legacy_stable_diffusion.schema.json @@ -25,6 +25,18 @@ "file_url": { "title": "File Url", "type": "string" + }, + "file_type": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "File Type" } }, "required": [ @@ -65,6 +77,18 @@ ], "default": null, "title": "Sha256Sum" + }, + "file_type": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "File Type" } }, "required": [ diff --git a/stable_diffusion.example.json b/stable_diffusion.example.json index 720871a..870b459 100644 --- a/stable_diffusion.example.json +++ b/stable_diffusion.example.json @@ -9,6 +9,7 @@ "file_name": "example_general_model.ckpt", "file_url": "https://www.some_website.com/a_different_name_on_the_website.ckpt", "sha256sum": "DEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEF", + "file_type": null, "known_slow_download": null } ] @@ -44,6 +45,7 @@ "file_name": "example_general_model.ckpt", "file_url": "https://www.some_website.com/a_different_name_on_the_website.ckpt", "sha256sum": "DEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEF", + "file_type": null, "known_slow_download": null } ] diff --git a/stable_diffusion.schema.json b/stable_diffusion.schema.json index 9980654..f69df4e 100644 --- a/stable_diffusion.schema.json +++ b/stable_diffusion.schema.json @@ -15,6 +15,18 @@ "title": "Sha256Sum", "type": "string" }, + "file_type": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "File Type" + }, "known_slow_download": { "anyOf": [ { @@ -66,7 +78,8 @@ "stable_diffusion_1", "stable_diffusion_2_768", "stable_diffusion_2_512", - "stable_diffusion_xl" + "stable_diffusion_xl", + "stable_cascade" ], "title": "STABLE_DIFFUSION_BASELINE_CATEGORY", "type": "string" @@ -113,7 +126,15 @@ "type": "object" }, "purpose": { - "$ref": "#/$defs/MODEL_PURPOSE" + "anyOf": [ + { + "$ref": "#/$defs/MODEL_PURPOSE" + }, + { + "type": "string" + } + ], + "title": "Purpose" }, "features_not_supported": { "anyOf": [ @@ -143,7 +164,15 @@ "title": "Inpainting" }, "baseline": { - "$ref": "#/$defs/STABLE_DIFFUSION_BASELINE_CATEGORY" + "anyOf": [ + { + "$ref": "#/$defs/STABLE_DIFFUSION_BASELINE_CATEGORY" + }, + { + "type": "string" + } + ], + "title": "Baseline" }, "tags": { "anyOf": [ @@ -223,11 +252,15 @@ { "$ref": "#/$defs/MODEL_STYLE" }, + { + "type": "string" + }, { "type": "null" } ], - "default": null + "default": null, + "title": "Style" }, "size_on_disk_bytes": { "anyOf": [ diff --git a/tests/test_records.py b/tests/test_records.py new file mode 100644 index 0000000..6901a31 --- /dev/null +++ b/tests/test_records.py @@ -0,0 +1,22 @@ +from horde_model_reference.model_reference_records import DownloadRecord, StableDiffusion_ModelRecord + + +def test_stable_diffusion_model_record(): + """Tests the StableDiffusion_ModelRecord class.""" + # Create a record + StableDiffusion_ModelRecord( + name="test_name", + description="test_description", + version="test_version", + style="test_style", + purpose="test_purpose", + inpainting=False, + baseline="test_baseline", + tags=["test_tag"], + nsfw=False, + config={ + "test_config": [ + DownloadRecord(file_name="test_file_name", file_url="test_file_url", sha256sum="test_sha256sum"), + ], + }, + )