diff --git a/.github/workflows/test_python.yml b/.github/workflows/test_python.yml index 1b1935a2f6..9437c69ae8 100644 --- a/.github/workflows/test_python.yml +++ b/.github/workflows/test_python.yml @@ -62,6 +62,7 @@ jobs: env: NUM_WORKERS: 0 DP_TEST_TF2_ONLY: 1 + DP_DTYPE_PROMOTION_STRICT: 1 if: matrix.group == 1 - run: mv .test_durations .test_durations_${{ matrix.group }} - name: Upload partial durations diff --git a/deepmd/dpmodel/atomic_model/base_atomic_model.py b/deepmd/dpmodel/atomic_model/base_atomic_model.py index 4e7620bdda..970019028c 100644 --- a/deepmd/dpmodel/atomic_model/base_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/base_atomic_model.py @@ -201,18 +201,19 @@ def forward_common_atomic( ret_dict = self.apply_out_stat(ret_dict, atype) # nf x nloc - atom_mask = ext_atom_mask[:, :nloc].astype(xp.int32) + atom_mask = ext_atom_mask[:, :nloc] if self.atom_excl is not None: - atom_mask *= self.atom_excl.build_type_exclude_mask(atype) + atom_mask = xp.logical_and( + atom_mask, self.atom_excl.build_type_exclude_mask(atype) + ) for kk in ret_dict.keys(): out_shape = ret_dict[kk].shape out_shape2 = math.prod(out_shape[2:]) - ret_dict[kk] = ( - ret_dict[kk].reshape([out_shape[0], out_shape[1], out_shape2]) - * atom_mask[:, :, None] - ).reshape(out_shape) - ret_dict["mask"] = atom_mask + tmp_arr = ret_dict[kk].reshape([out_shape[0], out_shape[1], out_shape2]) + tmp_arr = xp.where(atom_mask[:, :, None], tmp_arr, xp.zeros_like(tmp_arr)) + ret_dict[kk] = xp.reshape(tmp_arr, out_shape) + ret_dict["mask"] = xp.astype(atom_mask, xp.int32) return ret_dict diff --git a/deepmd/dpmodel/common.py b/deepmd/dpmodel/common.py index 6e6113b494..2bef086726 100644 --- a/deepmd/dpmodel/common.py +++ b/deepmd/dpmodel/common.py @@ -3,9 +3,14 @@ ABC, abstractmethod, ) +from functools import ( + wraps, +) from typing import ( Any, + Callable, Optional, + overload, ) import array_api_compat @@ -116,6 +121,105 @@ def to_numpy_array(x: Any) -> Optional[np.ndarray]: return np.from_dlpack(x) +def cast_precision(func: Callable[..., Any]) -> Callable[..., Any]: + """A decorator that casts and casts back the input + and output tensor of a method. + + The decorator should be used on an instance method. + + The decorator will do the following thing: + (1) It casts input arrays from the global precision + to precision defined by property `precision`. + (2) It casts output arrays from `precision` to + the global precision. + (3) It checks inputs and outputs and only casts when + input or output is an array and its dtype matches + the global precision and `precision`, respectively. + If it does not match (e.g. it is an integer), the decorator + will do nothing on it. + + The decorator supports the array API. + + Returns + ------- + Callable + a decorator that casts and casts back the input and + output array of a method + + Examples + -------- + >>> class A: + ... def __init__(self): + ... self.precision = "float32" + ... + ... @cast_precision + ... def f(x: Array, y: Array) -> Array: + ... return x**2 + y + """ + + @wraps(func) + def wrapper(self, *args, **kwargs): + # only convert tensors + returned_tensor = func( + self, + *[safe_cast_array(vv, "global", self.precision) for vv in args], + **{ + kk: safe_cast_array(vv, "global", self.precision) + for kk, vv in kwargs.items() + }, + ) + if isinstance(returned_tensor, tuple): + return tuple( + safe_cast_array(vv, self.precision, "global") for vv in returned_tensor + ) + elif isinstance(returned_tensor, dict): + return { + kk: safe_cast_array(vv, self.precision, "global") + for kk, vv in returned_tensor.items() + } + else: + return safe_cast_array(returned_tensor, self.precision, "global") + + return wrapper + + +@overload +def safe_cast_array( + input: np.ndarray, from_precision: str, to_precision: str +) -> np.ndarray: ... +@overload +def safe_cast_array(input: None, from_precision: str, to_precision: str) -> None: ... +def safe_cast_array( + input: Optional[np.ndarray], from_precision: str, to_precision: str +) -> Optional[np.ndarray]: + """Convert an array from a precision to another precision. + + If input is not an array or without the specific precision, the method will not + cast it. + + Array API is supported. + + Parameters + ---------- + input : np.ndarray or None + Input array + from_precision : str + Array data type that is casted from + to_precision : str + Array data type that casts to + + Returns + ------- + np.ndarray or None + casted array + """ + if array_api_compat.is_array_api_obj(input): + xp = array_api_compat.array_namespace(input) + if input.dtype == get_xp_precision(xp, from_precision): + return xp.astype(input, get_xp_precision(xp, to_precision)) + return input + + __all__ = [ "GLOBAL_NP_FLOAT_PRECISION", "GLOBAL_ENER_FLOAT_PRECISION", diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index 259593e731..d21fc492c3 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -19,6 +19,7 @@ xp_take_along_axis, ) from deepmd.dpmodel.common import ( + cast_precision, to_numpy_array, ) from deepmd.dpmodel.utils import ( @@ -329,6 +330,7 @@ def __init__( self.tebd_dim = tebd_dim self.concat_output_tebd = concat_output_tebd self.trainable = trainable + self.precision = precision def get_rcut(self) -> float: """Returns the cut-off radius.""" @@ -448,6 +450,7 @@ def change_type_map( obj["davg"] = obj["davg"][remap_index] obj["dstd"] = obj["dstd"][remap_index] + @cast_precision def call( self, coord_ext, diff --git a/deepmd/dpmodel/descriptor/dpa2.py b/deepmd/dpmodel/descriptor/dpa2.py index 097be2ef09..d82f136a9f 100644 --- a/deepmd/dpmodel/descriptor/dpa2.py +++ b/deepmd/dpmodel/descriptor/dpa2.py @@ -14,6 +14,7 @@ xp_take_along_axis, ) from deepmd.dpmodel.common import ( + cast_precision, to_numpy_array, ) from deepmd.dpmodel.utils import ( @@ -594,6 +595,7 @@ def init_subclass_params(sub_data, sub_class): self.rcut = self.repinit.get_rcut() self.ntypes = ntypes self.sel = self.repinit.sel + self.precision = precision def get_rcut(self) -> float: """Returns the cut-off radius.""" @@ -757,6 +759,7 @@ def get_stat_mean_and_stddev(self) -> tuple[list[np.ndarray], list[np.ndarray]]: stddev_list.append(self.repinit_three_body.stddev) return mean_list, stddev_list + @cast_precision def call( self, coord_ext: np.ndarray, diff --git a/deepmd/dpmodel/descriptor/se_e2_a.py b/deepmd/dpmodel/descriptor/se_e2_a.py index 63402b6f84..4ffdd025e8 100644 --- a/deepmd/dpmodel/descriptor/se_e2_a.py +++ b/deepmd/dpmodel/descriptor/se_e2_a.py @@ -15,6 +15,7 @@ NativeOP, ) from deepmd.dpmodel.common import ( + cast_precision, to_numpy_array, ) from deepmd.dpmodel.utils import ( @@ -29,9 +30,6 @@ from deepmd.dpmodel.utils.update_sel import ( UpdateSel, ) -from deepmd.env import ( - GLOBAL_NP_FLOAT_PRECISION, -) from deepmd.utils.data_system import ( DeepmdDataSystem, ) @@ -340,6 +338,7 @@ def reinit_exclude( self.exclude_types = exclude_types self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types) + @cast_precision def call( self, coord_ext, @@ -415,9 +414,7 @@ def call( # nf x nloc x ng x ng1 grrg = np.einsum("flid,fljd->flij", gr, gr1) # nf x nloc x (ng x ng1) - grrg = grrg.reshape(nf, nloc, ng * self.axis_neuron).astype( - GLOBAL_NP_FLOAT_PRECISION - ) + grrg = grrg.reshape(nf, nloc, ng * self.axis_neuron) return grrg, gr[..., 1:], None, None, ww def serialize(self) -> dict: @@ -506,6 +503,7 @@ def update_sel( class DescrptSeAArrayAPI(DescrptSeA): + @cast_precision def call( self, coord_ext, @@ -585,7 +583,5 @@ def call( # grrg = xp.einsum("flid,fljd->flij", gr, gr1) grrg = xp.sum(gr[:, :, :, None, :] * gr1[:, :, None, :, :], axis=4) # nf x nloc x (ng x ng1) - grrg = xp.astype( - xp.reshape(grrg, (nf, nloc, ng * self.axis_neuron)), input_dtype - ) + grrg = xp.reshape(grrg, (nf, nloc, ng * self.axis_neuron)) return grrg, gr[..., 1:], None, None, ww diff --git a/deepmd/dpmodel/descriptor/se_r.py b/deepmd/dpmodel/descriptor/se_r.py index d652eb1420..45757c68ec 100644 --- a/deepmd/dpmodel/descriptor/se_r.py +++ b/deepmd/dpmodel/descriptor/se_r.py @@ -14,6 +14,7 @@ NativeOP, ) from deepmd.dpmodel.common import ( + cast_precision, get_xp_precision, to_numpy_array, ) @@ -289,6 +290,7 @@ def cal_g( gg = self.embeddings[(ll,)].call(ss) return gg + @cast_precision def call( self, coord_ext, @@ -352,7 +354,6 @@ def call( res_rescale = 1.0 / 5.0 res = xyz_scatter * res_rescale res = xp.reshape(res, (nf, nloc, ng)) - res = xp.astype(res, get_xp_precision(xp, "global")) return res, None, None, None, ww def serialize(self) -> dict: diff --git a/deepmd/dpmodel/descriptor/se_t.py b/deepmd/dpmodel/descriptor/se_t.py index be587c77da..38bd660af2 100644 --- a/deepmd/dpmodel/descriptor/se_t.py +++ b/deepmd/dpmodel/descriptor/se_t.py @@ -14,6 +14,7 @@ NativeOP, ) from deepmd.dpmodel.common import ( + cast_precision, get_xp_precision, to_numpy_array, ) @@ -264,6 +265,7 @@ def reinit_exclude( self.exclude_types = exclude_types self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types) + @cast_precision def call( self, coord_ext, @@ -317,7 +319,6 @@ def call( # we don't require atype is the same in all frames exclude_mask = xp.reshape(exclude_mask, (nf * nloc, nnei)) rr = xp.reshape(rr, (nf * nloc, nnei, 4)) - rr = xp.astype(rr, get_xp_precision(xp, self.precision)) for embedding_idx in itertools.product( range(self.ntypes), repeat=self.embeddings.ndim @@ -349,7 +350,6 @@ def call( result += res_ij # nf x nloc x ng result = xp.reshape(result, (nf, nloc, ng)) - result = xp.astype(result, get_xp_precision(xp, "global")) return result, None, None, None, ww def serialize(self) -> dict: diff --git a/deepmd/dpmodel/descriptor/se_t_tebd.py b/deepmd/dpmodel/descriptor/se_t_tebd.py index 298f823690..b1b7cfa930 100644 --- a/deepmd/dpmodel/descriptor/se_t_tebd.py +++ b/deepmd/dpmodel/descriptor/se_t_tebd.py @@ -16,7 +16,7 @@ xp_take_along_axis, ) from deepmd.dpmodel.common import ( - get_xp_precision, + cast_precision, to_numpy_array, ) from deepmd.dpmodel.utils import ( @@ -168,6 +168,7 @@ def __init__( self.tebd_dim = tebd_dim self.concat_output_tebd = concat_output_tebd self.trainable = trainable + self.precision = precision def get_rcut(self) -> float: """Returns the cut-off radius.""" @@ -287,6 +288,7 @@ def change_type_map( obj["davg"] = obj["davg"][remap_index] obj["dstd"] = obj["dstd"][remap_index] + @cast_precision def call( self, coord_ext, @@ -741,7 +743,6 @@ def call( res_ij = res_ij * (1.0 / float(self.nnei) / float(self.nnei)) # nf x nl x ng result = xp.reshape(res_ij, (nf, nloc, self.filter_neuron[-1])) - result = xp.astype(result, get_xp_precision(xp, "global")) return ( result, None, diff --git a/deepmd/dpmodel/fitting/dipole_fitting.py b/deepmd/dpmodel/fitting/dipole_fitting.py index 5988951445..986c48d878 100644 --- a/deepmd/dpmodel/fitting/dipole_fitting.py +++ b/deepmd/dpmodel/fitting/dipole_fitting.py @@ -11,6 +11,9 @@ from deepmd.dpmodel import ( DEFAULT_PRECISION, ) +from deepmd.dpmodel.common import ( + cast_precision, +) from deepmd.dpmodel.fitting.base_fitting import ( BaseFitting, ) @@ -174,6 +177,7 @@ def output_def(self): ] ) + @cast_precision def call( self, descriptor: np.ndarray, diff --git a/deepmd/dpmodel/fitting/general_fitting.py b/deepmd/dpmodel/fitting/general_fitting.py index 344dab7ff1..23fa35f3a5 100644 --- a/deepmd/dpmodel/fitting/general_fitting.py +++ b/deepmd/dpmodel/fitting/general_fitting.py @@ -439,18 +439,22 @@ def _call_common( ): assert xx_zeros is not None atom_property -= self.nets[(type_i,)](xx_zeros) - atom_property = atom_property + self.bias_atom_e[type_i, ...] - atom_property = atom_property * xp.astype(mask, atom_property.dtype) + atom_property = xp.where( + mask, atom_property, xp.zeros_like(atom_property) + ) outs = outs + atom_property # Shape is [nframes, natoms[0], 1] else: - outs = self.nets[()](xx) + xp.reshape( - xp.take(self.bias_atom_e, xp.reshape(atype, [-1]), axis=0), - [nf, nloc, net_dim_out], - ) + outs = self.nets[()](xx) if xx_zeros is not None: outs -= self.nets[()](xx_zeros) + outs += xp.reshape( + xp.take( + xp.astype(self.bias_atom_e, outs.dtype), xp.reshape(atype, [-1]), axis=0 + ), + [nf, nloc, net_dim_out], + ) # nf x nloc exclude_mask = self.emask.build_type_exclude_mask(atype) # nf x nloc x nod - outs = outs * xp.astype(exclude_mask[:, :, None], outs.dtype) - return {self.var_name: xp.astype(outs, get_xp_precision(xp, "global"))} + outs = xp.where(exclude_mask[:, :, None], outs, xp.zeros_like(outs)) + return {self.var_name: outs} diff --git a/deepmd/dpmodel/fitting/invar_fitting.py b/deepmd/dpmodel/fitting/invar_fitting.py index afae455441..8fe9c00c0f 100644 --- a/deepmd/dpmodel/fitting/invar_fitting.py +++ b/deepmd/dpmodel/fitting/invar_fitting.py @@ -10,6 +10,9 @@ from deepmd.dpmodel import ( DEFAULT_PRECISION, ) +from deepmd.dpmodel.common import ( + cast_precision, +) from deepmd.dpmodel.output_def import ( FittingOutputDef, OutputVariableDef, @@ -203,6 +206,7 @@ def output_def(self): ] ) + @cast_precision def call( self, descriptor: np.ndarray, diff --git a/deepmd/dpmodel/fitting/polarizability_fitting.py b/deepmd/dpmodel/fitting/polarizability_fitting.py index 8cf3c4faaa..3e01b85056 100644 --- a/deepmd/dpmodel/fitting/polarizability_fitting.py +++ b/deepmd/dpmodel/fitting/polarizability_fitting.py @@ -15,6 +15,7 @@ DEFAULT_PRECISION, ) from deepmd.dpmodel.common import ( + cast_precision, to_numpy_array, ) from deepmd.dpmodel.fitting.base_fitting import ( @@ -241,6 +242,7 @@ def change_type_map( self.scale = self.scale[remap_index] self.constant_matrix = self.constant_matrix[remap_index] + @cast_precision def call( self, descriptor: np.ndarray, @@ -285,7 +287,8 @@ def call( ] # out = out * self.scale[atype, ...] scale_atype = xp.reshape( - xp.take(self.scale, xp.reshape(atype, [-1]), axis=0), (*atype.shape, 1) + xp.take(xp.astype(self.scale, out.dtype), xp.reshape(atype, [-1]), axis=0), + (*atype.shape, 1), ) out = out * scale_atype # (nframes * nloc, m1, 3) @@ -308,7 +311,11 @@ def call( if self.shift_diag: # bias = self.constant_matrix[atype] bias = xp.reshape( - xp.take(self.constant_matrix, xp.reshape(atype, [-1]), axis=0), + xp.take( + xp.astype(self.constant_matrix, out.dtype), + xp.reshape(atype, [-1]), + axis=0, + ), (nframes, nloc), ) # (nframes, nloc, 1) diff --git a/deepmd/dpmodel/utils/network.py b/deepmd/dpmodel/utils/network.py index 9ae5fd2b40..5087b77800 100644 --- a/deepmd/dpmodel/utils/network.py +++ b/deepmd/dpmodel/utils/network.py @@ -248,6 +248,10 @@ def call(self, x: np.ndarray) -> np.ndarray: if self.b is not None else xp.matmul(x, self.w) ) + if y.dtype != x.dtype: + # workaround for bfloat16 + # https://github.com/jax-ml/ml_dtypes/issues/235 + y = xp.astype(y, x.dtype) y = fn(y) if self.idt is not None: y *= self.idt diff --git a/deepmd/jax/env.py b/deepmd/jax/env.py index 1b90433b00..02e31ae66e 100644 --- a/deepmd/jax/env.py +++ b/deepmd/jax/env.py @@ -13,6 +13,9 @@ jax.config.update("jax_enable_x64", True) # jax.config.update("jax_debug_nans", True) +if os.environ.get("DP_DTYPE_PROMOTION_STRICT") == "1": + jax.config.update("jax_numpy_dtype_promotion", "strict") + __all__ = [ "jax", "jnp", diff --git a/source/tests/common/dpmodel/test_network.py b/source/tests/common/dpmodel/test_network.py index 381c542272..4b5eb4fa66 100644 --- a/source/tests/common/dpmodel/test_network.py +++ b/source/tests/common/dpmodel/test_network.py @@ -8,6 +8,9 @@ import numpy as np +from deepmd.dpmodel.common import ( + get_xp_precision, +) from deepmd.dpmodel.utils import ( EmbeddingNet, FittingNet, @@ -46,7 +49,9 @@ def test_serialize_deserize(self): inp_shap = [ni] if ashp is not None: inp_shap = ashp + inp_shap - inp = np.arange(np.prod(inp_shap)).reshape(inp_shap) + inp = np.arange( + np.prod(inp_shap), dtype=get_xp_precision(np, prec) + ).reshape(inp_shap) np.testing.assert_allclose(nl0.call(inp), nl1.call(inp)) def test_shape_error(self): @@ -168,7 +173,7 @@ def test_embedding_net(self): resnet_dt=idt, ) en1 = EmbeddingNet.deserialize(en0.serialize()) - inp = np.ones([ni]) + inp = np.ones([ni], dtype=get_xp_precision(np, prec)) np.testing.assert_allclose(en0.call(inp), en1.call(inp)) @@ -191,7 +196,7 @@ def test_fitting_net(self): bias_out=bo, ) en1 = FittingNet.deserialize(en0.serialize()) - inp = np.ones([ni]) + inp = np.ones([ni], dtype=get_xp_precision(np, prec)) en0.call(inp) en1.call(inp) np.testing.assert_allclose(en0.call(inp), en1.call(inp))