-
Notifications
You must be signed in to change notification settings - Fork 37
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
660 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
{ | ||
"git.ignoreLimitWarning": true | ||
} |
Empty file.
Submodule guided-diffusion
added at
22e0df
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
""" | ||
Train a diffusion model on images. | ||
""" | ||
|
||
import argparse | ||
|
||
from house_diffusion import dist_util, logger | ||
from house_diffusion.rplanhg_datasets import load_rplanhg_data | ||
from house_diffusion.resample import create_named_schedule_sampler | ||
from house_diffusion.script_util import ( | ||
model_and_diffusion_defaults, | ||
create_model_and_diffusion, | ||
args_to_dict, | ||
add_dict_to_argparser, | ||
update_arg_parser, | ||
) | ||
from house_diffusion.train_util import TrainLoop | ||
|
||
|
||
def main(): | ||
args = create_argparser().parse_args() | ||
update_arg_parser(args) | ||
|
||
dist_util.setup_dist() | ||
logger.configure() | ||
|
||
logger.log("creating model and diffusion...") | ||
model, diffusion = create_model_and_diffusion( | ||
**args_to_dict(args, model_and_diffusion_defaults().keys()) | ||
) | ||
model.to(dist_util.dev()) | ||
schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion) | ||
|
||
logger.log("creating data loader...") | ||
if args.dataset=='rplan': | ||
data = load_rplanhg_data( | ||
batch_size=args.batch_size, | ||
analog_bit=args.analog_bit, | ||
target_set=args.target_set, | ||
set_name=args.set_name, | ||
) | ||
else: | ||
print('dataset not exist!') | ||
assert False | ||
|
||
logger.log("training...") | ||
TrainLoop( | ||
model=model, | ||
diffusion=diffusion, | ||
data=data, | ||
batch_size=args.batch_size, | ||
microbatch=args.microbatch, | ||
lr=args.lr, | ||
ema_rate=args.ema_rate, | ||
log_interval=args.log_interval, | ||
save_interval=args.save_interval, | ||
resume_checkpoint=args.resume_checkpoint, | ||
use_fp16=args.use_fp16, | ||
fp16_scale_growth=args.fp16_scale_growth, | ||
schedule_sampler=schedule_sampler, | ||
weight_decay=args.weight_decay, | ||
lr_anneal_steps=args.lr_anneal_steps, | ||
analog_bit=args.analog_bit, | ||
).run_loop() | ||
|
||
|
||
def create_argparser(): | ||
defaults = dict( | ||
dataset = '', | ||
schedule_sampler= "uniform", #"loss-second-moment", "uniform", | ||
lr=1e-4, | ||
weight_decay=0.0, | ||
lr_anneal_steps=0, | ||
batch_size=1, | ||
microbatch=-1, # -1 disables microbatches | ||
ema_rate="0.9999", # comma-separated list of EMA values | ||
log_interval=10, | ||
save_interval=10000, | ||
resume_checkpoint="", | ||
use_fp16=False, | ||
fp16_scale_growth=1e-3, | ||
) | ||
parser = argparse.ArgumentParser() | ||
defaults.update(model_and_diffusion_defaults()) | ||
add_dict_to_argparser(parser, defaults) | ||
return parser | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import json | ||
import numpy as np | ||
from glob import glob | ||
import os | ||
|
||
def reader(filename): | ||
with open(filename) as f: | ||
info = json.load(f) | ||
rms_bbs = np.asarray(info['boxes']) | ||
fp_eds = info['edges'] | ||
rms_type = info['room_type'] | ||
eds_to_rms = info['ed_rm'] | ||
s_r = 0 | ||
for rmk in range(len(rms_type)): | ||
if rms_type[rmk] != 17: | ||
s_r = s_r + 1 | ||
rms_bbs = np.array(rms_bbs) / 256.0 | ||
fp_eds = np.array(fp_eds) / 256.0 | ||
fp_eds = fp_eds[:, :4] | ||
tl = np.min(rms_bbs[:, :2], 0) | ||
br = np.max(rms_bbs[:, 2:], 0) | ||
shift = (tl + br) / 2.0 - 0.5 | ||
rms_bbs[:, :2] -= shift | ||
rms_bbs[:, 2:] -= shift | ||
fp_eds[:, :2] -= shift | ||
fp_eds[:, 2:] -= shift | ||
tl -= shift | ||
br -= shift | ||
eds_to_rms_tmp = [] | ||
|
||
for l in range(len(eds_to_rms)): | ||
eds_to_rms_tmp.append([eds_to_rms[l][0]]) | ||
|
||
return rms_type, fp_eds, rms_bbs, eds_to_rms, eds_to_rms_tmp | ||
|
||
file_list = glob('rplan/*') | ||
|
||
processed_files = 0 | ||
|
||
out_size = 64 | ||
length_edges = [] | ||
subgraphs = [] | ||
for line in file_list: | ||
rms_type, fp_eds, rms_bbs, eds_to_rms, eds_to_rms_tmp = reader(line) | ||
|
||
eds_to_rms_tmp = [] | ||
for l in range(len(eds_to_rms)): | ||
eds_to_rms_tmp.append([eds_to_rms[l][0]]) | ||
|
||
rms_masks = [] | ||
im_size = 256 | ||
fp_mk = np.zeros((out_size, out_size)) | ||
nodes = rms_type | ||
for k in range(len(nodes)): | ||
eds = [] | ||
for l, e_map in enumerate(eds_to_rms_tmp): | ||
if k in e_map: | ||
eds.append(l) | ||
for eds_poly in [eds]: | ||
length_edges.append((line, np.array([fp_eds[l][:4] for l in eds_poly], dtype=object))) | ||
|
||
processed_files += 1 | ||
if processed_files % 1000 == 0: | ||
print(f"Processed {processed_files} files.") | ||
|
||
print(f"Finished processing {processed_files} files.") | ||
|
||
# Convert length_edges to a structured array to handle variable-length sequences | ||
dtype = [('filename', 'U256'), ('edges', 'O')] | ||
length_edges_structured = np.array(length_edges, dtype=dtype) | ||
|
||
chk = [x['edges'].shape for x in length_edges_structured] | ||
idx = [i for i, x in enumerate(chk) if len(x) != 2] | ||
final = length_edges_structured[idx]['filename'].tolist() | ||
final = [x.replace('\n', '') for x in final] | ||
|
||
for fin in final: | ||
try: | ||
os.remove(fin) | ||
except: | ||
print(f"Failed to delete {fin}") | ||
|
||
print("Verification: Basic check complete.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import os | ||
|
||
def get_image_number(name): | ||
return int(name.split(".")[0]) | ||
|
||
def write_filenames_to_txt(directory, txt_file): | ||
json_names = os.listdir(directory) | ||
json_names = sorted(json_names, key=get_image_number) | ||
with open(txt_file, 'w') as f: | ||
for filename in json_names: | ||
f.write(filename + '\n') | ||
|
||
|
||
write_filenames_to_txt("rplan", "list.txt") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -87,4 +87,4 @@ def create_argparser(): | |
|
||
|
||
if __name__ == "__main__": | ||
main() | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import json | ||
import numpy as np | ||
from glob import glob | ||
import os | ||
|
||
def reader(filename): | ||
with open(filename) as f: | ||
info = json.load(f) | ||
rms_bbs = np.asarray(info['boxes']) | ||
fp_eds = info['edges'] | ||
rms_type = info['room_type'] | ||
eds_to_rms = info['ed_rm'] | ||
s_r = 0 | ||
for rmk in range(len(rms_type)): | ||
if rms_type[rmk] != 17: | ||
s_r = s_r + 1 | ||
rms_bbs = np.array(rms_bbs) / 256.0 | ||
fp_eds = np.array(fp_eds) / 256.0 | ||
fp_eds = fp_eds[:, :4] | ||
tl = np.min(rms_bbs[:, :2], 0) | ||
br = np.max(rms_bbs[:, 2:], 0) | ||
shift = (tl + br) / 2.0 - 0.5 | ||
rms_bbs[:, :2] -= shift | ||
rms_bbs[:, 2:] -= shift | ||
fp_eds[:, :2] -= shift | ||
fp_eds[:, 2:] -= shift | ||
tl -= shift | ||
br -= shift | ||
eds_to_rms_tmp = [] | ||
|
||
for l in range(len(eds_to_rms)): | ||
eds_to_rms_tmp.append([eds_to_rms[l][0]]) | ||
|
||
return rms_type, fp_eds, rms_bbs, eds_to_rms, eds_to_rms_tmp | ||
|
||
# Adjust the glob path to match where your JSON files are located | ||
file_list = glob('rplan/*') | ||
|
||
# Initialize a counter for progress monitoring | ||
processed_files = 0 | ||
|
||
out_size = 64 | ||
length_edges = [] | ||
subgraphs = [] | ||
for line in file_list: | ||
rms_type, fp_eds, rms_bbs, eds_to_rms, eds_to_rms_tmp = reader(line) | ||
|
||
eds_to_rms_tmp = [] | ||
for l in range(len(eds_to_rms)): | ||
eds_to_rms_tmp.append([eds_to_rms[l][0]]) | ||
|
||
rms_masks = [] | ||
im_size = 256 | ||
fp_mk = np.zeros((out_size, out_size)) | ||
nodes = rms_type | ||
for k in range(len(nodes)): | ||
eds = [] | ||
for l, e_map in enumerate(eds_to_rms_tmp): | ||
if k in e_map: | ||
eds.append(l) | ||
for eds_poly in [eds]: | ||
length_edges.append((line, np.array([fp_eds[l][:4] for l in eds_poly]))) | ||
|
||
# Progress monitoring | ||
processed_files += 1 | ||
if processed_files % 1000 == 0: | ||
print(f"Processed {processed_files} files.") | ||
|
||
# After processing all files | ||
print(f"Finished processing {processed_files} files.") | ||
|
||
chk = [x.shape for x in np.array(length_edges)[:, 1]] | ||
idx = [i for i, x in enumerate(chk) if len(x) != 2] | ||
final = np.array(length_edges)[idx][:, 0].tolist() | ||
final = [x.replace('\n', '') for x in final] | ||
|
||
# Attempt to delete files based on final list | ||
for fin in final: | ||
try: | ||
os.remove(fin) | ||
except: | ||
print(f"Failed to delete {fin}") | ||
|
||
# Simple verification example | ||
# Add your verification logic here as needed | ||
print("Verification: Basic check complete.") |
Empty file.