Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make convert_time as State method #2431

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,55 @@ def get_elapsed_duration(self) -> Optional[Time[float]]:
return None
return self.timestamp.get(self.max_duration.unit) / self.max_duration

def convert_time(self, time: Union[str, Time[int], Time[float]], ssr: float = 1.0) -> Time[int]:
"""Converts ``time`` into a fine-grained :class:`.TimeUnit` according to the current state.

It will not convert ``TimeUnit.BATCH``, ``TimeUnit.SAMPLE`` and ``TimeUnit.TOKEN``.
``TimeUnit.Epoch`` is converted to ``TimeUnit.BATCH```.
Finally, ``TimeUnit.DURATION`` is converted either to ``TimeUnit.BATCH`` or ``TimeUnit.SAMPLE`` if ``.max_duration`` is ``TimeUnit.SAMPLE``.

.. note::

``.max_duration`` cannot be `None`.

``.dataloader_len`` cannot be `None` if,
* ``time.unit`` is ``TimeUnit.EPOCH``
* ``time.unit`` is ``TimeUnit.DURATION`` and ``.max_duration.unit`` is ``TimeUnit.EPOCH``.

Scale Schedule Ratio (ssr) is not applied if ``time.unit`` is ``TimeUnit.DURATION``

Args:
time (Time | str): A time string, or instance of :class:`.Time`.
ssr (float): Scale Schedule Ratio.

Returns:
Time: An instance of :class:`.Time`.
"""
if isinstance(time, str):
time = Time.from_timestring(time)

if self.max_duration is None:
raise RuntimeError('`max_duration` should be set whenever time conversion is invoked.')

if time.unit == TimeUnit.DURATION:
if self.max_duration.unit == TimeUnit.EPOCH:
if self.dataloader_len is None:
raise RuntimeError('Cannot convert time, as state.dataloader_len is None.')
return Time(
int(time.value * int(self.dataloader_len) * self.max_duration.value),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mvpatel2000, I am not sure if we should apply ssr here. I decided to keep the original implementation but I do not see any reason to apply this scaling once the TimeUnit has changed to BATCH. Similarly to the elif case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm... i think we probably shouldn't apply SSR here.

Do you mind if I directly touch your PR? There's a few other places we use a similar function (eg runtime estimator) and I think it could be good to consolidate into this API.

Also, apologies for the delay in review. We're in the middle of a release so we're freezing everything but bug fixes. I'll make sure this gets into the next release (0.16.1) which will come very soon after 0.16 (likely a weekish?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worries, feel free to edit the PR 😄

Thanks for your efforts. I am happy to keep contributing with PRs or Issues if I find something.

Let me know if I can help.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@priba apologies for the delay, this might be on the backburner for a few weeks. Really slammed with other things. If this is a blocker, please let me know and we can try landing a smaller solution. However, I'd really like to get to the full solution and make this much easier for everyone.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problem, we already have a workaround. Although it would be a nice feature, I understand it is not a high priority :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @mvpatel2000 , are there any news regarding this PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I wasn't able to get to it yet. I will make sure to prioritize over next 2-3 weeks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@priba sorry for delay, I still haven't gotten to it. I will make sure to do in January though. I keep getting pre-empted with higher priority items

TimeUnit.BATCH,
)
return Time(int(time.value * self.max_duration.value), self.max_duration.unit)
elif time.unit == TimeUnit.EPOCH:
# Epochs do not provide sufficient granularity for SSR scaling
# e.g. if max_duration = 1ep, then any SSR would result in a new duration of 0.
# so, convert the time into batches
if self.dataloader_len is None:
raise RuntimeError('Cannot convert time, as state.dataloader_len is None.')
time = Time(value=time.value * int(self.dataloader_len), unit=TimeUnit.BATCH)

return Time(value=int(time.value * ssr), unit=time.unit)

def stop_training(self):
"""Gracefully stop training.

Expand Down
55 changes: 16 additions & 39 deletions composer/optim/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from torch.optim.lr_scheduler import LambdaLR

from composer.core import PyTorchScheduler, State, Time, TimeUnit
from composer.core import PyTorchScheduler, State, Time

if TYPE_CHECKING:
from typing import Protocol
Expand Down Expand Up @@ -124,29 +124,6 @@ def __call__(self, state: State, ssr: float = 1.0) -> float:
raise NotImplementedError


def _convert_time(time: Union[str, Time[int], Time[float]], state: State, ssr: float = 1.0) -> Time[int]:
if isinstance(time, str):
time = Time.from_timestring(time)

assert state.max_duration is not None, 'max_duration should be set whenever schedulers are invoked'

if time.unit == TimeUnit.DURATION:
if state.max_duration.unit == TimeUnit.EPOCH:
if state.dataloader_len is None:
raise RuntimeError('Cannot convert time, as state.dataloader_len is None.')
return Time(int(time.value * int(state.dataloader_len) * state.max_duration.value), TimeUnit.BATCH)
return Time(int(time.value * state.max_duration.value), state.max_duration.unit)
elif time.unit == TimeUnit.EPOCH:
# Epochs do not provide sufficient granularity for SSR scaling
# e.g. if max_duration = 1ep, then any SSR would result in a new duration of 0.
# so, convert the time into batches
if state.dataloader_len is None:
raise RuntimeError('Cannot convert time, as state.dataloader_len is None.')
time = Time(value=time.value * int(state.dataloader_len), unit=TimeUnit.BATCH)

return Time(value=int(time.value * ssr), unit=time.unit)


def compile_composer_scheduler(scheduler: ComposerScheduler, state: State, ssr: float = 1.0) -> PyTorchScheduler:
"""Converts a stateless scheduler into a PyTorch scheduler object.

Expand Down Expand Up @@ -215,7 +192,7 @@ def __init__(self, step_size: Union[str, Time], gamma: float = 0.1):
self.gamma = gamma

def __call__(self, state: State, ssr: float = 1.0):
step_size = _convert_time(self.step_size, state, ssr=ssr)
step_size = state.convert_time(self.step_size, ssr=ssr)
current_time = state.timestamp.get(step_size.unit)
steps = int(current_time / step_size)

Expand Down Expand Up @@ -248,7 +225,7 @@ def __init__(self, milestones: List[Union[str, Time]], gamma: float = 0.1):
self.gamma = gamma

def __call__(self, state: State, ssr: float = 1.0):
milestones = [_convert_time(milestone, state, ssr=ssr) for milestone in self.milestones]
milestones = [state.convert_time(milestone, ssr=ssr) for milestone in self.milestones]

factor = 1.0
for milestone in milestones:
Expand Down Expand Up @@ -284,7 +261,7 @@ def __init__(self, alpha: float = 1.0, t_max: Union[str, Time] = '1dur') -> None
self.t_max = t_max

def __call__(self, state: State, ssr: float = 1.0) -> float:
t_max = _convert_time(self.t_max, state, ssr=ssr)
t_max = state.convert_time(self.t_max, ssr=ssr)

if state.timestamp < t_max:
return self.alpha
Expand Down Expand Up @@ -331,7 +308,7 @@ def __init__(self, alpha_i: float = 1.0, alpha_f: float = 0.0, t_max: Union[str,
self.t_max = Time.from_timestring(t_max) if isinstance(t_max, str) else t_max

def __call__(self, state: State, ssr: float = 1.0):
t_max = _convert_time(self.t_max, state, ssr=ssr)
t_max = state.convert_time(self.t_max, ssr=ssr)
current_time = state.timestamp.get(t_max.unit)
frac_of_total = min(1.0, (current_time / t_max).value)

Expand Down Expand Up @@ -365,7 +342,7 @@ def __init__(self, gamma: float, decay_period: Union[str, Time] = '1ep'):
self.decay_period = decay_period

def __call__(self, state: State, ssr: float = 1.0):
decay_period = _convert_time(self.decay_period, state, ssr)
decay_period = state.convert_time(self.decay_period, ssr=ssr)
current_time_in_decay_units = state.timestamp.get(decay_period.unit)

return self.gamma**float(current_time_in_decay_units / decay_period)
Expand Down Expand Up @@ -410,7 +387,7 @@ def __init__(self, t_max: Union[str, Time] = '1dur', alpha_f: float = 0.0):
self.alpha_f = alpha_f

def __call__(self, state: State, ssr: float = 1.0):
t_max = _convert_time(self.t_max, state, ssr=ssr)
t_max = state.convert_time(self.t_max, ssr=ssr)
current_time = state.timestamp.get(t_max.unit)
frac_of_total = (current_time / t_max).value

Expand Down Expand Up @@ -453,7 +430,7 @@ def __init__(self, t_0: Union[str, Time], t_mult: float = 1.0, alpha_f: float =
self.alpha_f = alpha_f

def __call__(self, state: State, ssr: float = 1.0):
t_0 = _convert_time(self.t_0, state, ssr=ssr)
t_0 = state.convert_time(self.t_0, ssr=ssr)
current_interval_len = t_0
current_interval_end = t_0
while current_interval_end <= state.timestamp.get(current_interval_end.unit):
Expand Down Expand Up @@ -501,7 +478,7 @@ def __init__(self, power: float, t_max: Union[str, Time] = '1dur', alpha_f: floa
self.alpha_f = alpha_f

def __call__(self, state: State, ssr: float = 1.0):
t_max = _convert_time(self.t_max, state, ssr=ssr)
t_max = state.convert_time(self.t_max, ssr=ssr)
current_time = state.timestamp.get(t_max.unit)
frac_of_total = (current_time / t_max).value

Expand Down Expand Up @@ -558,7 +535,7 @@ def __init__(self,
self.step_scheduler = MultiStepScheduler(milestones=milestones, gamma=gamma)

def __call__(self, state: State, ssr: float = 1.0):
t_warmup = _convert_time(self.t_warmup, state)
t_warmup = state.convert_time(self.t_warmup)
if t_warmup.value == 0:
warnings.warn(
textwrap.dedent("""\
Expand Down Expand Up @@ -676,7 +653,7 @@ def __init__(self,
self.warmup_scheduler = LinearScheduler(alpha_i=0.0, alpha_f=alpha_i, t_max=t_warmup)

def __call__(self, state: State, ssr: float = 1.0):
t_warmup = _convert_time(self.t_warmup, state)
t_warmup = state.convert_time(self.t_warmup)
if t_warmup.value == 0:
warnings.warn(
textwrap.dedent("""\
Expand All @@ -689,7 +666,7 @@ def __call__(self, state: State, ssr: float = 1.0):
return self.warmup_scheduler(state, ssr)
return self.warmup_scheduler(state)

t_max = _convert_time(self.t_max, state, ssr=ssr)
t_max = state.convert_time(self.t_max, ssr=ssr)
current_time = state.timestamp.get(t_warmup.unit)
frac_of_total = ((current_time - t_warmup) / (t_max - t_warmup)).value if (t_max > t_warmup) else 0.0
frac_of_total = min(1.0, frac_of_total)
Expand Down Expand Up @@ -744,7 +721,7 @@ def __init__(self,
self.warmup_scheduler = LinearScheduler(alpha_i=0.0, alpha_f=1.0, t_max=t_warmup)

def __call__(self, state: State, ssr: float = 1.0):
t_warmup = _convert_time(self.t_warmup, state)
t_warmup = state.convert_time(self.t_warmup)
if t_warmup.value == 0:
warnings.warn(
textwrap.dedent("""\
Expand All @@ -757,7 +734,7 @@ def __call__(self, state: State, ssr: float = 1.0):
return self.warmup_scheduler(state, ssr)
return self.warmup_scheduler(state)

t_max = _convert_time(self.t_max, state, ssr=ssr)
t_max = state.convert_time(self.t_max, ssr=ssr)
current_time = state.timestamp.get(t_warmup.unit)
frac_of_total = ((current_time - t_warmup) / (t_max - t_warmup)).value if (t_max > t_warmup) else 0.0
frac_of_total = min(1.0, frac_of_total)
Expand Down Expand Up @@ -814,7 +791,7 @@ def __init__(self,
self.warmup_scheduler = LinearScheduler(alpha_i=0.0, alpha_f=1.0, t_max=t_warmup)

def __call__(self, state: State, ssr: float = 1.0):
t_warmup = _convert_time(self.t_warmup, state)
t_warmup = state.convert_time(self.t_warmup)
if t_warmup.value == 0:
warnings.warn(
textwrap.dedent("""\
Expand All @@ -827,7 +804,7 @@ def __call__(self, state: State, ssr: float = 1.0):
return self.warmup_scheduler(state, ssr)
return self.warmup_scheduler(state)

t_max = _convert_time(self.t_max, state, ssr=ssr)
t_max = state.convert_time(self.t_max, ssr=ssr)
current_time = state.timestamp.get(t_warmup.unit)
frac_of_total = ((current_time - t_warmup) / (t_max - t_warmup)).value if (t_max > t_warmup) else 0.0
frac_of_total = min(1.0, frac_of_total)
Expand Down
148 changes: 147 additions & 1 deletion tests/test_state.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

import pathlib
import random
from typing import Union

import pytest
import torch
Expand All @@ -11,6 +11,7 @@

import composer
from composer.core import Batch, Precision, State
from composer.core.time import Time
from composer.devices import DeviceCPU, DeviceGPU
from composer.loggers import Logger
from tests.common import SimpleModel, assert_state_equivalent
Expand Down Expand Up @@ -155,3 +156,148 @@ def test_composer_metadata_in_state_dict(tmp_path, request: pytest.FixtureReques
assert loaded_state_dict['metadata']['world_size'] == 1
assert loaded_state_dict['metadata']['device_train_microbatch_size'] == 2
assert loaded_state_dict['metadata']['train_dataloader_batch_size'] == 4


@pytest.mark.parametrize(
'time',
[
'3ep',
'5dur',
Time.from_timestring('3ep'),
Time.from_timestring('5dur'),
],
)
def test_convert_time_none_dataloader_len(time: Time, request: pytest.FixtureRequest):
# Get a dummy state without a dataloader_len
state = get_dummy_state(request)
assert state.max_duration is not None
assert state.dataloader_len is None

with pytest.raises(RuntimeError, match='Cannot convert time, as state.dataloader_len is None.'):
state.convert_time(time)


@pytest.mark.parametrize(
'time',
[
'3ep',
'2ba',
'1024sp',
'98tok',
'5dur',
Time.from_timestring('3ep'),
Time.from_timestring('2ba'),
Time.from_timestring('1024sp'),
Time.from_timestring('98tok'),
Time.from_timestring('5dur'),
],
)
def test_convert_time_max_duration_none(time: Time, request: pytest.FixtureRequest):
# Get a dummy state without a dataloader_len
state = get_dummy_state(request)
state.max_duration = None
assert state.max_duration is None

with pytest.raises(
RuntimeError,
match='`max_duration` should be set whenever time conversion is invoked.',
):
state.convert_time(time)


@pytest.mark.parametrize(
'max_duration',
[
Time.from_timestring('1024ep'),
Time.from_timestring('25600ba'),
Time.from_timestring('102400sp'),
],
)
@pytest.mark.parametrize(
'time,expected_time',
[
('3ep', Time.from_timestring('75ba')),
('2ba', Time.from_timestring('2ba')),
('1024sp', Time.from_timestring('1024sp')),
('98tok', Time.from_timestring('98tok')),
(Time.from_timestring('3ep'), Time.from_timestring('75ba')),
(Time.from_timestring('2ba'), Time.from_timestring('2ba')),
(Time.from_timestring('1024sp'), Time.from_timestring('1024sp')),
(Time.from_timestring('98tok'), Time.from_timestring('98tok')),
],
)
def test_convert_time(
max_duration: Time,
time: Union[str, Time],
expected_time: Time,
request: pytest.FixtureRequest,
):
# Get a dummy state without a dataloader_len
state = get_dummy_state(request)

state.set_dataloader(dataloader=state.train_dataloader, dataloader_label='train_dataloader')
assert state.dataloader_len is not None

state.max_duration = max_duration
assert state.max_duration is not None

out_time = state.convert_time(time)
assert out_time == expected_time


@pytest.mark.parametrize(
'max_duration',
[
Time.from_timestring('1024ep'),
Time.from_timestring('25600ba'),
],
)
@pytest.mark.parametrize(
'time,expected_time',
[
('0.1dur', Time.from_timestring('2560ba')),
(Time.from_timestring('0.1dur'), Time.from_timestring('2560ba')),
],
)
def test_convert_time_duration(
max_duration: Time,
time: Union[str, Time],
expected_time: Time,
request: pytest.FixtureRequest,
):
# Get a dummy state without a dataloader_len
state = get_dummy_state(request)

state.set_dataloader(dataloader=state.train_dataloader, dataloader_label='train_dataloader')
assert state.dataloader_len is not None

state.max_duration = max_duration
assert state.max_duration is not None

out_time = state.convert_time(time)
assert out_time == expected_time


@pytest.mark.parametrize(
'time,expected_time',
[
('0.1dur', Time.from_timestring('10240sp')),
(Time.from_timestring('0.1dur'), Time.from_timestring('10240sp')),
],
)
def test_convert_time_duration_samples(
time: Union[str, Time],
expected_time: Time,
request: pytest.FixtureRequest,
):
# Get a dummy state without a dataloader_len
state = get_dummy_state(request)

state.set_dataloader(dataloader=state.train_dataloader, dataloader_label='train_dataloader')
assert state.dataloader_len is not None

state.max_duration = Time.from_timestring('102400sp')
assert state.max_duration is not None

out_time = state.convert_time(time)
assert out_time == expected_time