Skip to content

Commit

Permalink
add leaky_relu op
Browse files Browse the repository at this point in the history
  • Loading branch information
beverlylytle committed Nov 20, 2024
1 parent 60f3ee1 commit b43481f
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 11 deletions.
2 changes: 2 additions & 0 deletions thunder/executors/torchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,7 @@ def _erfcinv_impl(a: torch.Tensor) -> torch.Tensor:
celu = _register_torch_operation("celu", module=torch.nn.functional)
elu = _register_torch_operation("elu", module=torch.nn.functional)
gelu = _register_torch_operation("gelu", module=torch.nn.functional)
leaky_relu = _register_torch_operation("leaky_relu", module=torch.nn.functional)
relu = _register_torch_operation("relu", module=torch.nn.functional)
relu6 = _register_torch_operation("relu6", module=torch.nn.functional)
hardswish = _register_torch_operation("hardswish", module=torch.nn.functional)
Expand All @@ -850,6 +851,7 @@ def _elementwise_unary_with_inplace_checker(a: TensorProxy, /, inplace: bool = F
_register_elementwise_unary_implementation(ltorch.elu, elu, checker=_always_executable)
_register_elementwise_unary_implementation(ltorch.celu, celu, checker=_always_executable)
_register_elementwise_unary_implementation(ltorch.gelu, gelu, checker=_always_executable)
_register_elementwise_unary_implementation(ltorch.leaky_relu, leaky_relu, checker=_always_executable)
_register_elementwise_unary_implementation(ltorch.relu, relu, checker=_elementwise_unary_with_inplace_checker)
_register_elementwise_unary_implementation(ltorch.relu6, relu6, checker=_elementwise_unary_with_inplace_checker)
_register_elementwise_unary_implementation(ltorch.hardswish, hardswish, checker=_elementwise_unary_with_inplace_checker)
Expand Down
34 changes: 24 additions & 10 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -1633,20 +1633,24 @@ def _abs_torch(x: torch.Tensor | Number):
elementwise_unary_ops.append(reciprocal_opinfo)


def elementwise_unary_with_alpha_generator(op, device, dtype, requires_grad):
alphas = (None, -1.0, 0.5)
samples = elementwise_unary_generator(op, device, dtype, requires_grad)
for alpha, sample in itertools.product(alphas, samples):
if alpha is None:
yield sample
else:
yield SampleInput(*sample.args, alpha=alpha, **sample.kwargs)
def get_elementwise_unary_with_alpha_generator():
kwargs_list = [{}, {"alpha": -1.0}, {"alpha": 0.5}]
return get_elementwise_unary_with_kwargs_generator(kwargs_list)


def get_elementwise_unary_with_kwargs_generator(kwargs_list):
def gen(op, device, dtype, requires_grad):
samples = elementwise_unary_generator(op, device, dtype, requires_grad)
for kwargs, sample in itertools.product(kwargs_list, samples):
yield SampleInput(*sample.args, **kwargs, **sample.kwargs)

return gen


celu_opinfo = OpInfo(
ltorch.celu,
dtypes=(datatypes.floating,),
sample_input_generator=elementwise_unary_with_alpha_generator,
sample_input_generator=get_elementwise_unary_with_alpha_generator(),
torch_reference=_elementwise_unary_torch(torch.celu),
test_directives=(),
)
Expand All @@ -1656,7 +1660,7 @@ def elementwise_unary_with_alpha_generator(op, device, dtype, requires_grad):
elu_opinfo = OpInfo(
ltorch.elu,
dtypes=(datatypes.floating,),
sample_input_generator=elementwise_unary_with_alpha_generator,
sample_input_generator=get_elementwise_unary_with_alpha_generator(),
torch_reference=torch.nn.functional.elu,
# fdm.jvp, which is used in test_vjp_correctness, behaves badly on (-1e-6, 1e-6) for this function
singularity_fn=lambda x: x,
Expand All @@ -1665,6 +1669,16 @@ def elementwise_unary_with_alpha_generator(op, device, dtype, requires_grad):
elementwise_unary_ops.append(elu_opinfo)


leaky_relu_opinfo = OpInfo(
ltorch.leaky_relu,
dtypes=(datatypes.floating,),
sample_input_generator=get_elementwise_unary_with_kwargs_generator([{}, {"negative_slope": 0.5}]),
torch_reference=torch.nn.functional.leaky_relu,
test_directives=(),
)
elementwise_unary_ops.append(leaky_relu_opinfo)


relu_opinfo = OpInfo(
ltorch.relu,
sample_input_generator=elementwise_unary_generator,
Expand Down
11 changes: 11 additions & 0 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1801,6 +1801,17 @@ def gelu(a: TensorProxy, /, *, approximate: str = "none") -> TensorLike:
raise ValueError(f"gelu does not support the approximate={approximate} argument")


@torchsymbol(torch.nn.functional.leaky_relu, is_method=False)
def leaky_relu(a: TensorProxy, /, negative_slope=0.01, inplace: bool = False) -> TensorLike:
out = where(a > 0, a, a * negative_slope)
if inplace:
return prims.copy_(out, a)
return out


_inplace_to_out_of_place[leaky_relu] = leaky_relu, 2


# TODO Should this use clamp? -- Would that propagate NaNs properly?
@torchsymbol(torch.relu, torch.nn.functional.relu, id="torch.relu", is_method=True)
def relu(a: TensorLike, /, inplace: bool = False) -> TensorLike:
Expand Down
1 change: 0 additions & 1 deletion thunder/torch/default_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,6 @@
torch.nn.functional.instance_norm,
torch.nn.functional.kl_div,
torch.nn.functional.l1_loss,
torch.nn.functional.leaky_relu,
torch.nn.functional.local_response_norm,
torch.nn.functional.logsigmoid,
torch.nn.functional.lp_pool1d,
Expand Down

0 comments on commit b43481f

Please sign in to comment.