Skip to content

Commit

Permalink
fix evaler
Browse files Browse the repository at this point in the history
  • Loading branch information
apoorvtintin committed Dec 7, 2024
1 parent eaf45a4 commit 0f5171b
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 17 deletions.
21 changes: 17 additions & 4 deletions axlearn/common/evaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from axlearn.common.module import Module, OutputCollection
from axlearn.common.module import functional as F
from axlearn.common.utils import (
DataPartitionType,
NestedPartitionSpec,
NestedTensor,
Tensor,
Expand Down Expand Up @@ -81,6 +82,11 @@ class Config(Module.Config):
# evalers, not setting prefix will show the accuracies on the same plot for comparison
# across evalers.
prefix: Optional[str] = None
# Subset of mesh axis names over which the leaves of the input batch are sharded.
batch_axis_names: Union[str, Sequence[str]] = "data"
# The input partition:
# Options: FULL (default), BATCH, REPLICATED
input_partition_type: Optional[DataPartitionType] = DataPartitionType.FULL

def __init__(
self,
Expand Down Expand Up @@ -188,11 +194,11 @@ def _pjit(self, fn: Callable) -> Callable:
in_shardings=(
self._model_param_partition_specs, # model_params.
None, # replicated_inputs (e.g., prng_key).
utils.input_partition_spec(), # per_example_inputs.
utils.data_partition_type_to_spec(partition=self.config.input_partition_type, batch_axis_names=self.config.batch_axis_names), # per_example_inputs.
),
out_shardings=dict(
replicated=None,
per_example=utils.input_partition_spec(),
per_example=utils.data_partition_type_to_spec( partition=self.config.input_partition_type, batch_axis_names=self.config.batch_axis_names),
),
)

Expand Down Expand Up @@ -574,6 +580,11 @@ class Config(Module.Config):
metric_calculator: BaseMetricCalculator.Config = ModelSummaryAccumulator.default_config()
# If not None, writes input batches and `metric_calculator` forward outputs.
output_writer: Optional[BaseOutputWriter.Config] = None
# Subset of mesh axis names over which the leaves of the input batch are sharded.
batch_axis_names: Union[str, Sequence[str]] = "data"
# The input partition:
# Options: FULL (default), BATCH, REPLICATED
input_partition_type: Optional[DataPartitionType] = DataPartitionType.FULL

def __init__(
self,
Expand All @@ -595,7 +606,7 @@ def __init__(
self._add_child("input", maybe_set_config(cfg.input, is_training=False))
self._add_child(
"metric_calculator",
cfg.metric_calculator.set(eval_dtype=cfg.eval_dtype),
cfg.metric_calculator.set(eval_dtype=cfg.eval_dtype, batch_axis_names=cfg.batch_axis_names, input_partition_type=cfg.input_partition_type),
model=model,
model_param_partition_specs=model_param_partition_specs,
)
Expand Down Expand Up @@ -691,7 +702,9 @@ def eval_step(

with jax.profiler.StepTraceAnnotation(cfg.name, step_num=step):
with jax.profiler.TraceAnnotation(f"{cfg.name}.forward"):
global_input_batch = utils.host_to_global_device_array(input_batch)
print("evaler data partition type ", self.config.input_partition_type, flush=True)
print("evaler batch_axis_names type ", self.config.batch_axis_names, flush=True)
global_input_batch = utils.host_to_global_device_array(input_batch, partition=self.config.input_partition_type, batch_axis_names=self.config.batch_axis_names)
forward_outputs = self.metric_calculator.forward(
global_input_batch,
model_params=model_params,
Expand Down
28 changes: 15 additions & 13 deletions axlearn/experiments/text/gpt/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,9 +699,24 @@ def config_fn() -> InstantiableConfig:
)
if input_partition_type:
cfg.input_partition_type = input_partition_type
if len(mesh_axis_names) != len(mesh_shape):
raise ValueError(
f"Number of mesh axis names ({mesh_axis_names}) "
f"must match number of mesh dims ({mesh_shape})."
)
cfg.mesh_axis_names = mesh_axis_names
cfg.mesh_shape = mesh_shape
# Set batch sharding spec to exclude the "model" axis (assumed for tensor-parallelism) and
# "pipeline" axis (for pipeline parallelism).
cfg.batch_axis_names = tuple(
el for el in mesh_axis_names if el not in ("model", "pipeline")
)
cfg.mesh_rules = mesh_rules
cfg.evalers = {}
for name, evaler_cfg in evalers.items():
evaler_cfg.input.batcher.set(global_batch_size=eval_batch_size or train_batch_size)
evaler_cfg.set(input_partition_type=input_partition_type)
evaler_cfg.set(batch_axis_names=cfg.batch_axis_names)
evaler_cfg.set(
eval_policy=config_for_function(eval_every_n_steps_policy).set(
n=eval_every_n_steps,
Expand All @@ -718,19 +733,6 @@ def config_fn() -> InstantiableConfig:
cfg.checkpointer.keep_last_n = 3
cfg.summary_writer.write_every_n_steps = min(eval_every_n_steps, 100)
cfg.summary_writer.max_queue = 1000
if len(mesh_axis_names) != len(mesh_shape):
raise ValueError(
f"Number of mesh axis names ({mesh_axis_names}) "
f"must match number of mesh dims ({mesh_shape})."
)
cfg.mesh_axis_names = mesh_axis_names
cfg.mesh_shape = mesh_shape
# Set batch sharding spec to exclude the "model" axis (assumed for tensor-parallelism) and
# "pipeline" axis (for pipeline parallelism).
cfg.batch_axis_names = tuple(
el for el in mesh_axis_names if el not in ("model", "pipeline")
)
cfg.mesh_rules = mesh_rules
# Maybe load state.
if init_state_builder:
cfg.init_state_builder = init_state_builder
Expand Down

0 comments on commit 0f5171b

Please sign in to comment.