Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update code and readme #10

Open
wants to merge 24 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 116 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,122 @@
# BlendHunter
![blendhunter](https://user-images.githubusercontent.com/7417573/127934298-39734525-6325-4d98-900d-136227f03b38.png)

Deep learning tool for identifying blended galaxy images in survey images.

### blendhunter
Find the codes for the data preparation and the network in that folder.
---
> Main contributors:
> - <a href="https://github.com/sfarrens" target="_blank" style="text-decoration:none; color: #F08080">Samuel Farrens</a>
> - <a href="https://github.com/ablacan" target="_blank" style="text-decoration:none; color: #F08080">Alice Lacan</a>
> - <a href="https://github.com/aguinot" target="_blank" style="text-decoration:none; color: #F08080">Axel Guinot</a>
> - <a href="https://github.com/andrevitorelli" target="_blank" style="text-decoration:none; color: #F08080">André Zamorano Vitorelli</a>
---

BlendHunter deep transfer learning based approach for the automated and robust identification of blended sources in galaxy survey data. See [Farrens et al. (2021)](...) for details.

## Dependencies
The following python packages should be installed with their specific dependencies:

- [Numpy](https://github.com/numpy/numpy)
- [ModOpt](https://github.com/CEA-COSMIC/ModOpt)
- [LMFIT](https://lmfit.github.io/lmfit-py/)
- [SF_Tools](https://github.com/sfarrens/sf_tools)
- [OpenCV](https://github.com/opencv/opencv-python)
- [TensorFlow](https://github.com/tensorflow/tensorflow)
- [SEP](https://github.com/kbarbary/sep/tree/v1.1.x)

## Local installation

```bash
$ git clone https://github.com/CosmoStat/BlendHunter.git
$ pip install -e .
```

## Reproducible Research

To repeat experiments carried out in [Farrens et al. (2021)](...) or to carry out similar experiments on a different data set you will need to go through the following steps.

### Download Data

You can download the parametric training data and realistic CFIS-like images [here]().

Alternatively, you can use your own data provided it is formatted in the same way.

### Configuration Setup

You will need to modify the `bhconfig.yml` file in the `data` directory. Specifically, specifying the paths to the input data and where outputs should be written.

The structure of this file is as follows

```yml
out_path: ...
in_path: ...
cosmos_path: ...
noise_sigma:
- 5
- 10
- 15
- 20
- 25
- 30
- 35
n_noise_real: 10
sep_sample_range:
- 36000
- 40000
cosmos_sample_range:
- 0
- 10000
```

where:

- `out_path` specifies the path where outputs should be written.
- `in_path` specifies the path to the input parametric model training data.
- `cosmos_path` specifies the path to the input realistic testing data.
- `noise_sigma` specifies the list of noise standard deviations that should be added to the training data.
- `n_noise_real` specifies the number of noise realisations that should be made for each noise level.
- `sep_sample_range` specifies the range of objects in the training sample on which SEP should be run.
- `cosmos_sample_range` specifies the range of objects...

### Prepare Data

Once you have downloaded (or formatted) your data and updated the configuration file you should run the following scrips in the `scripts` directory.

```bash
$ python scripts/create_directories.py
```

This will prepare directories in your output path to store all of the output products.

```bash
$ python scripts/prep_data.py
```

This will prepare the training and testing data set by padding, adding noise and converting to PNG files.

> :warning: Each noise level and realisation will constitute an independent data set and increase the amount of storage space required. *e.g.* 7 noise levels and 10 realisations will constitute 70 times the volume of data.

### Run BlendHunter and SEP

Run BlendHunter to train the network on the parametric training data. This additionally tests the resulting weights by making predictions on a subsample of this data reserved for testing.

```bash
$ python scripts/run_bh.py
```

Run SEP on the subsample of testing data for comparison.

```bash
$ python scripts/run_sep.py
```

### Test on CFIS-like images

### notebooks
Find the jupyter notebooks for results visualization.
Run both BlendHunter and SEP on the realistic CFIS-like testing data.

### sextractor
Find the scripts to run SExtractor.
```bash
$ python scripts/test_cosmos.py
```

### Notebooks

Finally, in the `notebooks` directory, you will find several Jupyter notebooks where you can reproduce the plots in the paper or make equivalent plots for a different data set.
28 changes: 28 additions & 0 deletions blendhunter/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import os
import yaml


class BHConfig:

def __init__(self, config_file='data/bhconfig.yml'):

self.config_file = config_file
self._read_config()

def _read_config(self):

if os.path.isfile(self.config_file):
with open(self.config_file) as file:
self.config = yaml.load(file, Loader=yaml.FullLoader)
else:
self.config = {}

def _update_config(self):

with open(self.config_file, 'w') as file:
yaml.dump(self.config, file)

def _add_params(self, params):

self.config = {**self.config, **params}
self._update_config()
82 changes: 46 additions & 36 deletions blendhunter/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def __init__(self, images, output_path, train_fractions=(0.45, 0.45, 0.1),
self.blend_images = blend_images
self.blend_fractions = blend_fractions
self.blend_method = blend_method
self._image_num = 0

self._make_output_dirs()

Expand Down Expand Up @@ -230,8 +229,9 @@ def _write_images(self, images, path):
"""

min_shape = np.array([48, 48])
zero_pad = np.log10(images.shape[0]).astype(int) + 1

for image in images:
for image_num, image in enumerate(images):

image = self._rescale(image)

Expand All @@ -240,8 +240,8 @@ def _write_images(self, images, path):
if np.sum(shape_diff) > 0:
image = self._pad(image, shape_diff)

cv2.imwrite('{}/image_{}.png'.format(path, self._image_num), image)
self._image_num += 1
cv2.imwrite('{0}/image_{2:0{1}d}.png'.format(path, zero_pad,
image_num), image)

def _write_data_set(self, data_list, path_list):
""" Write Data Set
Expand Down Expand Up @@ -279,7 +279,6 @@ def _write_labels(self, data_list):

np.save('{}/labels.npy'.format(self._test_path), labels)


def _write_positions(self, pos_list):

np.save('{}/positions.npy'.format(self._test_path), np.array(pos_list))
Expand Down Expand Up @@ -361,56 +360,67 @@ def generate(self):
test_set = np.vstack(test_set)
self._write_images(test_set, self._test_path)


def prep_axel(self, path_to_output=None, psf=None, param_1 = None, param_2=None, map=None):
def prep_axel(self, path_to_output=None, psf=None, param_1=None,
param_2=None, map=None):
# Can add parameters : psf, param_1, param_2, map

#self.images[0] = np.random.permutation(self.images[0])
#self.images[1] = np.random.permutation(self.images[1])
# self.images[0] = np.random.permutation(self.images[0])
# self.images[1] = np.random.permutation(self.images[1])

split1 = self._split_array(self.images[0], self.train_fractions)
split2 = self._split_array(self.images[1], self.train_fractions)

#Split fwhm
train_fractions=(0.45, 0.45, 0.1)
#psf_split1 = CreateTrainData._split_array(psf[0], train_fractions)
#psf_split2 = CreateTrainData._split_array(psf[1], train_fractions)
# Split fwhm
train_fractions = (0.45, 0.45, 0.1)
# psf_split1 = CreateTrainData._split_array(psf[0], train_fractions)
# psf_split2 = CreateTrainData._split_array(psf[1], train_fractions)

#Split shift params
#x_split = CreateTrainData._split_array(param_1[0], train_fractions)
#y_split = CreateTrainData._split_array(param_2[0], train_fractions)
# Split shift params
# x_split = CreateTrainData._split_array(param_1[0], train_fractions)
# y_split = CreateTrainData._split_array(param_2[0], train_fractions)

#Split segmentation map
#map_split = CreateTrainData._split_array(map[0], train_fractions)
# Split segmentation map
# map_split = CreateTrainData._split_array(map[0], train_fractions)

train_set = split1[0], split2[0]
valid_set = split1[1], split2[1]
test_set = split1[2], split2[2]


#test_psf = psf_split1[2], psf_split2[2]
#test_param_x = x_split[2]
#test_param_y = y_split[2]
#test_im_blended = split1[2] #blended test images
#test_im_nb = split2[2]
#test_map = map_split[2]
# test_psf = psf_split1[2], psf_split2[2]
# test_param_x = x_split[2]
# test_param_y = y_split[2]
# test_im_blended = split1[2] #blended test images
# test_im_nb = split2[2]
# test_map = map_split[2]

self._write_data_set(train_set, self._train_paths)
self._write_data_set(valid_set, self._valid_paths)
self._write_labels(test_set)
self._write_images(np.vstack(test_set), self._test_path)

#Save test_psf
#np.save(path_to_output+'/test_psf.npy', test_psf)
# Save test_psf
# np.save(path_to_output+'/test_psf.npy', test_psf)

# Save test_params
# np.save(path_to_output+'/test_param_x.npy', test_param_x)
# np.save(path_to_output+'/test_param_y.npy', test_param_y)

# Save blended test images
# np.save(path_to_output+'/gal_im_blended.npy', test_im_blended)
# np.save(path_to_output+'/gal_im_nb.npy', test_im_nb)
# np.save(path_to_output+'/test_images.npy', test_set)

# Save seg_map
# np.save(path_to_output+'/test_seg_map.npy', test_map)

def prep_cosmos(self):

#Save test_params
#np.save(path_to_output+'/test_param_x.npy', test_param_x)
#np.save(path_to_output+'/test_param_y.npy', test_param_y)
split1 = self._split_array(self.images[0], self.train_fractions)
split2 = self._split_array(self.images[1], self.train_fractions)

#Save blended test images
#np.save(path_to_output+'/gal_im_blended.npy', test_im_blended)
#np.save(path_to_output+'/gal_im_nb.npy', test_im_nb)
#np.save(path_to_output+'/test_images.npy', test_set)
train_set = split1[0], split2[0]
valid_set = split1[1], split2[1]
test_set = split1[2], split2[2]

#Save seg_map
#np.save(path_to_output+'/test_seg_map.npy', test_map)
self._write_labels(test_set)
self._write_images(np.vstack(test_set), self._test_path)
39 changes: 17 additions & 22 deletions blendhunter/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,15 @@
import seaborn as sns
from time import time
from cv2 import imread
import keras
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential, Model
from keras.layers import Dropout, Flatten, Dense, Input
from keras.applications import VGG16
from keras.optimizers import Adam, SGD
from keras.callbacks import ModelCheckpoint
from keras.callbacks import EarlyStopping
from keras.callbacks import ReduceLROnPlateau
from keras import regularizers
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Dropout, Flatten, Dense, Input
from tensorflow.keras.applications import VGG16
from tensorflow.keras.optimizers import Adam, SGD
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.callbacks import ReduceLROnPlateau
from tensorflow.keras import regularizers


class BlendHunter(object):
Expand Down Expand Up @@ -118,7 +117,7 @@ def _get_target_shape(self, image_path=None):
self._target_size = self._image_shape[:2]

def _load_generator(self, input_dir, batch_size=None,
class_mode=None, augmentation=True):
class_mode=None, augmentation=False):
""" Load Generator
Load files from an input directory into a Keras generator.
Parameters
Expand Down Expand Up @@ -261,7 +260,8 @@ def _load_features(self):
value[feature_name] = self._load_data(key, out_path)

@staticmethod
def _build_top_model(input_shape, dense_output=(256, 512, 1024), dropout=0.1):
def _build_top_model(input_shape, dense_output=(256, 512, 1024),
dropout=0.1):
""" Build Top Model
Build the fully connected layers of the network.
Parameters
Expand Down Expand Up @@ -431,7 +431,6 @@ def _fine_tune(self):
batch_size=self._batch_size_fine,
class_mode='binary')


callbacks = []
callbacks.append(ModelCheckpoint('{}.h5'.format(self._fine_tune_file),
monitor='val_loss', verbose=self._verbose,
Expand All @@ -444,37 +443,34 @@ def _fine_tune(self):
cooldown=2, verbose=self._verbose))
# callbacks.append(LoggingCallback(filetxt=logfile, log=write_log)])

history_tune1 = model.fit_generator(train_gen, steps_per_epoch=train_gen.steps,
history_tune1 = model.fit_generator(
train_gen,
steps_per_epoch=train_gen.steps,
epochs=self._epochs_fine,
callbacks=callbacks,
validation_data=valid_gen,
validation_steps=valid_gen.steps,
verbose=self._verbose)
verbose=self._verbose
)

self._freeze_layers(model, 19)
model.layers[17].trainable = True
model.layers[18].trainable = True


model.compile(loss='binary_crossentropy',
optimizer=Adam(lr=1e-6, decay=1e-6),
metrics=['binary_accuracy'])


model.fit_generator(train_gen, steps_per_epoch=train_gen.steps,
epochs=self._epochs_fine,
callbacks=callbacks,
validation_data=valid_gen,
validation_steps=valid_gen.steps,
verbose=self._verbose)


model.save_weights('{}.h5'.format(self._final_model_file))
print(history_tune1.history.keys())




def train(self, input_path, get_features=True, train_top=True,
fine_tune=True, train_dir_name='train',
valid_dir_name='validation', epochs_top=500, epochs_fine=80,
Expand Down Expand Up @@ -598,7 +594,6 @@ def predict(self, input_path=None, input_path_keras=None, input_data=None,
verbose=self._verbose,
steps=test_gen.steps).flatten()


elif not isinstance(input_data, type(None)):

self._image_shape = input_data.shape[1:]
Expand Down
Loading