Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding phase retrieval #130

Open
wants to merge 159 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
159 commits
Select commit Hold shift + click to select a range
625d135
add dev configs
tobias-liaudat Jan 8, 2024
6e88956
config for fast test
tobias-liaudat Jan 8, 2024
38b0f34
fix configs
tobias-liaudat Jan 8, 2024
9e67e91
update (fix) configs
tobias-liaudat Jan 8, 2024
d16cd72
try again
tobias-liaudat Jan 8, 2024
a2f64a0
fix config problem
tobias-liaudat Jan 8, 2024
e45df03
Merge branch 'develop' of https://github.com/CosmoStat/wf-psf into u/…
tobias-liaudat Jan 8, 2024
bb9cba1
Add new Zernike projection function with obscurations
tobias-liaudat Jan 9, 2024
edccc88
Add new Zk projection algorithm
tobias-liaudat Jan 9, 2024
f3ee49d
Add default behaviour
tobias-liaudat Jan 9, 2024
7838cca
Improve comments and doc
tobias-liaudat Jan 9, 2024
5ee00bc
Fix problem with repeating seed for S init after projection
tobias-liaudat Jan 11, 2024
08c0aad
Fix small bug when random seed is not defined
tobias-liaudat Jan 12, 2024
20dcfdb
Add phase retrieval test configs
tobias-liaudat Jan 12, 2024
72e6f69
update dev configs
tobias-liaudat Jan 12, 2024
51e8b65
Add job for PR validation
tobias-liaudat Jan 12, 2024
9b5e817
Add log dir
tobias-liaudat Jan 12, 2024
e03a34f
Update configs
tobias-liaudat Jan 12, 2024
f60dcaf
update job
tobias-liaudat Jan 12, 2024
b1a4c1b
Add new configs and job in new dir
tobias-liaudat Jan 12, 2024
c7d2914
debug lines
tobias-liaudat Jan 12, 2024
9eef69b
try bug fix
tobias-liaudat Jan 12, 2024
a1dc1da
add new test
tobias-liaudat Jan 12, 2024
bce7c9b
fix paths
tobias-liaudat Jan 12, 2024
51d0d6d
erase debug lines
tobias-liaudat Jan 12, 2024
12598d8
fix another path
tobias-liaudat Jan 12, 2024
00faa96
update paths
tobias-liaudat Jan 12, 2024
d200f07
remove unused configs
tobias-liaudat Jan 12, 2024
d50b647
Add plot all config
tobias-liaudat Jan 15, 2024
a5b3f5c
fix job
tobias-liaudat Jan 15, 2024
d55f46e
fix conf
tobias-liaudat Jan 15, 2024
652d504
Fix config file
tobias-liaudat Jan 15, 2024
9c8ac6a
Fix path
tobias-liaudat Jan 15, 2024
d066fc3
Fix path
tobias-liaudat Jan 15, 2024
1716bfa
fix path
tobias-liaudat Jan 15, 2024
ac8d578
Update configs and job for standard sanity check run
tobias-liaudat Jan 15, 2024
3bb5779
update name
tobias-liaudat Jan 15, 2024
b2e5032
Remove projections
tobias-liaudat Jan 15, 2024
092198d
Update model with new projection procedure
tobias-liaudat Jan 16, 2024
256aa3c
Small bug fix when reset is True and project is False
tobias-liaudat Jan 16, 2024
abfe770
Remove repeated line (probably a typo)
tobias-liaudat Jan 16, 2024
0b19f88
Reproducibility change
tobias-liaudat Jan 16, 2024
2167eaa
Add PR validation configs
tobias-liaudat Jan 16, 2024
f8f6ccf
Add PR validation job
tobias-liaudat Jan 16, 2024
e8c2634
Merge branch 'develop' of https://github.com/CosmoStat/wf-psf into u/…
tobias-liaudat Mar 6, 2024
dcf6918
Merge branch 'develop' of https://github.com/CosmoStat/wf-psf into u/…
tobias-liaudat Mar 6, 2024
983c91b
remove duplicated comment
tobias-liaudat Mar 6, 2024
59ec709
add unit test for obscured zernike projection
tobias-liaudat Mar 6, 2024
cf0588b
apply black
tobias-liaudat Mar 6, 2024
d92f1f3
add unit test for PI_zernikes
tobias-liaudat Mar 6, 2024
af59920
black file
tobias-liaudat Mar 6, 2024
a2121eb
Merge branch 'develop' of https://github.com/CosmoStat/wf-psf into u/…
tobias-liaudat Mar 11, 2024
b8a661c
clean dev and validation files
tobias-liaudat Mar 11, 2024
d07e72d
remove TODO comment as it is already done
tobias-liaudat Mar 11, 2024
ffe365d
clean comment
tobias-liaudat Mar 13, 2024
7a9442a
clean comments
tobias-liaudat Mar 13, 2024
8bbd771
clean more comments
tobias-liaudat Mar 13, 2024
8fed999
black formatting
tobias-liaudat Mar 13, 2024
7aaea30
improve comment
tobias-liaudat Mar 18, 2024
48217ea
specify exception
tobias-liaudat Mar 18, 2024
d415010
add comment
tobias-liaudat Mar 18, 2024
09d881b
improve messages
tobias-liaudat Mar 18, 2024
93e0fec
add more descriptive name for PI_zernikes
tobias-liaudat Mar 18, 2024
ea09bf3
add comment
tobias-liaudat Mar 18, 2024
8737b24
improve comments
tobias-liaudat Mar 18, 2024
69a1f52
remove unused import
tobias-liaudat Mar 18, 2024
d501322
remove unused imports
tobias-liaudat Mar 18, 2024
ac2761f
remove unused import
tobias-liaudat Mar 18, 2024
8774691
black formatting
tobias-liaudat Mar 18, 2024
d0ee468
remove unnecessary dependency
tobias-liaudat Mar 18, 2024
44e55c8
apply black formatting
tobias-liaudat Mar 18, 2024
17d9a45
fix docstring
tobias-liaudat Mar 18, 2024
862ffd5
Modified to use PSF Factory Class
jeipollack Feb 7, 2024
1f5f1d7
Added Factory class and rm duplicate
jeipollack Feb 7, 2024
657f1d7
Refactored with DataHandler class and updated docstrings
jeipollack Feb 7, 2024
5cf3e2d
Updated to DataHandler class
jeipollack Feb 7, 2024
3a6c718
Reduced to single DataHandler class
jeipollack Feb 7, 2024
f3b2cfd
Updated docstring
jeipollack Feb 7, 2024
80de920
Updated docstring
jeipollack Feb 7, 2024
2f8a374
Reformatted with black
jeipollack Feb 7, 2024
d091af4
Removed attribute from class
jeipollack Feb 7, 2024
3cedf69
Added refactored v1 SpatialVaryingPSF module
jeipollack Feb 19, 2024
838d854
Updated docstring and comment
jeipollack Feb 19, 2024
6b12a4d
Added unit test package for sims
jeipollack Feb 26, 2024
62d992c
Unit tests for spatial_varying_psf module
jeipollack Feb 29, 2024
236cc96
Added spatial_varying_psf module
jeipollack Feb 29, 2024
9a83a7d
Added fixtures
jeipollack Feb 29, 2024
9508984
Reformatted module
jeipollack Feb 29, 2024
65818e5
Renamed module following pep8
jeipollack Mar 1, 2024
b0c013b
Added logger and changed arg names
jeipollack Mar 1, 2024
f61247c
Updated module description with new name of sim.spatial_varying_psf m…
jeipollack Mar 1, 2024
de3c424
Added missing doc strings
jeipollack Mar 5, 2024
b3fc264
Formatting and updated unit tests to check logging
jeipollack Mar 6, 2024
e5a4653
Renamed dict key training -> train
jeipollack Mar 7, 2024
a9da437
Fix run_config_name typo and DataHandler obj creation
jeipollack Mar 7, 2024
546ec42
Reformat with black
jeipollack Mar 7, 2024
a5126f6
Changed name to grid_size
jeipollack Mar 7, 2024
edd8ceb
Change to correct psf_factory_class method name
jeipollack Mar 7, 2024
362a739
Reformatted with black
jeipollack Mar 7, 2024
08e1026
Refactored to generic DataHandler
jeipollack Mar 7, 2024
19769a2
Change dataset name to
jeipollack Mar 7, 2024
6df68fa
Revert "Change dataset name to"
jeipollack Mar 11, 2024
964f19a
Added pytest-mock dependency to create mock api calls in unit tests
jeipollack Mar 12, 2024
a081817
Updated module with new data object
jeipollack Mar 12, 2024
539cc20
Updated conftest with new data key
jeipollack Mar 12, 2024
452abb9
Changed to absolute path
jeipollack Mar 12, 2024
a4b4dea
Added mock fixtures and unit tests for Data and TrainingHandlers
jeipollack Mar 12, 2024
238100d
Updated with new training_data and test_data namespaces
jeipollack Mar 12, 2024
352ef42
Corrected missing args due to new data_conf object
jeipollack Mar 12, 2024
c83e7bd
add black formatter to check/format the code to the ci.yml file
Feb 19, 2024
0edc700
Refactored with DataHandler class and updated docstrings
jeipollack Feb 7, 2024
38091c3
Removed print() from unit test
jeipollack Mar 13, 2024
a4bab87
Corrected name of arg
jeipollack Mar 13, 2024
cdcd0ff
Added missing args
jeipollack Mar 13, 2024
9f27358
Updated fixture for data_conf
jeipollack Mar 13, 2024
2cef6a6
Changed key train to training
jeipollack Mar 13, 2024
d4faaee
Moved build_psf_models() to psf_models.psf_models
jeipollack Mar 13, 2024
c9a4b95
Updated import build_psf_models from psf_models.psf_models
jeipollack Mar 13, 2024
f140870
Removed duplicated parametric model class
jeipollack Mar 13, 2024
281a995
Refactor: Updated class, variable, and function names to comply with …
jeipollack Mar 16, 2024
c0398d7
Modify class names to comply with PEP8
jeipollack Mar 18, 2024
48f2904
Change name of TestClass to RegisterConfigClass
jeipollack Mar 18, 2024
27ab641
Add module for psf physical model
jeipollack Mar 18, 2024
10021d0
Update doc strings
jeipollack Mar 18, 2024
279d44b
Move get_obs_positions utility function to data package
jeipollack Mar 19, 2024
956b2e7
Add test_data unit test package
jeipollack Mar 19, 2024
cfc9ed5
Rename tf_zernike_cube method to generate_zernike_maps_3d
jeipollack Mar 19, 2024
2ec9c4f
Refactor: Rename ground_truth_psf_model function to create_ground_tru…
jeipollack Mar 20, 2024
9b1dccd
Refactor: Remove methods, Improve docstrings, Set attributes in appro…
jeipollack Mar 20, 2024
8a1c6a1
Fix model_params.pupil_diameter typo
jeipollack Mar 20, 2024
18d34b3
Remove interpolation attributes from class
jeipollack Mar 20, 2024
160e45f
Add test_psf_models unit test package
jeipollack Mar 20, 2024
37f2053
Add get_zernike_prior method to retrieve zernike prior from a dataset
jeipollack Mar 20, 2024
5c1c973
Add fixture and unit tests for get_zernike_prior method
jeipollack Mar 20, 2024
8c674d3
Modify to use get_zernike_prior() to retrieve zernike_prior info
jeipollack Mar 20, 2024
ef93336
Refactor: Change/Add fixture and mocks for get_zernike_prior, fix moc…
jeipollack Mar 20, 2024
63078aa
Correct formatting with black
jeipollack Mar 20, 2024
8ea9404
Change zks_pad function name to pad_zernikes
jeipollack Mar 20, 2024
ec0cc86
Format with black
jeipollack Mar 20, 2024
bb27ce4
Refactor: Remove unused attributes, Add predict_zernikes, Improve doc…
jeipollack Mar 21, 2024
db844be
Add unit tests for pad_zernikes
jeipollack Mar 21, 2024
4330271
Refactor: Remove nzernikes attribute, Improve function names, Use zk_…
jeipollack Mar 21, 2024
01bca2c
Update to new class name TFPhysicalLayer
jeipollack Mar 22, 2024
cf1f93a
Replace compute_zernike with predict_zernike in methods predict_step,…
jeipollack Mar 22, 2024
caf6b46
Update call() docstring for TFPhysicalLayer for clarity
jeipollack Mar 22, 2024
9d76f55
Reformat with black
jeipollack Mar 22, 2024
c865134
Refactor: Added conditionals for pad_zernike and corrected typo in co…
jeipollack Mar 22, 2024
d9a2aa1
Add unit test for compute_zernikes
jeipollack Mar 22, 2024
3f47313
Refactor: Add conditional to pad_zernikes() to check shapes match
jeipollack Mar 25, 2024
8fa046a
Change tensors to same length in unit test, Add unit test to check mi…
jeipollack Mar 25, 2024
7410ed2
Merge branch 'feature/issues-116-and-123/psf-model-and-sims' into u/t…
Mar 27, 2024
b48f780
Update class name to PSFSimulator
Mar 27, 2024
0ddb218
Correct format with black
Mar 27, 2024
7dcf10d
Remove unwanted text
Mar 27, 2024
faea59a
Remove unused import
jeipollack Mar 27, 2024
8e3dc09
Change new funcs names to improve readability
jeipollack Mar 27, 2024
4dfeca2
Remove Duplicate Black formatter step from ci
Mar 27, 2024
68aad3f
Reformat file with black
Mar 27, 2024
4b9bdf7
Change names to improve readability; Correct typo
jeipollack Mar 27, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,14 @@ jobs:
- name: Test with pytest
run: python -m pytest

# Add Black formatter
- name: Install Black formatter
run: python -m pip install black

- name: Check code formatting with Black
run: black . --check --diff

linter_name:
name: Run Black formatter
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v2

- name: Run Black formatter
uses: rickstaa/action-black@v1
with:
black_args: ". --check"
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ release = [

test = [
"pytest>=7.0.0,<8.1",
"pytest-mock",
"pytest-cov",
"pytest-emoji",
"pytest-raises",
Expand Down
186 changes: 133 additions & 53 deletions src/wf_psf/data/training_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,90 +9,170 @@
import numpy as np
import wf_psf.utils.utils as utils
import tensorflow as tf
import tensorflow_addons as tfa
import os


class TrainingDataHandler:
"""Training Data Handler.
class DataHandler:
"""Data Handler.

A class to manage training data.
This class manages loading and processing of training and testing data for use in machine learning models.
It provides methods to access and preprocess the data.

Parameters
----------
training_data_params: Recursive Namespace object
data_type: str
A string indicating type of data ("train" or "test").
data_params: Recursive Namespace object
Recursive Namespace object containing training data parameters
simPSF: PSFSimulator
An instance of the PSFSimulator class for simulating a PSF.
n_bins_lambda: int
The number of bins in wavelength.
init_flag: bool, optional
A flag indicating whether to perform initialization steps upon object creation.
If True (default), the dataset is loaded and processed. If False, initialization
steps are skipped, and manual initialization is required.

Attributes
----------
dataset: dict
A dictionary containing the loaded dataset, including positions and stars/noisy_stars.
simPSF: object
PSFSimulator instance
An instance of the SimPSFToolkit class for simulating PSF.
n_bins_lambda: int
Number of bins in wavelength
The number of bins in wavelength.
sed_data: tf.Tensor
A TensorFlow tensor containing the SED data for training/testing.
init_flag: bool, optional
A flag used to control initialization steps. If True, initialization is performed
upon object creation.


"""

def __init__(self, training_data_params, simPSF, n_bins_lambda):
self.training_data_params = training_data_params
self.train_dataset = np.load(
os.path.join(
self.training_data_params.data_dir, self.training_data_params.file
),
def __init__(self, data_type, data_params, simPSF, n_bins_lambda, init_flag=True):
self.data_params = data_params.__dict__[data_type]
self.dataset = None
self.simPSF = simPSF
self.n_bins_lambda = n_bins_lambda
self.sed_data = None
self.initialize(init_flag)

def load_dataset(self):
"""Load dataset.

Load the dataset based on the specified data type.

"""
self.dataset = np.load(
os.path.join(self.data_params.data_dir, self.data_params.file),
allow_pickle=True,
)[()]
self.train_dataset["positions"] = tf.convert_to_tensor(
self.train_dataset["positions"], dtype=tf.float32
self.dataset["positions"] = tf.convert_to_tensor(
self.dataset["positions"], dtype=tf.float32
)
self.train_dataset["noisy_stars"] = tf.convert_to_tensor(
self.train_dataset["noisy_stars"], dtype=tf.float32
)
self.simPSF = simPSF
self.n_bins_lambda = n_bins_lambda
if "train" in self.data_params.file:
self.dataset["noisy_stars"] = tf.convert_to_tensor(
self.dataset["noisy_stars"], dtype=tf.float32
)
elif "test" in self.data_params.file:
self.dataset["stars"] = tf.convert_to_tensor(
self.dataset["stars"], dtype=tf.float32
)

def process_sed_data(self):
"""Process SED Data.

A method to generate and process SED data.

"""
self.sed_data = [
utils.generate_SED_elems_in_tensorflow(
_sed, self.simPSF, n_bins=self.n_bins_lambda, tf_dtype=tf.float64
)
for _sed in self.train_dataset["SEDs"]
for _sed in self.dataset["SEDs"]
]
self.sed_data = tf.convert_to_tensor(self.sed_data, dtype=tf.float32)
self.sed_data = tf.transpose(self.sed_data, perm=[0, 2, 1])

def initialize(self, init_flag):
"""Initialize.

Initialize the DataHandler instance by loading and processing the dataset,
if the init_flag is True.

class TestDataHandler:
"""Test Data Handler.
Parameters
----------
init_flag : bool
A flag indicating whether to perform initialization steps. If True,
the dataset is loaded and processed. If False, initialization steps
are skipped.

A class to handle test data for model validation.
"""
if init_flag:
self.load_dataset()
self.process_sed_data()


def get_obs_positions(data):
"""Get observed positions from the provided dataset.

This method concatenates the positions of the stars from both the training
and test datasets to obtain the observed positions.

Parameters
----------
test_data_params: Recursive Namespace object
Recursive Namespace object containing test data parameters
simPSF: object
PSFSimulator instance
n_bins_lambda: int
Number of bins in wavelength
data : DataConfigHandler
Object containing training and test datasets.

Returns
-------
tf.Tensor
Tensor containing the observed positions of the stars.

Notes
-----
The observed positions are obtained by concatenating the positions of stars
from both the training and test datasets along the 0th axis.

"""
obs_positions = np.concatenate(
(
data.training_data.dataset["positions"],
data.test_data.dataset["positions"],
),
axis=0,
)
return tf.convert_to_tensor(obs_positions, dtype=tf.float32)

def __init__(self, test_data_params, simPSF, n_bins_lambda):
self.test_data_params = test_data_params
self.test_dataset = np.load(
os.path.join(self.test_data_params.data_dir, self.test_data_params.file),
allow_pickle=True,
)[()]
self.test_dataset["stars"] = tf.convert_to_tensor(
self.test_dataset["stars"], dtype=tf.float32
)
self.test_dataset["positions"] = tf.convert_to_tensor(
self.test_dataset["positions"], dtype=tf.float32
)

# Prepare validation data inputs
self.simPSF = simPSF
self.n_bins_lambda = n_bins_lambda
def get_zernike_prior(data):
"""Get Zernike priors from the provided dataset.

self.sed_data = [
utils.generate_SED_elems_in_tensorflow(
_sed, self.simPSF, n_bins=self.n_bins_lambda, tf_dtype=tf.float64
)
for _sed in self.test_dataset["SEDs"]
]
self.sed_data = tf.convert_to_tensor(self.sed_data, dtype=tf.float32)
self.sed_data = tf.transpose(self.sed_data, perm=[0, 2, 1])
This method concatenates the Zernike priors from both the training
and test datasets.

Parameters
----------
data : DataConfigHandler
Object containing training and test datasets.

Returns
-------
tf.Tensor
Tensor containing the observed positions of the stars.

Notes
-----
The Zernike prior are obtained by concatenating the Zernike priors
from both the training and test datasets along the 0th axis.

"""
zernike_prior = np.concatenate(
(
data.training_data.dataset["zernike_prior"],
data.test_data.dataset["zernike_prior"],
),
axis=0,
)
return tf.convert_to_tensor(zernike_prior, dtype=tf.float32)
Loading
Loading