Skip to content

Commit

Permalink
add training code
Browse files Browse the repository at this point in the history
  • Loading branch information
kliyer-ai committed Oct 2, 2024
1 parent 554bc1b commit 2708246
Show file tree
Hide file tree
Showing 10 changed files with 536 additions and 89 deletions.
41 changes: 41 additions & 0 deletions configs/experiment/sample_struct_attn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# @package _global_

defaults:
- /[email protected]: struct_attn
- override /lora/[email protected]: midas # hed
- override /model: sd15
- override /data: local
- _self_


size: 512
n_samples: 4

save_grid: true
log_cond: true

data:
caption_from_name: true
caption_prefix: "a picture of "
directories:
- data

model:
guidance_scale: 7.5

prompt: ''

lora:
struct:
cfg: false
# ckpt_path: checkpoints/sd15-hed-128-only-res
ckpt_path: checkpoints/sd15-depth-02-self
config:

c_dim: 128
rank: 0.2
adaption_mode: only_self

tag: struct

bf16: true
38 changes: 38 additions & 0 deletions configs/experiment/train_struct_sd15.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# @package _global_

defaults:
- /[email protected]: struct
- override /lora/[email protected]: midas
- override /model: sd15
- override /data: local
- _self_

data:
batch_size: 8
caption_from_name: true
caption_prefix: "a picture of "
directories:
- data

lora:
struct:
optimize: true


size: 512

log_c: true

val_batches: 4

learning_rate: 1e-4

ckpt_steps: 3000
val_steps: 3000

epochs: 10

prompt: null

# model:
# guidance_scale: 1.5
37 changes: 37 additions & 0 deletions configs/experiment/train_style_sd15.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# @package _global_

defaults:
- /[email protected]: style
- override /model: sd15
- override /data: local
- _self_

data:
batch_size: 8
caption_from_name: true
caption_prefix: "a picture of "
directories:
- data

val_batches: 1

lora:
style:
# rank: 208
# rank: 16
adaption_mode: only_cross
optimize: true

size: 512

learning_rate: 1e-4

ckpt_steps: 1000
val_steps: 1000

epochs: 100

prompt: null

# model:
# guidance_scale: 1.5
2 changes: 0 additions & 2 deletions configs/model/sd15.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,4 @@ defaults:
_target_: src.model.SD15
pipeline_type: diffusers.StableDiffusionPipeline
model_name: runwayml/stable-diffusion-v1-5
use_embeds: ${use_embeds}
dtype: fp32
local_files_only: ${local_files_only}
1 change: 0 additions & 1 deletion configs/model/sdxl.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,4 @@ defaults:
_target_: src.model.SDXL
pipeline_type: diffusers.StableDiffusionXLPipeline
model_name: stabilityai/stable-diffusion-xl-base-1.0
use_embeds: ${use_embeds}
local_files_only: ${local_files_only}
3 changes: 1 addition & 2 deletions configs/sample.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,4 @@ hydra:
job:
chdir: true

local_files_only: false
use_embeds: false
local_files_only: false
37 changes: 37 additions & 0 deletions configs/train.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
defaults:
- data: ???
- model: ???
- _self_
- experiment: null

size: ???
max_train_steps: null
epochs: 20
learning_rate: 1e-4

lr_warmup_steps: 0
lr_scheduler: constant

prompt: null
gradient_accumulation_steps: 1

ckpt_steps: 1000
val_steps: 1000
val_images: 4
seed: 42
n_samples: 4



tag: ''

local_files_only: false

hydra:
run:
dir: outputs/train/${tag}/runs/${now:%Y-%m-%d}/${now:%H-%M-%S}
sweep:
dir: outputs/train/${tag}/multiruns/${now:%Y-%m-%d}/${now:%H-%M-%S}
subdir: ${hydra.job.num}
job:
chdir: true
53 changes: 42 additions & 11 deletions src/data/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,22 +66,53 @@ def __getitem__(self, idx: int):


class ImageDataModule:
def __init__(self, directories: list[str], transform: list, batch_size: int = 32, caption_from_name: bool = False, caption_prefix: str = ""):
def __init__(
self,
directories: list[str],
transform: list,
val_directories: list[str] = [],
batch_size: int = 32,
val_batch_size: int = 1,
workers: int = 4,
val_workers: int = 1,
caption_from_name: bool = False,
caption_prefix: str = "",
):
super().__init__()

self.batch_size = batch_size
self.val_batch_size = val_batch_size
self.workers = workers
self.val_workers = val_workers

project_root = Path(os.path.abspath(__file__)).parent.parent.parent

self.train_dataset = ZipDataset(
[
ImageFolderDataset(
directory=Path(project_root, d),
transform=transforms.Compose(transform),
caption_from_name=caption_from_name,
caption_prefix=caption_prefix,
)
for d in directories
]
)

self.val_dataset = ZipDataset(
[ImageFolderDataset(directory=Path(project_root, d), transform=transforms.Compose(transform), caption_from_name=caption_from_name, caption_prefix=caption_prefix) for d in directories]
[
ImageFolderDataset(
directory=Path(project_root, d),
transform=transforms.Compose(transform),
caption_from_name=caption_from_name,
caption_prefix=caption_prefix,
)
for d in val_directories
]
)
self.batch_size = batch_size

def train_dataloader(self):
raise Exception("Not implemented")
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.workers)

def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False)

def test_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False)

def predict_dataloader(self):
raise Exception("Not implemented")
return DataLoader(self.val_dataset, batch_size=self.val_batch_size, shuffle=False, num_workers=self.val_workers)
Loading

0 comments on commit 2708246

Please sign in to comment.