Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

5 evms integrate core framework #43

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
19 changes: 19 additions & 0 deletions jlab_datascience_toolkit/analyses/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from jlab_datascience_toolkit.utils.registration import register, make, list_registered_modules

# Residual analyzer:
register(
id="ResidualAnalyzer_v0",
entry_point="jlab_datascience_toolkit.analyses.residual_analyzer:ResidualAnalyzer"
)

# Data reconstruction:
register(
id="DataReconstruction_v0",
entry_point="jlab_datascience_toolkit.analyses.data_reconstruction:DataReconstruction"
)

# Learning Curve Visualizer:
register(
id="LearningCurveVisualizer_v0",
entry_point="jlab_datascience_toolkit.analyses.learning_curve_visualizer:LearningCurveVisualizer"
)
194 changes: 194 additions & 0 deletions jlab_datascience_toolkit/analyses/data_reconstruction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
from jlab_datascience_toolkit.core.jdst_analysis import JDSTAnalysis
import tensorflow as tf
import numpy as np
import os
import inspect
import yaml
import logging

class DataReconstruction(JDSTAnalysis):
'''
Simple module that passes input data x through a model:

x_rec = model(x)

where model can be a (variational) Autoencoder, U-Net, Diffusion model,...
The data here is processed via the tf.dataset system, in order to efficiently handle large data sets.

Input(s):
i) Numpy arrays / images
ii) A trained model

Output(s):
i) Dictionary with reconstructed images and (optional) original images
'''

# Initialize:
#*********************************************
def __init__(self,path_to_cfg,user_config={}):
# Define the module and module name:
self.module_name = "data_reconstruction"

# Load the configuration:
self.config = self.load_config(path_to_cfg,user_config)

# Save this config, if a path is provided:
if 'store_cfg_loc' in self.config:
self.save_config(self.config['store_cfg_loc'])

# General settings:
self.output_loc = self.config['output_loc']
self.data_store_name = self.config['data_store_name']

# Data processing settings:
self.buffer_size = self.config['buffer_size']
self.n_analysis_samples = self.config['n_analysis_samples']
self.analysis_sample_size = self.config['analysis_sample_size']

# Get names of the data:
self.data_names = self.config['data_names']
self.record_original_data = self.config['record_original_data']

self.store_data = False
if self.output_loc is not None and self.output_loc.lower() != "":
self.store_data = True
os.makedirs(self.output_loc,exist_ok=True)
#*********************************************

# Check the input data type:
#*********************************************
def check_input_data_type(self,x=None,model_list=[]):

if isinstance(x,np.ndarray) and isinstance(model_list,list):
pass_model_type_check = False
if len(model_list) > 0:
pass_model_type_check = True
#+++++++++++++++
for m in model_list:
if isinstance(m,tf.keras.Model) == False:
pass_model_type_check = False
#+++++++++++++++

return pass_model_type_check
else:
logging.error(f">>> {self.module_name}: The provided data does not match the requirements. The first argument has to be a numpy array, Whereas the second argument should be a non-empty list with tf.keras.Model. Going to return None. <<<")
return False
#*********************************************

# Provide information about this module:
#*********************************************
def get_info(self):
print(inspect.getdoc(self))
#*********************************************

# Handle configurations:
#*********************************************
# Load the config:
def load_config(self,path_to_cfg,user_config):
with open(path_to_cfg, 'r') as file:
cfg = yaml.safe_load(file)

# Overwrite config with user settings, if provided
try:
if bool(user_config):
#++++++++++++++++++++++++
for key in user_config:
cfg[key] = user_config[key]
#++++++++++++++++++++++++
except:
logging.exception(">>> " + self.module_name +": Invalid user config. Please make sure that a dictionary is provided <<<")

return cfg

#-----------------------------

# Store the config:
def save_config(self,path_to_config):
with open(path_to_config, 'w') as file:
yaml.dump(self.config, file)
#*********************************************

# Reconstruct the data:
#*********************************************
# First, we need a model prediction:
def get_model_predictions(self,x,model_list):
# Go through all elements within the model list and collect the predictions
x_in = x
#++++++++++++++++++
for model in model_list:
x_out = model.predict_on_batch(x_in)
x_in = x_out
#++++++++++++++++++

return x_out

#------------------------------

# Now run the reconstruction:
def reconstruct_data(self,x,model_list):
# First, we need to create a tf data set which shows the beauty of this method:

# Provide the option to only analyze a part of the initial data:
n_ana_samples = x.shape[0]
if self.n_analysis_samples > 0:
n_ana_samples = self.n_analysis_samples

# If we only analyze a fraction of the data, we need to record to original data as well:
self.record_original_data = True

tf_data = tf.data.Dataset.from_tensor_slices(x).shuffle(buffer_size=self.buffer_size).take(n_ana_samples).batch(self.analysis_sample_size)

# Second, make sure that we have a model list:
if type(model_list) != list:
model_list = [model_list]

# Third, make some predictions:
predictions = []
inputs = []
#++++++++++++++++++++++
for sample in tf_data:
# Get the prediction:
current_pred = self.get_model_predictions(sample,model_list)
predictions.append(current_pred)

if self.record_original_data == True:
inputs.append(sample)
#++++++++++++++++++++++

# Record everything:
result_dict = {}
result_dict[self.data_names[1]] = np.concatenate(predictions,axis=0)

if self.record_original_data == True:
result_dict[self.data_names[0]] = np.concatenate(inputs,axis=0)
else:
result_dict[self.data_names[0]] = None

return result_dict
#*********************************************

# Run the analysis:
#*********************************************
def run(self,x,model_list):
# Run type check:
if self.check_input_data_type(x,model_list):
results = self.reconstruct_data(x,model_list)

if self.store_data:
np.save(self.output_loc+"/"+self.data_store_name+".npy",np.array(results,dtype=object))

return results

else:
return None
#*********************************************

# Save and load are not active here:
#****************************
def save(self):
pass

def load(self):
pass
#****************************

175 changes: 175 additions & 0 deletions jlab_datascience_toolkit/analyses/learning_curve_visualizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
from jlab_datascience_toolkit.core.jdst_analysis import JDSTAnalysis
import matplotlib.pyplot as plt
import os
import yaml
import inspect
import logging

class LearningCurveVisualizer(JDSTAnalysis):
'''
Simple class to visualize the learning curves produced during model training.

Input(s):
i) Dictionary with all loss curves

Output(s):
i) .png files visualizing the learning curves
'''

# Initialize:
#*********************************************
def __init__(self,path_to_cfg,user_config={}):
# Set the name specific to this module:
self.module_name = "learning_curve_visualizer"

# Load the configuration:
self.config = self.load_config(path_to_cfg,user_config)

# Get plots that shall be produced:
self.plots = self.config['plots']
# Get the corresponding plot labels, plot legends and the names of each individual plot:
self.plot_labels = self.config['plot_labels']
self.plot_legends = self.config['plot_legends']
self.plot_names = self.config['plot_names']

# Cosmetics:
self.fig_size = self.config['fig_size']
self.line_width = self.config['line_width']
self.font_size = self.config['font_size']
self.leg_font_size = self.config['leg_font_size']

# Set font size:
plt.rcParams.update({'font.size':self.font_size})

# Save this config, if a path is provided:
if 'store_cfg_loc' in self.config:
self.save_config(self.config['store_cfg_loc'])

# Get the output location and create proper folders:
self.output_loc = self.config['output_loc']
self.plot_loc = self.output_loc+"/learning_curves"

os.makedirs(self.output_loc,exist_ok=True)
os.makedirs(self.plot_loc,exist_ok=True)
#*********************************************

# Check the data type:
#*********************************************
def check_input_data_type(self,data):
if isinstance(data,dict) == True:
if bool(dict) == False:
logging.error(f">>> {self.module_name}: Your dictionary {data} is empty. Please check. Going to return None. <<<")
return False
return True

else:
logging.error(f">>> {self.module_name}: The data type you provided {type(data)} is not a dictionary. Please check. Going to return None. <<<")
return False
#*********************************************

# Provide information about this module:
#*********************************************
def get_info(self):
print(inspect.getdoc(self))
#*********************************************

# Handle configurations:
#*********************************************
# Load the config:
def load_config(self,path_to_cfg,user_config):
with open(path_to_cfg, 'r') as file:
cfg = yaml.safe_load(file)

# Overwrite config with user settings, if provided
try:
if bool(user_config):
#++++++++++++++++++++++++
for key in user_config:
cfg[key] = user_config[key]
#++++++++++++++++++++++++
except:
logging.exception(">>> " + self.module_name +": Invalid user config. Please make sure that a dictionary is provided <<<")

return cfg

#-----------------------------

# Store the config:
def save_config(self,path_to_config):
with open(path_to_config, 'w') as file:
yaml.dump(self.config, file)
#*********************************************

# Run the entire analysis:
#*********************************************
# Peoduce a single plot, based on scores and legends:
def produce_single_plot(self,history,scores,legend_entries,axis):
if legend_entries is None:
#++++++++++++++++
for s in scores:
if s in history:
metric = history[s]
x = [k for k in range(1,1+len(metric))]
axis.plot(x,metric,linewidth=self.line_width)
#++++++++++++++++

else:
#++++++++++++++++
for s,l in zip(scores,legend_entries):
if s in history:
metric = history[s]
x = [k for k in range(1,1+len(metric))]
axis.plot(x,metric,linewidth=self.line_width,label=l)
#++++++++++++++++
axis.legend(fontsize=self.leg_font_size)


def run(self,training_history):
if self.check_input_data_type(training_history):
# Loop through all plots that we wish to produce:
#+++++++++++++++++++++++
for plot in self.plots:
# Create a canvas to draw on:
fig,ax = plt.subplots(figsize=self.fig_size)

scores = self.plots[plot]

legend_entries = self.plot_legends.get(plot,None)
labels = self.plot_labels.get(plot,None)
name = self.plot_names.get(plot,None)

if legend_entries is not None:
assert len(legend_entries) == len(scores), logging.error(f">>> {self.module_name}: Number of legend entries {legend_entries} does not match the number of available score {scores} <<<")

# Produce a nice plot:
self.produce_single_plot(training_history,scores,legend_entries,ax)
ax.grid(True)

if labels is not None:
assert len(labels) == 2, logging.error(f">>> {self.module_name}: Number of plot labels {labels} does not match exptected number of two entries <<<")

# Add labels if available:
ax.set_xlabel(labels[0])
ax.set_ylabel(labels[1])

# Store the figure somewhere:
if name is not None:
fig.savefig(self.plot_loc+"/"+name+".png")
plt.close(fig)
#+++++++++++++++++++++++

else:
return None
#*********************************************

# Save and load are not active here:
#*********************************************
def save(self):
pass

def load(self):
pass
#*********************************************



Loading