Skip to content

Commit

Permalink
fix(jax): handle DPA-2 pbc/nopbc without mapping (#4363)
Browse files Browse the repository at this point in the history
In the C++ API, generate the mapping for the no PBC and throw the error
for PBC. Considering I forgot setting `atom_modify map yes` when testing
it, others may also forget.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Introduced a function to determine if the model supports message
passing, enhancing the model's interface.
- Added a private member variable to facilitate message passing
functionality.
- Implemented unit tests for the `DeepPot` class, validating its
functionality under various conditions.

- **Bug Fixes**
- Improved error handling for TensorFlow function retrieval, ensuring
more specific exceptions are thrown.
- Enhanced compatibility with earlier model versions by managing
exceptions related to the new message passing variable.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Nov 19, 2024
1 parent 031c3ce commit f879b48
Show file tree
Hide file tree
Showing 4 changed files with 424 additions and 1 deletion.
6 changes: 6 additions & 0 deletions deepmd/jax/jax2tf/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,12 @@ def get_model_def_script():
)

tf_model.get_model_def_script = get_model_def_script

@tf.function
def has_message_passing() -> tf.Tensor:
return tf.constant(model.has_message_passing(), dtype=tf.bool)

tf_model.has_message_passing = has_message_passing
tf.saved_model.save(
tf_model,
model_file,
Expand Down
2 changes: 2 additions & 0 deletions source/api_cc/include/DeepPotJAX.h
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@ class DeepPotJAX : public DeepPotBackend {
std::vector<int64_t> sel;
// number of neighbors
int nnei;
// do message passing
bool do_message_passing;
// padding to nall
int padding_to_nall = 0;
// padding for nloc
Expand Down
24 changes: 23 additions & 1 deletion source/api_cc/src/DeepPotJAX.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ inline TF_DataType get_data_tensor_type(const std::vector<int64_t>& data) {
return TF_INT64;
}

struct tf_function_not_found : public deepmd::deepmd_exception {
public:
tf_function_not_found() : deepmd_exception() {};
tf_function_not_found(const std::string& msg) : deepmd_exception(msg) {};
};

inline TFE_Op* get_func_op(TFE_Context* ctx,
const std::string func_name,
const std::vector<TF_Function*>& funcs,
Expand All @@ -72,7 +78,7 @@ inline TFE_Op* get_func_op(TFE_Context* ctx,
TF_Function* func = NULL;
find_function(func, funcs, func_name);
if (func == NULL) {
throw std::runtime_error("Function " + func_name + " not found");
throw tf_function_not_found("Function " + func_name + " not found");
}
const char* real_func_name = TF_FunctionName(func);
// execute the function
Expand Down Expand Up @@ -314,6 +320,13 @@ void deepmd::DeepPotJAX::init(const std::string& model,
ntypes = type_map_.size();
sel = get_vector<int64_t>(ctx, "get_sel", func_vector, device, status);
nnei = std::accumulate(sel.begin(), sel.end(), decltype(sel)::value_type(0));
try {
do_message_passing = get_scalar<bool>(ctx, "do_message_passing",
func_vector, device, status);
} catch (tf_function_not_found& e) {
// compatibile with models generated by v3.0.0rc0
do_message_passing = false;
}
inited = true;
}

Expand Down Expand Up @@ -584,6 +597,15 @@ void deepmd::DeepPotJAX::compute(std::vector<ENERGYTYPE>& ener,
for (size_t ii = 0; ii < nall_real; ii++) {
mapping[ii] = lmp_list.mapping[fwd_map[ii]];
}
} else if (nloc_real == nall_real) {
// no ghost atoms
for (size_t ii = 0; ii < nall_real; ii++) {
mapping[ii] = ii;
}
} else if (do_message_passing) {
throw deepmd::deepmd_exception(
"Mapping is required for a message passing model. If you are using "
"LAMMPS, set `atom_modify map yes`");
}
input_list[3] = add_input(op, mapping, mapping_shape, data_tensor[3], status);
// fparam
Expand Down
Loading

0 comments on commit f879b48

Please sign in to comment.