Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add elu op #1417

Merged
merged 6 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions thunder/executors/torchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,7 @@ def _erfcinv_impl(a: torch.Tensor) -> torch.Tensor:

# nn.functional elementwise unary
celu = _register_torch_operation("celu", module=torch.nn.functional)
elu = _register_torch_operation("elu", module=torch.nn.functional)
beverlylytle marked this conversation as resolved.
Show resolved Hide resolved
gelu = _register_torch_operation("gelu", module=torch.nn.functional)
relu = _register_torch_operation("relu", module=torch.nn.functional)
relu6 = _register_torch_operation("relu6", module=torch.nn.functional)
Expand All @@ -815,6 +816,7 @@ def _elementwise_unary_with_inplace_checker(a: TensorProxy, /, inplace: bool = F
return isinstance(a, TensorProxy) and not inplace


_register_elementwise_unary_implementation(ltorch.elu, elu, checker=_always_executable)
_register_elementwise_unary_implementation(ltorch.celu, celu, checker=_always_executable)
_register_elementwise_unary_implementation(ltorch.gelu, gelu, checker=_always_executable)
_register_elementwise_unary_implementation(ltorch.relu, relu, checker=_elementwise_unary_with_inplace_checker)
Expand Down
16 changes: 14 additions & 2 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -1633,7 +1633,7 @@ def _abs_torch(x: torch.Tensor | Number):
elementwise_unary_ops.append(reciprocal_opinfo)


def celu_sample_generator(op, device, dtype, requires_grad):
def elementwise_unary_with_alpha_generator(op, device, dtype, requires_grad):
alphas = (None, -1.0, 0.5)
samples = elementwise_unary_generator(op, device, dtype, requires_grad)
for alpha, sample in itertools.product(alphas, samples):
Expand All @@ -1646,13 +1646,25 @@ def celu_sample_generator(op, device, dtype, requires_grad):
celu_opinfo = OpInfo(
ltorch.celu,
dtypes=(datatypes.floating,),
sample_input_generator=celu_sample_generator,
sample_input_generator=elementwise_unary_with_alpha_generator,
torch_reference=_elementwise_unary_torch(torch.celu),
test_directives=(),
)
elementwise_unary_ops.append(celu_opinfo)


elu_opinfo = OpInfo(
ltorch.elu,
dtypes=(datatypes.floating,),
sample_input_generator=elementwise_unary_with_alpha_generator,
torch_reference=torch.nn.functional.elu,
# fdm.jvp, which is used in test_vjp_correctness, behaves badly on (-1e-6, 1e-6) for this function
beverlylytle marked this conversation as resolved.
Show resolved Hide resolved
singularity_fn=lambda x: x,
test_directives=(),
)
elementwise_unary_ops.append(elu_opinfo)


relu_opinfo = OpInfo(
ltorch.relu,
sample_input_generator=elementwise_unary_generator,
Expand Down
4 changes: 2 additions & 2 deletions thunder/tests/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def check_vjp(f, *primals, comp, executor="torch", set_compile_data: bool = Fals
"""
# Let f be a function from vectors of size n to vectors of size m.
# Its Jacobian is a matrix J of size m x n.
# The adjoint property is J^* J = I, where J^* is the conjugate transpose (adjoint) of J.
# Represent by J^* the conjugate transpose (adjoint) of J.
beverlylytle marked this conversation as resolved.
Show resolved Hide resolved
# J^* is a matrix of size n x m.
# For any vector v of size m, J^* v is a vector of size n.
# For any vector u of size n, J u is a vector of size m.
Expand All @@ -296,7 +296,7 @@ def check_vjp(f, *primals, comp, executor="torch", set_compile_data: bool = Fals

u = tree_map(make, primals)

comp_f = thunder.jit(f)
comp_f = thunder.jit(f, disable_torch_autograd=True)

outs_p, J_u = numerical_jvp(comp_f)(primals, u)

Expand Down
12 changes: 12 additions & 0 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1774,6 +1774,18 @@ def celu(a: TensorLike, /, alpha: float = 1.0, inplace: bool = False) -> TensorL
_inplace_to_out_of_place[celu] = celu, 2


@torchsymbol(torch.nn.functional.elu, is_method=False)
def elu(a: TensorProxy, /, alpha: float = 1.0, inplace: bool = False) -> TensorLike:
negative_domain_value = alpha * expm1(a)
out = where(a > 0, a, negative_domain_value)
if inplace:
return prims.copy_(out, a)
return out


_inplace_to_out_of_place[elu] = elu, 2


@torchsymbol(torch.nn.functional.gelu, is_method=False)
def gelu(a: TensorProxy, /, *, approximate: str = "none") -> TensorLike:
if approximate == "none":
Expand Down
1 change: 0 additions & 1 deletion thunder/torch/default_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,6 @@
torch.nn.functional.dropout1d,
torch.nn.functional.dropout2d,
torch.nn.functional.dropout3d,
torch.nn.functional.elu,
torch.nn.functional.embedding_bag,
torch.nn.functional.feature_alpha_dropout,
torch.nn.functional.fold,
Expand Down
Loading