Skip to content

Commit

Permalink
- enable building stubs in both in-tree and out-of-tree builds (rdkit…
Browse files Browse the repository at this point in the history
…#6980)

with cmake --build . --target stubs (also make stubs on *NIX)
- improved the patching script to do a better assignment of
  overloaded constructor parameters, whcih results in a number
  of docstring fixes

Co-authored-by: ptosco <[email protected]>
  • Loading branch information
ptosco and ptosco authored Dec 13, 2023
1 parent c7c9ad3 commit c43f6d3
Show file tree
Hide file tree
Showing 15 changed files with 136 additions and 59 deletions.
1 change: 0 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ add_library(rdkit_base INTERFACE)

option(RDK_BUILD_SWIG_WRAPPERS "build the SWIG wrappers" OFF )
option(RDK_BUILD_PYTHON_WRAPPERS "build the standard python wrappers" ON )
option(RDK_BUILD_PYTHON_STUBS "build the python stubs" OFF )
option(RDK_BUILD_COMPRESSED_SUPPLIERS "build in support for compressed MolSuppliers" OFF )
option(RDK_BUILD_INCHI_SUPPORT "build the rdkit inchi wrapper" OFF )
option(RDK_BUILD_AVALON_SUPPORT "install support for the avalon toolkit. Use the variable AVALONTOOLS_DIR to set the location of the source." OFF )
Expand Down
2 changes: 1 addition & 1 deletion Code/DataStructs/ExplicitBitVect.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class RDKIT_DATASTRUCTS_EXPORT ExplicitBitVect : public BitVect {
ExplicitBitVect(unsigned int size, bool bitsSet);
ExplicitBitVect(const ExplicitBitVect &other);
//! construct from a string pickle
ExplicitBitVect(const std::string &);
ExplicitBitVect(const std::string &pkl);
//! construct from a text pickle
ExplicitBitVect(const char *, const unsigned int);
//! construct directly from a dynamic_bitset pointer
Expand Down
2 changes: 1 addition & 1 deletion Code/DataStructs/SparseBitVect.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class RDKIT_DATASTRUCTS_EXPORT SparseBitVect : public BitVect {
std::copy(bv->begin(), bv->end(), std::inserter(*dp_bits, dp_bits->end()));
}
//! construct from a string pickle
SparseBitVect(const std::string &);
SparseBitVect(const std::string &pkl);
//! construct from a text pickle
SparseBitVect(const char *data, const unsigned int dataLen);

Expand Down
2 changes: 1 addition & 1 deletion Code/DataStructs/Wrap/DiscreteValueVect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ struct discreteValVec_wrapper {
python::class_<DiscreteValueVect>(
"DiscreteValueVect", disValVectDoc.c_str(),
python::init<DiscreteValueVect::DiscreteValueType, unsigned int>(
python::args("self", "pkl", "len"), "Constructor"))
python::args("self", "valType", "length"), "Constructor"))
.def(python::init<std::string>(python::args("self", "pkl")))
.def("__len__", &DiscreteValueVect::getLength, python::args("self"),
"Get the number of entries in the vector")
Expand Down
2 changes: 1 addition & 1 deletion Code/DataStructs/Wrap/wrap_ExplicitBV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ struct EBV_wrapper {
python::class_<ExplicitBitVect, boost::shared_ptr<ExplicitBitVect>>(
"ExplicitBitVect", ebvClassDoc.c_str(),
python::init<unsigned int>(python::args("self", "size")))
.def(python::init<std::string>(python::args("self", "size")))
.def(python::init<std::string>(python::args("self", "pkl")))
.def(python::init<unsigned int, bool>(
python::args("self", "size", "bitsSet")))
.def("SetBit", (bool(EBV::*)(unsigned int)) & EBV::setBit,
Expand Down
2 changes: 1 addition & 1 deletion Code/DataStructs/Wrap/wrap_SparseBV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ struct SBV_wrapper {
python::class_<SparseBitVect, boost::shared_ptr<SparseBitVect>>(
"SparseBitVect", sbvClassDoc.c_str(),
python::init<unsigned int>(python::args("self", "size")))
.def(python::init<std::string>(python::args("self", "size")))
.def(python::init<std::string>(python::args("self", "pkl")))
.def("SetBit", (bool(SBV::*)(unsigned int)) & SBV::setBit,
python::args("self", "which"),
"Turns on a particular bit. Returns the original state of the "
Expand Down
2 changes: 1 addition & 1 deletion Code/GraphMol/ChemReactions/Wrap/rdChemReactions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,7 @@ Sample Usage:
python::init<>(python::args("self"), "Constructor, takes no arguments"))
.def(python::init<const std::string &>(python::args("self", "binStr")))
.def(python::init<const RDKit::ChemicalReaction &>(
python::args("self", "binStr")))
python::args("self", "other")))
.def("GetNumReactantTemplates",
&RDKit::ChemicalReaction::getNumReactantTemplates,
python::args("self"),
Expand Down
4 changes: 2 additions & 2 deletions Code/GraphMol/FilterCatalog/Wrap/FilterCatalog.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ struct filtercat_wrapper {
python::bases<FilterMatcherBase>>(
"SmartsMatcher", SmartsMatcherDoc,
python::init<const std::string &>(python::args("self", "name")))
.def(python::init<const ROMol &>(python::args("self", "name"),
.def(python::init<const ROMol &>(python::args("self", "rhs"),
"Construct from a molecule"))
.def(python::init<const std::string &, const ROMol &, unsigned int,
unsigned int>(
Expand Down Expand Up @@ -532,7 +532,7 @@ struct filtercat_wrapper {
python::init<>(python::args("self")))
.def(python::init<const std::string &>(python::args("self", "binStr")))
.def(python::init<const FilterCatalogParams &>(
python::args("self", "catalogs")))
python::args("self", "params")))
.def(python::init<FilterCatalogParams::FilterCatalogs>(
python::args("self", "catalogs")))
.def("Serialize", &FilterCatalog_Serialize, python::args("self"))
Expand Down
4 changes: 2 additions & 2 deletions Code/GraphMol/Fingerprints/Wrap/MHFPWrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ BOOST_PYTHON_FUNCTION_OVERLOADS(EncodeSECFPMolsBulkOverloads,

BOOST_PYTHON_MODULE(rdMHFPFingerprint) {
python::class_<MHFPEncoder>(
"MHFPEncoder",
python::init<python::optional<unsigned int, unsigned int>>())
"MHFPEncoder", python::init<python::optional<unsigned int, unsigned int>>(
python::args("self", "n_permutations", "seed")))
.def("FromStringArray", FromStringArray, python::args("self", "vec"),
"Creates a MHFP vector from a list of arbitrary strings.")
.def("FromArray", FromArray, python::args("self", "vec"),
Expand Down
5 changes: 3 additions & 2 deletions Code/GraphMol/SubstructLibrary/Wrap/SubstructLibraryWrap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,8 @@ struct substructlibrary_wrapper {
" - idx: which molecule to return\n\n"
" - sanitize: if sanitize is False, return the internal "
"molecule state [default True]\n\n"
" NOTE: molecule indices start at 0\n");
" NOTE: molecule indices start at 0\n")
.def("__len__", &MolHolderBase::size, python::args("self"));

python::class_<MolHolder, boost::shared_ptr<MolHolder>,
python::bases<MolHolderBase>>(
Expand Down Expand Up @@ -753,7 +754,7 @@ struct substructlibrary_wrapper {
python::args("self", "molecules", "fingerprints")))
.def(python::init<boost::shared_ptr<MolHolderBase>,
boost::shared_ptr<KeyHolderBase>>(
python::args("self", "molecules", "fingerprints")))
python::args("self", "molecules", "keys")))
.def(python::init<boost::shared_ptr<MolHolderBase>,
boost::shared_ptr<FPHolderBase>,
boost::shared_ptr<KeyHolderBase>>(
Expand Down
7 changes: 4 additions & 3 deletions Code/GraphMol/Wrap/Atom.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,11 @@ Note that, though it is possible to create one, having an Atom on its own\n\
(i.e not associated with a molecule) is not particularly useful.\n";
struct atom_wrapper {
static void wrap() {
python::class_<Atom>("Atom", atomClassDoc.c_str(),
python::init<std::string>(python::args("self", "num")))
python::class_<Atom>(
"Atom", atomClassDoc.c_str(),
python::init<std::string>(python::args("self", "what")))

.def(python::init<const Atom &>(python::args("self", "num")))
.def(python::init<const Atom &>(python::args("self", "other")))
.def(python::init<unsigned int>(
python::args("self", "num"),
"Constructor, takes the atomic number"))
Expand Down
2 changes: 1 addition & 1 deletion Code/GraphMol/Wrap/Conformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ struct conformer_wrapper {
.def(python::init<unsigned int>(
python::args("self", "numAtoms"),
"Constructor with the number of atoms specified"))
.def(python::init<const Conformer &>(python::args("self", "numAtoms")))
.def(python::init<const Conformer &>(python::args("self", "other")))

.def("GetNumAtoms", &Conformer::getNumAtoms, python::args("self"),
"Get the number of atoms in the conformer\n")
Expand Down
2 changes: 1 addition & 1 deletion Scripts/gen_rdkit_stubs/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def parse_args():
default_n_cpus = max(1, multiprocessing.cpu_count() - 2)
default_output_dirs = [os.getcwd()]
parser = argparse.ArgumentParser()
parser.add_argument("--concurrency",
parser.add_argument("--concurrency", type=int,
help=f"max number of CPUs to be used (defaults to {default_n_cpus})",
default=default_n_cpus)
parser.add_argument("--verbose",
Expand Down
99 changes: 82 additions & 17 deletions Scripts/patch_rdkit_docstrings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import sys
import os
import re
import itertools
import glob
import json
import importlib
Expand All @@ -23,6 +24,7 @@
from threading import Thread
from pathlib import Path


RDKIT_MODULE_NAME = "rdkit"
CLANG_CPP_EXE = os.environ.get("CLANG_CPP_EXE", "clang++")
CLANG_FORMAT_EXE = os.environ.get("CLANG_FORMAT_EXE", "clang-format")
Expand Down Expand Up @@ -129,7 +131,8 @@ class CppFile(DictLike):

QUOTED_FIELD_REGEX = re.compile(r"\"([^\"]*)\"")
EXTRACT_BASE_CLASS_NAME_REGEX = re.compile(r"\s*(\S+)\s*<[^>]+>\s*$")
EXTRACT_INIT_ARGS = re.compile(r"^<(.*)>$")
EXTRACT_INIT_ARGS = re.compile(r"^<(.*)\s>+\s$")
IS_TEMPLATE_TYPE = re.compile(r"^T\d*$")
SELF_LITERAL = "self"

def __init__(self, cpp_path=None):
Expand Down Expand Up @@ -618,7 +621,22 @@ def find_cpp_class_r(self, cursor, cpp_class_name, func_name):
return res

@staticmethod
def num_matching_parameters(expected_params, params):
def have_param(param_list, param):
"""If param is part of param_list return True and pop it from param_list.
Args:
param_list (list[str]): list of parameters
param (str): parameter
Returns:
bool: True if param is part of param_list, False if not
"""
res = param in param_list
if res:
param_list.pop(param_list.index(param))
return res

def num_matching_parameters(self, expected_params, params):
"""Find the number of matching params between params
(list of individual parameter typenames) and expected_params
(concatenated string of expected parameter typenames)
Expand All @@ -627,10 +645,13 @@ def num_matching_parameters(expected_params, params):
params (list[str]): list of individual parameter typenames
Returns:
int: number of matching params
tuple[int, int]: number of matching params, number of non-matching params
"""
expected_params_concat = "".join(expected_params)
return [p in expected_params_concat for p in params].count(True)
expected_params_tok = [p.split("::")[-1] for p in expected_params.split()]
params_tok = [p.split("::")[-1] for p in " ".join(params).split()]
num_matched_params = [self.have_param(expected_params_tok, p) for p in params_tok].count(True)
num_non_matched_params = len(params_tok) - num_matched_params
return num_matched_params, -num_non_matched_params

def find_cpp_func_params(self, cursor, is_staticmethod, cpp_class_name, func_name,
expected_cpp_params, expected_param_count):
Expand All @@ -656,10 +677,23 @@ def find_cpp_func_params(self, cursor, is_staticmethod, cpp_class_name, func_nam
list[str]: list of parameter names
"""
self.params = None
for cmp_func in (int.__eq__, int.__gt__):
self.find_cpp_func_params_r(cursor, cpp_class_name, func_name,
expected_cpp_params, expected_param_count, cmp_func)
assigned_overloads = None
if cpp_class_name == func_name:
key = f"{cpp_class_name}::{cpp_class_name}"
assigned_overloads = self.assigned_overloads.get(key, [])
if not assigned_overloads:
self.assigned_overloads[key] = assigned_overloads
self.assigned_overloads_for_func = assigned_overloads
for accept_params_no_type in (False, True):
self.accept_params_no_type = accept_params_no_type
for cmp_func in (int.__eq__, int.__gt__):
self.find_cpp_func_params_r(cursor, cpp_class_name, func_name,
expected_cpp_params, expected_param_count, cmp_func)
if self.params is not None:
break
if self.params is not None:
if assigned_overloads is not None and not self.has_template_type(self.params):
assigned_overloads.append(self.get_params_hash(self.params))
break
if self.params is None:
params = [f"arg{i + 1}" for i in range(expected_param_count)]
Expand All @@ -668,6 +702,31 @@ def find_cpp_func_params(self, cursor, is_staticmethod, cpp_class_name, func_nam
return params
return [p for p, _ in self.params]

def has_template_type(self, params):
"""Find if any parameter in params is of template type.
Args:
params (list[tuple[str, str]]): list of (name, type) tuples
Returns:
bool: True if params contain parameters of template type
(i.e., T, optionally followed by a number)
"""
return any(self.IS_TEMPLATE_TYPE.match(t) for _, t in params)

@staticmethod
def get_params_hash(params):
"""Get a hash from function parameters.
Args:
params (list[tuple[str, str]]): list of function parameters
as (parameter name, paramater type) tuples
Returns:
tuple: a sorted tuple that can be used as a hash
"""
return tuple(sorted(params))

def find_cpp_func_params_r(self, cursor, cpp_class_name, func_name,
expected_cpp_params, expected_param_count, cmp_func):
"""Find parameter names of a C++ method (recursive).
Expand Down Expand Up @@ -695,16 +754,21 @@ def find_cpp_func_params_r(self, cursor, cpp_class_name, func_name,
accepted_kinds.append(CursorKind.CONSTRUCTOR)
for child in cursor.get_children():
if child.kind in accepted_kinds and child.spelling.split("<")[0] == func_name:
params = [(child2.spelling, "".join(child3.spelling for child3 in child2.get_children() if child3.kind == CursorKind.TYPE_REF))
for child2 in child.get_children() if child2.kind == CursorKind.PARM_DECL]
params = [(child2.spelling, " ".join(child3.spelling for child3 in child2.get_children()
if child3.kind in (CursorKind.TEMPLATE_REF, CursorKind.TYPE_REF)))
for child2 in child.get_children() if child2.kind == CursorKind.PARM_DECL]
# certain C++ headers have only the type declaration but no variable name,
# in that case we replace "" with a dummy parameter name since python::args("")
# is not acceptable
params = [(p or f"arg{i + 1}", t) for i, (p, t) in enumerate(params)]
params_hash = self.get_params_hash(params)
if self.assigned_overloads_for_func is not None and params_hash in self.assigned_overloads_for_func:
continue
if ((expected_param_count == -1 or cmp_func(len(params), expected_param_count))
and (not expected_cpp_params or self.params is None or
self.num_matching_parameters(expected_cpp_params, [t for _, t in params])
> self.num_matching_parameters(expected_cpp_params, [t for _, t in self.params]))):
and (not expected_cpp_params or (self.accept_params_no_type and self.params is None)
or (self.params is not None and
self.num_matching_parameters(expected_cpp_params, [t for _, t in params])
> self.num_matching_parameters(expected_cpp_params, [t for _, t in self.params])))):
if expected_param_count != -1:
params = params[:expected_param_count]
self.params = params
Expand Down Expand Up @@ -892,7 +956,7 @@ def find_no_arg(self, is_init, tokens, is_staticmethod, cpp_func_name, expected_
open_bracket_count = t.spelling.count("<")
closed_bracket_count = t.spelling.count(">")
if open_bracket_count or bracket_count:
init_args += t.spelling
init_args += t.spelling + " "
bracket_count += (open_bracket_count - closed_bracket_count)
if bracket_count == 0:
if init_args:
Expand All @@ -901,7 +965,7 @@ def find_no_arg(self, is_init, tokens, is_staticmethod, cpp_func_name, expected_
init_args = ""
is_init = False
else:
init_args = m.group(1)
init_args = m.group(1).replace("<", "").strip()
if init_args:
cpp_func_name = f"{class_info.cpp_class_name}::{class_info.cpp_class_name}"
expected_param_count = 1 + init_args.count(",")
Expand Down Expand Up @@ -1115,6 +1179,7 @@ def parse_ast(self, arg1_func_byclass_dict):
that need fixing. Also free functions are included under class name
FixSignatures.NO_CLASS_KEY
"""
self.assigned_overloads = {}
try:
translation_unit = TranslationUnit.from_ast_file(self.ast_path)
out_path = self.cpp_path_noext + ".out"
Expand All @@ -1124,7 +1189,7 @@ def parse_ast(self, arg1_func_byclass_dict):
with open(log_path, "w") as hnd:
pass
class_info_by_class_hash = self.find_nodes(translation_unit.cursor)
class_method_node_hashes = set(sum([[node.hash for node in class_info.parents] for class_info in class_info_by_class_hash.values()], []))
class_method_node_hashes = set(itertools.chain.from_iterable([node.hash for node in class_info.parents] for class_info in class_info_by_class_hash.values()))
arg1_non_class_func_names = arg1_func_byclass_dict.get(FixSignatures.NO_CLASS_KEY, None)
if arg1_non_class_func_names is not None:
non_class_defs = self.find_non_class_defs(translation_unit.cursor, class_method_node_hashes, arg1_non_class_func_names)
Expand Down Expand Up @@ -1483,7 +1548,7 @@ def generate_ast_files(self):
self.queue = queue.Queue()
cpp_class_files = list(self.cpp_file_dict.values())
# Uncomment the following to troubleshoot specific file(s)
# cpp_class_files = [f for f in cpp_class_files if os.path.basename(f.cpp_path).startswith("Validate")]
# cpp_class_files = [f for f in cpp_class_files if os.path.basename(f.cpp_path) == "Atom.cpp"]
n_files = len(cpp_class_files)
self.logger.debug(f"Number of files: {n_files}")
n_workers = min(self.concurrency, n_files)
Expand Down
Loading

0 comments on commit c43f6d3

Please sign in to comment.