Skip to content

Commit

Permalink
bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
joey00072 committed Apr 22, 2024
1 parent a598db9 commit 16941b3
Show file tree
Hide file tree
Showing 4 changed files with 394 additions and 2 deletions.
308 changes: 308 additions & 0 deletions dev/test.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,308 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 66,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"from torch import Tensor\n",
"import bitmat"
]
},
{
"cell_type": "code",
"execution_count": 67,
"metadata": {},
"outputs": [],
"source": [
"device = torch.device(\"cuda\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"y_mat = bitlinear(x)"
]
},
{
"cell_type": "code",
"execution_count": 70,
"metadata": {},
"outputs": [],
"source": [
"class RMSNorm(nn.Module):\n",
" def __init__(self, dim: int, eps: float = 1e-5):\n",
" super().__init__()\n",
" \"\"\"\n",
" Paper: https://arxiv.org/abs/1910.07467\n",
" \"\"\"\n",
" self.eps = eps\n",
" self.weight = nn.Parameter(torch.ones(dim))\n",
"\n",
" def _norm(self, x: Tensor):\n",
" return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)\n",
"\n",
" def forward(self, x):\n",
" output = self._norm(x.float()).type_as(x)\n",
" return output * self.weight\n",
"\n",
"\n",
"def activation_quant(x):\n",
" \"\"\"Per−token quantization to 8 bits. No grouping is needed for quantization.\n",
" Args:\n",
" x: an activation tensor with shape [n, d]\n",
" Returns:\n",
" y: a quantized activation tensor with shape [n, d]\n",
" \"\"\"\n",
" scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)\n",
" y = (x * scale).round().clamp_(-128, 127) / scale\n",
" return y\n",
"\n",
"def weight_quant(w):\n",
" \"\"\"Per−tensor quantization to 1.58 bits. No grouping is needed for quantization.\n",
" Args:\n",
" w: a weight tensor with shape [d, k]\n",
" Returns:\n",
" u: a quantized weight with shape [d, k]\n",
" \"\"\"\n",
" scale = 1.0 / w.abs().mean().clamp_(min=1e-5)\n",
" u = (w * scale).round().clamp_(-1, 1) / scale\n",
" return u\n",
"\n",
"\n",
"class BitLinear(nn.Linear):\n",
" \"\"\"\n",
" This is only for training, and kernel optimization is needed for efficiency.\n",
" \"\"\"\n",
" def __init__(self, in_features: int, out_features: int, bias: bool = True,\n",
" device=None, dtype=None, config=None):\n",
" super().__init__(in_features, out_features, bias, device, dtype)\n",
" self.norm = RMSNorm(in_features)\n",
"\n",
" def forward(self, x):\n",
" \"\"\"\n",
" Args:\n",
" x: an input tensor with shape [n, d]\n",
" Returns:\n",
" y: an output tensor with shape [n, d]\n",
" \"\"\"\n",
" w = self.weight # a weight tensor with shape [d, k]\n",
" x_norm = self.norm(x)\n",
" # Atrick for implementing Straight−Through−Estimator (STE) using detach()\n",
" x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach()\n",
" w_quant = w + (weight_quant(w) - w).detach()\n",
" y = F.linear(x_quant, w_quant)\n",
" return y\n"
]
},
{
"cell_type": "code",
"execution_count": 71,
"metadata": {},
"outputs": [],
"source": [
"blin = BitLinear(2000, 200,bias=False).to(device=device)\n",
"bitlinear = bitmat.BitLinear(2000, 200,bias=False).to(device=device)\n",
"x = torch.randn(8, 2000,device=device)"
]
},
{
"cell_type": "code",
"execution_count": 72,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<All keys matched successfully>"
]
},
"execution_count": 72,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"bitlinear.convert_weights_to_parameters()\n",
"blin.load_state_dict(bitlinear.state_dict())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 73,
"metadata": {},
"outputs": [],
"source": [
"y_exp = blin(x)\n",
"y_mat = bitlinear(x)"
]
},
{
"cell_type": "code",
"execution_count": 74,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(False, device='cuda:0')"
]
},
"execution_count": 74,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.isclose(y_exp, y_mat).all()"
]
},
{
"cell_type": "code",
"execution_count": 76,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(2.0575, device='cuda:0', grad_fn=<SumBackward0>)"
]
},
"execution_count": 76,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(y_exp - y_mat).sum()"
]
},
{
"cell_type": "code",
"execution_count": 77,
"metadata": {},
"outputs": [],
"source": [
"x_ref = blin.norm(x.to(blin.norm.weight.dtype)).to(x.dtype)\n",
"x_tri = bitlinear.norm(x.to(bitlinear.norm.weight.dtype)).to(x.dtype)"
]
},
{
"cell_type": "code",
"execution_count": 78,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(True, device='cuda:0')"
]
},
"execution_count": 78,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.isclose(x_ref, x_tri).all()"
]
},
{
"cell_type": "code",
"execution_count": 85,
"metadata": {},
"outputs": [],
"source": [
"from bitmat.utils.bitmat import bitmat as pi_bitmat"
]
},
{
"cell_type": "code",
"execution_count": 95,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 0.8170, 1.1203, -0.1639, ..., 0.2053, -0.8501, -2.2957],\n",
" [ 1.6535, -1.3252, 1.0182, ..., 0.5661, -0.0032, 0.7885],\n",
" [ 1.9974, 0.1114, -1.1349, ..., 1.1120, -1.1433, -0.3425],\n",
" ...,\n",
" [-0.5807, -0.2230, 0.7941, ..., 1.2162, -0.1268, -0.4903],\n",
" [ 0.0493, 0.3854, -0.9208, ..., 1.3942, -0.2622, -2.5241],\n",
" [-1.3806, 0.6822, -1.7021, ..., 0.1908, -0.3309, 0.2036]],\n",
" device='cuda:0', grad_fn=<BitMatBackward>)"
]
},
"execution_count": 95,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"out_tri = pi_bitmat(bitlinear.weight,x_tri,None)\n",
"out_tri"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"x_quant = x_ref + (activation_quant(x_ref) - x_ref).detach()\n",
"w_quant = w + (weight_quant(w) - w).detach()\n",
"y = F.linear(x_quant, w_quant)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
2 changes: 1 addition & 1 deletion experiments/bitnet/bitnet.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ This is from a precursor paper https://arxiv.org/pdf/2310.11453.pdf
from torch import Tensor

def activation_quant(x: Tensor) -> Tensor:
scale: Tensor = 127.0 / x.abs().max(dim=1, keepdim=True).values.clamp(min=1e-5)
scale: Tensor = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
y: Tensor = (x * scale).round().clamp(-128, 127) / scale
return x + (y - x).detach()
```
Expand Down
2 changes: 1 addition & 1 deletion experiments/bitnet/bitnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def forward(self, x):

@torch.jit.script # jit speedup https://colab.research.google.com/drive/1B_-PfHKzSmuwF3TETx_ZMlFSE5PNcr1k?usp=sharing
def activation_quant(x: Tensor) -> Tensor:
scale: Tensor = 127.0 / x.abs().max(dim=1, keepdim=True).values.clamp(min=1e-5)
scale: Tensor = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
y: Tensor = (x * scale).round().clamp(-128, 127) / scale
return x + (y - x).detach()

Expand Down
Loading

0 comments on commit 16941b3

Please sign in to comment.