Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

step tracker #59

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions dolomite_engine/utils/step_tracker.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,35 @@
from enum import Enum

from .parallel import ProcessGroupManager


_MICRO_BATCH_SIZE: int | None = None
_GRADIENT_ACCUMULATION_STEPS: int | None = None
_SEQUENCE_LENGTH: int | None = None


class StepTrackerMethod(Enum):
samples = "samples"
tokens = "tokens"


class StepTracker:
def __init__(self, micro_batch_size: int, gradient_accumulation_steps: int) -> None:
global _MICRO_BATCH_SIZE, _GRADIENT_ACCUMULATION_STEPS
def __init__(self, micro_batch_size: int, gradient_accumulation_steps: int, sequence_length: int = None) -> None:
global _MICRO_BATCH_SIZE, _GRADIENT_ACCUMULATION_STEPS, _SEQUENCE_LENGTH

_MICRO_BATCH_SIZE = micro_batch_size
_GRADIENT_ACCUMULATION_STEPS = gradient_accumulation_steps
_SEQUENCE_LENGTH = sequence_length

@staticmethod
def get_local_batch_size() -> int:
global _MICRO_BATCH_SIZE, _GRADIENT_ACCUMULATION_STEPS
return _MICRO_BATCH_SIZE * _GRADIENT_ACCUMULATION_STEPS
def get_local_batch_size(tracker_method: StepTrackerMethod) -> int:
local_batch_size = _MICRO_BATCH_SIZE * _GRADIENT_ACCUMULATION_STEPS

if tracker_method == StepTrackerMethod.tokens:
local_batch_size = local_batch_size * _SEQUENCE_LENGTH

return local_batch_size

@staticmethod
def get_global_batch_size() -> int:
return StepTracker.get_local_batch_size() * ProcessGroupManager.get_data_parallel_world_size()
def get_global_batch_size(tracker_method: StepTrackerMethod) -> int:
return StepTracker.get_local_batch_size(tracker_method) * ProcessGroupManager.get_data_parallel_world_size()
Loading