Skip to content

Commit

Permalink
add llama3 test and is_compiling lookaside (#1397)
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi authored Nov 4, 2024
1 parent cb02f8a commit dcf0729
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 12 deletions.
14 changes: 14 additions & 0 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,20 @@ def lookaside_wrapper(lookaside):
return lookaside_wrapper


# PyTorch moved this to torch.compiler.is_compiling as official API
# we are compiling
if hasattr(torch, "compiler") and hasattr(torch.compiler, "is_compiling"):
is_compiling = torch.compiler.is_compiling
else:
is_compiling = torch._dynamo.is_compiling


@register_general_jit_lookaside(is_compiling)
@interpreter_needs_wrap
def jit_is_compiling_lookaside():
return True


# lookaside for getattr. We record the provenance of the attribute but for the core attribute getting, we
# rely on the default JIT getattr lookaside (as returned from default_lookaside)
@register_general_jit_lookaside(getattr)
Expand Down
93 changes: 81 additions & 12 deletions thunder/tests/test_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,18 +255,6 @@ def test_hf_bert():
def dummy(*args):
pass

# Transformers 2.41+ adds some more non-essential data-dependent
# control flow behind a check whether we are compiling
if hasattr(torch, "compiler") and hasattr(torch.compiler, "is_compiling"):
is_compiling = torch.compiler.is_compiling
else:
is_compiling = torch._dynamo.is_compiling

@thunder.core.jit_ext.register_general_jit_lookaside(is_compiling)
@thunder.core.jit_ext.interpreter_needs_wrap
def dummy(*args):
return True

# transformers accesses the old attrib and causes the future warning
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=FutureWarning, message=".*torch._dynamo.*.is_compiling.*")
Expand Down Expand Up @@ -410,3 +398,84 @@ def test_thunderfx_mistral_nemo_small():
logits.backward(grad_logits)

assert th_backend.subgraph_infos, "Should have at least 1 subgraph"


LLAMA_3_2_1B_CFG = {
"architectures": ["LlamaForCausalLM"],
"attention_bias": False,
"attention_dropout": 0.0,
"bos_token_id": 128000,
"eos_token_id": 128001,
"head_dim": 64,
"hidden_act": "silu",
"hidden_size": 2048,
"initializer_range": 0.02,
"intermediate_size": 8192,
"max_position_embeddings": 131072,
"mlp_bias": False,
"model_type": "llama",
"num_attention_heads": 32,
"num_hidden_layers": 16,
"num_key_value_heads": 8,
"pretraining_tp": 1,
"rms_norm_eps": 1e-05,
"rope_scaling": {
"factor": 32.0,
"high_freq_factor": 4.0,
"low_freq_factor": 1.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3",
},
"rope_theta": 500000.0,
"tie_word_embeddings": True,
"torch_dtype": "bfloat16",
"transformers_version": "4.45.0.dev0",
"use_cache": True,
"vocab_size": 128256,
"_commit_hash": "4e20de362430cd3b72f300e6b0f18e50e7166e08",
}


@requiresCUDA
def test_hf_llama():
from transformers.models.llama import LlamaForCausalLM, LlamaConfig
from transformers import DynamicCache
from transformers.models.llama.modeling_llama import logger as llama_logger
import logging

# transformers logs a cache deprecation warning
llama_logger.setLevel(logging.CRITICAL)
model_id = "meta-llama/Llama-3.2-1B"

config_args = LLAMA_3_2_1B_CFG.copy()
config_args["num_hidden_layers"] = 1
with torch.device("cuda"):
model = LlamaForCausalLM(LlamaConfig(**config_args)).to(torch.bfloat16).requires_grad_(False).eval()

jm = thunder.jit(model)

args1 = dict(
cache_position=torch.tensor([0, 1, 2, 3, 4, 5], device="cuda:0"),
input_ids=torch.tensor([[128000, 791, 1401, 311, 2324, 374]], device="cuda:0"),
inputs_embeds=None,
attention_mask=torch.tensor([[1, 1, 1, 1, 1, 1]], device="cuda:0"),
use_cache=True,
return_dict=True,
)
res = jm(**args1)
expected = model(**args1)

assert_close(res, expected, rtol=1e-1, atol=1e-1)

args2 = dict(
cache_position=torch.tensor([6], device="cuda:0"),
input_ids=torch.tensor([[311]], device="cuda:0"),
inputs_embeds=None,
attention_mask=torch.tensor([[1, 1, 1, 1, 1, 1, 1]], device="cuda:0"),
use_cache=True,
return_dict=True,
)

res2 = jm(past_key_values=res["past_key_values"], **args2)
expected2 = model(past_key_values=res["past_key_values"], **args2)
assert_close(res2, expected2, rtol=1e-1, atol=1e-1)

0 comments on commit dcf0729

Please sign in to comment.