Skip to content

Commit

Permalink
added a ConditionalScalarEstimator which subclasses ConditionalDiscre…
Browse files Browse the repository at this point in the history
…tePolicyEstimator
  • Loading branch information
josephdviviano committed Oct 4, 2024
1 parent 1d64b55 commit 348ee82
Showing 1 changed file with 32 additions and 2 deletions.
34 changes: 32 additions & 2 deletions src/gfn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def __init__(
self,
module: nn.Module,
n_actions: int,
preprocessor: Preprocessor | None,
preprocessor: Preprocessor | None = None,
is_backward: bool = False,
):
"""Initializes a estimator for P_F for discrete environments.
Expand Down Expand Up @@ -226,7 +226,7 @@ def __init__(
conditioning_module: nn.Module,
final_module: nn.Module,
n_actions: int,
preprocessor: Preprocessor | None,
preprocessor: Preprocessor | None = None,
is_backward: bool = False,
):
"""Initializes a estimator for P_F for discrete environments.
Expand All @@ -252,3 +252,33 @@ def forward(
self._output_dim_is_checked = True

return out


class ConditionalScalarEstimator(ConditionalDiscretePolicyEstimator):
def __init__(
self,
state_module: nn.Module,
conditioning_module: nn.Module,
final_module: nn.Module,
preprocessor: Preprocessor | None = None,
is_backward: bool = False,
):
super().__init__(
state_module,
conditioning_module,
final_module,
n_actions=1,
preprocessor=preprocessor,
is_backward=is_backward,
)

def expected_output_dim(self) -> int:
return 1

def to_probability_distribution(
self,
states: States,
module_output: TT["batch_shape", "output_dim", float],
**policy_kwargs: Optional[dict],
) -> Distribution:
raise NotImplementedError

0 comments on commit 348ee82

Please sign in to comment.