Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some fixes in sisl.viz #619

Merged
merged 3 commits into from
Sep 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified src/sisl/viz/.coverage
Binary file not shown.
28 changes: 26 additions & 2 deletions src/sisl/viz/figure/plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,9 @@
action_name = action['method']
if action_name.startswith("draw_"):
action = {**action, "kwargs": {**action.get("kwargs", {}), **row_col_kwargs}}
action['kwargs']['meta'] = {**action['kwargs'].get('meta', {}), "i_plot": i}

Check warning on line 123 in src/sisl/viz/figure/plotly.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/viz/figure/plotly.py#L123

Added line #L123 was not covered by tests
elif action_name.startswith("set_ax"):
action = {**action, "kwargs": {**action.get("kwargs", {}), "_active_axes": active_axes}}
action = {**action, "kwargs": {**action.get("kwargs", {}), "_active_axes": active_axes}}

Check warning on line 125 in src/sisl/viz/figure/plotly.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/viz/figure/plotly.py#L125

Added line #L125 was not covered by tests

sanitized_section_actions.append(action)

Expand Down Expand Up @@ -157,12 +158,28 @@
action_name = action['method']
if action_name.startswith("draw_"):
action = {**action, "kwargs": {**action.get("kwargs", {}), **active_axes_kwargs}}
action['kwargs']['meta'] = {**action['kwargs'].get('meta', {}), "i_plot": i}

Check warning on line 161 in src/sisl/viz/figure/plotly.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/viz/figure/plotly.py#L161

Added line #L161 was not covered by tests
elif action_name.startswith("set_ax"):
action = {**action, "kwargs": {**action.get("kwargs", {}), "_active_axes": active_axes}}

sanitized_section_actions.append(action)

yield sanitized_section_actions

def _iter_same_axes(self, plot_actions):

for i, section_actions in enumerate(plot_actions):

Check warning on line 171 in src/sisl/viz/figure/plotly.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/viz/figure/plotly.py#L171

Added line #L171 was not covered by tests

sanitized_section_actions = []
for action in section_actions:
action_name = action['method']
if action_name.startswith("draw_"):
action = {**action, "kwargs": action.get("kwargs", {})}
action['kwargs']['meta'] = {**action['kwargs'].get('meta', {}), "i_plot": i}

Check warning on line 178 in src/sisl/viz/figure/plotly.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/viz/figure/plotly.py#L173-L178

Added lines #L173 - L178 were not covered by tests

sanitized_section_actions.append(action)

Check warning on line 180 in src/sisl/viz/figure/plotly.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/viz/figure/plotly.py#L180

Added line #L180 was not covered by tests

yield sanitized_section_actions

Check warning on line 182 in src/sisl/viz/figure/plotly.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/viz/figure/plotly.py#L182

Added line #L182 was not covered by tests

def _init_figure_animated(self, frame_names: Optional[Sequence[str]] = None, frame_duration: int = 500, transition: int = 300, redraw: bool = False, **kwargs):
self._animation_settings = {
Expand Down Expand Up @@ -190,7 +207,14 @@
frames = []
for i, section_actions in enumerate(plot_actions):

yield section_actions
sanitized_section_actions = []
for action in section_actions:
action_name = action['method']
if action_name.startswith("draw_"):
action = {**action, "kwargs": action.get("kwargs", {})}
action['kwargs']['meta'] = {**action['kwargs'].get('meta', {}), "i_plot": i}

Check warning on line 215 in src/sisl/viz/figure/plotly.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/viz/figure/plotly.py#L210-L215

Added lines #L210 - L215 were not covered by tests

yield sanitized_section_actions

Check warning on line 217 in src/sisl/viz/figure/plotly.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/viz/figure/plotly.py#L217

Added line #L217 was not covered by tests

# Create a frame and append it
frames.append(go.Frame(name=frame_names[i],data=self.figure.data, layout=self.figure.layout))
Expand Down
2 changes: 1 addition & 1 deletion src/sisl/viz/plots/merged.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


def merge_plots(*figures: Figure,
composite_method: Optional[Literal["multiple", "subplots", "multiple_x", "multiple_y", "animation"]] = None,
composite_method: Optional[Literal["multiple", "subplots", "multiple_x", "multiple_y", "animation"]] = "multiple",
backend: Literal["plotly", "matplotlib", "py3dmol", "blender"] = "plotly",
**kwargs
) -> Figure:
Expand Down
62 changes: 42 additions & 20 deletions src/sisl/viz/processors/orbital.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,40 +560,62 @@
if isinstance(sanitize_group, OrbitalQueriesManager):
sanitize_group = sanitize_group.sanitize_query

if geometry is None:
def _sanitize_group(group):
group = group.copy()
group = sanitize_group(group)
data_spin = orbital_data.attrs.get("spin", Spin(""))

def _sanitize_group(group):
group = group.copy()
group = sanitize_group(group)

if geometry is None:
orbitals = group.get('orbitals')
try:
group['orbitals'] = np.array(orbitals, dtype=int)
assert orbitals.ndim == 1
except:
raise SislError("A geometry was neither provided nor found in the xarray object. Therefore we can't"
f" convert the provided atom selection ({orbitals}) to an array of integers.")
else:
group["orbitals"] = geometry._sanitize_orbs(group["orbitals"])

group['selector'] = group['orbitals']
if spin_reduce is not None and spin_dim in orbital_data.dims:
group['selector'] = (group['selector'], group.get('spin'))
group['reduce_func'] = (group.get('reduce_func', reduce_func), spin_reduce)
group['selector'] = group['orbitals']

return group
else:
def _sanitize_group(group):
group = group.copy()
group = sanitize_group(group)
group["orbitals"] = geometry._sanitize_orbs(group["orbitals"])
group['selector'] = group['orbitals']
if spin_reduce is not None and spin_dim in orbital_data.dims:
group['selector'] = (group['selector'], group.get('spin'))
group['reduce_func'] = (group.get('reduce_func', reduce_func), spin_reduce)
return group
req_spin = group.get("spin")
if req_spin is None and data_spin.is_polarized and spin_dim in orbital_data.coords:
if spin_reduce is None:
group['spin'] = original_spin_coord
else:
group['spin'] = [0, 1]

if (spin_reduce is not None or group.get("spin") is not None) and spin_dim in orbital_data.dims:
group['selector'] = (group['selector'], group.get('spin'))
group['reduce_func'] = (group.get('reduce_func', reduce_func), spin_reduce)

return group

original_spin_coord = None
if data_spin.is_polarized and spin_dim in orbital_data.coords:

if not isinstance(orbital_data, (DataArray, Dataset)):
orbital_data = orbital_data._data

Check warning on line 599 in src/sisl/viz/processors/orbital.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/viz/processors/orbital.py#L599

Added line #L599 was not covered by tests

original_spin_coord = orbital_data.coords[spin_dim].values

if "total" in orbital_data.coords['spin']:
spin_up = ((orbital_data.sel(spin="total") - orbital_data.sel(spin="z")) / 2).assign_coords(spin=0)
spin_down = ((orbital_data.sel(spin="total") + orbital_data.sel(spin="z")) / 2).assign_coords(spin=1)

orbital_data = xarray.concat([orbital_data, spin_up, spin_down], "spin")
else:
total = orbital_data.sum(spin_dim).assign_coords(spin="total")
z = (orbital_data.sel(spin=0) - orbital_data.sel(spin=1)).assign_coords(spin="z")

Check warning on line 610 in src/sisl/viz/processors/orbital.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/viz/processors/orbital.py#L609-L610

Added lines #L609 - L610 were not covered by tests

orbital_data = xarray.concat([total, z, orbital_data], "spin")

Check warning on line 612 in src/sisl/viz/processors/orbital.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/viz/processors/orbital.py#L612

Added line #L612 was not covered by tests

# If a reduction for spin was requested, then pass the two different functions to reduce
# each coordinate.
reduce_funcs = reduce_func
reduce_dims = orb_dim
if spin_reduce is not None and spin_dim in orbital_data.dims:
if (spin_reduce is not None or data_spin.is_polarized) and spin_dim in orbital_data.dims:
reduce_funcs = (reduce_func, spin_reduce)
reduce_dims = (orb_dim, spin_dim)

Expand Down
10 changes: 10 additions & 0 deletions src/sisl/viz/processors/tests/test_orbital.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,16 @@ def test_reduce_orbital_data(geometry, spin):
with pytest.raises(SislError):
reduced = reduce_orbital_data(data_no_geometry, [{"name": "all"}] )

def test_reduce_orbital_data_spin(geometry, spin):

data = PDOSData.toy_example(geometry=geometry, spin=spin)._data

if spin.is_polarized:
sel_total = reduce_orbital_data(data, [{"name": "all", "spin": "total"}] )
red_total = reduce_orbital_data(data, [{"name": "all"}], spin_reduce=np.sum)

assert np.allclose(sel_total.values, red_total.values)

def test_atom_data_from_orbital_data(geometry: Geometry, spin):

data = PDOSData.toy_example(geometry=geometry, spin=spin)._data
Expand Down
12 changes: 9 additions & 3 deletions src/sisl/viz/processors/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,13 @@
empty = False
for dim in reduce_dim:
selected = getattr(group_vals, dim, [])
empty = len(selected) == 0
if empty:
break
try:
empty = len(selected) == 0
if empty:
break
except TypeError:
# selected is a scalar
...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.

if empty:
# Handle the case where the selection found no matches.
Expand All @@ -128,6 +132,8 @@
if not isinstance(reduce_funcs, tuple):
reduce_funcs = tuple([reduce_funcs] * len(reduce_dim))
for dim, func in zip(reduce_dim, reduce_funcs):
if func is None or (reduce_dim not in group_vals.dims and reduce_dim in group_vals.coords):
continue
group_vals = group_vals.reduce(func, dim=dim)


Expand Down