Skip to content

Commit

Permalink
move data sampling before data processing occurs; this is faster and …
Browse files Browse the repository at this point in the history
…avoids unnecessary waiting
  • Loading branch information
borauyar committed Apr 6, 2024
1 parent f2db9d6 commit 53e9b3a
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 17 deletions.
7 changes: 2 additions & 5 deletions flexynesis/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,13 +187,10 @@ class AvailableModels(NamedTuple):
graph=graph,
processed_dir = '_'.join(['processed', args.prefix]),
string_organism=args.string_organism,
string_node_name=args.string_node_name)
string_node_name=args.string_node_name,
downsample = args.subsample)
train_dataset, test_dataset = data_importer.import_data(force = True)

if args.subsample > 0:
print("[INFO] Randomly drawing",args.subsample,"samples for training")
train_dataset = flexynesis.downsample(train_dataset, N = args.subsample)

# print feature logs to file (we use these tables to track which features are dropped/selected and why)
feature_logs = data_importer.feature_logs
for key in feature_logs.keys():
Expand Down
26 changes: 14 additions & 12 deletions flexynesis/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class DataImporter:

def __init__(self, path, data_types, processed_dir="processed", log_transform = False, concatenate = False, restrict_to_features = None, min_features=None,
top_percentile=20, correlation_threshold = 0.9, variance_threshold=1e-5, na_threshold=0.1,
graph=None, string_organism=9606, string_node_name="gene_name", transform=None):
graph=None, string_organism=9606, string_node_name="gene_name", transform=None, downsample=0):
self.path = path
self.data_types = data_types
self.processed_dir = os.path.join(self.path, processed_dir)
Expand All @@ -115,6 +115,7 @@ def __init__(self, path, data_types, processed_dir="processed", log_transform =
self.scalers = None
# initialize data transformers
self.transformers = None
self.downsample = downsample

self.graph = graph
if self.graph is not None:
Expand Down Expand Up @@ -179,6 +180,10 @@ def import_data(self, force=False):
train_dat = self.read_data(training_path)
test_dat = self.read_data(testing_path)

if self.downsample > 0:
print("[INFO] Randomly drawing",self.downsample,"samples for training")
train_dat = self.subsample(train_dat, self.downsample)

if self.restrict_to_features is not None:
train_dat = self.filter_by_features(train_dat, self.restrict_to_features)
test_dat = self.filter_by_features(test_dat, self.restrict_to_features)
Expand Down Expand Up @@ -287,6 +292,14 @@ def read_data(self, folder_path):
print(f"[INFO] Importing {file_path}...")
data[file_name] = pd.read_csv(file_path, index_col=0)
return data

# randomly draw N samples; return subset of dat (output of read_data)
def subsample(self, dat, N):
clin = dat['clin'].sample(N)
dat_sub = {x: dat[x][clin.index] for x in self.data_types}
dat_sub['clin'] = clin
return dat_sub


def filter_by_features(self, dat, features):
"""
Expand Down Expand Up @@ -954,14 +967,3 @@ def split_by_median(tensor_dict):
# If tensor is not numerical, leave it as it is
new_dict[key] = tensor
return new_dict


# downsample a given MultiOmicDataset
def downsample(dataset, N = None):
idx = np.random.choice(range(0,len(dataset)), N, replace = False)
sub = dataset[idx]

sub_dataset = MultiomicDataset(dat = sub[0], ann = sub[1], variable_types = dataset.variable_types,
features = dataset.features, samples = [dataset.samples[i] for i in idx],
label_mappings=dataset.label_mappings, feature_ann=dataset.feature_ann)
return sub_dataset

0 comments on commit 53e9b3a

Please sign in to comment.