-
Notifications
You must be signed in to change notification settings - Fork 1
/
repxl.py
92 lines (73 loc) · 3.5 KB
/
repxl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#
# Stable Diffusion Shell
# Tool to automate SD workflow. For now just re-sizes images.
#
# (c) in 2022 by Guido Appenzeller
import sys
import os
import click
import tqdm
from PIL import Image
from dotenv import load_dotenv, find_dotenv
from imgtools import prep_images
from finetune import train_model
def get_tmp_root(tmproot):
if tmproot is None:
tmproot = os.path.join(os.getcwd(),"tmp")
if not os.path.exists(tmproot):
os.makedirs(tmproot)
return tmproot
def get_tmp_dir(tmproot,name):
tmproot = get_tmp_root(tmproot)
tmp_dir = os.path.join(tmproot,name)
if not os.path.exists(tmp_dir):
os.makedirs(tmp_dir)
return tmp_dir
else:
print(f"Directory {tmp_dir} already exists. Delete or pick a different name with '--name <name>'")
return None
@click.group()
def repxl():
"""Command line tool to run SDXL on Replicate."""
pass
@repxl.command()
@click.argument('prompt', type=click.STRING)
@click.option('--model', type=click.STRING, help="Name of the model to use in the format 'username/modelname'")
def render(prompt,model):
"""Render a new image from a trained model."""
print(prompt)
pass
@repxl.command()
@click.argument('srcdir', type=click.STRING)
@click.option('--tmpdir', type=click.STRING, default=None, help="Temporary directory for all training, default is ./tmp")
@click.option('--token', type=click.STRING, default="TOK", help="Token name we use for the training run, default is 'TOK'")
def prepare(srcdir,tmpdir,token):
"""Prepare images for fine-tuning (crop/convert/zip).
SRCDIR: directory with images to prepare, must be .png./.jpeg/.jpg format.
By default temporary directory is created in current directory."""
tmp_root = get_tmp_root(tmpdir)
tmpdir = get_tmp_dir(tmp_root,token)
prep_images(srcdir, tmpdir, 1024, 1024, iname=".src.jpg")
# Run the zip command to creat a zip file of the directory tmpdir
os.system(f"zip -j -r {tmp_root}/{token}.zip {tmpdir}")
@repxl.command()
@click.option('--token', type=click.STRING, default="TOK", help="Token name we use for the training run, default is 'TOK'")
@click.option('--tmpdir', type=click.STRING, default=None, help="Temporary directory for all training, default is ./tmp")
@click.option('--face-detection/--no-face-detection', default=True, help="Use face detection instead of masktarget, default is True.")
@click.option('--masktarget', type=click.STRING, default=None, help="Mask target for training, default is None. Disables face detection.")
@click.option('--captionprefix', type=click.STRING, default="a photo of", help="Prefix before the token, default is a 'a photo of'.")
@click.option('--dreambooth/--lora', default=False, help="Use dreambooth instead of the default LoRA")
@click.option('--model', type=click.STRING, default=None, help="Name of the model to push to, default is [token]-[lora/dreambooth]")
def train(token, tmpdir, face_detection, masktarget, captionprefix, dreambooth, model):
"""Fine-tune SDXL on Replicate.
Training progress can be viewed on https://replicate.com/trainings"""
tmpdir = get_tmp_root(tmpdir)
if model is None:
model = f"{token}-{'dreambooth' if dreambooth else 'lora'}"
train_model(token,tmpdir,model, masktarget=masktarget,captionprefix=captionprefix, dreambooth=dreambooth, use_face_detection_instead=face_detection)
repxl.add_command(render)
repxl.add_command(prepare)
repxl.add_command(train)
if __name__ == '__main__':
load_dotenv(find_dotenv())
repxl()