From b9f640ba2804a981797c6c6a12c7ca05bb680ff0 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 19 Nov 2024 21:30:29 -0500 Subject: [PATCH 1/2] fix: don't distinguish nlist types in high model interfaces It will be distinguished in the low interfaces anyway. Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/model/make_model.py | 4 +++- deepmd/jax/jax2tf/make_model.py | 4 +++- deepmd/pt/model/model/make_model.py | 4 +++- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/deepmd/dpmodel/model/make_model.py b/deepmd/dpmodel/model/make_model.py index a539adb292..fbf2c6e21f 100644 --- a/deepmd/dpmodel/model/make_model.py +++ b/deepmd/dpmodel/model/make_model.py @@ -108,7 +108,9 @@ def model_call_from_call_lower( nloc, rcut, sel, - distinguish_types=not mixed_types, + # types will be distinguished in the lower interface, + # so it doesn't need to be distinguished here + distinguish_types=False, ) extended_coord = extended_coord.reshape(nframes, -1, 3) model_predict_lower = call_lower( diff --git a/deepmd/jax/jax2tf/make_model.py b/deepmd/jax/jax2tf/make_model.py index d21fc998b5..29ed131f8e 100644 --- a/deepmd/jax/jax2tf/make_model.py +++ b/deepmd/jax/jax2tf/make_model.py @@ -90,7 +90,9 @@ def model_call_from_call_lower( nloc, rcut, sel, - distinguish_types=not mixed_types, + # types will be distinguished in the lower interface, + # so it doesn't need to be distinguished here + distinguish_types=False, ) extended_coord = extended_coord.reshape(nframes, -1, 3) model_predict_lower = call_lower( diff --git a/deepmd/pt/model/model/make_model.py b/deepmd/pt/model/model/make_model.py index 6bb5f6b8e9..15b8a00613 100644 --- a/deepmd/pt/model/model/make_model.py +++ b/deepmd/pt/model/model/make_model.py @@ -175,7 +175,9 @@ def forward_common( atype, self.get_rcut(), self.get_sel(), - mixed_types=self.mixed_types(), + # types will be distinguished in the lower interface, + # so it doesn't need to be distinguished here + mixed_types=False, box=bb, ) model_predict_lower = self.forward_common_lower( From d3e244b94d03f86a099c1f610209b3b6f9527946 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 19 Nov 2024 22:32:50 -0500 Subject: [PATCH 2/2] True, not False --- deepmd/pt/model/model/make_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/pt/model/model/make_model.py b/deepmd/pt/model/model/make_model.py index 15b8a00613..83abf9ee4a 100644 --- a/deepmd/pt/model/model/make_model.py +++ b/deepmd/pt/model/model/make_model.py @@ -177,7 +177,7 @@ def forward_common( self.get_sel(), # types will be distinguished in the lower interface, # so it doesn't need to be distinguished here - mixed_types=False, + mixed_types=True, box=bb, ) model_predict_lower = self.forward_common_lower(