Skip to content

Commit

Permalink
intermediate commit
Browse files Browse the repository at this point in the history
  • Loading branch information
apoorvtintin committed Dec 17, 2024
1 parent 68c6ee9 commit 29d1ed4
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 31 deletions.
58 changes: 29 additions & 29 deletions axlearn/common/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1261,49 +1261,49 @@ def apply_rotary_position_embeddings(
"""
# sin [batch_size, num_heads, sequence_length, embed_size_per_head//2]
# cos [batch_size, num_heads, sequence_length, embed_size_per_head//2]
# sin, cos = jnp.split(sinusoidal_pos, 2, axis=-1)
# # sin [θ0,θ1,θ2......θd/2-1] -> sin_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
# sin_pos = jnp.reshape(jnp.stack([sin, sin], axis=-1), sinusoidal_pos.shape)
# # cos [θ0,θ1,θ2......θd/2-1] -> cos_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
# cos_pos = jnp.reshape(jnp.stack([cos, cos], axis=-1), sinusoidal_pos.shape)
# # rotate_half_query_layer [-q1,q0,-q3,q2......,-qd-1,qd-2]
# rotate_half_query = jnp.reshape(
# jnp.stack([-query[..., 1::2], query[..., ::2]], axis=-1), query.shape
# )
# query = query * cos_pos + rotate_half_query * sin_pos
# # rotate_half_key_layer [-k1,k0,-k3,k2......,-kd-1,kd-2]
# rotate_half_key = jnp.reshape(jnp.stack([-key[..., 1::2], key[..., ::2]], axis=-1), key.shape)
# key = key * cos_pos + rotate_half_key * sin_pos
# if rotary_value:
# # rotate_half_value_layer [-v1,v0,-v3,v2......,-vd-1,vd-2]
# rotate_half_value = jnp.reshape(
# jnp.stack([-value[..., 1::2], value[..., ::2]], axis=-1), value.shape
# )
# value = value * cos_pos + rotate_half_value * sin_pos
# return query, key, value

def _rotate_half(x: jnp.ndarray) -> jnp.ndarray:
halves = jnp.split(x, 2, axis=-1)
return jnp.concatenate((-halves[1], halves[0]), axis=-1)

sin, cos = jnp.split(sinusoidal_pos, 2, axis=-1)
# sin [θ0,θ1,θ2......θd/2-1] -> sin_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
sin_pos = jnp.reshape(jnp.stack([sin, sin], axis=-1), sinusoidal_pos.shape)
# cos [θ0,θ1,θ2......θd/2-1] -> cos_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
cos_pos = jnp.reshape(jnp.stack([cos, cos], axis=-1), sinusoidal_pos.shape)
# rotate_half_query_layer [-q1,q0,-q3,q2......,-qd-1,qd-2]
rotate_half_query = _rotate_half(query)

rotate_half_query = jnp.reshape(
jnp.stack([-query[..., 1::2], query[..., ::2]], axis=-1), query.shape
)
query = query * cos_pos + rotate_half_query * sin_pos
# rotate_half_key_layer [-k1,k0,-k3,k2......,-kd-1,kd-2]
rotate_half_key = _rotate_half(key)
rotate_half_key = jnp.reshape(jnp.stack([-key[..., 1::2], key[..., ::2]], axis=-1), key.shape)
key = key * cos_pos + rotate_half_key * sin_pos
if rotary_value:
# rotate_half_value_layer [-v1,v0,-v3,v2......,-vd-1,vd-2]
rotate_half_value = _rotate_half(value)
rotate_half_value = jnp.reshape(
jnp.stack([-value[..., 1::2], value[..., ::2]], axis=-1), value.shape
)
value = value * cos_pos + rotate_half_value * sin_pos
return query, key, value

# def _rotate_half(x: jnp.ndarray) -> jnp.ndarray:
# halves = jnp.split(x, 2, axis=-1)
# return jnp.concatenate((-halves[1], halves[0]), axis=-1)

# sin, cos = jnp.split(sinusoidal_pos, 2, axis=-1)
# # sin [θ0,θ1,θ2......θd/2-1] -> sin_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
# sin_pos = jnp.reshape(jnp.stack([sin, sin], axis=-1), sinusoidal_pos.shape)
# # cos [θ0,θ1,θ2......θd/2-1] -> cos_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
# cos_pos = jnp.reshape(jnp.stack([cos, cos], axis=-1), sinusoidal_pos.shape)
# # rotate_half_query_layer [-q1,q0,-q3,q2......,-qd-1,qd-2]
# rotate_half_query = _rotate_half(query)

# query = query * cos_pos + rotate_half_query * sin_pos
# # rotate_half_key_layer [-k1,k0,-k3,k2......,-kd-1,kd-2]
# rotate_half_key = _rotate_half(key)
# key = key * cos_pos + rotate_half_key * sin_pos
# if rotary_value:
# # rotate_half_value_layer [-v1,v0,-v3,v2......,-vd-1,vd-2]
# rotate_half_value = _rotate_half(value)
# value = value * cos_pos + rotate_half_value * sin_pos
# return query, key, value


class RoFormerQKVLinear(BaseQKVLinear):
"""RoFormerQKVLinear class
Expand Down
51 changes: 51 additions & 0 deletions axlearn/common/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def restore_tf_savables(value_map: Nested[Any], *, dir: str) -> Nested[Any]:
"""Restores TF savables from `dir` into `value_map` in-place."""

for path, value in utils.flatten_items(value_map):
logging.info("restoring path %s, value %s", path, value)
tf_checkpoint = tf.train.Checkpoint(value)
tf_checkpoint.read(os.path.join(dir, path))

Expand Down Expand Up @@ -540,9 +541,11 @@ def restore_from_dir(
check_state_structure(
read_index_file(ckpt_dir), target_structure=spec.index, validation=validation
)
logging.info(" 1 Restoring checkpoint from directory %s", ckpt_dir)
restore_tf_savables(
spec.tf_ckpt_map, dir=os.path.join(ckpt_dir, f"tf_{jax.process_index()}")
)
logging.info(" 2 Restoring checkpoint from directory %s", ckpt_dir)
maybe_restore_grain_savables(
spec.grain_ckpt_map, dir=os.path.join(ckpt_dir, f"grain_{jax.process_index()}")
)
Expand All @@ -556,13 +559,16 @@ def restore_from_dir(
)
state_leaves = []
for path, value in spec.index:
logging.info(" path %s, value %s", path, value)
if path == "step":
pass
elif path in spec.tf_ckpt_map:
state_leaves.append(spec.tf_ckpt_map[path])
logging.info("tf_ckpt_map %s", spec.tf_ckpt_map[path])
elif path in spec.grain_ckpt_map:
state_leaves.append(spec.grain_ckpt_map[path])
elif isinstance(value, dict):
logging.info("restored_gda_values.pop(0) %s", restored_gda_values[0])
state_leaves.append(restored_gda_values.pop(0))
else:
raise RuntimeError(f"Unknown index entry '{value}'")
Expand All @@ -571,6 +577,7 @@ def restore_from_dir(
jax.tree_util.tree_structure(state), state_leaves
)
multihost_utils.sync_global_devices(ckpt_dir)
# state_leaves.append(maybe_convert_param(path, restored_gda_values.pop(0)))
return restored_state

def stop(self):
Expand Down Expand Up @@ -1086,6 +1093,50 @@ def validate_and_restore(*, step: int, ckpt_dir: str):
step=step, state=state, ckpt_dir=ckpt_dir
)
logging.info("Restored state from ckpt at step %s", step)
from jax.experimental.pjit import pjit
import re
def convert_to_new_embeddings(x):
even_indices = jnp.arange(0, x.shape[-1], 2)
odd_indices = jnp.arange(1, x.shape[-1], 2)
reorder_indices = jnp.concatenate([odd_indices, even_indices])
return x[..., reorder_indices]

def maybe_convert_param(param, value, spec):
logging.info("param str %s", str(param))
logging.info("value %s", value)
logging.info("spec %s", spec)
logging.info("spec.mesh_axes %s", spec.mesh_axes)
# pattern = r'model/decoder/transformer/layer(\d+)/self_attention/attention/i_proj/i_proj/[qkv]_proj/weight'
pattern = r'.*([kq]_proj).*'
# pytree_str = ""
pytree_str = jax.tree_util.keystr(param)
# for key in param:
# pytree_str += ('/' + str(key))

match = re.match(pattern, pytree_str)
logging.info("pytree_str %s", pytree_str)
if match:
# layer_number = match.group(1)
# proj_type = param.split('/')[-2][0] # Extract 'q', 'k', or 'v'
# logging.info(f"Projection type: {proj_type}")

return convert_to_new_embeddings(value)

return value

logging.info("state %s", restored_state)
# new_state
# for path, value in utils.flatten_items(restored_state):
# maybe_convert_param(path, value)
def convert_wrapper(restored_state):
return jax.tree_util.tree_map_with_path(maybe_convert_param, restored_state, state)

fn = pjit(
convert_wrapper,
# in_shardings=value.sharding,
# out_shardings=value.sharding,
)
restored_state = fn(restored_state)
if "summary_writer" in self.children:
self.summary_writer.log_checkpoint(
step=step,
Expand Down
4 changes: 2 additions & 2 deletions axlearn/experiments/text/gpt/fuji.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def get_trainer_kwargs(
elif model_size == "70B":
trainer_kwargs = dict(
model_kwargs=dict(
num_layers=80,
num_layers=8,
hidden_dim=128 * 64,
num_heads=64,
# No GQA support in V1 models, so num_kv_heads is the same as num_heads.
Expand All @@ -415,7 +415,7 @@ def get_trainer_kwargs(
),
learner_kwargs=dict(peak_lr=1.5e-4, weight_decay=0.1),
max_sequence_length=max_sequence_length,
train_batch_size=train_batch_size,
train_batch_size=16,
input_partition_type=None if backend != "neuron" else DataPartitionType.BATCH,
max_step=max_step,
mesh_shape=mesh_shape_from_axes(fsdp=-1),
Expand Down

0 comments on commit 29d1ed4

Please sign in to comment.