-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerate.py
68 lines (60 loc) · 2.01 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import jax
import ml_collections as mlc
import optax
import char_diffusion as cd
from char_diffusion import utils
from char_diffusion import configs
from char_diffusion.diffusion import get_schedule
def generate(config: mlc.ConfigDict):
assert config.checkpoint_path is not None, \
"Must provide a checkpoint path to generate samples."
key = jax.random.PRNGKey(config.seed)
net = cd.UNet1d(
in_channels=1,
model_channels=config.model.base_channels,
key=key,
bit_width=config.model.bit_width,
num_res_blocks=config.model.num_res_blocks,
num_heads=config.model.num_heads,
num_groups=4,
attn_resolutions=(False, False, True),
channel_mult=(1, 2, 4),
)
optim = optax.adam(
config.optim.lr,
b1=config.optim.adam_beta1,
b2=config.optim.adam_beta2,
eps=1e-8,
)
optim_state = optim.init(net)
step_state = 0
net, optim_state, step_state = utils.load_state_dict(
path=config.checkpoint_path,
tree=(net, optim_state, step_state)
)
diffuser = cd.CharDiffusion(
num_steps=config.model.num_steps,
use_self_cond=config.model.use_self_cond,
gamma_schedule=get_schedule(config.model.schedule),
optim=optim,
)
key, gen_key = jax.random.split(key)
num_samples = 8
generation = diffuser.generate(
net,
shape=(num_samples, config.model.bit_width, config.model.seq_len),
num_steps=config.model.num_gen_steps,
bit_width=config.model.bit_width,
key=gen_key,
time_delta=config.model.time_delta,
)
generation = generation.squeeze(1).device_buffer.to_py()
print(f"Generation IDs:\n{generation}")
print(f"Generations:\n{[cd.utils.decode(g) for g in generation]}")
if __name__ == "__main__":
config = configs.char_diffusion_base_config()
config.seed = 9999
config.model.num_gen_steps = 2_000
config.model.schedule = "cosine"
config.checkpoint_path = ""
generate(config)