diff --git a/apax/data/input_pipeline.py b/apax/data/input_pipeline.py index 19cf2d92..3e53fa33 100644 --- a/apax/data/input_pipeline.py +++ b/apax/data/input_pipeline.py @@ -252,7 +252,6 @@ def shuffle_and_batch(self) -> Iterator[jax.Array]: ) if self.n_jit_steps > 1: - print("JIT BATCH") ds = ds.batch(batch_size=self.n_jit_steps) ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2)