Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

conditional gfn #188

Merged
merged 33 commits into from
Oct 24, 2024
Merged
Changes from 1 commit
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
6e8dc4d
example of conditional GFN computation with TB only (for now)
josephdviviano Sep 25, 2024
39fb5ee
should be no change
josephdviviano Sep 25, 2024
2bc2263
Trajectories objects now have an optional .conditonal field which opt…
josephdviviano Sep 25, 2024
99afaf3
small changes to logz paramater handling, optionally incorporate cond…
josephdviviano Sep 25, 2024
e6d25a0
logZ is optionally computed using a conditioning vector
josephdviviano Sep 25, 2024
2c72bf9
NeuralNets now have input/output dims
josephdviviano Sep 25, 2024
580c455
added a ConditionalDiscretePolicyEstimator, and the forward of GFNMod…
josephdviviano Sep 25, 2024
a74872f
added conditioning to sampler, which will save the tensor as an attri…
josephdviviano Sep 25, 2024
056d935
black
josephdviviano Sep 25, 2024
96b725c
API changes adapted
josephdviviano Oct 1, 2024
5cd32a7
added conditioning to all gflownets
josephdviviano Oct 1, 2024
877c4a0
both trajectories and transitions can now store a conditioning tensor
josephdviviano Oct 1, 2024
279a313
input_dim setting is now private
josephdviviano Oct 1, 2024
65135c1
added exception handling for all estimator calls potentially involvin…
josephdviviano Oct 1, 2024
b4c418c
API change -- n vs. n_trajectories
josephdviviano Oct 1, 2024
738b062
change test_box target value
josephdviviano Oct 1, 2024
4434e5f
API changes
josephdviviano Oct 1, 2024
851e03e
hacky fix for problematic test (added TODO)
josephdviviano Oct 1, 2024
5152295
working examples for all 4 major losses
josephdviviano Oct 4, 2024
1d64b55
added conditioning indexing for correct broadcasting
josephdviviano Oct 4, 2024
348ee82
added a ConditionalScalarEstimator which subclasses ConditionalDiscre…
josephdviviano Oct 4, 2024
9120afe
added modified DB example
josephdviviano Oct 4, 2024
f59f4de
conditioning added to modified db example
josephdviviano Oct 4, 2024
c5ef7ea
black
josephdviviano Oct 4, 2024
d67dfd5
reorganized keyword arguments and fixed some type errors (not all)
josephdviviano Oct 9, 2024
d56a798
reorganized keyword arguments and fixed some type errors (not all)
josephdviviano Oct 9, 2024
db8844c
added typing and a ConditionalScalarEstimator
josephdviviano Oct 9, 2024
e03c03a
added typing
josephdviviano Oct 9, 2024
6b47e06
typing
josephdviviano Oct 9, 2024
988faf0
typing
josephdviviano Oct 9, 2024
f2bbce3
added kwargs
josephdviviano Oct 9, 2024
eb13a2d
renamed torso to trunk
josephdviviano Oct 24, 2024
fd3d9dc
renamed torso to trunk
josephdviviano Oct 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions src/gfn/modules.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Optional
from typing import Optional, Any

import torch
import torch.nn as nn
Expand Down Expand Up @@ -109,7 +109,7 @@ def to_probability_distribution(
self,
states: States,
module_output: TT["batch_shape", "output_dim", float],
**policy_kwargs: Optional[dict],
**policy_kwargs: Any,
) -> Distribution:
"""Transform the output of the module into a probability distribution.

Expand Down Expand Up @@ -240,13 +240,20 @@ def __init__(
self.conditioning_module = conditioning_module
self.final_module = final_module

def forward(
self, states: States, conditioning: torch.tensor
def _forward_trunk(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is what you call trunk the same thing I called torso before ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah -- let me unify the naming

self, states: States, conditioning: torch.Tensor
) -> TT["batch_shape", "output_dim", float]:
state_out = self.module(self.preprocessor(states))
conditioning_out = self.conditioning_module(conditioning)
out = self.final_module(torch.cat((state_out, conditioning_out), -1))

return out

def forward(
self, states: States, conditioning: torch.tensor
) -> TT["batch_shape", "output_dim", float]:
out = self._forward_trunk(states, conditioning)

if not self._output_dim_is_checked:
self.check_output_dim(out)
self._output_dim_is_checked = True
Expand All @@ -272,13 +279,24 @@ def __init__(
is_backward=is_backward,
)

def forward(
self, states: States, conditioning: torch.tensor
) -> TT["batch_shape", "output_dim", float]:
out = self._forward_trunk(states, conditioning)

if not self._output_dim_is_checked:
self.check_output_dim(out)
self._output_dim_is_checked = True

return out

def expected_output_dim(self) -> int:
return 1

def to_probability_distribution(
self,
states: States,
module_output: TT["batch_shape", "output_dim", float],
**policy_kwargs: Optional[dict],
**policy_kwargs: Any,
) -> Distribution:
raise NotImplementedError