diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index c48b481bca..7b1cf3eb87 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -411,6 +411,20 @@ def lookaside_wrapper(lookaside): return lookaside_wrapper +# PyTorch moved this to torch.compiler.is_compiling as official API +# we are compiling +if hasattr(torch, "compiler") and hasattr(torch.compiler, "is_compiling"): + is_compiling = torch.compiler.is_compiling +else: + is_compiling = torch._dynamo.is_compiling + + +@register_general_jit_lookaside(is_compiling) +@interpreter_needs_wrap +def jit_is_compiling_lookaside(): + return True + + # lookaside for getattr. We record the provenance of the attribute but for the core attribute getting, we # rely on the default JIT getattr lookaside (as returned from default_lookaside) @register_general_jit_lookaside(getattr) diff --git a/thunder/tests/test_networks.py b/thunder/tests/test_networks.py index 6076b681a4..fb8b6f4546 100644 --- a/thunder/tests/test_networks.py +++ b/thunder/tests/test_networks.py @@ -255,18 +255,6 @@ def test_hf_bert(): def dummy(*args): pass - # Transformers 2.41+ adds some more non-essential data-dependent - # control flow behind a check whether we are compiling - if hasattr(torch, "compiler") and hasattr(torch.compiler, "is_compiling"): - is_compiling = torch.compiler.is_compiling - else: - is_compiling = torch._dynamo.is_compiling - - @thunder.core.jit_ext.register_general_jit_lookaside(is_compiling) - @thunder.core.jit_ext.interpreter_needs_wrap - def dummy(*args): - return True - # transformers accesses the old attrib and causes the future warning with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=FutureWarning, message=".*torch._dynamo.*.is_compiling.*") @@ -410,3 +398,84 @@ def test_thunderfx_mistral_nemo_small(): logits.backward(grad_logits) assert th_backend.subgraph_infos, "Should have at least 1 subgraph" + + +LLAMA_3_2_1B_CFG = { + "architectures": ["LlamaForCausalLM"], + "attention_bias": False, + "attention_dropout": 0.0, + "bos_token_id": 128000, + "eos_token_id": 128001, + "head_dim": 64, + "hidden_act": "silu", + "hidden_size": 2048, + "initializer_range": 0.02, + "intermediate_size": 8192, + "max_position_embeddings": 131072, + "mlp_bias": False, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 16, + "num_key_value_heads": 8, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": { + "factor": 32.0, + "high_freq_factor": 4.0, + "low_freq_factor": 1.0, + "original_max_position_embeddings": 8192, + "rope_type": "llama3", + }, + "rope_theta": 500000.0, + "tie_word_embeddings": True, + "torch_dtype": "bfloat16", + "transformers_version": "4.45.0.dev0", + "use_cache": True, + "vocab_size": 128256, + "_commit_hash": "4e20de362430cd3b72f300e6b0f18e50e7166e08", +} + + +@requiresCUDA +def test_hf_llama(): + from transformers.models.llama import LlamaForCausalLM, LlamaConfig + from transformers import DynamicCache + from transformers.models.llama.modeling_llama import logger as llama_logger + import logging + + # transformers logs a cache deprecation warning + llama_logger.setLevel(logging.CRITICAL) + model_id = "meta-llama/Llama-3.2-1B" + + config_args = LLAMA_3_2_1B_CFG.copy() + config_args["num_hidden_layers"] = 1 + with torch.device("cuda"): + model = LlamaForCausalLM(LlamaConfig(**config_args)).to(torch.bfloat16).requires_grad_(False).eval() + + jm = thunder.jit(model) + + args1 = dict( + cache_position=torch.tensor([0, 1, 2, 3, 4, 5], device="cuda:0"), + input_ids=torch.tensor([[128000, 791, 1401, 311, 2324, 374]], device="cuda:0"), + inputs_embeds=None, + attention_mask=torch.tensor([[1, 1, 1, 1, 1, 1]], device="cuda:0"), + use_cache=True, + return_dict=True, + ) + res = jm(**args1) + expected = model(**args1) + + assert_close(res, expected, rtol=1e-1, atol=1e-1) + + args2 = dict( + cache_position=torch.tensor([6], device="cuda:0"), + input_ids=torch.tensor([[311]], device="cuda:0"), + inputs_embeds=None, + attention_mask=torch.tensor([[1, 1, 1, 1, 1, 1, 1]], device="cuda:0"), + use_cache=True, + return_dict=True, + ) + + res2 = jm(past_key_values=res["past_key_values"], **args2) + expected2 = model(past_key_values=res["past_key_values"], **args2) + assert_close(res2, expected2, rtol=1e-1, atol=1e-1)