diff --git a/src/gfn/gym/hypergrid.py b/src/gfn/gym/hypergrid.py index bddc734c..4656f1c0 100644 --- a/src/gfn/gym/hypergrid.py +++ b/src/gfn/gym/hypergrid.py @@ -3,9 +3,13 @@ """ from typing import Literal, Tuple -from math import gcd +from math import gcd, log from functools import reduce from decimal import Decimal +import itertools +import torch +import multiprocessing +from time import time import torch from einops import rearrange @@ -18,6 +22,9 @@ from gfn.states import DiscreteStates +multiprocessing.set_start_method("fork") # multiprocessing-torch compatibility. + + def lcm(a, b): """Returns the lowest common multiple between a and b.""" return a * b // gcd(a, b) @@ -53,6 +60,8 @@ def __init__( reward_cos: bool = False, device_str: Literal["cpu", "cuda"] = "cpu", preprocessor_name: Literal["KHot", "OneHot", "Identity", "Enum"] = "KHot", + calculate_partition: bool = True, + calculate_all_states: bool = False, ): """HyperGrid environment from the GFlowNets paper. The states are represented as 1-d tensors of length `ndim` with values in @@ -69,6 +78,12 @@ def __init__( reward_cos (bool, optional): Which version of the reward to use. Defaults to False. device_str (str, optional): "cpu" or "cuda". Defaults to "cpu". preprocessor_name (str, optional): "KHot" or "OneHot" or "Identity". Defaults to "KHot". + calculate_partition: If True, calculates the true log partition function, + which requires enumerating all states of the hypergrid. Might have + intractable time complexity for very large problems. + calculate_all_states: If True, stores all states in the internal property + all_states. Might have intractable space complexity for very large + problems. """ self.ndim = ndim self.height = height @@ -76,6 +91,16 @@ def __init__( self.R1 = R1 self.R2 = R2 self.reward_cos = reward_cos + self._all_states = None # Populated at first request. + self._log_partition = None # Populated at first request. + self.calculate_partition = calculate_partition + self.calculate_all_states = calculate_all_states + + # Pre-computes these values. + if self.calculate_all_states: + self.all_states() + if self.calculate_partition: + self.log_partition() # This scale is used to stabilize calculations. self.scale_factor = smallest_multiplier_to_integers([R0, R1, R2]) @@ -152,7 +177,13 @@ def reward(self, final_states: DiscreteStates) -> TT["batch_shape", torch.float] - 0.5 \right\rvert \in (0.25, 0.5] \right) + 2 \prod_{d=1}^D \mathbf{1} \left( \left\lvert \frac{s^d}{H-1} - 0.5 \right\rvert \in (0.3, 0.4) \right) """ - final_states_raw = final_states.tensor + if isinstance(final_states, DiscreteStates): + final_states_raw = final_states.tensor + elif isinstance(final_states, torch.Tensor): + final_states_raw = final_states + else: + raise TypeError("final_states should be a States instance or Tensor.") + R0, R1, R2 = (self.R0, self.R1, self.R2) ax = abs(final_states_raw / (self.height - 1) - 0.5) if not self.reward_cos: @@ -191,46 +222,114 @@ def n_terminating_states(self) -> int: @property def true_dist_pmf(self) -> torch.Tensor: - all_states = self.all_states - assert torch.all( - self.get_states_indices(all_states) - == torch.arange(self.n_states, device=self.device) - ) - true_dist = self.reward(all_states) - true_dist /= true_dist.sum() - return true_dist + """Returns the pmf over all states in the hypergrid.""" + if not self._true_dist and self.calculate_all_states: + assert torch.all( + self.get_states_indices(self.all_states) + == torch.arange(self.n_states, device=self.device) + ) + self._true_dist = self.reward(self.all_states) + self._true_dist /= self._true_dist.sum() + + return self._true_dist @property - def log_partition(self) -> float: - grid = self.build_grid() - rewards = self.reward(grid) - return rewards.sum().log().item() - - def build_grid(self) -> DiscreteStates: - "Utility function to build the complete grid" - H = self.height - ndim = self.ndim - grid_shape = (H,) * ndim + (ndim,) # (H, ..., H, ndim) - grid = torch.zeros(grid_shape, device=self.device) - for i in range(ndim): - grid_i = torch.linspace(start=0, end=H - 1, steps=H) - for _ in range(i): - grid_i = grid_i.unsqueeze(1) - grid[..., i] = grid_i - - rearrange_string = " ".join([f"n{i}" for i in range(1, ndim + 1)]) - rearrange_string += " ndim -> " - rearrange_string += " ".join([f"n{i}" for i in range(ndim, 0, -1)]) - rearrange_string += " ndim" - grid = rearrange(grid, rearrange_string).long() - return self.States(grid) + def log_partition(self, batch_size: int = 20_000) -> float: + """Returns the log partition of the complete hypergrid. + + Args: + batch_size: Compute this number of hypergrid indices in parallel. + """ + if self._log_partition is None and self.calculate_partition: + # The # of possible combinations (with repetition) of 𝑛 numbers, where each + # number can be any integer from 0 to 𝑘 (inclusive), is given by: + # n = (k + 1) ** n -- note that k in our case is height-1, as it represents + # a python index. + max_height_idx = self.height - 1 # Handles 0 indexing. + n_expected = (max_height_idx + 1) ** self.ndim + n_found = 0 + start_time = time() + total_reward = 0 + + for batch in self._generate_combinations_in_batches( + self.ndim, + max_height_idx, + batch_size, + ): + batch = torch.LongTensor(list(batch)) + rewards = self.reward(batch) # Operates on raw tensors due to multiprocessing. + total_reward += rewards.sum().item() # Accumulate. + n_found += batch.shape[0] + + assert n_expected == n_found, "failed to compute reward of all indices!" + end_time = time() + total_log_reward = log(total_reward) + + print( + "log_partition = {}, calculated in {} minutes".format( + total_log_reward, + (end_time - start_time) / 60.0, + ) + ) + + self._log_partition = total_log_reward + + return self._log_partition @property - def all_states(self) -> DiscreteStates: - grid = self.build_grid() - flat_grid = rearrange(grid.tensor, "... ndim -> (...) ndim") - return self.States(flat_grid) + def all_states(self, batch_size: int = 20_000) -> DiscreteStates: + """Returns a tensor of all hypergrid states.""" + + if self._all_states is None and self.calculate_all_states: + start_time = time() + all_states = [] + + for batch in self._generate_combinations_in_batches( + self.ndim, + self.height - 1, # Handles 0 indexing. + batch_size, + ): + all_states.append(torch.LongTensor(list(batch))) + + all_states = torch.cat(all_states, dim=0) + end_time = time() + + print( + "calculated tensor of all states in {} minutes".format( + (end_time - start_time) / 60.0, + ) + ) + + self._all_states = self.States(all_states) + + return self._all_states @property def terminating_states(self) -> DiscreteStates: return self.all_states + + def _generate_combinations_chunk(self, numbers, n, start, end): + """Generate combinations with replacement for the specified range.""" + # islice accesses a subset of the full iterator - each job does unique work. + return itertools.islice(itertools.product(numbers, repeat=n), start, end) + + def _worker(self, task): + """Executes a single call to `generate_combinations_chunk`.""" + numbers, n, start, end = task + return self._generate_combinations_chunk(numbers, n, start, end) + + def _generate_combinations_in_batches(self, n, k, batch_size): + """Uses Pool to collect subsets of the results of itertools.product in parallel.""" + numbers = list(range(k + 1)) + + # Number of possible combinations (with repetition) of 𝑛 numbers, where each + # number can be any integer from 0 to 𝑘 (inclusive). + total_combinations = (k + 1) ** n + tasks = [ + (numbers, n, i, min(i + batch_size, total_combinations)) + for i in range(0, total_combinations, batch_size) + ] + + with multiprocessing.Pool() as pool: + for result in pool.imap(self._worker, tasks): + yield result