Skip to content

Commit

Permalink
Adding inference support
Browse files Browse the repository at this point in the history
  • Loading branch information
MekkCyber committed May 23, 2024
1 parent e1c3c1a commit 2239f41
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 57 deletions.
46 changes: 23 additions & 23 deletions examples/config_tiny_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,35 +22,14 @@
)
from nanotron.logging import human_format

model_config = LlamaConfig(
# Config for a tiny model model with 1.62M parameters
bos_token_id=1,
eos_token_id=2,
hidden_act="silu",
hidden_size=16,
initializer_range=0.02,
intermediate_size=64,
max_position_embeddings=256,
num_attention_heads=4,
num_hidden_layers=2,
num_key_value_heads=4,
pretraining_tp=1,
rms_norm_eps=1e-05,
rope_scaling=None,
tie_word_embeddings=True,
use_cache=True,
vocab_size=256,
)

# model_config = LlamaBitNetConfig(
# # Config for a tiny 1.58bit model model with 1.62M parameters
# model_config = LlamaConfig(
# # Config for a tiny model model with 1.62M parameters
# bos_token_id=1,
# eos_token_id=2,
# hidden_act="silu",
# hidden_size=16,
# initializer_range=0.02,
# intermediate_size=64,
# is_bitnet_config=True,
# max_position_embeddings=256,
# num_attention_heads=4,
# num_hidden_layers=2,
Expand All @@ -63,6 +42,27 @@
# vocab_size=256,
# )

model_config = LlamaBitNetConfig(
# Config for a tiny 1.58bit model model with 1.62M parameters
bos_token_id=1,
eos_token_id=2,
hidden_act="silu",
hidden_size=16,
initializer_range=0.02,
intermediate_size=64,
is_bitnet_config=True,
max_position_embeddings=256,
num_attention_heads=4,
num_hidden_layers=2,
num_key_value_heads=4,
pretraining_tp=1,
rms_norm_eps=1e-05,
rope_scaling=None,
tie_word_embeddings=True,
use_cache=True,
vocab_size=256,
)

num_params = human_format(
model_config.vocab_size * model_config.hidden_size * 2
+ model_config.num_hidden_layers
Expand Down
2 changes: 1 addition & 1 deletion examples/config_tiny_llama.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ model:
hidden_size: 16
initializer_range: 0.02
intermediate_size: 64
is_llama_config: true
is_bitnet_config: true
max_position_embeddings: 256
num_attention_heads: 4
num_hidden_layers: 2
Expand Down
12 changes: 6 additions & 6 deletions src/nanotron/models/llama_bitnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,32 +608,32 @@ def __init__(
layer_idx: int,
):
super().__init__()
self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
#self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.attn = CausalSelfAttentionBitNet(
config=config,
parallel_config=parallel_config,
tp_pg=tp_pg,
layer_idx=layer_idx,
)

self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.mlp = MLPBitNet(config=config, parallel_config=parallel_config, tp_pg=tp_pg)
#self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.MLPBitNet = MLPBitNet(config=config, parallel_config=parallel_config, tp_pg=tp_pg)

def forward(
self,
hidden_states: Union[torch.Tensor, TensorPointer],
sequence_mask: Union[torch.Tensor, TensorPointer],
) -> Dict[str, Union[torch.Tensor, TensorPointer]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
#hidden_states = self.input_layernorm(hidden_states)

output = self.attn(hidden_states=hidden_states, sequence_mask=sequence_mask)
hidden_states = output["hidden_states"]
hidden_states = hidden_states + residual

residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states=hidden_states)["hidden_states"]
#hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.MLPBitNet(hidden_states=hidden_states)["hidden_states"]
hidden_states = hidden_states + residual

return {
Expand Down
108 changes: 81 additions & 27 deletions src/nanotron/parallel/tensor_parallel/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import torch
from torch import nn
import torch.nn.functional as F

from nanotron import distributed as dist
from nanotron.distributed import get_global_rank
Expand Down Expand Up @@ -311,6 +312,17 @@ def activation_quant(x):
y = (x * scale).round().clamp_(-128, 127) / scale
return y

def activation_quant_inference(x):
"""Per−token quantization to 8 bits. No grouping is needed for quantization.
Args:
x: an activation tensor with shape [n, d]
Returns:
y: a quantized activation tensor with shape [n, d]
"""
scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
y = (x * scale).round().clamp_(-128, 127)
return y, scale

def weight_quant(w):
"""Per−tensor quantization to 1.58 bits.
Args:
Expand All @@ -329,6 +341,14 @@ def normalize_last_two_dimensions(tensor, eps=1e-6):

return normalized_tensor

def normalize(x, dim, eps=1e-5) :
scale = dim ** (-0.5)
#return F.normalize(x, dim=-1) * scale
norm_x = x.norm(2, dim=-1, keepdim=True)
rms_x = norm_x * scale
x_normed = x / (rms_x + eps)
return x_normed

class TensorParallelColumnLinearBitNet(nn.Linear):
def __init__(
self,
Expand Down Expand Up @@ -367,27 +387,44 @@ def __init__(
), f"Sum of contiguous chunks ({sum(contiguous_chunks)}) must equal to out_features ({out_features})"
split_config = SplitConfig(split_dim=0, contiguous_chunks=contiguous_chunks)


self.register_parameter('weight_scale', None)
# self.is_model_inference = False
mark_all_parameters_in_module_as_sharded(
self,
pg=self.pg,
split_config=split_config,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
w = self.weight
x_norm = normalize_last_two_dimensions(x)

x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach()
w_quant = w + (weight_quant(w) - w).detach()

return column_linear(
input=x_quant,
weight=w_quant,
bias=self.bias,
group=self.pg,
tp_mode=self.mode,
async_communication=self.async_communication,
)
if not self.training :
w = self.weight # a weight tensor with shape [d, k]
w = w.to(torch.bfloat16)
w_scale = getattr(self, "weight_scale").data
x_norm = normalize(x, self.in_features)
x_quant, x_scale = activation_quant_inference(x_norm)
return column_linear(
input=x_quant,
weight=w,
bias=self.bias,
group=self.pg,
tp_mode=self.mode,
async_communication=self.async_communication,
) / w_scale / x_scale
else :
w = self.weight
x_norm = normalize(x, self.in_features)
x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach()
w_quant = w + (weight_quant(w) - w).detach()

return column_linear(
input=x_quant,
weight=w_quant,
bias=self.bias,
group=self.pg,
tp_mode=self.mode,
async_communication=self.async_communication,
)

def extra_repr(self) -> str:
return f"tp_rank={dist.get_rank(self.pg)}, {super().extra_repr()}, unsharded_out_features={self.out_features * self.world_size}"
Expand Down Expand Up @@ -436,6 +473,8 @@ def __init__(

split_config = SplitConfig(split_dim=1, contiguous_chunks=contiguous_chunks)

self.register_parameter('weight_scale', None)
# self.is_model_inference = False
self._mark_all_parameters_in_module_as_sharded(split_config)

def _mark_all_parameters_in_module_as_sharded(self, split_config: SplitConfig):
Expand All @@ -452,19 +491,34 @@ def _mark_all_parameters_in_module_as_sharded(self, split_config: SplitConfig):
setattr(self, name, new_param)

def forward(self, x: torch.Tensor) -> torch.Tensor:
w = self.weight
x_norm = normalize_last_two_dimensions(x)

x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach()
w_quant = w + (weight_quant(w) - w).detach()
return row_linear(
input=x_quant,
weight=w_quant,
bias=self.bias,
group=self.pg,
tp_mode=self.mode,
async_communication=self.async_communication,
)
if not self.training :
w = self.weight # a weight tensor with shape [d, k]
w = w.to(torch.bfloat16)
w_scale = getattr(self, "weight_scale").data
x_norm = normalize(x, self.in_features)
x_quant, x_scale = activation_quant_inference(x_norm)
# print("x_quant from row_linear: ",x_quant.shape)
return row_linear(
input=x_quant,
weight=w,
bias=self.bias,
group=self.pg,
tp_mode=self.mode,
async_communication=self.async_communication,
) / w_scale / x_scale
else :
w = self.weight
x_norm = normalize(x, self.in_features)
x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach()
w_quant = w + (weight_quant(w) - w).detach()
return row_linear(
input=x_quant,
weight=w_quant,
bias=self.bias,
group=self.pg,
tp_mode=self.mode,
async_communication=self.async_communication,
)

def extra_repr(self) -> str:
return f"tp_rank={dist.get_rank(self.pg)}, {super().extra_repr()}, unsharded_in_features={self.in_features * self.world_size}"

0 comments on commit 2239f41

Please sign in to comment.