From f093056389c9bf38f69caff9effe69047d927bea Mon Sep 17 00:00:00 2001 From: Thomas Roche <2297889+thoroc@users.noreply.github.com> Date: Wed, 15 Feb 2023 15:56:33 +0000 Subject: [PATCH] Feature/GitHub action Python formatting/linting #3 (#14) * linting stage for the github action workflow * linting stage for the github action workflow * more complete gitignore * formatting all python code with black * restricting to ubuntu * black and isort ran locally * fixing compatibility in examples * format/linting before pushing * removed dead code * sorting imports on 007 * removing docformatter * removing pycodestyle * adding the args options to the github action * removed py3.10 as an option * disabling docformatter * disabling pycodestyle * trying again to disable pycodestyle * trying something new * Revert "trying something new" This reverts commit 6e0f6c3d9c56bebdfa788d9c3aa1631b251e4dff. * getting rid of the magical action * folding the actions together * allowing end of file to be a blank line * end of file blank line rule for flake8 * reverted python files to the upstream main state * ran black locally * ran isort locally * ran autoflake locally * reformatting via black * missing import for load_model * flake8 will ignore the line length for now * f-string unneeded * reverted changes to the init files in sdkit module * installed dev requirements * running isort --- .flake8 | 4 + .github/workflows/py-lint.yml | 59 ++++ .gitignore | 163 ++++++++++- examples/001-generate-model_file_on_disk.py | 12 +- examples/002-generate-download_known_model.py | 16 +- ...generate-download_multiple_known_models.py | 10 +- examples/003-generate-custom_model.py | 14 +- examples/004-generate-custom_vae.py | 16 +- examples/005-generate-custom_hypernetwork.py | 25 +- examples/006-generate-change_models.py | 35 ++- examples/007-generate-and-filter.py | 24 +- examples/100-filter-fix_faces.py | 15 +- ...1-filter-fix_faces_download_known_model.py | 17 +- examples/102-filter-upscale.py | 17 +- examples/103-filter-multiple.py | 27 +- examples/200-train-merge_models.py | 8 +- examples/300-device-low-vram.py | 26 +- examples/301-device-run_on_different_gpu.py | 12 +- examples/302-device-run_on_cpu.py | 12 +- examples/304-device-run_on_multiple_gpus.py | 31 +- examples/305-device-full_precision.py | 12 +- examples/400-security-scan_model.py | 12 +- pyproject.toml | 20 ++ requirements.txt | 0 requirements_dev.txt | 5 + scripts/compare_benchmarks.py | 20 +- scripts/download_all_models.py | 35 ++- scripts/print_quick_hashes.py | 22 +- scripts/run_everything.py | 272 ++++++++++++------ scripts/txt2img_original.py | 90 +++--- sdkit/__init__.py | 28 +- sdkit/filter/apply_filters.py | 18 +- sdkit/filter/gfpgan.py | 23 +- sdkit/filter/realesrgan.py | 11 +- sdkit/generate/__init__.py | 4 +- sdkit/generate/image_generator.py | 139 +++++---- sdkit/generate/prompt_parser.py | 31 +- sdkit/generate/sampler/default_samplers.py | 88 +++--- sdkit/generate/sampler/k_samplers.py | 78 +++-- sdkit/generate/sampler/sampler_main.py | 31 +- sdkit/models/__init__.py | 13 +- sdkit/models/model_downloader.py | 37 ++- sdkit/models/model_loader/__init__.py | 29 +- sdkit/models/model_loader/gfpgan.py | 14 +- .../model_loader/hypernetwork/__init__.py | 58 ++-- .../model_loader/hypernetwork/hypernetwork.py | 54 +++- sdkit/models/model_loader/realesrgan.py | 23 +- .../model_loader/stable_diffusion/__init__.py | 80 ++++-- .../stable_diffusion/optimizations.py | 117 +++++--- sdkit/models/model_loader/vae.py | 21 +- sdkit/models/models_db/__init__.py | 14 +- sdkit/models/scan_models.py | 5 +- sdkit/train/merge_models.py | 35 +-- sdkit/utils/__init__.py | 46 +-- sdkit/utils/file_utils.py | 61 ++-- sdkit/utils/hash_utils.py | 31 +- sdkit/utils/http_utils.py | 28 +- sdkit/utils/image_utils.py | 16 +- sdkit/utils/latent_utils.py | 25 +- sdkit/utils/memory_utils.py | 103 ++++--- tests/vram_frees_after_image_generation.py | 36 ++- 61 files changed, 1492 insertions(+), 836 deletions(-) create mode 100644 .flake8 create mode 100644 .github/workflows/py-lint.yml create mode 100644 requirements.txt create mode 100644 requirements_dev.txt diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..1848441 --- /dev/null +++ b/.flake8 @@ -0,0 +1,4 @@ +[flake8] +max-line-length = 120 +extend-ignore = E203, E402, E722, W391 +per-file-ignores = __init__.py:F401 \ No newline at end of file diff --git a/.github/workflows/py-lint.yml b/.github/workflows/py-lint.yml new file mode 100644 index 0000000..0645b17 --- /dev/null +++ b/.github/workflows/py-lint.yml @@ -0,0 +1,59 @@ +name: Python Linting + +on: + push: + branches: + - "!main" + pull_request: + branches: + - main + +jobs: + linting: + name: Python Linting + runs-on: ${{ matrix.os }} + strategy: + fail-fast: true + matrix: + # apparently only ubuntu latest is available to run a contenairised job + os: [ubuntu-latest] + python-version: ["3.9"] + architecture: ["x64"] + + steps: + #---------------------------------------------- + # check-out repo + #---------------------------------------------- + - name: Check out code + uses: actions/checkout@v3 + + #---------------------------------------------- + # set-up python + #---------------------------------------------- + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + #---------------------------------------------- + # Formatting & Linting + #---------------------------------------------- + - name: Format and lint the code + run: | + pip install black isort autoflake flake8 + black tests scripts sdkit examples --line-length 120 --include="\.py" + isort tests scripts sdkit examples --profile black + autoflake --in-place --remove-all-unused-imports --remove-unused-variables --recursive . + flake8 tests scripts sdkit examples --max-line-length 120 --extend-ignore=E203,E402,E501,E722,W391 --per-file-ignores=__init__.py:F401 + + #---------------------------------------------- + # Committing all changes + #---------------------------------------------- + - name: Commit changes + uses: EndBug/add-and-commit@v4 + with: + author_name: ${{ github.actor }} + author_email: ${{ github.actor }}@users.noreply.github.com + message: "Code formatted and linted" + add: "." + branch: ${{ github.ref }} diff --git a/.gitignore b/.gitignore index 0e78141..68bc17f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,160 @@ -__pycache__ -sdkit.egg-info -dist +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ diff --git a/examples/001-generate-model_file_on_disk.py b/examples/001-generate-model_file_on_disk.py index 2a2bb29..390d8a4 100644 --- a/examples/001-generate-model_file_on_disk.py +++ b/examples/001-generate-model_file_on_disk.py @@ -1,18 +1,18 @@ import sdkit -from sdkit.models import load_model from sdkit.generate import generate_images -from sdkit.utils import save_images, log +from sdkit.models import load_model +from sdkit.utils import log, save_images context = sdkit.Context() # set the path to the model file on the disk (.ckpt or .safetensors file) -context.model_paths['stable-diffusion'] = 'D:\\path\\to\\512-base-ema.ckpt' -load_model(context, 'stable-diffusion') +context.model_paths["stable-diffusion"] = "D:\\path\\to\\512-base-ema.ckpt" +load_model(context, "stable-diffusion") # generate the image -images = generate_images(context, prompt='Photograph of an astronaut riding a horse', seed=42, width=512, height=512) +images = generate_images(context, prompt="Photograph of an astronaut riding a horse", seed=42, width=512, height=512) # save the image -save_images(images, dir_path='.') +save_images(images, dir_path=".") log.info("Generated images!") diff --git a/examples/002-generate-download_known_model.py b/examples/002-generate-download_known_model.py index ccaf9ef..80c8eda 100644 --- a/examples/002-generate-download_known_model.py +++ b/examples/002-generate-download_known_model.py @@ -1,21 +1,23 @@ import sdkit -from sdkit.models import load_model, download_model, resolve_downloaded_model_path from sdkit.generate import generate_images -from sdkit.utils import save_images, log +from sdkit.models import download_model, load_model, resolve_downloaded_model_path +from sdkit.utils import log, save_images context = sdkit.Context() # download the model (skips if already downloaded, resumes if downloaded partially) -download_model(model_type='stable-diffusion', model_id='1.5-pruned-emaonly') +download_model(model_type="stable-diffusion", model_id="1.5-pruned-emaonly") # set the path to the auto-downloaded model -context.model_paths['stable-diffusion'] = resolve_downloaded_model_path(context, 'stable-diffusion', '1.5-pruned-emaonly') -load_model(context, 'stable-diffusion') +context.model_paths["stable-diffusion"] = resolve_downloaded_model_path( + context, "stable-diffusion", "1.5-pruned-emaonly" +) +load_model(context, "stable-diffusion") # generate the image -images = generate_images(context, prompt='Photograph of an astronaut riding a horse', seed=42, width=512, height=512) +images = generate_images(context, prompt="Photograph of an astronaut riding a horse", seed=42, width=512, height=512) # save the image -save_images(images, dir_path='D:\\path\\to\\images\\directory') +save_images(images, dir_path="D:\\path\\to\\images\\directory") log.info("Generated images!") diff --git a/examples/002b-generate-download_multiple_known_models.py b/examples/002b-generate-download_multiple_known_models.py index c705c7e..82323ec 100644 --- a/examples/002b-generate-download_multiple_known_models.py +++ b/examples/002b-generate-download_multiple_known_models.py @@ -1,7 +1,9 @@ from sdkit.models import download_models # download all three models (skips if already downloaded, resumes if downloaded partially) -download_models(models={ - 'stable-diffusion': ['1.4', '1.5-pruned-emaonly'], - 'gfpgan': '1.3', -}) +download_models( + models={ + "stable-diffusion": ["1.4", "1.5-pruned-emaonly"], + "gfpgan": "1.3", + } +) diff --git a/examples/003-generate-custom_model.py b/examples/003-generate-custom_model.py index b88cbba..7691afd 100644 --- a/examples/003-generate-custom_model.py +++ b/examples/003-generate-custom_model.py @@ -1,22 +1,22 @@ import sdkit -from sdkit.models import load_model from sdkit.generate import generate_images -from sdkit.utils import save_images, log +from sdkit.models import load_model +from sdkit.utils import log, save_images context = sdkit.Context() # set the path to the custom model file on the disk -context.model_paths['stable-diffusion'] = 'D:\\path\\to\\cmodelUpgradeStableD.safetensors' -context.model_configs['stable-diffusion'] = 'D:\\path\\to\\Cmodelsafetensor.yaml' +context.model_paths["stable-diffusion"] = "D:\\path\\to\\cmodelUpgradeStableD.safetensors" +context.model_configs["stable-diffusion"] = "D:\\path\\to\\Cmodelsafetensor.yaml" # the yaml config file is required if it's an unknown model to use. # it is not necessary for known models present in the models_db. -load_model(context, 'stable-diffusion') +load_model(context, "stable-diffusion") # generate the image -images = generate_images(context, prompt='Photograph of an astronaut riding a horse', seed=42, width=512, height=512) +images = generate_images(context, prompt="Photograph of an astronaut riding a horse", seed=42, width=512, height=512) # save the image -save_images(images, dir_path='D:\\path\\to\\images\\directory') +save_images(images, dir_path="D:\\path\\to\\images\\directory") log.info("Generated images!") diff --git a/examples/004-generate-custom_vae.py b/examples/004-generate-custom_vae.py index b316d98..975e4c3 100644 --- a/examples/004-generate-custom_vae.py +++ b/examples/004-generate-custom_vae.py @@ -1,20 +1,20 @@ import sdkit -from sdkit.models import load_model from sdkit.generate import generate_images -from sdkit.utils import save_images, log +from sdkit.models import load_model +from sdkit.utils import log, save_images context = sdkit.Context() # set the path to the model and VAE file on the disk -context.model_paths['stable-diffusion'] = 'D:\\path\\to\\model.ckpt' -context.model_paths['vae'] = 'D:\\path\\to\\vae.ckpt' -load_model(context, 'stable-diffusion') -load_model(context, 'vae') +context.model_paths["stable-diffusion"] = "D:\\path\\to\\model.ckpt" +context.model_paths["vae"] = "D:\\path\\to\\vae.ckpt" +load_model(context, "stable-diffusion") +load_model(context, "vae") # generate the image images = generate_images(context, prompt="Photograph of an astronaut riding a horse", seed=42, width=512, height=512) # save the image -save_images(images, dir_path='D:\\path\\to\\images\\directory') +save_images(images, dir_path="D:\\path\\to\\images\\directory") -log.info('Generated images with a custom VAE!') +log.info("Generated images with a custom VAE!") diff --git a/examples/005-generate-custom_hypernetwork.py b/examples/005-generate-custom_hypernetwork.py index 4b164a3..13fba5f 100644 --- a/examples/005-generate-custom_hypernetwork.py +++ b/examples/005-generate-custom_hypernetwork.py @@ -1,20 +1,27 @@ import sdkit -from sdkit.models import load_model from sdkit.generate import generate_images -from sdkit.utils import save_images, log +from sdkit.models import load_model +from sdkit.utils import log, save_images context = sdkit.Context() # set the path to the model and hypernetwork file on the disk -context.model_paths['stable-diffusion'] = 'D:\\path\\to\\model.ckpt' -context.model_paths['hypernetwork'] = 'D:\\path\\to\\hypernetwork.pt' -load_model(context, 'stable-diffusion') -load_model(context, 'hypernetwork') +context.model_paths["stable-diffusion"] = "D:\\path\\to\\model.ckpt" +context.model_paths["hypernetwork"] = "D:\\path\\to\\hypernetwork.pt" +load_model(context, "stable-diffusion") +load_model(context, "hypernetwork") # generate the image, hypernetwork_strength at 0.3 -images = generate_images(context, prompt="Photograph of an astronaut riding a horse", seed=42, width=512, height=512, hypernetwork_strength=0.3) +images = generate_images( + context, + prompt="Photograph of an astronaut riding a horse", + seed=42, + width=512, + height=512, + hypernetwork_strength=0.3, +) # save the image -save_images(images, dir_path='D:\\path\\to\\images\\directory') +save_images(images, dir_path="D:\\path\\to\\images\\directory") -log.info('Generated images with a custom VAE!') +log.info("Generated images with a custom VAE!") diff --git a/examples/006-generate-change_models.py b/examples/006-generate-change_models.py index c9cebf1..62fea59 100644 --- a/examples/006-generate-change_models.py +++ b/examples/006-generate-change_models.py @@ -3,28 +3,35 @@ # the unused hypernetwork and modelA.ckpt will be unloaded from memory automatically. import sdkit -from sdkit.models import load_model, unload_model from sdkit.generate import generate_images +from sdkit.models import load_model, unload_model from sdkit.utils import save_images context = sdkit.Context() # first image with modelA.ckpt, with hypernetwork -context.model_paths['stable-diffusion'] = 'D:\\path\\to\\modelA.ckpt' -context.model_paths['hypernetwork'] = 'D:\\path\\to\\hypernetwork.pt' -load_model(context, 'stable-diffusion') -load_model(context, 'hypernetwork') - -images = generate_images(context, prompt="Photograph of an astronaut riding a horse", seed=42, width=512, height=512, hypernetwork_strength=0.3) - -save_images(images, dir_path='D:\\path\\to\\images\\directory', file_name='image_modelA_with_hypernetwork') +context.model_paths["stable-diffusion"] = "D:\\path\\to\\modelA.ckpt" +context.model_paths["hypernetwork"] = "D:\\path\\to\\hypernetwork.pt" +load_model(context, "stable-diffusion") +load_model(context, "hypernetwork") + +images = generate_images( + context, + prompt="Photograph of an astronaut riding a horse", + seed=42, + width=512, + height=512, + hypernetwork_strength=0.3, +) + +save_images(images, dir_path="D:\\path\\to\\images\\directory", file_name="image_modelA_with_hypernetwork") # second image with modelB.ckpt, without hypernetwork -context.model_paths['stable-diffusion'] = 'D:\\path\\to\\modelB.ckpt' -context.model_paths['hypernetwork'] = None -load_model(context, 'stable-diffusion') -unload_model(context, 'hypernetwork') +context.model_paths["stable-diffusion"] = "D:\\path\\to\\modelB.ckpt" +context.model_paths["hypernetwork"] = None +load_model(context, "stable-diffusion") +unload_model(context, "hypernetwork") images = generate_images(context, prompt="Photograph of an astronaut riding a horse", seed=42, width=512, height=512) -save_images(images, dir_path='D:\\path\\to\\images\\directory', file_name='image_modelB_without_hypernetwork') +save_images(images, dir_path="D:\\path\\to\\images\\directory", file_name="image_modelB_without_hypernetwork") diff --git a/examples/007-generate-and-filter.py b/examples/007-generate-and-filter.py index 5fb0f98..e72db6e 100644 --- a/examples/007-generate-and-filter.py +++ b/examples/007-generate-and-filter.py @@ -1,21 +1,29 @@ import sdkit -from sdkit.generate import generate_images from sdkit.filter import apply_filters +from sdkit.generate import generate_images +from sdkit.models import load_model from sdkit.utils import save_images context = sdkit.Context() # setup model paths -context.model_paths['stable-diffusion'] = 'D:\\path\\to\\modelA.ckpt' -context.model_paths['gfpgan'] = 'C:\\path\\to\\gfpgan-1.3.pth' -load_model(context, 'stable-diffusion') -load_model(context, 'gfpgan') +context.model_paths["stable-diffusion"] = "D:\\path\\to\\modelA.ckpt" +context.model_paths["gfpgan"] = "C:\\path\\to\\gfpgan-1.3.pth" +load_model(context, "stable-diffusion") +load_model(context, "gfpgan") # generate image -images = generate_images(context, prompt="Photograph of an astronaut riding a horse", seed=42, width=512, height=512, hypernetwork_strength=0.3) +images = generate_images( + context, + prompt="Photograph of an astronaut riding a horse", + seed=42, + width=512, + height=512, + hypernetwork_strength=0.3, +) # apply filter -images_face_fixed = apply_filters(context, filters='gfpgan', images=images) +images_face_fixed = apply_filters(context, filters="gfpgan", images=images) # save images -save_images(images_face_fixed, dir_path='D:\\path\\to\\images\\directory', file_name='image_with_face_fix') +save_images(images_face_fixed, dir_path="D:\\path\\to\\images\\directory", file_name="image_with_face_fix") diff --git a/examples/100-filter-fix_faces.py b/examples/100-filter-fix_faces.py index ac62552..b7490c6 100644 --- a/examples/100-filter-fix_faces.py +++ b/examples/100-filter-fix_faces.py @@ -1,17 +1,18 @@ +from PIL import Image + import sdkit -from sdkit.models import load_model from sdkit.filter import apply_filters -from PIL import Image +from sdkit.models import load_model context = sdkit.Context() -image = Image.open('photo of a man.jpg') +image = Image.open("photo of a man.jpg") # set the path to the model file on the disk -context.model_paths['gfpgan'] = 'C:\\path\\to\\gfpgan-1.3.pth' -load_model(context, 'gfpgan') +context.model_paths["gfpgan"] = "C:\\path\\to\\gfpgan-1.3.pth" +load_model(context, "gfpgan") # apply the filter -image_face_fixed = apply_filters(context, 'gfpgan', image) +image_face_fixed = apply_filters(context, "gfpgan", image) # save the filtered image -image_face_fixed.save('man_face_fixed.jpg') +image_face_fixed.save("man_face_fixed.jpg") diff --git a/examples/101-filter-fix_faces_download_known_model.py b/examples/101-filter-fix_faces_download_known_model.py index e2cd3f3..08f297b 100644 --- a/examples/101-filter-fix_faces_download_known_model.py +++ b/examples/101-filter-fix_faces_download_known_model.py @@ -1,20 +1,21 @@ +from PIL import Image + import sdkit -from sdkit.models import download_model, resolve_downloaded_model_path, load_model from sdkit.filter import apply_filters -from PIL import Image +from sdkit.models import download_model, load_model, resolve_downloaded_model_path context = sdkit.Context() -image = Image.open('photo of a man.jpg') +image = Image.open("photo of a man.jpg") # download the model (skips if already downloaded, resumes if downloaded partially) -download_model(model_type='gfpgan', model_id='1.3') +download_model(model_type="gfpgan", model_id="1.3") # set the path to the auto-downloaded model -context.model_paths['gfpgan'] = resolve_downloaded_model_path(context, 'gfpgan', '1.3') -load_model(context, 'gfpgan') +context.model_paths["gfpgan"] = resolve_downloaded_model_path(context, "gfpgan", "1.3") +load_model(context, "gfpgan") # apply the filter -image_face_fixed = apply_filters(context, 'gfpgan', image) +image_face_fixed = apply_filters(context, "gfpgan", image) # save the filtered image -image_face_fixed.save('man_face_fixed.jpg') +image_face_fixed.save("man_face_fixed.jpg") diff --git a/examples/102-filter-upscale.py b/examples/102-filter-upscale.py index 5af3cdd..5e19e39 100644 --- a/examples/102-filter-upscale.py +++ b/examples/102-filter-upscale.py @@ -1,18 +1,19 @@ +from PIL import Image + import sdkit -from sdkit.models import load_model from sdkit.filter import apply_filters -from PIL import Image +from sdkit.models import load_model context = sdkit.Context() -image = Image.open('photo of a man.jpg') +image = Image.open("photo of a man.jpg") # set the path to the model file on the disk -context.model_paths['realesrgan'] = 'C:\\path\\to\\RealESRGAN_x4plus.pth' -load_model(context, 'realesrgan') +context.model_paths["realesrgan"] = "C:\\path\\to\\RealESRGAN_x4plus.pth" +load_model(context, "realesrgan") # apply the filter -scale = 4 # or 2 -image_upscaled = apply_filters(context, 'realesrgan', image, scale=scale) +scale = 4 # or 2 +image_upscaled = apply_filters(context, "realesrgan", image, scale=scale) # save the filtered image -image_upscaled.save('man_upscaled.jpg') +image_upscaled.save("man_upscaled.jpg") diff --git a/examples/103-filter-multiple.py b/examples/103-filter-multiple.py index e364eb5..8afe7a2 100644 --- a/examples/103-filter-multiple.py +++ b/examples/103-filter-multiple.py @@ -1,25 +1,26 @@ +from PIL import Image + import sdkit -from sdkit.models import load_model from sdkit.filter import apply_filters -from PIL import Image +from sdkit.models import load_model context = sdkit.Context() -image = Image.open('photo of a man.jpg') +image = Image.open("photo of a man.jpg") # set the path to the model files on the disk context.model_paths = { - 'gfpgan': 'C:\\path\\to\\gfpgan-1.3.pth', - 'realesrgan': 'C:\\path\\to\\realesrgan.pth', + "gfpgan": "C:\\path\\to\\gfpgan-1.3.pth", + "realesrgan": "C:\\path\\to\\realesrgan.pth", } -load_model(context, 'gfpgan') -load_model(context, 'realesrgan') +load_model(context, "gfpgan") +load_model(context, "realesrgan") # apply the filters -image_face_fixed = apply_filters(context, 'gfpgan', image) -image_scaled_up = apply_filters(context, 'realesrgan', image) -image_face_fixed_and_scaled_up = apply_filters(context, ['gfpgan', 'realesrgan'], image) +image_face_fixed = apply_filters(context, "gfpgan", image) +image_scaled_up = apply_filters(context, "realesrgan", image) +image_face_fixed_and_scaled_up = apply_filters(context, ["gfpgan", "realesrgan"], image) # save the filtered images -image_face_fixed.save('man_face_fixed.jpg') -image_scaled_up.save('man_scaled_up.jpg') -image_face_fixed_and_scaled_up.save('man_face_fixed_and_scaled_up.jpg') +image_face_fixed.save("man_face_fixed.jpg") +image_scaled_up.save("man_scaled_up.jpg") +image_face_fixed_and_scaled_up.save("man_face_fixed_and_scaled_up.jpg") diff --git a/examples/200-train-merge_models.py b/examples/200-train-merge_models.py index bda2857..cd3ca89 100644 --- a/examples/200-train-merge_models.py +++ b/examples/200-train-merge_models.py @@ -1,9 +1,9 @@ from sdkit.train import merge_models merge_models( - model0_path='D:\\path\\to\\model_a.ckpt', - model1_path='D:\\path\\to\\model_b.ckpt', + model0_path="D:\\path\\to\\model_a.ckpt", + model1_path="D:\\path\\to\\model_b.ckpt", ratio=0.3, - out_path='D:\\path\\to\\merged_model.safetensors', - use_fp16=True + out_path="D:\\path\\to\\merged_model.safetensors", + use_fp16=True, ) diff --git a/examples/300-device-low-vram.py b/examples/300-device-low-vram.py index 36d4834..72cd3b6 100644 --- a/examples/300-device-low-vram.py +++ b/examples/300-device-low-vram.py @@ -1,29 +1,29 @@ import sdkit -from sdkit.models import load_model from sdkit.generate import generate_images +from sdkit.models import load_model # this example will generate an image, using 3 different VRAM configurations. context = sdkit.Context() -context.model_paths['stable-diffusion'] = 'D:\\path\\to\\sd-v1-4.ckpt' -load_model(context, 'stable-diffusion') +context.model_paths["stable-diffusion"] = "D:\\path\\to\\sd-v1-4.ckpt" +load_model(context, "stable-diffusion") # TEST 1 - default (balanced) VRAM optimizations (much lower VRAM usage, performance is nearly as fast as max) -images = generate_images(context, prompt='Photograph of an astronaut riding a horse', seed=42, width=512, height=512) -images[0].save('image1.jpg') +images = generate_images(context, prompt="Photograph of an astronaut riding a horse", seed=42, width=512, height=512) +images[0].save("image1.jpg") # TEST 2 - no VRAM optimizations (maximum VRAM usage, fastest performance) -context.vram_usage_level = 'high' -load_model(context, 'stable-diffusion') # reload the model, to apply the change to VRAM optimization +context.vram_usage_level = "high" +load_model(context, "stable-diffusion") # reload the model, to apply the change to VRAM optimization -images = generate_images(context, prompt='Photograph of an astronaut riding a horse', seed=42, width=512, height=512) -images[0].save('image2.jpg') +images = generate_images(context, prompt="Photograph of an astronaut riding a horse", seed=42, width=512, height=512) +images[0].save("image2.jpg") # TEST 3 - lowest VRAM usage, slowest performance (for GPUs with less than 4gb of VRAM) -context.vram_usage_level = 'low' -load_model(context, 'stable-diffusion') # reload the model, to apply the change to VRAM optimization +context.vram_usage_level = "low" +load_model(context, "stable-diffusion") # reload the model, to apply the change to VRAM optimization -images = generate_images(context, prompt='Photograph of an astronaut riding a horse', seed=42, width=512, height=512) -images[0].save('image3.jpg') +images = generate_images(context, prompt="Photograph of an astronaut riding a horse", seed=42, width=512, height=512) +images[0].save("image3.jpg") diff --git a/examples/301-device-run_on_different_gpu.py b/examples/301-device-run_on_different_gpu.py index 744bc19..2796621 100644 --- a/examples/301-device-run_on_different_gpu.py +++ b/examples/301-device-run_on_different_gpu.py @@ -1,13 +1,13 @@ import sdkit -from sdkit.models import load_model from sdkit.generate import generate_images +from sdkit.models import load_model context = sdkit.Context() -context.model_paths['stable-diffusion'] = 'D:\\path\\to\\sd-v1-4.ckpt' -context.device = 'cuda:1' # assuming the PC has a second GPU with the id 'cuda:1' +context.model_paths["stable-diffusion"] = "D:\\path\\to\\sd-v1-4.ckpt" +context.device = "cuda:1" # assuming the PC has a second GPU with the id 'cuda:1' -load_model(context, 'stable-diffusion') +load_model(context, "stable-diffusion") # generate image -images = generate_images(context, prompt='Photograph of an astronaut riding a horse', seed=42, width=512, height=512) -images[0].save('image_from_second_gpu.jpg') +images = generate_images(context, prompt="Photograph of an astronaut riding a horse", seed=42, width=512, height=512) +images[0].save("image_from_second_gpu.jpg") diff --git a/examples/302-device-run_on_cpu.py b/examples/302-device-run_on_cpu.py index 097d0b8..b7904fa 100644 --- a/examples/302-device-run_on_cpu.py +++ b/examples/302-device-run_on_cpu.py @@ -1,13 +1,13 @@ import sdkit -from sdkit.models import load_model from sdkit.generate import generate_images +from sdkit.models import load_model context = sdkit.Context() -context.model_paths['stable-diffusion'] = 'D:\\path\\to\\sd-v1-4.ckpt' -context.device = 'cpu' +context.model_paths["stable-diffusion"] = "D:\\path\\to\\sd-v1-4.ckpt" +context.device = "cpu" -load_model(context, 'stable-diffusion') +load_model(context, "stable-diffusion") # generate image -images = generate_images(context, prompt='Photograph of an astronaut riding a horse', seed=42, width=512, height=512) -images[0].save('image_from_cpu.jpg') +images = generate_images(context, prompt="Photograph of an astronaut riding a horse", seed=42, width=512, height=512) +images[0].save("image_from_cpu.jpg") diff --git a/examples/304-device-run_on_multiple_gpus.py b/examples/304-device-run_on_multiple_gpus.py index f980e41..5470074 100644 --- a/examples/304-device-run_on_multiple_gpus.py +++ b/examples/304-device-run_on_multiple_gpus.py @@ -1,35 +1,40 @@ +import threading + import sdkit -from sdkit.models import load_model from sdkit.generate import generate_images +from sdkit.models import load_model from sdkit.utils import log -import threading def render_thread(device): - log.info(f'starting on device {device}') + log.info(f"starting on device {device}") context = sdkit.Context() - context.model_paths['stable-diffusion'] = 'D:\\path\\to\\sd-v1-4.ckpt' + context.model_paths["stable-diffusion"] = "D:\\path\\to\\sd-v1-4.ckpt" context.device = device - load_model(context, 'stable-diffusion') + load_model(context, "stable-diffusion") # generate image - log.info(f'generating on device {device}') - images = generate_images(context, prompt='Photograph of an astronaut riding a horse', seed=42, width=512, height=512) - images[0].save(f'image_from_{device}.jpg') + log.info(f"generating on device {device}") + images = generate_images( + context, prompt="Photograph of an astronaut riding a horse", seed=42, width=512, height=512 + ) + images[0].save(f"image_from_{device}.jpg") + + log.info(f"finished generating on device {device}") - log.info(f'finished generating on device {device}') def start_thread(device): - thread = threading.Thread(target=render_thread, kwargs={'device': device}) + thread = threading.Thread(target=render_thread, kwargs={"device": device}) thread.daemon = True - thread.name = f'SD-{device}' + thread.name = f"SD-{device}" thread.start() return thread + # assuming the PC has two CUDA-compatible GPUs, start on the first two GPUs: cuda:0 and cuda:1 -t0 = start_thread('cuda:0') -t1 = start_thread('cuda:1') +t0 = start_thread("cuda:0") +t1 = start_thread("cuda:1") t0.join() t1.join() diff --git a/examples/305-device-full_precision.py b/examples/305-device-full_precision.py index 0826027..9d9e529 100644 --- a/examples/305-device-full_precision.py +++ b/examples/305-device-full_precision.py @@ -1,13 +1,13 @@ import sdkit -from sdkit.models import load_model from sdkit.generate import generate_images +from sdkit.models import load_model context = sdkit.Context() -context.half_precision = False # loads in full precision (i.e. float32, instead of float16). consumes more VRAM -context.model_paths['stable-diffusion'] = 'D:\\path\\to\\sd-v1-4.ckpt' +context.half_precision = False # loads in full precision (i.e. float32, instead of float16). consumes more VRAM +context.model_paths["stable-diffusion"] = "D:\\path\\to\\sd-v1-4.ckpt" -load_model(context, 'stable-diffusion') +load_model(context, "stable-diffusion") # generate image -images = generate_images(context, prompt='Photograph of an astronaut riding a horse', seed=42, width=512, height=512) -images[0].save(f'image.jpg') +images = generate_images(context, prompt="Photograph of an astronaut riding a horse", seed=42, width=512, height=512) +images[0].save("image.jpg") diff --git a/examples/400-security-scan_model.py b/examples/400-security-scan_model.py index ed8cc9b..2b60e25 100644 --- a/examples/400-security-scan_model.py +++ b/examples/400-security-scan_model.py @@ -6,10 +6,16 @@ # lots of models, and then disable model-scanning each time a model is loaded # to improve load time. -model_path = 'D:\\path\\to\\malicious_model.ckpt' +model_path = "D:\\path\\to\\malicious_model.ckpt" scan_result = scan_model(model_path) if scan_result.issues_count > 0 or scan_result.infected_files > 0: - log.warn(f":warning: [bold red]Scan %s: %d scanned, %d issue, %d infected.[/bold red]" % (model_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files)) + log.warn( + ":warning: [bold red]Scan %s: %d scanned, %d issue, %d infected.[/bold red]" + % (model_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files) + ) else: - log.debug("Scan %s: [green]%d scanned, %d issue, %d infected.[/green]" % (model_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files)) + log.debug( + "Scan %s: [green]%d scanned, %d issue, %d infected.[/green]" + % (model_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files) + ) diff --git a/pyproject.toml b/pyproject.toml index c85de9e..49a31a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,3 +23,23 @@ keywords = ["stable diffusion", "ai", "art"] [project.urls] "Homepage" = "https://github.com/easydiffusion/sdkit" "Bug Tracker" = "https://github.com/easydiffusion/sdkit/issues" + +[tool.isort] +profile = "black" + +[tool.black] +line-length = 120 +include = '\.pyi?$' +exclude = ''' +/( + \.git + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | _build + | buck-out + | build + | dist +)/ +''' \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..e69de29 diff --git a/requirements_dev.txt b/requirements_dev.txt new file mode 100644 index 0000000..e20fa42 --- /dev/null +++ b/requirements_dev.txt @@ -0,0 +1,5 @@ +-r requirements.txt +black==23.1.0 +isort==5.12.0 +flake8==6.0.0 +autoflake==2.0.1 diff --git a/scripts/compare_benchmarks.py b/scripts/compare_benchmarks.py index bce8e3e..72849bb 100644 --- a/scripts/compare_benchmarks.py +++ b/scripts/compare_benchmarks.py @@ -1,22 +1,22 @@ import sys -import pandas + # import seaborn as sns -import matplotlib.pyplot as plt +import pandas stable = pandas.read_csv(sys.argv[1]) beta = pandas.read_csv(sys.argv[2]) -del stable['vram_usage'], stable['ram_usage'] -del beta['vram_usage'], beta['ram_usage'] +del stable["vram_usage"], stable["ram_usage"] +del beta["vram_usage"], beta["ram_usage"] -del stable['render_test'], stable['vram_tp90'], stable['vram_spike_test'], stable['overall_status'] -del beta['render_test'], beta['vram_tp90'], beta['vram_spike_test'], beta['overall_status'] +del stable["render_test"], stable["vram_tp90"], stable["vram_spike_test"], stable["overall_status"] +del beta["render_test"], beta["vram_tp90"], beta["vram_spike_test"], beta["overall_status"] -bad = (stable['max_vram (GB)'] < beta['max_vram (GB)']) # & (stable['model_filename'] == 'sd-v1-4.ckpt') -good = (stable['max_vram (GB)'] > beta['max_vram (GB)']) # & (stable['model_filename'] == 'sd-v1-4.ckpt') +bad = stable["max_vram (GB)"] < beta["max_vram (GB)"] # & (stable['model_filename'] == 'sd-v1-4.ckpt') +good = stable["max_vram (GB)"] > beta["max_vram (GB)"] # & (stable['model_filename'] == 'sd-v1-4.ckpt') -print('bad results', bad) -print('good results', good) +print("bad results", bad) +print("good results", good) # combined = pandas.concat([stable, beta], axis=0, ignore_index=False) # combined['version'] = (len(stable)*('v2.4',) + len(beta)*('v2.5',)) diff --git a/scripts/download_all_models.py b/scripts/download_all_models.py index e40c441..5064db3 100644 --- a/scripts/download_all_models.py +++ b/scripts/download_all_models.py @@ -1,21 +1,30 @@ +import argparse import os import sys -import argparse + import requests # args parser = argparse.ArgumentParser() -parser.add_argument('--models-dir', required=True, help="Folder path where the models will be downloaded, with a subdir for each model type") -parser.add_argument('--hash-only', action='store_true', default=False, help="Don't download, just calculate the hashes of the models in the downloaded dir") +parser.add_argument( + "--models-dir", + required=True, + help="Folder path where the models will be downloaded, with a subdir for each model type", +) +parser.add_argument( + "--hash-only", + action="store_true", + default=False, + help="Don't download, just calculate the hashes of the models in the downloaded dir", +) args = parser.parse_args() if len(sys.argv) < 2: - print('Error: need to provide a folder path as the first argument') + print("Error: need to provide a folder path as the first argument") exit(1) # setup -from sdkit.models import download_model, resolve_downloaded_model_path -from sdkit.models import get_models_db +from sdkit.models import download_model, get_models_db, resolve_downloaded_model_path from sdkit.utils import hash_file_quick db = get_models_db() @@ -31,16 +40,20 @@ if not args.hash_only: download_model(model_type, model_id, download_base_dir=download_dir) - model_path = resolve_downloaded_model_path(model_type=model_type, model_id=model_id, download_base_dir=download_dir) + model_path = resolve_downloaded_model_path( + model_type=model_type, model_id=model_id, download_base_dir=download_dir + ) quick_hash = hash_file_quick(model_path) model_info = db[model_type][model_id] - expected_quick_hash = model_info['quick_hash'] - expected_size = int(requests.get(model_info['url'], stream=True).headers['content-length']) + expected_quick_hash = model_info["quick_hash"] + expected_size = int(requests.get(model_info["url"], stream=True).headers["content-length"]) actual_size = os.path.getsize(model_path) if quick_hash != expected_quick_hash: - print(f'''ERROR! {model_type} {model_id}: + print( + f"""ERROR! {model_type} {model_id}: expected hash:\t{expected_quick_hash} actual:\t\t{quick_hash} expected size:\t{expected_size} - actual size:\t\t{actual_size}''') + actual size:\t\t{actual_size}""" + ) diff --git a/scripts/print_quick_hashes.py b/scripts/print_quick_hashes.py index d0437c2..873e47c 100644 --- a/scripts/print_quick_hashes.py +++ b/scripts/print_quick_hashes.py @@ -1,14 +1,18 @@ -''' +""" Utility script for calculating quick hashes for all the entries in the models db. Usage: python print_quick_hashes.py --help -''' +""" import argparse # args parser = argparse.ArgumentParser(description="arg parser") -parser.add_argument("--diff-only", action="store_true", help="Only show entries if the calculated quick-hash doesn't match the stored quick-hash") +parser.add_argument( + "--diff-only", + action="store_true", + help="Only show entries if the calculated quick-hash doesn't match the stored quick-hash", +) parser.set_defaults(diff_only=False) args = parser.parse_args() @@ -20,17 +24,17 @@ hashes_found = {} if args.diff_only: - print('Printing quick-hashes for only those URLs that do not match the configured quick-hash') + print("Printing quick-hashes for only those URLs that do not match the configured quick-hash") for model_type, models in models_db.items(): - print(f'{model_type} models:') + print(f"{model_type} models:") for model_id, model_info in models.items(): - url = model_info['url'] + url = model_info["url"] quick_hash = hash_url_quick(url) - if not args.diff_only or quick_hash != model_info.get('quick_hash'): - print(f'{model_id} = {quick_hash}') + if not args.diff_only or quick_hash != model_info.get("quick_hash"): + print(f"{model_id} = {quick_hash}") if quick_hash in hashes_found: - print(f'HASH CONFLICT! {quick_hash} already maps to {hashes_found[quick_hash]}') + print(f"HASH CONFLICT! {quick_hash} already maps to {hashes_found[quick_hash]}") else: hashes_found[quick_hash] = url diff --git a/scripts/run_everything.py b/scripts/run_everything.py index eae6f4c..3f179a4 100644 --- a/scripts/run_everything.py +++ b/scripts/run_everything.py @@ -1,67 +1,110 @@ -''' +""" Runs the desired Stable Diffusion models against the desired samplers. Saves the output images to disk, along with the peak RAM and VRAM usage, as well as the sampling performance. -''' +""" +import argparse import os import time -import argparse -from sdkit.utils import log, get_device_usage + +from sdkit.utils import get_device_usage, log # args parser = argparse.ArgumentParser() -parser.add_argument('--prompt', type=str, default="Photograph of an astronaut riding a horse", help="Prompt to use for generating the image") -parser.add_argument('--seed', type=int, default=42, help="Seed to use for generating the image") -parser.add_argument('--models-dir', type=str, required=True, help="Path to the directory containing the Stable Diffusion models") -parser.add_argument('--out-dir', type=str, required=True, help="Path to the directory to save the generated images and test results") -parser.add_argument('--models', default='all', help="Comma-separated list of model filenames (without spaces) to test. Default: all") -parser.add_argument('--exclude-models', default=set(), help="Comma-separated list of model filenames (without spaces) to skip. Supports wildcards (without commas), for e.g. --exclude-models *.safetensors, or --exclude-models sd-1-4*") -parser.add_argument('--samplers', default='all', help="Comma-separated list of sampler names (without spaces) to test. Default: all") -parser.add_argument('--exclude-samplers', default=set(), help="Comma-separated list of sampler names (without spaces) to skip") -parser.add_argument('--init-image', default=None, help="Path to an initial image to use. Only works with DDIM sampler (for now).") -parser.add_argument('--vram-usage-levels', default='balanced', help="Comma-separated list of VRAM usage levels. Allowed values: low, balanced, high") -parser.add_argument('--skip-completed', default=False, help="Skips a model or sampler if it has already been tested (i.e. an output image exists for it)") -parser.add_argument('--steps', default=25, type=int, help="Number of inference steps to run for each sampler") -parser.add_argument('--sizes', default='auto', type=str, help="Comma-separated list of image sizes (width x height). No spaces. E.g. 512x512 or 512x512,1024x768. Defaults to what the model needs (512x512 or 768x768, if the model requires 768)") -parser.add_argument('--device', default='cuda:0', type=str, help="Specify the device to run on. E.g. cpu or cuda:0 or cuda:1 etc") -parser.add_argument('--live-perf', action="store_true", help="Print the RAM and VRAM usage stats every few seconds") +parser.add_argument( + "--prompt", + type=str, + default="Photograph of an astronaut riding a horse", + help="Prompt to use for generating the image", +) +parser.add_argument("--seed", type=int, default=42, help="Seed to use for generating the image") +parser.add_argument( + "--models-dir", type=str, required=True, help="Path to the directory containing the Stable Diffusion models" +) +parser.add_argument( + "--out-dir", type=str, required=True, help="Path to the directory to save the generated images and test results" +) +parser.add_argument( + "--models", default="all", help="Comma-separated list of model filenames (without spaces) to test. Default: all" +) +parser.add_argument( + "--exclude-models", + default=set(), + help="Comma-separated list of model filenames (without spaces) to skip. Supports wildcards (without commas), for e.g. --exclude-models *.safetensors, or --exclude-models sd-1-4*", +) +parser.add_argument( + "--samplers", default="all", help="Comma-separated list of sampler names (without spaces) to test. Default: all" +) +parser.add_argument( + "--exclude-samplers", default=set(), help="Comma-separated list of sampler names (without spaces) to skip" +) +parser.add_argument( + "--init-image", default=None, help="Path to an initial image to use. Only works with DDIM sampler (for now)." +) +parser.add_argument( + "--vram-usage-levels", + default="balanced", + help="Comma-separated list of VRAM usage levels. Allowed values: low, balanced, high", +) +parser.add_argument( + "--skip-completed", + default=False, + help="Skips a model or sampler if it has already been tested (i.e. an output image exists for it)", +) +parser.add_argument("--steps", default=25, type=int, help="Number of inference steps to run for each sampler") +parser.add_argument( + "--sizes", + default="auto", + type=str, + help="Comma-separated list of image sizes (width x height). No spaces. E.g. 512x512 or 512x512,1024x768. Defaults to what the model needs (512x512 or 768x768, if the model requires 768)", +) +parser.add_argument( + "--device", default="cuda:0", type=str, help="Specify the device to run on. E.g. cpu or cuda:0 or cuda:1 etc" +) +parser.add_argument("--live-perf", action="store_true", help="Print the RAM and VRAM usage stats every few seconds") parser.set_defaults(live_perf=False) args = parser.parse_args() -if args.models != 'all': args.models = set(args.models.split(',')) -if args.exclude_models != set() and '*' not in args.exclude_models: args.exclude_models = set(args.exclude_models.split(',')) -if args.samplers != 'all': args.samplers = set(args.samplers.split(',')) -if args.exclude_samplers != set(): args.exclude_samplers = set(args.exclude_samplers.split(',')) -if args.sizes != 'auto': args.sizes = [tuple(map(lambda x: int(x), size.split('x'))) for size in args.sizes.split(',')] +if args.models != "all": + args.models = set(args.models.split(",")) +if args.exclude_models != set() and "*" not in args.exclude_models: + args.exclude_models = set(args.exclude_models.split(",")) +if args.samplers != "all": + args.samplers = set(args.samplers.split(",")) +if args.exclude_samplers != set(): + args.exclude_samplers = set(args.exclude_samplers.split(",")) +if args.sizes != "auto": + args.sizes = [tuple(map(lambda x: int(x), size.split("x"))) for size in args.sizes.split(",")] # setup -log.info('Starting..') -from sdkit.models import load_model +log.info("Starting..") from sdkit.generate.sampler import default_samplers, k_samplers +from sdkit.models import load_model -sd_models = set([f for f in os.listdir(args.models_dir) if os.path.splitext(f)[1] in ('.ckpt', '.safetensors')]) +sd_models = set([f for f in os.listdir(args.models_dir) if os.path.splitext(f)[1] in (".ckpt", ".safetensors")]) all_samplers = set(default_samplers.samplers.keys()) | set(k_samplers.samplers.keys()) -args.vram_usage_levels = args.vram_usage_levels.split(',') +args.vram_usage_levels = args.vram_usage_levels.split(",") -if isinstance(args.exclude_models, str) and '*' in args.exclude_models: +if isinstance(args.exclude_models, str) and "*" in args.exclude_models: import fnmatch + args.exclude_models = set(fnmatch.filter(sd_models, args.exclude_models)) -models_to_test = sd_models if args.models == 'all' else args.models +models_to_test = sd_models if args.models == "all" else args.models models_to_test -= args.exclude_models -samplers_to_test = all_samplers if args.samplers == 'all' else args.samplers +samplers_to_test = all_samplers if args.samplers == "all" else args.samplers samplers_to_test -= args.exclude_samplers vram_usage_levels_to_test = args.vram_usage_levels if args.init_image is not None: if not os.path.exists(args.init_image): - log.error(f'Error! Could not an initial image at the path specified: {args.init_image}') + log.error(f"Error! Could not an initial image at the path specified: {args.init_image}") exit(1) - if samplers_to_test != {'ddim'}: + if samplers_to_test != {"ddim"}: log.error('We only support the "ddim" sampler for img2img right now!') exit(1) - all_samplers = {'ddim'} + all_samplers = {"ddim"} # setup the test from sdkit import Context @@ -69,19 +112,34 @@ from sdkit.models import get_model_info_from_db from sdkit.utils import hash_file_quick -perf_results = [['model_filename', 'vram_usage_level', 'sampler_name', 'max_ram (GB)', 'max_vram (GB)', 'image_size', 'time_taken (s)', 'speed (it/s)', 'render_test', 'ram_usage', 'vram_usage']] -perf_results_file = f'perf_results_{time.time()}.csv' +perf_results = [ + [ + "model_filename", + "vram_usage_level", + "sampler_name", + "max_ram (GB)", + "max_vram (GB)", + "image_size", + "time_taken (s)", + "speed (it/s)", + "render_test", + "ram_usage", + "vram_usage", + ] +] +perf_results_file = f"perf_results_{time.time()}.csv" # print test info -log.info('---') -log.info(f'Models actually being tested: {models_to_test}') -log.info(f'Samplers actually being tested: {samplers_to_test}') -log.info(f'VRAM usage levels being tested: {vram_usage_levels_to_test}') -log.info(f'Image sizes being tested: {args.sizes}') -log.info('---') -log.info(f'Available models: {sd_models}') -log.info(f'Available samplers: {all_samplers}') -log.info('---') +log.info("---") +log.info(f"Models actually being tested: {models_to_test}") +log.info(f"Samplers actually being tested: {samplers_to_test}") +log.info(f"VRAM usage levels being tested: {vram_usage_levels_to_test}") +log.info(f"Image sizes being tested: {args.sizes}") +log.info("---") +log.info(f"Available models: {sd_models}") +log.info(f"Available samplers: {all_samplers}") +log.info("---") + # run the test def run_test(): @@ -89,7 +147,7 @@ def run_test(): model_dir_path = os.path.join(args.out_dir, model_filename) if args.skip_completed and is_model_already_tested(model_dir_path): - log.info(f'skipping model {model_filename} since it has already been processed at {model_dir_path}') + log.info(f"skipping model {model_filename} since it has already been processed at {model_dir_path}") continue for vram_usage_level in vram_usage_levels_to_test: @@ -102,24 +160,33 @@ def run_test(): os.makedirs(out_dir_path, exist_ok=True) try: - context.model_paths['stable-diffusion'] = os.path.join(args.models_dir, model_filename) - load_model(context, 'stable-diffusion', scan_model=False) + context.model_paths["stable-diffusion"] = os.path.join(args.models_dir, model_filename) + load_model(context, "stable-diffusion", scan_model=False) except Exception as e: log.exception(e) - perf_results.append([model_filename, vram_usage_level, 'n/a', 'n/a', 'n/a', 'n/a', 'n/a', 'n/a', False, [], []]) + perf_results.append( + [model_filename, vram_usage_level, "n/a", "n/a", "n/a", "n/a", "n/a", "n/a", False, [], []] + ) log_perf_results() continue - # run a warm-up, before running the actual samplers - log.info('Warming up..') + log.info("Warming up..") try: - generate_images(context, prompt='Photograph of an astronaut riding a horse', num_inference_steps=10, seed=42, width=512, height=512, sampler_name='euler_a') + generate_images( + context, + prompt="Photograph of an astronaut riding a horse", + num_inference_steps=10, + seed=42, + width=512, + height=512, + sampler_name="euler_a", + ) except: pass - if args.sizes == 'auto': - min_size = get_min_size(context.model_paths['stable-diffusion']) + if args.sizes == "auto": + min_size = get_min_size(context.model_paths["stable-diffusion"]) sizes = [(min_size, min_size)] else: sizes = args.sizes @@ -130,29 +197,35 @@ def run_test(): del context + def run_samplers(context, model_filename, out_dir_path, width, height, vram_usage_level): - from threading import Thread, Event from queue import Queue + from threading import Event, Thread for sampler_name in samplers_to_test: # setup - img_path = os.path.join(out_dir_path, f'{sampler_name}_0.jpeg') + img_path = os.path.join(out_dir_path, f"{sampler_name}_0.jpeg") if args.skip_completed and os.path.exists(img_path): - log.info(f'skipping sampler {sampler_name} since it has already been processed at {img_path}') + log.info(f"skipping sampler {sampler_name} since it has already been processed at {img_path}") continue - log.info(f'Model: {model_filename}, Sampler: {sampler_name}, Size: {width}x{height}, VRAM Usage Level: {vram_usage_level}') + log.info( + f"Model: {model_filename}, Sampler: {sampler_name}, Size: {width}x{height}, VRAM Usage Level: {vram_usage_level}" + ) # start profiling prof_thread_stop_event = Event() ram_usage = Queue() vram_usage = Queue() - prof_thread = Thread(target=profiling_thread, kwargs={ - 'device': context.device, - 'prof_thread_stop_event': prof_thread_stop_event, - 'ram_usage': ram_usage, - 'vram_usage': vram_usage, - }) + prof_thread = Thread( + target=profiling_thread, + kwargs={ + "device": context.device, + "prof_thread_stop_event": prof_thread_stop_event, + "ram_usage": ram_usage, + "vram_usage": vram_usage, + }, + ) prof_thread.start() t, speed = time.time(), 0 @@ -164,7 +237,8 @@ def run_samplers(context, model_filename, out_dir_path, width, height, vram_usag prompt=args.prompt, seed=args.seed, num_inference_steps=args.steps, - width=width, height=height, + width=width, + height=height, sampler_name=sampler_name, init_image=args.init_image, ) @@ -182,10 +256,25 @@ def run_samplers(context, model_filename, out_dir_path, width, height, vram_usag prof_thread_stop_event.set() prof_thread.join() - perf_results.append([model_filename, vram_usage_level, sampler_name, f'{max(ram_usage.queue):.1f}', f'{max(vram_usage.queue):.1f}', f'{width}x{height}', f'{t:.1f}', f'{speed:.1f}', render_success, list(ram_usage.queue), list(vram_usage.queue)]) + perf_results.append( + [ + model_filename, + vram_usage_level, + sampler_name, + f"{max(ram_usage.queue):.1f}", + f"{max(vram_usage.queue):.1f}", + f"{width}x{height}", + f"{t:.1f}", + f"{speed:.1f}", + render_success, + list(ram_usage.queue), + list(vram_usage.queue), + ] + ) log_perf_results() + def profiling_thread(device, prof_thread_stop_event, ram_usage, vram_usage): import time @@ -197,60 +286,65 @@ def profiling_thread(device, prof_thread_stop_event, ram_usage, vram_usage): time.sleep(1) + def is_model_already_tested(out_dir_path): if not os.path.exists(out_dir_path): return False - sampler_files = list(map(lambda x: os.path.join(out_dir_path, f'{x}_0.jpeg'), samplers_to_test)) + sampler_files = list(map(lambda x: os.path.join(out_dir_path, f"{x}_0.jpeg"), samplers_to_test)) images_exist = list(map(lambda x: os.path.exists(x), sampler_files)) return all(images_exist) + def get_min_size(model_path, default_size=512): model_info = get_model_info_from_db(quick_hash=hash_file_quick(model_path)) - return model_info['metadata']['min_size'] if model_info is not None else default_size + return model_info["metadata"]["min_size"] if model_info is not None else default_size + def log_perf_results(): - import pandas as pd - import numpy as np from importlib.metadata import version - pd.set_option('display.max_rows', 1000) + import numpy as np + import pandas as pd - print('\n-- Performance summary --') + pd.set_option("display.max_rows", 1000) + + print("\n-- Performance summary --") print(f"sdkit version: {version('sdkit')}") print(f"stable-diffusion-sdkit version: {version('stable-diffusion-sdkit')}") - print(f'Device: {args.device}') - print(f'Num inference steps: {args.steps}') - print('') + print(f"Device: {args.device}") + print(f"Num inference steps: {args.steps}") + print("") df = pd.DataFrame(data=perf_results) df = df.rename(columns=df.iloc[0]).drop(df.index[0]) - df = df.sort_values(by=['image_size', 'model_filename'], ascending=False) + df = df.sort_values(by=["image_size", "model_filename"], ascending=False) df = df.reset_index(drop=True) - df['vram_tp90'] = df['vram_usage'].apply(lambda x: np.percentile(x, 90)) - df['vram_tp100'] = df['vram_usage'].apply(lambda x: np.percentile(x, 100)) - df['vram_spike_test'] = abs((df['vram_tp100'] - df['vram_tp90'])) < 0.5 # okay with a spike of 500 MB - df['vram_tp90'] = df['vram_tp90'].apply(lambda x: f'{x:.1f}') - df['overall_status'] = df['render_test'] & df['vram_spike_test'] + df["vram_tp90"] = df["vram_usage"].apply(lambda x: np.percentile(x, 90)) + df["vram_tp100"] = df["vram_usage"].apply(lambda x: np.percentile(x, 100)) + df["vram_spike_test"] = abs((df["vram_tp100"] - df["vram_tp90"])) < 0.5 # okay with a spike of 500 MB + df["vram_tp90"] = df["vram_tp90"].apply(lambda x: f"{x:.1f}") + df["overall_status"] = df["render_test"] & df["vram_spike_test"] - df['vram_spike_test'] = df['vram_spike_test'].apply(lambda is_pass: 'pass' if is_pass else 'FAIL') - df['render_test'] = df['render_test'].apply(lambda is_pass: 'pass' if is_pass else 'FAIL') - df['overall_status'] = df['overall_status'].apply(lambda is_pass: 'pass' if is_pass else 'FAIL') + df["vram_spike_test"] = df["vram_spike_test"].apply(lambda is_pass: "pass" if is_pass else "FAIL") + df["render_test"] = df["render_test"].apply(lambda is_pass: "pass" if is_pass else "FAIL") + df["overall_status"] = df["overall_status"].apply(lambda is_pass: "pass" if is_pass else "FAIL") - del df['vram_tp100'] + del df["vram_tp100"] out_file = os.path.join(args.out_dir, perf_results_file) df.to_csv(out_file, index=False) # print the summary - del df['vram_usage'] - del df['ram_usage'] + del df["vram_usage"] + del df["ram_usage"] print(df) - print('') + print("") + + print(f"Written the performance summary to {out_file}\n") - print(f'Written the performance summary to {out_file}\n') run_test() diff --git a/scripts/txt2img_original.py b/scripts/txt2img_original.py index c4d32cd..7871232 100644 --- a/scripts/txt2img_original.py +++ b/scripts/txt2img_original.py @@ -1,31 +1,30 @@ -''' +""" A simplified version of the original txt2img.py script that is included with Stable Diffusion 2.0. Useful for testing responses and memory usage against the original script. -''' +""" + +from itertools import islice -import torch import numpy as np +import torch +from einops import rearrange +from ldm.models.diffusion.plms import PLMSSampler +from ldm.util import instantiate_from_config from omegaconf import OmegaConf from PIL import Image -from tqdm import tqdm, trange -from itertools import islice -from einops import rearrange from pytorch_lightning import seed_everything from torch import autocast -from contextlib import nullcontext - -from ldm.util import instantiate_from_config -from ldm.models.diffusion.ddim import DDIMSampler -from ldm.models.diffusion.plms import PLMSSampler -from ldm.models.diffusion.dpm_solver import DPMSolverSampler +from tqdm import tqdm, trange torch.set_grad_enabled(False) + def chunk(it, size): it = iter(it) return iter(lambda: tuple(islice(it, size)), ()) + def load_model_from_config(config, ckpt, verbose=False): print(f"Loading model from {ckpt}") pl_sd = torch.load(ckpt, map_location="cpu") @@ -45,11 +44,12 @@ def load_model_from_config(config, ckpt, verbose=False): model.eval() return model + def main(): seed_everything(42) config = OmegaConf.load("path/to/models/stable-diffusion/v1-inference.yaml") - model = load_model_from_config(config, f"path/to/models/stable-diffusion/sd-v1-4.ckpt") + model = load_model_from_config(config, "path/to/models/stable-diffusion/sd-v1-4.ckpt") device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = model.to(device) @@ -67,37 +67,37 @@ def main(): for i in range(4): precision_scope = autocast - with torch.no_grad(), \ - precision_scope("cuda"), \ - model.ema_scope(): - for n in trange(1, desc="Sampling"): - for prompts in tqdm(data, desc="data"): - uc = None - uc = model.get_learned_conditioning(batch_size * [""]) - if isinstance(prompts, tuple): - prompts = list(prompts) - c = model.get_learned_conditioning(prompts) - shape = [4, 2048 // 8, 2048 // 8] - try: - samples, _ = sampler.sample(S=1, - conditioning=c, - batch_size=1, - shape=shape, - verbose=False, - unconditional_guidance_scale=7.5, - unconditional_conditioning=uc, - eta=0., - x_T=start_code) - - x_samples = model.decode_first_stage(samples) - x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) - - for x_sample in x_samples: - x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') - img = Image.fromarray(x_sample.astype(np.uint8)) - base_count += 1 - sample_count += 1 - except Exception as e: - print(e) + with torch.no_grad(), precision_scope("cuda"), model.ema_scope(): + for n in trange(1, desc="Sampling"): + for prompts in tqdm(data, desc="data"): + uc = None + uc = model.get_learned_conditioning(batch_size * [""]) + if isinstance(prompts, tuple): + prompts = list(prompts) + c = model.get_learned_conditioning(prompts) + shape = [4, 2048 // 8, 2048 // 8] + try: + samples, _ = sampler.sample( + S=1, + conditioning=c, + batch_size=1, + shape=shape, + verbose=False, + unconditional_guidance_scale=7.5, + unconditional_conditioning=uc, + eta=0.0, + x_T=start_code, + ) + + x_samples = model.decode_first_stage(samples) + x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) + + for x_sample in x_samples: + x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c") + Image.fromarray(x_sample.astype(np.uint8)) + sample_count += 1 + except Exception as e: + print(e) + main() diff --git a/sdkit/__init__.py b/sdkit/__init__.py index 51dc31e..b283749 100644 --- a/sdkit/__init__.py +++ b/sdkit/__init__.py @@ -1,8 +1,9 @@ from threading import local + class Context(local): def __init__(self) -> None: - self._device: str = 'cuda:0' + self._device: str = "cuda:0" self._half_precision: bool = True self._vram_usage_level = None @@ -12,7 +13,7 @@ def __init__(self) -> None: self.device_name: str = None self.vram_optimizations: set = set() - ''' + """ **Do not change this unless you know what you're doing!** Instead set `context.vram_usage_level` to `'low'`, `'balanced'` or `'high'`. Possible values: @@ -32,8 +33,8 @@ def __init__(self) -> None: it consumes about 1 GB more than `'KEEP_ENTIRE_MODEL_IN_CPU'`, for a much faster rendering performance. * `'SET_ATTENTION_STEP_TO_16'`: Very useful! Lowest GPU memory utilization, but slowest performance. - ''' - self.vram_usage_level = 'balanced' + """ + self.vram_usage_level = "balanced" # hacky approach, but we need to enforce full precision for some devices # we also need to force full precision for these devices (haven't implemented this yet): @@ -45,9 +46,10 @@ def device(self): @device.setter def device(self, d): self._device = d - if d == 'cpu': + if d == "cpu": from sdkit.utils import log - log.info('forcing full precision for device: cpu') + + log.info("forcing full precision for device: cpu") self._half_precision = False @property @@ -56,7 +58,7 @@ def half_precision(self): @half_precision.setter def half_precision(self, h): - self._half_precision = h if self._device != 'cpu' else False + self._half_precision = h if self._device != "cpu" else False @property def vram_usage_level(self): @@ -66,9 +68,9 @@ def vram_usage_level(self): def vram_usage_level(self, level): self._vram_usage_level = level - if level == 'low': - self.vram_optimizations = {'KEEP_ENTIRE_MODEL_IN_CPU', 'SET_ATTENTION_STEP_TO_16'} - elif level == 'balanced': - self.vram_optimizations = {'KEEP_FS_AND_CS_IN_CPU', 'SET_ATTENTION_STEP_TO_16'} - elif level == 'high': - self.vram_optimizations = {'SET_ATTENTION_STEP_TO_2'} + if level == "low": + self.vram_optimizations = {"KEEP_ENTIRE_MODEL_IN_CPU", "SET_ATTENTION_STEP_TO_16"} + elif level == "balanced": + self.vram_optimizations = {"KEEP_FS_AND_CS_IN_CPU", "SET_ATTENTION_STEP_TO_16"} + elif level == "high": + self.vram_optimizations = {"SET_ATTENTION_STEP_TO_2"} diff --git a/sdkit/filter/apply_filters.py b/sdkit/filter/apply_filters.py index 29300c1..4553cda 100644 --- a/sdkit/filter/apply_filters.py +++ b/sdkit/filter/apply_filters.py @@ -1,31 +1,33 @@ -from . import gfpgan, realesrgan - from sdkit import Context -from sdkit.utils import base64_str_to_img, log, gc +from sdkit.utils import base64_str_to_img, gc, log + +from . import gfpgan, realesrgan filter_modules = { - 'gfpgan': gfpgan, - 'realesrgan': realesrgan, + "gfpgan": gfpgan, + "realesrgan": realesrgan, } + def apply_filters(context: Context, filters, images, **kwargs): - ''' + """ * context: Context * filters: filter_type (string) or list of strings * images: str or PIL.Image or list of str/PIL.Image - image to filter. if a string is passed, it needs to be a base64-encoded image returns: PIL.Image - filtered image - ''' + """ images = images if isinstance(images, list) else [images] filters = filters if isinstance(filters, list) else [filters] return [apply_filter_single_image(context, filters, image, **kwargs) for image in images] + def apply_filter_single_image(context, filters, image, **kwargs): image = base64_str_to_img(image) if isinstance(image, str) else image for filter_type in filters: - log.info(f'Applying {filter_type}...') + log.info(f"Applying {filter_type}...") gc(context) image = filter_modules[filter_type].apply(context, image, **kwargs) diff --git a/sdkit/filter/gfpgan.py b/sdkit/filter/gfpgan.py index 36f625b..d3150e8 100644 --- a/sdkit/filter/gfpgan.py +++ b/sdkit/filter/gfpgan.py @@ -1,24 +1,29 @@ -import torch +from threading import Lock + import numpy as np +import torch from PIL import Image -from threading import Lock from sdkit import Context -gfpgan_temp_device_lock = Lock() # workaround: gfpgan currently can only start on one device at a time. +gfpgan_temp_device_lock = Lock() # workaround: gfpgan currently can only start on one device at a time. + def apply(context: Context, image, **kwargs): # This lock is only ever used here. No need to use timeout for the request. Should never deadlock. - with gfpgan_temp_device_lock: # Wait for any other devices to complete before starting. + with gfpgan_temp_device_lock: # Wait for any other devices to complete before starting. # hack for a bug in facexlib: https://github.com/xinntao/facexlib/pull/19/files from facexlib.detection import retinaface + retinaface.device = torch.device(context.device) - image = image.convert('RGB') - image = np.array(image, dtype=np.uint8)[...,::-1] + image = image.convert("RGB") + image = np.array(image, dtype=np.uint8)[..., ::-1] - _, _, output = context.models['gfpgan'].enhance(image, has_aligned=False, only_center_face=False, paste_back=True) - output = output[:,:,::-1] + _, _, output = context.models["gfpgan"].enhance( + image, has_aligned=False, only_center_face=False, paste_back=True + ) + output = output[:, :, ::-1] output = Image.fromarray(output) - return output \ No newline at end of file + return output diff --git a/sdkit/filter/realesrgan.py b/sdkit/filter/realesrgan.py index a7f72c8..f8a7d2a 100644 --- a/sdkit/filter/realesrgan.py +++ b/sdkit/filter/realesrgan.py @@ -3,12 +3,13 @@ from sdkit import Context + def apply(context: Context, image, scale=4, **kwargs): - image = image.convert('RGB') - image = np.array(image, dtype=np.uint8)[...,::-1] + image = image.convert("RGB") + image = np.array(image, dtype=np.uint8)[..., ::-1] - output, _ = context.models['realesrgan'].enhance(image, outscale=scale) - output = output[:,:,::-1] + output, _ = context.models["realesrgan"].enhance(image, outscale=scale) + output = output[:, :, ::-1] output = Image.fromarray(output) - return output \ No newline at end of file + return output diff --git a/sdkit/generate/__init__.py b/sdkit/generate/__init__.py index 11517ac..683614b 100644 --- a/sdkit/generate/__init__.py +++ b/sdkit/generate/__init__.py @@ -1,3 +1 @@ -from .image_generator import ( - generate_images, -) +from .image_generator import generate_images diff --git a/sdkit/generate/image_generator.py b/sdkit/generate/image_generator.py index 87c3d7a..9c1f358 100644 --- a/sdkit/generate/image_generator.py +++ b/sdkit/generate/image_generator.py @@ -1,40 +1,42 @@ +from contextlib import nullcontext + import torch -from tqdm import trange from pytorch_lightning import seed_everything -from contextlib import nullcontext +from tqdm import trange from sdkit import Context -from sdkit.utils import latent_samples_to_images, base64_str_to_img, get_image_latent_and_mask, apply_color_profile -from sdkit.utils import gc +from sdkit.utils import ( + apply_color_profile, + base64_str_to_img, + gc, + get_image_latent_and_mask, + latent_samples_to_images, +) from .prompt_parser import get_cond_and_uncond from .sampler import make_samples + def generate_images( - context: Context, - prompt: str = "", - negative_prompt: str = "", - - seed: int = 42, - width: int = 512, - height: int = 512, - - num_outputs: int = 1, - num_inference_steps: int = 25, - guidance_scale: float = 7.5, - - init_image = None, - init_image_mask = None, - prompt_strength: float = 0.8, - preserve_init_image_color_profile = False, - - sampler_name: str = "euler_a", # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms", - # "dpm_solver_stability", "dpmpp_2s_a", "dpmpp_2m", "dpmpp_sde", "dpm_fast" - # "dpm_adaptive" - hypernetwork_strength: float = 0, - - callback=None, - ): + context: Context, + prompt: str = "", + negative_prompt: str = "", + seed: int = 42, + width: int = 512, + height: int = 512, + num_outputs: int = 1, + num_inference_steps: int = 25, + guidance_scale: float = 7.5, + init_image=None, + init_image_mask=None, + prompt_strength: float = 0.8, + preserve_init_image_color_profile=False, + sampler_name: str = "euler_a", # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms", + # "dpm_solver_stability", "dpmpp_2s_a", "dpmpp_2m", "dpmpp_sde", "dpm_fast" + # "dpm_adaptive" + hypernetwork_strength: float = 0, + callback=None, +): req_args = locals() try: @@ -43,27 +45,29 @@ def generate_images( seed_everything(seed) precision_scope = torch.autocast if context.half_precision and context.device != "cpu" else nullcontext - if 'stable-diffusion' not in context.models: - raise RuntimeError("The model for Stable Diffusion has not been loaded yet! If you've tried to load it, please check the logs above this message for errors (while loading the model).") + if "stable-diffusion" not in context.models: + raise RuntimeError( + "The model for Stable Diffusion has not been loaded yet! If you've tried to load it, please check the logs above this message for errors (while loading the model)." + ) - model = context.models['stable-diffusion'] - if 'hypernetwork' in context.models: - context.models['hypernetwork']['hypernetwork_strength'] = hypernetwork_strength + model = context.models["stable-diffusion"] + if "hypernetwork" in context.models: + context.models["hypernetwork"]["hypernetwork_strength"] = hypernetwork_strength with precision_scope("cuda"): cond, uncond = get_cond_and_uncond(prompt, negative_prompt, num_outputs, model) generate_fn = txt2img if init_image is None else img2img common_sampler_params = { - 'context': context, - 'sampler_name': sampler_name, - 'seed': seed, - 'batch_size': num_outputs, - 'shape': [4, height // 8, width // 8], - 'cond': cond, - 'uncond': uncond, - 'guidance_scale': guidance_scale, - 'callback': callback, + "context": context, + "sampler_name": sampler_name, + "seed": seed, + "batch_size": num_outputs, + "shape": [4, height // 8, width // 8], + "cond": cond, + "uncond": uncond, + "guidance_scale": guidance_scale, + "callback": callback, } with torch.no_grad(), precision_scope("cuda"): @@ -75,27 +79,47 @@ def generate_images( finally: context.init_image_latent, context.init_image_mask_tensor = None, None + def txt2img(params: dict, context: Context, num_inference_steps, **kwargs): - params.update({ - 'steps': num_inference_steps, - }) + params.update( + { + "steps": num_inference_steps, + } + ) samples = make_samples(**params) return latent_samples_to_images(context, samples) -def img2img(params: dict, context: Context, num_inference_steps, num_outputs, width, height, init_image, init_image_mask, prompt_strength, preserve_init_image_color_profile, **kwargs): + +def img2img( + params: dict, + context: Context, + num_inference_steps, + num_outputs, + width, + height, + init_image, + init_image_mask, + prompt_strength, + preserve_init_image_color_profile, + **kwargs, +): init_image = get_image(init_image) init_image_mask = get_image(init_image_mask) - if not hasattr(context, 'init_image_latent') or context.init_image_latent is None: - context.init_image_latent, context.init_image_mask_tensor = get_image_latent_and_mask(context, init_image, init_image_mask, width, height, num_outputs) - - params.update({ - 'steps': num_inference_steps, - 'init_image_latent': context.init_image_latent, - 'mask': context.init_image_mask_tensor, - 'prompt_strength': prompt_strength, - }) + if not hasattr(context, "init_image_latent") or context.init_image_latent is None: + context.init_image_latent, context.init_image_mask_tensor = get_image_latent_and_mask( + context, init_image, init_image_mask, width, height, num_outputs + ) + + params.update( + { + "steps": num_inference_steps, + "init_image_latent": context.init_image_latent, + "mask": context.init_image_mask_tensor, + "prompt_strength": prompt_strength, + } + ) samples = make_samples(**params) images = latent_samples_to_images(context, samples) @@ -106,14 +130,17 @@ def img2img(params: dict, context: Context, num_inference_steps, num_outputs, wi return images + def get_image(img): if not isinstance(img, str): return img - if img.startswith('data:image'): + if img.startswith("data:image"): return base64_str_to_img(img) import os + if os.path.exists(img): from PIL import Image + return Image.open(img) diff --git a/sdkit/generate/prompt_parser.py b/sdkit/generate/prompt_parser.py index f53679c..cb05229 100644 --- a/sdkit/generate/prompt_parser.py +++ b/sdkit/generate/prompt_parser.py @@ -2,12 +2,14 @@ from sdkit.utils import log + def get_cond_and_uncond(prompt, negative_prompt, batch_size, model): cond = parse_prompt(prompt, batch_size, model) uncond = parse_prompt(negative_prompt, batch_size, model) return cond, uncond + def parse_prompt(prompt, batch_size, model): """ Requires model to be on the device @@ -18,16 +20,19 @@ def parse_prompt(prompt, batch_size, model): weights_sum = sum(weights) for i, subprompt in enumerate(subprompts): - result = torch.add(result, model.get_learned_conditioning(batch_size * [subprompt]), alpha=weights[i] / weights_sum) + result = torch.add( + result, model.get_learned_conditioning(batch_size * [subprompt]), alpha=weights[i] / weights_sum + ) if len(subprompts) == 0: result = empty_result return result + def split_weighted_subprompts(text): """ - grabs all text up to the first occurrence of ':' + grabs all text up to the first occurrence of ':' uses the grabbed text as a sub-prompt, and takes the value following ':' as weight if ':' has no value defined, defaults to 1.0 repeats until no text remaining @@ -37,35 +42,35 @@ def split_weighted_subprompts(text): weights = [] while remaining > 0: if ":" in text: - idx = text.index(":") # first occurrence from start + idx = text.index(":") # first occurrence from start # grab up to index as sub-prompt prompt = text[:idx] remaining -= idx # remove from main text - text = text[idx+1:] - # find value for weight + text = text[idx + 1 :] + # find value for weight if " " in text: - idx = text.index(" ") # first occurence - else: # no space, read to end + idx = text.index(" ") # first occurence + else: # no space, read to end idx = len(text) if idx != 0: try: weight = float(text[:idx]) - except: # couldn't treat as float + except: # couldn't treat as float log.warn(f"Warning: '{text[:idx]}' is not a value, are you missing a space?") weight = 1.0 - else: # no value found + else: # no value found weight = 1.0 # remove from main text remaining -= idx - text = text[idx+1:] + text = text[idx + 1 :] # append the sub-prompt and its weight prompts.append(prompt) weights.append(weight) - else: # no : found - if len(text) > 0: # there is still text though + else: # no : found + if len(text) > 0: # there is still text though # take remainder as weight 1 prompts.append(text) weights.append(1.0) remaining = 0 - return prompts, weights \ No newline at end of file + return prompts, weights diff --git a/sdkit/generate/sampler/default_samplers.py b/sdkit/generate/sampler/default_samplers.py index fea0530..73aeced 100644 --- a/sdkit/generate/sampler/default_samplers.py +++ b/sdkit/generate/sampler/default_samplers.py @@ -1,63 +1,85 @@ import torch -from torch import Tensor - from ldm.models.diffusion.ddim import DDIMSampler -from ldm.models.diffusion.plms import PLMSSampler from ldm.models.diffusion.dpm_solver import DPMSolverSampler +from ldm.models.diffusion.plms import PLMSSampler +from torch import Tensor from sdkit import Context samplers = { - 'ddim': DDIMSampler, - 'plms': PLMSSampler, - 'dpm_solver_stability': DPMSolverSampler, + "ddim": DDIMSampler, + "plms": PLMSSampler, + "dpm_solver_stability": DPMSolverSampler, } -def sample(context: Context, sampler_name:str=None, noise: Tensor=None, batch_size: int=1, shape: tuple=(), steps: int=50, cond: Tensor=None, uncond: Tensor=None, guidance_scale: float=0.8, callback=None, **kwargs): - model = context.models['stable-diffusion'] - sample_fn = _sample_txt2img if 'init_image_latent' not in kwargs else _sample_img2img + +def sample( + context: Context, + sampler_name: str = None, + noise: Tensor = None, + batch_size: int = 1, + shape: tuple = (), + steps: int = 50, + cond: Tensor = None, + uncond: Tensor = None, + guidance_scale: float = 0.8, + callback=None, + **kwargs, +): + model = context.models["stable-diffusion"] + sample_fn = _sample_txt2img if "init_image_latent" not in kwargs else _sample_img2img common_params = { - 'S': steps, - 'batch_size': batch_size, - 'shape': shape, - 'conditioning': cond, - 'verbose': False, - 'unconditional_guidance_scale': guidance_scale, - 'unconditional_conditioning': uncond, - 'eta': 0., - 'img_callback': callback, + "S": steps, + "batch_size": batch_size, + "shape": shape, + "conditioning": cond, + "verbose": False, + "unconditional_guidance_scale": guidance_scale, + "unconditional_conditioning": uncond, + "eta": 0.0, + "img_callback": callback, } samples, _ = sample_fn(model, sampler_name, noise, steps, batch_size, common_params.copy(), **kwargs) return samples + def _sample_txt2img(model, sampler_name, noise, steps, batch_size, params, **kwargs): sampler = samplers[sampler_name](model) - params.update({ - 'x_T': noise, - }) + params.update( + { + "x_T": noise, + } + ) return sampler.sample(**params) + def _sample_img2img(model, sampler_name, noise, steps, batch_size, params, **kwargs): sampler = DDIMSampler(model) - actual_inference_steps = int(steps * kwargs['prompt_strength']) - init_image_latent = kwargs['init_image_latent'] - mask = kwargs.get('mask') + actual_inference_steps = int(steps * kwargs["prompt_strength"]) + init_image_latent = kwargs["init_image_latent"] + mask = kwargs.get("mask") - sampler.make_schedule(ddim_num_steps=steps, ddim_eta=0., verbose=False) - z_enc = sampler.stochastic_encode(init_image_latent, torch.tensor([actual_inference_steps] * batch_size).to(model.device), noise=noise) + sampler.make_schedule(ddim_num_steps=steps, ddim_eta=0.0, verbose=False) + z_enc = sampler.stochastic_encode( + init_image_latent, torch.tensor([actual_inference_steps] * batch_size).to(model.device), noise=noise + ) - sampler.make_schedule = (lambda **kwargs: kwargs) # we've already called this, don't call this again from within the sampler + sampler.make_schedule = ( + lambda **kwargs: kwargs + ) # we've already called this, don't call this again from within the sampler sampler.ddim_timesteps = sampler.ddim_timesteps[:actual_inference_steps] - params.update({ - 'S': actual_inference_steps, - 'x_T': z_enc, - 'x0': init_image_latent if mask is not None else None, - 'mask': mask, - }) + params.update( + { + "S": actual_inference_steps, + "x_T": z_enc, + "x0": init_image_latent if mask is not None else None, + "mask": mask, + } + ) return sampler.sample(**params) diff --git a/sdkit/generate/sampler/k_samplers.py b/sdkit/generate/sampler/k_samplers.py index b003f4f..fab74c9 100644 --- a/sdkit/generate/sampler/k_samplers.py +++ b/sdkit/generate/sampler/k_samplers.py @@ -1,57 +1,73 @@ +import k_diffusion.external +import k_diffusion.sampling as k_samplers import torch import torch.nn as nn from torch import Tensor -import k_diffusion.sampling as k_samplers -import k_diffusion.external - from sdkit import Context samplers = { - 'euler_a': k_samplers.sample_euler_ancestral, - 'euler': k_samplers.sample_euler, - 'lms': k_samplers.sample_lms, - 'heun': k_samplers.sample_heun, - 'dpm2': k_samplers.sample_dpm_2, - 'dpm2_a': k_samplers.sample_dpm_2_ancestral, - 'dpmpp_2s_a': k_samplers.sample_dpmpp_2s_ancestral, - 'dpmpp_2m': k_samplers.sample_dpmpp_2m, - 'dpmpp_sde': k_samplers.sample_dpmpp_sde, - 'dpm_fast': k_samplers.sample_dpm_fast, - 'dpm_adaptive': k_samplers.sample_dpm_adaptive, + "euler_a": k_samplers.sample_euler_ancestral, + "euler": k_samplers.sample_euler, + "lms": k_samplers.sample_lms, + "heun": k_samplers.sample_heun, + "dpm2": k_samplers.sample_dpm_2, + "dpm2_a": k_samplers.sample_dpm_2_ancestral, + "dpmpp_2s_a": k_samplers.sample_dpmpp_2s_ancestral, + "dpmpp_2m": k_samplers.sample_dpmpp_2m, + "dpmpp_sde": k_samplers.sample_dpmpp_sde, + "dpm_fast": k_samplers.sample_dpm_fast, + "dpm_adaptive": k_samplers.sample_dpm_adaptive, } -def sample(context: Context, sampler_name:str=None, noise: Tensor=None, batch_size: int=1, shape: tuple=(), steps: int=50, cond: Tensor=None, uncond: Tensor=None, guidance_scale: float=0.8, callback=None, **kwargs): - model = context.models['stable-diffusion'] - denoiser = k_diffusion.external.CompVisVDenoiser if model.parameterization == 'v' else k_diffusion.external.CompVisDenoiser + +def sample( + context: Context, + sampler_name: str = None, + noise: Tensor = None, + batch_size: int = 1, + shape: tuple = (), + steps: int = 50, + cond: Tensor = None, + uncond: Tensor = None, + guidance_scale: float = 0.8, + callback=None, + **kwargs, +): + model = context.models["stable-diffusion"] + denoiser = ( + k_diffusion.external.CompVisVDenoiser if model.parameterization == "v" else k_diffusion.external.CompVisDenoiser + ) wrapped_model = DenoiserWrap(denoiser(model)) sigmas = wrapped_model.inner_model.get_sigmas(steps) sample_fn = samplers.get(sampler_name) - x_latent = noise # because we only use DDIM for img2img + x_latent = noise # because we only use DDIM for img2img x_latent *= sigmas[0] params = { - 'model': wrapped_model, - 'x': x_latent, - 'callback': (lambda info: callback(info['x'], info['i'])) if callback is not None else None, - 'extra_args': { - 'uncond': uncond, - 'cond': cond, - 'guidance_scale': guidance_scale, - } + "model": wrapped_model, + "x": x_latent, + "callback": (lambda info: callback(info["x"], info["i"])) if callback is not None else None, + "extra_args": { + "uncond": uncond, + "cond": cond, + "guidance_scale": guidance_scale, + }, } - if sampler_name in ('dpm_fast', 'dpm_adaptive'): - params['sigma_min'] = sigmas[-2] # sigmas is sorted. the last element is 0, which isn't allowed - params['sigma_max'] = sigmas[0] + if sampler_name in ("dpm_fast", "dpm_adaptive"): + params["sigma_min"] = sigmas[-2] # sigmas is sorted. the last element is 0, which isn't allowed + params["sigma_max"] = sigmas[0] - if sampler_name == 'dpm_fast': params['n'] = steps - 1 + if sampler_name == "dpm_fast": + params["n"] = steps - 1 else: - params['sigmas'] = sigmas + params["sigmas"] = sigmas return sample_fn(**params) + # based on https://github.com/XmYx/waifu-diffusion-gradio-hosted-by-colab-en/blob/main/scripts/kdiff.py#L109 class DenoiserWrap(nn.Module): def __init__(self, model): diff --git a/sdkit/generate/sampler/sampler_main.py b/sdkit/generate/sampler/sampler_main.py index 2e388fe..7cdb746 100644 --- a/sdkit/generate/sampler/sampler_main.py +++ b/sdkit/generate/sampler/sampler_main.py @@ -6,7 +6,20 @@ from . import default_samplers, k_samplers -def make_samples(context: Context, sampler_name: str=None, seed: int=42, batch_size: int=1, shape: tuple=(), steps: int=50, cond: Tensor=None, uncond: Tensor=None, guidance_scale: float=0.8, callback=None, **kwargs): + +def make_samples( + context: Context, + sampler_name: str = None, + seed: int = 42, + batch_size: int = 1, + shape: tuple = (), + steps: int = 50, + cond: Tensor = None, + uncond: Tensor = None, + guidance_scale: float = 0.8, + callback=None, + **kwargs, +): """ Common args: * context: Context @@ -21,20 +34,26 @@ def make_samples(context: Context, sampler_name: str=None, seed: int=42, batch_s * callback: function - signature: `callback(x_samples: Tensor, i: int)` additional args for txt2img: - + additional args for img2img: * init_image_latent: Tensor * mask: Tensor * prompt_strength: float - between 0 and 1. Use 0 to ignore the prompt entirely, or 1 to ignore the init image entirely """ sampler_module = None - if sampler_name in default_samplers.samplers: sampler_module = default_samplers - if sampler_name in k_samplers.samplers: sampler_module = k_samplers - if sampler_module is None: raise RuntimeError(f'Unknown sampler "{sampler_name}"!') + if sampler_name in default_samplers.samplers: + sampler_module = default_samplers + if sampler_name in k_samplers.samplers: + sampler_module = k_samplers + if sampler_module is None: + raise RuntimeError(f'Unknown sampler "{sampler_name}"!') noise = make_some_noise(seed, batch_size, shape, context.device) - return sampler_module.sample(context, sampler_name, noise, batch_size, shape, steps, cond, uncond, guidance_scale, callback, **kwargs) + return sampler_module.sample( + context, sampler_name, noise, batch_size, shape, steps, cond, uncond, guidance_scale, callback, **kwargs + ) + def make_some_noise(seed, batch_size, shape, device): b1, b2, b3 = shape diff --git a/sdkit/models/__init__.py b/sdkit/models/__init__.py index 81724a1..417c868 100644 --- a/sdkit/models/__init__.py +++ b/sdkit/models/__init__.py @@ -1,17 +1,8 @@ -from .model_loader import ( - load_model, - unload_model, -) - -from .models_db import ( - get_model_info_from_db, - get_models_db, -) - from .model_downloader import ( download_model, download_models, resolve_downloaded_model_path, ) - +from .model_loader import load_model, unload_model +from .models_db import get_model_info_from_db, get_models_db from .scan_models import scan_model diff --git a/sdkit/models/model_downloader.py b/sdkit/models/model_downloader.py index cafea01..7466f80 100644 --- a/sdkit/models/model_downloader.py +++ b/sdkit/models/model_downloader.py @@ -3,8 +3,9 @@ from sdkit.utils import download_file, log -def download_models(models: dict, download_base_dir: str=None, subdir_for_model_type=True): - ''' + +def download_models(models: dict, download_base_dir: str = None, subdir_for_model_type=True): + """ Downloads the requested models (and config files) based on the SDKit models database. Resumes incomplete downloads, and shows a progress bar. @@ -18,15 +19,16 @@ def download_models(models: dict, download_base_dir: str=None, subdir_for_model_ * subdir_for_model_type: bool - default True. Saves the downloaded model in a subdirectory (named with the model_type). For e.g. if `download_base_dir` is `D:\\models`, then a `stable-diffusion` type model is downloaded to `D:\\models\\stable-diffusion`, a `hypernetwork` type model is downloaded to `D:\\models\\hypernetwork` and so on. - ''' + """ for model_type, model_ids in models.items(): model_ids = model_ids if isinstance(model_ids, list) else [model_ids] for model_id in model_ids: download_model(model_type, model_id, download_base_dir, subdir_for_model_type) -def download_model(model_type: str, model_id: str, download_base_dir: str=None, subdir_for_model_type=True): - ''' + +def download_model(model_type: str, model_id: str, download_base_dir: str = None, subdir_for_model_type=True): + """ Downloads the requested model (and config file) based on the SDKit models database. Resumes incomplete downloads, and shows a progress bar. @@ -39,14 +41,14 @@ def download_model(model_type: str, model_id: str, download_base_dir: str=None, * subdir_for_model_type: bool - default True. Saves the downloaded model in a subdirectory (named with the model_type). For e.g. if `download_base_dir` is `D:\\models`, then a `stable-diffusion` type model is downloaded to `D:\\models\\stable-diffusion`, a `hypernetwork` type model is downloaded to `D:\\models\\hypernetwork` and so on. - ''' + """ download_base_dir = get_actual_base_dir(model_type, download_base_dir, subdir_for_model_type) try: - model_url, model_file_name = get_url_and_filename(model_type, model_id, url_key='url') - config_url, config_file_name = get_url_and_filename(model_type, model_id, url_key='config_url') + model_url, model_file_name = get_url_and_filename(model_type, model_id, url_key="url") + config_url, config_file_name = get_url_and_filename(model_type, model_id, url_key="config_url") if model_url is None: - log.warn(f'No download url found for model {model_type} {model_id}') + log.warn(f"No download url found for model {model_type} {model_id}") return os.makedirs(download_base_dir, exist_ok=True) @@ -59,8 +61,11 @@ def download_model(model_type: str, model_id: str, download_base_dir: str=None, except Exception as e: log.exception(e) -def resolve_downloaded_model_path(model_type: str, model_id: str, download_base_dir: str=None, subdir_for_model_type=True): - ''' + +def resolve_downloaded_model_path( + model_type: str, model_id: str, download_base_dir: str = None, subdir_for_model_type=True +): + """ Gets the path to the downloaded model file. Returns `None` if a file doesn't exist at the calculated path. Args: @@ -72,21 +77,23 @@ def resolve_downloaded_model_path(model_type: str, model_id: str, download_base_ * subdir_for_model_type: bool - default True. Saves the downloaded model in a subdirectory (named with the model_type). For e.g. if `download_base_dir` is `D:\\models`, then a `stable-diffusion` type model is downloaded to `D:\\models\\stable-diffusion`, a `hypernetwork` type model is downloaded to `D:\\models\\hypernetwork` and so on. - ''' + """ download_base_dir = get_actual_base_dir(model_type, download_base_dir, subdir_for_model_type) - _, file_name = get_url_and_filename(model_type, model_id, url_key='url') + _, file_name = get_url_and_filename(model_type, model_id, url_key="url") if file_name is None: return file_path = os.path.join(download_base_dir, file_name) return file_path if os.path.exists(file_path) else None + def get_actual_base_dir(model_type, download_base_dir, subdir_for_model_type): - download_base_dir = os.path.join('~', '.cache', 'sdkit') if download_base_dir is None else download_base_dir + download_base_dir = os.path.join("~", ".cache", "sdkit") if download_base_dir is None else download_base_dir download_base_dir = os.path.join(download_base_dir, model_type) if subdir_for_model_type else download_base_dir return os.path.abspath(download_base_dir) -def get_url_and_filename(model_type, model_id, url_key='url'): + +def get_url_and_filename(model_type, model_id, url_key="url"): from sdkit.models import get_model_info_from_db model_info = get_model_info_from_db(model_type=model_type, model_id=model_id) diff --git a/sdkit/models/model_loader/__init__.py b/sdkit/models/model_loader/__init__.py index 408a171..c4bbba4 100644 --- a/sdkit/models/model_loader/__init__.py +++ b/sdkit/models/model_loader/__init__.py @@ -1,17 +1,17 @@ -from . import hypernetwork -from . import stable_diffusion, gfpgan, realesrgan, vae - from sdkit import Context from sdkit.utils import gc, log +from . import gfpgan, hypernetwork, realesrgan, stable_diffusion, vae + models = { - 'stable-diffusion': stable_diffusion, - 'gfpgan': gfpgan, - 'realesrgan': realesrgan, - 'vae': vae, - 'hypernetwork': hypernetwork, + "stable-diffusion": stable_diffusion, + "gfpgan": gfpgan, + "realesrgan": realesrgan, + "vae": vae, + "hypernetwork": hypernetwork, } + def load_model(context: Context, model_type: str, **kwargs): if context.model_paths.get(model_type) is None: return @@ -19,16 +19,17 @@ def load_model(context: Context, model_type: str, **kwargs): if model_type in context.models: unload_model(context, model_type) - log.info(f'loading {model_type} model from {context.model_paths.get(model_type)} to device: {context.device}') + log.info(f"loading {model_type} model from {context.model_paths.get(model_type)} to device: {context.device}") context.models[model_type] = models[model_type].load_model(context, **kwargs) - log.info(f'loaded {model_type} model from {context.model_paths.get(model_type)} to device: {context.device}') + log.info(f"loaded {model_type} model from {context.model_paths.get(model_type)} to device: {context.device}") # reload dependent models - if model_type == 'stable-diffusion': - load_model(context, 'vae') - load_model(context, 'hypernetwork') + if model_type == "stable-diffusion": + load_model(context, "vae") + load_model(context, "hypernetwork") + def unload_model(context: Context, model_type: str, **kwargs): if model_type not in context.models: @@ -38,4 +39,4 @@ def unload_model(context: Context, model_type: str, **kwargs): models[model_type].unload_model(context) gc(context) - log.info(f'unloaded {model_type} model from device: {context.device}') + log.info(f"unloaded {model_type} model from device: {context.device}") diff --git a/sdkit/models/model_loader/gfpgan.py b/sdkit/models/model_loader/gfpgan.py index d194592..aee9364 100644 --- a/sdkit/models/model_loader/gfpgan.py +++ b/sdkit/models/model_loader/gfpgan.py @@ -3,14 +3,24 @@ from sdkit import Context + def load_model(context: Context, **kwargs): - model_path = context.model_paths.get('gfpgan') + model_path = context.model_paths.get("gfpgan") # hack for a bug in facexlib: https://github.com/xinntao/facexlib/pull/19/files from facexlib.detection import retinaface + retinaface.device = torch.device(context.device) - return GFPGANer(device=torch.device(context.device), model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None) + return GFPGANer( + device=torch.device(context.device), + model_path=model_path, + upscale=1, + arch="clean", + channel_multiplier=2, + bg_upsampler=None, + ) + def unload_model(context: Context, **kwargs): pass diff --git a/sdkit/models/model_loader/hypernetwork/__init__.py b/sdkit/models/model_loader/hypernetwork/__init__.py index ff8bf1c..7f5183e 100644 --- a/sdkit/models/model_loader/hypernetwork/__init__.py +++ b/sdkit/models/model_loader/hypernetwork/__init__.py @@ -1,33 +1,55 @@ import traceback from sdkit import Context -from sdkit.utils import log, load_tensor_file +from sdkit.utils import load_tensor_file, log + def load_model(context: Context, **kwargs): from .hypernetwork import HypernetworkModule, override_attention_context_kv - model_path = context.model_paths.get('hypernetwork') + + model_path = context.model_paths.get("hypernetwork") try: state_dict = load_tensor_file(model_path) - layer_structure = state_dict.get('layer_structure', [1, 2, 1]) - activation_func = state_dict.get('activation_func', None) - weight_init = state_dict.get('weight_initialization', 'Normal') - add_layer_norm = state_dict.get('is_layer_norm', False) - use_dropout = state_dict.get('use_dropout', False) - activate_output = state_dict.get('activate_output', True) - last_layer_dropout = state_dict.get('last_layer_dropout', False) + layer_structure = state_dict.get("layer_structure", [1, 2, 1]) + activation_func = state_dict.get("activation_func", None) + weight_init = state_dict.get("weight_initialization", "Normal") + add_layer_norm = state_dict.get("is_layer_norm", False) + use_dropout = state_dict.get("use_dropout", False) + activate_output = state_dict.get("activate_output", True) + last_layer_dropout = state_dict.get("last_layer_dropout", False) - layers = {'hypernetwork_strength': 0} + layers = {"hypernetwork_strength": 0} for size, sd in state_dict.items(): if type(size) == int: layers[size] = ( - HypernetworkModule(size, sd[0], layer_structure, activation_func, weight_init, add_layer_norm, - use_dropout, activate_output, last_layer_dropout=last_layer_dropout, - model=layers, device=context.device), - HypernetworkModule(size, sd[1], layer_structure, activation_func, weight_init, add_layer_norm, - use_dropout, activate_output, last_layer_dropout=last_layer_dropout, - model=layers, device=context.device), + HypernetworkModule( + size, + sd[0], + layer_structure, + activation_func, + weight_init, + add_layer_norm, + use_dropout, + activate_output, + last_layer_dropout=last_layer_dropout, + model=layers, + device=context.device, + ), + HypernetworkModule( + size, + sd[1], + layer_structure, + activation_func, + weight_init, + add_layer_norm, + use_dropout, + activate_output, + last_layer_dropout=last_layer_dropout, + model=layers, + device=context.device, + ), ) override_attention_context_kv(layers) @@ -35,8 +57,10 @@ def load_model(context: Context, **kwargs): return layers except: log.error(traceback.format_exc()) - log.error(f'Could not load hypernetwork: {model_path}') + log.error(f"Could not load hypernetwork: {model_path}") + def unload_model(context: Context, **kwargs): from .hypernetwork import override_attention_context_kv + override_attention_context_kv(None) diff --git a/sdkit/models/model_loader/hypernetwork/hypernetwork.py b/sdkit/models/model_loader/hypernetwork/hypernetwork.py index fa75746..e86a7c7 100644 --- a/sdkit/models/model_loader/hypernetwork/hypernetwork.py +++ b/sdkit/models/model_loader/hypernetwork/hypernetwork.py @@ -2,8 +2,10 @@ # which was a cut down version of https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/c9a2cfdf2a53d37c2de1908423e4f548088667ef/modules/hypernetworks/hypernetwork.py import inspect + import torch + class HypernetworkModule(torch.nn.Module): multiplier = 0.5 activation_dict = { @@ -15,10 +17,28 @@ class HypernetworkModule(torch.nn.Module): "tanh": torch.nn.Tanh, "sigmoid": torch.nn.Sigmoid, } - activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'}) - - def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', - add_layer_norm=False, use_dropout=False, activate_output=False, last_layer_dropout=False, model=None, device='cuda'): + activation_dict.update( + { + cls_name.lower(): cls_obj + for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) + if inspect.isclass(cls_obj) and cls_obj.__module__ == "torch.nn.modules.activation" + } + ) + + def __init__( + self, + dim, + state_dict=None, + layer_structure=None, + activation_func=None, + weight_init="Normal", + add_layer_norm=False, + use_dropout=False, + activate_output=False, + last_layer_dropout=False, + model=None, + device="cuda", + ): super().__init__() self.model = model @@ -29,21 +49,24 @@ def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=N linears = [] for i in range(len(layer_structure) - 1): - # Add a fully-connected layer - linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1]))) + linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i + 1]))) # Add an activation func except last layer - if activation_func == "linear" or activation_func is None or (i >= len(layer_structure) - 2 and not activate_output): + if ( + activation_func == "linear" + or activation_func is None + or (i >= len(layer_structure) - 2 and not activate_output) + ): pass elif activation_func in self.activation_dict: linears.append(self.activation_dict[activation_func]()) else: - raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}') + raise RuntimeError(f"hypernetwork uses an unsupported activation function: {activation_func}") # Add layer normalization if add_layer_norm: - linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) + linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i + 1]))) # Add dropout except last layer if use_dropout and (i < len(layer_structure) - 3 or last_layer_dropout and i < len(layer_structure) - 2): @@ -58,10 +81,10 @@ def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=N def fix_old_state_dict(self, state_dict): changes = { - 'linear1.bias': 'linear.0.bias', - 'linear1.weight': 'linear.0.weight', - 'linear2.bias': 'linear.1.bias', - 'linear2.weight': 'linear.1.weight', + "linear1.bias": "linear.0.bias", + "linear1.weight": "linear.0.weight", + "linear2.bias": "linear.1.bias", + "linear2.weight": "linear.1.weight", } for fr, to in changes.items(): @@ -73,7 +96,8 @@ def fix_old_state_dict(self, state_dict): state_dict[to] = x def forward(self, x: torch.Tensor): - return x + self.linear(x) * self.model['hypernetwork_strength'] + return x + self.linear(x) * self.model["hypernetwork_strength"] + def apply_hypernetwork(hypernetwork, attention_context, layer=None): hypernetwork_layers = hypernetwork.get(attention_context.shape[2], None) @@ -89,8 +113,10 @@ def apply_hypernetwork(hypernetwork, attention_context, layer=None): attention_context_v = hypernetwork_layers[1](attention_context) return attention_context_k, attention_context_v + def override_attention_context_kv(hypernetwork_model): import sdkit.models.model_loader.stable_diffusion.optimizations as sd_model_optimizer + def get_context_kv(attention_context): if hypernetwork_model is None: return attention_context, attention_context diff --git a/sdkit/models/model_loader/realesrgan.py b/sdkit/models/model_loader/realesrgan.py index 45415f9..d36a215 100644 --- a/sdkit/models/model_loader/realesrgan.py +++ b/sdkit/models/model_loader/realesrgan.py @@ -1,30 +1,37 @@ -import torch import os + +import torch from basicsr.archs.rrdbnet_arch import RRDBNet from realesrgan import RealESRGANer from sdkit import Context + def load_model(context: Context, **kwargs): - model_path = context.model_paths.get('realesrgan') + model_path = context.model_paths.get("realesrgan") RealESRGAN_models = { - 'RealESRGAN_x4plus': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4), - 'RealESRGAN_x4plus_anime_6B': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) + "RealESRGAN_x4plus": RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4), + "RealESRGAN_x4plus_anime_6B": RRDBNet( + num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4 + ), } model_to_use = os.path.basename(model_path) model_to_use, _ = os.path.splitext(model_to_use) model_to_use = RealESRGAN_models[model_to_use] - half = context.half_precision if context.device != 'cpu' else False - model = RealESRGANer(device=torch.device(context.device), scale=4, model_path=model_path, model=model_to_use, pre_pad=0, half=half) - if context.device == 'cpu': - model.model.to('cpu') + half = context.half_precision if context.device != "cpu" else False + model = RealESRGANer( + device=torch.device(context.device), scale=4, model_path=model_path, model=model_to_use, pre_pad=0, half=half + ) + if context.device == "cpu": + model.model.to("cpu") model.model.name = model_to_use return model + def unload_model(context: Context, **kwargs): pass diff --git a/sdkit/models/model_loader/stable_diffusion/__init__.py b/sdkit/models/model_loader/stable_diffusion/__init__.py index 8e87ab1..ea80b2f 100644 --- a/sdkit/models/model_loader/stable_diffusion/__init__.py +++ b/sdkit/models/model_loader/stable_diffusion/__init__.py @@ -1,54 +1,66 @@ import os -from omegaconf import OmegaConf -from ldm.util import instantiate_from_config -from transformers import logging as tr_logging -from torch.nn.functional import silu -import ldm.modules.attention -import ldm.modules.diffusionmodules.model import tempfile -from urllib.parse import urlparse from pathlib import Path +from urllib.parse import urlparse + +import ldm.modules.attention +import ldm.modules.diffusionmodules.model +from ldm.util import instantiate_from_config +from omegaconf import OmegaConf +from torch.nn.functional import silu +from transformers import logging as tr_logging from sdkit import Context -from sdkit.utils import load_tensor_file, save_tensor_file, hash_file_quick, download_file, log +from sdkit.utils import ( + download_file, + hash_file_quick, + load_tensor_file, + log, + save_tensor_file, +) + +tr_logging.set_verbosity_error() # suppress unnecessary logging -tr_logging.set_verbosity_error() # suppress unnecessary logging def load_model(context: Context, scan_model=True, check_for_config_with_same_name=True, **kwargs): - from . import optimizations from sdkit.models import scan_model as scan_model_fn - model_path = context.model_paths.get('stable-diffusion') + from . import optimizations + + model_path = context.model_paths.get("stable-diffusion") config_file_path = get_model_config_file(context, check_for_config_with_same_name) if scan_model: scan_result = scan_model_fn(model_path) if scan_result.issues_count > 0 or scan_result.infected_files > 0: - raise Exception(f'Model scan failed! Potentially infected model: {model_path}') + raise Exception(f"Model scan failed! Potentially infected model: {model_path}") # load the model file sd = load_tensor_file(model_path) - sd = sd['state_dict'] if 'state_dict' in sd else sd + sd = sd["state_dict"] if "state_dict" in sd else sd # try to guess the config, if no config file was given # check if a key specific to SD 2.0 is missing - if config_file_path is None and 'cond_stage_model.model.ln_final.bias' not in sd.keys(): + if config_file_path is None and "cond_stage_model.model.ln_final.bias" not in sd.keys(): # try using an SD 1.4 config from sdkit.models import get_model_info_from_db - sd_v1_4_info = get_model_info_from_db(model_type='stable-diffusion', model_id='1.4') + + sd_v1_4_info = get_model_info_from_db(model_type="stable-diffusion", model_id="1.4") config_file_path = resolve_model_config_file_path(sd_v1_4_info, model_path) # load the config if config_file_path is None: - raise Exception(f'Unknown model! No config file path specified in context.model_configs for the "stable-diffusion" model!') + raise Exception( + 'Unknown model! No config file path specified in context.model_configs for the "stable-diffusion" model!' + ) - log.info(f'using config: {config_file_path}') + log.info(f"using config: {config_file_path}") config = OmegaConf.load(config_file_path) config.model.params.unet_config.params.use_fp16 = context.half_precision - extra_config = config.get('extra', {}) - attn_precision = extra_config.get('attn_precision', 'fp16' if context.half_precision else 'fp32') - log.info(f'using attn_precision: {attn_precision}') + extra_config = config.get("extra", {}) + attn_precision = extra_config.get("attn_precision", "fp16" if context.half_precision else "fp32") + log.info(f"using attn_precision: {attn_precision}") # instantiate the model model = instantiate_from_config(config.model) @@ -61,30 +73,36 @@ def load_model(context: Context, scan_model=True, check_for_config_with_same_nam del sd # optimize CrossAttention.forward() for faster performance, and lower VRAM usage - ldm.modules.attention.CrossAttention.forward = optimizations.make_attn_forward(context, attn_precision=attn_precision) + ldm.modules.attention.CrossAttention.forward = optimizations.make_attn_forward( + context, attn_precision=attn_precision + ) ldm.modules.diffusionmodules.model.nonlinearity = silu # save the model vae into a temp folder (used for restoring the default VAE, if a custom VAE is unloaded) - save_tensor_file(model.first_stage_model.state_dict(), os.path.join(tempfile.gettempdir(), 'sd-base-vae.safetensors')) + save_tensor_file( + model.first_stage_model.state_dict(), os.path.join(tempfile.gettempdir(), "sd-base-vae.safetensors") + ) # optimizations.print_model_size_breakdown(model) return model + def unload_model(context: Context, **kwargs): - context.module_in_gpu = None # don't keep a dangling reference, prevents gc + context.module_in_gpu = None # don't keep a dangling reference, prevents gc + def get_model_config_file(context: Context, check_for_config_with_same_name): from sdkit.models import get_model_info_from_db - if context.model_configs.get('stable-diffusion') is not None: - return context.model_configs['stable-diffusion'] + if context.model_configs.get("stable-diffusion") is not None: + return context.model_configs["stable-diffusion"] - model_path = context.model_paths['stable-diffusion'] + model_path = context.model_paths["stable-diffusion"] if check_for_config_with_same_name: model_name_path = os.path.splitext(model_path)[0] - model_config_path = f'{model_name_path}.yaml' + model_config_path = f"{model_name_path}.yaml" if os.path.exists(model_config_path): return model_config_path @@ -93,14 +111,15 @@ def get_model_config_file(context: Context, check_for_config_with_same_name): return resolve_model_config_file_path(model_info, model_path) + def resolve_model_config_file_path(model_info, model_path): if model_info is None: return - config_url = model_info.get('config_url') + config_url = model_info.get("config_url") if config_url is None: return - if config_url.startswith('http'): + if config_url.startswith("http"): config_file_name = os.path.basename(urlparse(config_url).path) model_dir_name = os.path.dirname(model_path) config_file_path = os.path.join(model_dir_name, config_file_name) @@ -109,7 +128,8 @@ def resolve_model_config_file_path(model_info, model_path): download_file(config_url, config_file_path) else: from sdkit.models import models_db + models_db_path = Path(models_db.__file__).parent - config_file_path = models_db_path/config_url + config_file_path = models_db_path / config_url return config_file_path diff --git a/sdkit/models/model_loader/stable_diffusion/optimizations.py b/sdkit/models/model_loader/stable_diffusion/optimizations.py index 367bb7a..d47300b 100644 --- a/sdkit/models/model_loader/stable_diffusion/optimizations.py +++ b/sdkit/models/model_loader/stable_diffusion/optimizations.py @@ -1,12 +1,14 @@ -import torch -from torch import einsum import math -from ldm.util import default + +import torch from einops import rearrange +from ldm.util import default +from torch import einsum from sdkit import Context from sdkit.utils import log + def send_to_device(context: Context, model): """ Sends the model to the device, based on the VRAM optimizations set in @@ -15,13 +17,13 @@ def send_to_device(context: Context, model): Please see the documentation for `diffusionkit.types.Context.vram_optimizations` for a summary of the logic used for VRAM optimizations """ - if len(context.vram_optimizations) == 0 or context.device == 'cpu': - log.info('No VRAM optimizations being applied') + if len(context.vram_optimizations) == 0 or context.device == "cpu": + log.info("No VRAM optimizations being applied") model.to(context.device) model.cond_stage_model.device = context.device return - log.info(f'VRAM Optimizations: {context.vram_optimizations}') + log.info(f"VRAM Optimizations: {context.vram_optimizations}") # based on the approach at https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/modules/lowvram.py # the idea is to keep only one module in the GPU at a time, depending on the desired optimization level @@ -38,21 +40,30 @@ def move_to_gpu(module, _): return if context.module_in_gpu is not None: - context.module_in_gpu.to('cpu') - log.debug(f"moved {getattr(context.module_in_gpu, 'log_name', context.module_in_gpu.__class__.__name__)} to cpu") + context.module_in_gpu.to("cpu") + log.debug( + f"moved {getattr(context.module_in_gpu, 'log_name', context.module_in_gpu.__class__.__name__)} to cpu" + ) module.to(context.device) - if module == model.cond_stage_model: module.device = context.device + if module == model.cond_stage_model: + module.device = context.device context.module_in_gpu = module - log.debug(f"moved {getattr(context.module_in_gpu, 'log_name', context.module_in_gpu.__class__.__name__)} to GPU") + log.debug( + f"moved {getattr(context.module_in_gpu, 'log_name', context.module_in_gpu.__class__.__name__)} to GPU" + ) def wrap_fs_fn(fn, model_to_move): def wrap(x): move_to_gpu(model_to_move, None) return fn(x) + return wrap - if 'KEEP_FS_AND_CS_IN_CPU' in context.vram_optimizations or 'KEEP_ENTIRE_MODEL_IN_CPU' in context.vram_optimizations: + if ( + "KEEP_FS_AND_CS_IN_CPU" in context.vram_optimizations + or "KEEP_ENTIRE_MODEL_IN_CPU" in context.vram_optimizations + ): # move the FS, CS and the main model to CPU. And send only the overall reference to the correct device tmp = model.cond_stage_model, model.first_stage_model, model.model model.cond_stage_model, model.first_stage_model, model.model = (None,) * 3 @@ -60,16 +71,18 @@ def wrap(x): model.cond_stage_model, model.first_stage_model, model.model = tmp # set forward_pre_hook (a feature of torch NN module) to move each module to the GPU only when required - model.first_stage_model.log_name = 'model.first_stage_model' + model.first_stage_model.log_name = "model.first_stage_model" model.first_stage_model.register_forward_pre_hook(move_to_gpu) model.first_stage_model.encode = wrap_fs_fn(model.first_stage_model.encode, model.first_stage_model) model.first_stage_model.decode = wrap_fs_fn(model.first_stage_model.decode, model.first_stage_model) - model.cond_stage_model.log_name = 'model.cond_stage_model' + model.cond_stage_model.log_name = "model.cond_stage_model" model.cond_stage_model.register_forward_pre_hook(move_to_gpu) model.cond_stage_model.forward = wrap_fs_fn(model.cond_stage_model.forward, model.cond_stage_model) - if 'KEEP_ENTIRE_MODEL_IN_CPU' in context.vram_optimizations: # apply the same approach, but to the individual blocks in model + if ( + "KEEP_ENTIRE_MODEL_IN_CPU" in context.vram_optimizations + ): # apply the same approach, but to the individual blocks in model d = model.model.diffusion_model tmp = d.input_blocks, d.middle_block, d.output_blocks, d.time_embed @@ -77,61 +90,66 @@ def wrap(x): model.model.to(context.device) d.input_blocks, d.middle_block, d.output_blocks, d.time_embed = tmp - d.time_embed.log_name = 'model.model.diffusion_model.time_embed' + d.time_embed.log_name = "model.model.diffusion_model.time_embed" d.time_embed.register_forward_pre_hook(move_to_gpu) for i, block in enumerate(d.input_blocks): - block.log_name = f'model.model.diffusion_model.input_blocks[{i}]' + block.log_name = f"model.model.diffusion_model.input_blocks[{i}]" block.register_forward_pre_hook(move_to_gpu) - d.middle_block.log_name = 'model.model.diffusion_model.middle_block' + d.middle_block.log_name = "model.model.diffusion_model.middle_block" d.middle_block.register_forward_pre_hook(move_to_gpu) for i, block in enumerate(d.output_blocks): - block.log_name = f'model.model.diffusion_model.output_blocks[{i}]' + block.log_name = f"model.model.diffusion_model.output_blocks[{i}]" block.register_forward_pre_hook(move_to_gpu) else: model.model.to(context.device) - if 'KEEP_ENTIRE_MODEL_IN_CPU' not in context.vram_optimizations and 'KEEP_FS_AND_CS_IN_CPU' not in context.vram_optimizations: + if ( + "KEEP_ENTIRE_MODEL_IN_CPU" not in context.vram_optimizations + and "KEEP_FS_AND_CS_IN_CPU" not in context.vram_optimizations + ): model.to(context.device) model.cond_stage_model.device = context.device + def get_context_kv(attention_context): return attention_context, attention_context + # modified version of https://github.com/Doggettx/stable-diffusion/blob/main/ldm/modules/attention.py#L170 # faster iterations/sec than the default SD implementation, and consumes far less VRAM # On a 3060 12 GB (with the sd-v1-4.ckpt model): # - without this code, the standard SD sampler runs at 4.5 it/sec, and consumes ~6.6 GB of VRAM # - using this code makes the sampler run at 5.6 to 5.9 it/sec, and consume ~3.6 GB of VRAM on lower-end PCs, and ~4.9 GB on higher-end PCs -def make_attn_forward(context: Context, attn_precision='fp16'): +def make_attn_forward(context: Context, attn_precision="fp16"): app_context = context def get_steps(q, k): - if context.device == 'cpu' or 'SET_ATTENTION_STEP_TO_2' in context.vram_optimizations: + if context.device == "cpu" or "SET_ATTENTION_STEP_TO_2" in context.vram_optimizations: return 2 - elif 'SET_ATTENTION_STEP_TO_4' in context.vram_optimizations: - return 4 # use for balanced - elif 'SET_ATTENTION_STEP_TO_6' in context.vram_optimizations: + elif "SET_ATTENTION_STEP_TO_4" in context.vram_optimizations: + return 4 # use for balanced + elif "SET_ATTENTION_STEP_TO_6" in context.vram_optimizations: return 6 - elif 'SET_ATTENTION_STEP_TO_8' in context.vram_optimizations: + elif "SET_ATTENTION_STEP_TO_8" in context.vram_optimizations: return 8 - elif 'SET_ATTENTION_STEP_TO_16' in context.vram_optimizations: + elif "SET_ATTENTION_STEP_TO_16" in context.vram_optimizations: return 16 - elif 'SET_ATTENTION_STEP_TO_24' in context.vram_optimizations: - return 24 # use for low + elif "SET_ATTENTION_STEP_TO_24" in context.vram_optimizations: + return 24 # use for low # figure out the available memory stats = torch.cuda.memory_stats(q.device) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] + mem_active = stats["active_bytes.all.current"] + mem_reserved = stats["reserved_bytes.all.current"] mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) mem_free_torch = mem_reserved - mem_active mem_free_total = mem_free_cuda + mem_free_torch # figure out the required memory - gb = 1024 ** 3 + gb = 1024**3 tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() modifier = 3 if q.element_size() == 2 else 2.5 mem_required = tensor_size * modifier @@ -142,8 +160,10 @@ def get_steps(q, k): if steps > 64: max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 - raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' - f'Need: {mem_required / 64 / gb:0.1f} GB free, Have:{mem_free_total / gb:0.1f} GB free') + raise RuntimeError( + f"Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). " + f"Need: {mem_required / 64 / gb:0.1f} GB free, Have:{mem_free_total / gb:0.1f} GB free" + ) return steps @@ -158,10 +178,10 @@ def forward(self, x, context=None, mask=None): k_in *= self.scale del context, x - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q_in, k_in, v_in)) del q_in, k_in, v_in - autocast_device = 'cpu' if app_context.device == 'cpu' else 'cuda' # doesn't accept (or need) 'cuda:N' + autocast_device = "cpu" if app_context.device == "cpu" else "cuda" # doesn't accept (or need) 'cuda:N' r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) steps = get_steps(q, k) @@ -171,25 +191,26 @@ def forward(self, x, context=None, mask=None): if attn_precision == "fp32": with torch.autocast(enabled=False, device_type=autocast_device): q, k = q.float(), k.float() - s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) + s1 = einsum("b i d, b j d -> b i j", q[:, i:end], k) else: - s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) + s1 = einsum("b i d, b j d -> b i j", q[:, i:end], k) s2 = s1.softmax(dim=-1, dtype=q.dtype) del s1 - r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) + r1[:, i:end] = einsum("b i j, b j d -> b i d", s2, v) del s2 del q, k, v - r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) + r2 = rearrange(r1, "(b h) n d -> b n (h d)", h=h) del r1 return self.to_out(r2) return forward + def print_model_size_breakdown(model): """ Useful debugging function for analyzing the memory usage of a model @@ -204,28 +225,30 @@ def mb(n_bytes): size_input, size_middle, size_output = 0, 0, 0 for key, val in model.model.diffusion_model.state_dict().items(): s = val.element_size() * val.nelement() - if 'input' in key: + if "input" in key: size_input += s - elif 'middle' in key: + elif "middle" in key: size_middle += s - elif 'output' in key: + elif "output" in key: size_output += s - log.info(f'model.diffusion_model (input, middle, output blocks): {mb(size_input)} Mb, {mb(size_middle)} Mb, {mb(size_output)} Mb') - log.info(f'model.diffusion_model (total): {mb(size_input + size_middle + size_output)} Mb') + log.info( + f"model.diffusion_model (input, middle, output blocks): {mb(size_input)} Mb, {mb(size_middle)} Mb, {mb(size_output)} Mb" + ) + log.info(f"model.diffusion_model (total): {mb(size_input + size_middle + size_output)} Mb") # modelFS sizeFS = 0 for _, val in model.first_stage_model.state_dict().items(): sizeFS += val.element_size() * val.nelement() - log.info(f'model.first_stage_model: {mb(sizeFS)} Mb') + log.info(f"model.first_stage_model: {mb(sizeFS)} Mb") # modelCS sizeCS = 0 for _, val in model.cond_stage_model.state_dict().items(): sizeCS += val.element_size() * val.nelement() - log.info(f'model.cond_stage_model: {mb(sizeCS)} Mb') + log.info(f"model.cond_stage_model: {mb(sizeCS)} Mb") - log.info(f'model (TOTAL): {mb(size_input + size_middle + size_output + sizeFS + sizeCS)} Mb') + log.info(f"model (TOTAL): {mb(size_input + size_middle + size_output + sizeFS + sizeCS)} Mb") diff --git a/sdkit/models/model_loader/vae.py b/sdkit/models/model_loader/vae.py index b2b2749..e3dbfb0 100644 --- a/sdkit/models/model_loader/vae.py +++ b/sdkit/models/model_loader/vae.py @@ -1,9 +1,9 @@ import os -import traceback import tempfile +import traceback from sdkit import Context -from sdkit.utils import log, load_tensor_file +from sdkit.utils import load_tensor_file, log """ The VAE model overwrites the state_dict of model.first_stage_model. @@ -12,8 +12,9 @@ and restore that copy if the custom VAE is unloaded. """ + def load_model(context: Context, **kwargs): - vae_model_path = context.model_paths.get('vae') + vae_model_path = context.model_paths.get("vae") try: vae = load_tensor_file(vae_model_path) @@ -25,25 +26,29 @@ def load_model(context: Context, **kwargs): _set_vae(context, vae_dict) del vae_dict - return {} # we don't need to access this again + return {} # we don't need to access this again except: log.error(traceback.format_exc()) - log.error(f'Could not load VAE: {vae_model_path}') + log.error(f"Could not load VAE: {vae_model_path}") + def move_model_to_cpu(context: Context): pass + def unload_model(context: Context, **kwargs): base_vae = _get_base_model_vae(context) _set_vae(context, base_vae) + def _set_vae(context: Context, vae: dict): - if 'stable-diffusion' not in context.models: + if "stable-diffusion" not in context.models: return - model = context.models['stable-diffusion'] + model = context.models["stable-diffusion"] model.first_stage_model.load_state_dict(vae, strict=False) + def _get_base_model_vae(context: Context): - base_vae = os.path.join(tempfile.gettempdir(), 'sd-base-vae.safetensors') + base_vae = os.path.join(tempfile.gettempdir(), "sd-base-vae.safetensors") return load_tensor_file(base_vae) diff --git a/sdkit/models/models_db/__init__.py b/sdkit/models/models_db/__init__.py index d13dee9..410c919 100644 --- a/sdkit/models/models_db/__init__.py +++ b/sdkit/models/models_db/__init__.py @@ -4,6 +4,7 @@ db = None index = None + def get_models_db(): global db @@ -12,11 +13,15 @@ def get_models_db(): db_path = Path(__file__).parent db = {} - with open(db_path/'stable_diffusion.json') as f: db['stable-diffusion'] = json.load(f) - with open(db_path/'gfpgan.json') as f: db['gfpgan'] = json.load(f) - with open(db_path/'realesrgan.json') as f: db['realesrgan'] = json.load(f) + with open(db_path / "stable_diffusion.json") as f: + db["stable-diffusion"] = json.load(f) + with open(db_path / "gfpgan.json") as f: + db["gfpgan"] = json.load(f) + with open(db_path / "realesrgan.json") as f: + db["realesrgan"] = json.load(f) return db + def get_model_info_from_db(quick_hash=None, model_type=None, model_id=None): db = get_models_db() @@ -29,11 +34,12 @@ def get_model_info_from_db(quick_hash=None, model_type=None, model_id=None): m = db.get(model_type, {}) return m.get(model_id) + def rebuild_index(): global index db = get_models_db() index = {} for _, m in db.items(): - module_index = {info.get('quick_hash'): info for _, info in m.items()} + module_index = {info.get("quick_hash"): info for _, info in m.items()} index.update(module_index) diff --git a/sdkit/models/scan_models.py b/sdkit/models/scan_models.py index 01f3462..349de91 100644 --- a/sdkit/models/scan_models.py +++ b/sdkit/models/scan_models.py @@ -1,7 +1,8 @@ import picklescan.scanner + def scan_model(file_path): - ''' + """ Uses `picklescan.scanner.scan_file_path()` to scan and return the results. - ''' + """ return picklescan.scanner.scan_file_path(file_path) diff --git a/sdkit/train/merge_models.py b/sdkit/train/merge_models.py index aa6afa8..3b115a8 100644 --- a/sdkit/train/merge_models.py +++ b/sdkit/train/merge_models.py @@ -1,61 +1,62 @@ # loosely inspired by https://github.com/lodimasq/batch-checkpoint-merger/blob/master/batch_checkpoint_merger/main.py#L71 -from sdkit.utils import load_tensor_file, save_tensor_file, log +from sdkit.utils import load_tensor_file, log, save_tensor_file + def merge_models(model0_path: str, model1_path: str, ratio: float, out_path: str, use_fp16=True): - ''' + """ Merges (using weighted sum) and writes to the `out_path`. * model0, model1 - the first and second model files to be merged * ratio - the ratio of the second model. 1 means only the second model will be used. 0 means only the first model will be used. - ''' + """ - log.info(f'[cyan]Merge models:[/cyan] Merging {model0_path} and {model1_path}, ratio {ratio}') + log.info(f"[cyan]Merge models:[/cyan] Merging {model0_path} and {model1_path}, ratio {ratio}") merged = merge_two_models(model0_path, model1_path, ratio, use_fp16) - log.info(f'[cyan]Merge models:[/cyan] ... saving as {out_path}') + log.info(f"[cyan]Merge models:[/cyan] ... saving as {out_path}") if out_path.lower().endswith(".safetensors"): # elldrethSOg4060Mix_v10.ckpt contains a state_dict key among all the tensors, but safetensors # assumes that all entries are tensors, not dicts => remove the key - if 'state_dict' in merged: - del merged['state_dict'] + if "state_dict" in merged: + del merged["state_dict"] save_tensor_file(merged, out_path) else: - save_tensor_file({'state_dict':merged}, out_path) + save_tensor_file({"state_dict": merged}, out_path) + # do this pair-wise, to avoid having to load all the models into memory def merge_two_models(model0, model1, alpha, use_fp16=True): - ''' + """ Returns a tensor containing the merged model. Uses weighted-sum. * model0, model1 - the first and second model files to be merged * alpha - a float between [0, 1]. 0 means only model0 will be used, 1 means only model1. - + If model0 is a tensor, then model0 will be over-written with the merged data, and the same model0 reference will be returned. - ''' + """ model0_file = load_tensor_file(model0) if isinstance(model0, str) else model0 model1_file = load_tensor_file(model1) if isinstance(model1, str) else model1 - model0 = model0_file['state_dict'] if 'state_dict' in model0_file else model0_file - model1 = model1_file['state_dict'] if 'state_dict' in model1_file else model1_file + model0 = model0_file["state_dict"] if "state_dict" in model0_file else model0_file + model1 = model1_file["state_dict"] if "state_dict" in model1_file else model1_file # common weights for key in model0.keys(): - if 'model' in key and key in model1: + if "model" in key and key in model1: model0[key] = (1 - alpha) * model0[key] + alpha * model1[key] for key in model1.keys(): - if 'model' in key and key not in model0: + if "model" in key and key not in model0: model0[key] = model1[key] - if use_fp16: for key, val in model0.items(): - if 'model' in key: + if "model" in key: model0[key] = val.half() # unload model1 from memory diff --git a/sdkit/utils/__init__.py b/sdkit/utils/__init__.py index 1b0e6cb..6f9d08f 100644 --- a/sdkit/utils/__init__.py +++ b/sdkit/utils/__init__.py @@ -1,53 +1,33 @@ import logging -log = logging.getLogger('sdkit') -LOG_FORMAT = '%(asctime)s.%(msecs)03d %(levelname)s %(threadName)s %(message)s' -logging.basicConfig( - level=logging.INFO, - format=LOG_FORMAT, - datefmt="%X" -) - -from .file_utils import ( - load_tensor_file, - save_tensor_file, - save_images, - save_dicts, -) - -from .hash_utils import ( - hash_bytes, - hash_url_quick, - hash_file_quick, -) +log = logging.getLogger("sdkit") +LOG_FORMAT = "%(asctime)s.%(msecs)03d %(levelname)s %(threadName)s %(message)s" +logging.basicConfig(level=logging.INFO, format=LOG_FORMAT, datefmt="%X") +from .file_utils import load_tensor_file, save_dicts, save_images, save_tensor_file +from .hash_utils import hash_bytes, hash_file_quick, hash_url_quick +from .http_utils import download_file from .image_utils import ( - img_to_base64_str, - img_to_buffer, - buffer_to_base64_str, + apply_color_profile, base64_str_to_buffer, base64_str_to_img, + buffer_to_base64_str, + img_to_base64_str, + img_to_buffer, resize_img, - apply_color_profile, ) - from .latent_utils import ( - img_to_tensor, get_image_latent_and_mask, + img_to_tensor, latent_samples_to_images, ) - from .memory_utils import ( gc, get_device_usage, + get_object_id, + get_tensors_in_memory, print_largest_tensors_in_memory, print_tensor_info, - get_object_id, record_tensor_name, - get_tensors_in_memory, take_memory_snapshot, ) - -from .http_utils import ( - download_file, -) \ No newline at end of file diff --git a/sdkit/utils/file_utils.py b/sdkit/utils/file_utils.py index f2e62c9..d02de93 100644 --- a/sdkit/utils/file_utils.py +++ b/sdkit/utils/file_utils.py @@ -1,26 +1,30 @@ -import os import json -import torch -import safetensors.torch +import os + import piexif import piexif.helper -from PIL import Image, PngImagePlugin +import safetensors.torch +import torch +from PIL import Image from PIL.PngImagePlugin import PngInfo + def load_tensor_file(path): if path.lower().endswith(".safetensors"): return safetensors.torch.load_file(path, device="cpu") else: return torch.load(path, map_location="cpu") + def save_tensor_file(data, path): if path.lower().endswith(".safetensors"): return safetensors.torch.save_file(data, path, metadata={"format": "pt"}) else: return torch.save(data, path) -def save_images(images: list, dir_path: str, file_name='image', output_format='JPEG', output_quality=75): - ''' + +def save_images(images: list, dir_path: str, file_name="image", output_format="JPEG", output_quality=75): + """ * images: a list of of PIL.Image images to save * dir_path: the directory path where the images will be saved * file_name: the file name to save. Can be a string or a function. @@ -29,17 +33,19 @@ def save_images(images: list, dir_path: str, file_name='image', output_format='J and the returned value will be used as the actual file name. e.g `def fn(i): return 'foo' + i` * output_format: 'JPEG' or 'PNG' * output_quality: an integer between 0 and 100, used for JPEG - ''' - if dir_path is None: return + """ + if dir_path is None: + return os.makedirs(dir_path, exist_ok=True) for i, img in enumerate(images): - actual_file_name = file_name(i) if callable(file_name) else f'{file_name}_{i}' + actual_file_name = file_name(i) if callable(file_name) else f"{file_name}_{i}" path = os.path.join(dir_path, actual_file_name) - img.save(f'{path}.{output_format.lower()}', quality=output_quality) + img.save(f"{path}.{output_format.lower()}", quality=output_quality) + -def save_dicts(entries: list, dir_path: str, file_name='data', output_format='txt', file_format=''): - ''' +def save_dicts(entries: list, dir_path: str, file_name="data", output_format="txt", file_format=""): + """ * entries: a list of dictionaries * dir_path: the directory path where the files will be saved * file_name: the file name to save. Can be a string or a function. @@ -48,30 +54,33 @@ def save_dicts(entries: list, dir_path: str, file_name='data', output_format='tx and the returned value will be used as the actual file name. e.g `def fn(i): return 'foo' + i` * output_format: 'txt', 'json', or 'embed' if 'embed', the metadata will be embedded in PNG files in tEXt chunks, and as EXIF UserComment for JPEG files - ''' - if dir_path is None: return + """ + if dir_path is None: + return os.makedirs(dir_path, exist_ok=True) for i, metadata in enumerate(entries): - actual_file_name = file_name(i) if callable(file_name) else f'{file_name}_{i}' + actual_file_name = file_name(i) if callable(file_name) else f"{file_name}_{i}" path = os.path.join(dir_path, actual_file_name) - if output_format.lower() == 'embed' and file_format.lower() == 'png': - targetImage = Image.open(f'{path}.{file_format.lower()}') + if output_format.lower() == "embed" and file_format.lower() == "png": + targetImage = Image.open(f"{path}.{file_format.lower()}") embedded_metadata = PngInfo() for key, val in metadata.items(): embedded_metadata.add_text(key, str(val)) - targetImage.save(f'{path}.{file_format.lower()}', pnginfo=embedded_metadata) - elif output_format.lower() == 'embed' and file_format.lower() == 'jpeg': - targetImage = Image.open(f'{path}.{file_format.lower()}') + targetImage.save(f"{path}.{file_format.lower()}", pnginfo=embedded_metadata) + elif output_format.lower() == "embed" and file_format.lower() == "jpeg": + targetImage = Image.open(f"{path}.{file_format.lower()}") user_comment = json.dumps(metadata) - exif_dict = {'Exif': { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(user_comment, encoding="unicode")}} + exif_dict = { + "Exif": {piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(user_comment, encoding="unicode")} + } exif_bytes = piexif.dump(exif_dict) - targetImage.save(f'{path}.{file_format.lower()}', exif=exif_bytes) + targetImage.save(f"{path}.{file_format.lower()}", exif=exif_bytes) else: - with open(f'{path}.{output_format.lower()}', 'w', encoding='utf-8') as f: - if output_format.lower() == 'txt': + with open(f"{path}.{output_format.lower()}", "w", encoding="utf-8") as f: + if output_format.lower() == "txt": for key, val in metadata.items(): - f.write(f'{key}: {val}\n') - elif output_format.lower() == 'json': + f.write(f"{key}: {val}\n") + elif output_format.lower() == "json": json.dump(metadata, f, indent=2) diff --git a/sdkit/utils/hash_utils.py b/sdkit/utils/hash_utils.py index 86a41e3..0bc3c8e 100644 --- a/sdkit/utils/hash_utils.py +++ b/sdkit/utils/hash_utils.py @@ -1,20 +1,23 @@ -import os import hashlib +import os + import requests + def hash_url_quick(url): from sdkit.utils import log - log.debug(f'hashing url: {url}') + + log.debug(f"hashing url: {url}") def get_size(): res = requests.get(url, stream=True) - size = int(res.headers['content-length']) # fail loudly if the url doesn't return a content-length header - log.debug(f'total size: {size}') + size = int(res.headers["content-length"]) # fail loudly if the url doesn't return a content-length header + log.debug(f"total size: {size}") return size def read_bytes(offset: int, count: int): res = requests.get(url, headers={"Range": f"bytes={offset}-{offset+count-1}"}) - log.debug(f'read byte range. offset: {offset}, count: {count}, actual count: {len(res.content)}') + log.debug(f"read byte range. offset: {offset}, count: {count}, actual count: {len(res.content)}") return res.content return compute_quick_hash( @@ -22,20 +25,22 @@ def read_bytes(offset: int, count: int): read_bytes_fn=read_bytes, ) + def hash_file_quick(file_path): from sdkit.utils import log - log.debug(f'hashing file: {file_path}') + + log.debug(f"hashing file: {file_path}") def get_size(): size = os.path.getsize(file_path) - log.debug(f'total size: {size}') + log.debug(f"total size: {size}") return size def read_bytes(offset: int, count: int): - with open(file_path, 'rb') as f: + with open(file_path, "rb") as f: f.seek(offset) bytes = f.read(count) - log.debug(f'read byte range. offset: {offset}, count: {count}, actual count: {len(bytes)}') + log.debug(f"read byte range. offset: {offset}, count: {count}, actual count: {len(bytes)}") return bytes return compute_quick_hash( @@ -43,8 +48,9 @@ def read_bytes(offset: int, count: int): read_bytes_fn=read_bytes, ) + def compute_quick_hash(total_size_fn, read_bytes_fn): - ''' + """ quick-hash logic: - read 64k chunks from the start, middle and end, and hash them - start offset: 1 MB @@ -52,14 +58,15 @@ def compute_quick_hash(total_size_fn, read_bytes_fn): - end offset: -1 MB Do not use if the file size is less than 3 MB - ''' + """ total_size = total_size_fn() start_bytes = read_bytes_fn(offset=0x100000, count=0x10000) - middle_bytes = read_bytes_fn(offset=int(total_size/2), count=0x10000) + middle_bytes = read_bytes_fn(offset=int(total_size / 2), count=0x10000) end_bytes = read_bytes_fn(offset=total_size - 0x100000, count=0x10000) return hash_bytes(start_bytes + middle_bytes + end_bytes) + def hash_bytes(bytes): return hashlib.sha256(bytes).hexdigest() diff --git a/sdkit/utils/http_utils.py b/sdkit/utils/http_utils.py index ebcdf53..74c53cb 100644 --- a/sdkit/utils/http_utils.py +++ b/sdkit/utils/http_utils.py @@ -1,29 +1,35 @@ -import requests import os -from tqdm import tqdm from shutil import copyfileobj +import requests +from tqdm import tqdm + + def download_file(url: str, out_path: str): - ''' + """ Features: * Downloads large files (without storing them in memory) * Resumes downloads from the bytes it has downloaded already * Shows a progress bar The remote server needs to support the `Range` header, for resume to work. - ''' + """ from sdkit.utils import log start_offset = 0 if not os.path.exists(out_path) else os.path.getsize(out_path) res = requests.get(url, stream=True) - if not res.ok: return - total_bytes = int(res.headers.get('Content-Length', '0')) + if not res.ok: + return + total_bytes = int(res.headers.get("Content-Length", "0")) - res = requests.get(url, stream=True, headers={'Range': f'bytes={start_offset}-', 'Accept-Encoding': 'identity'}) - if not res.ok: return + res = requests.get(url, stream=True, headers={"Range": f"bytes={start_offset}-", "Accept-Encoding": "identity"}) + if not res.ok: + return - write_mode = 'wb' if start_offset == 0 else 'ab' + write_mode = "wb" if start_offset == 0 else "ab" - log.info(f'Downloading {url} to {out_path}') - with open(out_path, write_mode) as f, tqdm.wrapattr(res.raw, 'read', initial=start_offset, total=total_bytes, desc='Downloading', colour='green') as res_stream: + log.info(f"Downloading {url} to {out_path}") + with open(out_path, write_mode) as f, tqdm.wrapattr( + res.raw, "read", initial=start_offset, total=total_bytes, desc="Downloading", colour="green" + ) as res_stream: copyfileobj(res_stream, f) diff --git a/sdkit/utils/image_utils.py b/sdkit/utils/image_utils.py index deed5d2..00da245 100644 --- a/sdkit/utils/image_utils.py +++ b/sdkit/utils/image_utils.py @@ -1,15 +1,18 @@ -import numpy as np -import cv2 -from skimage import exposure import base64 from io import BytesIO + +import cv2 +import numpy as np from PIL import Image +from skimage import exposure + # https://stackoverflow.com/a/61114178 def img_to_base64_str(img, output_format="PNG", output_quality=75): buffered = img_to_buffer(img, output_format, output_quality=output_quality) return buffer_to_base64_str(buffered, output_format) + def img_to_buffer(img, output_format="PNG", output_quality=75): buffered = BytesIO() if output_format.upper() == "JPEG": @@ -19,6 +22,7 @@ def img_to_buffer(img, output_format="PNG", output_quality=75): buffered.seek(0) return buffered + def buffer_to_base64_str(buffered, output_format="PNG"): buffered.seek(0) img_byte = buffered.getvalue() @@ -26,18 +30,21 @@ def buffer_to_base64_str(buffered, output_format="PNG"): img_str = f"data:{mime_type};base64," + base64.b64encode(img_byte).decode() return img_str + def base64_str_to_buffer(img_str): mime_type = "image/png" if img_str.startswith("data:image/png;") else "image/jpeg" - img_str = img_str[len(f"data:{mime_type};base64,"):] + img_str = img_str[len(f"data:{mime_type};base64,") :] data = base64.b64decode(img_str) buffered = BytesIO(data) return buffered + def base64_str_to_img(img_str): buffered = base64_str_to_buffer(img_str) img = Image.open(buffered) return img + def resize_img(img: Image, desired_width, desired_height, clamp_to_64=False): w, h = img.size @@ -49,6 +56,7 @@ def resize_img(img: Image, desired_width, desired_height, clamp_to_64=False): return img.resize((w, h), resample=Image.Resampling.LANCZOS) + def apply_color_profile(orig_image: Image, image_to_modify: Image): reference = cv2.cvtColor(np.asarray(orig_image), cv2.COLOR_RGB2LAB) image_to_modify = cv2.cvtColor(np.asarray(image_to_modify), cv2.COLOR_RGB2LAB) diff --git a/sdkit/utils/latent_utils.py b/sdkit/utils/latent_utils.py index ac33926..2bc6e60 100644 --- a/sdkit/utils/latent_utils.py +++ b/sdkit/utils/latent_utils.py @@ -1,10 +1,11 @@ import numpy as np import torch +from einops import rearrange, repeat from PIL import Image, ImageOps -from einops import repeat, rearrange from sdkit import Context + def img_to_tensor(img: Image, batch_size, device, half_precision: bool, shift_range=False, unsqueeze=False): if img is None: return None @@ -12,7 +13,7 @@ def img_to_tensor(img: Image, batch_size, device, half_precision: bool, shift_ra img = np.array(img).astype(np.float32) / 255.0 img = img[None].transpose(0, 3, 1, 2) img = torch.from_numpy(img) - img = 2. * img - 1. if shift_range else img + img = 2.0 * img - 1.0 if shift_range else img img = img.to(device) if device != "cpu" and half_precision: @@ -21,39 +22,43 @@ def img_to_tensor(img: Image, batch_size, device, half_precision: bool, shift_ra if unsqueeze: img = img[0][0].unsqueeze(0).repeat(4, 1, 1).unsqueeze(0) - img = repeat(img, '1 ... -> b ...', b=batch_size) + img = repeat(img, "1 ... -> b ...", b=batch_size) return img + def get_image_latent_and_mask(context: Context, image: Image, mask: Image, desired_width, desired_height, batch_size): """ Assumes model is on the correct device """ from .image_utils import resize_img - if image is None or 'stable-diffusion' not in context.models: + + if image is None or "stable-diffusion" not in context.models: return None, None - model = context.models['stable-diffusion'] + model = context.models["stable-diffusion"] - image = image.convert('RGB') + image = image.convert("RGB") image = resize_img(image, desired_width, desired_height, clamp_to_64=True) image = img_to_tensor(image, batch_size, context.device, context.half_precision, shift_range=True) - image = model.get_first_stage_encoding(model.encode_first_stage(image)) # move to latent space + image = model.get_first_stage_encoding(model.encode_first_stage(image)) # move to latent space if mask is None: return image, None - mask = mask.convert('RGB') + mask = mask.convert("RGB") mask = resize_img(mask, image.shape[3], image.shape[2]) mask = ImageOps.invert(mask) mask = img_to_tensor(mask, batch_size, context.device, context.half_precision, unsqueeze=True) return image, mask + def latent_samples_to_images(context: Context, samples): - model = context.models['stable-diffusion'] + model = context.models["stable-diffusion"] - if context.half_precision and samples.dtype != torch.float16: samples = samples.half() + if context.half_precision and samples.dtype != torch.float16: + samples = samples.half() samples = model.decode_first_stage(samples) samples = torch.clamp((samples + 1.0) / 2.0, min=0.0, max=1.0) diff --git a/sdkit/utils/memory_utils.py b/sdkit/utils/memory_utils.py index acfa28f..49e0f27 100644 --- a/sdkit/utils/memory_utils.py +++ b/sdkit/utils/memory_utils.py @@ -1,27 +1,30 @@ -import torch -import psutil import base64 from functools import reduce from gc import collect, get_objects, get_referrers +import psutil +import torch + from sdkit import Context tensor_ids_snapshot = None recorded_tensor_names = {} + def gc(context: Context): collect() - if context.device == 'cpu': + if context.device == "cpu": return torch.cuda.empty_cache() torch.cuda.ipc_collect() + def get_device_usage(device, log_info=False): cpu_used = psutil.cpu_percent() ram_used, ram_total = psutil.virtual_memory().used, psutil.virtual_memory().total - vram_free, vram_total = torch.cuda.mem_get_info(device) if device != 'cpu' else (0, 0) - vram_used = (vram_total - vram_free) + vram_free, vram_total = torch.cuda.mem_get_info(device) if device != "cpu" else (0, 0) + vram_used = vram_total - vram_free ram_used /= 1024**3 ram_total /= 1024**3 @@ -30,95 +33,106 @@ def get_device_usage(device, log_info=False): if log_info: from sdkit.utils import log - msg = f'CPU utilization: {cpu_used:.1f}%, System RAM used: {ram_used:.1f} of {ram_total:.1f} GiB' - if device != 'cpu': msg += f', GPU RAM used ({device}): {vram_used:.1f} of {vram_total:.1f} GiB' + + msg = f"CPU utilization: {cpu_used:.1f}%, System RAM used: {ram_used:.1f} of {ram_total:.1f} GiB" + if device != "cpu": + msg += f", GPU RAM used ({device}): {vram_used:.1f} of {vram_total:.1f} GiB" log.info(msg) return cpu_used, ram_used, ram_total, vram_used, vram_total + def get_object_id(o): - ''' + """ Returns a more-readable object id, than the long number returned by the inbuilt `id()` function. Internally, this calls `id()` and converts the number to a base64 string. - ''' + """ obj_id = id(o) - obj_id = base64.b64encode(obj_id.to_bytes(8, 'big')).decode() - return obj_id.translate({43:None, 47:None, 61:None})[-8:] + obj_id = base64.b64encode(obj_id.to_bytes(8, "big")).decode() + return obj_id.translate({43: None, 47: None, 61: None})[-8:] -def record_tensor_name(t, name='t', log_info=False): - ''' + +def record_tensor_name(t, name="t", log_info=False): + """ Records a name for the given tensor object. Helpful while investigating the source of memory leaks. For e.g. you can record variables from across the codebase, and see which one is leaking by calling `print_largest_tensors_in_memory()` or `take_memory_snapshot()` after calling `gc()` (for garbage-collection). - + `print_largest_tensors_in_memory()` and `take_memory_snapshot()` print the recorded names for each tensor, if available. - ''' + """ obj_id = get_object_id(t) recorded_tensor_names[obj_id] = [] if obj_id not in recorded_tensor_names else recorded_tensor_names[obj_id] recorded_tensor_names[obj_id].append(name) - if log_info: print_tensor_info(t, name) + if log_info: + print_tensor_info(t, name) -def print_tensor_info(t, name='t'): - 'Prints a summary of the tensor, for e.g. its size, shape, data type, device etc.' + +def print_tensor_info(t, name="t"): + "Prints a summary of the tensor, for e.g. its size, shape, data type, device etc." from sdkit.utils import log obj_id = get_object_id(t) - obj_size = t.nelement() * t.element_size() / 1024**2 # MiB - log.info(f' {name} id: {obj_id}, size: {obj_size} MiB, shape: {t.shape}, requires_grad: {t.requires_grad}, type: {t.dtype}, device: {t.device}') + obj_size = t.nelement() * t.element_size() / 1024**2 # MiB + log.info( + f" {name} id: {obj_id}, size: {obj_size} MiB, shape: {t.shape}, requires_grad: {t.requires_grad}, type: {t.dtype}, device: {t.device}" + ) + def get_tensors_in_memory(device): - ''' + """ Returns the list of all the tensor objects in memory, on the given device. **Warning: Do not keep a reference to the returned list longer than necessary, since that will prevent garbage-collection of all the tensors in memory.** - ''' + """ device = torch.device(device) if isinstance(device, str) else device tensors = [] objs_in_mem = get_objects() for obj in objs_in_mem: try: - if (torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data))) and obj.device == device: + if (torch.is_tensor(obj) or (hasattr(obj, "data") and torch.is_tensor(obj.data))) and obj.device == device: tensors.append(obj) except: pass return tensors + def print_largest_tensors_in_memory(device, num=10): - ''' + """ Prints a list of the largest tensors in the given device. Choose the number of objects displayed with the `num` argument. Prints the recorded names for each tensor, if recorded using `record_tensor_name()`. See also: `take_memory_snapshot()` which is probably more useful for investigating memory leaks. - ''' + """ entries, total_mem = _get_tensor_entries(device) n = len(entries) - entries = entries[:min(num, n)] + entries = entries[: min(num, n)] - print(f'== {num} largest tensors on {device} ==') + print(f"== {num} largest tensors on {device} ==") print(_fmt_tensors_summary(entries)) - print('---') - print(f'Total memory occupied on {device} by {n} tensors: {total_mem:.1f} MiB') + print("---") + print(f"Total memory occupied on {device} by {n} tensors: {total_mem:.1f} MiB") + def take_memory_snapshot(device, print_snapshot=True): - ''' + """ Records and prints a list of new tensors (in the device) since the last snapshot (created by calling `take_memory_snapshot()`). Prints the recorded names for each tensor, if recorded using `record_tensor_name()`. See also: `print_largest_tensors_in_memory()`. - ''' + """ global tensor_ids_snapshot - is_first_snapshot = (tensor_ids_snapshot is None) + is_first_snapshot = tensor_ids_snapshot is None tensor_ids_snapshot = set() if is_first_snapshot else tensor_ids_snapshot # take the snapshot @@ -132,22 +146,23 @@ def take_memory_snapshot(device, print_snapshot=True): return new_tensor_entries = [entry for entry in entries if entry[0] in new_tensor_ids] - new_tensors_total_mem = reduce(lambda sum, entry: sum + entry[1], new_tensor_entries, 0) # MiB + new_tensors_total_mem = reduce(lambda sum, entry: sum + entry[1], new_tensor_entries, 0) # MiB print(new_tensors_total_mem) num_new_tensors = len(new_tensor_ids) - print(f'== {num_new_tensors} new tensors this snapshot on {device} ==') + print(f"== {num_new_tensors} new tensors this snapshot on {device} ==") print(_fmt_tensors_summary(new_tensor_entries)) - print('---') - print(f'Total memory occupied on {device} by {len(entries)} tensors: {total_mem:.1f} MiB') - print(f'{num_new_tensors} new tensors added {new_tensors_total_mem:.1f} MiB this frame') + print("---") + print(f"Total memory occupied on {device} by {len(entries)} tensors: {total_mem:.1f} MiB") + print(f"{num_new_tensors} new tensors added {new_tensors_total_mem:.1f} MiB this frame") + def _get_tensor_entries(device, sorted_by_size=True): entries = [] tensors = get_tensors_in_memory(device) total_mem = 0 for t in tensors: - size = t.nelement() * t.element_size() / 1024**2 # MiB + size = t.nelement() * t.element_size() / 1024**2 # MiB obj_id = get_object_id(t) entry = [obj_id, size, t.shape, len(get_referrers(t)), t.requires_grad, t.dtype] entries.append(entry) @@ -155,15 +170,19 @@ def _get_tensor_entries(device, sorted_by_size=True): del tensors - if sorted_by_size: entries.sort(key=lambda x: x[0], reverse=True) + if sorted_by_size: + entries.sort(key=lambda x: x[0], reverse=True) return entries, total_mem + def _fmt_tensors_summary(entries): summary = [] for i, o in enumerate(entries): obj_id, size, shape, n_referrers, requires_grad, dtype = o - known_names = f' ({recorded_tensor_names[obj_id]})' if obj_id in recorded_tensor_names else '' - summary.append(f'{i+1}. Id: {obj_id}{known_names}, Size: {size:.1f} MiB, Shape: {shape}, Referrers: {n_referrers}, requires_grad: {requires_grad}, dtype: {dtype}') + known_names = f" ({recorded_tensor_names[obj_id]})" if obj_id in recorded_tensor_names else "" + summary.append( + f"{i+1}. Id: {obj_id}{known_names}, Size: {size:.1f} MiB, Shape: {shape}, Referrers: {n_referrers}, requires_grad: {requires_grad}, dtype: {dtype}" + ) - return '\n'.join(summary) + return "\n".join(summary) diff --git a/tests/vram_frees_after_image_generation.py b/tests/vram_frees_after_image_generation.py index 3ca7ce1..c5d90bc 100644 --- a/tests/vram_frees_after_image_generation.py +++ b/tests/vram_frees_after_image_generation.py @@ -1,27 +1,29 @@ -import os import argparse +import os from collections import namedtuple parser = argparse.ArgumentParser() -parser.add_argument('--models-dir', type=str, required=True, help="Path to the directory containing the Stable Diffusion models") +parser.add_argument( + "--models-dir", type=str, required=True, help="Path to the directory containing the Stable Diffusion models" +) args = parser.parse_args() from sdkit import Context -from sdkit.models import load_model from sdkit.generate import generate_images -from sdkit.utils import log, get_device_usage +from sdkit.models import load_model +from sdkit.utils import get_device_usage, log -DeviceUsage = namedtuple('DeviceUsage', ['cpu_used', 'ram_used', 'ram_total', 'vram_used', 'vram_total']) +DeviceUsage = namedtuple("DeviceUsage", ["cpu_used", "ram_used", "ram_total", "vram_used", "vram_total"]) c = Context() -log.info('Starting..') +log.info("Starting..") usage_start = DeviceUsage(*get_device_usage(c.device, log_info=True)) -c.model_paths['stable-diffusion'] = os.path.join(args.models_dir, 'sd-v1-4.ckpt') -load_model(c, 'stable-diffusion') +c.model_paths["stable-diffusion"] = os.path.join(args.models_dir, "sd-v1-4.ckpt") +load_model(c, "stable-diffusion") -log.info('Loaded the model..') +log.info("Loaded the model..") usage_model_load = DeviceUsage(*get_device_usage(c.device, log_info=True)) try: @@ -29,16 +31,20 @@ except Exception as e: log.exception(e) -log.info('Generated the image..') +log.info("Generated the image..") usage_after_render = DeviceUsage(*get_device_usage(c.device, log_info=True)) -print('') -log.info(f'VRAM trend: {usage_start.vram_used:.1f} (start) GiB to {usage_model_load.vram_used:.1f} GiB (before render) to {usage_after_render.vram_used:.1f} GiB (after render)') -print('') +print("") +log.info( + f"VRAM trend: {usage_start.vram_used:.1f} (start) GiB to {usage_model_load.vram_used:.1f} GiB (before render) to {usage_after_render.vram_used:.1f} GiB (after render)" +) +print("") max_expected_vram = usage_model_load.vram_used + 0.3 if usage_after_render.vram_used > max_expected_vram: - log.error(f'Test failed! VRAM after render was expected to be below {max_expected_vram:.1f} GiB, but was {usage_after_render.vram_used:.1f} GiB!') + log.error( + f"Test failed! VRAM after render was expected to be below {max_expected_vram:.1f} GiB, but was {usage_after_render.vram_used:.1f} GiB!" + ) exit(1) else: - log.info('Test passed!') + log.info("Test passed!")