Skip to content

Commit

Permalink
fixed some errors
Browse files Browse the repository at this point in the history
  • Loading branch information
adeerBB committed Apr 8, 2024
1 parent 24ff52c commit 63f4833
Show file tree
Hide file tree
Showing 12 changed files with 660 additions and 5 deletions.
3 changes: 3 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"git.ignoreLimitWarning": true
}
Empty file.
1 change: 1 addition & 0 deletions guided-diffusion
Submodule guided-diffusion added at 22e0df
4 changes: 2 additions & 2 deletions house_diffusion/rplanhg_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def make_non_manhattan(poly, polygon, house_poly):
return poly2

get_bin = lambda x, z: [int(y) for y in format(x, 'b').zfill(z)]
get_one_hot = lambda x, z: np.eye(z)[x]
get_one_hot = lambda x, z: np.eye(z)[min(x, z-1)]
class RPlanhgDataset(Dataset):
def __init__(self, set_name, analog_bit, target_set, non_manhattan=False):
super().__init__()
Expand Down Expand Up @@ -534,4 +534,4 @@ def reader(filename):
return rms_type,fp_eds,rms_bbs,eds_to_rms

if __name__ == '__main__':
dataset = RPlanhgDataset('eval', False, 8)
dataset = RPlanhgDataset('eval', False, 8)
378 changes: 378 additions & 0 deletions image_sample.py

Large diffs are not rendered by default.

90 changes: 90 additions & 0 deletions image_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""
Train a diffusion model on images.
"""

import argparse

from house_diffusion import dist_util, logger
from house_diffusion.rplanhg_datasets import load_rplanhg_data
from house_diffusion.resample import create_named_schedule_sampler
from house_diffusion.script_util import (
model_and_diffusion_defaults,
create_model_and_diffusion,
args_to_dict,
add_dict_to_argparser,
update_arg_parser,
)
from house_diffusion.train_util import TrainLoop


def main():
args = create_argparser().parse_args()
update_arg_parser(args)

dist_util.setup_dist()
logger.configure()

logger.log("creating model and diffusion...")
model, diffusion = create_model_and_diffusion(
**args_to_dict(args, model_and_diffusion_defaults().keys())
)
model.to(dist_util.dev())
schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion)

logger.log("creating data loader...")
if args.dataset=='rplan':
data = load_rplanhg_data(
batch_size=args.batch_size,
analog_bit=args.analog_bit,
target_set=args.target_set,
set_name=args.set_name,
)
else:
print('dataset not exist!')
assert False

logger.log("training...")
TrainLoop(
model=model,
diffusion=diffusion,
data=data,
batch_size=args.batch_size,
microbatch=args.microbatch,
lr=args.lr,
ema_rate=args.ema_rate,
log_interval=args.log_interval,
save_interval=args.save_interval,
resume_checkpoint=args.resume_checkpoint,
use_fp16=args.use_fp16,
fp16_scale_growth=args.fp16_scale_growth,
schedule_sampler=schedule_sampler,
weight_decay=args.weight_decay,
lr_anneal_steps=args.lr_anneal_steps,
analog_bit=args.analog_bit,
).run_loop()


def create_argparser():
defaults = dict(
dataset = '',
schedule_sampler= "uniform", #"loss-second-moment", "uniform",
lr=1e-4,
weight_decay=0.0,
lr_anneal_steps=0,
batch_size=1,
microbatch=-1, # -1 disables microbatches
ema_rate="0.9999", # comma-separated list of EMA values
log_interval=10,
save_interval=10000,
resume_checkpoint="",
use_fp16=False,
fp16_scale_growth=1e-3,
)
parser = argparse.ArgumentParser()
defaults.update(model_and_diffusion_defaults())
add_dict_to_argparser(parser, defaults)
return parser


if __name__ == "__main__":
main()
83 changes: 83 additions & 0 deletions json_fixer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import json
import numpy as np
from glob import glob
import os

def reader(filename):
with open(filename) as f:
info = json.load(f)
rms_bbs = np.asarray(info['boxes'])
fp_eds = info['edges']
rms_type = info['room_type']
eds_to_rms = info['ed_rm']
s_r = 0
for rmk in range(len(rms_type)):
if rms_type[rmk] != 17:
s_r = s_r + 1
rms_bbs = np.array(rms_bbs) / 256.0
fp_eds = np.array(fp_eds) / 256.0
fp_eds = fp_eds[:, :4]
tl = np.min(rms_bbs[:, :2], 0)
br = np.max(rms_bbs[:, 2:], 0)
shift = (tl + br) / 2.0 - 0.5
rms_bbs[:, :2] -= shift
rms_bbs[:, 2:] -= shift
fp_eds[:, :2] -= shift
fp_eds[:, 2:] -= shift
tl -= shift
br -= shift
eds_to_rms_tmp = []

for l in range(len(eds_to_rms)):
eds_to_rms_tmp.append([eds_to_rms[l][0]])

return rms_type, fp_eds, rms_bbs, eds_to_rms, eds_to_rms_tmp

file_list = glob('rplan/*')

processed_files = 0

out_size = 64
length_edges = []
subgraphs = []
for line in file_list:
rms_type, fp_eds, rms_bbs, eds_to_rms, eds_to_rms_tmp = reader(line)

eds_to_rms_tmp = []
for l in range(len(eds_to_rms)):
eds_to_rms_tmp.append([eds_to_rms[l][0]])

rms_masks = []
im_size = 256
fp_mk = np.zeros((out_size, out_size))
nodes = rms_type
for k in range(len(nodes)):
eds = []
for l, e_map in enumerate(eds_to_rms_tmp):
if k in e_map:
eds.append(l)
for eds_poly in [eds]:
length_edges.append((line, np.array([fp_eds[l][:4] for l in eds_poly], dtype=object)))

processed_files += 1
if processed_files % 1000 == 0:
print(f"Processed {processed_files} files.")

print(f"Finished processing {processed_files} files.")

# Convert length_edges to a structured array to handle variable-length sequences
dtype = [('filename', 'U256'), ('edges', 'O')]
length_edges_structured = np.array(length_edges, dtype=dtype)

chk = [x['edges'].shape for x in length_edges_structured]
idx = [i for i, x in enumerate(chk) if len(x) != 2]
final = length_edges_structured[idx]['filename'].tolist()
final = [x.replace('\n', '') for x in final]

for fin in final:
try:
os.remove(fin)
except:
print(f"Failed to delete {fin}")

print("Verification: Basic check complete.")
14 changes: 14 additions & 0 deletions list_maker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import os

def get_image_number(name):
return int(name.split(".")[0])

def write_filenames_to_txt(directory, txt_file):
json_names = os.listdir(directory)
json_names = sorted(json_names, key=get_image_number)
with open(txt_file, 'w') as f:
for filename in json_names:
f.write(filename + '\n')


write_filenames_to_txt("rplan", "list.txt")
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@ imageio==2.19.2
matplotlib==3.5.1
mpi4py==3.1.4
networkx==2.8.2
numpy==1.21.5
numpy==1.21.2
opencv_python==4.6.0.66
Pillow==9.4.0
pytorch_fid==0.3.0
setuptools==57.5.0
Shapely==1.8.4
tensorflow==2.11.0
torch==2.0.0.dev20221212
torch==2.0.0
tqdm==4.61.2
webcolors==1.12
2 changes: 1 addition & 1 deletion scripts/image_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,4 @@ def create_argparser():


if __name__ == "__main__":
main()
main()
86 changes: 86 additions & 0 deletions scripts/json_fixer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import json
import numpy as np
from glob import glob
import os

def reader(filename):
with open(filename) as f:
info = json.load(f)
rms_bbs = np.asarray(info['boxes'])
fp_eds = info['edges']
rms_type = info['room_type']
eds_to_rms = info['ed_rm']
s_r = 0
for rmk in range(len(rms_type)):
if rms_type[rmk] != 17:
s_r = s_r + 1
rms_bbs = np.array(rms_bbs) / 256.0
fp_eds = np.array(fp_eds) / 256.0
fp_eds = fp_eds[:, :4]
tl = np.min(rms_bbs[:, :2], 0)
br = np.max(rms_bbs[:, 2:], 0)
shift = (tl + br) / 2.0 - 0.5
rms_bbs[:, :2] -= shift
rms_bbs[:, 2:] -= shift
fp_eds[:, :2] -= shift
fp_eds[:, 2:] -= shift
tl -= shift
br -= shift
eds_to_rms_tmp = []

for l in range(len(eds_to_rms)):
eds_to_rms_tmp.append([eds_to_rms[l][0]])

return rms_type, fp_eds, rms_bbs, eds_to_rms, eds_to_rms_tmp

# Adjust the glob path to match where your JSON files are located
file_list = glob('rplan/*')

# Initialize a counter for progress monitoring
processed_files = 0

out_size = 64
length_edges = []
subgraphs = []
for line in file_list:
rms_type, fp_eds, rms_bbs, eds_to_rms, eds_to_rms_tmp = reader(line)

eds_to_rms_tmp = []
for l in range(len(eds_to_rms)):
eds_to_rms_tmp.append([eds_to_rms[l][0]])

rms_masks = []
im_size = 256
fp_mk = np.zeros((out_size, out_size))
nodes = rms_type
for k in range(len(nodes)):
eds = []
for l, e_map in enumerate(eds_to_rms_tmp):
if k in e_map:
eds.append(l)
for eds_poly in [eds]:
length_edges.append((line, np.array([fp_eds[l][:4] for l in eds_poly])))

# Progress monitoring
processed_files += 1
if processed_files % 1000 == 0:
print(f"Processed {processed_files} files.")

# After processing all files
print(f"Finished processing {processed_files} files.")

chk = [x.shape for x in np.array(length_edges)[:, 1]]
idx = [i for i, x in enumerate(chk) if len(x) != 2]
final = np.array(length_edges)[idx][:, 0].tolist()
final = [x.replace('\n', '') for x in final]

# Attempt to delete files based on final list
for fin in final:
try:
os.remove(fin)
except:
print(f"Failed to delete {fin}")

# Simple verification example
# Add your verification logic here as needed
print("Verification: Basic check complete.")
Empty file.

0 comments on commit 63f4833

Please sign in to comment.