diff --git a/examples/llama/llama_ark.py b/examples/llama/llama_ark.py index d1a4f3103..c1e9b6342 100644 --- a/examples/llama/llama_ark.py +++ b/examples/llama/llama_ark.py @@ -207,7 +207,7 @@ def __init__(self, params: ModelArgs): # ) self.layers = [] - for layer_id in range(6): + for layer_id in range(self.n_layers): self.tmp_layer = TransformerBlock(layer_id, params) self.layers.append(self.tmp_layer) diff --git a/examples/llama/llama_test.py b/examples/llama/llama_test.py index 22e6b1fbe..c0192e3b4 100644 --- a/examples/llama/llama_test.py +++ b/examples/llama/llama_test.py @@ -371,5 +371,5 @@ def test_rotary_embedding(): # test_attention() # test_feedforward() # test_transformerblock() - # test_transformer() - test_rotary_embedding() + test_transformer() + # test_rotary_embedding()