diff --git a/src/sisl/viz/.coverage b/src/sisl/viz/.coverage index b486dd959..d52efb704 100644 Binary files a/src/sisl/viz/.coverage and b/src/sisl/viz/.coverage differ diff --git a/src/sisl/viz/figure/plotly.py b/src/sisl/viz/figure/plotly.py index 264aa8a2c..88ddc0887 100644 --- a/src/sisl/viz/figure/plotly.py +++ b/src/sisl/viz/figure/plotly.py @@ -120,8 +120,9 @@ def _iter_subplots(self, plot_actions): action_name = action['method'] if action_name.startswith("draw_"): action = {**action, "kwargs": {**action.get("kwargs", {}), **row_col_kwargs}} + action['kwargs']['meta'] = {**action['kwargs'].get('meta', {}), "i_plot": i} elif action_name.startswith("set_ax"): - action = {**action, "kwargs": {**action.get("kwargs", {}), "_active_axes": active_axes}} + action = {**action, "kwargs": {**action.get("kwargs", {}), "_active_axes": active_axes}} sanitized_section_actions.append(action) @@ -157,12 +158,28 @@ def _iter_multiaxis(self, plot_actions): action_name = action['method'] if action_name.startswith("draw_"): action = {**action, "kwargs": {**action.get("kwargs", {}), **active_axes_kwargs}} + action['kwargs']['meta'] = {**action['kwargs'].get('meta', {}), "i_plot": i} elif action_name.startswith("set_ax"): action = {**action, "kwargs": {**action.get("kwargs", {}), "_active_axes": active_axes}} sanitized_section_actions.append(action) yield sanitized_section_actions + + def _iter_same_axes(self, plot_actions): + + for i, section_actions in enumerate(plot_actions): + + sanitized_section_actions = [] + for action in section_actions: + action_name = action['method'] + if action_name.startswith("draw_"): + action = {**action, "kwargs": action.get("kwargs", {})} + action['kwargs']['meta'] = {**action['kwargs'].get('meta', {}), "i_plot": i} + + sanitized_section_actions.append(action) + + yield sanitized_section_actions def _init_figure_animated(self, frame_names: Optional[Sequence[str]] = None, frame_duration: int = 500, transition: int = 300, redraw: bool = False, **kwargs): self._animation_settings = { @@ -190,7 +207,14 @@ def _iter_animation(self, plot_actions): frames = [] for i, section_actions in enumerate(plot_actions): - yield section_actions + sanitized_section_actions = [] + for action in section_actions: + action_name = action['method'] + if action_name.startswith("draw_"): + action = {**action, "kwargs": action.get("kwargs", {})} + action['kwargs']['meta'] = {**action['kwargs'].get('meta', {}), "i_plot": i} + + yield sanitized_section_actions # Create a frame and append it frames.append(go.Frame(name=frame_names[i],data=self.figure.data, layout=self.figure.layout)) diff --git a/src/sisl/viz/plots/merged.py b/src/sisl/viz/plots/merged.py index a9640ffb8..0cc1fbf65 100644 --- a/src/sisl/viz/plots/merged.py +++ b/src/sisl/viz/plots/merged.py @@ -6,7 +6,7 @@ def merge_plots(*figures: Figure, - composite_method: Optional[Literal["multiple", "subplots", "multiple_x", "multiple_y", "animation"]] = None, + composite_method: Optional[Literal["multiple", "subplots", "multiple_x", "multiple_y", "animation"]] = "multiple", backend: Literal["plotly", "matplotlib", "py3dmol", "blender"] = "plotly", **kwargs ) -> Figure: diff --git a/src/sisl/viz/processors/orbital.py b/src/sisl/viz/processors/orbital.py index 18495856c..8edba6aa1 100644 --- a/src/sisl/viz/processors/orbital.py +++ b/src/sisl/viz/processors/orbital.py @@ -560,10 +560,13 @@ def reduce_orbital_data(orbital_data: Union[DataArray, Dataset], groups: Sequenc if isinstance(sanitize_group, OrbitalQueriesManager): sanitize_group = sanitize_group.sanitize_query - if geometry is None: - def _sanitize_group(group): - group = group.copy() - group = sanitize_group(group) + data_spin = orbital_data.attrs.get("spin", Spin("")) + + def _sanitize_group(group): + group = group.copy() + group = sanitize_group(group) + + if geometry is None: orbitals = group.get('orbitals') try: group['orbitals'] = np.array(orbitals, dtype=int) @@ -571,29 +574,48 @@ def _sanitize_group(group): except: raise SislError("A geometry was neither provided nor found in the xarray object. Therefore we can't" f" convert the provided atom selection ({orbitals}) to an array of integers.") + else: + group["orbitals"] = geometry._sanitize_orbs(group["orbitals"]) - group['selector'] = group['orbitals'] - if spin_reduce is not None and spin_dim in orbital_data.dims: - group['selector'] = (group['selector'], group.get('spin')) - group['reduce_func'] = (group.get('reduce_func', reduce_func), spin_reduce) + group['selector'] = group['orbitals'] - return group - else: - def _sanitize_group(group): - group = group.copy() - group = sanitize_group(group) - group["orbitals"] = geometry._sanitize_orbs(group["orbitals"]) - group['selector'] = group['orbitals'] - if spin_reduce is not None and spin_dim in orbital_data.dims: - group['selector'] = (group['selector'], group.get('spin')) - group['reduce_func'] = (group.get('reduce_func', reduce_func), spin_reduce) - return group + req_spin = group.get("spin") + if req_spin is None and data_spin.is_polarized and spin_dim in orbital_data.coords: + if spin_reduce is None: + group['spin'] = original_spin_coord + else: + group['spin'] = [0, 1] + + if (spin_reduce is not None or group.get("spin") is not None) and spin_dim in orbital_data.dims: + group['selector'] = (group['selector'], group.get('spin')) + group['reduce_func'] = (group.get('reduce_func', reduce_func), spin_reduce) + + return group + + original_spin_coord = None + if data_spin.is_polarized and spin_dim in orbital_data.coords: + + if not isinstance(orbital_data, (DataArray, Dataset)): + orbital_data = orbital_data._data + + original_spin_coord = orbital_data.coords[spin_dim].values + + if "total" in orbital_data.coords['spin']: + spin_up = ((orbital_data.sel(spin="total") - orbital_data.sel(spin="z")) / 2).assign_coords(spin=0) + spin_down = ((orbital_data.sel(spin="total") + orbital_data.sel(spin="z")) / 2).assign_coords(spin=1) + + orbital_data = xarray.concat([orbital_data, spin_up, spin_down], "spin") + else: + total = orbital_data.sum(spin_dim).assign_coords(spin="total") + z = (orbital_data.sel(spin=0) - orbital_data.sel(spin=1)).assign_coords(spin="z") + + orbital_data = xarray.concat([total, z, orbital_data], "spin") # If a reduction for spin was requested, then pass the two different functions to reduce # each coordinate. reduce_funcs = reduce_func reduce_dims = orb_dim - if spin_reduce is not None and spin_dim in orbital_data.dims: + if (spin_reduce is not None or data_spin.is_polarized) and spin_dim in orbital_data.dims: reduce_funcs = (reduce_func, spin_reduce) reduce_dims = (orb_dim, spin_dim) diff --git a/src/sisl/viz/processors/tests/test_orbital.py b/src/sisl/viz/processors/tests/test_orbital.py index a540ab734..198174a13 100644 --- a/src/sisl/viz/processors/tests/test_orbital.py +++ b/src/sisl/viz/processors/tests/test_orbital.py @@ -196,6 +196,16 @@ def test_reduce_orbital_data(geometry, spin): with pytest.raises(SislError): reduced = reduce_orbital_data(data_no_geometry, [{"name": "all"}] ) +def test_reduce_orbital_data_spin(geometry, spin): + + data = PDOSData.toy_example(geometry=geometry, spin=spin)._data + + if spin.is_polarized: + sel_total = reduce_orbital_data(data, [{"name": "all", "spin": "total"}] ) + red_total = reduce_orbital_data(data, [{"name": "all"}], spin_reduce=np.sum) + + assert np.allclose(sel_total.values, red_total.values) + def test_atom_data_from_orbital_data(geometry: Geometry, spin): data = PDOSData.toy_example(geometry=geometry, spin=spin)._data diff --git a/src/sisl/viz/processors/xarray.py b/src/sisl/viz/processors/xarray.py index 8966af80f..0c6744a14 100644 --- a/src/sisl/viz/processors/xarray.py +++ b/src/sisl/viz/processors/xarray.py @@ -106,9 +106,13 @@ def group_reduce(data: Union[DataArray, Dataset, XarrayData], groups: Sequence[G empty = False for dim in reduce_dim: selected = getattr(group_vals, dim, []) - empty = len(selected) == 0 - if empty: - break + try: + empty = len(selected) == 0 + if empty: + break + except TypeError: + # selected is a scalar + ... if empty: # Handle the case where the selection found no matches. @@ -128,6 +132,8 @@ def group_reduce(data: Union[DataArray, Dataset, XarrayData], groups: Sequence[G if not isinstance(reduce_funcs, tuple): reduce_funcs = tuple([reduce_funcs] * len(reduce_dim)) for dim, func in zip(reduce_dim, reduce_funcs): + if func is None or (reduce_dim not in group_vals.dims and reduce_dim in group_vals.coords): + continue group_vals = group_vals.reduce(func, dim=dim)