Skip to content

Commit

Permalink
Merge pull request #84 from chrishavlin/rescale_option
Browse files Browse the repository at this point in the history
add a data rescale option
  • Loading branch information
chrishavlin authored Aug 29, 2023
2 parents 28bd9bb + e6012eb commit 4955871
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 1 deletion.
6 changes: 6 additions & 0 deletions src/yt_napari/_data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ class Region(BaseModel):
(400, 400, 400),
description="the resolution at which to sample between the edges.",
)
rescale: Optional[bool] = Field(
False, description="rescale the final image between 0,1"
)


class Slice(BaseModel):
Expand All @@ -81,6 +84,9 @@ class Slice(BaseModel):
periodic: Optional[bool] = Field(
False, description="should the slice be periodic? default False."
)
rescale: Optional[bool] = Field(
False, description="rescale the final image between 0,1"
)


class SelectionObject(BaseModel):
Expand Down
17 changes: 16 additions & 1 deletion src/yt_napari/_model_ingestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def align_sanitize_layers(self, layer_list: List[SpatialLayer]) -> List[Layer]:

def selections_match(sel_1: Union[Slice, Region], sel_2: Union[Slice, Region]) -> bool:
# compare selections, ignoring fields
if not type(sel_2) == type(sel_1):
if not type(sel_2) is type(sel_1):
return False

for attr in sel_1.__fields__.keys():
Expand Down Expand Up @@ -467,6 +467,9 @@ def _load_3D_regions(
if field_container.take_log:
data = np.log10(data)

if sel.rescale:
data = _linear_rescale(data)

# create a metadata dict and set a name
fieldname = ":".join(field)
md = create_metadata_dict(data, layer_domain, field_container.take_log)
Expand Down Expand Up @@ -531,6 +534,15 @@ def _process_slice(
return frb, layer_domain


def _linear_rescale(data, fill_inf=True):
# rescales an array between 0, 1 handling nans and infs
if fill_inf:
data[np.isinf(data)] = np.nan
data_max = np.nanmax(data)
data_min = np.nanmin(data)
return (data - data_min) / (data_max - data_min)


def _load_2D_slices(
ds,
selections: SelectionObject,
Expand Down Expand Up @@ -572,6 +584,9 @@ def _load_2D_slices(
if field_container.take_log:
data = np.log10(data)

if slice.rescale:
data = _linear_rescale(data)

# create a metadata dict and set a name
fieldname = ":".join(field)
md = create_metadata_dict(data, layer_domain, field_container.take_log)
Expand Down
24 changes: 24 additions & 0 deletions src/yt_napari/_tests/test_model_ingestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,3 +496,27 @@ def test_yt_data_dir_check(tmp_path):
files = _mi._generate_file_list("test_fi_blah_???")
assert len(files) == nfiles
ytcfg.set("yt", "test_data_dir", init_dir)


def test_linear_rescale():
data = 10 * np.random.random((5, 5))
rsc = _mi._linear_rescale(data)
assert rsc.min() == 0.0
assert rsc.max() == 1.0

data[0, 0] = np.nan
rsc = _mi._linear_rescale(data)
assert np.nanmin(rsc) == 0.0
assert np.nanmax(rsc) == 1.0

data[1, 1] = np.inf
rsc = _mi._linear_rescale(data)
assert np.nanmin(rsc) == 0.0
assert np.nanmax(rsc) == 1.0

data = 10 * np.random.random((5, 5))
data[1, 1] = np.inf
with pytest.warns(RuntimeWarning, match="invalid value"):
rsc = _mi._linear_rescale(data, fill_inf=False)
assert np.nanmin(rsc) == 0.0
assert np.nanmax(rsc) == 0.0
35 changes: 35 additions & 0 deletions src/yt_napari/_tests/test_regions_json.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import pytest

from yt_napari._data_model import InputModel
from yt_napari._model_ingestor import _process_validated_model
from yt_napari._schema_version import schema_name

jdicts = []
jdicts.append(
{
"$schema": schema_name,
"datasets": [
{
"filename": "_ytnapari_load_grid",
"selections": {
"regions": [
{
"fields": [{"field_name": "density", "field_type": "gas"}],
"resolution": [10, 10, 10],
}
]
},
}
],
}
)


@pytest.mark.parametrize("jdict", jdicts)
def test_load_region(jdict):
jdict["datasets"][0]["selections"]["regions"][0]["rescale"] = True
m = InputModel.parse_obj(jdict)
layers, _ = _process_validated_model(m)
im_data = layers[0][0]
assert im_data.min() == 0.0
assert im_data.max() == 1.0
7 changes: 7 additions & 0 deletions src/yt_napari/_tests/test_slices_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,10 @@ def test_slice_load(yt_ugrid_ds_fn, jdict):
layer_lists, _ = _process_validated_model(im)
ref_layer = _choose_ref_layer(layer_lists)
_ = ref_layer.align_sanitize_layers(layer_lists)

jdict["datasets"][0]["selections"]["slices"][0]["rescale"] = True
im = InputModel.parse_obj(jdict)
layer_lists, _ = _process_validated_model(im)
im_data = layer_lists[0][0]
assert im_data.min() == 0
assert im_data.max() == 1
8 changes: 8 additions & 0 deletions src/yt_napari/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,14 @@ def _add_to_scene(
take_log,
colormap=None,
link_to=None,
rescale=False,
**kwargs,
):
# adds any new data to the viewer

if rescale:
data = _mi._linear_rescale(data)

if colormap is None:
colormap = "viridis"

Expand Down Expand Up @@ -118,6 +122,7 @@ def add_region(
take_log: Optional[bool] = None,
colormap: Optional[str] = None,
link_to: Optional[Union[str, Layer]] = None,
rescale: Optional[bool] = False,
**kwargs,
):
"""
Expand Down Expand Up @@ -191,6 +196,7 @@ def add_region(
take_log,
colormap=colormap,
link_to=link_to,
rescale=rescale,
**kwargs,
)

Expand Down Expand Up @@ -240,6 +246,7 @@ def add_slice(
periodic: Optional[bool] = False,
colormap: Optional[str] = None,
link_to: Optional[Union[str, Layer]] = None,
rescale: Optional[bool] = False,
**kwargs,
):
"""
Expand Down Expand Up @@ -313,6 +320,7 @@ def add_slice(
take_log,
colormap=colormap,
link_to=link_to,
rescale=rescale,
**kwargs,
)

Expand Down

0 comments on commit 4955871

Please sign in to comment.