From b074ed61cd8f393cd6437fa7f4fca5ffb93dfcb8 Mon Sep 17 00:00:00 2001 From: Pau Riba Date: Mon, 14 Aug 2023 00:20:02 +0200 Subject: [PATCH] Make convert_time as state method --- composer/core/state.py | 49 ++++++++++++ composer/optim/scheduler.py | 55 ++++---------- tests/test_state.py | 148 +++++++++++++++++++++++++++++++++++- 3 files changed, 212 insertions(+), 40 deletions(-) diff --git a/composer/core/state.py b/composer/core/state.py index a2d59b5e57..5f3a7ad409 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -651,6 +651,55 @@ def get_elapsed_duration(self) -> Optional[Time[float]]: return None return self.timestamp.get(self.max_duration.unit) / self.max_duration + def convert_time(self, time: Union[str, Time[int], Time[float]], ssr: float = 1.0) -> Time[int]: + """Converts ``time`` into a fine-grained :class:`.TimeUnit` according to the current state. + + It will not convert ``TimeUnit.BATCH``, ``TimeUnit.SAMPLE`` and ``TimeUnit.TOKEN``. + ``TimeUnit.Epoch`` is converted to ``TimeUnit.BATCH```. + Finally, ``TimeUnit.DURATION`` is converted either to ``TimeUnit.BATCH`` or ``TimeUnit.SAMPLE`` if ``.max_duration`` is ``TimeUnit.SAMPLE``. + + .. note:: + + ``.max_duration`` cannot be `None`. + + ``.dataloader_len`` cannot be `None` if, + * ``time.unit`` is ``TimeUnit.EPOCH`` + * ``time.unit`` is ``TimeUnit.DURATION`` and ``.max_duration.unit`` is ``TimeUnit.EPOCH``. + + Scale Schedule Ratio (ssr) is not applied if ``time.unit`` is ``TimeUnit.DURATION`` + + Args: + time (Time | str): A time string, or instance of :class:`.Time`. + ssr (float): Scale Schedule Ratio. + + Returns: + Time: An instance of :class:`.Time`. + """ + if isinstance(time, str): + time = Time.from_timestring(time) + + if self.max_duration is None: + raise RuntimeError('`max_duration` should be set whenever time conversion is invoked.') + + if time.unit == TimeUnit.DURATION: + if self.max_duration.unit == TimeUnit.EPOCH: + if self.dataloader_len is None: + raise RuntimeError('Cannot convert time, as state.dataloader_len is None.') + return Time( + int(time.value * int(self.dataloader_len) * self.max_duration.value), + TimeUnit.BATCH, + ) + return Time(int(time.value * self.max_duration.value), self.max_duration.unit) + elif time.unit == TimeUnit.EPOCH: + # Epochs do not provide sufficient granularity for SSR scaling + # e.g. if max_duration = 1ep, then any SSR would result in a new duration of 0. + # so, convert the time into batches + if self.dataloader_len is None: + raise RuntimeError('Cannot convert time, as state.dataloader_len is None.') + time = Time(value=time.value * int(self.dataloader_len), unit=TimeUnit.BATCH) + + return Time(value=int(time.value * ssr), unit=time.unit) + def stop_training(self): """Gracefully stop training. diff --git a/composer/optim/scheduler.py b/composer/optim/scheduler.py index 424cc71557..bd4fc1f1e4 100644 --- a/composer/optim/scheduler.py +++ b/composer/optim/scheduler.py @@ -20,7 +20,7 @@ from torch.optim.lr_scheduler import LambdaLR -from composer.core import PyTorchScheduler, State, Time, TimeUnit +from composer.core import PyTorchScheduler, State, Time if TYPE_CHECKING: from typing import Protocol @@ -124,29 +124,6 @@ def __call__(self, state: State, ssr: float = 1.0) -> float: raise NotImplementedError -def _convert_time(time: Union[str, Time[int], Time[float]], state: State, ssr: float = 1.0) -> Time[int]: - if isinstance(time, str): - time = Time.from_timestring(time) - - assert state.max_duration is not None, 'max_duration should be set whenever schedulers are invoked' - - if time.unit == TimeUnit.DURATION: - if state.max_duration.unit == TimeUnit.EPOCH: - if state.dataloader_len is None: - raise RuntimeError('Cannot convert time, as state.dataloader_len is None.') - return Time(int(time.value * int(state.dataloader_len) * state.max_duration.value), TimeUnit.BATCH) - return Time(int(time.value * state.max_duration.value), state.max_duration.unit) - elif time.unit == TimeUnit.EPOCH: - # Epochs do not provide sufficient granularity for SSR scaling - # e.g. if max_duration = 1ep, then any SSR would result in a new duration of 0. - # so, convert the time into batches - if state.dataloader_len is None: - raise RuntimeError('Cannot convert time, as state.dataloader_len is None.') - time = Time(value=time.value * int(state.dataloader_len), unit=TimeUnit.BATCH) - - return Time(value=int(time.value * ssr), unit=time.unit) - - def compile_composer_scheduler(scheduler: ComposerScheduler, state: State, ssr: float = 1.0) -> PyTorchScheduler: """Converts a stateless scheduler into a PyTorch scheduler object. @@ -215,7 +192,7 @@ def __init__(self, step_size: Union[str, Time], gamma: float = 0.1): self.gamma = gamma def __call__(self, state: State, ssr: float = 1.0): - step_size = _convert_time(self.step_size, state, ssr=ssr) + step_size = state.convert_time(self.step_size, ssr=ssr) current_time = state.timestamp.get(step_size.unit) steps = int(current_time / step_size) @@ -248,7 +225,7 @@ def __init__(self, milestones: List[Union[str, Time]], gamma: float = 0.1): self.gamma = gamma def __call__(self, state: State, ssr: float = 1.0): - milestones = [_convert_time(milestone, state, ssr=ssr) for milestone in self.milestones] + milestones = [state.convert_time(milestone, ssr=ssr) for milestone in self.milestones] factor = 1.0 for milestone in milestones: @@ -284,7 +261,7 @@ def __init__(self, alpha: float = 1.0, t_max: Union[str, Time] = '1dur') -> None self.t_max = t_max def __call__(self, state: State, ssr: float = 1.0) -> float: - t_max = _convert_time(self.t_max, state, ssr=ssr) + t_max = state.convert_time(self.t_max, ssr=ssr) if state.timestamp < t_max: return self.alpha @@ -331,7 +308,7 @@ def __init__(self, alpha_i: float = 1.0, alpha_f: float = 0.0, t_max: Union[str, self.t_max = Time.from_timestring(t_max) if isinstance(t_max, str) else t_max def __call__(self, state: State, ssr: float = 1.0): - t_max = _convert_time(self.t_max, state, ssr=ssr) + t_max = state.convert_time(self.t_max, ssr=ssr) current_time = state.timestamp.get(t_max.unit) frac_of_total = min(1.0, (current_time / t_max).value) @@ -365,7 +342,7 @@ def __init__(self, gamma: float, decay_period: Union[str, Time] = '1ep'): self.decay_period = decay_period def __call__(self, state: State, ssr: float = 1.0): - decay_period = _convert_time(self.decay_period, state, ssr) + decay_period = state.convert_time(self.decay_period, ssr=ssr) current_time_in_decay_units = state.timestamp.get(decay_period.unit) return self.gamma**float(current_time_in_decay_units / decay_period) @@ -410,7 +387,7 @@ def __init__(self, t_max: Union[str, Time] = '1dur', alpha_f: float = 0.0): self.alpha_f = alpha_f def __call__(self, state: State, ssr: float = 1.0): - t_max = _convert_time(self.t_max, state, ssr=ssr) + t_max = state.convert_time(self.t_max, ssr=ssr) current_time = state.timestamp.get(t_max.unit) frac_of_total = (current_time / t_max).value @@ -453,7 +430,7 @@ def __init__(self, t_0: Union[str, Time], t_mult: float = 1.0, alpha_f: float = self.alpha_f = alpha_f def __call__(self, state: State, ssr: float = 1.0): - t_0 = _convert_time(self.t_0, state, ssr=ssr) + t_0 = state.convert_time(self.t_0, ssr=ssr) current_interval_len = t_0 current_interval_end = t_0 while current_interval_end <= state.timestamp.get(current_interval_end.unit): @@ -501,7 +478,7 @@ def __init__(self, power: float, t_max: Union[str, Time] = '1dur', alpha_f: floa self.alpha_f = alpha_f def __call__(self, state: State, ssr: float = 1.0): - t_max = _convert_time(self.t_max, state, ssr=ssr) + t_max = state.convert_time(self.t_max, ssr=ssr) current_time = state.timestamp.get(t_max.unit) frac_of_total = (current_time / t_max).value @@ -558,7 +535,7 @@ def __init__(self, self.step_scheduler = MultiStepScheduler(milestones=milestones, gamma=gamma) def __call__(self, state: State, ssr: float = 1.0): - t_warmup = _convert_time(self.t_warmup, state) + t_warmup = state.convert_time(self.t_warmup) if t_warmup.value == 0: warnings.warn( textwrap.dedent("""\ @@ -676,7 +653,7 @@ def __init__(self, self.warmup_scheduler = LinearScheduler(alpha_i=0.0, alpha_f=alpha_i, t_max=t_warmup) def __call__(self, state: State, ssr: float = 1.0): - t_warmup = _convert_time(self.t_warmup, state) + t_warmup = state.convert_time(self.t_warmup) if t_warmup.value == 0: warnings.warn( textwrap.dedent("""\ @@ -689,7 +666,7 @@ def __call__(self, state: State, ssr: float = 1.0): return self.warmup_scheduler(state, ssr) return self.warmup_scheduler(state) - t_max = _convert_time(self.t_max, state, ssr=ssr) + t_max = state.convert_time(self.t_max, ssr=ssr) current_time = state.timestamp.get(t_warmup.unit) frac_of_total = ((current_time - t_warmup) / (t_max - t_warmup)).value if (t_max > t_warmup) else 0.0 frac_of_total = min(1.0, frac_of_total) @@ -744,7 +721,7 @@ def __init__(self, self.warmup_scheduler = LinearScheduler(alpha_i=0.0, alpha_f=1.0, t_max=t_warmup) def __call__(self, state: State, ssr: float = 1.0): - t_warmup = _convert_time(self.t_warmup, state) + t_warmup = state.convert_time(self.t_warmup) if t_warmup.value == 0: warnings.warn( textwrap.dedent("""\ @@ -757,7 +734,7 @@ def __call__(self, state: State, ssr: float = 1.0): return self.warmup_scheduler(state, ssr) return self.warmup_scheduler(state) - t_max = _convert_time(self.t_max, state, ssr=ssr) + t_max = state.convert_time(self.t_max, ssr=ssr) current_time = state.timestamp.get(t_warmup.unit) frac_of_total = ((current_time - t_warmup) / (t_max - t_warmup)).value if (t_max > t_warmup) else 0.0 frac_of_total = min(1.0, frac_of_total) @@ -814,7 +791,7 @@ def __init__(self, self.warmup_scheduler = LinearScheduler(alpha_i=0.0, alpha_f=1.0, t_max=t_warmup) def __call__(self, state: State, ssr: float = 1.0): - t_warmup = _convert_time(self.t_warmup, state) + t_warmup = state.convert_time(self.t_warmup) if t_warmup.value == 0: warnings.warn( textwrap.dedent("""\ @@ -827,7 +804,7 @@ def __call__(self, state: State, ssr: float = 1.0): return self.warmup_scheduler(state, ssr) return self.warmup_scheduler(state) - t_max = _convert_time(self.t_max, state, ssr=ssr) + t_max = state.convert_time(self.t_max, ssr=ssr) current_time = state.timestamp.get(t_warmup.unit) frac_of_total = ((current_time - t_warmup) / (t_max - t_warmup)).value if (t_max > t_warmup) else 0.0 frac_of_total = min(1.0, frac_of_total) diff --git a/tests/test_state.py b/tests/test_state.py index 2660cc25ab..31cae3fec8 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -1,8 +1,8 @@ # Copyright 2022 MosaicML Composer authors # SPDX-License-Identifier: Apache-2.0 - import pathlib import random +from typing import Union import pytest import torch @@ -11,6 +11,7 @@ import composer from composer.core import Batch, Precision, State +from composer.core.time import Time from composer.devices import DeviceCPU, DeviceGPU from composer.loggers import Logger from tests.common import SimpleModel, assert_state_equivalent @@ -155,3 +156,148 @@ def test_composer_metadata_in_state_dict(tmp_path, request: pytest.FixtureReques assert loaded_state_dict['metadata']['world_size'] == 1 assert loaded_state_dict['metadata']['device_train_microbatch_size'] == 2 assert loaded_state_dict['metadata']['train_dataloader_batch_size'] == 4 + + +@pytest.mark.parametrize( + 'time', + [ + '3ep', + '5dur', + Time.from_timestring('3ep'), + Time.from_timestring('5dur'), + ], +) +def test_convert_time_none_dataloader_len(time: Time, request: pytest.FixtureRequest): + # Get a dummy state without a dataloader_len + state = get_dummy_state(request) + assert state.max_duration is not None + assert state.dataloader_len is None + + with pytest.raises(RuntimeError, match='Cannot convert time, as state.dataloader_len is None.'): + state.convert_time(time) + + +@pytest.mark.parametrize( + 'time', + [ + '3ep', + '2ba', + '1024sp', + '98tok', + '5dur', + Time.from_timestring('3ep'), + Time.from_timestring('2ba'), + Time.from_timestring('1024sp'), + Time.from_timestring('98tok'), + Time.from_timestring('5dur'), + ], +) +def test_convert_time_max_duration_none(time: Time, request: pytest.FixtureRequest): + # Get a dummy state without a dataloader_len + state = get_dummy_state(request) + state.max_duration = None + assert state.max_duration is None + + with pytest.raises( + RuntimeError, + match='`max_duration` should be set whenever time conversion is invoked.', + ): + state.convert_time(time) + + +@pytest.mark.parametrize( + 'max_duration', + [ + Time.from_timestring('1024ep'), + Time.from_timestring('25600ba'), + Time.from_timestring('102400sp'), + ], +) +@pytest.mark.parametrize( + 'time,expected_time', + [ + ('3ep', Time.from_timestring('75ba')), + ('2ba', Time.from_timestring('2ba')), + ('1024sp', Time.from_timestring('1024sp')), + ('98tok', Time.from_timestring('98tok')), + (Time.from_timestring('3ep'), Time.from_timestring('75ba')), + (Time.from_timestring('2ba'), Time.from_timestring('2ba')), + (Time.from_timestring('1024sp'), Time.from_timestring('1024sp')), + (Time.from_timestring('98tok'), Time.from_timestring('98tok')), + ], +) +def test_convert_time( + max_duration: Time, + time: Union[str, Time], + expected_time: Time, + request: pytest.FixtureRequest, +): + # Get a dummy state without a dataloader_len + state = get_dummy_state(request) + + state.set_dataloader(dataloader=state.train_dataloader, dataloader_label='train_dataloader') + assert state.dataloader_len is not None + + state.max_duration = max_duration + assert state.max_duration is not None + + out_time = state.convert_time(time) + assert out_time == expected_time + + +@pytest.mark.parametrize( + 'max_duration', + [ + Time.from_timestring('1024ep'), + Time.from_timestring('25600ba'), + ], +) +@pytest.mark.parametrize( + 'time,expected_time', + [ + ('0.1dur', Time.from_timestring('2560ba')), + (Time.from_timestring('0.1dur'), Time.from_timestring('2560ba')), + ], +) +def test_convert_time_duration( + max_duration: Time, + time: Union[str, Time], + expected_time: Time, + request: pytest.FixtureRequest, +): + # Get a dummy state without a dataloader_len + state = get_dummy_state(request) + + state.set_dataloader(dataloader=state.train_dataloader, dataloader_label='train_dataloader') + assert state.dataloader_len is not None + + state.max_duration = max_duration + assert state.max_duration is not None + + out_time = state.convert_time(time) + assert out_time == expected_time + + +@pytest.mark.parametrize( + 'time,expected_time', + [ + ('0.1dur', Time.from_timestring('10240sp')), + (Time.from_timestring('0.1dur'), Time.from_timestring('10240sp')), + ], +) +def test_convert_time_duration_samples( + time: Union[str, Time], + expected_time: Time, + request: pytest.FixtureRequest, +): + # Get a dummy state without a dataloader_len + state = get_dummy_state(request) + + state.set_dataloader(dataloader=state.train_dataloader, dataloader_label='train_dataloader') + assert state.dataloader_len is not None + + state.max_duration = Time.from_timestring('102400sp') + assert state.max_duration is not None + + out_time = state.convert_time(time) + assert out_time == expected_time