diff --git a/easydel/trainers/trainer/trainer.py b/easydel/trainers/trainer/trainer.py index f3ba9911..db1ed3b0 100644 --- a/easydel/trainers/trainer/trainer.py +++ b/easydel/trainers/trainer/trainer.py @@ -60,7 +60,7 @@ def create_collect_function( def collate_fn(batch): results = {} for key in batch[0].keys(): - if self.model.loss_function.__name__ == "ForCausalLMLoss": + if self.model.__class__.__name__ == "ForCausalLMLoss": if truncation_mode == "keep_end": corrected_sequence = [ jnp.array(f[key])[..., -max_sequence_length:] for f in batch @@ -69,13 +69,10 @@ def collate_fn(batch): corrected_sequence = [ jnp.array(f[key])[..., :max_sequence_length] for f in batch ] + results[key] = jnp.stack(corrected_sequence) else: corrected_sequence = [jnp.array(f[key]) for f in batch] - - results[key] = jnp.stack(corrected_sequence).reshape( - -1, - corrected_sequence[0].shape[-1], - ) + results[key] = jnp.stack(corrected_sequence) return results return collate_fn @@ -359,7 +356,9 @@ def _execute_eval_step(self, state, batch) -> LossMetrics: return metrics def _execute_train_step( - self, state, batch + self, + state, + batch, ) -> tp.Tuple[EasyDeLState, LossMetrics, Exception]: """Execute a single training step.""" if self.pruning_module is not None: @@ -369,6 +368,7 @@ def _execute_train_step( state.opt_state, ) ) + metrics = LossMetrics() try: state, metrics = jax.block_until_ready( self.sharded_training_step_function(state, batch) diff --git a/tests/trainer_test.py b/tests/trainer_test.py index 69b669de..42411e47 100644 --- a/tests/trainer_test.py +++ b/tests/trainer_test.py @@ -45,8 +45,8 @@ def create_dataset(): dataset = datasets.load_dataset( "PowerInfer/QWQ-LONGCOT-500K", - split="train", - streaming=True, + split="train[:5%]", + streaming=False, ) def to_ids(sample):