Skip to content

Commit

Permalink
Update accel_sdxl_gen_img.py
Browse files Browse the repository at this point in the history
  • Loading branch information
DKnight54 authored Feb 1, 2025
1 parent efb3722 commit 03232a6
Showing 1 changed file with 37 additions and 35 deletions.
72 changes: 37 additions & 35 deletions accel_sdxl_gen_img.py
Original file line number Diff line number Diff line change
Expand Up @@ -2877,41 +2877,43 @@ def scale_and_round(x):
logger.info(f"batch_data line 2878: {len(batch_data)}")
batch_separated_list = []
logger.info(f"Device {distributed_state.device}, distributed_state.is_main_process 2878: {distributed_state.is_main_process}")
if len(batch_data) > 0:
if distributed_state.is_main_process:
unique_extinfo = list(set(extinfo))
logger.info(f"Device {distributed_state.device}, batch_data line 2880: {len(batch_data)}, len(unique_extinfo): {len(unique_extinfo)}")
# splits list of prompts into sublists where BatchDataExt ext is identical
for i in range(len(unique_extinfo)):
logger.info(f"Device {distributed_state.device}, batch_data line 2880: {len(batch_data)}, len(unique_extinfo): {i}")
templist = []
res = [j for j, val in enumerate(batch_data) if val.ext == unique_extinfo[i]]
for index in res:
templist.append(batch_data[index])
split_into_batches = get_batches(items=templist, batch_size=args.batch_size)
if(len(split_into_batches) % distributed_state.num_processes != 0):
#Distributes last round of batches across available processes if last round of batches less than available processes and there is more than one prompt in the last batch
sublist = []
for j in range(len(split_into_batches) % distributed_state.num_processes):
if len(split_into_batches) > 1 :
sublist.extend(split_into_batches.pop(-1))
elif len(split_into_batches) == 1 :
sublist.extend(split_into_batches.pop(-1))
listofbatches = []
n, m = divmod(len(sublist), distributed_state.num_processes)
split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(distributed_state.num_processes)])
batch_separated_list.append(split_into_batches)
logger.info(f"batch_separated_list line 2901: {len(batch_separated_list)}, {distributed_state.num_processes}")
if distributed_state.num_processes > 1:
logger.info(f"batch_separated_list: {len(batch_separated_list)}")
templist = []
for i in range(distributed_state.num_processes):
templist.append(batch_separated_list[i :: distributed_state.num_processes])
logger.info(f"templist: {len(templist)}")
batch_separated_list = []
for sub_batch_list in templist:
batch_separated_list.extend(sub_batch_list)
logger.info(f"batch_separated_list: {len(batch_separated_list)}")
if len(batch_data) > 0 and distributed_state.is_main_process:
unique_extinfo = list(set(extinfo))
logger.info(f"Device {distributed_state.device}, batch_data line 2880: {len(batch_data)}, len(unique_extinfo): {len(unique_extinfo)}")
# splits list of prompts into sublists where BatchDataExt ext is identical
for i in range(len(unique_extinfo)):
logger.info(f"Device {distributed_state.device}, batch_data line 2880: {len(batch_data)}, len(unique_extinfo): {i}")
templist = []
res = [j for j, val in enumerate(batch_data) if val.ext == unique_extinfo[i]]
for index in res:
templist.append(batch_data[index])
split_into_batches = get_batches(items=templist, batch_size=args.batch_size)
if(len(split_into_batches) % distributed_state.num_processes != 0):
#Distributes last round of batches across available processes if last round of batches less than available processes and there is more than one prompt in the last batch
sublist = []
for j in range(len(split_into_batches) % distributed_state.num_processes):
if len(split_into_batches) > 1 :
sublist.extend(split_into_batches.pop(-1))
elif len(split_into_batches) == 1 :
sublist.extend(split_into_batches.pop(-1))
split_into_batches = []
n, m = divmod(len(sublist), distributed_state.num_processes)
split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(distributed_state.num_processes)])
batch_separated_list.extend(split_into_batches)
logger.info(f"batch_separated_list line 2901: {len(batch_separated_list)}, {distributed_state.num_processes}")
if distributed_state.num_processes > 1:
logger.info(f"batch_separated_list: {len(batch_separated_list)}")

temp_list = []
for ext_batch in batch_separated_list:
for i in range(distributed_state.num_processes):
templist.append(ext_batch[i :: distributed_state.num_processes])
logger.info(f"templist: {len(temp_list)}")
batch_separated_list = []
for sub_batch_list in temp_list:
batch_separated_list.append(sub_batch_list)
logger.info(f"batch_separated_list: {len(batch_separated_list)}")
logger.info(f"sub_batch_list: {len(sub_batch_list)}")
distributed_state.wait_for_everyone()
batch_data = gather_object(batch_separated_list)
logger.info(f"batch_data line 2911: {len(batch_data)}")
Expand Down

0 comments on commit 03232a6

Please sign in to comment.