Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement torch_tensor_backward #286

Draft
wants to merge 24 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
0ffec4f
Enable requires_grad in autograd example
jwallwork23 Jan 20, 2025
2135968
Implement torch_tensor_backward
jwallwork23 Jan 20, 2025
24559fe
Simplify autograd example
jwallwork23 Jan 20, 2025
be86846
Setup requires_grad properly; Use TensorOptions in tensor constructors
jwallwork23 Jan 20, 2025
10d2499
Implement get_gradient
jwallwork23 Jan 20, 2025
b124da0
Finish autograd example
jwallwork23 Jan 20, 2025
586de97
Unit test for gradient of assignment
jwallwork23 Jan 20, 2025
9cc506f
Unit test for gradient of addition
jwallwork23 Jan 20, 2025
13d28e8
Unit test for gradient of subtraction
jwallwork23 Jan 21, 2025
13d8e40
Unit test for gradient of negative
jwallwork23 Jan 21, 2025
d8f40b2
Unit test for gradient of multiplication
jwallwork23 Jan 21, 2025
abb303e
Unit test for gradient of division
jwallwork23 Jan 21, 2025
0c9a59a
Unit test for gradient of square
jwallwork23 Jan 21, 2025
f68d647
Unit test for gradient of square root
jwallwork23 Jan 21, 2025
f100c97
Unit test for gradient of scalar multiplication - FIXME
jwallwork23 Jan 21, 2025
55d3849
Unit test for gradient of scalar division - FIXME
jwallwork23 Jan 21, 2025
37e6873
Rename get_gradient and provide method
jwallwork23 Feb 17, 2025
bd28159
Drop unnecessary c_loc use
jwallwork23 Feb 17, 2025
c815301
FIXME- backward needs intent(inout)
jwallwork23 Feb 17, 2025
7e350e0
Drop unused imports
jwallwork23 Feb 17, 2025
3bc60dc
Set up external gradient of ones for now
jwallwork23 Feb 17, 2025
7e3fe43
Scalar multiplication and division using rank-1 tensors
jwallwork23 Feb 18, 2025
d42f3ed
Fix static analysis
jwallwork23 Feb 18, 2025
2c5e8cb
Apply cmake-format
jwallwork23 Feb 18, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/static_analysis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ jobs:
style: 'file'
tidy-checks: ''
# Use the compile_commands.json from CMake to locate headers
database: ${{ github.workspace }}/src/build
database: ${{ github.workspace }}/build
# only 'update' a single comment in a pull request thread.
thread-comments: ${{ github.event_name == 'pull_request' && 'update' }}
- name: Fail fast?!
Expand Down
80 changes: 46 additions & 34 deletions examples/6_Autograd/autograd.f90
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ program example
use, intrinsic :: iso_fortran_env, only : sp => real32

! Import our library for interfacing with PyTorch's Autograd module
use ftorch, only: assignment(=), operator(+), operator(-), operator(*), &
operator(/), operator(**), torch_kCPU, torch_tensor, torch_tensor_delete, &
torch_tensor_from_array, torch_tensor_to_array
use ftorch, only: assignment(=), operator(+), operator(-), operator(*), operator(/), &
operator(**), torch_kCPU, torch_tensor, torch_tensor_backward, &
torch_tensor_delete, torch_tensor_from_array, torch_tensor_to_array

! Import our tools module for testing utils
use ftorch_test_utils, only : assert_allclose
Expand All @@ -17,48 +17,56 @@ program example
integer, parameter :: wp = sp

! Set up Fortran data structures
integer, parameter :: n=2, m=1
real(wp), dimension(n,m), target :: in_data1
real(wp), dimension(n,m), target :: in_data2
real(wp), dimension(:,:), pointer :: out_data
real(wp), dimension(n,m) :: expected
integer :: tensor_layout(2) = [1, 2]

! Flag for testing
logical :: test_pass
integer, parameter :: n = 2
real(wp), dimension(:), pointer :: out_data1, out_data2, out_data3
integer :: tensor_layout(1) = [1]

! Set up Torch data structures
type(torch_tensor) :: a, b, Q
type(torch_tensor) :: a, b, Q, dQda, dQdb, multiplier, divisor

! Initialise input arrays as in Python example
in_data1(:,1) = [2.0_wp, 3.0_wp]
in_data2(:,1) = [6.0_wp, 4.0_wp]
! Initialise input arrays as in the Python example and construct Torch Tensors from them
call torch_tensor_from_array(a, [2.0_wp, 3.0_wp], tensor_layout, torch_kCPU, requires_grad=.true.)
call torch_tensor_from_array(b, [6.0_wp, 4.0_wp], tensor_layout, torch_kCPU, requires_grad=.true.)

! Construct a Torch Tensor from a Fortran array
! TODO: Implement requires_grad=.true.
call torch_tensor_from_array(a, in_data1, tensor_layout, torch_kCPU)
call torch_tensor_from_array(b, in_data2, tensor_layout, torch_kCPU)
! Scalar multiplication and division are not currently implemented in FTorch. However, you can
! achieve the same thing by defining a rank-1 tensor with a single entry, as follows:
call torch_tensor_from_array(multiplier, [3.0_wp], tensor_layout, torch_kCPU)
call torch_tensor_from_array(divisor, [3.0_wp], tensor_layout, torch_kCPU)

! Check arithmetic operations work for torch_tensors
write (*,*) "a = ", in_data1(:,1)
write (*,*) "b = ", in_data2(:,1)
Q = 3 * (a**3 - b * b / 3)
! Compute the same mathematical expression as in the Python example
Q = multiplier * (a**3 - b * b / divisor)

! Extract a Fortran array from a Torch tensor
call torch_tensor_to_array(Q, out_data, shape(in_data1))
write (*,*) "Q = 3 * (a ** 3 - b * b / 2) =", out_data(:,1)
! Extract a Fortran array from the Torch tensor
call torch_tensor_to_array(Q, out_data1, [2])
write (*,*) "Q = 3 * (a^3 - b*b/3) = 3*a^3 - b^2 = ", out_data1(:)

! Check output tensor matches expected value
expected(:,1) = [-12.0_wp, 65.0_wp]
test_pass = assert_allclose(out_data, expected, test_name="torch_tensor_to_array", rtol=1e-5)
if (.not. test_pass) then
if (.not. assert_allclose(out_data1, [-12.0_wp, 65.0_wp], test_name="autograd_Q")) then
call clean_up()
print *, "Error :: out_data does not match expected value"
print *, "Error :: value of Q does not match expected value"
stop 999
end if

! Back-propagation
! TODO: Requires API extension
! Run the back-propagation operator
call torch_tensor_backward(Q)
dQda = a%grad()
dQdb = b%grad()

! Extract Fortran arrays from the Torch tensors and check the gradients take expected values
call torch_tensor_to_array(dQda, out_data2, [2])
print *, "dQda = 9*a^2 = ", out_data2
if (.not. assert_allclose(out_data2, [36.0_wp, 81.0_wp], test_name="autograd_dQdb")) then
call clean_up()
print *, "Error :: value of dQdb does not match expected value"
stop 999
end if
call torch_tensor_to_array(dQdb, out_data3, [2])
print *, "dQdb = - 2*b = ", out_data3
if (.not. assert_allclose(out_data3, [-12.0_wp, -8.0_wp], test_name="autograd_dQdb")) then
call clean_up()
print *, "Error :: value of dQdb does not match expected value"
stop 999
end if

call clean_up()
write (*,*) "Autograd example ran successfully"
Expand All @@ -67,9 +75,13 @@ program example

! Subroutine for freeing memory and nullifying pointers used in the example
subroutine clean_up()
nullify(out_data)
nullify(out_data1)
nullify(out_data2)
nullify(out_data3)
call torch_tensor_delete(a)
call torch_tensor_delete(b)
call torch_tensor_delete(multiplier)
call torch_tensor_delete(divisor)
call torch_tensor_delete(Q)
end subroutine clean_up

Expand Down
50 changes: 40 additions & 10 deletions src/ctorch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,12 @@ torch_tensor_t torch_empty(int ndim, const int64_t *shape, torch_data_t dtype,
try {
// This doesn't throw if shape and dimensions are incompatible
c10::IntArrayRef vshape(shape, ndim);
auto options = torch::TensorOptions()
.dtype(get_libtorch_dtype(dtype))
.device(get_libtorch_device(device_type, device_index))
.requires_grad(requires_grad);
tensor = new torch::Tensor;
*tensor = torch::empty(vshape, torch::dtype(get_libtorch_dtype(dtype)))
.to(get_libtorch_device(device_type, device_index));
*tensor = torch::empty(vshape, options);
} catch (const torch::Error &e) {
std::cerr << "[ERROR]: " << e.msg() << std::endl;
delete tensor;
Expand All @@ -145,9 +148,12 @@ torch_tensor_t torch_zeros(int ndim, const int64_t *shape, torch_data_t dtype,
try {
// This doesn't throw if shape and dimensions are incompatible
c10::IntArrayRef vshape(shape, ndim);
auto options = torch::TensorOptions()
.dtype(get_libtorch_dtype(dtype))
.device(get_libtorch_device(device_type, device_index))
.requires_grad(requires_grad);
tensor = new torch::Tensor;
*tensor = torch::zeros(vshape, torch::dtype(get_libtorch_dtype(dtype)))
.to(get_libtorch_device(device_type, device_index));
*tensor = torch::zeros(vshape, options);
} catch (const torch::Error &e) {
std::cerr << "[ERROR]: " << e.msg() << std::endl;
delete tensor;
Expand All @@ -168,9 +174,12 @@ torch_tensor_t torch_ones(int ndim, const int64_t *shape, torch_data_t dtype,
try {
// This doesn't throw if shape and dimensions are incompatible
c10::IntArrayRef vshape(shape, ndim);
auto options = torch::TensorOptions()
.dtype(get_libtorch_dtype(dtype))
.device(get_libtorch_device(device_type, device_index))
.requires_grad(requires_grad);
tensor = new torch::Tensor;
*tensor = torch::ones(vshape, torch::dtype(get_libtorch_dtype(dtype)))
.to(get_libtorch_device(device_type, device_index));
*tensor = torch::ones(vshape, options);
} catch (const torch::Error &e) {
std::cerr << "[ERROR]: " << e.msg() << std::endl;
delete tensor;
Expand All @@ -196,10 +205,12 @@ torch_tensor_t torch_from_blob(void *data, int ndim, const int64_t *shape,
// This doesn't throw if shape and dimensions are incompatible
c10::IntArrayRef vshape(shape, ndim);
c10::IntArrayRef vstrides(strides, ndim);
auto options = torch::TensorOptions()
.dtype(get_libtorch_dtype(dtype))
.device(get_libtorch_device(device_type, device_index))
.requires_grad(requires_grad);
tensor = new torch::Tensor;
*tensor = torch::from_blob(data, vshape, vstrides,
torch::dtype(get_libtorch_dtype(dtype)))
.to(get_libtorch_device(device_type, device_index));
*tensor = torch::from_blob(data, vshape, vstrides, options);

} catch (const torch::Error &e) {
std::cerr << "[ERROR]: " << e.msg() << std::endl;
Expand Down Expand Up @@ -310,7 +321,7 @@ torch_tensor_t torch_tensor_assign(const torch_tensor_t input) {
torch::AutoGradMode enable_grad(in->requires_grad());
torch::Tensor *output = nullptr;
output = new torch::Tensor;
*output = in->detach().clone();
*output = *in;
return output;
}

Expand Down Expand Up @@ -384,6 +395,25 @@ torch_tensor_t torch_tensor_power_float(const torch_tensor_t tensor,
return output;
}

// =============================================================================
// --- Functions related to automatic differentiation functionality for tensors
// =============================================================================

void torch_tensor_backward(const torch_tensor_t tensor,
const torch_tensor_t external_gradient) {
auto t = reinterpret_cast<torch::Tensor *>(tensor);
auto g = reinterpret_cast<torch::Tensor *const>(external_gradient);
t->backward(*g);
}

EXPORT_C torch_tensor_t torch_tensor_get_gradient(const torch_tensor_t tensor) {
auto t = reinterpret_cast<torch::Tensor *const>(tensor);
torch::Tensor *output = nullptr;
output = new torch::Tensor;
*output = t->grad();
return output;
}

// =============================================================================
// --- Torch model API
// =============================================================================
Expand Down
20 changes: 20 additions & 0 deletions src/ctorch.h
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,26 @@ EXPORT_C torch_tensor_t torch_tensor_power_int(const torch_tensor_t tensor,
EXPORT_C torch_tensor_t torch_tensor_power_float(const torch_tensor_t tensor,
const torch_float_t exponent);

// =============================================================================
// --- Functions related to automatic differentiation functionality for tensors
// =============================================================================

/**
* Function to perform back-propagation on a Torch Tensor.
* Note that the Tensor must have the requires_grad attribute set to true.
* @param Tensor to perform back-propagation on
* @param Tensor with an external gradient to supply for the back-propagation
*/
EXPORT_C void torch_tensor_backward(const torch_tensor_t tensor,
const torch_tensor_t external_gradient);

/**
* Function to return the grad attribute of a Torch Tensor.
* @param Tensor to get the gradient of
* @return Tensor for the gradient
*/
EXPORT_C torch_tensor_t torch_tensor_get_gradient(const torch_tensor_t tensor);

// =============================================================================
// --- Torch model API
// =============================================================================
Expand Down
Loading