Skip to content
This repository has been archived by the owner on Nov 15, 2018. It is now read-only.

Commit

Permalink
engine/dataset and engine/util modules refactoring, tests update
Browse files Browse the repository at this point in the history
  • Loading branch information
intsco committed Aug 11, 2016
1 parent 41be8e5 commit eda373d
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 10 deletions.
15 changes: 12 additions & 3 deletions sm/engine/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,6 @@ def _define_pixels_order(self):
pixel_indices = pixel_indices.astype(np.int32)
self.norm_img_pixel_inds = pixel_indices

self.sample_area_mask = np.zeros(ncols*nrows).astype(bool)
self.sample_area_mask[pixel_indices] = True

def get_norm_img_pixel_inds(self):
"""
Returns
Expand All @@ -78,6 +75,18 @@ def get_norm_img_pixel_inds(self):
"""
return self.norm_img_pixel_inds

def get_sample_area_mask(self):
"""
Returns
-------
: ndarray
One-dimensional bool array of pixel indices where spectra were sampled
"""
nrows, ncols = self.get_dims()
sample_area_mask = np.zeros(ncols * nrows).astype(bool)
sample_area_mask[self.norm_img_pixel_inds] = True
return sample_area_mask

def get_dims(self):
"""
Returns
Expand Down
2 changes: 1 addition & 1 deletion sm/engine/msm_basic/formula_img_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def sf_image_metrics(sf_images, sc, formulas, ds, ds_config):
"""
nrows, ncols = ds.get_dims()
empty_matrix = np.zeros((nrows, ncols))
compute_metrics = get_compute_img_metrics(ds.sample_area_mask, empty_matrix, ds_config['image_generation'])
compute_metrics = get_compute_img_metrics(ds.get_sample_area_mask(), empty_matrix, ds_config['image_generation'])
sf_add_ints_map_brcast = sc.broadcast(formulas.get_sf_peak_ints())
# sf_peak_ints_brcast = sc.broadcast(formulas.get_sf_peak_ints())

Expand Down
3 changes: 2 additions & 1 deletion sm/engine/tests/msm_basic/test_formula_img_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_get_compute_img_measures_pass(chaos_mock, image_corr_mock, pattern_matc
'q': 99.0
}
empty_matrix = np.zeros((2, 3))
compute_measures = get_compute_img_metrics(empty_matrix, img_gen_conf)
compute_measures = get_compute_img_metrics(np.ones(2*3).astype(bool), empty_matrix, img_gen_conf)

sf_iso_images = [csr_matrix([[0., 100., 100.], [10., 0., 3.]]),
csr_matrix([[0., 50., 50.], [0., 20., 0.]])]
Expand All @@ -38,6 +38,7 @@ def test_get_compute_img_measures_pass(chaos_mock, image_corr_mock, pattern_matc
def ds_formulas_images_mock():
ds_mock = MagicMock(spec=Dataset)
ds_mock.get_dims.return_value = (2, 3)
ds_mock.get_sample_area_mask.return_value = np.ones(2*3).astype(bool)

formulas_mock = MagicMock(spec=FormulasSegm)
formulas_mock.get_sf_peak_ints.return_value = {(0, '+H'): [100, 10, 1], (1, '+H'): [100, 10, 1]}
Expand Down
22 changes: 21 additions & 1 deletion sm/engine/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,27 @@

from sm.engine.dataset import Dataset
from sm.engine.util import SMConfig
#from sm.engine.work_dir import WorkDir
from sm.engine.work_dir import WorkDirManager
from sm.engine.tests.util import sm_config, ds_config, spark_context


def test_get_sample_area_mask_correctness(sm_config, ds_config, spark_context):
work_dir_man_mock = MagicMock(WorkDirManager)
work_dir_man_mock.ds_coord_path = '/ds_path'
work_dir_man_mock.txt_path = '/txt_path'

SMConfig._config_dict = sm_config

with patch('sm.engine.tests.util.SparkContext.textFile') as m:
m.return_value = spark_context.parallelize([
'0,0,0\n',
'2,1,1\n'])

ds = Dataset(spark_context, 'ds_name', '', 'input_path', ds_config, work_dir_man_mock, None)

#ds.norm_img_pixel_inds = np.array([0, 3])

assert tuple(ds.get_sample_area_mask()) == (True, False, False, True)


# def test_get_dims_2by3(spark_context, sm_config):
Expand Down
9 changes: 6 additions & 3 deletions sm/engine/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,12 @@ def get_conf(cls):
SM engine configuration
"""
if not cls._config_dict:
config_path = cls._path or join(proj_root(), 'conf', 'config.json')
with open(config_path) as f:
cls._config_dict = json.load(f)
try:
config_path = cls._path or join(proj_root(), 'conf', 'config.json')
with open(config_path) as f:
cls._config_dict = json.load(f)
except IOError as e:
logger.warn(e)
return cls._config_dict


Expand Down
1 change: 0 additions & 1 deletion tests/test_dataset_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def test_save_ds_meta_ds_doesnt_exist(spark_context, create_test_db, drop_test_d

db.close()


# def test_save_ds_meta_ds_exists(spark_context, create_test_db, fill_test_db, drop_test_db, sm_config, ds_config):
# work_dir_mock = MagicMock(WorkDir)
# work_dir_mock.ds_coord_path = '/new_ds_path'
Expand Down

0 comments on commit eda373d

Please sign in to comment.