From b1d27b52515bf43acbacf70c61384f488dfa2975 Mon Sep 17 00:00:00 2001 From: Pol Febrer Date: Fri, 29 Sep 2023 17:29:00 +0200 Subject: [PATCH 1/3] bug: fix for spin polarized selections --- src/sisl/viz/processors/orbital.py | 41 ++++++++++++++++++- src/sisl/viz/processors/tests/test_orbital.py | 10 +++++ src/sisl/viz/processors/xarray.py | 12 ++++-- 3 files changed, 58 insertions(+), 5 deletions(-) diff --git a/src/sisl/viz/processors/orbital.py b/src/sisl/viz/processors/orbital.py index 18495856c..0c3431786 100644 --- a/src/sisl/viz/processors/orbital.py +++ b/src/sisl/viz/processors/orbital.py @@ -560,6 +560,8 @@ def reduce_orbital_data(orbital_data: Union[DataArray, Dataset], groups: Sequenc if isinstance(sanitize_group, OrbitalQueriesManager): sanitize_group = sanitize_group.sanitize_query + data_spin = orbital_data.attrs.get("spin", Spin("")) + if geometry is None: def _sanitize_group(group): group = group.copy() @@ -571,6 +573,13 @@ def _sanitize_group(group): except: raise SislError("A geometry was neither provided nor found in the xarray object. Therefore we can't" f" convert the provided atom selection ({orbitals}) to an array of integers.") + + req_spin = group.get("spin") + if req_spin is None and data_spin.is_polarized and spin_dim in orbital_data.coords: + if spin_reduce is None: + group['spin'] = original_spin_coord + else: + group['spin'] = [0, 1] group['selector'] = group['orbitals'] if spin_reduce is not None and spin_dim in orbital_data.dims: @@ -584,16 +593,44 @@ def _sanitize_group(group): group = sanitize_group(group) group["orbitals"] = geometry._sanitize_orbs(group["orbitals"]) group['selector'] = group['orbitals'] - if spin_reduce is not None and spin_dim in orbital_data.dims: + + req_spin = group.get("spin") + if req_spin is None and data_spin.is_polarized and spin_dim in orbital_data.coords: + if spin_reduce is None: + group['spin'] = original_spin_coord + else: + group['spin'] = [0, 1] + + if (spin_reduce is not None or group.get("spin") is not None) and spin_dim in orbital_data.dims: group['selector'] = (group['selector'], group.get('spin')) group['reduce_func'] = (group.get('reduce_func', reduce_func), spin_reduce) + return group + original_spin_coord = None + if data_spin.is_polarized and spin_dim in orbital_data.coords: + + if not isinstance(orbital_data, (DataArray, Dataset)): + orbital_data = orbital_data._data + + original_spin_coord = orbital_data.coords[spin_dim].values + + if "total" in orbital_data.coords['spin']: + spin_up = ((orbital_data.sel(spin="total") - orbital_data.sel(spin="z")) / 2).assign_coords(spin=0) + spin_down = ((orbital_data.sel(spin="total") + orbital_data.sel(spin="z")) / 2).assign_coords(spin=1) + + orbital_data = xarray.concat([orbital_data, spin_up, spin_down], "spin") + else: + total = orbital_data.sum(spin_dim).assign_coords(spin="total") + z = (orbital_data.sel(spin=0) - orbital_data.sel(spin=1)).assign_coords(spin="z") + + orbital_data = xarray.concat([total, z, orbital_data], "spin") + # If a reduction for spin was requested, then pass the two different functions to reduce # each coordinate. reduce_funcs = reduce_func reduce_dims = orb_dim - if spin_reduce is not None and spin_dim in orbital_data.dims: + if (spin_reduce is not None or data_spin.is_polarized) and spin_dim in orbital_data.dims: reduce_funcs = (reduce_func, spin_reduce) reduce_dims = (orb_dim, spin_dim) diff --git a/src/sisl/viz/processors/tests/test_orbital.py b/src/sisl/viz/processors/tests/test_orbital.py index a540ab734..198174a13 100644 --- a/src/sisl/viz/processors/tests/test_orbital.py +++ b/src/sisl/viz/processors/tests/test_orbital.py @@ -196,6 +196,16 @@ def test_reduce_orbital_data(geometry, spin): with pytest.raises(SislError): reduced = reduce_orbital_data(data_no_geometry, [{"name": "all"}] ) +def test_reduce_orbital_data_spin(geometry, spin): + + data = PDOSData.toy_example(geometry=geometry, spin=spin)._data + + if spin.is_polarized: + sel_total = reduce_orbital_data(data, [{"name": "all", "spin": "total"}] ) + red_total = reduce_orbital_data(data, [{"name": "all"}], spin_reduce=np.sum) + + assert np.allclose(sel_total.values, red_total.values) + def test_atom_data_from_orbital_data(geometry: Geometry, spin): data = PDOSData.toy_example(geometry=geometry, spin=spin)._data diff --git a/src/sisl/viz/processors/xarray.py b/src/sisl/viz/processors/xarray.py index 8966af80f..0c6744a14 100644 --- a/src/sisl/viz/processors/xarray.py +++ b/src/sisl/viz/processors/xarray.py @@ -106,9 +106,13 @@ def group_reduce(data: Union[DataArray, Dataset, XarrayData], groups: Sequence[G empty = False for dim in reduce_dim: selected = getattr(group_vals, dim, []) - empty = len(selected) == 0 - if empty: - break + try: + empty = len(selected) == 0 + if empty: + break + except TypeError: + # selected is a scalar + ... if empty: # Handle the case where the selection found no matches. @@ -128,6 +132,8 @@ def group_reduce(data: Union[DataArray, Dataset, XarrayData], groups: Sequence[G if not isinstance(reduce_funcs, tuple): reduce_funcs = tuple([reduce_funcs] * len(reduce_dim)) for dim, func in zip(reduce_dim, reduce_funcs): + if func is None or (reduce_dim not in group_vals.dims and reduce_dim in group_vals.coords): + continue group_vals = group_vals.reduce(func, dim=dim) From 6d132c2d1de4a1d56cdbccff9e2f48fc0ca2c0da Mon Sep 17 00:00:00 2001 From: Pol Febrer Date: Fri, 29 Sep 2023 17:54:06 +0200 Subject: [PATCH 2/3] enh: added plot index by default to merged plots in plotly --- src/sisl/viz/.coverage | Bin 106496 -> 106496 bytes src/sisl/viz/figure/plotly.py | 28 ++++++++++++++++++++++++++-- src/sisl/viz/plots/merged.py | 2 +- 3 files changed, 27 insertions(+), 3 deletions(-) diff --git a/src/sisl/viz/.coverage b/src/sisl/viz/.coverage index b486dd959aac99a14655a10de4b0191bc366fdb0..d52efb7045eb31b2598df56485585fddc65ed89e 100644 GIT binary patch delta 5386 zcmYjV30zdw8vo9{cjnIAS-~YG1>6Ni0aq4bM3j9I1r!m`M8gzZQ=b}hs`q-meCddI z(k#tf0#r~EQbcKerWu$DiGpOMi76zx0P4K&-ofnf=wabEYPSdM$CbY7F4uoa=ik6B!ZCNJX`U0(Xlz^!VcDcyg4HDMo( zaEP>}2Y*Q(#*c#COf-{JoeyauHZlbs!fu;wG@n5GvQ3Lf;MVCx$NB~lgQ|}J$onab zgtIg!62NK#h=v^xB9h8Qk=VjWuzr5lf|biMM=V(rOMA1!qe&o(4<^oh9PPy##*+!` z*ka;b{)RJYGuZNfC=c&PJSEl9aDh58pAa&JWeg)~Y2}KnMayvu_*A~MpTcCF!-!c` z7A8>Bb4_h_&OG@nD9xi6yYzNC5s8B%?2iG@EE|UtJvp3LQ|rgFy@*&o%!^#%<`#kJ zpefa)H6An0H<}D54T}vf`fv1W^j^9fx-GiNQoFQUiqy7h_h@HooiwL4>omi~i{h)I zi@HKRM-8gIsu?PSa6(urxbrpqoBT-b8n>B?CLN@JB*SAUve_o_u{6dJ1+6dXPNJEs zFG*2FVYmjc^YLU7tMVlpREm_9!r0CfB8qV|;xD;w?wss-awsOzaAt8PllVv)#;krM zaC5r)FJTOw!hC#4h%r4qD?4jNdO8{;(8)^rSb8>*OzF!q7i2A8k+}>d5j2!hA508N z49s~niAM)QWC!BOL^j8XjA!xwM9aKliHH)5!+*C$M8%f)5d$Aiga4*w4n?g<8YDXv z$u#&Mv&&SZQ%qF0Fa*@dOVB#a_m^`OWh2z9ep`%$< z?=HFSSrJX9qw+G^9Y#`Dy@XR=`3m7^@AJ%F&9F{%=%x2_FIR;SBRVlsb|R6VPDiku z>}Qv&9*^1))XNb|B|MvAh(?U2!eO$)*O7P&D&Uj&1cP z-N;Z@!lTnjYglX;k@%@}2y-1prZVdU61jCI(P9<`DFq_US&_3CQy5DJ%IgU%V+I-> zz-*y&17zpASe}8ziS%bBCfrzQ)0$LF-|*S!O$_ZP7mL!DBCC-PjHm7nRm{5}_p8Wq z!ZFja)UE5-)tPb;##2{^o;87(SZfN=zAUL%fft+;q%_G{ds_Rt)?d@2*{ey{IOzxJ zuIe@$zc4OVP1ae&D&c`ps2*-ARyCT8;wxfz^+`0F#&_}&{72kvZX@SSYRO9CBDEM@ z4QCDO48v7Se^tNHl%_tbVm4bQqm!_K{){E5LWV#m$|x{}Sz}30a)NnYso?OS~#jG$;VXd$pKa}6B#c~BHdL*lLe|{6O*vacSRGC zI}Hj*Dc0Cq8sT$9`R6QbB#r8nR0Ls&8$zOhMT3`^z;;DpI$EQN(cuoBWo8YD*q>=p zWFI+~clfGb-nDjhxjmE^I8q=-$?vaX(y2btFM?2JOD5|TAm(73!^O`QlWrWmm$&4+ zpWN;O3yvfo2&Oa`V6(Z&o8YrXEj@!D_+QuKM{W2U#=tiU*HJN8Njmb89FS1v9*z%% z9SH)KE#3hS=qf2A*(EE)lI<)#3hyRklkxA^1YiFZVnM+t4!)Alh!-$cQ*AZ`(}-sQ zm%%6vsw7Px@PvF&Ur}u}c&j3XKZPCo2JxJFgZ^KFML$N@tjpJhN^R0d;yNi-`$YAN zI7A(wEz%}y)%>sg+x&2@hFiw<(=QCH>;eWeYOp>`kYM zG6E-&=?Z8vLk=RuH1=O_GMS~ybb@RhB6TG@flht?eudfAA>770;I=Er(KJrBjz-9i zl_5JOTfWjLI1Hv6rUuj3rV7(LCTdDHjWX$sPmC?btHwRXca3itlZ@fUK%>#{yWwYi zui0iuH6$4N8oC)I{Vjd1zFc3Ze@h>(7jzA}V>*)r+QI64b)0&LDq0xF?dIOW_m?-} z0vtvhDlb`rZ)#Cv4Fb_-8{)}RyWO6vhb4LdHCWd)xs_}Cyu356>AzN8M=&grfFtC< zx;xEH_#(Rtra0-*ZlK;=GIBxNV|)9OrVTH#me)w5`JZ}d0;t)v(du(Id-MviydjsA z7(MEK{$qDlKabCaO(AF2N|}$3g4trlJL9mzQ29e6PXia1m%WR9L3%TB6A4C#bS=6j z-8x;c)FS;VRZ4rMrBb34B?U@O+D`3d?HX;4c8)emYtfF?G-@ttPGNHwX!10^nto!w z__nxOTqMpF6UE8mIB}@>qG(XJtG`s2sSDIg)fwtUwU^pmbxU{ zKZGiwOn6BM6TAWi4}s?&@i+J%_|yCezJ%Y$Z{`>B(|I4hAFtsqbGx`4ZY~$V4d!|u zwhbZ7Z8Op4Zhrh&?Mz=0IPE~VeK+;t{nY9)?W+&nO}JCg zT=0)ejYUaCiOC0>ACw(0%zmhnXcnq#TgR_D=fxV%Bo!;I)&xZU&6ar3=tv5!}oMxLTI<y^@!gNq6e6(xO` zcB!(gq8{hMve~!Vs-LzOH|Oto`A(pzC(T6l;qdEN>$S3jDM#lO%&Z(6m6)G9$Ux_z z#07rYer4c1%AE@KURe0pNr04Joa|krNs93_WQTXJL;}o3+_heqRL>n zo^7q{zuO$rZ~C32WUoPV4ss*mr@Q2{iuX4h^~zsSo-}pa(Xy0lXG#NtmoF&Z@!`1FEF5q_efNs4)`{|mioB9gvv$&)~xClmuT zOZc?2_5k;7@pQThU3P_w&6Q?&9JZ36xp`Z!FHVn_AcF&I5^U)KFTf#WTRx1~EP=JQ z2U-dq@x-!`gXF!i=zea;C*b?R=qtCPmv8V+ixcTIJkJOh29KZ#NTtFL^}%N7)?~Lk zoKrjmYm413I>@ehmo%A9mEF<6`Qx465O+uhXef`t(@S8>g_-0}Rn=qXw4ilS)oyVB zD`rLe9{031t8%UP{we03Z%ZbX9HfydPq6f+@yeZR>Q&s299*eC<>}U+w|h3`&V-U!|w2k@mef#=GN;F8ZKk0dNZ!bNafQw#G8UG~*@kB=h z70nY1BLO%A4ZxKPRAA9+p@x^JzhZgZ^NnzgNd0hOgkx$g9gmbVeCDYe>S3hgaM1_K z7Z!L}tQPl2S~^ydmK~XLfokX&tgY9LaI}Nn+fGtvr0m-7aem_-U$wg#!s~0}_CEgP z!SuC-Z6E)JPYMs8)x9P)4mZ-#%9fvwAgK^cz&(wQk}K37N;^CU{HMDQ44S`r%jSX+ z>}Y-2b+6C6?|f9miryzdx!tHQS`LOIu{}lu)L~#P;3 zksb)eo}-lUq`Ix0h;)Radgy5)^^(sV3I|&s++F?K#RoUfvD|4bfy1J1~CCOo#TJ|#9ME@wP3=hQm*9ySg&{1);8wLzeinfdBfYvbII@5cb6Ub zZqrupDXri8Z7h%xq4~u1HY^u;v>v;igUGPbYqquJ@;l9ICaB!#U_2@a_T9bV-X5eV z33mdX6g51$zI?=D2{d9iI!JMN@2h7!75VQuO*eDNP&yFV@vx_C>57d|S< z)V?n7Z2_x%9Si(h@S*M~n+v--&M2=H8mk5<$~gQ3O)i)@*sQ#2Xn>#r$Khn{LEX^E z6LzK?t6o+4G?~By37Uej#a)pf2_M$amID;vRe_sWV^r4K(%Eqb>x+_Jx$XJr6#ke( z`=Xu^KHy?#A35;7V8^R%Z^b11+$<*2Nlg?b3&@r2qK_CS;S~{w*pj zfq@YNISo>RyRw_Sx%@A)2$3lnU6kEJW#0&K0<$c1jQ(7ID`@nzH+t9$wk7m-ap`C^ j(q6di1MeLP+TXM$v9eI2FCe2kF!ROFg3Ap%NE`ou;u?vh delta 5521 zcmY*d30zgx)?fP!_uPB#$YkKeRdGZHClnb3ML}dz859tlQ)GIIrKoUsef7+GCSe{m zC4H3|Bw&gdiBM*mnxFch25xueJ7CYybCupM$MlW2@Ix z-i_@W0|1V+C*B=0@MiJj@Gh9@@D5$fd(7!(v#D;cgQeU=T$Eq?5Pzj{JkcuMlZZ?j z_nOoHvKja@{2P2Q_XoF^n^WxE8p8?hd*`-xki8Yx4iaV)s-K&)WX`O`-nq-crfl~{ zCm~&lZtw0NlRWw1uw0QWB-Jg5{7OD1X~Yq}gyn2{^_oC?D)nw;lrqzkX!J|-7G^F% zH_^0*QY1W;Dp{k51LWE%skSuPA5U7K4 zI)sc+Yq&to&oy;V=Z4M2^t-2uEIjmI1g)z zih4^?G)4ec-b4~A&zYB*iIti{qm@J1WC|ZmqwIBOAb2O;$WNq`l&hI!A|FE|l`n!w zpzy~dLCVbGM3(2y%A17-u{4Z%h*cVgkvN=fESoJ*sT@v3B`APs(JJ(rl@ke3EF(y$ zVx2%FDUFW7L_(EbBguHBE|j?NDKuF5VgZ_O^CL!O#5gt3AQq@y)Fo%}?1g!=GSi<< zo}(8#o$uXHqafYo~=0KLUfUuGYJ74uOsJ}Htopbu}m44XR#8XceES4w-XFi!q%zPXj#Fi14DjT_= zJeU~CKszp&Go{giNJ?%r+Q!iVOfk?diXB7Dr2n(!lI2;;)JfxLKc*GOPoaGo&cc;T zDMXKHd8ld8XcB?R_F>7YQ3&SF{T4-g+cS?QV+`{bzPNNT+t@L*m%Vx{5L0^k!t8~4 z>FGEyiCT+We;Z||u&gXBR5W#e4rrxEG%;q#!UFJyV;~tdUua&|3=w}8OT>85B-G3A z%26^W=!ZD39QZiqeoSV#!Hm`#2KGU}PpWitkti(Ps)x!}=mnRS+4- z0j4uutWFfD7o(G)=h)$3zp_0UB@8+oR2Bbv#$SaA0}coKl((aBP~9?!8(En;n2ff= zD0>3Yl24VMF`eTKJQ~;|C4-?jzk&#I63eyU!qB23 zZ%hKidAjI&@-cH5rqm+<$cIV{i%cCfGe5i7V2t=*`62`r75Uo(&vj&hH!E2&B%U;> z@jnBS@gM64+lzBg``e>xm_)V;1`!4q6%7zNs}ijEW)df5VK~`oI&Rzy#?E5|Fvf9; zdo(eK?~fO#UOAmaf<%M_Y*Y?SAc<-UYKh`viOEjI>P}-Mqrnpi4@H)F6!}CYZgmW4 zMl06O6p&<|uLv?t7qM=JfK2gRh+N}D4k^Afuof!$%ncw@?42Aj0O(=f9Zl* zc{y2~gan_<(rYjfPJ*%3_?ppQY!lx$T-WxN{0wi4!wkXt`}z`1kUm;|ES;8j$w`{E zGN<`TSE)4oowm4Z8en4ib%xSd=$xkXkJPdI7bWDe3E(Eh8aXfcCckkXn?x+wYc z5h&Z!Nje@D$jevLNNn|uI5JyM7iVS?F)MGSVD{(K$|KQ-DT5PndCzu^TeOQ8&SmvI zV}}`?uh;VZ(_dHle!c!3VIE94~M8T;)s2fC-64@a2) zDP#w;oJ_uYreJ(ZRnglLUz}Eshh9UYCq9NN;jpolBedB zCP$+cKiN~XK&+lUkF1pN3ZSi62)4spn484iO= zFkd%cG9NL2Zhq5D%}M5AX1(c&>8|OL=>yY$Os|*{OyQ;glhJs~c+FUA+-OWOPBwNo zIvRC`8-^yse#6Ix*A3wYUVl-4KyT7@==y5bYhpD0#7Tl5x1C!UC{?3cUk$)5)=YY%Mefck-;6@<#)s>=8m zP($?PEC#gfZLl;d*Z=Cczh_9LgL1ckT(*qWfCXT|vG97=mq&K3;j1qdl1kIS2g?o) z$c>2fB_FQ)tz@iQp%(}iD|*7q3Nm}JmvWfct$P1j5=V%;BCR813iTbh-+$Jx)-RT? z%PsOZa-F6+9m zos{aNN@r*A_dP{kq)?`r z0;h9^Um^<%ct^SjJ05WBPWj{f^0v(@1JB(jy=evt!{BC1bz9A?@^1Z`N=i#glWx~$ zR3$HMD$UL*tGV=oOc$zder|5wKf;?Xz)mXMC@*ZE{-mi6w%tG1p8r>N&XbpDTgs}6 zCuxss|EM^0?!?1C3Y&A%8}IiRPG3Ze1<=-#dm;J3iTfYTIG?hjEPBPHtWD?BE_laJ zNmyDoBc-}5ZED?jJ`XB(Jlc7xqoSk3lg>x8M7Z9PyCSjd-RP#JO~=oc)tq}N=|Wje zQqz9VRqf};bR^Wi0n}{b8)0A{MeFp zA~SiulaWDpNJ3#AcowdMMq?4iy3Gi-f!<7$QPvY$S{|0Y)RbNIV3C$iM~MZ_XXyLTY1o_q-{H+8S9|o%fkGuW3+(Z*m*d3Y|FWuR?JfXJKL=#j&lcjlpP0F&8Av7KZ zv*EjSgB}kKskGR_@x%2ENAyr1>5*_dz9Z*>6Sli{y)ybiNApeq+lb%Y=v1^B3TI2f zw#dbn1h7;r1Q)26VO>-log!lPJ>krwG%d7$P{V>*9c7p`x4Q_w1VTLFxNOq{T7k8} zW13(RXl8p*?c>1hyMiEivAK(8msE*L9R3oTfWekd?{Z1 zQ(~@vGqXDsmb%ecHL8Z2BKJLqWvv_JepiG?4LgI04!^$GYMiGK(9BjKV`<>GXB;E(s`vm^^W=WlV!v*fi zkmyFE)Y(pEEXtOxI6p+RtMLGY7)^8%4t9qVr3JSMKpZjCNEGyh+ADd@%gttUL;eC6+cW}49`4><>nWnqS68H2APMDuRRfuijlH~P zz16yPNzbXw=XcAGv>mkv-n!)BNZb3PN@{nNoSQ1s@i@6N94X!3!VRZk*jWHyRo;C6 z{Oh4k`+~s|*j6|CcFGt6zo-9>_KINx=~6gOwK-hc=MDxp)*jrK_gX&xA#~^dF*%3+ zelNf6{HEALKc@ef=e@}%;CsJSGyct=Ivd;j$g(@dS^)nZ+~$2@C7G^AH?9x(aCs9J zFNDQxyBhfVXWYyyL%NSL9VIdiRp+O!ytjR_S1Ale0C%YFX((15nT#K5stBuBeXtD+>&C zv}#k63?Ur+E>nN(I>TozR!7|Jqp)Lw0~#G2i7iK{?h4vMiTbJH{q0&h0$UcS`aQ^j z4##G1*q2egWbF?;NYJX4=`fTz!`{-CDF!_qik&X-Y4X&wroIw>dotI13i?xDY{fw3 zp;70DR-{&pe6uz`tp8r8`&Fd}jpyWbU)fqWUT?m#6~8vA5A#q66^-}Xw(t`5R$Y|G zIZ-cc^@H7q(wZLCUMYK0m(uaL(1i|0=^*&z=FPH#iTx6;2Czy)f`NLX+!e|)KDt?2 zEJ<{b>ZX))qXV%u6h5xZ(7nP`L%_zVkJZ8uIsnIJ!>$6qLiKFrtXzl{7Y=_6+5ZPN z4}4%V1hHKFn)Ae2^Ou{!kJCqhO=>$$~$YJzltN)tePX zY!(~q;|>2XThv7%OB*j{1M9$|ny@z5!%+J+wJ)gl9Cip4a@I83n?*1h{+nj3-*Gpi zp~~k`C~n#aZrxne9|iBA&o{!@G&T^dD8fD(tA4|ueLvVldt$3Qlq8Z=!2NF3(Hu-clapNso@W{Hgb4o${iqfE_zb5D5BCdB0e5+SZ7bn^c zJ@kY3VqAs;oWqbpU=aRisjEHR%{5gHxP+YV*lQ)By1$X m9ujp^W82hIM_tuWHYQsBrKc}oy9+25@VnEa+PH<>=l=)FiKU+a diff --git a/src/sisl/viz/figure/plotly.py b/src/sisl/viz/figure/plotly.py index 264aa8a2c..88ddc0887 100644 --- a/src/sisl/viz/figure/plotly.py +++ b/src/sisl/viz/figure/plotly.py @@ -120,8 +120,9 @@ def _iter_subplots(self, plot_actions): action_name = action['method'] if action_name.startswith("draw_"): action = {**action, "kwargs": {**action.get("kwargs", {}), **row_col_kwargs}} + action['kwargs']['meta'] = {**action['kwargs'].get('meta', {}), "i_plot": i} elif action_name.startswith("set_ax"): - action = {**action, "kwargs": {**action.get("kwargs", {}), "_active_axes": active_axes}} + action = {**action, "kwargs": {**action.get("kwargs", {}), "_active_axes": active_axes}} sanitized_section_actions.append(action) @@ -157,12 +158,28 @@ def _iter_multiaxis(self, plot_actions): action_name = action['method'] if action_name.startswith("draw_"): action = {**action, "kwargs": {**action.get("kwargs", {}), **active_axes_kwargs}} + action['kwargs']['meta'] = {**action['kwargs'].get('meta', {}), "i_plot": i} elif action_name.startswith("set_ax"): action = {**action, "kwargs": {**action.get("kwargs", {}), "_active_axes": active_axes}} sanitized_section_actions.append(action) yield sanitized_section_actions + + def _iter_same_axes(self, plot_actions): + + for i, section_actions in enumerate(plot_actions): + + sanitized_section_actions = [] + for action in section_actions: + action_name = action['method'] + if action_name.startswith("draw_"): + action = {**action, "kwargs": action.get("kwargs", {})} + action['kwargs']['meta'] = {**action['kwargs'].get('meta', {}), "i_plot": i} + + sanitized_section_actions.append(action) + + yield sanitized_section_actions def _init_figure_animated(self, frame_names: Optional[Sequence[str]] = None, frame_duration: int = 500, transition: int = 300, redraw: bool = False, **kwargs): self._animation_settings = { @@ -190,7 +207,14 @@ def _iter_animation(self, plot_actions): frames = [] for i, section_actions in enumerate(plot_actions): - yield section_actions + sanitized_section_actions = [] + for action in section_actions: + action_name = action['method'] + if action_name.startswith("draw_"): + action = {**action, "kwargs": action.get("kwargs", {})} + action['kwargs']['meta'] = {**action['kwargs'].get('meta', {}), "i_plot": i} + + yield sanitized_section_actions # Create a frame and append it frames.append(go.Frame(name=frame_names[i],data=self.figure.data, layout=self.figure.layout)) diff --git a/src/sisl/viz/plots/merged.py b/src/sisl/viz/plots/merged.py index a9640ffb8..0cc1fbf65 100644 --- a/src/sisl/viz/plots/merged.py +++ b/src/sisl/viz/plots/merged.py @@ -6,7 +6,7 @@ def merge_plots(*figures: Figure, - composite_method: Optional[Literal["multiple", "subplots", "multiple_x", "multiple_y", "animation"]] = None, + composite_method: Optional[Literal["multiple", "subplots", "multiple_x", "multiple_y", "animation"]] = "multiple", backend: Literal["plotly", "matplotlib", "py3dmol", "blender"] = "plotly", **kwargs ) -> Figure: From 898e76c75c5df4f54e7e5b2efb66f1200d12bf44 Mon Sep 17 00:00:00 2001 From: Pol Febrer Date: Sat, 30 Sep 2023 00:47:36 +0200 Subject: [PATCH 3/3] maint: single sanitize_group function in orbital reduce --- src/sisl/viz/processors/orbital.py | 51 +++++++++++------------------- 1 file changed, 18 insertions(+), 33 deletions(-) diff --git a/src/sisl/viz/processors/orbital.py b/src/sisl/viz/processors/orbital.py index 0c3431786..8edba6aa1 100644 --- a/src/sisl/viz/processors/orbital.py +++ b/src/sisl/viz/processors/orbital.py @@ -562,10 +562,11 @@ def reduce_orbital_data(orbital_data: Union[DataArray, Dataset], groups: Sequenc data_spin = orbital_data.attrs.get("spin", Spin("")) - if geometry is None: - def _sanitize_group(group): - group = group.copy() - group = sanitize_group(group) + def _sanitize_group(group): + group = group.copy() + group = sanitize_group(group) + + if geometry is None: orbitals = group.get('orbitals') try: group['orbitals'] = np.array(orbitals, dtype=int) @@ -573,39 +574,23 @@ def _sanitize_group(group): except: raise SislError("A geometry was neither provided nor found in the xarray object. Therefore we can't" f" convert the provided atom selection ({orbitals}) to an array of integers.") - - req_spin = group.get("spin") - if req_spin is None and data_spin.is_polarized and spin_dim in orbital_data.coords: - if spin_reduce is None: - group['spin'] = original_spin_coord - else: - group['spin'] = [0, 1] - - group['selector'] = group['orbitals'] - if spin_reduce is not None and spin_dim in orbital_data.dims: - group['selector'] = (group['selector'], group.get('spin')) - group['reduce_func'] = (group.get('reduce_func', reduce_func), spin_reduce) - - return group - else: - def _sanitize_group(group): - group = group.copy() - group = sanitize_group(group) + else: group["orbitals"] = geometry._sanitize_orbs(group["orbitals"]) - group['selector'] = group['orbitals'] - req_spin = group.get("spin") - if req_spin is None and data_spin.is_polarized and spin_dim in orbital_data.coords: - if spin_reduce is None: - group['spin'] = original_spin_coord - else: - group['spin'] = [0, 1] + group['selector'] = group['orbitals'] + + req_spin = group.get("spin") + if req_spin is None and data_spin.is_polarized and spin_dim in orbital_data.coords: + if spin_reduce is None: + group['spin'] = original_spin_coord + else: + group['spin'] = [0, 1] - if (spin_reduce is not None or group.get("spin") is not None) and spin_dim in orbital_data.dims: - group['selector'] = (group['selector'], group.get('spin')) - group['reduce_func'] = (group.get('reduce_func', reduce_func), spin_reduce) + if (spin_reduce is not None or group.get("spin") is not None) and spin_dim in orbital_data.dims: + group['selector'] = (group['selector'], group.get('spin')) + group['reduce_func'] = (group.get('reduce_func', reduce_func), spin_reduce) - return group + return group original_spin_coord = None if data_spin.is_polarized and spin_dim in orbital_data.coords: