Skip to content

Commit

Permalink
custom CudaProjector for large models to avoid overflow error in CUDA…
Browse files Browse the repository at this point in the history
… kernel
  • Loading branch information
kristian-georgiev committed Nov 2, 2023
1 parent 62426eb commit 259f087
Show file tree
Hide file tree
Showing 4 changed files with 426 additions and 36 deletions.
199 changes: 173 additions & 26 deletions tests/test_jl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,38 @@
import math
from itertools import product
import numpy as np
import torch as ch
import torch
from torch import testing

from trak.projectors import CudaProjector, ProjectionType
from trak.projectors import CudaProjector, ProjectionType, ChunkedCudaProjector

ch = torch


def get_max_chunk_size(
batch_size: int,
) -> tuple[int, list]:
max_chunk_size = np.iinfo(np.uint32).max // batch_size
return max_chunk_size


def make_input(
input_shape, max_chunk_size, device="cuda", dtype=torch.float32, g_tensor=None
):
if g_tensor is None:
g = testing.make_tensor(*input_shape, device=device, dtype=dtype)
else:
g = g_tensor
_, num_params = input_shape
num_chunks = np.ceil(num_params / max_chunk_size).astype("int32")
g_chunks = ch.chunk(g, num_chunks, dim=1)
result = {}
for i, x in enumerate(g_chunks):
result[i] = x
print(f"Input param group {i} shape: {x.shape}")

return result


BasicProjector = CudaProjector

Expand All @@ -28,6 +56,22 @@
)
)

PARAM = list(
product(
[123], # seed
[ProjectionType.rademacher], # proj type
[ch.float32], # dtype
[
# tests for MAXINT32 overflow
(8, 180645096), # pass: np.prod(shape) < np.iinfo(np.int32).max
(31, 180645096), # fail: np.prod(shape) > np.iinfo(np.int32).max
(32, 180645096), # fail: np.prod(shape) > np.iinfo(np.int32).max
(2, 780645096), # fail: np.prod(shape) > np.iinfo(np.int32).max
], # input shape
[15_360], # proj dim
)
)


@pytest.mark.parametrize("seed, proj_type, dtype, input_shape, proj_dim", PARAM)
@pytest.mark.cuda
Expand All @@ -43,7 +87,6 @@ def test_seed_consistency(
leads to the same result.
"""

g = testing.make_tensor(*input_shape, device="cuda:0", dtype=dtype)
proj = BasicProjector(
grad_dim=input_shape[-1],
proj_dim=proj_dim,
Expand All @@ -53,11 +96,17 @@ def test_seed_consistency(
dtype=dtype,
max_batch_size=MAX_BATCH_SIZE,
)
batch_size = input_shape[0]
max_chunk_size = get_max_chunk_size(batch_size)
g = make_input(input_shape, max_chunk_size, "cuda:0", dtype)

result = proj.project(g, model_id=0)
result_again = proj.project(g, model_id=0)
testing.assert_close(result, result_again, equal_nan=True)

del g
torch.cuda.empty_cache()


@pytest.mark.parametrize("seed, proj_type, dtype, input_shape, proj_dim", PARAM)
@pytest.mark.cuda
Expand All @@ -73,7 +122,10 @@ def test_seed_consistency_2(
with the same seed leads to the same result.
"""

g = testing.make_tensor(*input_shape, device="cuda:0", dtype=dtype)
batch_size = input_shape[0]
max_chunk_size = get_max_chunk_size(batch_size)
g = make_input(input_shape, max_chunk_size, "cuda:0", dtype)

proj = BasicProjector(
grad_dim=input_shape[-1],
proj_dim=proj_dim,
Expand All @@ -98,6 +150,9 @@ def test_seed_consistency_2(
result_again = proj_again.project(g, model_id=0)
testing.assert_close(result, result_again, equal_nan=True)

del g
torch.cuda.empty_cache()


@pytest.mark.parametrize("seed, proj_type, dtype, input_shape, proj_dim", PARAM)
@pytest.mark.cuda
Expand All @@ -111,15 +166,37 @@ def test_norm_preservation(
"""
Check that norms of differences are approximately preserved.
"""
g = testing.make_tensor(*input_shape, device="cuda:0", dtype=dtype)
proj = BasicProjector(
grad_dim=input_shape[-1],
proj_dim=proj_dim,
proj_type=proj_type,
seed=seed,
device="cuda:0",
dtype=dtype,
max_batch_size=MAX_BATCH_SIZE,
batch_size = input_shape[0]
max_chunk_size = get_max_chunk_size(batch_size)
g = make_input(input_shape, max_chunk_size, "cuda:0", dtype)

rng = np.random.default_rng(seed)
seeds = rng.integers(
low=0,
high=500,
size=len(g),
)

param_chunk_sizes = [v.size(1) for v in g.values()]
projector_per_chunk = [
BasicProjector(
grad_dim=chunk_size,
proj_dim=proj_dim,
seed=seeds[i],
proj_type=proj_type,
max_batch_size=MAX_BATCH_SIZE,
dtype=dtype,
device="cuda:0",
)
for i, chunk_size in enumerate(param_chunk_sizes)
]
proj = ChunkedCudaProjector(
projector_per_chunk,
max_chunk_size,
param_chunk_sizes,
batch_size,
"cuda:0",
dtype,
)

p = proj.project(g, model_id=0)
Expand All @@ -129,6 +206,10 @@ def test_norm_preservation(
num_trials = 100
num_successes = 0

# flatten
g = ch.cat([v for v in g.values()], dim=1)
print(f"Flattened input shape: {g.shape}")

for _ in range(num_trials):
i, j = np.random.choice(range(g.shape[0]), size=2)
n = (g[i] - g[j]).norm()
Expand All @@ -144,6 +225,9 @@ def test_norm_preservation(
num_successes += int(res <= 35 * eps * n)
assert num_successes >= num_trials * (1 - 3 * delta) # leeway with 2 *

del g
torch.cuda.empty_cache()


@pytest.mark.parametrize("seed, proj_type, dtype, input_shape, proj_dim", PARAM)
@pytest.mark.cuda
Expand All @@ -157,7 +241,10 @@ def test_prod_preservation(
"""
Check that dot products are approximately preserved.
"""
g = testing.make_tensor(*input_shape, device="cuda:0", dtype=dtype)
batch_size = input_shape[0]
max_chunk_size = get_max_chunk_size(batch_size)
g = make_input(input_shape, max_chunk_size, "cuda:0", dtype)

proj = BasicProjector(
grad_dim=input_shape[-1],
proj_dim=proj_dim,
Expand All @@ -179,6 +266,10 @@ def test_prod_preservation(
num_trials = 100
num_successes = 0

# flatten
g = ch.cat([v for v in g.values()], dim=1)
print(f"Flattened input shape: {g.shape}")

for _ in range(num_trials):
i, j = np.random.choice(range(g.shape[0]), size=2)
n = g[i] @ g[j]
Expand All @@ -190,6 +281,9 @@ def test_prod_preservation(

assert num_successes >= num_trials * (1 - 2 * delta)

del g
torch.cuda.empty_cache()


@pytest.mark.parametrize("seed, proj_type, dtype, input_shape, proj_dim", PARAM)
@pytest.mark.cuda
Expand All @@ -203,11 +297,18 @@ def test_single_nonzero_feature(
"""
Check that output takes into account every feature.
"""
g = ch.zeros(*input_shape, device="cuda:0", dtype=dtype)

batch_size = input_shape[0]
max_chunk_size = get_max_chunk_size(batch_size)
g = make_input(input_shape, max_chunk_size, "cuda:0", dtype)
for k in g.keys():
g[k] = ch.zeros_like(g[k])

for ind in range(input_shape[0]):
coord = np.random.choice(range(input_shape[1]))
param_group = np.random.choice(range(len(g.keys())))
coord = np.random.choice(range(g[param_group].size(1)))
val = ch.randn(1)
g[ind, coord] = val.item()
g[param_group][ind, coord] = val.item()

proj = BasicProjector(
grad_dim=input_shape[-1],
Expand Down Expand Up @@ -237,6 +338,11 @@ def test_first_nonzero_feature(
g = ch.zeros(*input_shape, device="cuda:0", dtype=dtype)
g[:, 0] = 1.0

batch_size = input_shape[0]
max_chunk_size = get_max_chunk_size(batch_size)
g = make_input(input_shape, max_chunk_size, g_tensor=g)
print(g[0])

proj = BasicProjector(
grad_dim=input_shape[-1],
proj_dim=proj_dim,
Expand Down Expand Up @@ -265,6 +371,11 @@ def test_last_nonzero_feature(
g = ch.zeros(*input_shape, device="cuda:0", dtype=dtype)
g[:, -1] = 1.0

batch_size = input_shape[0]
max_chunk_size = get_max_chunk_size(batch_size)
g = make_input(input_shape, max_chunk_size, g_tensor=g)
print(g[0])

proj = BasicProjector(
grad_dim=input_shape[-1],
proj_dim=proj_dim,
Expand Down Expand Up @@ -293,19 +404,47 @@ def test_same_features(
g = testing.make_tensor(*input_shape, device="cuda:0", dtype=dtype)
g[-1] = g[0]

proj = BasicProjector(
grad_dim=input_shape[-1],
proj_dim=proj_dim,
proj_type=proj_type,
seed=seed,
device="cuda:0",
dtype=dtype,
max_batch_size=MAX_BATCH_SIZE,
batch_size = input_shape[0]
max_chunk_size = get_max_chunk_size(batch_size)
g = make_input(input_shape, max_chunk_size, g_tensor=g)
for i in range(len(g)):
print(g[i][0] == g[i][-1])

rng = np.random.default_rng(seed)
seeds = rng.integers(
low=0,
high=500,
size=len(g),
)

param_chunk_sizes = [v.size(1) for v in g.values()]
projector_per_chunk = [
BasicProjector(
grad_dim=chunk_size,
proj_dim=proj_dim,
seed=seeds[i],
proj_type=proj_type,
max_batch_size=MAX_BATCH_SIZE,
dtype=dtype,
device="cuda:0",
)
for i, chunk_size in enumerate(param_chunk_sizes)
]
proj = ChunkedCudaProjector(
projector_per_chunk,
max_chunk_size,
param_chunk_sizes,
batch_size,
"cuda:0",
dtype,
)
p = proj.project(g, model_id=0)

assert ch.allclose(p[0], p[-1])

del g
torch.cuda.empty_cache()


@pytest.mark.parametrize("seed, proj_type, dtype, input_shape, proj_dim", PARAM)
@pytest.mark.cuda
Expand Down Expand Up @@ -333,11 +472,19 @@ def test_orthogonality(
)

num_successes = 0
num_trials = 100
num_trials = 10
for _ in range(num_trials):
g = testing.make_tensor(*input_shape, device="cuda:0", dtype=dtype)
g[-1] -= g[0] @ g[-1] / (g[0].norm() ** 2) * g[0]

batch_size = input_shape[0]
max_chunk_size = get_max_chunk_size(batch_size)
g = make_input(input_shape, max_chunk_size, g_tensor=g)

p = proj.project(g, model_id=0)
if p[0] @ p[-1] < 1e-3:
num_successes += 1
assert num_successes > 0.33 * num_trials

del g
torch.cuda.empty_cache()
Loading

1 comment on commit 259f087

@kristian-georgiev
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Co-authored by @AlaaKhaddaj

Please sign in to comment.