Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Goodput & Badput recording and monitoring support. #783

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

dipannita08
Copy link

@dipannita08 dipannita08 commented Oct 25, 2024

This change adds the following:

  • Upgrades to the latest ml-goodput-measurement library
  • Integrates badput recording into the GoodputRecorder in AxLearn
  • Builds a Goodput monitor to configure and visualize Goodput and Badput using Tensorboard

Tested:

image

Copy link
Contributor

@markblee markblee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

axlearn/common/launch_trainer_main.py Outdated Show resolved Hide resolved
@Ethanlm
Copy link
Contributor

Ethanlm commented Oct 25, 2024

Functional end to end testing using fuji-test and fuji-7b (example tensorboard instance)

I don't have access to this link. Can you provide an example that we can take a look

Comment on lines 79 to 82
# Instantiate ml-goodput-measurement's GoodputMonitor
# to asynchronously calculate goodput and badput at
# the upload_interval and upload to the specified
# tensorboard directory.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Convert this into a docstring for the method? BTW, we use 100 line length.

axlearn/cloud/gcp/measurement.py Show resolved Hide resolved
fv.mark_as_parsed()

recorder = GoodputRecorder.from_flags(fv)
recorder._monitor = None # Ensure _monitor is initially None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
recorder._monitor = None # Ensure _monitor is initially None
self.assertIsNone(recorder._monitor) # Ensure _monitor is initially None

axlearn/cloud/gcp/measurement.py Show resolved Hide resolved
axlearn/common/launch_trainer_main.py Show resolved Hide resolved
@markblee
Copy link
Contributor

markblee commented Nov 7, 2024

Please feel free to "re-request review" when ready. Thanks!

axlearn/cloud/gcp/measurement.py Outdated Show resolved Hide resolved
axlearn/cloud/gcp/measurement.py Show resolved Hide resolved
@@ -324,6 +325,7 @@ def __init__(
model=self.model,
model_param_partition_specs=model_param_partition_specs,
)
self._maybe_record_event(measurement.Event.END_ACCELERATOR_INIT)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you clarify what accelerator init is supposed to capture? E.g., would utils_spmd where we call jax distributed init be more appropriate?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is supposed to capture device related initialization such as device scanning, mesh initialization, device reinit/reset, security setup, initialization of pre-mapped buffers etc. You are right, jax distributed init should be included here.

I would lean on your team to update/re-position the record calls to locations that seems best fit for this codebase.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @dipannita08 ! WDYT about removing the trainer.py changes for now, so that we can add them in a follow-up PR? This PR can focus on adding the scaffolding in measurement.py.

@@ -847,6 +850,7 @@ def _prepare_training(self, prng_key: Tensor) -> bool:
with fs.open(os.path.join(cfg.dir, "model_analysis.txt"), "w") as f:
f.write(model_analysis)

self._maybe_record_event(measurement.Event.END_TRAINING_PREPARATION)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's usually considered part of "training preparation"? Should we count the jit compilation below as a potentially substantial part of it? What about the checkpoint restoration above?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would include the time spent on the creation of checkpoint managers, checkpoint loading, running mesh and model optimizers etc.

The JIT compilation is currently computed algorithmically based on the entire timeline of events (other recorded logs) and is included in the "program startup" badput that is meant to measure the time spent on framework specific function transformations (such as JAX tracing), compilation tasks, runtime initialization etc.

For now, checkpoint restoration is not included in this bucket - the expectation is the next version of the library (v0.0.5) has more definitive recorder and calculator APIs for checkpoint save and restore badput.

@@ -883,6 +887,7 @@ def restore_checkpoint(self, restore_step: Optional[int] = None) -> Optional[int
restore_input_iter = cfg.save_input_iterator
try:
# Try to restore with `input_iter`.
self._maybe_record_event(measurement.Event.START_DATA_LOADING)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, is data loading be in relation to the input loading or the checkpoint or both? Here it seems only capturing the checkpoint restoration?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Data loading is in relation to the input data only. Checkpoint restore would be recorded and badput would be computed separately with changes coming in the next version of the Goodput package. Please feel free to update the location of the record calls specifically, for data loading, I wasn't sure where is the most appropriate place to put it.

Thanks for your help!

Copy link
Contributor

@markblee markblee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ack.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants