Skip to content

Commit

Permalink
sync local
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Jul 17, 2024
1 parent 9f38a17 commit 81f38cf
Showing 1 changed file with 136 additions and 37 deletions.
173 changes: 136 additions & 37 deletions src/gfn/gym/hypergrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@
"""

from typing import Literal, Tuple
from math import gcd
from math import gcd, log
from functools import reduce
from decimal import Decimal
import itertools
import torch
import multiprocessing
from time import time

import torch
from einops import rearrange
Expand All @@ -18,6 +22,9 @@
from gfn.states import DiscreteStates


multiprocessing.set_start_method("fork") # multiprocessing-torch compatibility.


def lcm(a, b):
"""Returns the lowest common multiple between a and b."""
return a * b // gcd(a, b)
Expand Down Expand Up @@ -53,6 +60,8 @@ def __init__(
reward_cos: bool = False,
device_str: Literal["cpu", "cuda"] = "cpu",
preprocessor_name: Literal["KHot", "OneHot", "Identity", "Enum"] = "KHot",
calculate_partition: bool = True,
calculate_all_states: bool = False,
):
"""HyperGrid environment from the GFlowNets paper.
The states are represented as 1-d tensors of length `ndim` with values in
Expand All @@ -69,13 +78,29 @@ def __init__(
reward_cos (bool, optional): Which version of the reward to use. Defaults to False.
device_str (str, optional): "cpu" or "cuda". Defaults to "cpu".
preprocessor_name (str, optional): "KHot" or "OneHot" or "Identity". Defaults to "KHot".
calculate_partition: If True, calculates the true log partition function,
which requires enumerating all states of the hypergrid. Might have
intractable time complexity for very large problems.
calculate_all_states: If True, stores all states in the internal property
all_states. Might have intractable space complexity for very large
problems.
"""
self.ndim = ndim
self.height = height
self.R0 = R0
self.R1 = R1
self.R2 = R2
self.reward_cos = reward_cos
self._all_states = None # Populated at first request.
self._log_partition = None # Populated at first request.
self.calculate_partition = calculate_partition
self.calculate_all_states = calculate_all_states

# Pre-computes these values.
if self.calculate_all_states:
self.all_states()
if self.calculate_partition:
self.log_partition()

# This scale is used to stabilize calculations.
self.scale_factor = smallest_multiplier_to_integers([R0, R1, R2])
Expand Down Expand Up @@ -152,7 +177,13 @@ def reward(self, final_states: DiscreteStates) -> TT["batch_shape", torch.float]
- 0.5 \right\rvert \in (0.25, 0.5] \right)
+ 2 \prod_{d=1}^D \mathbf{1} \left( \left\lvert \frac{s^d}{H-1} - 0.5 \right\rvert \in (0.3, 0.4) \right)
"""
final_states_raw = final_states.tensor
if isinstance(final_states, DiscreteStates):
final_states_raw = final_states.tensor
elif isinstance(final_states, torch.Tensor):
final_states_raw = final_states
else:
raise TypeError("final_states should be a States instance or Tensor.")

R0, R1, R2 = (self.R0, self.R1, self.R2)
ax = abs(final_states_raw / (self.height - 1) - 0.5)
if not self.reward_cos:
Expand Down Expand Up @@ -191,46 +222,114 @@ def n_terminating_states(self) -> int:

@property
def true_dist_pmf(self) -> torch.Tensor:
all_states = self.all_states
assert torch.all(
self.get_states_indices(all_states)
== torch.arange(self.n_states, device=self.device)
)
true_dist = self.reward(all_states)
true_dist /= true_dist.sum()
return true_dist
"""Returns the pmf over all states in the hypergrid."""
if not self._true_dist and self.calculate_all_states:
assert torch.all(
self.get_states_indices(self.all_states)
== torch.arange(self.n_states, device=self.device)
)
self._true_dist = self.reward(self.all_states)
self._true_dist /= self._true_dist.sum()

return self._true_dist

@property
def log_partition(self) -> float:
grid = self.build_grid()
rewards = self.reward(grid)
return rewards.sum().log().item()

def build_grid(self) -> DiscreteStates:
"Utility function to build the complete grid"
H = self.height
ndim = self.ndim
grid_shape = (H,) * ndim + (ndim,) # (H, ..., H, ndim)
grid = torch.zeros(grid_shape, device=self.device)
for i in range(ndim):
grid_i = torch.linspace(start=0, end=H - 1, steps=H)
for _ in range(i):
grid_i = grid_i.unsqueeze(1)
grid[..., i] = grid_i

rearrange_string = " ".join([f"n{i}" for i in range(1, ndim + 1)])
rearrange_string += " ndim -> "
rearrange_string += " ".join([f"n{i}" for i in range(ndim, 0, -1)])
rearrange_string += " ndim"
grid = rearrange(grid, rearrange_string).long()
return self.States(grid)
def log_partition(self, batch_size: int = 20_000) -> float:
"""Returns the log partition of the complete hypergrid.
Args:
batch_size: Compute this number of hypergrid indices in parallel.
"""
if self._log_partition is None and self.calculate_partition:
# The # of possible combinations (with repetition) of 𝑛 numbers, where each
# number can be any integer from 0 to 𝑘 (inclusive), is given by:
# n = (k + 1) ** n -- note that k in our case is height-1, as it represents
# a python index.
max_height_idx = self.height - 1 # Handles 0 indexing.
n_expected = (max_height_idx + 1) ** self.ndim
n_found = 0
start_time = time()
total_reward = 0

for batch in self._generate_combinations_in_batches(
self.ndim,
max_height_idx,
batch_size,
):
batch = torch.LongTensor(list(batch))
rewards = self.reward(batch) # Operates on raw tensors due to multiprocessing.
total_reward += rewards.sum().item() # Accumulate.
n_found += batch.shape[0]

assert n_expected == n_found, "failed to compute reward of all indices!"
end_time = time()
total_log_reward = log(total_reward)

print(
"log_partition = {}, calculated in {} minutes".format(
total_log_reward,
(end_time - start_time) / 60.0,
)
)

self._log_partition = total_log_reward

return self._log_partition

@property
def all_states(self) -> DiscreteStates:
grid = self.build_grid()
flat_grid = rearrange(grid.tensor, "... ndim -> (...) ndim")
return self.States(flat_grid)
def all_states(self, batch_size: int = 20_000) -> DiscreteStates:
"""Returns a tensor of all hypergrid states."""

if self._all_states is None and self.calculate_all_states:
start_time = time()
all_states = []

for batch in self._generate_combinations_in_batches(
self.ndim,
self.height - 1, # Handles 0 indexing.
batch_size,
):
all_states.append(torch.LongTensor(list(batch)))

all_states = torch.cat(all_states, dim=0)
end_time = time()

print(
"calculated tensor of all states in {} minutes".format(
(end_time - start_time) / 60.0,
)
)

self._all_states = self.States(all_states)

return self._all_states

@property
def terminating_states(self) -> DiscreteStates:
return self.all_states

def _generate_combinations_chunk(self, numbers, n, start, end):
"""Generate combinations with replacement for the specified range."""
# islice accesses a subset of the full iterator - each job does unique work.
return itertools.islice(itertools.product(numbers, repeat=n), start, end)

def _worker(self, task):
"""Executes a single call to `generate_combinations_chunk`."""
numbers, n, start, end = task
return self._generate_combinations_chunk(numbers, n, start, end)

def _generate_combinations_in_batches(self, n, k, batch_size):
"""Uses Pool to collect subsets of the results of itertools.product in parallel."""
numbers = list(range(k + 1))

# Number of possible combinations (with repetition) of 𝑛 numbers, where each
# number can be any integer from 0 to 𝑘 (inclusive).
total_combinations = (k + 1) ** n
tasks = [
(numbers, n, i, min(i + batch_size, total_combinations))
for i in range(0, total_combinations, batch_size)
]

with multiprocessing.Pool() as pool:
for result in pool.imap(self._worker, tasks):
yield result

0 comments on commit 81f38cf

Please sign in to comment.