Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decoupled Momentum Optimization #771

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,7 @@ def effective_n_kv_heads(self) -> int:
class OptimizerType(StrEnum):
lionw = "lionw"
adamw = "adamw"
demo = "demo"


@dataclass
Expand Down Expand Up @@ -533,6 +534,20 @@ class OptimizerConfig(BaseConfig):
of the update with AdamW.
"""

### DeMo parameters
compression_decay: float = 0.999

compression_topk: int = 32
"""
How many numbers of topk to transmit per chunk, if dynamic is enabled, this is the initial topk
"""

compression_chunk: int = 64
"""
Size of the chunk of the gradients, note that 2D gradients are chunked in 2D, which the topk sparsity is squared compared to 1D
"""


def __post_init__(self):
self.betas = tuple(self.betas) # type: ignore[assignment]

Expand Down Expand Up @@ -736,6 +751,12 @@ class DDPGradSyncMode(StrEnum):
set to True, to prevent errors.
"""

none = "none"
"""
Totally disable gradient synchronization within the distributed model.
Should only be done with some explicit external synchronization (e.g. DeMo) or if you just like spinning your wheels
"""


@dataclass
class DDPConfig(BaseConfig):
Expand Down Expand Up @@ -830,6 +851,8 @@ class FSDPConfig(BaseConfig):
PyTorch's default HSDP behavior matches this default behavior.
"""

disable_grad_sync: bool = False


class CheckpointType(StrEnum):
sharded = "sharded"
Expand Down
286 changes: 286 additions & 0 deletions olmo/demo_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,286 @@
import math
import torch
import torch.fft
import torch.distributed as dist

from einops import rearrange


class TransformDCT:
@torch.no_grad()
def __init__(self, param_groups, target_chunk, norm="ortho"):
self.target_chunk = target_chunk

self.shape_dict = dict()
self.f_dict = dict()
self.b_dict = dict()

# Get all variants of model tensor sizes
# Generate all possible valid DCT sizes for model tensors
for group in param_groups:
for p in group["params"]:
if not p.requires_grad:
continue
for s in p.shape:
# Get the closest smallest divisor to the targeted DCT size
sc = _get_smaller_split(s, self.target_chunk)
self.shape_dict[s] = sc

# Pregenerate DCT basis matrices
if sc not in self.f_dict:
I = torch.eye(sc)
self.f_dict[sc] = _dct(I, norm=norm).to(p.dtype).to(p.device)
self.b_dict[sc] = _idct(I, norm=norm).to(p.dtype).to(p.device)

@torch.no_grad()
def einsum_2d(self, x, b, d=None):
if d is None:
return torch.einsum("...ij, jb -> ...ib", x, b)
else:
# Note: b-c axis output is transposed to chunk DCT in 2D
return torch.einsum("...ijkl, jb, ld -> ...ikbd", x, b, d)

@torch.no_grad()
def einsum_2d_t(self, x, b, d=None):
if d is None:
return torch.einsum("...ij, jb -> ...ib", x, b)
else:
# Note: b-c axis output is transposed to chunk DCT in 2D
return torch.einsum("...ijkl, kb, ld -> ...ibjd", x, b, d)

@torch.no_grad()
def encode(self, x):
if len(x.shape) > 1: # 2D weights
n1 = self.shape_dict[x.shape[0]]
n2 = self.shape_dict[x.shape[1]]
n1w = self.f_dict[n1].to(x.device)
n2w = self.f_dict[n2].to(x.device)
self.f_dict[n1] = n1w
self.f_dict[n2] = n2w

x = rearrange(x, "(y h) (x w) -> y h x w", h=n1, w=n2)
x = self.einsum_2d(x, n1w, n2w)

else: # 1D weights
n1 = self.shape_dict[x.shape[0]]
n1w = self.f_dict[n1].to(x.device)
self.f_dict[n1] = n1w

x = rearrange(x, "(x w) -> x w", w=n1)
x = self.einsum_2d(x, n1w)

return x

@torch.no_grad()
def decode(self, x):
if len(x.shape) > 2: # 2D weights
n1 = x.shape[2]
n2 = x.shape[3]
n1w = self.b_dict[n1].to(x.device)
n2w = self.b_dict[n2].to(x.device)
self.b_dict[n1] = n1w
self.b_dict[n2] = n2w

x = self.einsum_2d_t(x, n1w, n2w)
x = rearrange(x, "y h x w -> (y h) (x w)")

else: # 1D weights
n1 = x.shape[1]
n1w = self.b_dict[n1].to(x.device)
self.b_dict[n1] = n1w

x = self.einsum_2d_t(x, n1w)
x = rearrange(x, "x w -> (x w)")

return x


class CompressDCT:
@torch.no_grad()
def __init__(self):
pass

def _clamp_topk(self, x, topk):
if topk > x.shape[-1]:
topk = x.shape[-1]
if topk < 1:
topk = 1
return topk

@torch.no_grad()
def compress(self, x, topk):
xshape = x.shape
if len(x.shape) > 2: # 2D weights
x = rearrange(x, "y x h w -> y x (h w)")

# Limit topk to max size
totalk = x.shape[-1]
topk = self._clamp_topk(x, topk)

idx = torch.topk(x.abs(), k=topk, dim=-1, largest=True, sorted=False).indices
val = torch.gather(x, dim=-1, index=idx)

return idx, val, xshape, totalk

@torch.no_grad()
def decompress(self, p, idx, val, xshape, totalk):
x = torch.zeros(xshape, device=p.device, dtype=p.dtype)

if len(xshape) > 2: # 2D weights
x = rearrange(x, "y x h w -> y x (h w)")

# TODO: Careful, this is nondeterministic across different CUDA devices! might cause errors to accumulate between nodes!
x.scatter_reduce_(dim=-1, index=idx, src=val, reduce="mean", include_self=False).reshape(xshape)

if len(x.shape) > 2: # 2D weights
x = rearrange(x, "y x (h w) -> y x h w", h=xshape[2])

return x

@torch.no_grad()
def batch_decompress(self, p, idx, val, xshape, totalk):
idx = torch.concatenate(idx, dim=-1).to(device=p.device)
val = torch.concatenate(val, dim=-1).to(device=p.device)
return self.decompress(p, idx, val, xshape, totalk)


# Code modified and sourced from https://github.com/zh217/torch-dct
def _dct_fft_impl(v):
return torch.view_as_real(torch.fft.fft(v, dim=1))


def _idct_irfft_impl(V):
return torch.fft.irfft(torch.view_as_complex(V), n=V.shape[1], dim=1)


def _dct(x, norm=None):
"""
Discrete Cosine Transform, Type II (a.k.a. the DCT)

For the meaning of the parameter `norm`, see:
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html

:param x: the input signal
:param norm: the normalization, None or 'ortho'
:return: the DCT-II of the signal over the last dimension
"""
x_shape = x.shape
N = x_shape[-1]
x = x.contiguous().view(-1, N)

v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1)

Vc = _dct_fft_impl(v)

k = -torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * math.pi / (2 * N)
W_r = torch.cos(k)
W_i = torch.sin(k)

V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i

if norm == "ortho":
V[:, 0] /= math.sqrt(N) * 2
V[:, 1:] /= math.sqrt(N / 2) * 2

V = 2 * V.view(*x_shape)

return V


def _idct(X, norm=None):
"""
The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III

Our definition of idct is that idct(dct(x)) == x

For the meaning of the parameter `norm`, see:
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html

:param X: the input signal
:param norm: the normalization, None or 'ortho'
:return: the inverse DCT-II of the signal over the last dimension
"""

x_shape = X.shape
N = x_shape[-1]

X_v = X.contiguous().view(-1, x_shape[-1]) / 2

if norm == "ortho":
X_v[:, 0] *= math.sqrt(N) * 2
X_v[:, 1:] *= math.sqrt(N / 2) * 2

k = torch.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :] * math.pi / (2 * N)
W_r = torch.cos(k)
W_i = torch.sin(k)

V_t_r = X_v
V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1)

V_r = V_t_r * W_r - V_t_i * W_i
V_i = V_t_r * W_i + V_t_i * W_r

V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2)

v = _idct_irfft_impl(V)
x = v.new_zeros(v.shape)
x[:, ::2] += v[:, : N - (N // 2)]
x[:, 1::2] += v.flip([1])[:, : N // 2]

return x.view(*x_shape)


def _get_prime_divisors(n):
divisors = []
while n % 2 == 0:
divisors.append(2)
n //= 2
while n % 3 == 0:
divisors.append(3)
n //= 3
i = 5
while i * i <= n:
for k in (i, i + 2):
while n % k == 0:
divisors.append(k)
n //= k
i += 6
if n > 1:
divisors.append(n)
return divisors


def _get_divisors(n):
divisors = []
if n == 1:
divisors.append(1)
elif n > 1:
prime_factors = _get_prime_divisors(n)
divisors = [1]
last_prime = 0
factor = 0
slice_len = 0
# Find all the products that are divisors of n
for prime in prime_factors:
if last_prime != prime:
slice_len = len(divisors)
factor = prime
else:
factor *= prime
for i in range(slice_len):
divisors.append(divisors[i] * factor)
last_prime = prime
divisors.sort()
return divisors


def _get_smaller_split(n, close_to):
all_divisors = _get_divisors(n)
for ix, val in enumerate(all_divisors):
if val == close_to:
return val
if val > close_to:
if ix == 0:
return val
return all_divisors[ix - 1]
return n
Loading