Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
chhwang committed Oct 8, 2024
1 parent 02bf517 commit 7d05da8
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 47 deletions.
6 changes: 3 additions & 3 deletions examples/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def forward(self, x, ffn_norm):
]
elif seqlen == 128:
schedule = [
[128, [128, 64], 16480],
[128, [128, 64], 12384],
]
else:
raise ValueError(f"Unsupported seqlen {seqlen}")
Expand Down Expand Up @@ -471,7 +471,7 @@ def forward(self, x, ffn_norm):
sram = 24672
elif seqlen == 128:
tile = [128, 64]
sram = 16480
sram = 12384
else:
raise ValueError(f"Unsupported seqlen {seqlen}")

Expand Down Expand Up @@ -645,7 +645,7 @@ def softmax(scores):
sram = 24672
elif seqlen == 128:
tile = [128, 64]
sram = 16480
sram = 12384
else:
raise ValueError(f"Unsupported seqlen {seqlen}")

Expand Down
44 changes: 0 additions & 44 deletions examples/llama/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,49 +363,6 @@ def test_column_parallel_linear(
)


def test_attention(
args: ModelArgs,
batch_size: int,
seq_len: int,
dtype: np.dtype,
rank: int = 0,
world_size: int = 1,
):
#
freqs_cis = precompute_freqs_cis(
args.dim // args.n_heads, args.max_seq_len * 2
)[0:seq_len]

freqs_cis_ark = freqs_cis.astype(np.complex64)
freqs_cis_ark = (
np.stack([freqs_cis_ark.real, freqs_cis_ark.imag], axis=-1)
.astype(dtype)
.reshape(1, seq_len, 1, args.dim // args.n_heads)
)

seed = 1695878986 # int(time.time())
print(f"seed: {seed}")
np.random.seed(seed)
feature = np.random.uniform(
low=-0.1, high=0.1, size=(batch_size, seq_len, args.dim)
).astype(dtype)

test_module(
module_class_ark=model_ark.Attention,
module_args_ark=[
args,
ark.DataType.from_numpy(dtype),
rank,
world_size,
],
inputs_ark=[feature, 0, freqs_cis_ark, None],
module_class_pt=model_pt.Attention,
module_args_pt=[args],
inputs_pt=[feature.astype(dtype), 0, freqs_cis, None],
module_name_prefix="layers.0.attention",
)


def test_transformer(
args: ModelArgs,
batch_size: int,
Expand Down Expand Up @@ -472,7 +429,6 @@ def test(args, batch_size, seq_len, dtype, rank, world_size):
# test_rmsnorm(args, batch_size, seq_len, dtype)
# test_row_parallel_linear(args, batch_size, seq_len, dtype, rank, world_size)
# test_column_parallel_linear(args, batch_size, seq_len, dtype, rank, world_size)
# test_attention(args, batch_size, seq_len, dtype, rank, world_size)
test_transformer(args, batch_size, seq_len, dtype, rank, world_size)


Expand Down

0 comments on commit 7d05da8

Please sign in to comment.