Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

Commit

Permalink
Optimize LayerNorm
Browse files Browse the repository at this point in the history
  • Loading branch information
danny.jang authored Oct 6, 2023
1 parent 1813a8b commit 96b0328
Showing 1 changed file with 59 additions and 17 deletions.
76 changes: 59 additions & 17 deletions trident/kernel/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

class LayerNorm:
@staticmethod
@triton.heuristics({"require_x_boundary_check": lambda args: args["x_size"] % args["x_block_size"]})
@triton.jit
def forward(
output_ptr: tl.tensor,
Expand All @@ -31,6 +32,7 @@ def forward(
eps: tl.float32,
dtype: tl.constexpr,
x_block_size: tl.constexpr,
require_x_boundary_check: tl.constexpr,
):
y_offset = tl.program_id(0)

Expand Down Expand Up @@ -67,10 +69,16 @@ def forward(
order=(1, 0),
)

input = tl.load(input_block_ptr, boundary_check=(1,), padding_option="zero")
mean = tl.sum(input / x_size, 1)
condition = tl.arange(0, x_block_size) < x_size
centered_mean = tl.where(condition, input - mean, 0)
if require_x_boundary_check:
input = tl.load(input_block_ptr, boundary_check=(1,), padding_option="zero")
mean = tl.sum(input / x_size, 1)
condition = tl.arange(0, x_block_size) < x_size
centered_mean = tl.where(condition, input - mean, 0)
else:
input = tl.load(input_block_ptr)
mean = tl.sum(input / x_size, 1)
centered_mean = input - mean

var = tl.sum(centered_mean * centered_mean / x_size, 1)
rstd = tl.math.rsqrt(var + eps)
output = centered_mean * rstd
Expand All @@ -84,7 +92,12 @@ def forward(
block_shape=(x_block_size,),
order=(0,),
)
weight = tl.load(weight_block_ptr, boundary_check=(0,))

if require_x_boundary_check:
weight = tl.load(weight_block_ptr, boundary_check=(0,))
else:
weight = tl.load(weight_block_ptr)

output *= weight

if bias_ptr is not None:
Expand All @@ -96,14 +109,24 @@ def forward(
block_shape=(x_block_size,),
order=(0,),
)
bias = tl.load(bias_block_ptr, boundary_check=(0,))

if require_x_boundary_check:
bias = tl.load(bias_block_ptr, boundary_check=(0,))
else:
bias = tl.load(bias_block_ptr)

output += bias

tl.store(output_block_ptr, output.to(dtype), boundary_check=(1,))
if require_x_boundary_check:
tl.store(output_block_ptr, output.to(dtype), boundary_check=(1,))
else:
tl.store(output_block_ptr, output.to(dtype))

tl.store(rstd_block_ptr, rstd.to(dtype))
tl.store(mean_block_ptr, mean.to(dtype))

@staticmethod
@triton.heuristics({"require_x_boundary_check": lambda args: args["x_size"] % args["x_block_size"]})
@triton.jit
def backward(
grad_input_ptr: tl.tensor,
Expand All @@ -117,6 +140,7 @@ def backward(
mean_ptr: tl.tensor,
dtype: tl.constexpr,
x_block_size: tl.constexpr,
require_x_boundary_check: tl.constexpr,
):
y_offset = tl.program_id(0)

Expand Down Expand Up @@ -161,12 +185,16 @@ def backward(
order=(0,),
)

grad_output = tl.load(grad_output_block_ptr, boundary_check=(1,))
input = tl.load(input_block_ptr, boundary_check=(1,))
if require_x_boundary_check:
grad_output = tl.load(grad_output_block_ptr, boundary_check=(1,), padding_option="zero")
input = tl.load(input_block_ptr, boundary_check=(1,), padding_option="zero")
else:
grad_output = tl.load(grad_output_block_ptr)
input = tl.load(input_block_ptr)

rstd = tl.load(rstd_block_ptr)
mean = tl.load(mean_block_ptr)
condition = tl.arange(0, x_block_size) < x_size
centered_mean = tl.where(condition, input - mean, 0)
centered_mean = input - mean

if weight_ptr is not None:
weight_block_ptr = tl.make_block_ptr(
Expand All @@ -177,22 +205,29 @@ def backward(
block_shape=(1, x_block_size),
order=(1, 0),
)
weight = tl.load(weight_block_ptr, boundary_check=(1,))

if require_x_boundary_check:
weight = tl.load(weight_block_ptr, boundary_check=(1,))
else:
weight = tl.load(weight_block_ptr)

grad_norm = weight * grad_output
else:
grad_norm = grad_output

grad_std = tl.sum(grad_norm * centered_mean, 1)
grad_var = grad_std * -(0.5 * rstd * rstd * rstd) / x_size
grad_distance = 2 * centered_mean * grad_var
grad_centered_mean = tl.where(condition, grad_norm * rstd + grad_distance, 0)
grad_centered_mean = grad_norm * rstd + grad_distance
grad_mean = -tl.sum(grad_centered_mean, 1) / x_size
grad_input = grad_centered_mean + grad_mean
tl.store(grad_input_block_ptr, grad_input.to(dtype), boundary_check=(1,))

if require_x_boundary_check:
tl.store(grad_input_block_ptr, grad_input.to(dtype), boundary_check=(1,))
else:
tl.store(grad_input_block_ptr, grad_input.to(dtype))

if grad_weight_staging_ptr is not None:
norm = centered_mean * rstd
grad_weight = norm * grad_output
grad_weight_staging_block_ptr = tl.make_block_ptr(
grad_weight_staging_ptr,
shape=(y_size, x_size),
Expand All @@ -201,4 +236,11 @@ def backward(
block_shape=(1, x_block_size),
order=(1, 0),
)
tl.store(grad_weight_staging_block_ptr, grad_weight.to(dtype), boundary_check=(1,))

norm = centered_mean * rstd
grad_weight = norm * grad_output

if require_x_boundary_check:
tl.store(grad_weight_staging_block_ptr, grad_weight.to(dtype), boundary_check=(1,))
else:
tl.store(grad_weight_staging_block_ptr, grad_weight.to(dtype))

0 comments on commit 96b0328

Please sign in to comment.