Skip to content

Commit

Permalink
[Bugfix] wrong call of threshold_backward
Browse files Browse the repository at this point in the history
  • Loading branch information
StrongSpoon committed Jan 23, 2025
1 parent 8b6148f commit 6abc7bc
Showing 1 changed file with 2 additions and 5 deletions.
7 changes: 2 additions & 5 deletions tests/test_binary_pointwise_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1307,15 +1307,12 @@ def test_accuracy_threshold_backward(shape, dtype):
res_inp = torch.randn(shape, dtype=dtype, device=flag_gems.device)
res_grad = torch.randn_like(res_inp)
threshold = 0
value = 100

ref_inp = to_reference(res_inp, True)
ref_grad = to_reference(res_grad, True)

ref_in_grad = torch.ops.aten.threshold_backward(ref_grad, ref_inp, threshold, value)
ref_in_grad = torch.ops.aten.threshold_backward(ref_grad, ref_inp, threshold)
with flag_gems.use_gems():
res_in_grad = torch.ops.aten.threshold_backward(
res_grad, res_inp, threshold, value
)
res_in_grad = torch.ops.aten.threshold_backward(res_grad, res_inp, threshold)

gems_assert_close(res_in_grad, ref_in_grad, dtype)

0 comments on commit 6abc7bc

Please sign in to comment.