Skip to content

Commit

Permalink
register scoring functions
Browse files Browse the repository at this point in the history
  • Loading branch information
albertbou92 committed Apr 23, 2024
1 parent 7cb8183 commit 802251c
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 11 deletions.
3 changes: 3 additions & 0 deletions acegen/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
__version__ = "1.0"

from models import custom_models
from scoring_functions import custom_scoring_functions, register_custom_scoring_function

from acegen.models import (
create_gpt2_actor,
create_gpt2_actor_critic,
Expand Down
18 changes: 9 additions & 9 deletions acegen/scoring_functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,33 +19,33 @@ def check_scoring_function(scoring_function):
f"scoring_function must be a callable, got {type(scoring_function)}"
)

# Check it accepts a single smiles and returns typing number of a tensor
if not isinstance(scoring_function("CCO"), (int, float, list, Tensor, ndarray)):
# Check it accepts a single smiles and returns a number, list, tensor or array
if not isinstance(scoring_function("CCO"), (float, list, Tensor, ndarray)):
raise ValueError(
f"scoring_function must return a number, got {type(scoring_function('CCO'))}"
f"scoring_function must return a float, list, array or tensor, got {type(scoring_function('CCO'))}"
)

# Check it accepts a single smiles and returns a list of number or a tensor
# Check it accepts multiple smiles and returns a list, a tensor or an array
scores = scoring_function(["CCO", "CCC"])
if not isinstance(scores, (list, Tensor, ndarray)):
raise ValueError(
f"scoring_function must return a list of number, got {type(scoring_function(['CCO', 'CCC']))}"
f"scoring_function must return a list, array or tensor, got {type(scores)}"
)

# If scores is a list, check that each element is a number
# If scores is a list, check that each element is a float
if isinstance(scores, list):
for score in scores:
if not isinstance(score, (int, float)):
if not isinstance(score, float):
raise ValueError(
f"scoring_function must return a list of number, got {type(scoring_function(['CCO', 'CCC']))}"
f"scoring_function must return a list of floats, got {type(score)}"
)


def register_custom_scoring_function(name, scoring_function):
"""Register a custom scoring function.
Example:
>>> from acegen import register_custom_scoring_function, custom_scoring_functions
>>> from acegen.scoring_functions import register_custom_scoring_function, custom_scoring_functions
>>> from my_module import my_scoring_function
>>> register_custom_scoring_function("my_scoring_function", my_scoring_function)
>>> custom_scoring_functions["my_scoring_function"]
Expand Down
1 change: 0 additions & 1 deletion scripts/ahc/ahc.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
LazyTensorStorage,
PrioritizedSampler,
TensorDictMaxValueWriter,
TensorDictPrioritizedReplayBuffer,
TensorDictReplayBuffer,
)
from torchrl.envs import InitTracker, TensorDictPrimer, TransformedEnv
Expand Down
11 changes: 10 additions & 1 deletion scripts/reinvent/reinvent.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
LazyTensorStorage,
PrioritizedSampler,
TensorDictMaxValueWriter,
TensorDictPrioritizedReplayBuffer,
TensorDictReplayBuffer,
)
from torchrl.envs import InitTracker, TensorDictPrimer, TransformedEnv
Expand All @@ -45,6 +44,16 @@
os.chdir("/tmp")


from acegen.scoring_functions import (
custom_scoring_functions,
register_custom_scoring_function,
)

my_scoring_function = lambda x: [float(1)] * len(x)
register_custom_scoring_function("my_scoring_function", my_scoring_function)
custom_scoring_functions["my_scoring_function"]


@hydra.main(
config_path=".",
config_name="config_denovo",
Expand Down

0 comments on commit 802251c

Please sign in to comment.