Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rethinking sampling #147

Merged
merged 70 commits into from
Feb 16, 2024
Merged
Changes from 1 commit
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
95f4e01
improved documentation, and also formatting
josephdviviano Nov 16, 2023
7b0688b
notes, and some changes RE: passing of policy_kwargs (renamed)
josephdviviano Nov 16, 2023
0b812e5
renaming policy kwargs and changing usage
josephdviviano Nov 16, 2023
6eeb0af
added policy_kwargs
josephdviviano Nov 16, 2023
a391b28
renamed args
josephdviviano Nov 16, 2023
d9fa884
changed logic and added TODOs for improved efficiency
josephdviviano Nov 16, 2023
2a16b1b
sample trajectories bugfix
josephdviviano Nov 16, 2023
740e16f
sample trajectories bugfix
josephdviviano Nov 16, 2023
9d1ffd0
sync for debug
josephdviviano Nov 18, 2023
963693c
adding line environment
josephdviviano Nov 18, 2023
76ab487
TODO
josephdviviano Nov 21, 2023
8999dd3
added logic for sampling off policy and some helper functions
josephdviviano Nov 21, 2023
450ebf0
estimator_outputs can be passed around
josephdviviano Nov 21, 2023
a8b637e
estimator outputs can be saved
josephdviviano Nov 21, 2023
1acfcce
tweaks to demo
josephdviviano Nov 21, 2023
e052c82
added back in default recomputing behaviour for pf in off policy mode.
josephdviviano Nov 21, 2023
f897aab
bugfix
josephdviviano Nov 21, 2023
67ea36e
simplified logprobs calc
josephdviviano Nov 21, 2023
e88e1bb
v1 of the line tutorial
josephdviviano Nov 22, 2023
b28dc95
estimator outputs can be passed around to avoid recalculation
josephdviviano Nov 22, 2023
a72afe9
documentation & removal of reward clamping.
josephdviviano Nov 22, 2023
46693af
added clone (just a test)
josephdviviano Nov 23, 2023
119559d
black formatting, debugging code left in (commented), and log_reward_…
josephdviviano Nov 23, 2023
e8ab999
log_reward_clip_min is now default off
josephdviviano Nov 23, 2023
5e87e3f
log_reward_clip_min is now optional
josephdviviano Nov 23, 2023
732fb0f
black
josephdviviano Nov 23, 2023
5048f3c
added log reward clipping
josephdviviano Nov 23, 2023
3b4e597
formatting
josephdviviano Nov 23, 2023
5ef0c22
reorg of training loop (nothing is functionally different
josephdviviano Nov 23, 2023
d69d258
isort
josephdviviano Nov 23, 2023
e29a278
variable naming and a note
josephdviviano Nov 23, 2023
ff45949
no longer using deepcopy. removed all log_reward clipping, which shou…
josephdviviano Nov 23, 2023
2ca4ced
note RE typecasting
josephdviviano Nov 23, 2023
a4c1786
improved efficiency of the init, and also added a clone method for st…
josephdviviano Nov 23, 2023
f6edd53
note to self RE typecasting in the identity preprocessor -- not sure …
josephdviviano Nov 23, 2023
6aab6e0
black / isort
josephdviviano Nov 23, 2023
716ee7a
log reward clipping removed
josephdviviano Nov 23, 2023
dfb929d
debugging sync
josephdviviano Nov 23, 2023
5d62bee
Independent distributions
josephdviviano Nov 23, 2023
1ceb53d
synced (debug still included)
josephdviviano Nov 23, 2023
b67a6d2
Merge branch 'easier_environment_definition' of github.com:saleml/tor…
josephdviviano Nov 24, 2023
c419dd3
removed debugging notes (confirmed that the issue is with my personal…
josephdviviano Nov 24, 2023
a0f43c6
turned clipping on
josephdviviano Nov 24, 2023
d6ad17f
estimator_outputs now live inside trajectories
josephdviviano Nov 24, 2023
50b74d2
estimator outputs now live inside trajectories. if they aren't comput…
josephdviviano Nov 24, 2023
80b4c29
full support for estimator_outputs (lots of padding logic added -- th…
josephdviviano Nov 27, 2023
b5fdd32
TODO - I think we found a function that is never called
josephdviviano Nov 27, 2023
ac42f22
estimator outputs now saved in a padded format. Also, some logic chan…
josephdviviano Nov 27, 2023
05732a8
all flags are now off_policy for consistency
josephdviviano Nov 27, 2023
25235f0
all flags are now off_policy for consistency
josephdviviano Nov 27, 2023
0814725
all tests pass
josephdviviano Nov 27, 2023
72ac58f
added off policy flag
josephdviviano Nov 27, 2023
b0432c9
isort / black
josephdviviano Nov 27, 2023
b987a39
updated scripts with new API and tweaked tests (with reproducibility)
josephdviviano Nov 27, 2023
12eab45
tests passing
josephdviviano Nov 27, 2023
cd35cb8
syncing notebook states
josephdviviano Nov 27, 2023
44050a9
removed one order of magnitude precision required
josephdviviano Nov 27, 2023
038b67b
merge issues resolved
josephdviviano Nov 27, 2023
8388362
fixed tests
josephdviviano Nov 27, 2023
93f2e5f
removed comments
josephdviviano Nov 27, 2023
a6601d7
further loosened test tolerances
josephdviviano Nov 27, 2023
71da6b5
changes requested for PR
josephdviviano Feb 13, 2024
bafa1ad
moved training specific imports here to avoid circular deps
josephdviviano Feb 13, 2024
0990d51
circular deps fix
josephdviviano Feb 13, 2024
aa3c656
removing addiditons (additions commented out)
josephdviviano Feb 14, 2024
e2ad9dd
formatting common
josephdviviano Feb 14, 2024
be122ed
indexing reverted to old strategy with copius documentation
josephdviviano Feb 14, 2024
cfc560c
formatting of tests
josephdviviano Feb 14, 2024
2bebde2
isort / black
josephdviviano Feb 14, 2024
e7c7453
isort
josephdviviano Feb 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
bugfix
  • Loading branch information
josephdviviano committed Nov 21, 2023
commit f897aabf56fe68060f37007e337ee012ec35cec5
54 changes: 32 additions & 22 deletions tutorials/examples/train_line.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ class Line(Env):

def __init__(
self,
mus: list = [-2, 2],
variances: list = [0.5, 0.5],
mus: list,
variances: list,
init_value: float,
n_sd: float = 4.5,
init_value: float = 0,
n_steps_per_trajectory: int = 5,
device_str: Literal["cpu", "cuda"] = "cpu",
):
Expand All @@ -37,15 +37,14 @@ def __init__(
self.n_sd = n_sd
self.n_steps_per_trajectory = n_steps_per_trajectory
self.mixture = [
Normal(torch.tensor(m), torch.tensor(s)) for m, s in zip(mus, self.sigmas)
Normal(m, s) for m, s in zip(self.mus, self.sigmas)
]

self.init_value = init_value # Used in s0.
self.lb = min(self.mus) - self.n_sd * max(self.sigmas) # Convienience only.
self.ub = max(self.mus) + self.n_sd * max(self.sigmas) # Convienience only.
assert self.lb < self.init_value < self.ub

# The state is [x_value, count]. x_value is initalized close to the lower bound.
s0 = torch.tensor([self.init_value, 0.0], device=torch.device(device_str))
super().__init__(s0=s0) # sf is -inf.

Expand All @@ -54,7 +53,7 @@ def make_States_class(self) -> type[States]:

class LineStates(States):
state_shape: ClassVar[Tuple[int, ...]] = (2,)
s0 = env.s0 # should be [init value, 0].
s0 = env.s0 # should be [init x value, 0].
sf = env.sf # should be [-inf, -inf].

return LineStates
Expand All @@ -81,14 +80,21 @@ def maskless_backward_step(self, states: States, actions: Actions) -> TT["batch_

def is_action_valid(self, states: States, actions: Actions, backward: bool = False) -> bool:
# Can't take a backward step at the beginning of a trajectory.
non_terminal_s0_states = states[~actions.is_exit].is_initial_state
if torch.any(non_terminal_s0_states) and backward:
if torch.any(states[~actions.is_exit].is_initial_state) and backward:
return False

return True

def reward(self, final_states: States) -> TT["batch_shape", torch.float]:
return torch.exp(self.log_reward(final_states))
# def reward(self, final_states: States) -> TT["batch_shape", torch.float]:
# s = final_states.tensor[..., 0]
# if s.nelement() == 0:
# return torch.zeros(final_states.batch_shape)

# rewards = torch.empty(final_states.batch_shape)
# for i, m in enumerate(self.mixture):
# rewards = rewards + torch.exp(m.log_prob(s))

# return rewards

def log_reward(self, final_states: States) -> TT["batch_shape", torch.float]:
s = final_states.tensor[..., 0]
Expand All @@ -115,9 +121,13 @@ def render(env, validation_samples=None):
100,
)

d = np.zeros(x.shape)
for mu, sigma in zip(env.mus, env.sigmas):
d += stats.norm.pdf(x, mu, sigma)
# Get the rewards from our environment.
r = env.States(
torch.tensor(
np.stack((x, torch.ones(len(x))), 1) # Add dummy counter.
)
)
d = torch.exp(env.log_reward(r)) # Plots the reward, not the log reward.

fig, ax1 = plt.subplots()

Expand Down Expand Up @@ -216,7 +226,7 @@ def __init__(
hidden_dim: int,
n_hidden_layers: int,
policy_std_min: float = 0.1,
policy_std_max: float = 5,
policy_std_max: float = 1,
):
"""Instantiates the neural network for the forward policy."""
assert policy_std_min > 0
Expand Down Expand Up @@ -256,7 +266,7 @@ def to_probability_distribution(
self,
states: States,
module_output: TT["batch_shape", "output_dim", float],
scale_factor = 1, # policy_kwarg.
scale_factor = 0, # policy_kwarg.
) -> Distribution:
# First, we verify that the batch shape of states is 1
assert len(states.batch_shape) == 1
Expand Down Expand Up @@ -360,7 +370,7 @@ def train(
env,
n_samples=batch_size,
sample_off_policy=True,
scale_factor=scale_schedule[iteration],
scale_factor=scale_schedule[iteration], # Off policy kwargs.
)
training_samples = gflownet.to_training_samples(trajectories)
optimizer.zero_grad()
Expand Down Expand Up @@ -412,10 +422,10 @@ def train(

if __name__ == "__main__":

env = Line(mus=[-2, 2], variances=[0.5, 0.5], n_sd=4.5, init_value=0.5, n_steps_per_trajectory=10)
env = Line(mus=[2, 5], variances=[0.2, 0.2], init_value=0, n_sd=4.5, n_steps_per_trajectory=5)
# Forward and backward policy estimators. We pass the lower bound from the env here.
hid_dim = 128
n_hidden_layers = 2
hid_dim = 64
n_hidden_layers = 1
policy_std_min = 0.1
policy_std_max = 1
exploration_var_starting_val = 2
Expand Down Expand Up @@ -454,9 +464,9 @@ def train(
# Magic hyperparameters: lr_base=4e-2, n_trajectories=3e6, batch_size=2048
gflownet, jsd = train(
gflownet,
lr_base=1e-4,
n_trajectories=3e6,
batch_size=1024,
lr_base=1e-3,
n_trajectories=1e6,
batch_size=256,
exploration_var_starting_val=exploration_var_starting_val
) # I started training this with 1e-3 and then reduced it.

Expand Down
Loading