diff --git a/src/gfn/env.py b/src/gfn/env.py index c1c1085..7249f36 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -94,6 +94,8 @@ def states_from_batch_shape( Args: batch_shape: Tuple representing the shape of the batch of states. + random (optional): Initalize states randomly. + sink (optional): States initialized with s_f (the sink state). Returns: States: A batch of initial states.