diff --git a/README.md b/README.md index 9b0a3afa8..17d9c557a 100644 --- a/README.md +++ b/README.md @@ -60,6 +60,28 @@ This may require a reservation on the Infiniband cluster. See the [Beaker documentation](https://beaker-docs.apps.allenai.org/distributed-training.html) for more information on distributed training. +## Generating text + +You can use the `generate()` method to produce text using beam search with a variety of options. + +For example: + +```python +# Prepare inputs. +# Note: we don't want the EOS token added to the end of the input, hence +# the `add_special_tokens=False`. +input_ids = tokenizer.encode("I'm a large language model, ", add_special_tokens=False) +# `model.generate()` expects a batch. +input_tensor = torch.tensor(input_ids).unsqueeze(0) + +# Run beam search. +outputs = model.generate(input_tensor, max_steps=3, beam_size=3) + +# The output token IDs are shape (batch_size, beam_size, max_steps) +best_generation = outputs.token_ids[0][0].tolist() +print(tokenizer.decode(best_generation)) +``` + ## Finding official runs We keep all of our runs in WandB under [the "ai2-llm" entity](https://wandb.ai/ai2-llm). @@ -67,4 +89,4 @@ We don't store model checkpoints in WandB. Those are in GCS under `gs://allennlp ### Highlighted models - * 300M parameters, ~70B tokens, a starter model that's not completely random: https://wandb.ai/ai2-llm/LLM-scripts/runs/ed5krfk9 \ No newline at end of file + * 300M parameters, ~70B tokens, a starter model that's not completely random: https://wandb.ai/ai2-llm/LLM-scripts/runs/ed5krfk9 diff --git a/dolma/beam_search.py b/dolma/beam_search.py new file mode 100644 index 000000000..02fa99c6a --- /dev/null +++ b/dolma/beam_search.py @@ -0,0 +1,1082 @@ +""" +This is a self-contained and flexible beam search implementation adapted from +AllenNLP's beam search: https://github.com/allenai/allennlp/blob/main/allennlp/nn/beam_search.py +""" + +import copy +import warnings +from abc import abstractmethod +from inspect import signature +from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, cast + +import torch + +__all__ = [ + "Sampler", + "DeterministicSampler", + "MultinomialSampler", + "TopKSampler", + "TopPSampler", + "GumbelSampler", + "FinalSequenceScorer", + "SequenceLogProbabilityScorer", + "LengthNormalizedSequenceLogProbabilityScorer", + "Constraint", + "RepeatedNGramBlockingConstraint", + "BeamSearch", +] + +StateType = Dict[str, torch.Tensor] +StepFunctionTypeWithTimestep = Callable[[torch.Tensor, StateType, int], Tuple[torch.Tensor, StateType]] +StepFunctionTypeNoTimestep = Callable[[torch.Tensor, StateType], Tuple[torch.Tensor, StateType]] + +StepFunctionType = TypeVar("StepFunctionType", StepFunctionTypeWithTimestep, StepFunctionTypeNoTimestep) +""" +The type of step function that can be passed to [`BeamSearch.search`](#search). + +This can either be [`StepFunctionTypeWithTimestep`](#stepfunctiontypewithtimestep) +or [`StepFunctionTypeNoTimestep`](#stepfunctiontypenotimestep). +""" + +ConstraintStateType = List[List[Dict[str, Any]]] + + +class Sampler: + """ + An abstract class that can be used to sample candidates (either nodes or beams) + within `BeamSearch`. + + A `Sampler` just has three methods, `init_state()`, `sample_nodes()` and `sample_beams()`. + + `init_state()` takes three arguments: + + - a tensor of starting log probs with shape `(batch_size,, num_classes)`, + - the batch size, an int, + - and the number of classes, also an int. + + It returns a state dictionary with any state tensors needed for subsequent + calls to `sample_nodes()` and `sample_beams()`. + + By default this method just returns an empty dictionary. + + Both `sample_nodes()` and `sample_beams()` should take three arguments: + + - tensor of normalized log probabilities with shape `(batch_size, num_examples)`, + - an integer representing the number of samples to take for each example in the batch, + - and a state dictionary which could contain any tensors needed for the `Sampler` to keep + track of state. + + For `sample_nodes()`, `num_examples = num_classes`, but for `sample_beams`, + `num_examples = beam_size * per_node_beam_size`. + + The return value should be a tuple containing: + + - a tensor of log probabilities of the sampled examples with shape `(batch_size, num_samples)`, + - a tensor of indices of the sampled examples with shape `(batch_size, num_samples)`, + - and the updated state dictionary. + + A default implementation of `sample_beams` is provided, which just deterministically + picks the `k` examples with highest log probability. + """ + + default_implementation = "deterministic" + + def init_state( + self, start_class_log_probabilities: torch.Tensor, batch_size: int, num_classes: int + ) -> StateType: + del start_class_log_probabilities, batch_size, num_classes + return {} + + @abstractmethod + def sample_nodes( + self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType + ) -> Tuple[torch.Tensor, torch.Tensor, StateType]: + raise NotImplementedError + + def sample_beams( + self, log_probs: torch.Tensor, beam_size: int, state: StateType + ) -> Tuple[torch.Tensor, torch.Tensor, StateType]: + del state + selected_log_probs, selected_indices = torch.topk(log_probs, beam_size, dim=-1) + return selected_log_probs, selected_indices, {} + + +class DeterministicSampler(Sampler): + """ + A `Sampler` that just deterministically returns the `k` nodes or beams with highest + log probability. + """ + + def sample_nodes( + self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType + ) -> Tuple[torch.Tensor, torch.Tensor, StateType]: + del state + selected_log_probs, selected_indices = torch.topk(log_probs, per_node_beam_size, dim=-1) + return selected_log_probs, selected_indices, {} + + +class MultinomialSampler(Sampler): + """ + A `Sampler` which samples nodes from the given multinomial distribution. Beams are sampled + in the default, non-deterministic way. + + :param temperature: A `temperature` below 1.0 produces a sharper probability distribution and a `temperature` + above 1.0 produces a flatter probability distribution. + :param with_replacement: Whether to sample with replacement. + + """ + + def __init__( + self, + temperature: float = 1.0, + with_replacement: bool = False, + ) -> None: + self.temperature = temperature + self.with_replacement = with_replacement + + def sample_nodes( + self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType + ) -> Tuple[torch.Tensor, torch.Tensor, StateType]: + if self.temperature != 1.0: + _probabilities = torch.nn.functional.softmax(log_probs / self.temperature, dim=-1) + else: + _probabilities = log_probs.exp() + + selected_indices = torch.multinomial(_probabilities, per_node_beam_size, replacement=self.with_replacement) + + return torch.gather(log_probs, 1, selected_indices), selected_indices, state + + +class TopKSampler(Sampler): + """ + A `Sampler` which redistributes the probability mass function for nodes among the + top `k` choices, then samples from that subset after re-normalizing the probabilities. + + Beams are sampled in the default, deterministic way. + + :param k: The number of top choices to be selected from. + :param temperature: A `temperature` below 1.0 produces a sharper probability distribution and a `temperature` + above 1.0 produces a flatter probability distribution. + :param with_replacement: If set to `True`, samples will be selected with replacement from the top k choices. + """ + + def __init__( + self, + k: int = 1, + temperature: float = 1.0, + with_replacement: bool = False, + ): + self.k = k + self.temperature = temperature or 1.0 + self.with_replacement = with_replacement + + def sample_nodes( + self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType + ) -> Tuple[torch.Tensor, torch.Tensor, StateType]: + if not per_node_beam_size <= self.k <= log_probs.size()[1]: + raise ValueError( + "k must be a postive integer no less than per_node_beam_size and no greater than vocabulary size" + ) + + # shape (both): (batch_size, k) + top_k_log_probs, top_k_indices = log_probs.topk(self.k, dim=-1) + + # Apply temperature if necessary. + # shape: (batch_size, k) + if self.temperature != 1.0: + top_k_log_probs = top_k_log_probs / self.temperature + + # Re-normalize the subset. + # shape: (batch_size, k) + normalized_top_k_probs = torch.nn.functional.softmax(top_k_log_probs, dim=-1) + + # Sample from the re-normalized subset. + # NOTE: These indices are not indices into `log_probs`, they are indices into `top_k_log_probs`. + # shape: (batch_size, per_node_beam_size) + sampled_indices = torch.multinomial( + normalized_top_k_probs, per_node_beam_size, replacement=self.with_replacement + ) + + # Convert `sampled_indices` back to indices in the original `log_probs` tensor. + # shape: (batch_size, per_node_beam_size) + indices = top_k_indices.gather(-1, sampled_indices) + + return log_probs.gather(1, indices), indices, state + + +class TopPSampler(Sampler): + """ + A `Sampler` which redistributes the probability mass function for nodes among + the top choices with a cumulative probability of at least `p`, then samples from that subset + after re-normalizing the probabilities. + + Beams are sampled in the default, deterministic way. + + :param p: + The cumulative probability cutoff threshold. A higher value of `p` will result in more possible + examples to sample from. If `with_replacement` is `False` and the number of possible samples is + insufficient to sample without replacement from when calling `sample_nodes`, then the top + `per_node_beam_size` examples will be chosen. + :param temperature: + A `temperature` below 1.0 produces a sharper probability distribution and a `temperature` + above 1.0 produces a flatter probability distribution. + :param with_replacement: + If set to `True`, samples will be selected with replacement from the top choices. + + """ + + def __init__( + self, + p: float = 0.9, + temperature: float = 1.0, + with_replacement: bool = False, + ): + if p < 0.0 or p > 1.0: + raise ValueError("p must be a positive float no greater than 1.0") + self.p = p + self.temperature = temperature or 1.0 + self.with_replacement = with_replacement + + def sample_nodes( + self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType + ) -> Tuple[torch.Tensor, torch.Tensor, StateType]: + if not per_node_beam_size <= log_probs.size()[1]: + raise ValueError("per_node_beam_size cannot be greater than vocabulary size") + + # First apply temperature coefficient: + if self.temperature != 1.0: + _log_probs = torch.nn.functional.log_softmax(log_probs / self.temperature, dim=-1) + else: + _log_probs = log_probs + + # Sort the probabilities in descending order to then find cumulative sum + log_probs_descending, sorting_indices = torch.sort(_log_probs, descending=True) + + # shape: (batch_size, num_classes) + probabilities_descending = log_probs_descending.exp() + probabilities_summed = torch.cumsum(probabilities_descending, dim=-1) + + # Create a mask for filtering out probabilities that don't make the top `p`. + # shape: (batch_size, num_classes) + exclusion_mask = probabilities_summed >= self.p + + # We want to include the first index where probabilities_summed >= p, so we shift over one. + exclusion_mask[..., 1:] = exclusion_mask[..., :-1].clone() + exclusion_mask[..., 0] = False + + # Make sure there's at least `per_node_beam_size` options to be selected. + if not self.with_replacement: + exclusion_mask[..., :per_node_beam_size] = False + + log_probs_descending[exclusion_mask] = torch.finfo(log_probs.dtype).min + + # Now re-normalized the included log probs. + # shape: (batch_size, num_classes) + filtered_probabilities = torch.nn.functional.softmax(log_probs_descending, dim=-1) + + # Sample from the re-normalized subset. + # NOTE: These indices are not indices into `log_probs`, they are indices into `log_probs_descending`. + # shape: (batch_size, per_node_beam_size) + sampled_indices = torch.multinomial( + filtered_probabilities, per_node_beam_size, replacement=self.with_replacement + ) + + # Convert `sampled_indices` back to indices in the original `log_probs` tensor. + # shape: (batch_size, per_node_beam_size) + selected_indices = sorting_indices.gather(-1, sampled_indices) + + # Return (selected log probabilities, selected classes) + # shape: (len(log_probs),1) , (len(log_probs), 1) + return torch.gather(log_probs, 1, selected_indices), selected_indices, state + + +class GumbelSampler(Sampler): + """ + A `Sampler` which uses the Gumbel-Top-K trick to sample without replacement. See + [*Stochastic Beams and Where to Find Them: The Gumbel-Top-k Trick for Sampling + Sequences Without Replacement*, W Kool, H Van Hoof and M Welling, 2010] + (https://api.semanticscholar.org/CorpusID:76662039). + + :param temperature: A `temperature` below 1.0 produces a sharper probability distribution and a `temperature` + above 1.0 produces a flatter probability distribution. + """ + + def __init__(self, temperature: float = 1.0): + self.temperature = temperature + + def init_state( + self, start_class_log_probabilities: torch.Tensor, batch_size: int, num_classes: int + ) -> StateType: + # shape: (batch_size, num_classes) + zeros = start_class_log_probabilities.new_zeros((batch_size, num_classes)) + + # shape: (batch_size, num_classes) + G_phi_S = self.gumbel_with_max(start_class_log_probabilities, zeros) + + return {"G_phi_S": G_phi_S} + + def sample_nodes( + self, + log_probs: torch.Tensor, + per_node_beam_size: int, + state: StateType, + ) -> Tuple[torch.Tensor, torch.Tensor, StateType]: + # First apply temperature coefficient: + # shape: (batch_size * beam_size, num_classes) + if self.temperature != 1.0: + _log_probs = torch.nn.functional.log_softmax(log_probs / self.temperature, dim=-1) + else: + _log_probs = log_probs + + # shape: (group_size,) + phi_S = state["phi_S"] + + # shape: (group_size, num_classes) + phi_S = phi_S.unsqueeze(-1).expand_as(_log_probs) + + # shape: (group_size, num_classes) + phi_S_new = phi_S + _log_probs + + # shape: (group_size, 1) + G_phi_S = state["G_phi_S"].unsqueeze(-1) + + # shape: (group_size, num_classes) + G_phi_S_new = self.gumbel_with_max(phi_S_new, G_phi_S) + + # Replace NaNs with very negative number. + # shape: (group_size, num_classes) + # G_phi_S_new[G_phi_S_new.isnan()] = torch.finfo(G_phi_S_new.dtype).min + + # shape (both): (group_size, per_node_beam_size) + top_G_phi_S_new, top_indices = torch.topk(G_phi_S_new, per_node_beam_size, dim=-1) + + # shape: (group_size, per_node_beam_size) + top_log_probs = log_probs.gather(1, top_indices) + + return top_log_probs, top_indices, {"G_phi_S": top_G_phi_S_new} + + def sample_beams( + self, + log_probs: torch.Tensor, + beam_size: int, + state: StateType, + ) -> Tuple[torch.Tensor, torch.Tensor, StateType]: + """ + Returns the beams with the highest perturbed log probabilities. + """ + # shape (log_probs): (batch_size, beam_size * per_node_beam_size) + + batch_size = log_probs.size()[0] + + # shape: (batch_size * beam_size, per_node_beam_size) + G_phi_S = state["G_phi_S"] + + # shape: (batch_size, beam_size * per_node_beam_size) + G_phi_S = G_phi_S.reshape_as(log_probs) + + # shape (both): (batch_size, beam_size) + G_phi_S_new, selected_indices = torch.topk(G_phi_S, beam_size, dim=-1) + + # shape: (batch_size, beam_size) + selected_log_probs = log_probs.gather(1, selected_indices) + + # Now sort the selected beams by their true log prob. + # shape (all): (batch_size, beam_size) + selected_log_probs, sort_indices = selected_log_probs.sort(dim=-1, descending=True) + selected_indices = selected_indices.gather(1, sort_indices) + G_phi_S_new = G_phi_S_new.gather(1, sort_indices) + + # shape: (batch_size * beam_size,) + G_phi_S_new = G_phi_S_new.reshape(batch_size * beam_size) + + # shape: (batch_size * beam_size,) + phi_S = selected_log_probs.reshape(batch_size * beam_size) + + return selected_log_probs, selected_indices, {"G_phi_S": G_phi_S_new, "phi_S": phi_S} + + def gumbel(self, phi) -> torch.Tensor: + """ + Sample `Gumbel(phi)`. + + `phi` should have shape `(batch_size, num_classes)`. + """ + return -torch.log(-torch.log(torch.rand_like(phi))) + phi + + def gumbel_with_max(self, phi, T) -> torch.Tensor: + """ + Sample `Gumbel(phi)` conditioned on the maximum value being equal to `T`. + + `phi` should have shape `(batch_size, num_classes)` and `T` should have + shape `(batch_size, 1)`. + """ + # Shape: (batch_size, num_classes) + G_phi = self.gumbel(phi) + + # Now we find the maximum from these samples. + # Shape: (batch_size, ) + Z, _ = G_phi.max(dim=-1) + + # Shape: (batch_size, num_classes) + v = T - G_phi + torch.log1p(-torch.exp(G_phi - Z.unsqueeze(-1))) + + # Shape: (batch_size, num_classes) + return T - torch.nn.functional.relu(v) - torch.log1p(torch.exp(-v.abs())) + + +class FinalSequenceScorer: + """ + An abstract class that can be used to score the final generated sequences found + by beam search. Given the predicted sequences and the corresponding log probabilities of + those sequences, the class calculates and returns the final score of the sequences. + + The default implementation scores the sequences using the sum of the log probabilities of + the sequence, which is passed as input. + """ + + default_implementation = "sequence-log-prob" + + @abstractmethod + def score(self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int) -> torch.Tensor: + """ + Score the final predictions found by beam search. + Returns a tensor of the final sequence scores of shape `(batch_size, beam_size)`. + + :param predictions: A tensor containing the initial predictions with shape `(batch_size, beam_size, max_steps)`. + :param log_probabilities: A tensor containing the log probabilities of the sequence, defined as the sum + of the log probabilities per token, with shape `(batch_size, beam_size)`. + :param end_index: The index of the end symbol. + + """ + raise NotImplementedError + + +class SequenceLogProbabilityScorer(FinalSequenceScorer): + """ + A :class:`FinalSequenceScorer` which scores the sequences by the sum of the log probabilities + across the sequence's tokens. + """ + + def score(self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int) -> torch.Tensor: + del predictions, end_index + # The sum of the sequence log probabilities is the input parameter, so just + # return it. + return log_probabilities + + +class LengthNormalizedSequenceLogProbabilityScorer(FinalSequenceScorer): + """ + A :class:`FinalSequenceScorer` which scores the sequences by the average log probability of the + tokens in the sequence. It optionally includes a length penalty which promotes + or demotes sequences based on their lengths. The final score for a sequence will + be `(sequence_log_probability) / (sequence_length ** length_penalty)`. The sequence length + here includes the end token. + + :param length_penalty: The length penalty to use. A value of 1.0 means no length penalty is used. + A value > 1.0 favors longer sequences, and < 1.0 favors shorter sequences. + """ + + def __init__(self, length_penalty: float = 1.0): + super().__init__() + self.length_penalty = length_penalty + + def score(self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int) -> torch.Tensor: + # shape: (batch_size, beam_size) + lengths = (predictions != end_index).long().sum(dim=2) + + # If the sequence ended during beam search, the `log_probabilities` will include + # the transition to the end token. Therefore, in such situations, `lengths` is + # actually off by 1. This corrects for that. + # shape: (batch_size, beam_size) + is_end_token = predictions[:, :, -1] == end_index + lengths += is_end_token.long() + + # shape: (batch_size, beam_size) + average_log_probs = log_probabilities / (lengths**self.length_penalty) + return average_log_probs + + +class Constraint: + """ + An abstract class that can be used to enforce constraints on the output predictions + by manipulating the class log probabilities during beam search. + + A `Constraint` just has three methods that need to be implemented by subclasses: + `init_state()`, `apply()` and `_update_state()`. + + `init_state()` takes one argument: + + - the batch size, an int + + It returns a constraint state, which is a nested list of dictionaries, with any state needed for subsequent + calls to `apply()` and `update_state()`. The length of the outer list should be equal to `batch_size`. + Each inner list should be of length 1. + + `apply()` takes two arguments: + + - the constraint state, which is a nested list of dictionaries. The length of the outer list is `batch_size` + and the length of each inner list is `beam_size` except on the first time `apply()` is called when it is 1. + - `class_log_probabilities`, a tensor of shape `(batch_size, beam_size, num_classes)` that contains the + log probabilities for the classes during search. The first time `apply()` is called, `beam_size = 1`. + + The `apply()` method should return new `class_log_probabilities` that enforce the constraint + for this step of beam search. For instance, it may prevent a specific class from being selected by setting + the corresponding log probability to a negligible value such as `float("-inf")` or + `torch.finfo(class_log_probabilities.dtype).min`. + + `_update_state()` takes two arguments: + + - the copied parent constraint state, which is a nested list of dictionaries. `state[i][j]` contains the + copied state for the parent of `last_prediction[i, j]`. It is unique to that batch and beam, so it can be + directly edited in-place without affecting the others. + - last_prediction, a tensor of shape `(batch_size, beam_size)` containing the predictions from the last + step of beam search. + + The `_update_state()` function should return a new constraint state, a nested list of dictionaries of + length `batch_size` and inner list of length `beam_size`, one for each of the predictions in `last_prediction`. + + """ + + @abstractmethod + def init_state( + self, + batch_size: int, + ) -> ConstraintStateType: + raise NotImplementedError + + @abstractmethod + def apply( + self, + state: ConstraintStateType, + class_log_probabilities: torch.Tensor, + ) -> torch.Tensor: + raise NotImplementedError + + @staticmethod + def _copy_state( + state: ConstraintStateType, + batch_size: int, + beam_size: int, + last_backpointer: Optional[torch.Tensor] = None, + ) -> ConstraintStateType: + """ + Copies the `state` . This method copies the data in `state` using `copy.deepcopy()`. If this + is not appropriate for your constraint, you will need to implement the copying yourself. + """ + new_state = [] + for i in range(batch_size): + batch_state = [] + for j in range(beam_size): + if last_backpointer is None: + # This is the first prediction, so the backpointer is 0 + backpointer = 0 + else: + backpointer = last_backpointer[i, j].item() + batch_state.append(copy.deepcopy(state[i][backpointer])) # type: ignore + new_state.append(batch_state) + return new_state + + def update_state( + self, + state: ConstraintStateType, + last_prediction: torch.Tensor, + last_backpointer: Optional[torch.Tensor] = None, + ) -> ConstraintStateType: + batch_size, beam_size = last_prediction.size() + new_state = self._copy_state(state, batch_size, beam_size, last_backpointer) + return self._update_state(new_state, last_prediction) + + @abstractmethod + def _update_state( + self, + state: ConstraintStateType, + last_prediction: torch.Tensor, + ) -> ConstraintStateType: + raise NotImplementedError + + +class RepeatedNGramBlockingConstraint(Constraint): + def __init__(self, ngram_size: int, **kwargs) -> None: + super().__init__(**kwargs) + self.ngram_size = ngram_size + + def init_state( + self, + batch_size: int, + ) -> ConstraintStateType: + return [[{"seen_ngrams": {}, "current_prefix": []}] for _ in range(batch_size)] + + def apply( + self, + state: ConstraintStateType, + class_log_probabilities: torch.Tensor, + ) -> torch.Tensor: + for i, batch in enumerate(state): + for j, beam in enumerate(batch): + current_prefix = tuple(beam["current_prefix"]) + seen_ngrams = beam["seen_ngrams"] + try: + disallowed_indices = seen_ngrams[current_prefix] + class_log_probabilities[i, j, disallowed_indices] = torch.finfo( + class_log_probabilities.dtype + ).min + except KeyError: + # We have not seen this prefix before, so there is no index + # that needs to be blocked + pass + return class_log_probabilities + + def _update_state( + self, + state: ConstraintStateType, + last_prediction: torch.Tensor, + ) -> ConstraintStateType: + for i, batch in enumerate(state): + for j, beam in enumerate(batch): + prediction = last_prediction[i, j].item() + prefix = beam["current_prefix"] + seen_ngrams = beam["seen_ngrams"] + + if len(prefix) == self.ngram_size - 1: + # This is a new ngram that we have to remember + if tuple(prefix) not in seen_ngrams: + seen_ngrams[tuple(prefix)] = [] + seen_ngrams[tuple(prefix)].append(prediction) + + # Create the new prefix, removing the oldest index if the prefix + # is too long + prefix.append(prediction) + if len(prefix) == self.ngram_size: + prefix.pop(0) + return state + + +class BeamSearch: + """ + Implements the beam search algorithm for decoding the most likely sequences. + + :param end_index: The index of the "stop" or "end" token in the vocabulary. Usually the EOS token ID. + + :param max_steps: The maximum number of decoding steps to take, i.e. the maximum length + of the predicted sequences. + + :param beam_size: The width of the beam used. + + :param per_node_beam_size: The maximum number of candidates to consider per node, at each step in the search. + If not given, this just defaults to `beam_size`. Setting this parameter + to a number smaller than `beam_size` may give better results, as it can introduce + more diversity into the search. See + [*Beam Search Strategies for Neural Machine Translation*, Freitag and Al-Onaizan, 2017] + (https://api.semanticscholar.org/CorpusID:2229477). + + :param sampler: An optional `Sampler` which is used to pick next candidate nodes and beams. + If not specified, `DeterministicSampler` will be used, which just takes the + `per_node_beam_size` most likely nodes and the `beam_size` most likely beams. + + Using the [`GumbelSampler`](#gumbelsampler), on the other hand, will give you + [Stochastic Beam Search](https://api.semanticscholar.org/CorpusID:76662039). + + :param min_steps: The minimum number of decoding steps to take, i.e. the minimum length of + the predicted sequences. This does not include the start or end tokens. If `None`, + no minimum is enforced. + + :param final_sequence_scorer: An optional `FinalSequenceScorer` which is used to score the final generated sequences. + The output from this module is what is returned by the `search` method. If not + specified, `SequenceLogProbabilityScorer` will be used, which scores the sequences + by the sum of the token log probabilities. + + :param constraints: An optional list of `Constraint`s which should be applied during beam search. If not + provided, no constraints will be enforced. + + """ + + def __init__( + self, + end_index: int, + *, + max_steps: int = 50, + beam_size: int = 10, + per_node_beam_size: Optional[int] = None, + sampler: Optional[Sampler] = None, + min_steps: Optional[int] = None, + final_sequence_scorer: Optional[FinalSequenceScorer] = None, + constraints: Optional[List[Constraint]] = None, + ) -> None: + if not max_steps > 0: + raise ValueError("max_steps must be positive") + if not beam_size > 0: + raise ValueError("beam_size must be positive") + if per_node_beam_size is not None and not per_node_beam_size > 0: + raise ValueError("per_node_beam_size must be positive") + if min_steps is not None: + if not min_steps >= 0: + raise ValueError("min_steps must be non-negative") + if not min_steps <= max_steps: + raise ValueError("min_steps must be less than or equal to max_steps") + + self._end_index = end_index + self.max_steps = max_steps + self.beam_size = beam_size + self.per_node_beam_size = per_node_beam_size or beam_size + self.sampler = sampler or DeterministicSampler() + self.min_steps = min_steps or 0 + self.final_sequence_scorer = final_sequence_scorer or SequenceLogProbabilityScorer() + self.constraints = constraints or [] + + @staticmethod + def _reconstruct_sequences(predictions, backpointers): + # Reconstruct the sequences. + # shape: [(batch_size, beam_size, 1)] + reconstructed_predictions = [predictions[-1].unsqueeze(2)] + + if not backpointers: + return reconstructed_predictions + + # shape: (batch_size, beam_size) + cur_backpointers = backpointers[-1] + + for timestep in range(len(predictions) - 2, 0, -1): + # shape: (batch_size, beam_size, 1) + cur_preds = predictions[timestep].gather(1, cur_backpointers).unsqueeze(2) + + reconstructed_predictions.append(cur_preds) + + # shape: (batch_size, beam_size) + cur_backpointers = backpointers[timestep - 1].gather(1, cur_backpointers) + + # shape: (batch_size, beam_size, 1) + final_preds = predictions[0].gather(1, cur_backpointers).unsqueeze(2) + + reconstructed_predictions.append(final_preds) + + return reconstructed_predictions + + def search( + self, + start_predictions: torch.Tensor, + start_state: StateType, + step: StepFunctionType, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Given a starting state and a step function, apply beam search to find the + most likely target sequences. + + Returns a tuple of `(predictions, final_scores)`, where `predictions` + has shape `(batch_size, beam_size, max_steps)` and `final_scores` + has shape `(batch_size, beam_size)`. + + .. note:: + If your step function returns `-inf` for some log probabilities + (like if you're using a masked log-softmax) then some of the "best" + sequences returned may also have `-inf` log probability. Specifically + this happens when the beam size is smaller than the number of actions + with finite log probability (non-zero probability) returned by the step function. + Therefore if you're using a mask you may want to check the results from `search` + and potentially discard sequences with non-finite log probability. + + :param start_predictions: A tensor containing the initial predictions with shape `(batch_size,)`. + Usually the initial predictions are just the index of the "start" token + in the target vocabulary. + + :param start_state: The initial state passed to the `step` function. Each value of the state dict + should be a tensor of shape `(batch_size, *)`, where `*` means any other + number of dimensions. + + :param step: A function that is responsible for computing the next most likely tokens, + given the current state and the predictions from the last time step. + The function should accept two or three arguments: + + - a tensor of shape `(group_size,)` or representing the index of the predicted + tokens from the last time step, + - the current state, a `StateType`, and + - optionally, the timestep, an `int`. + + The `group_size` will be `batch_size * beam_size`, except in the initial + step, for which it will just be `batch_size`. + + The function is expected to return a tuple, where the first element + is a tensor of shape `(group_size, vocab_size)` containing + the log probabilities of the tokens for the next step, and the second + element is the updated state. The tensor in the state should have shape + `(group_size, *)`, where `*` means any other number of dimensions. + + """ + step_signature = signature(step) + if len(step_signature.parameters) < 3: + # If the step function we're given does not take the time step argument, wrap it + # in one that does. + old_step = cast(StepFunctionTypeNoTimestep, step) + + def new_step(last_predictions: torch.Tensor, state: Dict[str, torch.Tensor], time_step: int): + del time_step + return old_step(last_predictions, state) + + return self._search(start_predictions, start_state, new_step) + else: + return self._search(start_predictions, start_state, cast(StepFunctionTypeWithTimestep, step)) + + def _search( + self, + start_predictions: torch.Tensor, + start_state: StateType, + step: StepFunctionTypeWithTimestep, + ) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size = start_predictions.size()[0] + + # List of (batch_size, beam_size) tensors. One for each time step. Does not + # include the start symbols, which are implicit. + predictions: List[torch.Tensor] = [] + + # List of (batch_size, beam_size) tensors. One for each time step. None for + # the first. Stores the index n for the parent prediction, i.e. + # predictions[t-1][i][n], that it came from. + backpointers: List[torch.Tensor] = [] + + constraint_states = [constraint.init_state(batch_size) for constraint in self.constraints] + + # Calculate the first timestep. This is done outside the main loop + # because we are going from a single decoder input (the output from the + # encoder) to the top `beam_size` decoder outputs. On the other hand, + # within the main loop we are going from the `beam_size` elements of the + # beam to `beam_size`^2 candidates from which we will select the top + # `beam_size` elements for the next iteration. + # shape: (batch_size, num_classes) + start_class_log_probabilities, state = step(start_predictions, start_state, 0) + + num_classes = start_class_log_probabilities.size()[1] + + # Make sure `per_node_beam_size` is not larger than `num_classes`. + if self.per_node_beam_size > num_classes: + raise ValueError( + f"Vocab size ({num_classes:d}) too small " + f"relative to per_node_beam_size ({self.per_node_beam_size:d}).\n" + f"Please decrease beam_size or per_node_beam_size." + ) + + sampler_state = self.sampler.init_state(start_class_log_probabilities, batch_size, num_classes) + + # Apply all constraints. + if self.constraints: + # shape: (batch_size, 1, num_classes) + expanded_start_class_log_probabilities = start_class_log_probabilities.unsqueeze(1) + for constraint, constraint_state in zip(self.constraints, constraint_states): + expanded_start_class_log_probabilities = constraint.apply( + constraint_state, expanded_start_class_log_probabilities + ) + start_class_log_probabilities = expanded_start_class_log_probabilities.squeeze(1) + + # Prevent selecting the end symbol if there is any min_steps constraint + if self.min_steps >= 1: + start_class_log_probabilities[:, self._end_index] = torch.finfo( + start_class_log_probabilities.dtype + ).min + + # Get the initial predicted classed and their log probabilities. + # shape: (batch_size, beam_size), (batch_size, beam_size) + ( + start_top_log_probabilities, + start_predicted_classes, + sampler_state, + ) = self.sampler.sample_beams(start_class_log_probabilities, self.beam_size, sampler_state) + + if self.beam_size == 1 and (start_predicted_classes == self._end_index).all(): + warnings.warn( + "Empty sequences predicted. You may want to increase the beam size or ensure " + "your step function is working properly.", + RuntimeWarning, + ) + return start_predicted_classes.unsqueeze(-1), start_top_log_probabilities + + # The log probabilities for the last time step. + # shape: (batch_size, beam_size) + last_log_probabilities = start_top_log_probabilities + + # shape: [(batch_size, beam_size)] + predictions.append(start_predicted_classes) + + # Log probability tensor that mandates that the end token is selected. + # shape: (batch_size * beam_size, num_classes) + log_probs_after_end = start_class_log_probabilities.new_full( + (batch_size * self.beam_size, num_classes), + torch.finfo(start_class_log_probabilities.dtype).min, + ) + log_probs_after_end[:, self._end_index] = 0.0 + + # Set the same state for each element in the beam. + self._update_initial_state(state, batch_size) + + for i, constraint in enumerate(self.constraints): + constraint_states[i] = constraint.update_state(constraint_states[i], start_predicted_classes) + + for timestep in range(self.max_steps - 1): + # shape: (batch_size * beam_size,) + last_predictions = predictions[-1].reshape(batch_size * self.beam_size) + + # If every predicted token from the last step is `self._end_index`, + # then we can stop early. + if (last_predictions == self._end_index).all(): + break + # Take a step. This get the predicted log probs of the next classes + # and updates the state. + # shape: (batch_size * beam_size, num_classes) + class_log_probabilities, state = step(last_predictions, state, timestep + 1) + + # Apply all constraints. + if self.constraints: + # shape: (batch_size, beam_size, num_classes) + reshaped_class_log_probabilities = class_log_probabilities.view(batch_size, self.beam_size, -1) + for constraint, constraint_state in zip(self.constraints, constraint_states): + reshaped_class_log_probabilities = constraint.apply( + constraint_state, reshaped_class_log_probabilities + ) + # shape: (batch_size * beam_size, num_classes) + class_log_probabilities = reshaped_class_log_probabilities.view(batch_size * self.beam_size, -1) + + # The `timestep`-th iteration of the for loop is generating the `timestep + 2`-th token + # of the sequence (because `timestep` is 0-indexed and we generated the first token + # before the for loop). Here we block the end index if the search is not allowed to + # terminate on this iteration. + if timestep + 2 <= self.min_steps: + class_log_probabilities[:, self._end_index] = torch.finfo(class_log_probabilities.dtype).min + + # shape: (batch_size * beam_size, num_classes) + last_predictions_expanded = last_predictions.unsqueeze(-1).expand( + batch_size * self.beam_size, num_classes + ) + + # Here we are finding any beams where we predicted the end token in + # the previous timestep and replacing the distribution with a + # one-hot distribution, forcing the beam to predict the end token + # this timestep as well. + # shape: (batch_size * beam_size, num_classes) + cleaned_log_probabilities = torch.where( + last_predictions_expanded == self._end_index, + log_probs_after_end, + class_log_probabilities, + ) + + # shape (both): (batch_size * beam_size, per_node_beam_size) + top_log_probabilities, predicted_classes, sampler_state = self.sampler.sample_nodes( + cleaned_log_probabilities, self.per_node_beam_size, sampler_state + ) + + # Here we expand the last log probabilities to (batch_size * beam_size, per_node_beam_size) + # so that we can add them to the current log probs for this timestep. + # This lets us maintain the log probability of each element on the beam. + # shape: (batch_size * beam_size, per_node_beam_size) + expanded_last_log_probabilities = ( + last_log_probabilities.unsqueeze(2) + .expand(batch_size, self.beam_size, self.per_node_beam_size) + .reshape(batch_size * self.beam_size, self.per_node_beam_size) + ) + + # shape: (batch_size * beam_size, per_node_beam_size) + summed_top_log_probabilities = top_log_probabilities + expanded_last_log_probabilities + + # shape: (batch_size, beam_size * per_node_beam_size) + reshaped_summed = summed_top_log_probabilities.reshape( + batch_size, self.beam_size * self.per_node_beam_size + ) + + # shape: (batch_size, beam_size * per_node_beam_size) + reshaped_predicted_classes = predicted_classes.reshape( + batch_size, self.beam_size * self.per_node_beam_size + ) + + # Keep only the top `beam_size` beam indices. + # shape (both): (batch_size, beam_size) + ( + restricted_beam_log_probs, + restricted_beam_indices, + sampler_state, + ) = self.sampler.sample_beams(reshaped_summed, self.beam_size, sampler_state) + + # Use the beam indices to extract the corresponding classes. + # shape: (batch_size, beam_size) + restricted_predicted_classes = reshaped_predicted_classes.gather(1, restricted_beam_indices) + + predictions.append(restricted_predicted_classes) + + # shape: (batch_size, beam_size) + last_log_probabilities = restricted_beam_log_probs + + # The beam indices come from a `beam_size * per_node_beam_size` dimension where the + # indices with a common ancestor are grouped together. Hence + # dividing by per_node_beam_size gives the ancestor. (Note that this is integer + # division as the tensor is a LongTensor.) + # shape: (batch_size, beam_size) + backpointer = torch.divide(restricted_beam_indices, self.per_node_beam_size, rounding_mode="trunc") + backpointers.append(backpointer) + + # Keep only the pieces of the state tensors corresponding to the + # ancestors created this iteration. + self._update_state(state, backpointer) + + for i, constraint in enumerate(self.constraints): + constraint_states[i] = constraint.update_state( + constraint_states[i], restricted_predicted_classes, last_backpointer=backpointer + ) + + # Warn about "-inf" log probabilities if not using any constraints (negligible + # log probabilities are expected when using constraints). + if not self.constraints and ( + not torch.isfinite(last_log_probabilities).all() + or (last_log_probabilities == torch.finfo(last_log_probabilities.dtype).min).any() + ): + warnings.warn( + "Negligible log probabilities encountered ('-inf' or equivalent). " + "Some final sequences may not make sense. " + "This can happen when the beam size is larger than the number of valid (non-zero " + "probability) transitions that the step function produces.", + RuntimeWarning, + ) + + reconstructed_predictions = self._reconstruct_sequences(predictions, backpointers) + + # shape: (batch_size, beam_size, max_steps) + all_predictions = torch.cat(list(reversed(reconstructed_predictions)), 2) + + # Calculate the final sequence scores + # shape: (batch_size, beam_size) + final_scores = self.final_sequence_scorer.score(all_predictions, last_log_probabilities, self._end_index) + + # Sort the sequences based on the final scores so the best scoring + # sequence is at index 0 + sorted_final_scores, sorted_indices = torch.sort(final_scores, dim=1, descending=True) + sorted_all_predictions = torch.gather( + all_predictions, 1, sorted_indices.unsqueeze(-1).expand_as(all_predictions) + ) + + return sorted_all_predictions, sorted_final_scores + + def _update_initial_state(self, state: StateType, batch_size: int): + """ + Expand tensors in a state dictionary from `(batch_size, *)` to `(batch_size * beam_size, *)`. + """ + for key, state_tensor in state.items(): + if state_tensor is None: + continue + # shape: (batch_size * beam_size, *) + _, *last_dims = state_tensor.size() + state[key] = ( + state_tensor.unsqueeze(1) + .expand(batch_size, self.beam_size, *last_dims) + .reshape(batch_size * self.beam_size, *last_dims) + ) + + def _update_state(self, state: StateType, backpointer: torch.Tensor): + batch_size = backpointer.size()[0] + + for key, state_tensor in state.items(): + if state_tensor is None: + continue + _, *last_dims = state_tensor.size() + # shape: (batch_size, beam_size, *) + expanded_backpointer = backpointer.view(batch_size, self.beam_size, *([1] * len(last_dims))).expand( + batch_size, self.beam_size, *last_dims + ) + # shape: (batch_size * beam_size, *) + state[key] = ( + state_tensor.reshape(batch_size, self.beam_size, *last_dims) + .gather(1, expanded_backpointer) + .reshape(batch_size * self.beam_size, *last_dims) + ) diff --git a/dolma/model.py b/dolma/model.py index 98732026b..98c69f841 100644 --- a/dolma/model.py +++ b/dolma/model.py @@ -8,7 +8,7 @@ import math from abc import abstractmethod -from typing import NamedTuple, Optional, cast +from typing import List, NamedTuple, Optional, cast import torch import torch.backends.cuda @@ -16,6 +16,7 @@ import torch.nn.functional as F from torch import einsum +from .beam_search import BeamSearch, Constraint, FinalSequenceScorer, Sampler from .config import ActivationType, BlockType, LayerNormType, ModelConfig from .exceptions import DolmaConfigurationError @@ -32,6 +33,8 @@ "DolmaSequentialBlock", "DolmaParallelBlock", "Dolma", + "DolmaOutput", + "DolmaGenerateOutput", ] @@ -413,6 +416,19 @@ class DolmaOutput(NamedTuple): """ +class DolmaGenerateOutput(NamedTuple): + token_ids: torch.LongTensor + """ + The generated token IDs, a tensor of shape `(batch_size, beam_size, max_steps)`. + These do *not* include the original input IDs. + """ + + scores: torch.FloatTensor + """ + The scores of the generated sequences, a tensor of shape `(batch_size, beam_size)`. + """ + + class Dolma(nn.Module): def __init__(self, config: ModelConfig, init_params: bool = True): super().__init__() @@ -694,3 +710,100 @@ def num_fwd_flops(self): ) self.__num_fwd_flops = params_flops_per_seq + attn_flops_per_seq return self.__num_fwd_flops + + def generate( + self, + input_ids: torch.LongTensor, + attention_mask: Optional[torch.Tensor] = None, + attention_bias: Optional[torch.Tensor] = None, + max_steps: int = 10, + beam_size: int = 1, + per_node_beam_size: Optional[int] = None, + sampler: Optional[Sampler] = None, + min_steps: Optional[int] = None, + final_sequence_scorer: Optional[FinalSequenceScorer] = None, + constraints: Optional[List[Constraint]] = None, + ) -> DolmaGenerateOutput: + """ + Generate token IDs using beam search. + + Note that by default ``beam_size`` is set to 1, which is greedy decoding. + + :param input_ids: A tensor of shape `(batch_size, seq_len)`. + :param attention_mask: A optional tensor of shape `(batch_size, seq_len)`, the same + as for the forward method. + :param attention_bias: A tensor of shape + `(batch_size, 1, seq_len + tokens_to_generate, seq_len + tokens_to_generate)`, + the same as for the forward method except only one shape is excepted here. + + For an explanation of the other arguments, see the :class:`BeamSearch` class. + """ + beam_search = BeamSearch( + self.config.eos_token_id, + max_steps=max_steps, + beam_size=beam_size, + per_node_beam_size=per_node_beam_size, + sampler=sampler, + min_steps=min_steps, + final_sequence_scorer=final_sequence_scorer, + constraints=constraints, + ) + + # Validate inputs. + batch_size, seq_len = input_ids.shape + if attention_mask is not None: + assert attention_mask.shape == (batch_size, seq_len) + if attention_bias is not None: + assert len(attention_bias.shape) == 4 + assert attention_bias.shape[:2] == (batch_size, 1) + assert ( + seq_len + beam_search.max_steps + <= attention_bias.shape[2] + == attention_bias.shape[3] + <= self.config.max_sequence_length + ) + + tokens_generated = 0 + + def step( + last_predictions: torch.Tensor, state: dict[str, torch.Tensor] + ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + nonlocal tokens_generated + + input_ids = state["input_ids"] + attention_mask = state.get("attention_mask") + attention_bias = state.get("attention_bias") + group_size = input_ids.shape[0] + + if tokens_generated > 0: + input_ids = torch.cat((input_ids, last_predictions.unsqueeze(1)), dim=-1) + if attention_mask is not None: + attention_mask = torch.cat((attention_mask, attention_mask.new_ones((group_size, 1))), dim=-1) + + tokens_generated += 1 + + # Run forward pass of model to get logits, then normalize to get log probs. + output = self(input_ids, attention_mask=attention_mask, attention_bias=attention_bias) + log_probs = F.log_softmax(output.logits[:, -1, :], dim=-1) + + # Create new state. + state = {"input_ids": input_ids} + if attention_mask is not None: + state["attention_mask"] = attention_mask + if attention_bias is not None: + state["attention_bias"] = attention_bias + + return log_probs, state + + initial_preds = input_ids.new_zeros((batch_size,)) # This is arbitrary, we won't use this. + state: dict[str, torch.Tensor] = {"input_ids": input_ids} + if attention_mask is not None: + state["attention_mask"] = attention_mask + if attention_bias is not None: + state["attention_bias"] = attention_bias + token_ids, scores = beam_search.search(initial_preds, state, step) + + return DolmaGenerateOutput( + token_ids=token_ids, # type: ignore[arg-type] + scores=scores, # type: ignore[arg-type] + ) diff --git a/tests/beam_search_test.py b/tests/beam_search_test.py new file mode 100644 index 000000000..872517c8a --- /dev/null +++ b/tests/beam_search_test.py @@ -0,0 +1,731 @@ +from typing import Optional, Union, cast + +import numpy as np +import pytest +import torch + +from dolma.beam_search import ( + BeamSearch, + GumbelSampler, + LengthNormalizedSequenceLogProbabilityScorer, + MultinomialSampler, + RepeatedNGramBlockingConstraint, + StepFunctionTypeNoTimestep, + StepFunctionTypeWithTimestep, + TopKSampler, + TopPSampler, +) + +# fmt: off +transition_probabilities = torch.tensor( + [ # START 1 2 3 4 END + [0.0, 0.4, 0.3, 0.2, 0.1, 0.0], # START -> j + [0.0, 0.0, 1.0, 0.0, 0.0, 0.0], # 1 -> j + [0.0, 0.0, 0.0, 1.0, 0.0, 0.0], # 2 -> j + [0.0, 0.0, 0.0, 0.0, 1.0, 0.0], # 3 -> j + [0.0, 0.0, 0.0, 0.0, 0.0, 1.0], # 4 -> j + [0.2, 0.1, 0.2, 0.2, 0.2, 0.1], # END -> j (doesn't matter) + ] +) + +# A transition matrix that favors shorter sequences over longer ones +short_sequence_transition_probabilities = torch.tensor( + [ # START 1 2 3 4 END + [0.0, 0.1, 0.0, 0.0, 0.0, 0.9], # START -> j + [0.0, 0.0, 0.1, 0.0, 0.0, 0.9], # 1 -> j + [0.0, 0.0, 0.0, 0.1, 0.0, 0.9], # 2 -> j + [0.0, 0.0, 0.0, 0.0, 0.1, 0.9], # 3 -> j + [0.0, 0.0, 0.0, 0.0, 0.0, 1.0], # 4 -> j + [0.2, 0.1, 0.2, 0.2, 0.2, 0.1], # END -> j (doesn't matter) + ] +) + +# A transition matrix that favors repeated ngrams +repeated_ngram_transition_probabilities_0 = torch.tensor( + [ # START 1 2 3 4 END + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0], # START -> j + [0.0, 0.0, 0.4, 0.6, 0.0, 1e-9], # 1 -> j + [0.0, 0.0, 0.0, 1.0, 0.0, 1e-9], # 2 -> j + [0.0, 1.0, 0.0, 0.0, 0.0, 1e-9], # 3 -> j + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # 4 -> j (not used) + [0.0, 0.0, 0.0, 0.0, 0.0, 1.0], # END -> j (doesn't matter) + ] +) + +# Another transition matrix that favors repeated ngrams +repeated_ngram_transition_probabilities_1 = torch.tensor( + [ # START 1 2 3 4 END + [0.0, 0.4, 0.3, 0.2, 0.1, 0.0], # START -> j + [0.0, 0.4, 0.3, 0.2, 0.1, 0.1], # 1 -> j + [0.0, 0.0, 0.4, 0.3, 0.2, 0.1], # 2 -> j + [0.0, 0.0, 0.3, 0.4, 0.2, 0.1], # 3 -> j + [0.0, 0.0, 0.2, 0.3, 0.4, 0.1], # 4 -> j + [0.2, 0.1, 0.2, 0.2, 0.2, 0.1], # END -> j (doesn't matter) + ] +) +# fmt: on + +log_probabilities = torch.log(torch.tensor([[0.1, 0.3, 0.3, 0.3, 0.0, 0.0], [0.0, 0.0, 0.4, 0.3, 0.2, 0.1]])) + + +def get_step_function( + transition_matrix: torch.Tensor, with_timestep: bool = False +) -> Union[StepFunctionTypeNoTimestep, StepFunctionTypeWithTimestep]: + def _step_function( + last_predictions: torch.Tensor, state: dict[str, torch.Tensor] + ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + log_probs_list = [] + for last_token in last_predictions: + log_probs = torch.log(transition_matrix[last_token.item()]) # type: ignore + log_probs_list.append(log_probs) + + return torch.stack(log_probs_list), state + + if not with_timestep: + return _step_function + + def _step_function_with_timestep( + last_predictions: torch.Tensor, + state: dict[str, torch.Tensor], + timestep: int, + ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + del timestep + return _step_function(last_predictions, state) + + return _step_function_with_timestep + + +take_step_no_timestep = cast(StepFunctionTypeNoTimestep, get_step_function(transition_probabilities)) +take_step_with_timestep = cast( + StepFunctionTypeWithTimestep, get_step_function(transition_probabilities, with_timestep=True) +) +take_short_sequence_step = cast( + StepFunctionTypeNoTimestep, get_step_function(short_sequence_transition_probabilities) +) + + +class BeamSearchTest: + def setup_method(self): + self.end_index = transition_probabilities.size()[0] - 1 + self.beam_search = BeamSearch(self.end_index, max_steps=10, beam_size=3) + + # This is what the top k should look like for each item in the batch. + self.expected_top_k = np.array([[1, 2, 3, 4, 5], [2, 3, 4, 5, 5], [3, 4, 5, 5, 5]]) + + # This is what the log probs should look like for each item in the batch. + self.expected_log_probs = np.log(np.array([0.4, 0.3, 0.2])) + + def _check_results( + self, + batch_size: int = 5, + expected_top_k: Optional[np.array] = None, # type: ignore + expected_log_probs: Optional[np.array] = None, # type: ignore + beam_search: Optional[BeamSearch] = None, + state: Optional[dict[str, torch.Tensor]] = None, + take_step: Union[StepFunctionTypeNoTimestep, StepFunctionTypeWithTimestep] = take_step_with_timestep, + ) -> None: + expected_top_k = expected_top_k if expected_top_k is not None else self.expected_top_k + expected_log_probs = expected_log_probs if expected_log_probs is not None else self.expected_log_probs + state = state or {} + + beam_search = beam_search or self.beam_search + beam_size = beam_search.beam_size + + initial_predictions = torch.tensor([0] * batch_size) + top_k, log_probs = beam_search.search(initial_predictions, state, take_step) # type: ignore + + # top_k should be shape `(batch_size, beam_size, max_predicted_length)`. + assert list(top_k.size())[:-1] == [batch_size, beam_size] + np.testing.assert_array_equal(top_k[0].numpy(), expected_top_k) + + # log_probs should be shape `(batch_size, beam_size, max_predicted_length)`. + assert list(log_probs.size()) == [batch_size, beam_size] + np.testing.assert_allclose(log_probs[0].numpy(), expected_log_probs, rtol=1e-6) + + @pytest.mark.parametrize("step_function", [take_step_with_timestep, take_step_no_timestep]) + def test_search(self, step_function): + self._check_results(take_step=step_function) + + def test_finished_state(self): + state = {} + state["foo"] = torch.tensor([[1, 0, 1], [2, 0, 1], [0, 0, 1], [1, 1, 1], [0, 0, 0]]) + # shape: (batch_size, 3) + + expected_finished_state = {} + expected_finished_state["foo"] = np.array( + [ + [1, 0, 1], + [1, 0, 1], + [1, 0, 1], + [2, 0, 1], + [2, 0, 1], + [2, 0, 1], + [0, 0, 1], + [0, 0, 1], + [0, 0, 1], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0], + ] + ) + # shape: (batch_size x beam_size, 3) + + self._check_results(state=state) + + # check finished state. + for key, array in expected_finished_state.items(): + np.testing.assert_allclose(state[key].numpy(), array) + + def test_batch_size_of_one(self): + self._check_results(batch_size=1) + + def test_greedy_search(self): + beam_search = BeamSearch(self.end_index, beam_size=1) + expected_top_k = np.array([[1, 2, 3, 4, 5]]) + expected_log_probs = np.log(np.array([0.4])) + self._check_results( + expected_top_k=expected_top_k, + expected_log_probs=expected_log_probs, + beam_search=beam_search, + ) + + def test_single_step(self): + self.beam_search.max_steps = 1 + expected_top_k = np.array([[1], [2], [3]]) + expected_log_probs = np.log(np.array([0.4, 0.3, 0.2])) + self._check_results( + expected_top_k=expected_top_k, + expected_log_probs=expected_log_probs, + ) + + def test_early_stopping(self): + """ + Checks case where beam search will reach `max_steps` before finding end tokens. + """ + beam_search = BeamSearch(self.end_index, beam_size=3, max_steps=3) + expected_top_k = np.array([[1, 2, 3], [2, 3, 4], [3, 4, 5]]) + expected_log_probs = np.log(np.array([0.4, 0.3, 0.2])) + self._check_results( + expected_top_k=expected_top_k, + expected_log_probs=expected_log_probs, + beam_search=beam_search, + ) + + def test_take_short_sequence_step(self): + """ + Tests to ensure the top-k from the short_sequence_transition_probabilities + transition matrix is expected + """ + self.beam_search.beam_size = 5 + expected_top_k = np.array( + [[5, 5, 5, 5, 5], [1, 5, 5, 5, 5], [1, 2, 5, 5, 5], [1, 2, 3, 5, 5], [1, 2, 3, 4, 5]] + ) + expected_log_probs = np.log(np.array([0.9, 0.09, 0.009, 0.0009, 0.0001])) + self._check_results( + expected_top_k=expected_top_k, + expected_log_probs=expected_log_probs, + take_step=take_short_sequence_step, + ) + + def test_min_steps(self): + """ + Tests to ensure all output sequences are greater than a specified minimum length. + It uses the `take_short_sequence_step` step function, which favors shorter sequences. + See `test_take_short_sequence_step`. + """ + self.beam_search.beam_size = 1 + + # An empty sequence is allowed under this step function + self.beam_search.min_steps = 0 + expected_top_k = np.array([[5]]) + expected_log_probs = np.log(np.array([0.9])) + with pytest.warns(RuntimeWarning, match="Empty sequences predicted"): + self._check_results( + expected_top_k=expected_top_k, + expected_log_probs=expected_log_probs, + take_step=take_short_sequence_step, + ) + + self.beam_search.min_steps = 1 + expected_top_k = np.array([[1, 5]]) + expected_log_probs = np.log(np.array([0.09])) + self._check_results( + expected_top_k=expected_top_k, + expected_log_probs=expected_log_probs, + take_step=take_short_sequence_step, + ) + + self.beam_search.min_steps = 2 + expected_top_k = np.array([[1, 2, 5]]) + expected_log_probs = np.log(np.array([0.009])) + self._check_results( + expected_top_k=expected_top_k, + expected_log_probs=expected_log_probs, + take_step=take_short_sequence_step, + ) + + self.beam_search.beam_size = 3 + self.beam_search.min_steps = 2 + expected_top_k = np.array([[1, 2, 5, 5, 5], [1, 2, 3, 5, 5], [1, 2, 3, 4, 5]]) + expected_log_probs = np.log(np.array([0.009, 0.0009, 0.0001])) + self._check_results( + expected_top_k=expected_top_k, + expected_log_probs=expected_log_probs, + take_step=take_short_sequence_step, + ) + + def test_different_per_node_beam_size(self): + # per_node_beam_size = 1 + beam_search = BeamSearch(self.end_index, beam_size=3, per_node_beam_size=1) + self._check_results(beam_search=beam_search) + + # per_node_beam_size = 2 + beam_search = BeamSearch(self.end_index, beam_size=3, per_node_beam_size=2) + self._check_results(beam_search=beam_search) + + def test_catch_bad_config(self): + """ + If `per_node_beam_size` (which defaults to `beam_size`) is larger than + the size of the target vocabulary, `BeamSearch.search` should raise + a ValueError. + """ + beam_search = BeamSearch(self.end_index, beam_size=20) + with pytest.raises(ValueError): + self._check_results(beam_search=beam_search) + + def test_warn_for_bad_log_probs(self): + # The only valid next step from the initial predictions is the end index. + # But with a beam size of 3, the call to `topk` to find the 3 most likely + # next beams will result in 2 new beams that are invalid, in that have probability of 0. + # The beam search should warn us of this. + initial_predictions = torch.LongTensor([self.end_index - 1, self.end_index - 1]) + with pytest.warns(RuntimeWarning, match="Negligible log probabilities"): + self.beam_search.search(initial_predictions, {}, take_step_no_timestep) + + def test_empty_sequences(self): + initial_predictions = torch.LongTensor([self.end_index - 1, self.end_index - 1]) + beam_search = BeamSearch(self.end_index, beam_size=1) + with pytest.warns(RuntimeWarning, match="Empty sequences predicted"): + predictions, log_probs = beam_search.search(initial_predictions, {}, take_step_with_timestep) + # predictions hould have shape `(batch_size, beam_size, max_predicted_length)`. + assert list(predictions.size()) == [2, 1, 1] + # log probs hould have shape `(batch_size, beam_size)`. + assert list(log_probs.size()) == [2, 1] + assert (predictions == self.end_index).all() + assert (log_probs == 0).all() + + def test_top_p_search(self): + initial_predictions = torch.tensor([0] * 5) + beam_size = 3 + take_step = take_step_with_timestep + p_sampler = TopPSampler(p=0.8) + + top_p, log_probs = BeamSearch(self.end_index, beam_size=beam_size, max_steps=10, sampler=p_sampler).search( + initial_predictions, {}, take_step + ) + + beam_size = beam_size or 1 + batch_size = 5 + + # top_p should be shape `(batch_size, beam_size, max_predicted_length)`. + assert list(top_p.size())[:-1] == [batch_size, beam_size] + + assert ((0 <= top_p) & (top_p <= 5)).all() + + # log_probs should be shape `(batch_size, beam_size, max_predicted_length)`. + assert list(log_probs.size()) == [batch_size, beam_size] + + @pytest.mark.parametrize("p_val", [-1.0, 1.2, 1.1, float("inf")]) + def test_p_val(self, p_val): + with pytest.raises(ValueError): + initial_predictions = torch.tensor([0] * 5) + take_step = take_step_with_timestep + beam_size = 3 + p_sampler = TopPSampler(p=p_val, with_replacement=True) + + BeamSearch(self.end_index, beam_size=beam_size, max_steps=10, sampler=p_sampler).search( + initial_predictions, {}, take_step + ) + + def test_top_k_search(self): + initial_predictions = torch.tensor([0] * 5) + beam_size = 3 + take_step = take_step_with_timestep + k_sampler = TopKSampler(k=5, with_replacement=True) + + top_k, log_probs = BeamSearch(self.end_index, beam_size=beam_size, max_steps=10, sampler=k_sampler).search( + initial_predictions, {}, take_step + ) + + beam_size = beam_size or 1 + batch_size = 5 + + # top_p should be shape `(batch_size, beam_size, max_predicted_length)`. + assert list(top_k.size())[:-1] == [batch_size, beam_size] + + assert ((0 <= top_k) & (top_k <= 5)).all() + + # log_probs should be shape `(batch_size, beam_size, max_predicted_length)`. + assert list(log_probs.size()) == [batch_size, beam_size] + + @pytest.mark.parametrize("k_val", [-1, 0]) + def test_k_val(self, k_val): + with pytest.raises(ValueError): + initial_predictions = torch.tensor([0] * 5) + take_step = take_step_with_timestep + beam_size = 3 + k_sampler = TopKSampler(k=k_val, with_replacement=True) + + BeamSearch(self.end_index, beam_size=beam_size, max_steps=10, sampler=k_sampler).search( + initial_predictions, {}, take_step + ) + + def test_stochastic_beam_search(self): + initial_predictions = torch.tensor([0] * 5) + batch_size = 5 + beam_size = 3 + take_step = take_step_with_timestep + + gumbel_sampler = GumbelSampler() + + top_k, log_probs = BeamSearch( + self.end_index, beam_size=beam_size, max_steps=10, sampler=gumbel_sampler + ).search(initial_predictions, {}, take_step) + + # top_p should be shape `(batch_size, beam_size, max_predicted_length)`. + assert list(top_k.size())[:-1] == [batch_size, beam_size] + + assert ((0 <= top_k) & (top_k <= 5)).all() + + # log_probs should be shape `(batch_size, beam_size, max_predicted_length)`. + assert list(log_probs.size()) == [batch_size, beam_size] + + # Check to make sure that once the end index is predicted, all subsequent tokens + # must be the end index. This has been tested on toy examples in which + for batch in top_k: + for beam in batch: + reached_end = False + for token in beam: + if token == self.end_index: + reached_end = True + if reached_end: + assert token == self.end_index + + def test_multinomial_sampler(self): + sampler = MultinomialSampler(temperature=0.9) + + probabilities, classes, _ = sampler.sample_nodes(log_probabilities, 3, {}) + + assert probabilities.size() == classes.size() + assert classes.size() == (2, 3) + assert all([x < 4 for x in classes[0]]) + assert all([x > 1 for x in classes[1]]) + + def test_top_k_sampler(self): + sampler = TopKSampler(k=3, temperature=0.9) + + probabilities, classes, _ = sampler.sample_nodes(log_probabilities, 3, {}) + + assert probabilities.size() == classes.size() + assert classes.size() == (2, 3) + + assert all([x > 0 and x < 4 for x in classes[0]]) + assert all([x > 1 and x < 5 for x in classes[1]]) + + def test_top_p_sampler(self): + sampler = TopPSampler(p=0.8, temperature=0.9) + + probabilities, classes, _ = sampler.sample_nodes(log_probabilities, 3, {}) + + assert probabilities.size() == classes.size() + assert classes.size() == (2, 3) + + assert all([x > 0 and x < 4 for x in classes[0]]) + assert all([x > 1 and x < 5 for x in classes[1]]) + + # Make sure the filtered classes include the first class that exceeds p + sampler = TopPSampler(p=0.7, temperature=1.0) + + probabilities, classes, _ = sampler.sample_nodes(log_probabilities, 2, {}) + + assert all([x == 2 or x == 3 or x == 1 for x in classes[0]]) + assert all([x == 2 or x == 3 for x in classes[1]]) + + def test_gumbel_sampler(self): + sampler = GumbelSampler() + num_classes = len(log_probabilities[0]) + sampler_state = sampler.init_state(log_probabilities, batch_size=2, num_classes=num_classes) + + log_probs, indices, _ = sampler.sample_beams(log_probabilities, 3, sampler_state) + + assert log_probs.size() == indices.size() + assert indices.size() == (2, 3) + + # Make sure the probabilities are sorted. + _, sorted_indices = log_probs.sort(dim=-1, descending=True) + assert (sorted_indices == torch.arange(3).unsqueeze(0)).all() + + assert all([x >= 0 and x < 4 for x in indices[0]]) + assert all([x > 1 and x <= 5 for x in indices[1]]) + + def test_length_normalized_sequence_log_prob_scorer(self): + """ + Tests to ensure the sequences are normalized by the correct values. The end token is + included in the length. The start token is not. + """ + self.beam_search.final_sequence_scorer = LengthNormalizedSequenceLogProbabilityScorer() + expected_log_probs = np.log(np.array([0.4, 0.3, 0.2])) + length_normalization = np.array([5, 4, 3]) + expected_scores = expected_log_probs / length_normalization + self._check_results(expected_log_probs=expected_scores) + + # Introduce a length penalty + length_penalty = 2.0 + self.beam_search.final_sequence_scorer = LengthNormalizedSequenceLogProbabilityScorer( + length_penalty=length_penalty + ) + expected_log_probs = np.log(np.array([0.4, 0.3, 0.2])) + length_normalization = np.array([5**length_penalty, 4**length_penalty, 3**length_penalty]) + expected_scores = expected_log_probs / length_normalization + self._check_results(expected_log_probs=expected_scores) + + # Pick a length penalty so extreme that the order of the sequences is reversed + length_penalty = -2.0 + self.beam_search.final_sequence_scorer = LengthNormalizedSequenceLogProbabilityScorer( + length_penalty=length_penalty + ) + expected_top_k = np.array([[3, 4, 5, 5, 5], [2, 3, 4, 5, 5], [1, 2, 3, 4, 5]]) + expected_log_probs = np.log(np.array([0.2, 0.3, 0.4])) + length_normalization = np.array([3**length_penalty, 4**length_penalty, 5**length_penalty]) + expected_scores = expected_log_probs / length_normalization + self._check_results(expected_top_k=expected_top_k, expected_log_probs=expected_scores) + + # Here, we set the max_steps = 4. This prevents the first sequence from finishing, + # so its length does not include the end token, whereas the other sequences do. + length_penalty = 2.0 + self.beam_search.max_steps = 4 + self.beam_search.final_sequence_scorer = LengthNormalizedSequenceLogProbabilityScorer( + length_penalty=length_penalty + ) + expected_top_k = np.array([[1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 5]]) + expected_log_probs = np.log(np.array([0.4, 0.3, 0.2])) + length_normalization = np.array([4**length_penalty, 4**length_penalty, 3**length_penalty]) + expected_scores = expected_log_probs / length_normalization + self._check_results(expected_top_k=expected_top_k, expected_log_probs=expected_scores) + + def test_repeated_ngram_blocking_constraint_init_state(self): + ngram_size = 3 + batch_size = 2 + constraint = RepeatedNGramBlockingConstraint(ngram_size) + + state = constraint.init_state(batch_size) + assert len(state) == batch_size + for beam_states in state: + assert len(beam_states) == 1 + beam_state = beam_states[0] + assert len(beam_state.keys()) == 2 + assert len(beam_state["current_prefix"]) == 0 + assert len(beam_state["seen_ngrams"]) == 0 + + def test_repeated_ngram_blocking_constraint_apply(self): + ngram_size = 3 + batch_size = 2 + beam_size = 2 + num_classes = 10 + constraint = RepeatedNGramBlockingConstraint(ngram_size) + + state = [ + [ + {"current_prefix": [0, 1], "seen_ngrams": {}}, + {"current_prefix": [2, 3], "seen_ngrams": {(2, 3): [4]}}, + ], + [ + {"current_prefix": [4, 5], "seen_ngrams": {(8, 9): []}}, + {"current_prefix": [6, 7], "seen_ngrams": {(6, 7): [0, 1, 2]}}, + ], + ] + log_probabilities = torch.rand(batch_size, beam_size, num_classes) + constraint.apply(state, log_probabilities) # type: ignore + + disallowed_locations = torch.nonzero( + log_probabilities == torch.finfo(log_probabilities.dtype).min + ).tolist() + assert len(disallowed_locations) == 4 + assert [0, 1, 4] in disallowed_locations + assert [1, 1, 0] in disallowed_locations + assert [1, 1, 1] in disallowed_locations + assert [1, 1, 2] in disallowed_locations + + def test_repeated_ngram_blocking_constraint_update_state(self): + ngram_size = 3 + constraint = RepeatedNGramBlockingConstraint(ngram_size) + + # We will have [2, 3] -> {5, 6} from batch index 0 and [4, 5] -> {0} and [6, 7] -> {3} + # from batch index + state = [ + [ + {"current_prefix": [0, 1], "seen_ngrams": {}}, + {"current_prefix": [2, 3], "seen_ngrams": {(2, 3): [4]}}, + ], + [ + {"current_prefix": [4, 5], "seen_ngrams": {(8, 9): []}}, + {"current_prefix": [6, 7], "seen_ngrams": {(6, 7): [0, 1, 2]}}, + ], + ] + predictions = torch.LongTensor([[5, 6], [0, 3]]) + backpointers = torch.LongTensor([[1, 1], [0, 1]]) + + expected_state = [ + [ + {"current_prefix": [3, 5], "seen_ngrams": {(2, 3): [4, 5]}}, + {"current_prefix": [3, 6], "seen_ngrams": {(2, 3): [4, 6]}}, + ], + [ + {"current_prefix": [5, 0], "seen_ngrams": {(8, 9): [], (4, 5): [0]}}, + {"current_prefix": [7, 3], "seen_ngrams": {(6, 7): [0, 1, 2, 3]}}, + ], + ] + updated_state = constraint.update_state(state, predictions, backpointers) # type: ignore + assert updated_state == expected_state + + def test_take_repeated_ngram_step(self): + """ + Tests to ensure the top-k from the `repeated_ngram_transition_probabilities_0` + transition matrix is expected. The transitions are: + + - p(1|start) = 1.0 + - p(2|1) = 0.4 + - p(3|1) = 0.6 + - p(end|1) = 1e-9 + - p(3|2) = 1.0 + - p(end|2) = 1e-9 + - p(1|3) = 1.0 + - p(end|3) = 1e-9 + + The probabilities don't add up 1 because of the 1e-9 transitions to end. That doesn't + really matter. Each state just needed some transition to the end probability with a very + small probability to ensure it's possible to reach the end state from there and that it + isn't selected by beam search without a constraint. + + Below is the beam search tracing for beam size 2. Any sequence below the + line is not selected by beam search. The number that comes before the sequence + is the probability of the sequence. + + Step 1 + 1.0: [1] + + Step 2 + 0.6: [1, 3] + 0.4: [1, 2] + ----- + 1e-9: [1, 2, end] + + Step 3 + 0.6: [1, 3, 1] + 0.4: [1, 2, 3] + ----- + 0.6 * 1e-9: [1, 3, end] + 0.4 * 1e-9: [1, 2, end] + + Step 4 + 0.4: [1, 2, 3, 1] + 0.36: [1, 3, 1, 3] + ----- + 0.24: [1, 3, 1, 2] + 0.6 * 1e-9: [1, 3, 1, end] + 0.4 * 1e-9: [1, 2, 3, end] + + Step 5 + 0.36: [1, 3, 1, 3, 1] + 0.24: [1, 2, 3, 1, 3] + ----- + 0.16: [1, 2, 3, 1, 2] + 0.4 * 1e-9: [1, 2, 3, 1, end] + 0.36 * 1e-9: [1, 3, 1, 3, end] + """ + step_function = get_step_function(repeated_ngram_transition_probabilities_0) + self.beam_search.beam_size = 2 + self.beam_search.max_steps = 5 + expected_top_k = np.array([[1, 3, 1, 3, 1], [1, 2, 3, 1, 3]]) + expected_log_probs = np.log(np.array([0.36, 0.24])) + self._check_results( + expected_top_k=expected_top_k, + expected_log_probs=expected_log_probs, + take_step=step_function, + ) + + def test_repeated_ngram_blocking_end_to_end_unigrams(self): + step_function = get_step_function(repeated_ngram_transition_probabilities_0) + self.beam_search.beam_size = 2 + + # Unigrams: On step 3, [1, 3, 1] will be blocked and [1, 3, end] will take its place + self.beam_search.max_steps = 3 + self.beam_search.constraints = [RepeatedNGramBlockingConstraint(ngram_size=1)] # type: ignore + expected_top_k = np.array([[1, 2, 3], [1, 3, 5]]) + expected_log_probs = np.log(np.array([0.4, 0.6 * 1e-9])) + self._check_results( + expected_top_k=expected_top_k, + expected_log_probs=expected_log_probs, + take_step=step_function, + ) + + step_function = get_step_function(repeated_ngram_transition_probabilities_1) + self.beam_search.max_steps = 5 + expected_top_k = np.array([[1, 2, 3, 4, 5], [1, 2, 4, 3, 5]]) + expected_log_probs = np.log(np.array([0.4 * 0.3 * 0.3 * 0.2 * 0.1, 0.4 * 0.3 * 0.2 * 0.3 * 0.1])) + self._check_results( + expected_top_k=expected_top_k, + expected_log_probs=expected_log_probs, + take_step=step_function, + ) + + def test_repeated_ngram_blocking_end_to_end_bigrams(self): + step_function = get_step_function(repeated_ngram_transition_probabilities_0) + self.beam_search.beam_size = 2 + + # Bigrams: On step 4, [1, 3, 1, 3] will be blocked and [1, 3, 1, 2] will take its place + self.beam_search.max_steps = 4 + self.beam_search.constraints = [RepeatedNGramBlockingConstraint(ngram_size=2)] # type: ignore + expected_top_k = np.array([[1, 2, 3, 1], [1, 3, 1, 2]]) + expected_log_probs = np.log(np.array([0.4, 0.24])) + self._check_results( + expected_top_k=expected_top_k, + expected_log_probs=expected_log_probs, + take_step=step_function, + ) + + def test_repeated_ngram_blocking_end_to_end_trigrams(self): + step_function = get_step_function(repeated_ngram_transition_probabilities_0) + self.beam_search.beam_size = 2 + + # Trigrams: On step 5, [1, 3, 1, 3, 1] will be blocked and [1, 2, 3, 1, 2] will take its place + self.beam_search.max_steps = 5 + self.beam_search.constraints = [RepeatedNGramBlockingConstraint(ngram_size=3)] # type: ignore + expected_top_k = np.array([[1, 2, 3, 1, 3], [1, 2, 3, 1, 2]]) + expected_log_probs = np.log(np.array([0.24, 0.16])) + self._check_results( + expected_top_k=expected_top_k, + expected_log_probs=expected_log_probs, + take_step=step_function, + ) + + def test_repeated_ngram_blocking_end_indices(self): + """ + Ensures that the ngram blocking does not mess up when one sequence is shorter + than another, which would result in repeated "end" symbols. + """ + # We block unigrams, but 5 (the end symbol) is repeated and it does not mess + # up the sequence's probability + step_function = get_step_function(repeated_ngram_transition_probabilities_0) + self.beam_search.beam_size = 2 + self.beam_search.constraints = [RepeatedNGramBlockingConstraint(ngram_size=1)] # type: ignore + expected_top_k = np.array([[1, 3, 5, 5], [1, 2, 3, 5]]) + expected_log_probs = np.log(np.array([0.6 * 1e-9, 0.4 * 1e-9])) + self._check_results( + expected_top_k=expected_top_k, + expected_log_probs=expected_log_probs, + take_step=step_function, + ) diff --git a/tests/model_test.py b/tests/model_test.py index f834fcbd5..a952adb01 100644 --- a/tests/model_test.py +++ b/tests/model_test.py @@ -4,6 +4,7 @@ from dolma import BlockType, Dolma, ModelConfig, Tokenizer, TrainConfig from dolma.composer import build_optimizer +from dolma.config import PaddingDirection from dolma.data import DataCollator @@ -107,7 +108,7 @@ def test_forward( flash_attn: bool, block_type: BlockType, cuda: bool, - dtype, + dtype: torch.dtype, ): torch.manual_seed(0) torch.use_deterministic_algorithms(True) @@ -255,3 +256,70 @@ def test_backward( def test_build_optimizer(model_config: ModelConfig): build_optimizer(Dolma(model_config)) + + +@pytest.mark.parametrize( + "cuda, dtype", + [ + pytest.param(False, torch.float32, id="cpu-fp32"), + pytest.param( + True, + torch.float32, + id="cuda-fp32", + marks=( + pytest.mark.gpu, + pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Requires CUDA device"), + ), + ), + # TODO: with an uninitialized model like we have here we'll end up with nan's + # when we use half-precision. So eventually we should use a trained model in these tests. + # pytest.param(False, torch.bfloat16, id="cpu-bf16"), + ], +) +def test_generate( + train_config: TrainConfig, + tokenizer: Tokenizer, + cuda: bool, + dtype: torch.dtype, +): + torch.manual_seed(0) + torch.use_deterministic_algorithms(True) + + # Should always pad left when generating. + train_config.data.pad_direction = PaddingDirection.left + # We also need to use a relative positional embedding so that the + # padding doesn't affect the results. + train_config.model.alibi = True + + if cuda: + train_config.model.init_device = "cuda" + else: + train_config.model.init_device = "cpu" + use_amp = dtype in {torch.float16, torch.bfloat16} + + model = Dolma(train_config.model).eval() + + input1 = tokenizer.encode("My name is DOLMA! ", add_special_tokens=False) + input2 = tokenizer.encode("I'm a delightful large open language model :) ", add_special_tokens=False) + batch_inputs = DataCollator.from_train_config(train_config)( + [ # type: ignore + {"input_ids": input1, "attention_mask": [1.0] * len(input1)}, + {"input_ids": input2, "attention_mask": [1.0] * len(input2)}, + ] + ) + batch_inputs = { # type: ignore + k: v.to(device=train_config.device) if isinstance(v, torch.Tensor) else v for k, v in batch_inputs.items() + } + beam_search_kwargs = dict(beam_size=3, max_steps=5) + + with torch.inference_mode(): + with torch.autocast( + device_type="cuda" if cuda else "cpu", enabled=use_amp, dtype=None if not use_amp else dtype + ): + output1 = model.generate( + torch.tensor(input1, device=train_config.device).unsqueeze(0), # type: ignore + **beam_search_kwargs, + ) + batch_output = model.generate(**{**batch_inputs, **beam_search_kwargs}) + + torch.testing.assert_close(output1.scores[0], batch_output.scores[0])