Skip to content

Commit

Permalink
Merge pull request #76 from huggingface/nouamane/bye-recompute
Browse files Browse the repository at this point in the history
Deprecate `recompute_granularity` in config
  • Loading branch information
NouamaneTazi authored Feb 21, 2024
2 parents 7c01d0f + 0f0d354 commit 53c3064
Show file tree
Hide file tree
Showing 10 changed files with 7 additions and 62 deletions.
1 change: 0 additions & 1 deletion examples/bench_llama_7b.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@
pp_engine="1f1b",
tp_mode="REDUCE_SCATTER",
tp_linear_async_communication=True,
recompute_granularity="selective",
)

tokens = TokensArgs(sequence_length=8192, train_steps=5, micro_batch_size=1, batch_accumulation_per_replica=8)
Expand Down
1 change: 0 additions & 1 deletion examples/config_tiny_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@
pp_engine="1f1b",
tp_mode="REDUCE_SCATTER",
tp_linear_async_communication=True,
recompute_granularity="selective",
)

tokens = TokensArgs(sequence_length=32, train_steps=10, micro_batch_size=2, batch_accumulation_per_replica=1)
Expand Down
1 change: 0 additions & 1 deletion examples/config_tiny_llama.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ parallelism:
dp: 2
pp: 2
pp_engine: 1f1b
recompute_granularity: SELECTIVE
tp: 2
tp_linear_async_communication: true
tp_mode: REDUCE_SCATTER
Expand Down
1 change: 0 additions & 1 deletion examples/moe/config_llamoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ def __post_init__(self):
pp_engine="1f1b",
tp_mode="ALL_REDUCE",
tp_linear_async_communication=False,
recompute_granularity=None,
)

assert (
Expand Down
3 changes: 1 addition & 2 deletions examples/moe/config_llamoe.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,7 @@ parallelism:
expert_parallel_size: 2
pp: 1
pp_engine: 1f1b
recompute_granularity: null
tp: 2
tp: 1
tp_linear_async_communication: false
tp_mode: ALL_REDUCE
profiler: null
Expand Down
19 changes: 2 additions & 17 deletions examples/moe/llamoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from moe import dMoE
from nanotron import distributed as dist
from nanotron import logging
from nanotron.config import ParallelismArgs, RecomputeGranularity
from nanotron.config import ParallelismArgs
from nanotron.generation.generate_store import AttachableStore
from nanotron.logging import log_rank
from nanotron.models import NanotronModel
Expand Down Expand Up @@ -754,7 +754,6 @@ def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch
ffn_hidden_size=self.config.intermediate_size,
seq_len=sequence_length,
batch_size=global_batch_size,
recompute_granularity=self.parallel_config.recompute_granularity,
)

model_flops_per_s = model_flops / (iteration_time_in_sec * world_size * 1e12)
Expand Down Expand Up @@ -986,7 +985,6 @@ def get_flops(
seq_len,
ffn_hidden_size,
batch_size=1,
recompute_granularity=None,
):
"""Counts flops in an decoder-only model
Args:
Expand All @@ -998,7 +996,6 @@ def get_flops(
vocab_size: size of the vocabulary
seq_len: sequence length of the decoder
batch_size: batch size
recompute_granularity: Activation recomputation method. Either None, FULL or SELECTIVE. Check Megatron-LM docs for more info.
Returns:
model_flops: flops in the model (should be independent of the hardware and model implementation)
hardware_flops: flops in the hardware (actual flops performed on the hardware). Check 6.3 in https://arxiv.org/pdf/2205.05198.pdf
Expand Down Expand Up @@ -1044,17 +1041,5 @@ def get_flops(
# both input and weight tensors
model_flops = 3 * (decoder_flops_fwd + lm_head_flops_fwd) # 1 for fwd + 2 for bwd

if recompute_granularity is None:
hardware_flops = model_flops
elif recompute_granularity is RecomputeGranularity.FULL:
# Note: we don't recompute lm head activs
hardware_flops = model_flops + decoder_flops_fwd # + activ recomputation
elif recompute_granularity is RecomputeGranularity.SELECTIVE:
# all terms with s^2 are flops that are recomputed
# ref. appendix A: https://arxiv.org/pdf/2205.05198.pdf
recomputed_decoder_flops = decoder_qk_logits_flops_fwd + decoder_v_logits_flops_fwd
hardware_flops = model_flops + recomputed_decoder_flops
else:
raise ValueError("recompute_granularity must be one of 'full' or 'selective'")

hardware_flops = model_flops # TODO @nouamanetazi: add hardware flops
return model_flops, hardware_flops
1 change: 0 additions & 1 deletion run_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def main():
tp=args.tp or config.parallelism.tp,
pp_engine=OneForwardOneBackwardPipelineEngine(),
tp_mode=TensorParallelLinearMode.ALL_REDUCE,
recompute_granularity=None,
tp_linear_async_communication=True,
)

Expand Down
5 changes: 0 additions & 5 deletions src/nanotron/config/parallelism_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Optional

from nanotron.config.utils_config import (
RecomputeGranularity,
cast_str_to_pipeline_engine,
)
from nanotron.parallel.pipeline_parallel.engine import (
Expand All @@ -23,7 +22,6 @@ class ParallelismArgs:
expert_parallel_size: Number of expert parallel replicas (used only for MoEs)
pp_engine: Pipeline engine to use between "1f1b" and "afab"
tp_mode: TP mode to use between "all_reduce" and "reduce_scatter": all_reduce is normal, reduce_scatter activate sequence parallelism
recompute_granularity: Recompute granularity to use between "full" and "selective"
tp_linear_async_communication: Whether to use async communication in TP linear layers
"""

Expand All @@ -32,7 +30,6 @@ class ParallelismArgs:
tp: int
pp_engine: Optional[PipelineEngine] = None
tp_mode: Optional[TensorParallelLinearMode] = None
recompute_granularity: Optional[RecomputeGranularity] = None
tp_linear_async_communication: Optional[bool] = None

expert_parallel_size: int = 1
Expand All @@ -50,5 +47,3 @@ def __post_init__(self):
self.pp_engine = cast_str_to_pipeline_engine(self.pp_engine)
if isinstance(self.tp_mode, str):
self.tp_mode = TensorParallelLinearMode[self.tp_mode.upper()]
if isinstance(self.recompute_granularity, str):
self.recompute_granularity = RecomputeGranularity[self.recompute_granularity.upper()]
18 changes: 2 additions & 16 deletions src/nanotron/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from nanotron import distributed as dist
from nanotron import logging
from nanotron.config import LlamaConfig, ParallelismArgs, RecomputeGranularity
from nanotron.config import LlamaConfig, ParallelismArgs
from nanotron.generation.generate_store import AttachableStore
from nanotron.logging import log_rank
from nanotron.models import NanotronModel
Expand Down Expand Up @@ -803,7 +803,6 @@ def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch
ffn_hidden_size=self.config.intermediate_size,
seq_len=sequence_length,
batch_size=global_batch_size,
recompute_granularity=self.parallel_config.recompute_granularity,
)

model_flops_per_s = model_flops / (iteration_time_in_sec * world_size * 1e12)
Expand Down Expand Up @@ -1052,7 +1051,6 @@ def get_flops(
seq_len,
ffn_hidden_size,
batch_size=1,
recompute_granularity=None,
):
"""Counts flops in an decoder-only model
Args:
Expand All @@ -1064,7 +1062,6 @@ def get_flops(
vocab_size: size of the vocabulary
seq_len: sequence length of the decoder
batch_size: batch size
recompute_granularity: Activation recomputation method. Either None, FULL or SELECTIVE. Check Megatron-LM docs for more info.
Returns:
model_flops: flops in the model (should be independent of the hardware and model implementation)
hardware_flops: flops in the hardware (actual flops performed on the hardware). Check 6.3 in https://arxiv.org/pdf/2205.05198.pdf
Expand Down Expand Up @@ -1110,17 +1107,6 @@ def get_flops(
# both input and weight tensors
model_flops = 3 * (decoder_flops_fwd + lm_head_flops_fwd) # 1 for fwd + 2 for bwd

if recompute_granularity is None:
hardware_flops = model_flops
elif recompute_granularity is RecomputeGranularity.FULL:
# Note: we don't recompute lm head activs
hardware_flops = model_flops + decoder_flops_fwd # + activ recomputation
elif recompute_granularity is RecomputeGranularity.SELECTIVE:
# all terms with s^2 are flops that are recomputed
# ref. appendix A: https://arxiv.org/pdf/2205.05198.pdf
recomputed_decoder_flops = decoder_qk_logits_flops_fwd + decoder_v_logits_flops_fwd
hardware_flops = model_flops + recomputed_decoder_flops
else:
raise ValueError("recompute_granularity must be one of 'full' or 'selective'")
hardware_flops = model_flops # TODO: This is a placeholder for now

return model_flops, hardware_flops
19 changes: 2 additions & 17 deletions src/nanotron/models/starcoder2.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from torch.nn import functional as F

from nanotron import distributed as dist
from nanotron.config import ParallelismArgs, RecomputeGranularity, Starcoder2Config
from nanotron.config import ParallelismArgs, Starcoder2Config
from nanotron.generation.generate_store import AttachableStore
from nanotron.models import NanotronModel
from nanotron.nn.activations import ACT2FN
Expand Down Expand Up @@ -1660,7 +1660,6 @@ def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch
ffn_hidden_size=self.config.n_inner if self.config.n_inner is not None else 4 * self.config.hidden_size,
seq_len=sequence_length,
batch_size=global_batch_size,
recompute_granularity=self.parallel_config.recompute_granularity,
kv_channels=None,
glu_activation=False,
)
Expand All @@ -1678,7 +1677,6 @@ def get_flops(
kv_channels=None,
ffn_hidden_size=None,
batch_size=1,
recompute_granularity=None,
glu_activation=False,
):
"""Counts flops in an decoder-only model
Expand All @@ -1691,7 +1689,6 @@ def get_flops(
vocab_size: size of the vocabulary
seq_len: sequence length of the decoder
batch_size: batch size
recompute_granularity: Activation recomputation method. Either None, FULL or SELECTIVE. Check Megatron-LM docs for more info.
glu_activation: Whether to use GLU activation in FFN. Check T5 v1.1 for more info.
Returns:
model_flops: flops in the model (should be independent of the hardware and model implementation)
Expand Down Expand Up @@ -1749,17 +1746,5 @@ def get_flops(
# both input and weight tensors
model_flops = 3 * (decoder_flops_fwd + lm_head_flops_fwd) # 1 for fwd + 2 for bwd

if recompute_granularity is None:
hardware_flops = model_flops
elif recompute_granularity is RecomputeGranularity.FULL:
# Note: we don't recompute lm head activs
hardware_flops = model_flops + decoder_flops_fwd # + activ recomputation
elif recompute_granularity is RecomputeGranularity.SELECTIVE:
# all terms with s^2 are flops that are recomputed
# ref. appendix A: https://arxiv.org/pdf/2205.05198.pdf
recomputed_decoder_flops = decoder_qk_logits_flops_fwd + decoder_v_logits_flops_fwd
hardware_flops = model_flops + recomputed_decoder_flops
else:
raise ValueError("recompute_granularity must be one of 'full' or 'selective'")

hardware_flops = model_flops # TODO @nouamanetazi: This is a placeholder for now
return model_flops, hardware_flops

0 comments on commit 53c3064

Please sign in to comment.