Skip to content

Commit

Permalink
Update train_ablation.py
Browse files Browse the repository at this point in the history
  • Loading branch information
jayrn2 authored Oct 5, 2023
1 parent f014b26 commit d3feeb1
Showing 1 changed file with 27 additions and 21 deletions.
48 changes: 27 additions & 21 deletions train_ablation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def worker_init_fn(worker_id):

# here imma have to lower epoch to get my segmented images, it was set to 80 before

def main(seed=2022, epoches=5): #500
def main(seed=2022, epoches=50): #500
parser = argparse.ArgumentParser(description='ablation')
# dataset option
parser.add_argument('--model_name', type=str, default='mosts', choices=['mosts'], help='model name')
Expand Down Expand Up @@ -57,6 +57,7 @@ def main(seed=2022, epoches=5): #500
evaluator = Evaluator(num_class=data_val.split_point+1) # ignore background class

dir_name = 'log/' + str(args.data_loader) + '_' + str(args.model_name) + '_valid_group_' + str(args.valid_group)
# ablation_data_loader_mosts_valid_group_3

if not os.path.exists(dir_name):
#os.mkdir(dir_name)
Expand All @@ -76,6 +77,12 @@ def main(seed=2022, epoches=5): #500
# Complie model
model = models[args.model_name]()

# Load pretrained models for training with another dataset
# - path for pretrained set (DTD training set)
pre_trained_model_path = 'C:\\Users\\AUVSL\\Documents\\Jay\\MOSTS\\log\\ablation_data_loader_mosts_valid_group_3\\epoch_2023_09_29_10_42_09_texture.pth' #the one we wanna use for UC merced set next
model.load_state_dict(torch.load(pre_trained_model_path))


# CUDA init
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
Expand Down Expand Up @@ -132,6 +139,7 @@ def main(seed=2022, epoches=5): #500
logging.info('iter:' + str(iteration) + " time:" + str(run_time) + " train loss = {:02.5f}".format(losses))
losses = 0
model_path = dir_name + '/epoch_{epoches}_texture.pth'.format(epoches=now_time)

print("Training progress: ",data_train.curriculum*100,"%")

# Model evaluation after one epoch
Expand Down Expand Up @@ -168,29 +176,27 @@ def main(seed=2022, epoches=5): #500

# looping through data
for idx in range(query.shape[0]):
# 1. first make directory to store images
# 2. convert the tensor into a numpy array then transpose it into matplot color format (CxHxW to HxWxC)
# 3. save the images in their respectve categories

# query image
os.makedirs('results/query', exist_ok=True)
query_img = query[idx].cpu().numpy().transpose(1, 2, 0)
plt.imsave(f'results/query/query_img_{epoch}_{idx}.png', query_img)

# reference image
os.makedirs('results/reference',exist_ok=True)
ref_img = reference[idx].cpu().numpy().transpose(1, 2, 0)
plt.imsave(f'results/reference/ref_img_{epoch}_{idx}.png', ref_img)

# ground truth label
os.makedirs('results/truth_label',exist_ok=True)
label_img = label[idx].cpu().numpy()
plt.imsave(f'results/truth_label/label_img_{epoch}_{idx}.png', label_img, cmap='viridis')
# query, reference, & ground truth labels do not change so we can just use first iteration
if epoch == 1:
# query image
os.makedirs('results/query', exist_ok=True)
query_img = query[idx].cpu().numpy().transpose(1, 2, 0)
plt.imsave(f'results/query/query_img_{epoch}_{idx}.png', query_img)

# reference image
os.makedirs('results/reference',exist_ok=True)
ref_img = reference[idx].cpu().numpy().transpose(1, 2, 0)
plt.imsave(f'results/reference/ref_img_{epoch}_{idx}.png', ref_img)

# ground truth label
os.makedirs('results/truth_label',exist_ok=True)
label_img = label[idx].cpu().numpy()
plt.imsave(f'results/truth_label/label_img_{epoch}_{idx}.png', label_img, cmap='plasma')

# predicted segmentation (already np array)
os.makedirs('results/predicted_segment',exist_ok=True)
pred_img = pred[idx]
plt.imsave(f'results/predicted_segment/pred_img_{epoch}_{idx}.png', pred_img, cmap='inferno')
plt.imsave(f'results/predicted_segment/pred_img_{epoch}_{idx}.png', pred_img, cmap='viridis')

label = label.cpu().numpy()
evaluator.add_batch(label, pred, image_class)
Expand All @@ -216,4 +222,4 @@ def main(seed=2022, epoches=5): #500
logging.info(IoU_final)

if __name__ == '__main__':
main()
main()

0 comments on commit d3feeb1

Please sign in to comment.