-
Notifications
You must be signed in to change notification settings - Fork 269
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Goodput & Badput recording and monitoring support. #783
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
I don't have access to this link. Can you provide an example that we can take a look |
axlearn/cloud/gcp/measurement.py
Outdated
# Instantiate ml-goodput-measurement's GoodputMonitor | ||
# to asynchronously calculate goodput and badput at | ||
# the upload_interval and upload to the specified | ||
# tensorboard directory. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Convert this into a docstring for the method? BTW, we use 100 line length.
fv.mark_as_parsed() | ||
|
||
recorder = GoodputRecorder.from_flags(fv) | ||
recorder._monitor = None # Ensure _monitor is initially None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
recorder._monitor = None # Ensure _monitor is initially None | |
self.assertIsNone(recorder._monitor) # Ensure _monitor is initially None |
Please feel free to "re-request review" when ready. Thanks! |
@@ -324,6 +325,7 @@ def __init__( | |||
model=self.model, | |||
model_param_partition_specs=model_param_partition_specs, | |||
) | |||
self._maybe_record_event(measurement.Event.END_ACCELERATOR_INIT) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you clarify what accelerator init is supposed to capture? E.g., would utils_spmd
where we call jax distributed init be more appropriate?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is supposed to capture device related initialization such as device scanning, mesh initialization, device reinit/reset, security setup, initialization of pre-mapped buffers etc. You are right, jax distributed init should be included here.
I would lean on your team to update/re-position the record calls to locations that seems best fit for this codebase.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @dipannita08 ! WDYT about removing the trainer.py
changes for now, so that we can add them in a follow-up PR? This PR can focus on adding the scaffolding in measurement.py
.
@@ -847,6 +850,7 @@ def _prepare_training(self, prng_key: Tensor) -> bool: | |||
with fs.open(os.path.join(cfg.dir, "model_analysis.txt"), "w") as f: | |||
f.write(model_analysis) | |||
|
|||
self._maybe_record_event(measurement.Event.END_TRAINING_PREPARATION) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's usually considered part of "training preparation"? Should we count the jit compilation below as a potentially substantial part of it? What about the checkpoint restoration above?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would include the time spent on the creation of checkpoint managers, checkpoint loading, running mesh and model optimizers etc.
The JIT compilation is currently computed algorithmically based on the entire timeline of events (other recorded logs) and is included in the "program startup" badput that is meant to measure the time spent on framework specific function transformations (such as JAX tracing), compilation tasks, runtime initialization etc.
For now, checkpoint restoration is not included in this bucket - the expectation is the next version of the library (v0.0.5) has more definitive recorder and calculator APIs for checkpoint save and restore badput.
@@ -883,6 +887,7 @@ def restore_checkpoint(self, restore_step: Optional[int] = None) -> Optional[int | |||
restore_input_iter = cfg.save_input_iterator | |||
try: | |||
# Try to restore with `input_iter`. | |||
self._maybe_record_event(measurement.Event.START_DATA_LOADING) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, is data loading be in relation to the input loading or the checkpoint or both? Here it seems only capturing the checkpoint restoration?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Data loading is in relation to the input data only. Checkpoint restore would be recorded and badput would be computed separately with changes coming in the next version of the Goodput package. Please feel free to update the location of the record calls specifically, for data loading, I wasn't sure where is the most appropriate place to put it.
Thanks for your help!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ack.
This change adds the following:
Tested: