Skip to content

Commit

Permalink
Refactor collate_fn
Browse files Browse the repository at this point in the history
  • Loading branch information
erfanzar committed Jan 26, 2025
1 parent 79c1cc2 commit 218aaf9
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
14 changes: 7 additions & 7 deletions easydel/trainers/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def create_collect_function(
def collate_fn(batch):
results = {}
for key in batch[0].keys():
if self.model.loss_function.__name__ == "ForCausalLMLoss":
if self.model.__class__.__name__ == "ForCausalLMLoss":
if truncation_mode == "keep_end":
corrected_sequence = [
jnp.array(f[key])[..., -max_sequence_length:] for f in batch
Expand All @@ -69,13 +69,10 @@ def collate_fn(batch):
corrected_sequence = [
jnp.array(f[key])[..., :max_sequence_length] for f in batch
]
results[key] = jnp.stack(corrected_sequence)
else:
corrected_sequence = [jnp.array(f[key]) for f in batch]

results[key] = jnp.stack(corrected_sequence).reshape(
-1,
corrected_sequence[0].shape[-1],
)
results[key] = jnp.stack(corrected_sequence)
return results

return collate_fn
Expand Down Expand Up @@ -359,7 +356,9 @@ def _execute_eval_step(self, state, batch) -> LossMetrics:
return metrics

def _execute_train_step(
self, state, batch
self,
state,
batch,
) -> tp.Tuple[EasyDeLState, LossMetrics, Exception]:
"""Execute a single training step."""
if self.pruning_module is not None:
Expand All @@ -369,6 +368,7 @@ def _execute_train_step(
state.opt_state,
)
)
metrics = LossMetrics()
try:
state, metrics = jax.block_until_ready(
self.sharded_training_step_function(state, batch)
Expand Down
4 changes: 2 additions & 2 deletions tests/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def create_dataset():

dataset = datasets.load_dataset(
"PowerInfer/QWQ-LONGCOT-500K",
split="train",
streaming=True,
split="train[:5%]",
streaming=False,
)

def to_ids(sample):
Expand Down

0 comments on commit 218aaf9

Please sign in to comment.