From 3a45a340f353c12dc82fd4ce4c1bc0a99e96a8d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Tue, 27 Aug 2024 15:22:26 +0000 Subject: [PATCH] remove torch compile --- src/nanotron/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 49ea86e6..ca6c2441 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -165,7 +165,7 @@ def __init__( bias=False, async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER, ) - self.split_silu_mul = torch.compile(GLUActivation(config.hidden_act)) + self.split_silu_mul = GLUActivation(config.hidden_act) def forward(self, hidden_states): # [seq_length, batch_size, hidden_dim] merged_states = self.gate_up_proj(hidden_states)