Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trajectory Class Design #184

Open
ashdtu opened this issue Aug 23, 2024 · 0 comments
Open

Trajectory Class Design #184

ashdtu opened this issue Aug 23, 2024 · 0 comments
Assignees

Comments

@ashdtu
Copy link
Collaborator

ashdtu commented Aug 23, 2024

Description

  1. Indexing Order: The State class currently supports batching along both Num Timesteps dimension(i.e. batch size (Timesteps, )) and along Num Timesteps, Num Trajectories dimension (i.e. batch size (Timesteps, Trajectory)). Flipping the indexing dimension (Timesteps, Trajectory) --> (Trajectory, Timesteps) maybe more user friendly to understand since conceptually Trajectories are a container over the State class.
    The decision to keep the former might have been motivated by keeping an easy access to the sink/special states using the mask, but we can check if it's worth flipping to the latter. On first glance, it seems like lot of small changes need to be made to accommodate it.

  2. The Trajectory class should be indifferent to the implementation of the State, Action class. Currently for eg: the __repr__() method uses the following hardcoded implementation assuming State always has a tensor attribute. This fails for eg on GraphStates implementation. This needs to be kept generic and derived from the states, actions, __repr__() method.

def __repr__(self) -> str:
        states = self.states.tensor.transpose(0, 1)
        assert states.ndim == 3
        trajectories_representation = ""
        for traj in states[:10]:
            one_traj_repr = []
            for step in traj:
                one_traj_repr.append(str(step.numpy()))
                if step.equal(self.env.s0 if self.is_backward else self.env.sf):
                    break
            trajectories_representation += "-> ".join(one_traj_repr) + "\n"
        return (
            f"Trajectories(n_trajectories={self.n_trajectories}, max_length={self.max_length}, First 10 trajectories:"
            + f"states=\n{trajectories_representation}"
            # + f"actions=\n{self.actions.tensor.squeeze().transpose(0, 1)[:10].numpy()}, "
            + f"when_is_done={self.when_is_done[:10].numpy()})"
        )
  1. We can try to see if we can make the entry point for accessing States consistent. FlowMatching for example, directly accesses the States objects and other algorithms use Trajectory class to access the State attribute. We should check if we can have a general template for users implementing their own loss functions to keep things consistent. We can maybe have a simple example how to address both cases using the Trajectory class itself.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants