From f03c54b3cd467335ac63a40a32c61c79cb3179a4 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 20 Nov 2024 15:22:47 -0500 Subject: [PATCH 1/2] chore(pt): Change the type of `do_message_passing` from `int` to `bool` in `DeepPotPT` and `DeepSpinPT` classes Fix #4366. * Update the type of `do_message_passing` to `bool` in the `DeepPotPT` class and `init` method in `source/api_cc/include/DeepPotPT.h` and `source/api_cc/src/DeepPotPT.cc` * Update the type of `do_message_passing` to `bool` in the `DeepSpinPT` class and `init` method in `source/api_cc/include/DeepSpinPT.h` and `source/api_cc/src/DeepSpinPT.cc` --- source/api_cc/include/DeepPotPT.h | 3 +-- source/api_cc/include/DeepSpinPT.h | 3 +-- source/api_cc/src/DeepPotPT.cc | 5 ++--- source/api_cc/src/DeepSpinPT.cc | 5 ++--- 4 files changed, 6 insertions(+), 10 deletions(-) diff --git a/source/api_cc/include/DeepPotPT.h b/source/api_cc/include/DeepPotPT.h index 8f69168b5a..b8d59d790a 100644 --- a/source/api_cc/include/DeepPotPT.h +++ b/source/api_cc/include/DeepPotPT.h @@ -1,4 +1,3 @@ -// SPDX-License-Identifier: LGPL-3.0-or-later #pragma once #include @@ -335,7 +334,7 @@ class DeepPotPT : public DeepPotBackend { NeighborListData nlist_data; int max_num_neighbors; int gpu_id; - int do_message_passing; // 1:dpa2 model 0:others + bool do_message_passing; // 1:dpa2 model 0:others bool gpu_enabled; at::Tensor firstneigh_tensor; c10::optional mapping_tensor; diff --git a/source/api_cc/include/DeepSpinPT.h b/source/api_cc/include/DeepSpinPT.h index 643557eb07..462bc783d7 100644 --- a/source/api_cc/include/DeepSpinPT.h +++ b/source/api_cc/include/DeepSpinPT.h @@ -1,4 +1,3 @@ -// SPDX-License-Identifier: LGPL-3.0-or-later #pragma once #include @@ -257,7 +256,7 @@ class DeepSpinPT : public DeepSpinBackend { NeighborListData nlist_data; int max_num_neighbors; int gpu_id; - int do_message_passing; // 1:dpa2 model 0:others + bool do_message_passing; // 1:dpa2 model 0:others bool gpu_enabled; at::Tensor firstneigh_tensor; c10::optional mapping_tensor; diff --git a/source/api_cc/src/DeepPotPT.cc b/source/api_cc/src/DeepPotPT.cc index ce104b0f8e..79494f7ed6 100644 --- a/source/api_cc/src/DeepPotPT.cc +++ b/source/api_cc/src/DeepPotPT.cc @@ -1,4 +1,3 @@ -// SPDX-License-Identifier: LGPL-3.0-or-later #ifdef BUILD_PYTORCH #include "DeepPotPT.h" @@ -171,7 +170,7 @@ void DeepPotPT::compute(ENERGYVTYPE& ener, nlist_data.copy_from_nlist(lmp_list); nlist_data.shuffle_exclude_empty(fwd_map); nlist_data.padding(); - if (do_message_passing == 1) { + if (do_message_passing) { int nswap = lmp_list.nswap; torch::Tensor sendproc_tensor = torch::from_blob(lmp_list.sendproc, {nswap}, int32_option); @@ -234,7 +233,7 @@ void DeepPotPT::compute(ENERGYVTYPE& ener, .to(device); } c10::Dict outputs = - (do_message_passing == 1) + (do_message_passing) ? module .run_method("forward_lower", coord_wrapped_Tensor, atype_Tensor, firstneigh_tensor, mapping_tensor, fparam_tensor, diff --git a/source/api_cc/src/DeepSpinPT.cc b/source/api_cc/src/DeepSpinPT.cc index 3ae0eb3bb7..1b28274d8b 100644 --- a/source/api_cc/src/DeepSpinPT.cc +++ b/source/api_cc/src/DeepSpinPT.cc @@ -1,4 +1,3 @@ -// SPDX-License-Identifier: LGPL-3.0-or-later #ifdef BUILD_PYTORCH #include "DeepSpinPT.h" @@ -179,7 +178,7 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener, nlist_data.copy_from_nlist(lmp_list); nlist_data.shuffle_exclude_empty(fwd_map); nlist_data.padding(); - if (do_message_passing == 1) { + if (do_message_passing) { int nswap = lmp_list.nswap; torch::Tensor sendproc_tensor = torch::from_blob(lmp_list.sendproc, {nswap}, int32_option); @@ -234,7 +233,7 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener, .to(device); } c10::Dict outputs = - (do_message_passing == 1) + (do_message_passing) ? module .run_method("forward_lower", coord_wrapped_Tensor, atype_Tensor, spin_wrapped_Tensor, firstneigh_tensor, From 9e4eeb87a91f2cb3c483f790aceaccd07c1befb8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 20 Nov 2024 20:24:43 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- source/api_cc/include/DeepPotPT.h | 1 + source/api_cc/include/DeepSpinPT.h | 1 + source/api_cc/src/DeepPotPT.cc | 1 + source/api_cc/src/DeepSpinPT.cc | 1 + 4 files changed, 4 insertions(+) diff --git a/source/api_cc/include/DeepPotPT.h b/source/api_cc/include/DeepPotPT.h index b8d59d790a..207a13286c 100644 --- a/source/api_cc/include/DeepPotPT.h +++ b/source/api_cc/include/DeepPotPT.h @@ -1,3 +1,4 @@ +// SPDX-License-Identifier: LGPL-3.0-or-later #pragma once #include diff --git a/source/api_cc/include/DeepSpinPT.h b/source/api_cc/include/DeepSpinPT.h index 462bc783d7..be4c85d898 100644 --- a/source/api_cc/include/DeepSpinPT.h +++ b/source/api_cc/include/DeepSpinPT.h @@ -1,3 +1,4 @@ +// SPDX-License-Identifier: LGPL-3.0-or-later #pragma once #include diff --git a/source/api_cc/src/DeepPotPT.cc b/source/api_cc/src/DeepPotPT.cc index 79494f7ed6..7e5d391b1f 100644 --- a/source/api_cc/src/DeepPotPT.cc +++ b/source/api_cc/src/DeepPotPT.cc @@ -1,3 +1,4 @@ +// SPDX-License-Identifier: LGPL-3.0-or-later #ifdef BUILD_PYTORCH #include "DeepPotPT.h" diff --git a/source/api_cc/src/DeepSpinPT.cc b/source/api_cc/src/DeepSpinPT.cc index 1b28274d8b..c72cb34b15 100644 --- a/source/api_cc/src/DeepSpinPT.cc +++ b/source/api_cc/src/DeepSpinPT.cc @@ -1,3 +1,4 @@ +// SPDX-License-Identifier: LGPL-3.0-or-later #ifdef BUILD_PYTORCH #include "DeepSpinPT.h"