diff --git a/src/gfn/modules.py b/src/gfn/modules.py index d8f9e31c..e7dea7fd 100644 --- a/src/gfn/modules.py +++ b/src/gfn/modules.py @@ -148,7 +148,7 @@ def __init__( self, module: nn.Module, n_actions: int, - preprocessor: Preprocessor | None, + preprocessor: Preprocessor | None = None, is_backward: bool = False, ): """Initializes a estimator for P_F for discrete environments. @@ -226,7 +226,7 @@ def __init__( conditioning_module: nn.Module, final_module: nn.Module, n_actions: int, - preprocessor: Preprocessor | None, + preprocessor: Preprocessor | None = None, is_backward: bool = False, ): """Initializes a estimator for P_F for discrete environments. @@ -252,3 +252,33 @@ def forward( self._output_dim_is_checked = True return out + + +class ConditionalScalarEstimator(ConditionalDiscretePolicyEstimator): + def __init__( + self, + state_module: nn.Module, + conditioning_module: nn.Module, + final_module: nn.Module, + preprocessor: Preprocessor | None = None, + is_backward: bool = False, + ): + super().__init__( + state_module, + conditioning_module, + final_module, + n_actions=1, + preprocessor=preprocessor, + is_backward=is_backward, + ) + + def expected_output_dim(self) -> int: + return 1 + + def to_probability_distribution( + self, + states: States, + module_output: TT["batch_shape", "output_dim", float], + **policy_kwargs: Optional[dict], + ) -> Distribution: + raise NotImplementedError