diff --git a/src/yt_napari/_data_model.py b/src/yt_napari/_data_model.py index ec9cdac..7c35c7b 100644 --- a/src/yt_napari/_data_model.py +++ b/src/yt_napari/_data_model.py @@ -58,6 +58,9 @@ class Region(BaseModel): (400, 400, 400), description="the resolution at which to sample between the edges.", ) + rescale: Optional[bool] = Field( + False, description="rescale the final image between 0,1" + ) class Slice(BaseModel): @@ -81,6 +84,9 @@ class Slice(BaseModel): periodic: Optional[bool] = Field( False, description="should the slice be periodic? default False." ) + rescale: Optional[bool] = Field( + False, description="rescale the final image between 0,1" + ) class SelectionObject(BaseModel): diff --git a/src/yt_napari/_model_ingestor.py b/src/yt_napari/_model_ingestor.py index e3c163d..ba663df 100644 --- a/src/yt_napari/_model_ingestor.py +++ b/src/yt_napari/_model_ingestor.py @@ -188,7 +188,7 @@ def align_sanitize_layers(self, layer_list: List[SpatialLayer]) -> List[Layer]: def selections_match(sel_1: Union[Slice, Region], sel_2: Union[Slice, Region]) -> bool: # compare selections, ignoring fields - if not type(sel_2) == type(sel_1): + if not type(sel_2) is type(sel_1): return False for attr in sel_1.__fields__.keys(): @@ -467,6 +467,9 @@ def _load_3D_regions( if field_container.take_log: data = np.log10(data) + if sel.rescale: + data = _linear_rescale(data) + # create a metadata dict and set a name fieldname = ":".join(field) md = create_metadata_dict(data, layer_domain, field_container.take_log) @@ -531,6 +534,15 @@ def _process_slice( return frb, layer_domain +def _linear_rescale(data, fill_inf=True): + # rescales an array between 0, 1 handling nans and infs + if fill_inf: + data[np.isinf(data)] = np.nan + data_max = np.nanmax(data) + data_min = np.nanmin(data) + return (data - data_min) / (data_max - data_min) + + def _load_2D_slices( ds, selections: SelectionObject, @@ -572,6 +584,9 @@ def _load_2D_slices( if field_container.take_log: data = np.log10(data) + if slice.rescale: + data = _linear_rescale(data) + # create a metadata dict and set a name fieldname = ":".join(field) md = create_metadata_dict(data, layer_domain, field_container.take_log) diff --git a/src/yt_napari/_tests/test_model_ingestor.py b/src/yt_napari/_tests/test_model_ingestor.py index 62de672..c7b9f2e 100644 --- a/src/yt_napari/_tests/test_model_ingestor.py +++ b/src/yt_napari/_tests/test_model_ingestor.py @@ -496,3 +496,27 @@ def test_yt_data_dir_check(tmp_path): files = _mi._generate_file_list("test_fi_blah_???") assert len(files) == nfiles ytcfg.set("yt", "test_data_dir", init_dir) + + +def test_linear_rescale(): + data = 10 * np.random.random((5, 5)) + rsc = _mi._linear_rescale(data) + assert rsc.min() == 0.0 + assert rsc.max() == 1.0 + + data[0, 0] = np.nan + rsc = _mi._linear_rescale(data) + assert np.nanmin(rsc) == 0.0 + assert np.nanmax(rsc) == 1.0 + + data[1, 1] = np.inf + rsc = _mi._linear_rescale(data) + assert np.nanmin(rsc) == 0.0 + assert np.nanmax(rsc) == 1.0 + + data = 10 * np.random.random((5, 5)) + data[1, 1] = np.inf + with pytest.warns(RuntimeWarning, match="invalid value"): + rsc = _mi._linear_rescale(data, fill_inf=False) + assert np.nanmin(rsc) == 0.0 + assert np.nanmax(rsc) == 0.0 diff --git a/src/yt_napari/_tests/test_regions_json.py b/src/yt_napari/_tests/test_regions_json.py new file mode 100644 index 0000000..39332c6 --- /dev/null +++ b/src/yt_napari/_tests/test_regions_json.py @@ -0,0 +1,35 @@ +import pytest + +from yt_napari._data_model import InputModel +from yt_napari._model_ingestor import _process_validated_model +from yt_napari._schema_version import schema_name + +jdicts = [] +jdicts.append( + { + "$schema": schema_name, + "datasets": [ + { + "filename": "_ytnapari_load_grid", + "selections": { + "regions": [ + { + "fields": [{"field_name": "density", "field_type": "gas"}], + "resolution": [10, 10, 10], + } + ] + }, + } + ], + } +) + + +@pytest.mark.parametrize("jdict", jdicts) +def test_load_region(jdict): + jdict["datasets"][0]["selections"]["regions"][0]["rescale"] = True + m = InputModel.parse_obj(jdict) + layers, _ = _process_validated_model(m) + im_data = layers[0][0] + assert im_data.min() == 0.0 + assert im_data.max() == 1.0 diff --git a/src/yt_napari/_tests/test_slices_json.py b/src/yt_napari/_tests/test_slices_json.py index cea9c73..d772856 100644 --- a/src/yt_napari/_tests/test_slices_json.py +++ b/src/yt_napari/_tests/test_slices_json.py @@ -58,3 +58,10 @@ def test_slice_load(yt_ugrid_ds_fn, jdict): layer_lists, _ = _process_validated_model(im) ref_layer = _choose_ref_layer(layer_lists) _ = ref_layer.align_sanitize_layers(layer_lists) + + jdict["datasets"][0]["selections"]["slices"][0]["rescale"] = True + im = InputModel.parse_obj(jdict) + layer_lists, _ = _process_validated_model(im) + im_data = layer_lists[0][0] + assert im_data.min() == 0 + assert im_data.max() == 1 diff --git a/src/yt_napari/viewer.py b/src/yt_napari/viewer.py index e4b614e..af03dcd 100644 --- a/src/yt_napari/viewer.py +++ b/src/yt_napari/viewer.py @@ -60,10 +60,14 @@ def _add_to_scene( take_log, colormap=None, link_to=None, + rescale=False, **kwargs, ): # adds any new data to the viewer + if rescale: + data = _mi._linear_rescale(data) + if colormap is None: colormap = "viridis" @@ -118,6 +122,7 @@ def add_region( take_log: Optional[bool] = None, colormap: Optional[str] = None, link_to: Optional[Union[str, Layer]] = None, + rescale: Optional[bool] = False, **kwargs, ): """ @@ -191,6 +196,7 @@ def add_region( take_log, colormap=colormap, link_to=link_to, + rescale=rescale, **kwargs, ) @@ -240,6 +246,7 @@ def add_slice( periodic: Optional[bool] = False, colormap: Optional[str] = None, link_to: Optional[Union[str, Layer]] = None, + rescale: Optional[bool] = False, **kwargs, ): """ @@ -313,6 +320,7 @@ def add_slice( take_log, colormap=colormap, link_to=link_to, + rescale=rescale, **kwargs, )