From 76b5ff6139a0e663f6083424e73322d2f00b7533 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 12 Nov 2024 03:46:47 -0500 Subject: [PATCH 01/14] fix(dpmodel): fix precision Signed-off-by: Jinzhe Zeng --- .github/workflows/test_python.yml | 1 + deepmd/dpmodel/common.py | 92 +++++++++++++++++++++++ deepmd/dpmodel/descriptor/dpa1.py | 3 + deepmd/dpmodel/descriptor/dpa2.py | 3 + deepmd/dpmodel/descriptor/se_e2_a.py | 14 ++-- deepmd/dpmodel/descriptor/se_r.py | 3 +- deepmd/dpmodel/descriptor/se_t.py | 4 +- deepmd/dpmodel/descriptor/se_t_tebd.py | 5 +- deepmd/dpmodel/fitting/general_fitting.py | 23 ++++-- deepmd/jax/env.py | 3 + 10 files changed, 131 insertions(+), 20 deletions(-) diff --git a/.github/workflows/test_python.yml b/.github/workflows/test_python.yml index ba8858d6b9..5f48b14131 100644 --- a/.github/workflows/test_python.yml +++ b/.github/workflows/test_python.yml @@ -62,6 +62,7 @@ jobs: env: NUM_WORKERS: 0 DP_TEST_TF2_ONLY: 1 + DP_DTYPE_PROMOTION_STRICT: 1 if: matrix.group == 1 - run: mv .test_durations .test_durations_${{ matrix.group }} - name: Upload partial durations diff --git a/deepmd/dpmodel/common.py b/deepmd/dpmodel/common.py index 6e6113b494..63f3a34105 100644 --- a/deepmd/dpmodel/common.py +++ b/deepmd/dpmodel/common.py @@ -3,8 +3,12 @@ ABC, abstractmethod, ) +from functools import ( + wraps, +) from typing import ( Any, + Callable, Optional, ) @@ -116,6 +120,94 @@ def to_numpy_array(x: Any) -> Optional[np.ndarray]: return np.from_dlpack(x) +def cast_precision(func: Callable[..., Any]) -> Callable[..., Any]: + """A decorator that casts and casts back the input + and output tensor of a method. + + The decorator should be used in a classmethod. + + The decorator will do the following thing: + (1) It casts input arrays from the global precision + to precision defined by property `precision`. + (2) It casts output arrays from `precision` to + the global precision. + (3) It checks inputs and outputs and only casts when + input or output is an array and its dtype matches + the global precision and `precision`, respectively. + If it does not match (e.g. it is an integer), the decorator + will do nothing on it. + + The decorator supports the array API. + + Returns + ------- + Callable + a decorator that casts and casts back the input and + output array of a method + + Examples + -------- + >>> class A: + ... def __init__(self): + ... self.precision = "float32" + ... + ... @cast_precision + ... def f(x: Array, y: Array) -> Array: + ... return x**2 + y + """ + + @wraps(func) + def wrapper(self, *args, **kwargs): + # only convert tensors + returned_tensor = func( + self, + *[safe_cast_array(vv, "global", self.precision) for vv in args], + **{ + kk: safe_cast_array(vv, "global", self.precision) + for kk, vv in kwargs.items() + }, + ) + if isinstance(returned_tensor, tuple): + return tuple( + safe_cast_array(vv, self.precision, "global") for vv in returned_tensor + ) + else: + return safe_cast_array(returned_tensor, self.precision, "global") + + return wrapper + + +def safe_cast_array( + input: np.ndarray, from_precision: str, to_precision: str +) -> np.ndarray: + """Convert an array from a precision to another precision. + + If input is not an array or without the specific precision, the method will not + cast it. + + Array API is supported. + + Parameters + ---------- + input : tf.Tensor + Input tensor + from_precision : str + Array data type that is casted from + to_precision : str + Array data type that casts to + + Returns + ------- + tf.Tensor + casted Tensor + """ + if array_api_compat.is_array_api_obj(input): + xp = array_api_compat.array_namespace(input) + if input.dtype == get_xp_precision(xp, from_precision): + return xp.astype(input, get_xp_precision(xp, to_precision)) + return input + + __all__ = [ "GLOBAL_NP_FLOAT_PRECISION", "GLOBAL_ENER_FLOAT_PRECISION", diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index 259593e731..d21fc492c3 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -19,6 +19,7 @@ xp_take_along_axis, ) from deepmd.dpmodel.common import ( + cast_precision, to_numpy_array, ) from deepmd.dpmodel.utils import ( @@ -329,6 +330,7 @@ def __init__( self.tebd_dim = tebd_dim self.concat_output_tebd = concat_output_tebd self.trainable = trainable + self.precision = precision def get_rcut(self) -> float: """Returns the cut-off radius.""" @@ -448,6 +450,7 @@ def change_type_map( obj["davg"] = obj["davg"][remap_index] obj["dstd"] = obj["dstd"][remap_index] + @cast_precision def call( self, coord_ext, diff --git a/deepmd/dpmodel/descriptor/dpa2.py b/deepmd/dpmodel/descriptor/dpa2.py index 097be2ef09..d82f136a9f 100644 --- a/deepmd/dpmodel/descriptor/dpa2.py +++ b/deepmd/dpmodel/descriptor/dpa2.py @@ -14,6 +14,7 @@ xp_take_along_axis, ) from deepmd.dpmodel.common import ( + cast_precision, to_numpy_array, ) from deepmd.dpmodel.utils import ( @@ -594,6 +595,7 @@ def init_subclass_params(sub_data, sub_class): self.rcut = self.repinit.get_rcut() self.ntypes = ntypes self.sel = self.repinit.sel + self.precision = precision def get_rcut(self) -> float: """Returns the cut-off radius.""" @@ -757,6 +759,7 @@ def get_stat_mean_and_stddev(self) -> tuple[list[np.ndarray], list[np.ndarray]]: stddev_list.append(self.repinit_three_body.stddev) return mean_list, stddev_list + @cast_precision def call( self, coord_ext: np.ndarray, diff --git a/deepmd/dpmodel/descriptor/se_e2_a.py b/deepmd/dpmodel/descriptor/se_e2_a.py index 63402b6f84..4ffdd025e8 100644 --- a/deepmd/dpmodel/descriptor/se_e2_a.py +++ b/deepmd/dpmodel/descriptor/se_e2_a.py @@ -15,6 +15,7 @@ NativeOP, ) from deepmd.dpmodel.common import ( + cast_precision, to_numpy_array, ) from deepmd.dpmodel.utils import ( @@ -29,9 +30,6 @@ from deepmd.dpmodel.utils.update_sel import ( UpdateSel, ) -from deepmd.env import ( - GLOBAL_NP_FLOAT_PRECISION, -) from deepmd.utils.data_system import ( DeepmdDataSystem, ) @@ -340,6 +338,7 @@ def reinit_exclude( self.exclude_types = exclude_types self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types) + @cast_precision def call( self, coord_ext, @@ -415,9 +414,7 @@ def call( # nf x nloc x ng x ng1 grrg = np.einsum("flid,fljd->flij", gr, gr1) # nf x nloc x (ng x ng1) - grrg = grrg.reshape(nf, nloc, ng * self.axis_neuron).astype( - GLOBAL_NP_FLOAT_PRECISION - ) + grrg = grrg.reshape(nf, nloc, ng * self.axis_neuron) return grrg, gr[..., 1:], None, None, ww def serialize(self) -> dict: @@ -506,6 +503,7 @@ def update_sel( class DescrptSeAArrayAPI(DescrptSeA): + @cast_precision def call( self, coord_ext, @@ -585,7 +583,5 @@ def call( # grrg = xp.einsum("flid,fljd->flij", gr, gr1) grrg = xp.sum(gr[:, :, :, None, :] * gr1[:, :, None, :, :], axis=4) # nf x nloc x (ng x ng1) - grrg = xp.astype( - xp.reshape(grrg, (nf, nloc, ng * self.axis_neuron)), input_dtype - ) + grrg = xp.reshape(grrg, (nf, nloc, ng * self.axis_neuron)) return grrg, gr[..., 1:], None, None, ww diff --git a/deepmd/dpmodel/descriptor/se_r.py b/deepmd/dpmodel/descriptor/se_r.py index d652eb1420..45757c68ec 100644 --- a/deepmd/dpmodel/descriptor/se_r.py +++ b/deepmd/dpmodel/descriptor/se_r.py @@ -14,6 +14,7 @@ NativeOP, ) from deepmd.dpmodel.common import ( + cast_precision, get_xp_precision, to_numpy_array, ) @@ -289,6 +290,7 @@ def cal_g( gg = self.embeddings[(ll,)].call(ss) return gg + @cast_precision def call( self, coord_ext, @@ -352,7 +354,6 @@ def call( res_rescale = 1.0 / 5.0 res = xyz_scatter * res_rescale res = xp.reshape(res, (nf, nloc, ng)) - res = xp.astype(res, get_xp_precision(xp, "global")) return res, None, None, None, ww def serialize(self) -> dict: diff --git a/deepmd/dpmodel/descriptor/se_t.py b/deepmd/dpmodel/descriptor/se_t.py index be587c77da..38bd660af2 100644 --- a/deepmd/dpmodel/descriptor/se_t.py +++ b/deepmd/dpmodel/descriptor/se_t.py @@ -14,6 +14,7 @@ NativeOP, ) from deepmd.dpmodel.common import ( + cast_precision, get_xp_precision, to_numpy_array, ) @@ -264,6 +265,7 @@ def reinit_exclude( self.exclude_types = exclude_types self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types) + @cast_precision def call( self, coord_ext, @@ -317,7 +319,6 @@ def call( # we don't require atype is the same in all frames exclude_mask = xp.reshape(exclude_mask, (nf * nloc, nnei)) rr = xp.reshape(rr, (nf * nloc, nnei, 4)) - rr = xp.astype(rr, get_xp_precision(xp, self.precision)) for embedding_idx in itertools.product( range(self.ntypes), repeat=self.embeddings.ndim @@ -349,7 +350,6 @@ def call( result += res_ij # nf x nloc x ng result = xp.reshape(result, (nf, nloc, ng)) - result = xp.astype(result, get_xp_precision(xp, "global")) return result, None, None, None, ww def serialize(self) -> dict: diff --git a/deepmd/dpmodel/descriptor/se_t_tebd.py b/deepmd/dpmodel/descriptor/se_t_tebd.py index 298f823690..b1b7cfa930 100644 --- a/deepmd/dpmodel/descriptor/se_t_tebd.py +++ b/deepmd/dpmodel/descriptor/se_t_tebd.py @@ -16,7 +16,7 @@ xp_take_along_axis, ) from deepmd.dpmodel.common import ( - get_xp_precision, + cast_precision, to_numpy_array, ) from deepmd.dpmodel.utils import ( @@ -168,6 +168,7 @@ def __init__( self.tebd_dim = tebd_dim self.concat_output_tebd = concat_output_tebd self.trainable = trainable + self.precision = precision def get_rcut(self) -> float: """Returns the cut-off radius.""" @@ -287,6 +288,7 @@ def change_type_map( obj["davg"] = obj["davg"][remap_index] obj["dstd"] = obj["dstd"][remap_index] + @cast_precision def call( self, coord_ext, @@ -741,7 +743,6 @@ def call( res_ij = res_ij * (1.0 / float(self.nnei) / float(self.nnei)) # nf x nl x ng result = xp.reshape(res_ij, (nf, nloc, self.filter_neuron[-1])) - result = xp.astype(result, get_xp_precision(xp, "global")) return ( result, None, diff --git a/deepmd/dpmodel/fitting/general_fitting.py b/deepmd/dpmodel/fitting/general_fitting.py index 344dab7ff1..b4691bf8a3 100644 --- a/deepmd/dpmodel/fitting/general_fitting.py +++ b/deepmd/dpmodel/fitting/general_fitting.py @@ -364,6 +364,11 @@ def _call_common( """ xp = array_api_compat.array_namespace(descriptor, atype) + descriptor = xp.astype(descriptor, get_xp_precision(xp, self.precision)) + if fparam is not None: + fparam = xp.astype(fparam, get_xp_precision(xp, self.precision)) + if aparam is not None: + aparam = xp.astype(aparam, get_xp_precision(xp, self.precision)) nf, nloc, nd = descriptor.shape net_dim_out = self._net_out_dim() # check input dim @@ -439,18 +444,24 @@ def _call_common( ): assert xx_zeros is not None atom_property -= self.nets[(type_i,)](xx_zeros) - atom_property = atom_property + self.bias_atom_e[type_i, ...] - atom_property = atom_property * xp.astype(mask, atom_property.dtype) + atom_property = xp.where( + mask, atom_property, xp.zeros_like(atom_property) + ) outs = outs + atom_property # Shape is [nframes, natoms[0], 1] + outs = xp.astype(outs, get_xp_precision(xp, "global")) + for type_i in range(self.ntypes): + outs = outs + self.bias_atom_e[type_i, ...] else: - outs = self.nets[()](xx) + xp.reshape( + outs = self.nets[()](xx) + if xx_zeros is not None: + outs -= self.nets[()](xx_zeros) + outs = xp.astype(outs, get_xp_precision(xp, "global")) + outs += xp.reshape( xp.take(self.bias_atom_e, xp.reshape(atype, [-1]), axis=0), [nf, nloc, net_dim_out], ) - if xx_zeros is not None: - outs -= self.nets[()](xx_zeros) # nf x nloc exclude_mask = self.emask.build_type_exclude_mask(atype) # nf x nloc x nod - outs = outs * xp.astype(exclude_mask[:, :, None], outs.dtype) + outs = xp.where(exclude_mask[:, :, None], outs, xp.zeros_like(outs)) return {self.var_name: xp.astype(outs, get_xp_precision(xp, "global"))} diff --git a/deepmd/jax/env.py b/deepmd/jax/env.py index 1b90433b00..02e31ae66e 100644 --- a/deepmd/jax/env.py +++ b/deepmd/jax/env.py @@ -13,6 +13,9 @@ jax.config.update("jax_enable_x64", True) # jax.config.update("jax_debug_nans", True) +if os.environ.get("DP_DTYPE_PROMOTION_STRICT") == "1": + jax.config.update("jax_numpy_dtype_promotion", "strict") + __all__ = [ "jax", "jnp", From 73b4227790a9c38526f2a6d262f3f599a2339b65 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 12 Nov 2024 05:12:26 -0500 Subject: [PATCH 02/14] skip float32 tests for numpy Signed-off-by: Jinzhe Zeng --- .../tests/consistent/descriptor/test_dpa1.py | 6 +++ .../tests/consistent/descriptor/test_dpa2.py | 40 ++++++++++++++++++- .../consistent/descriptor/test_se_atten_v2.py | 6 +++ .../consistent/descriptor/test_se_e2_a.py | 6 +++ .../tests/consistent/descriptor/test_se_r.py | 6 +++ .../tests/consistent/descriptor/test_se_t.py | 17 +++++++- .../consistent/descriptor/test_se_t_tebd.py | 24 ++++++++++- .../tests/consistent/fitting/test_dipole.py | 25 +++++++++++- source/tests/consistent/fitting/test_dos.py | 26 ++++++++++++ source/tests/consistent/fitting/test_ener.py | 18 +++++++++ source/tests/consistent/fitting/test_polar.py | 25 +++++++++++- 11 files changed, 194 insertions(+), 5 deletions(-) diff --git a/source/tests/consistent/descriptor/test_dpa1.py b/source/tests/consistent/descriptor/test_dpa1.py index 3d80e310d0..c789eec1d9 100644 --- a/source/tests/consistent/descriptor/test_dpa1.py +++ b/source/tests/consistent/descriptor/test_dpa1.py @@ -182,6 +182,9 @@ def skip_dp(self) -> bool: use_econf_tebd, use_tebd_bias, ) = self.param + if precision == "float32": + # NumPy doesn't throw errors for float64 x float32 + return True return CommonTest.skip_dp or self.is_meaningless_zero_attention_layer_tests( attn_layer, temperature, @@ -238,6 +241,9 @@ def skip_array_api_strict(self) -> bool: use_econf_tebd, use_tebd_bias, ) = self.param + if precision == "float32": + # NumPy doesn't throw errors for float64 x float32 + return True return ( not INSTALLED_ARRAY_API_STRICT or self.is_meaningless_zero_attention_layer_tests( diff --git a/source/tests/consistent/descriptor/test_dpa2.py b/source/tests/consistent/descriptor/test_dpa2.py index 17c55db368..1a4d6d5ec0 100644 --- a/source/tests/consistent/descriptor/test_dpa2.py +++ b/source/tests/consistent/descriptor/test_dpa2.py @@ -245,6 +245,9 @@ def skip_dp(self) -> bool: use_econf_tebd, use_tebd_bias, ) = self.param + if precision == "float32": + # NumPy doesn't throw errors for float64 x float32 + return True return CommonTest.skip_dp @property @@ -281,7 +284,42 @@ def skip_tf(self) -> bool: return True skip_jax = not INSTALLED_JAX - skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT + + @property + def skip_array_api_strict(self) -> bool: + ( + repinit_tebd_input_mode, + repinit_set_davg_zero, + repinit_type_one_side, + repinit_use_three_body, + repformer_update_g1_has_conv, + repformer_direct_dist, + repformer_update_g1_has_drrd, + repformer_update_g1_has_grrg, + repformer_update_g1_has_attn, + repformer_update_g2_has_g1g1, + repformer_update_g2_has_attn, + repformer_update_h2, + repformer_attn2_has_gate, + repformer_update_style, + repformer_update_residual_init, + repformer_set_davg_zero, + repformer_trainable_ln, + repformer_ln_eps, + repformer_use_sqrt_nnei, + repformer_g1_out_conv, + repformer_g1_out_mlp, + smooth, + exclude_types, + precision, + add_tebd_to_repinit_out, + use_econf_tebd, + use_tebd_bias, + ) = self.param + if precision == "float32": + # NumPy doesn't throw errors for float64 x float32 + return True + return not INSTALLED_ARRAY_API_STRICT tf_class = DescrptDPA2TF dp_class = DescrptDPA2DP diff --git a/source/tests/consistent/descriptor/test_se_atten_v2.py b/source/tests/consistent/descriptor/test_se_atten_v2.py index f4a8119ca3..be9eaeb0d9 100644 --- a/source/tests/consistent/descriptor/test_se_atten_v2.py +++ b/source/tests/consistent/descriptor/test_se_atten_v2.py @@ -178,6 +178,9 @@ def skip_dp(self) -> bool: use_econf_tebd, use_tebd_bias, ) = self.param + if precision == "float32": + # NumPy doesn't throw errors for float64 x float32 + return True return CommonTest.skip_dp or self.is_meaningless_zero_attention_layer_tests( attn_layer, attn_dotr, @@ -238,6 +241,9 @@ def skip_array_api_strict(self) -> bool: use_econf_tebd, use_tebd_bias, ) = self.param + if precision == "float32": + # NumPy doesn't throw errors for float64 x float32 + return True return ( not INSTALLED_ARRAY_API_STRICT or self.is_meaningless_zero_attention_layer_tests( diff --git a/source/tests/consistent/descriptor/test_se_e2_a.py b/source/tests/consistent/descriptor/test_se_e2_a.py index 286703e21d..a1c26ef98c 100644 --- a/source/tests/consistent/descriptor/test_se_e2_a.py +++ b/source/tests/consistent/descriptor/test_se_e2_a.py @@ -98,6 +98,9 @@ def skip_dp(self) -> bool: precision, env_protection, ) = self.param + if precision == "float32": + # NumPy doesn't throw errors for float64 x float32 + return True return CommonTest.skip_dp @property @@ -131,6 +134,9 @@ def skip_array_api_strict(self) -> bool: precision, env_protection, ) = self.param + if precision == "float32": + # NumPy doesn't throw errors for float64 x float32 + return True return not type_one_side or not INSTALLED_ARRAY_API_STRICT tf_class = DescrptSeATF diff --git a/source/tests/consistent/descriptor/test_se_r.py b/source/tests/consistent/descriptor/test_se_r.py index e851106c44..aa352eba14 100644 --- a/source/tests/consistent/descriptor/test_se_r.py +++ b/source/tests/consistent/descriptor/test_se_r.py @@ -92,6 +92,9 @@ def skip_dp(self) -> bool: excluded_types, precision, ) = self.param + if precision == "float32": + # NumPy doesn't throw errors for float64 x float32 + return True return not type_one_side or CommonTest.skip_dp @property @@ -112,6 +115,9 @@ def skip_array_api_strict(self) -> bool: excluded_types, precision, ) = self.param + if precision == "float32": + # NumPy doesn't throw errors for float64 x float32 + return True return not type_one_side or not INSTALLED_ARRAY_API_STRICT tf_class = DescrptSeRTF diff --git a/source/tests/consistent/descriptor/test_se_t.py b/source/tests/consistent/descriptor/test_se_t.py index 1e6110705a..7a66873af9 100644 --- a/source/tests/consistent/descriptor/test_se_t.py +++ b/source/tests/consistent/descriptor/test_se_t.py @@ -89,6 +89,9 @@ def skip_dp(self) -> bool: precision, env_protection, ) = self.param + if precision == "float32": + # NumPy doesn't throw errors for float64 x float32 + return True return CommonTest.skip_dp @property @@ -101,7 +104,19 @@ def skip_tf(self) -> bool: ) = self.param return env_protection != 0.0 or excluded_types - skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT + @property + def skip_array_api_strict(self) -> bool: + ( + resnet_dt, + excluded_types, + precision, + env_protection, + ) = self.param + if precision == "float32": + # NumPy doesn't throw errors for float64 x float32 + return True + return not INSTALLED_ARRAY_API_STRICT + skip_jax = not INSTALLED_JAX tf_class = DescrptSeTTF diff --git a/source/tests/consistent/descriptor/test_se_t_tebd.py b/source/tests/consistent/descriptor/test_se_t_tebd.py index 4712c28e53..2e5f7fd2ad 100644 --- a/source/tests/consistent/descriptor/test_se_t_tebd.py +++ b/source/tests/consistent/descriptor/test_se_t_tebd.py @@ -127,6 +127,9 @@ def skip_dp(self) -> bool: use_econf_tebd, use_tebd_bias, ) = self.param + if precision == "float32": + # NumPy doesn't throw errors for float64 x float32 + return True return CommonTest.skip_dp @property @@ -147,7 +150,26 @@ def skip_tf(self) -> bool: return True skip_jax = not INSTALLED_JAX - skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT + + @property + def skip_array_api_strict(self) -> bool: + ( + tebd_dim, + tebd_input_mode, + resnet_dt, + excluded_types, + env_protection, + set_davg_zero, + smooth, + concat_output_tebd, + precision, + use_econf_tebd, + use_tebd_bias, + ) = self.param + if precision == "float32": + # NumPy doesn't throw errors for float64 x float32 + return True + return not INSTALLED_ARRAY_API_STRICT tf_class = DescrptSeTTebdTF dp_class = DescrptSeTTebdDP diff --git a/source/tests/consistent/fitting/test_dipole.py b/source/tests/consistent/fitting/test_dipole.py index 088cb30238..d77beab161 100644 --- a/source/tests/consistent/fitting/test_dipole.py +++ b/source/tests/consistent/fitting/test_dipole.py @@ -86,6 +86,30 @@ def skip_pt(self) -> bool: ) = self.param return CommonTest.skip_pt + @property + def skip_dp(self) -> bool: + ( + resnet_dt, + precision, + mixed_types, + ) = self.param + if precision == "float32": + # NumPy doesn't throw errors for float64 x float32 + return True + return CommonTest.skip_dp + + @property + def skip_array_api_strict(self) -> bool: + ( + resnet_dt, + precision, + mixed_types, + ) = self.param + if precision == "float32": + # NumPy doesn't throw errors for float64 x float32 + return True + return not INSTALLED_ARRAY_API_STRICT + tf_class = DipoleFittingTF dp_class = DipoleFittingDP pt_class = DipoleFittingPT @@ -93,7 +117,6 @@ def skip_pt(self) -> bool: array_api_strict_class = DipoleFittingArrayAPIStrict args = fitting_dipole() skip_jax = not INSTALLED_JAX - skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT def setUp(self): CommonTest.setUp(self) diff --git a/source/tests/consistent/fitting/test_dos.py b/source/tests/consistent/fitting/test_dos.py index 0649681ccb..c1155a4190 100644 --- a/source/tests/consistent/fitting/test_dos.py +++ b/source/tests/consistent/fitting/test_dos.py @@ -97,12 +97,38 @@ def skip_pt(self) -> bool: ) = self.param return CommonTest.skip_pt + @property + def skip_dp(self) -> bool: + ( + resnet_dt, + precision, + mixed_types, + numb_fparam, + numb_aparam, + numb_dos, + ) = self.param + if precision == "float32": + # NumPy doesn't throw errors for float64 x float32 + return True + return CommonTest.skip_dp + @property def skip_jax(self) -> bool: return not INSTALLED_JAX @property def skip_array_api_strict(self) -> bool: + ( + resnet_dt, + precision, + mixed_types, + numb_fparam, + numb_aparam, + numb_dos, + ) = self.param + if precision == "float32": + # NumPy doesn't throw errors for float64 x float32 + return True return not INSTALLED_ARRAY_API_STRICT tf_class = DOSFittingTF diff --git a/source/tests/consistent/fitting/test_ener.py b/source/tests/consistent/fitting/test_ener.py index 7be0382b16..b987d16929 100644 --- a/source/tests/consistent/fitting/test_ener.py +++ b/source/tests/consistent/fitting/test_ener.py @@ -100,6 +100,21 @@ def skip_pt(self) -> bool: ) = self.param return CommonTest.skip_pt + @property + def skip_dp(self) -> bool: + ( + resnet_dt, + precision, + mixed_types, + numb_fparam, + (numb_aparam, use_aparam_as_mask), + atom_ener, + ) = self.param + if precision == "float32": + # NumPy doesn't throw errors for float64 x float32 + return True + return CommonTest.skip_dp + skip_jax = not INSTALLED_JAX @property @@ -112,6 +127,9 @@ def skip_array_api_strict(self) -> bool: (numb_aparam, use_aparam_as_mask), atom_ener, ) = self.param + if precision == "float32": + # NumPy doesn't throw errors for float64 x float32 + return True # TypeError: The array_api_strict namespace does not support the dtype 'bfloat16' return not INSTALLED_ARRAY_API_STRICT or precision == "bfloat16" diff --git a/source/tests/consistent/fitting/test_polar.py b/source/tests/consistent/fitting/test_polar.py index 12f13d1e08..ed9e48ee7a 100644 --- a/source/tests/consistent/fitting/test_polar.py +++ b/source/tests/consistent/fitting/test_polar.py @@ -86,6 +86,30 @@ def skip_pt(self) -> bool: ) = self.param return CommonTest.skip_pt + @property + def skip_dp(self) -> bool: + ( + resnet_dt, + precision, + mixed_types, + ) = self.param + if precision == "float32": + # NumPy doesn't throw errors for float64 x float32 + return True + return CommonTest.skip_dp + + @property + def skip_array_api_strict(self) -> bool: + ( + resnet_dt, + precision, + mixed_types, + ) = self.param + if precision == "float32": + # NumPy doesn't throw errors for float64 x float32 + return True + return not INSTALLED_ARRAY_API_STRICT + tf_class = PolarFittingTF dp_class = PolarFittingDP pt_class = PolarFittingPT @@ -93,7 +117,6 @@ def skip_pt(self) -> bool: array_api_strict_class = PolarFittingArrayAPIStrict args = fitting_polar() skip_jax = not INSTALLED_JAX - skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT def setUp(self): CommonTest.setUp(self) From 5f52bb704b5cefc9cdb7f615011a7ec21518f8d2 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 12 Nov 2024 06:18:41 -0500 Subject: [PATCH 03/14] move bias Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/fitting/general_fitting.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/deepmd/dpmodel/fitting/general_fitting.py b/deepmd/dpmodel/fitting/general_fitting.py index b4691bf8a3..2c06644afb 100644 --- a/deepmd/dpmodel/fitting/general_fitting.py +++ b/deepmd/dpmodel/fitting/general_fitting.py @@ -448,18 +448,15 @@ def _call_common( mask, atom_property, xp.zeros_like(atom_property) ) outs = outs + atom_property # Shape is [nframes, natoms[0], 1] - outs = xp.astype(outs, get_xp_precision(xp, "global")) - for type_i in range(self.ntypes): - outs = outs + self.bias_atom_e[type_i, ...] else: outs = self.nets[()](xx) if xx_zeros is not None: outs -= self.nets[()](xx_zeros) - outs = xp.astype(outs, get_xp_precision(xp, "global")) - outs += xp.reshape( - xp.take(self.bias_atom_e, xp.reshape(atype, [-1]), axis=0), - [nf, nloc, net_dim_out], - ) + outs = xp.astype(outs, get_xp_precision(xp, "global")) + outs += xp.reshape( + xp.take(self.bias_atom_e, xp.reshape(atype, [-1]), axis=0), + [nf, nloc, net_dim_out], + ) # nf x nloc exclude_mask = self.emask.build_type_exclude_mask(atype) # nf x nloc x nod From 918284a05c54bd3e7d367657c6ed19bfaa2d024b Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 12 Nov 2024 06:40:07 -0500 Subject: [PATCH 04/14] Revert "skip float32 tests for numpy" This reverts commit 73b4227790a9c38526f2a6d262f3f599a2339b65. --- .../tests/consistent/descriptor/test_dpa1.py | 6 --- .../tests/consistent/descriptor/test_dpa2.py | 40 +------------------ .../consistent/descriptor/test_se_atten_v2.py | 6 --- .../consistent/descriptor/test_se_e2_a.py | 6 --- .../tests/consistent/descriptor/test_se_r.py | 6 --- .../tests/consistent/descriptor/test_se_t.py | 17 +------- .../consistent/descriptor/test_se_t_tebd.py | 24 +---------- .../tests/consistent/fitting/test_dipole.py | 25 +----------- source/tests/consistent/fitting/test_dos.py | 26 ------------ source/tests/consistent/fitting/test_ener.py | 18 --------- source/tests/consistent/fitting/test_polar.py | 25 +----------- 11 files changed, 5 insertions(+), 194 deletions(-) diff --git a/source/tests/consistent/descriptor/test_dpa1.py b/source/tests/consistent/descriptor/test_dpa1.py index c789eec1d9..3d80e310d0 100644 --- a/source/tests/consistent/descriptor/test_dpa1.py +++ b/source/tests/consistent/descriptor/test_dpa1.py @@ -182,9 +182,6 @@ def skip_dp(self) -> bool: use_econf_tebd, use_tebd_bias, ) = self.param - if precision == "float32": - # NumPy doesn't throw errors for float64 x float32 - return True return CommonTest.skip_dp or self.is_meaningless_zero_attention_layer_tests( attn_layer, temperature, @@ -241,9 +238,6 @@ def skip_array_api_strict(self) -> bool: use_econf_tebd, use_tebd_bias, ) = self.param - if precision == "float32": - # NumPy doesn't throw errors for float64 x float32 - return True return ( not INSTALLED_ARRAY_API_STRICT or self.is_meaningless_zero_attention_layer_tests( diff --git a/source/tests/consistent/descriptor/test_dpa2.py b/source/tests/consistent/descriptor/test_dpa2.py index 1a4d6d5ec0..17c55db368 100644 --- a/source/tests/consistent/descriptor/test_dpa2.py +++ b/source/tests/consistent/descriptor/test_dpa2.py @@ -245,9 +245,6 @@ def skip_dp(self) -> bool: use_econf_tebd, use_tebd_bias, ) = self.param - if precision == "float32": - # NumPy doesn't throw errors for float64 x float32 - return True return CommonTest.skip_dp @property @@ -284,42 +281,7 @@ def skip_tf(self) -> bool: return True skip_jax = not INSTALLED_JAX - - @property - def skip_array_api_strict(self) -> bool: - ( - repinit_tebd_input_mode, - repinit_set_davg_zero, - repinit_type_one_side, - repinit_use_three_body, - repformer_update_g1_has_conv, - repformer_direct_dist, - repformer_update_g1_has_drrd, - repformer_update_g1_has_grrg, - repformer_update_g1_has_attn, - repformer_update_g2_has_g1g1, - repformer_update_g2_has_attn, - repformer_update_h2, - repformer_attn2_has_gate, - repformer_update_style, - repformer_update_residual_init, - repformer_set_davg_zero, - repformer_trainable_ln, - repformer_ln_eps, - repformer_use_sqrt_nnei, - repformer_g1_out_conv, - repformer_g1_out_mlp, - smooth, - exclude_types, - precision, - add_tebd_to_repinit_out, - use_econf_tebd, - use_tebd_bias, - ) = self.param - if precision == "float32": - # NumPy doesn't throw errors for float64 x float32 - return True - return not INSTALLED_ARRAY_API_STRICT + skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT tf_class = DescrptDPA2TF dp_class = DescrptDPA2DP diff --git a/source/tests/consistent/descriptor/test_se_atten_v2.py b/source/tests/consistent/descriptor/test_se_atten_v2.py index be9eaeb0d9..f4a8119ca3 100644 --- a/source/tests/consistent/descriptor/test_se_atten_v2.py +++ b/source/tests/consistent/descriptor/test_se_atten_v2.py @@ -178,9 +178,6 @@ def skip_dp(self) -> bool: use_econf_tebd, use_tebd_bias, ) = self.param - if precision == "float32": - # NumPy doesn't throw errors for float64 x float32 - return True return CommonTest.skip_dp or self.is_meaningless_zero_attention_layer_tests( attn_layer, attn_dotr, @@ -241,9 +238,6 @@ def skip_array_api_strict(self) -> bool: use_econf_tebd, use_tebd_bias, ) = self.param - if precision == "float32": - # NumPy doesn't throw errors for float64 x float32 - return True return ( not INSTALLED_ARRAY_API_STRICT or self.is_meaningless_zero_attention_layer_tests( diff --git a/source/tests/consistent/descriptor/test_se_e2_a.py b/source/tests/consistent/descriptor/test_se_e2_a.py index a1c26ef98c..286703e21d 100644 --- a/source/tests/consistent/descriptor/test_se_e2_a.py +++ b/source/tests/consistent/descriptor/test_se_e2_a.py @@ -98,9 +98,6 @@ def skip_dp(self) -> bool: precision, env_protection, ) = self.param - if precision == "float32": - # NumPy doesn't throw errors for float64 x float32 - return True return CommonTest.skip_dp @property @@ -134,9 +131,6 @@ def skip_array_api_strict(self) -> bool: precision, env_protection, ) = self.param - if precision == "float32": - # NumPy doesn't throw errors for float64 x float32 - return True return not type_one_side or not INSTALLED_ARRAY_API_STRICT tf_class = DescrptSeATF diff --git a/source/tests/consistent/descriptor/test_se_r.py b/source/tests/consistent/descriptor/test_se_r.py index aa352eba14..e851106c44 100644 --- a/source/tests/consistent/descriptor/test_se_r.py +++ b/source/tests/consistent/descriptor/test_se_r.py @@ -92,9 +92,6 @@ def skip_dp(self) -> bool: excluded_types, precision, ) = self.param - if precision == "float32": - # NumPy doesn't throw errors for float64 x float32 - return True return not type_one_side or CommonTest.skip_dp @property @@ -115,9 +112,6 @@ def skip_array_api_strict(self) -> bool: excluded_types, precision, ) = self.param - if precision == "float32": - # NumPy doesn't throw errors for float64 x float32 - return True return not type_one_side or not INSTALLED_ARRAY_API_STRICT tf_class = DescrptSeRTF diff --git a/source/tests/consistent/descriptor/test_se_t.py b/source/tests/consistent/descriptor/test_se_t.py index 7a66873af9..1e6110705a 100644 --- a/source/tests/consistent/descriptor/test_se_t.py +++ b/source/tests/consistent/descriptor/test_se_t.py @@ -89,9 +89,6 @@ def skip_dp(self) -> bool: precision, env_protection, ) = self.param - if precision == "float32": - # NumPy doesn't throw errors for float64 x float32 - return True return CommonTest.skip_dp @property @@ -104,19 +101,7 @@ def skip_tf(self) -> bool: ) = self.param return env_protection != 0.0 or excluded_types - @property - def skip_array_api_strict(self) -> bool: - ( - resnet_dt, - excluded_types, - precision, - env_protection, - ) = self.param - if precision == "float32": - # NumPy doesn't throw errors for float64 x float32 - return True - return not INSTALLED_ARRAY_API_STRICT - + skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT skip_jax = not INSTALLED_JAX tf_class = DescrptSeTTF diff --git a/source/tests/consistent/descriptor/test_se_t_tebd.py b/source/tests/consistent/descriptor/test_se_t_tebd.py index 2e5f7fd2ad..4712c28e53 100644 --- a/source/tests/consistent/descriptor/test_se_t_tebd.py +++ b/source/tests/consistent/descriptor/test_se_t_tebd.py @@ -127,9 +127,6 @@ def skip_dp(self) -> bool: use_econf_tebd, use_tebd_bias, ) = self.param - if precision == "float32": - # NumPy doesn't throw errors for float64 x float32 - return True return CommonTest.skip_dp @property @@ -150,26 +147,7 @@ def skip_tf(self) -> bool: return True skip_jax = not INSTALLED_JAX - - @property - def skip_array_api_strict(self) -> bool: - ( - tebd_dim, - tebd_input_mode, - resnet_dt, - excluded_types, - env_protection, - set_davg_zero, - smooth, - concat_output_tebd, - precision, - use_econf_tebd, - use_tebd_bias, - ) = self.param - if precision == "float32": - # NumPy doesn't throw errors for float64 x float32 - return True - return not INSTALLED_ARRAY_API_STRICT + skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT tf_class = DescrptSeTTebdTF dp_class = DescrptSeTTebdDP diff --git a/source/tests/consistent/fitting/test_dipole.py b/source/tests/consistent/fitting/test_dipole.py index d77beab161..088cb30238 100644 --- a/source/tests/consistent/fitting/test_dipole.py +++ b/source/tests/consistent/fitting/test_dipole.py @@ -86,30 +86,6 @@ def skip_pt(self) -> bool: ) = self.param return CommonTest.skip_pt - @property - def skip_dp(self) -> bool: - ( - resnet_dt, - precision, - mixed_types, - ) = self.param - if precision == "float32": - # NumPy doesn't throw errors for float64 x float32 - return True - return CommonTest.skip_dp - - @property - def skip_array_api_strict(self) -> bool: - ( - resnet_dt, - precision, - mixed_types, - ) = self.param - if precision == "float32": - # NumPy doesn't throw errors for float64 x float32 - return True - return not INSTALLED_ARRAY_API_STRICT - tf_class = DipoleFittingTF dp_class = DipoleFittingDP pt_class = DipoleFittingPT @@ -117,6 +93,7 @@ def skip_array_api_strict(self) -> bool: array_api_strict_class = DipoleFittingArrayAPIStrict args = fitting_dipole() skip_jax = not INSTALLED_JAX + skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT def setUp(self): CommonTest.setUp(self) diff --git a/source/tests/consistent/fitting/test_dos.py b/source/tests/consistent/fitting/test_dos.py index c1155a4190..0649681ccb 100644 --- a/source/tests/consistent/fitting/test_dos.py +++ b/source/tests/consistent/fitting/test_dos.py @@ -97,38 +97,12 @@ def skip_pt(self) -> bool: ) = self.param return CommonTest.skip_pt - @property - def skip_dp(self) -> bool: - ( - resnet_dt, - precision, - mixed_types, - numb_fparam, - numb_aparam, - numb_dos, - ) = self.param - if precision == "float32": - # NumPy doesn't throw errors for float64 x float32 - return True - return CommonTest.skip_dp - @property def skip_jax(self) -> bool: return not INSTALLED_JAX @property def skip_array_api_strict(self) -> bool: - ( - resnet_dt, - precision, - mixed_types, - numb_fparam, - numb_aparam, - numb_dos, - ) = self.param - if precision == "float32": - # NumPy doesn't throw errors for float64 x float32 - return True return not INSTALLED_ARRAY_API_STRICT tf_class = DOSFittingTF diff --git a/source/tests/consistent/fitting/test_ener.py b/source/tests/consistent/fitting/test_ener.py index b987d16929..7be0382b16 100644 --- a/source/tests/consistent/fitting/test_ener.py +++ b/source/tests/consistent/fitting/test_ener.py @@ -100,21 +100,6 @@ def skip_pt(self) -> bool: ) = self.param return CommonTest.skip_pt - @property - def skip_dp(self) -> bool: - ( - resnet_dt, - precision, - mixed_types, - numb_fparam, - (numb_aparam, use_aparam_as_mask), - atom_ener, - ) = self.param - if precision == "float32": - # NumPy doesn't throw errors for float64 x float32 - return True - return CommonTest.skip_dp - skip_jax = not INSTALLED_JAX @property @@ -127,9 +112,6 @@ def skip_array_api_strict(self) -> bool: (numb_aparam, use_aparam_as_mask), atom_ener, ) = self.param - if precision == "float32": - # NumPy doesn't throw errors for float64 x float32 - return True # TypeError: The array_api_strict namespace does not support the dtype 'bfloat16' return not INSTALLED_ARRAY_API_STRICT or precision == "bfloat16" diff --git a/source/tests/consistent/fitting/test_polar.py b/source/tests/consistent/fitting/test_polar.py index ed9e48ee7a..12f13d1e08 100644 --- a/source/tests/consistent/fitting/test_polar.py +++ b/source/tests/consistent/fitting/test_polar.py @@ -86,30 +86,6 @@ def skip_pt(self) -> bool: ) = self.param return CommonTest.skip_pt - @property - def skip_dp(self) -> bool: - ( - resnet_dt, - precision, - mixed_types, - ) = self.param - if precision == "float32": - # NumPy doesn't throw errors for float64 x float32 - return True - return CommonTest.skip_dp - - @property - def skip_array_api_strict(self) -> bool: - ( - resnet_dt, - precision, - mixed_types, - ) = self.param - if precision == "float32": - # NumPy doesn't throw errors for float64 x float32 - return True - return not INSTALLED_ARRAY_API_STRICT - tf_class = PolarFittingTF dp_class = PolarFittingDP pt_class = PolarFittingPT @@ -117,6 +93,7 @@ def skip_array_api_strict(self) -> bool: array_api_strict_class = PolarFittingArrayAPIStrict args = fitting_polar() skip_jax = not INSTALLED_JAX + skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT def setUp(self): CommonTest.setUp(self) From 174c582a5fc6d8bb40d6d23f177131dcab5ea3a2 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 12 Nov 2024 17:28:56 -0500 Subject: [PATCH 05/14] fix the test error Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/atomic_model/base_atomic_model.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/deepmd/dpmodel/atomic_model/base_atomic_model.py b/deepmd/dpmodel/atomic_model/base_atomic_model.py index 4e7620bdda..c807dbf2ef 100644 --- a/deepmd/dpmodel/atomic_model/base_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/base_atomic_model.py @@ -201,17 +201,18 @@ def forward_common_atomic( ret_dict = self.apply_out_stat(ret_dict, atype) # nf x nloc - atom_mask = ext_atom_mask[:, :nloc].astype(xp.int32) + atom_mask = ext_atom_mask[:, :nloc] if self.atom_excl is not None: - atom_mask *= self.atom_excl.build_type_exclude_mask(atype) + atom_mask = xp.logical_and( + atom_mask, self.atom_excl.build_type_exclude_mask(atype) + ) for kk in ret_dict.keys(): out_shape = ret_dict[kk].shape out_shape2 = math.prod(out_shape[2:]) - ret_dict[kk] = ( - ret_dict[kk].reshape([out_shape[0], out_shape[1], out_shape2]) - * atom_mask[:, :, None] - ).reshape(out_shape) + tmp_arr = ret_dict[kk].reshape([out_shape[0], out_shape[1], out_shape2]) + tmp_arr = xp.where(atom_mask[:, :, None], tmp_arr, xp.zeros_like(tmp_arr)) + ret_dict[kk] = xp.reshape(tmp_arr, out_shape) ret_dict["mask"] = atom_mask return ret_dict From d5e8b4949cb0425bc57915b23078bef8a3df80d4 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 12 Nov 2024 17:31:53 -0500 Subject: [PATCH 06/14] improve type annotations of `safe_cast_array` Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/common.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/deepmd/dpmodel/common.py b/deepmd/dpmodel/common.py index 63f3a34105..42a3136d2c 100644 --- a/deepmd/dpmodel/common.py +++ b/deepmd/dpmodel/common.py @@ -10,6 +10,7 @@ Any, Callable, Optional, + overload, ) import array_api_compat @@ -177,9 +178,15 @@ def wrapper(self, *args, **kwargs): return wrapper +@overload def safe_cast_array( input: np.ndarray, from_precision: str, to_precision: str -) -> np.ndarray: +) -> np.ndarray: ... +@overload +def safe_cast_array(input: None, from_precision: str, to_precision: str) -> None: ... +def safe_cast_array( + input: Optional[np.ndarray], from_precision: str, to_precision: str +) -> Optional[np.ndarray]: """Convert an array from a precision to another precision. If input is not an array or without the specific precision, the method will not From b7d2b324249a21a1d02d620ba32e4986d3e2ccac Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 12 Nov 2024 17:32:26 -0500 Subject: [PATCH 07/14] support dict Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/common.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/deepmd/dpmodel/common.py b/deepmd/dpmodel/common.py index 42a3136d2c..efeeabaea1 100644 --- a/deepmd/dpmodel/common.py +++ b/deepmd/dpmodel/common.py @@ -172,6 +172,11 @@ def wrapper(self, *args, **kwargs): return tuple( safe_cast_array(vv, self.precision, "global") for vv in returned_tensor ) + elif isinstance(returned_tensor, dict): + return { + kk: safe_cast_array(vv, self.precision, "global") + for kk, vv in returned_tensor.items() + } else: return safe_cast_array(returned_tensor, self.precision, "global") From 201cf80a22e497075c608fd8d6364f6abd89defc Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 12 Nov 2024 18:10:07 -0500 Subject: [PATCH 08/14] fix docstring Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/common.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/deepmd/dpmodel/common.py b/deepmd/dpmodel/common.py index efeeabaea1..2bef086726 100644 --- a/deepmd/dpmodel/common.py +++ b/deepmd/dpmodel/common.py @@ -125,7 +125,7 @@ def cast_precision(func: Callable[..., Any]) -> Callable[..., Any]: """A decorator that casts and casts back the input and output tensor of a method. - The decorator should be used in a classmethod. + The decorator should be used on an instance method. The decorator will do the following thing: (1) It casts input arrays from the global precision @@ -201,8 +201,8 @@ def safe_cast_array( Parameters ---------- - input : tf.Tensor - Input tensor + input : np.ndarray or None + Input array from_precision : str Array data type that is casted from to_precision : str @@ -210,8 +210,8 @@ def safe_cast_array( Returns ------- - tf.Tensor - casted Tensor + np.ndarray or None + casted array """ if array_api_compat.is_array_api_obj(input): xp = array_api_compat.array_namespace(input) From 638acc2a869fea35cbae89170629f4b26163d259 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 12 Nov 2024 18:55:42 -0500 Subject: [PATCH 09/14] fix mask dtype Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/atomic_model/base_atomic_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/dpmodel/atomic_model/base_atomic_model.py b/deepmd/dpmodel/atomic_model/base_atomic_model.py index c807dbf2ef..970019028c 100644 --- a/deepmd/dpmodel/atomic_model/base_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/base_atomic_model.py @@ -213,7 +213,7 @@ def forward_common_atomic( tmp_arr = ret_dict[kk].reshape([out_shape[0], out_shape[1], out_shape2]) tmp_arr = xp.where(atom_mask[:, :, None], tmp_arr, xp.zeros_like(tmp_arr)) ret_dict[kk] = xp.reshape(tmp_arr, out_shape) - ret_dict["mask"] = atom_mask + ret_dict["mask"] = xp.astype(atom_mask, xp.int32) return ret_dict From 41d80b90e6dd329f831d1958f67c40d1af89f28b Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 12 Nov 2024 22:38:30 -0500 Subject: [PATCH 10/14] cast fitting Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/fitting/dipole_fitting.py | 4 ++++ deepmd/dpmodel/fitting/general_fitting.py | 12 ++++-------- deepmd/dpmodel/fitting/invar_fitting.py | 4 ++++ deepmd/dpmodel/fitting/polarizability_fitting.py | 11 +++++++++-- 4 files changed, 21 insertions(+), 10 deletions(-) diff --git a/deepmd/dpmodel/fitting/dipole_fitting.py b/deepmd/dpmodel/fitting/dipole_fitting.py index 5988951445..986c48d878 100644 --- a/deepmd/dpmodel/fitting/dipole_fitting.py +++ b/deepmd/dpmodel/fitting/dipole_fitting.py @@ -11,6 +11,9 @@ from deepmd.dpmodel import ( DEFAULT_PRECISION, ) +from deepmd.dpmodel.common import ( + cast_precision, +) from deepmd.dpmodel.fitting.base_fitting import ( BaseFitting, ) @@ -174,6 +177,7 @@ def output_def(self): ] ) + @cast_precision def call( self, descriptor: np.ndarray, diff --git a/deepmd/dpmodel/fitting/general_fitting.py b/deepmd/dpmodel/fitting/general_fitting.py index 2c06644afb..fe3fbcc744 100644 --- a/deepmd/dpmodel/fitting/general_fitting.py +++ b/deepmd/dpmodel/fitting/general_fitting.py @@ -364,11 +364,6 @@ def _call_common( """ xp = array_api_compat.array_namespace(descriptor, atype) - descriptor = xp.astype(descriptor, get_xp_precision(xp, self.precision)) - if fparam is not None: - fparam = xp.astype(fparam, get_xp_precision(xp, self.precision)) - if aparam is not None: - aparam = xp.astype(aparam, get_xp_precision(xp, self.precision)) nf, nloc, nd = descriptor.shape net_dim_out = self._net_out_dim() # check input dim @@ -452,13 +447,14 @@ def _call_common( outs = self.nets[()](xx) if xx_zeros is not None: outs -= self.nets[()](xx_zeros) - outs = xp.astype(outs, get_xp_precision(xp, "global")) outs += xp.reshape( - xp.take(self.bias_atom_e, xp.reshape(atype, [-1]), axis=0), + xp.take( + xp.astype(self.bias_atom_e, outs.dtype), xp.reshape(atype, [-1]), axis=0 + ), [nf, nloc, net_dim_out], ) # nf x nloc exclude_mask = self.emask.build_type_exclude_mask(atype) # nf x nloc x nod outs = xp.where(exclude_mask[:, :, None], outs, xp.zeros_like(outs)) - return {self.var_name: xp.astype(outs, get_xp_precision(xp, "global"))} + return {self.var_name: xp.astype(outs, descriptor.dtype)} diff --git a/deepmd/dpmodel/fitting/invar_fitting.py b/deepmd/dpmodel/fitting/invar_fitting.py index afae455441..8fe9c00c0f 100644 --- a/deepmd/dpmodel/fitting/invar_fitting.py +++ b/deepmd/dpmodel/fitting/invar_fitting.py @@ -10,6 +10,9 @@ from deepmd.dpmodel import ( DEFAULT_PRECISION, ) +from deepmd.dpmodel.common import ( + cast_precision, +) from deepmd.dpmodel.output_def import ( FittingOutputDef, OutputVariableDef, @@ -203,6 +206,7 @@ def output_def(self): ] ) + @cast_precision def call( self, descriptor: np.ndarray, diff --git a/deepmd/dpmodel/fitting/polarizability_fitting.py b/deepmd/dpmodel/fitting/polarizability_fitting.py index 8cf3c4faaa..3e01b85056 100644 --- a/deepmd/dpmodel/fitting/polarizability_fitting.py +++ b/deepmd/dpmodel/fitting/polarizability_fitting.py @@ -15,6 +15,7 @@ DEFAULT_PRECISION, ) from deepmd.dpmodel.common import ( + cast_precision, to_numpy_array, ) from deepmd.dpmodel.fitting.base_fitting import ( @@ -241,6 +242,7 @@ def change_type_map( self.scale = self.scale[remap_index] self.constant_matrix = self.constant_matrix[remap_index] + @cast_precision def call( self, descriptor: np.ndarray, @@ -285,7 +287,8 @@ def call( ] # out = out * self.scale[atype, ...] scale_atype = xp.reshape( - xp.take(self.scale, xp.reshape(atype, [-1]), axis=0), (*atype.shape, 1) + xp.take(xp.astype(self.scale, out.dtype), xp.reshape(atype, [-1]), axis=0), + (*atype.shape, 1), ) out = out * scale_atype # (nframes * nloc, m1, 3) @@ -308,7 +311,11 @@ def call( if self.shift_diag: # bias = self.constant_matrix[atype] bias = xp.reshape( - xp.take(self.constant_matrix, xp.reshape(atype, [-1]), axis=0), + xp.take( + xp.astype(self.constant_matrix, out.dtype), + xp.reshape(atype, [-1]), + axis=0, + ), (nframes, nloc), ) # (nframes, nloc, 1) From 30d33b767a3fa164185f10e2ba66b7caf770d165 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 13 Nov 2024 04:27:41 -0500 Subject: [PATCH 11/14] Update deepmd/dpmodel/fitting/general_fitting.py Co-authored-by: Han Wang <92130845+wanghan-iapcm@users.noreply.github.com> Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/fitting/general_fitting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/dpmodel/fitting/general_fitting.py b/deepmd/dpmodel/fitting/general_fitting.py index fe3fbcc744..eba5076418 100644 --- a/deepmd/dpmodel/fitting/general_fitting.py +++ b/deepmd/dpmodel/fitting/general_fitting.py @@ -457,4 +457,4 @@ def _call_common( exclude_mask = self.emask.build_type_exclude_mask(atype) # nf x nloc x nod outs = xp.where(exclude_mask[:, :, None], outs, xp.zeros_like(outs)) - return {self.var_name: xp.astype(outs, descriptor.dtype)} + return {self.var_name: outs)} From 4ba6e9758ef42b356ef8ff3e8dc07a8fe42a8849 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 13 Nov 2024 04:34:45 -0500 Subject: [PATCH 12/14] Update deepmd/dpmodel/fitting/general_fitting.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/fitting/general_fitting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/dpmodel/fitting/general_fitting.py b/deepmd/dpmodel/fitting/general_fitting.py index eba5076418..23fa35f3a5 100644 --- a/deepmd/dpmodel/fitting/general_fitting.py +++ b/deepmd/dpmodel/fitting/general_fitting.py @@ -457,4 +457,4 @@ def _call_common( exclude_mask = self.emask.build_type_exclude_mask(atype) # nf x nloc x nod outs = xp.where(exclude_mask[:, :, None], outs, xp.zeros_like(outs)) - return {self.var_name: outs)} + return {self.var_name: outs} From e43134d23c5a4e27523d1bc9a441e360f712732e Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 13 Nov 2024 15:13:54 -0500 Subject: [PATCH 13/14] workaround for bfloat16 Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/utils/network.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/deepmd/dpmodel/utils/network.py b/deepmd/dpmodel/utils/network.py index 9ae5fd2b40..5087b77800 100644 --- a/deepmd/dpmodel/utils/network.py +++ b/deepmd/dpmodel/utils/network.py @@ -248,6 +248,10 @@ def call(self, x: np.ndarray) -> np.ndarray: if self.b is not None else xp.matmul(x, self.w) ) + if y.dtype != x.dtype: + # workaround for bfloat16 + # https://github.com/jax-ml/ml_dtypes/issues/235 + y = xp.astype(y, x.dtype) y = fn(y) if self.idt is not None: y *= self.idt From 36b087d62c53223a3911ba271e108209b4efaecb Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 13 Nov 2024 18:08:35 -0500 Subject: [PATCH 14/14] fix input in `test_network.py` Signed-off-by: Jinzhe Zeng --- source/tests/common/dpmodel/test_network.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/source/tests/common/dpmodel/test_network.py b/source/tests/common/dpmodel/test_network.py index 381c542272..4b5eb4fa66 100644 --- a/source/tests/common/dpmodel/test_network.py +++ b/source/tests/common/dpmodel/test_network.py @@ -8,6 +8,9 @@ import numpy as np +from deepmd.dpmodel.common import ( + get_xp_precision, +) from deepmd.dpmodel.utils import ( EmbeddingNet, FittingNet, @@ -46,7 +49,9 @@ def test_serialize_deserize(self): inp_shap = [ni] if ashp is not None: inp_shap = ashp + inp_shap - inp = np.arange(np.prod(inp_shap)).reshape(inp_shap) + inp = np.arange( + np.prod(inp_shap), dtype=get_xp_precision(np, prec) + ).reshape(inp_shap) np.testing.assert_allclose(nl0.call(inp), nl1.call(inp)) def test_shape_error(self): @@ -168,7 +173,7 @@ def test_embedding_net(self): resnet_dt=idt, ) en1 = EmbeddingNet.deserialize(en0.serialize()) - inp = np.ones([ni]) + inp = np.ones([ni], dtype=get_xp_precision(np, prec)) np.testing.assert_allclose(en0.call(inp), en1.call(inp)) @@ -191,7 +196,7 @@ def test_fitting_net(self): bias_out=bo, ) en1 = FittingNet.deserialize(en0.serialize()) - inp = np.ones([ni]) + inp = np.ones([ni], dtype=get_xp_precision(np, prec)) en0.call(inp) en1.call(inp) np.testing.assert_allclose(en0.call(inp), en1.call(inp))