Skip to content

Commit

Permalink
Refactored FSDP sampling and fixed bug when config.n_eval_model_sampl…
Browse files Browse the repository at this point in the history
…es < config.eval_batch_size.
  • Loading branch information
Eric Mitchell committed Jul 12, 2023
1 parent 5815fab commit 68e6015
Showing 1 changed file with 18 additions and 11 deletions.
29 changes: 18 additions & 11 deletions trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from torch.distributed.fsdp.api import FullStateDictConfig, FullOptimStateDictConfig
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
import tensor_parallel as tp
import contextlib

from preference_datasets import get_batch_iterator
from utils import (
Expand Down Expand Up @@ -172,12 +173,18 @@ def __init__(self, policy: nn.Module, config: DictConfig, seed: int, run_dir: st
def get_batch_samples(self, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]:
"""Generate samples from the policy (and reference model, if doing DPO training) for the given batch of inputs."""

policy_output = self.policy.generate(
batch['prompt_input_ids'], attention_mask=batch['prompt_attention_mask'], max_length=self.config.max_length, do_sample=True, pad_token_id=self.tokenizer.pad_token_id)
if self.config.loss.name == 'dpo':
reference_output = self.reference_model.generate(
# FSDP generation according to https://github.com/pytorch/pytorch/issues/100069
ctx = lambda: (FSDP.summon_full_params(self.policy, writeback=False, recurse=False) if 'FSDP' in self.config.trainer else contextlib.nullcontext())
with ctx():
policy_output = self.policy.generate(
batch['prompt_input_ids'], attention_mask=batch['prompt_attention_mask'], max_length=self.config.max_length, do_sample=True, pad_token_id=self.tokenizer.pad_token_id)

if self.config.loss.name == 'dpo':
ctx = lambda: (FSDP.summon_full_params(self.reference_model, writeback=False, recurse=False) if 'FSDP' in self.config.trainer else contextlib.nullcontext())
with ctx():
reference_output = self.reference_model.generate(
batch['prompt_input_ids'], attention_mask=batch['prompt_attention_mask'], max_length=self.config.max_length, do_sample=True, pad_token_id=self.tokenizer.pad_token_id)

policy_output = pad_to_length(policy_output, self.config.max_length, self.tokenizer.pad_token_id)
policy_output = all_gather_if_needed(policy_output, self.rank, self.world_size)
policy_output_decoded = self.tokenizer.batch_decode(policy_output, skip_special_tokens=True)
Expand Down Expand Up @@ -285,15 +292,15 @@ def train(self):
all_eval_metrics[k].extend(v)

if self.config.sample_during_eval:
n_sample_batches = self.config.n_eval_model_samples // self.config.eval_batch_size
sample_batches = self.eval_batches[:n_sample_batches]
if self.config.n_eval_model_samples < self.config.eval_batch_size:
rank0_print(f'Warning: n_eval_model_samples ({self.config.n_eval_model_samples}) < eval_batch_size ({self.config.eval_batch_size}). Sampling from the first complete eval batch of prompts.')
sample_batches = self.eval_batches[:1]
else:
n_sample_batches = self.config.n_eval_model_samples // self.config.eval_batch_size
sample_batches = self.eval_batches[:n_sample_batches]
for eval_batch in (tqdm.tqdm(sample_batches, desc='Generating samples...') if self.rank == 0 else sample_batches):
local_eval_batch = slice_and_move_batch_for_device(eval_batch, self.rank, self.world_size, self.rank)
if 'FSDP' in self.config.trainer:
with FSDP.summon_full_params(self.policy, writeback=False, recurse=False):
policy_samples, reference_samples = self.get_batch_samples(local_eval_batch)
else:
policy_samples, reference_samples = self.get_batch_samples(local_eval_batch)
policy_samples, reference_samples = self.get_batch_samples(local_eval_batch)

all_policy_samples.extend(policy_samples)
all_reference_samples.extend(reference_samples)
Expand Down

0 comments on commit 68e6015

Please sign in to comment.