Skip to content

Commit

Permalink
Update ablation_data_loader.py
Browse files Browse the repository at this point in the history
  • Loading branch information
jayrn2 authored Oct 5, 2023
1 parent d3feeb1 commit fa006e9
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions utils/data/ablation_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@

from .ablation_generate_collages import generate_collages

def get_key(val,my_dict):
def get_key(val, my_dict):
for key, value in my_dict.items():
if val == value:
return key
print(f"Warning: {val} not found in dictionary")
return -1

class ablation_data_loader(torch.utils.data.Dataset):
Expand All @@ -29,17 +30,28 @@ def __init__(self, split='train', random_gen=None, num_candidates=5, transform_r
else:
self.random_gen = random_gen

valid_list = [['banded','blotchy','braided','bubbly','bumpy'],
# Original valid list used for the DTD dataset
''' valid_list = [['banded','blotchy','braided','bubbly','bumpy'],
['chequered','cobwebbed','cracked','crosshatched','crystalline'],
['dotted','fibrous','flecked','freckled','frilly'],
['gauzy','grid','grooved','honeycombed','interlaced'],
['waffled', 'potholed', 'pleated', 'meshed', 'spiralled']]

#dir = '/dataset/dtd/images'
'''

# Changes made to validation list for: UC Merced dataset
valid_list = [['agricultural', 'airplane', 'baseballdiamond', 'beach', 'buildings'],
['chaparral', 'denseresidential', 'forest', 'freeway', 'golfcourse'],
['harbor', 'intersection', 'mediumresidential', 'mobilehomepark', 'overpass'],
['parkinglot', 'river', 'runway', 'sparseresidential', 'storagetanks'],
['tenniscourt']]

# Windows file directory syntax
dir = 'C:\\MOST_training\\MOSTS\\dataset\\dtd\\images'
#dir = '/dataset/dtd/images'

# Path for DTD dataset images
# dir = 'C:\\Users\\AUVSL\\Documents\\Jay\\MOSTS\\dataset\\dtd\\images'

# Path for UC Merced outdoor landuse images:
dir = 'C:\\Users\\AUVSL\\Documents\\Jay\\MOSTS\\dataset\\UCMerced_LandUse\\Images'
idx_to_class, image_path_all = self.load_path(dir)
total_num_class = len(idx_to_class)

Expand Down

0 comments on commit fa006e9

Please sign in to comment.