Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix functorch error for 2.0.x #59

Merged
merged 3 commits into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/tad_dftd3/__version__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@
"""
Version module for *tad-dftd3*.
"""
__version__ = "0.2.1"
__version__ = "0.2.2"
19 changes: 19 additions & 0 deletions src/tad_dftd3/model/c6.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,25 @@ def atomic_c6(
"""
_check_memory(numbers, weights, chunk_size)

# PyTorch 2.0.x has a bug with functorch and custom autograd functions as
# documented in: https://github.com/pytorch/pytorch/issues/99973
#
# RuntimeError: unwrapped_count > 0 INTERNAL ASSERT FAILED at "../aten/src/
# ATen/functorch/TensorWrapper.cpp":202, please report a bug to PyTorch.
# Should have at least one dead wrapper
#
# Hence, we cannot use the custom backwards for reduced memory consumption.
if __tversion__[0] == 2 and __tversion__[1] == 0: # pragma: no cover
track_weights = torch._C._functorch.is_gradtrackingtensor(weights)
track_numbers = torch._C._functorch.is_gradtrackingtensor(numbers)
if track_weights or track_numbers:

if chunk_size is None:
return _atomic_c6_full(numbers, weights, reference)

return _atomic_c6_chunked(numbers, weights, reference, chunk_size)

# Use custom autograd function for reduced memory consumption
AtomicC6 = AtomicC6_V1 if __tversion__ < (2, 0, 0) else AtomicC6_V2
res = AtomicC6.apply(numbers, weights, reference, chunk_size)
assert res is not None
Expand Down
85 changes: 61 additions & 24 deletions test/test_grad/test_pos.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import pytest
import torch
from tad_mctc.autograd import dgradcheck, dgradgradcheck
from tad_mctc.autograd import dgradcheck, dgradgradcheck, jacrev
from tad_mctc.batch import pack

from tad_dftd3 import dftd3
Expand Down Expand Up @@ -151,29 +151,29 @@ def test_gradgradcheck_batch(dtype: torch.dtype, name1: str, name2: str) -> None
@pytest.mark.parametrize("name", sample_list)
def test_autograd(dtype: torch.dtype, name: str) -> None:
"""Compare with reference values from tblite."""
dd: DD = {"device": DEVICE, "dtype": dtype}

sample = samples[name]
numbers = sample["numbers"]
positions = sample["positions"].type(dtype)
numbers = sample["numbers"].to(DEVICE)
positions = sample["positions"].to(**dd)

ref = sample["grad"].to(**dd)

# GFN1-xTB parameters
param = {
"s6": positions.new_tensor(1.00000000),
"s8": positions.new_tensor(2.40000000),
"s9": positions.new_tensor(0.00000000),
"a1": positions.new_tensor(0.63000000),
"a2": positions.new_tensor(5.00000000),
"s6": torch.tensor(1.00000000, **dd),
"s8": torch.tensor(2.40000000, **dd),
"s9": torch.tensor(0.00000000, **dd),
"a1": torch.tensor(0.63000000, **dd),
"a2": torch.tensor(5.00000000, **dd),
}

ref = sample["grad"].type(dtype)

# variable to be differentiated
positions.requires_grad_(True)
pos = positions.clone().requires_grad_(True)

# automatic gradient
energy = torch.sum(dftd3(numbers, positions, param))
(grad,) = torch.autograd.grad(energy, positions)

positions.detach_()
energy = torch.sum(dftd3(numbers, pos, param))
(grad,) = torch.autograd.grad(energy, pos)

assert pytest.approx(ref.cpu(), abs=tol) == grad.cpu()

Expand All @@ -183,21 +183,23 @@ def test_autograd(dtype: torch.dtype, name: str) -> None:
@pytest.mark.parametrize("name", sample_list)
def test_backward(dtype: torch.dtype, name: str) -> None:
"""Compare with reference values from tblite."""
dd: DD = {"device": DEVICE, "dtype": dtype}

sample = samples[name]
numbers = sample["numbers"]
positions = sample["positions"].type(dtype)
numbers = sample["numbers"].to(DEVICE)
positions = sample["positions"].to(**dd)

ref = sample["grad"].to(**dd)

# GFN1-xTB parameters
param = {
"s6": positions.new_tensor(1.00000000),
"s8": positions.new_tensor(2.40000000),
"s9": positions.new_tensor(0.00000000),
"a1": positions.new_tensor(0.63000000),
"a2": positions.new_tensor(5.00000000),
"s6": torch.tensor(1.00000000, **dd),
"s8": torch.tensor(2.40000000, **dd),
"s9": torch.tensor(0.00000000, **dd),
"a1": torch.tensor(0.63000000, **dd),
"a2": torch.tensor(5.00000000, **dd),
}

ref = sample["grad"].type(dtype)

# variable to be differentiated
positions.requires_grad_(True)

Expand All @@ -213,3 +215,38 @@ def test_backward(dtype: torch.dtype, name: str) -> None:
positions.grad.data.zero_()

assert pytest.approx(ref.cpu(), abs=tol) == grad_backward.cpu()


@pytest.mark.grad
@pytest.mark.parametrize("dtype", [torch.double])
@pytest.mark.parametrize("name", sample_list)
def test_functorch(dtype: torch.dtype, name: str) -> None:
"""Compare with reference values from tblite."""
dd: DD = {"device": DEVICE, "dtype": dtype}

sample = samples[name]
numbers = sample["numbers"].to(DEVICE)
positions = sample["positions"].to(**dd)

ref = sample["grad"].to(**dd)

# GFN1-xTB parameters
param = {
"s6": torch.tensor(1.00000000, **dd),
"s8": torch.tensor(2.40000000, **dd),
"s9": torch.tensor(0.00000000, **dd),
"a1": torch.tensor(0.63000000, **dd),
"a2": torch.tensor(5.00000000, **dd),
}

# variable to be differentiated
pos = positions.clone().requires_grad_(True)

def dftd3_func(p: Tensor) -> Tensor:
return dftd3(numbers, p, param).sum()

grad = jacrev(dftd3_func)(pos)
assert isinstance(grad, Tensor)

assert grad.shape == ref.shape
assert pytest.approx(ref.cpu(), abs=tol) == grad.detach().cpu()
Loading