From 0ffec4f1587d6e27e84d798bcf690222f0090ed0 Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Mon, 20 Jan 2025 11:39:01 +0000 Subject: [PATCH 01/24] Enable requires_grad in autograd example --- examples/6_Autograd/autograd.f90 | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/6_Autograd/autograd.f90 b/examples/6_Autograd/autograd.f90 index 20067e90..889bcf47 100644 --- a/examples/6_Autograd/autograd.f90 +++ b/examples/6_Autograd/autograd.f90 @@ -35,9 +35,8 @@ program example in_data2(:,1) = [6.0_wp, 4.0_wp] ! Construct a Torch Tensor from a Fortran array - ! TODO: Implement requires_grad=.true. - call torch_tensor_from_array(a, in_data1, tensor_layout, torch_kCPU) - call torch_tensor_from_array(b, in_data2, tensor_layout, torch_kCPU) + call torch_tensor_from_array(a, in_data1, tensor_layout, torch_kCPU, requires_grad=.true.) + call torch_tensor_from_array(b, in_data2, tensor_layout, torch_kCPU, requires_grad=.true.) ! Check arithmetic operations work for torch_tensors write (*,*) "a = ", in_data1(:,1) From 2135968cb5ffb313fa8621eecbd81c38e5a91883 Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Mon, 20 Jan 2025 11:40:22 +0000 Subject: [PATCH 02/24] Implement torch_tensor_backward --- src/ctorch.cpp | 11 +++++++++++ src/ctorch.h | 13 +++++++++++++ src/ftorch.F90 | 24 ++++++++++++++++++++++++ src/ftorch.fypp | 24 ++++++++++++++++++++++++ 4 files changed, 72 insertions(+) diff --git a/src/ctorch.cpp b/src/ctorch.cpp index b5d5fd5f..5f1c6863 100644 --- a/src/ctorch.cpp +++ b/src/ctorch.cpp @@ -384,6 +384,17 @@ torch_tensor_t torch_tensor_power_float(const torch_tensor_t tensor, return output; } +// ============================================================================= +// --- Functions related to automatic differentiation functionality for tensors +// ============================================================================= + +void torch_tensor_backward(const torch_tensor_t tensor, + const torch_tensor_t external_gradient) { + auto t = reinterpret_cast(tensor); + auto g = reinterpret_cast(external_gradient); + t->backward(*g); +} + // ============================================================================= // --- Torch model API // ============================================================================= diff --git a/src/ctorch.h b/src/ctorch.h index 1123d6eb..0a9ed223 100644 --- a/src/ctorch.h +++ b/src/ctorch.h @@ -242,6 +242,19 @@ EXPORT_C torch_tensor_t torch_tensor_power_int(const torch_tensor_t tensor, EXPORT_C torch_tensor_t torch_tensor_power_float(const torch_tensor_t tensor, const torch_float_t exponent); +// ============================================================================= +// --- Functions related to automatic differentiation functionality for tensors +// ============================================================================= + +/** + * Function to perform back-propagation on a Torch Tensor. + * Note that the Tensor must have the requires_grad attribute set to true. + * @param Tensor to perform back-propagation on + * @param Tensor with an external gradient to supply for the back-propagation + */ +EXPORT_C void torch_tensor_backward(const torch_tensor_t tensor, + const torch_tensor_t external_gradient); + // ============================================================================= // --- Torch model API // ============================================================================= diff --git a/src/ftorch.F90 b/src/ftorch.F90 index c3612a13..eceedbbe 100644 --- a/src/ftorch.F90 +++ b/src/ftorch.F90 @@ -2895,6 +2895,30 @@ end function torch_tensor_power_float_c end function torch_tensor_power_real64 + ! ============================================================================ + ! --- Procedures related to automatic differentation functionality for tensors + ! ============================================================================ + + !> Performs back-propagation on a Torch Tensor, given some external gradient. + subroutine torch_tensor_backward(tensor, external_gradient) + type(torch_tensor), intent(in) :: tensor + type(torch_tensor), intent(in) :: external_gradient + + interface + subroutine torch_tensor_backward_c(tensor_c, external_gradient_c) & + bind(c, name = 'torch_tensor_backward') + use, intrinsic :: iso_c_binding, only : c_ptr + implicit none + type(c_ptr), value, intent(in) :: tensor_c + type(c_ptr), value, intent(in) :: external_gradient_c + end subroutine torch_tensor_backward_c + end interface + + ! TODO: Make external_gradient optional, setting to ones by default + + call torch_tensor_backward_c(tensor%p, external_gradient%p) + end subroutine torch_tensor_backward + ! ============================================================================ ! --- Torch Model API ! ============================================================================ diff --git a/src/ftorch.fypp b/src/ftorch.fypp index 0f57e0d8..b9c564d4 100644 --- a/src/ftorch.fypp +++ b/src/ftorch.fypp @@ -830,6 +830,30 @@ contains #:endfor + ! ============================================================================ + ! --- Procedures related to automatic differentation functionality for tensors + ! ============================================================================ + + !> Performs back-propagation on a Torch Tensor, given some external gradient. + subroutine torch_tensor_backward(tensor, external_gradient) + type(torch_tensor), intent(in) :: tensor + type(torch_tensor), intent(in) :: external_gradient + + interface + subroutine torch_tensor_backward_c(tensor_c, external_gradient_c) & + bind(c, name = 'torch_tensor_backward') + use, intrinsic :: iso_c_binding, only : c_ptr + implicit none + type(c_ptr), value, intent(in) :: tensor_c + type(c_ptr), value, intent(in) :: external_gradient_c + end subroutine torch_tensor_backward_c + end interface + + ! TODO: Make external_gradient optional, setting to ones by default + + call torch_tensor_backward_c(tensor%p, external_gradient%p) + end subroutine torch_tensor_backward + ! ============================================================================ ! --- Torch Model API ! ============================================================================ From 24559fe020d5e74f50eb2d355cb018f1dadae12a Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Mon, 20 Jan 2025 11:53:40 +0000 Subject: [PATCH 03/24] Simplify autograd example --- examples/6_Autograd/autograd.f90 | 42 +++++++++++++++----------------- 1 file changed, 20 insertions(+), 22 deletions(-) diff --git a/examples/6_Autograd/autograd.f90 b/examples/6_Autograd/autograd.f90 index 889bcf47..2df0630c 100644 --- a/examples/6_Autograd/autograd.f90 +++ b/examples/6_Autograd/autograd.f90 @@ -4,9 +4,9 @@ program example use, intrinsic :: iso_fortran_env, only : sp => real32 ! Import our library for interfacing with PyTorch's Autograd module - use ftorch, only: assignment(=), operator(+), operator(-), operator(*), & - operator(/), operator(**), torch_kCPU, torch_tensor, torch_tensor_delete, & - torch_tensor_from_array, torch_tensor_to_array + use ftorch, only: assignment(=), operator(+), operator(-), operator(*), operator(/), & + operator(**), torch_kCPU, torch_tensor, torch_tensor_backward, & + torch_tensor_delete, torch_tensor_from_array, torch_tensor_to_array ! Import our tools module for testing utils use ftorch_test_utils, only : assert_allclose @@ -17,47 +17,44 @@ program example integer, parameter :: wp = sp ! Set up Fortran data structures - integer, parameter :: n=2, m=1 - real(wp), dimension(n,m), target :: in_data1 - real(wp), dimension(n,m), target :: in_data2 - real(wp), dimension(:,:), pointer :: out_data - real(wp), dimension(n,m) :: expected - integer :: tensor_layout(2) = [1, 2] - - ! Flag for testing - logical :: test_pass + integer, parameter :: n = 2 + real(wp), dimension(n), target :: in_data1, in_data2, in_data3 + real(wp), dimension(:), pointer :: out_data + real(wp), dimension(n) :: expected + integer :: tensor_layout(1) = [1] ! Set up Torch data structures - type(torch_tensor) :: a, b, Q + type(torch_tensor) :: a, b, Q, external_gradient ! Initialise input arrays as in Python example - in_data1(:,1) = [2.0_wp, 3.0_wp] - in_data2(:,1) = [6.0_wp, 4.0_wp] + in_data1(:) = [2.0_wp, 3.0_wp] + in_data2(:) = [6.0_wp, 4.0_wp] ! Construct a Torch Tensor from a Fortran array call torch_tensor_from_array(a, in_data1, tensor_layout, torch_kCPU, requires_grad=.true.) call torch_tensor_from_array(b, in_data2, tensor_layout, torch_kCPU, requires_grad=.true.) ! Check arithmetic operations work for torch_tensors - write (*,*) "a = ", in_data1(:,1) - write (*,*) "b = ", in_data2(:,1) + write (*,*) "a = ", in_data1(:) + write (*,*) "b = ", in_data2(:) Q = 3 * (a**3 - b * b / 3) ! Extract a Fortran array from a Torch tensor call torch_tensor_to_array(Q, out_data, shape(in_data1)) - write (*,*) "Q = 3 * (a ** 3 - b * b / 2) =", out_data(:,1) + write (*,*) "Q = 3 * (a ** 3 - b * b / 2) =", out_data(:) ! Check output tensor matches expected value - expected(:,1) = [-12.0_wp, 65.0_wp] - test_pass = assert_allclose(out_data, expected, test_name="torch_tensor_to_array", rtol=1e-5) - if (.not. test_pass) then + expected(:) = [-12.0_wp, 65.0_wp] + if (.not. assert_allclose(out_data, expected, test_name="torch_tensor_to_array")) then call clean_up() print *, "Error :: out_data does not match expected value" stop 999 end if ! Back-propagation - ! TODO: Requires API extension + in_data3(:) = [1.0_wp, 1.0_wp] + call torch_tensor_from_array(external_gradient, in_data3, tensor_layout, torch_kCPU) + call torch_tensor_backward(Q, external_gradient) call clean_up() write (*,*) "Autograd example ran successfully" @@ -70,6 +67,7 @@ subroutine clean_up() call torch_tensor_delete(a) call torch_tensor_delete(b) call torch_tensor_delete(Q) + call torch_tensor_delete(external_gradient) end subroutine clean_up end program example From be868464af5fe677e295fecca75bbff1621a5d17 Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Mon, 20 Jan 2025 12:13:58 +0000 Subject: [PATCH 04/24] Setup requires_grad properly; Use TensorOptions in tensor constructors --- src/ctorch.cpp | 39 +++++++++++++++++++++++++++++---------- 1 file changed, 29 insertions(+), 10 deletions(-) diff --git a/src/ctorch.cpp b/src/ctorch.cpp index 5f1c6863..b35a761d 100644 --- a/src/ctorch.cpp +++ b/src/ctorch.cpp @@ -122,9 +122,12 @@ torch_tensor_t torch_empty(int ndim, const int64_t *shape, torch_data_t dtype, try { // This doesn't throw if shape and dimensions are incompatible c10::IntArrayRef vshape(shape, ndim); + auto options = torch::TensorOptions() + .dtype(get_libtorch_dtype(dtype)) + .device(get_libtorch_device(device_type, device_index)) + .requires_grad(requires_grad); tensor = new torch::Tensor; - *tensor = torch::empty(vshape, torch::dtype(get_libtorch_dtype(dtype))) - .to(get_libtorch_device(device_type, device_index)); + *tensor = torch::empty(vshape, options); } catch (const torch::Error &e) { std::cerr << "[ERROR]: " << e.msg() << std::endl; delete tensor; @@ -145,9 +148,12 @@ torch_tensor_t torch_zeros(int ndim, const int64_t *shape, torch_data_t dtype, try { // This doesn't throw if shape and dimensions are incompatible c10::IntArrayRef vshape(shape, ndim); + auto options = torch::TensorOptions() + .dtype(get_libtorch_dtype(dtype)) + .device(get_libtorch_device(device_type, device_index)) + .requires_grad(requires_grad); tensor = new torch::Tensor; - *tensor = torch::zeros(vshape, torch::dtype(get_libtorch_dtype(dtype))) - .to(get_libtorch_device(device_type, device_index)); + *tensor = torch::zeros(vshape, options); } catch (const torch::Error &e) { std::cerr << "[ERROR]: " << e.msg() << std::endl; delete tensor; @@ -168,9 +174,12 @@ torch_tensor_t torch_ones(int ndim, const int64_t *shape, torch_data_t dtype, try { // This doesn't throw if shape and dimensions are incompatible c10::IntArrayRef vshape(shape, ndim); + auto options = torch::TensorOptions() + .dtype(get_libtorch_dtype(dtype)) + .device(get_libtorch_device(device_type, device_index)) + .requires_grad(requires_grad); tensor = new torch::Tensor; - *tensor = torch::ones(vshape, torch::dtype(get_libtorch_dtype(dtype))) - .to(get_libtorch_device(device_type, device_index)); + *tensor = torch::ones(vshape, options); } catch (const torch::Error &e) { std::cerr << "[ERROR]: " << e.msg() << std::endl; delete tensor; @@ -196,11 +205,14 @@ torch_tensor_t torch_from_blob(void *data, int ndim, const int64_t *shape, // This doesn't throw if shape and dimensions are incompatible c10::IntArrayRef vshape(shape, ndim); c10::IntArrayRef vstrides(strides, ndim); + auto options = torch::TensorOptions() + .dtype(get_libtorch_dtype(dtype)) + .device(get_libtorch_device(device_type, device_index)) + .requires_grad(requires_grad); tensor = new torch::Tensor; - *tensor = torch::from_blob(data, vshape, vstrides, - torch::dtype(get_libtorch_dtype(dtype))) - .to(get_libtorch_device(device_type, device_index)); + *tensor = torch::from_blob(data, vshape, vstrides, options); + std::cout << "[DEBUG]: blob " << tensor->requires_grad() << std::endl; // TODO } catch (const torch::Error &e) { std::cerr << "[ERROR]: " << e.msg() << std::endl; delete tensor; @@ -310,7 +322,8 @@ torch_tensor_t torch_tensor_assign(const torch_tensor_t input) { torch::AutoGradMode enable_grad(in->requires_grad()); torch::Tensor *output = nullptr; output = new torch::Tensor; - *output = in->detach().clone(); + *output = *in; + std::cout << "[DEBUG]: assign " << output->requires_grad() << std::endl; // TODO return output; } @@ -321,6 +334,7 @@ torch_tensor_t torch_tensor_add(const torch_tensor_t tensor1, torch::Tensor *output = nullptr; output = new torch::Tensor; *output = *t1 + *t2; + std::cout << "[DEBUG]: add " << output->requires_grad() << std::endl; // TODO return output; } @@ -339,6 +353,7 @@ torch_tensor_t torch_tensor_subtract(const torch_tensor_t tensor1, torch::Tensor *output = nullptr; output = new torch::Tensor; *output = *t1 - *t2; + std::cout << "[DEBUG]: subtract " << output->requires_grad() << std::endl; // TODO return output; } @@ -349,6 +364,7 @@ torch_tensor_t torch_tensor_multiply(const torch_tensor_t tensor1, torch::Tensor *output = nullptr; output = new torch::Tensor; *output = *t1 * *t2; + std::cout << "[DEBUG]: multiply " << output->requires_grad() << std::endl; // TODO return output; } @@ -359,6 +375,7 @@ torch_tensor_t torch_tensor_divide(const torch_tensor_t tensor1, torch::Tensor *output = nullptr; output = new torch::Tensor; *output = *t1 / *t2; + std::cout << "[DEBUG]: divide " << output->requires_grad() << std::endl; // TODO return output; } @@ -370,6 +387,7 @@ torch_tensor_t torch_tensor_power_int(const torch_tensor_t tensor, torch::Tensor *output = nullptr; output = new torch::Tensor; *output = pow(*t, *exp); + std::cout << "[DEBUG]: power_int " << output->requires_grad() << std::endl; // TODO return output; } @@ -381,6 +399,7 @@ torch_tensor_t torch_tensor_power_float(const torch_tensor_t tensor, torch::Tensor *output = nullptr; output = new torch::Tensor; *output = pow(*t, *exp); + std::cout << "[DEBUG]: power_float " << output->requires_grad() << std::endl; // TODO return output; } From 10d2499b6b4cd7b200a80a3a5492796ad96c0ec6 Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Mon, 20 Jan 2025 15:23:01 +0000 Subject: [PATCH 05/24] Implement get_gradient --- src/ctorch.cpp | 8 ++++++++ src/ctorch.h | 7 +++++++ src/ftorch.F90 | 18 ++++++++++++++++++ src/ftorch.fypp | 18 ++++++++++++++++++ 4 files changed, 51 insertions(+) diff --git a/src/ctorch.cpp b/src/ctorch.cpp index b35a761d..1267be51 100644 --- a/src/ctorch.cpp +++ b/src/ctorch.cpp @@ -414,6 +414,14 @@ void torch_tensor_backward(const torch_tensor_t tensor, t->backward(*g); } +EXPORT_C torch_tensor_t get_gradient(const torch_tensor_t tensor) { + auto t = reinterpret_cast(tensor); + torch::Tensor *output = nullptr; + output = new torch::Tensor; + *output = t->grad(); + return output; +} + // ============================================================================= // --- Torch model API // ============================================================================= diff --git a/src/ctorch.h b/src/ctorch.h index 0a9ed223..6f39b793 100644 --- a/src/ctorch.h +++ b/src/ctorch.h @@ -255,6 +255,13 @@ EXPORT_C torch_tensor_t torch_tensor_power_float(const torch_tensor_t tensor, EXPORT_C void torch_tensor_backward(const torch_tensor_t tensor, const torch_tensor_t external_gradient); +/** + * Function to return the grad attribute of a Torch Tensor. + * @param Tensor to get the gradient of + * @return Tensor for the gradient + */ +EXPORT_C torch_tensor_t get_gradient(const torch_tensor_t tensor); + // ============================================================================= // --- Torch model API // ============================================================================= diff --git a/src/ftorch.F90 b/src/ftorch.F90 index eceedbbe..4f77a0ce 100644 --- a/src/ftorch.F90 +++ b/src/ftorch.F90 @@ -2919,6 +2919,24 @@ end subroutine torch_tensor_backward_c call torch_tensor_backward_c(tensor%p, external_gradient%p) end subroutine torch_tensor_backward + !> Retreives the gradient of a Torch Tensor. + function get_gradient(tensor) result(gradient) + type(torch_tensor), intent(in) :: tensor + type(torch_tensor) :: gradient + + interface + function get_gradient_c(tensor_c) result(gradient_c) & + bind(c, name = 'get_gradient') + use, intrinsic :: iso_c_binding, only : c_ptr + implicit none + type(c_ptr), value, intent(in) :: tensor_c + type(c_ptr) :: gradient_c + end function get_gradient_c + end interface + + gradient%p = get_gradient_c(tensor%p) + end function get_gradient + ! ============================================================================ ! --- Torch Model API ! ============================================================================ diff --git a/src/ftorch.fypp b/src/ftorch.fypp index b9c564d4..075e96bc 100644 --- a/src/ftorch.fypp +++ b/src/ftorch.fypp @@ -854,6 +854,24 @@ contains call torch_tensor_backward_c(tensor%p, external_gradient%p) end subroutine torch_tensor_backward + !> Retreives the gradient of a Torch Tensor. + function get_gradient(tensor) result(gradient) + type(torch_tensor), intent(in) :: tensor + type(torch_tensor) :: gradient + + interface + function get_gradient_c(tensor_c) result(gradient_c) & + bind(c, name = 'get_gradient') + use, intrinsic :: iso_c_binding, only : c_ptr + implicit none + type(c_ptr), value, intent(in) :: tensor_c + type(c_ptr) :: gradient_c + end function get_gradient_c + end interface + + gradient%p = get_gradient_c(tensor%p) + end function get_gradient + ! ============================================================================ ! --- Torch Model API ! ============================================================================ From b124da0d8f96bcb6749de81cb9324d5b808710d1 Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Mon, 20 Jan 2025 15:23:25 +0000 Subject: [PATCH 06/24] Finish autograd example --- examples/6_Autograd/autograd.f90 | 39 +++++++++++++++++++++++++------- 1 file changed, 31 insertions(+), 8 deletions(-) diff --git a/examples/6_Autograd/autograd.f90 b/examples/6_Autograd/autograd.f90 index 2df0630c..cc0b266b 100644 --- a/examples/6_Autograd/autograd.f90 +++ b/examples/6_Autograd/autograd.f90 @@ -5,7 +5,7 @@ program example ! Import our library for interfacing with PyTorch's Autograd module use ftorch, only: assignment(=), operator(+), operator(-), operator(*), operator(/), & - operator(**), torch_kCPU, torch_tensor, torch_tensor_backward, & + operator(**), get_gradient, torch_kCPU, torch_tensor, torch_tensor_backward, & torch_tensor_delete, torch_tensor_from_array, torch_tensor_to_array ! Import our tools module for testing utils @@ -19,12 +19,12 @@ program example ! Set up Fortran data structures integer, parameter :: n = 2 real(wp), dimension(n), target :: in_data1, in_data2, in_data3 - real(wp), dimension(:), pointer :: out_data + real(wp), dimension(:), pointer :: out_data1, out_data2, out_data3 real(wp), dimension(n) :: expected integer :: tensor_layout(1) = [1] ! Set up Torch data structures - type(torch_tensor) :: a, b, Q, external_gradient + type(torch_tensor) :: a, b, Q, external_gradient, dQda, dQdb ! Initialise input arrays as in Python example in_data1(:) = [2.0_wp, 3.0_wp] @@ -38,16 +38,17 @@ program example write (*,*) "a = ", in_data1(:) write (*,*) "b = ", in_data2(:) Q = 3 * (a**3 - b * b / 3) + ! FIXME: Something seems off with gradients related to scalar multiplication and/or division ! Extract a Fortran array from a Torch tensor - call torch_tensor_to_array(Q, out_data, shape(in_data1)) - write (*,*) "Q = 3 * (a ** 3 - b * b / 2) =", out_data(:) + call torch_tensor_to_array(Q, out_data1, shape(in_data1)) + write (*,*) "Q = 3 * (a ** 3 - b * b / 2) =", out_data1(:) ! Check output tensor matches expected value expected(:) = [-12.0_wp, 65.0_wp] - if (.not. assert_allclose(out_data, expected, test_name="torch_tensor_to_array")) then + if (.not. assert_allclose(out_data1, expected, test_name="autograd_Q")) then call clean_up() - print *, "Error :: out_data does not match expected value" + print *, "Error :: value of Q does not match expected value" stop 999 end if @@ -55,6 +56,26 @@ program example in_data3(:) = [1.0_wp, 1.0_wp] call torch_tensor_from_array(external_gradient, in_data3, tensor_layout, torch_kCPU) call torch_tensor_backward(Q, external_gradient) + dQda = get_gradient(a) + dQdb = get_gradient(b) + + ! Extract Fortran arrays from the Torch tensors and check the gradients take expected values + call torch_tensor_to_array(dQda, out_data2, shape(in_data1)) + print *, "dQda", out_data2 + expected(:) = [36.0_wp, 81.0_wp] + if (.not. assert_allclose(out_data2, expected, test_name="autograd_dQdb")) then + call clean_up() + print *, "Error :: value of dQdb does not match expected value" + stop 999 + end if + call torch_tensor_to_array(dQdb, out_data3, shape(in_data1)) + print *, "dQdb", out_data3 + expected(:) = [-12.0_wp, -8.0_wp] + if (.not. assert_allclose(out_data3, expected, test_name="autograd_dQdb")) then + call clean_up() + print *, "Error :: value of dQdb does not match expected value" + stop 999 + end if call clean_up() write (*,*) "Autograd example ran successfully" @@ -63,7 +84,9 @@ program example ! Subroutine for freeing memory and nullifying pointers used in the example subroutine clean_up() - nullify(out_data) + nullify(out_data1) + nullify(out_data2) + nullify(out_data3) call torch_tensor_delete(a) call torch_tensor_delete(b) call torch_tensor_delete(Q) From 586de97543e36c0749bcf0fac8230cafd93d0ae7 Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Mon, 20 Jan 2025 16:14:34 +0000 Subject: [PATCH 07/24] Unit test for gradient of assignment --- test/unit/CMakeLists.txt | 2 + ...test_tensor_operator_overloads_autograd.pf | 131 ++++++++++++++++++ 2 files changed, 133 insertions(+) create mode 100644 test/unit/test_tensor_operator_overloads_autograd.pf diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt index fc5705b0..410e5b9a 100644 --- a/test/unit/CMakeLists.txt +++ b/test/unit/CMakeLists.txt @@ -14,6 +14,8 @@ add_pfunit_ctest(test_tensor_interrogation TEST_SOURCES test_tensor_interrogation.pf LINK_LIBRARIES FTorch::ftorch) add_pfunit_ctest(test_operator_overloads TEST_SOURCES test_tensor_operator_overloads.pf LINK_LIBRARIES FTorch::ftorch) +add_pfunit_ctest(test_operator_overloads_autograd + TEST_SOURCES test_tensor_operator_overloads_autograd.pf LINK_LIBRARIES FTorch::ftorch) if(ENABLE_CUDA) check_language(CUDA) diff --git a/test/unit/test_tensor_operator_overloads_autograd.pf b/test/unit/test_tensor_operator_overloads_autograd.pf new file mode 100644 index 00000000..77144881 --- /dev/null +++ b/test/unit/test_tensor_operator_overloads_autograd.pf @@ -0,0 +1,131 @@ +!| Unit tests for FTorch's automatic differentiation of overloaded operators involving tensors. +! +! * License +! FTorch is released under an MIT license. +! See the [LICENSE](https://github.com/Cambridge-ICCS/FTorch/blob/main/LICENSE) +! file for details. +module test_tensor_operator_overloads_autograd + use funit + use ftorch, only: assignment(=), get_gradient, ftorch_int, torch_kCPU, torch_kFloat32, & + torch_tensor, torch_tensor_backward, torch_tensor_delete, torch_tensor_empty, & + torch_tensor_from_array, torch_tensor_ones, torch_tensor_to_array + use ftorch_test_utils, only: assert_allclose + use, intrinsic :: iso_c_binding, only : c_associated, c_int64_t + + implicit none + + public + + integer, parameter :: device_type = torch_kCPU + + ! Typedef holding a set of parameter values + @testParameter + type, extends(AbstractTestParameter) :: TestParametersType + logical :: switch + contains + procedure :: toString + end type TestParametersType + + ! Typedef for a test case with a particular set of parameters + @testCase(constructor=test_case_ctor) + type, extends (ParameterizedTestCase) :: TestCaseType + type(TestParametersType) :: param + end type TestCaseType + +contains + + ! A fixture comprised of a full list of parameter sets + function get_parameters_full() result(params) + type(TestParametersType), allocatable :: params(:) + params = [ & + TestParametersType(.false.), & + TestParametersType(.true.) & + ] + end function get_parameters_full + + ! A fixture comprised of a short list of parameter sets + function get_parameters_short() result(params) + type(TestParametersType), allocatable :: params(:) + params = [TestParametersType(.false.)] + end function get_parameters_short + + ! Constructor for the test case type + function test_case_ctor(param) + type(TestCaseType) :: test_case_ctor + type(TestParametersType) :: param + test_case_ctor%param = param + end function test_case_ctor + + ! Function for representing a parameter set as a string + function toString(this) result(string) + class(TestParametersType), intent(in) :: this + character(:), allocatable :: string + character(len=1) :: str + write(str,'(l1)') this%switch + string = str + end function toString + + @test(testParameters={get_parameters_short()}) + subroutine test_torch_tensor_assign(this) + use, intrinsic :: iso_fortran_env, only: sp => real32 + + implicit none + + ! Set working precision for reals + integer, parameter :: wp = sp + + class(TestCaseType), intent(inout) :: this + type(torch_tensor) :: Q, a, external_gradient, dQda + integer, parameter :: ndims = 2 + integer(ftorch_int), parameter :: tensor_layout(ndims) = [1, 2] + integer(c_int64_t), parameter, dimension(ndims) :: tensor_shape = [2, 3] + integer, parameter :: dtype = torch_kFloat32 + real(wp), dimension(2,3), target :: in_data + real(wp), dimension(:,:), pointer :: out_data + real(wp), dimension(2,3) :: expected + + ! Create an arbitrary input array + in_data(:,:) = reshape([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [2, 3]) + + ! Create a tensor based off an input array + call torch_tensor_from_array(a, in_data, tensor_layout, device_type, requires_grad=.true.) + + ! Create another empty tensor and assign it to the first using the overloaded assignment + ! operator + call torch_tensor_empty(Q, ndims, tensor_shape, dtype, device_type) + Q = a + + ! Apply back-propagation + ! TODO: Automate choice of ones for external_gradient + call torch_tensor_ones(external_gradient, ndims, tensor_shape, dtype, device_type) + call torch_tensor_backward(Q, external_gradient) + dQda = get_gradient(a) + + ! Extract Fortran array from the computed gradient and its data with the expected value: + ! Q(a) = a => dQ/da = 1 + call torch_tensor_to_array(dQda, out_data, shape(in_data)) + expected(:,:) = 1.0 + if (.not. assert_allclose(out_data, expected, test_name="test_torch_tensor_assign")) then + call clean_up() + print *, "Error :: incorrect gradient for assignment" + stop 999 + end if + + call clean_up() + + contains + + ! Subroutine for freeing memory and nullifying pointers used in the unit test + subroutine clean_up() + nullify(out_data) + call torch_tensor_delete(a) + call torch_tensor_delete(Q) + call torch_tensor_delete(dQda) + call torch_tensor_delete(external_gradient) + end subroutine clean_up + + end subroutine test_torch_tensor_assign + + ! TODO: Other operators + +end module test_tensor_operator_overloads_autograd From 9cc506fb60f643eb889bd4cdb1e159b483ac953c Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Mon, 20 Jan 2025 17:52:47 +0000 Subject: [PATCH 08/24] Unit test for gradient of addition --- ...test_tensor_operator_overloads_autograd.pf | 77 +++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/test/unit/test_tensor_operator_overloads_autograd.pf b/test/unit/test_tensor_operator_overloads_autograd.pf index 77144881..dbbef3cd 100644 --- a/test/unit/test_tensor_operator_overloads_autograd.pf +++ b/test/unit/test_tensor_operator_overloads_autograd.pf @@ -126,6 +126,83 @@ contains end subroutine test_torch_tensor_assign + @test(testParameters={get_parameters_short()}) + subroutine test_torch_tensor_add(this) + use, intrinsic :: iso_fortran_env, only: sp => real32 + use ftorch, only: operator(+) + + implicit none + + ! Set working precision for reals + integer, parameter :: wp = sp + + class(TestCaseType), intent(inout) :: this + type(torch_tensor) :: Q, a, b, external_gradient, dQda, dQdb + integer, parameter :: ndims = 2 + integer(ftorch_int), parameter :: tensor_layout(ndims) = [1, 2] + integer(c_int64_t), parameter, dimension(ndims) :: tensor_shape = [2, 3] + integer, parameter :: dtype = torch_kFloat32 + real(wp), dimension(2,3), target :: in_data1, in_data2 + real(wp), dimension(:,:), pointer :: out_data + real(wp), dimension(2,3) :: expected + + ! Create an arbitrary input array + in_data1(:,:) = reshape([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [2, 3]) + in_data2(:,:) = reshape([7.0, 8.0, 9.0, 10.0, 11.0, 12.0], [2, 3]) + + ! Create tensor based off input arrays + call torch_tensor_from_array(a, in_data1, tensor_layout, device_type, requires_grad=.true.) + call torch_tensor_from_array(b, in_data2, tensor_layout, device_type, requires_grad=.true.) + + ! Create another empty tensor and assign it to the sum of the first two using the overloaded + ! addition operator + call torch_tensor_empty(Q, ndims, tensor_shape, dtype, device_type) + Q = a + b + + ! Apply back-propagation + ! TODO: Automate choice of ones for external_gradient + call torch_tensor_ones(external_gradient, ndims, tensor_shape, dtype, device_type) + call torch_tensor_backward(Q, external_gradient) + dQda = get_gradient(a) + dQdb = get_gradient(b) + + ! Extract Fortran array from the first computed gradient and its data with the expected value: + ! Q(a,b) = a + b => dQ/da = 1 + call torch_tensor_to_array(dQda, out_data, shape(in_data1)) + expected(:,:) = 1.0 + if (.not. assert_allclose(out_data, expected, test_name="test_torch_tensor_add1")) then + call clean_up() + print *, "Error :: incorrect gradient w.r.t. first input for addition" + stop 999 + end if + + ! Extract Fortran array from the second computed gradient and its data with the expected value: + ! Q(a,b) = a + b => dQ/db = 1 + call torch_tensor_to_array(dQdb, out_data, shape(in_data2)) + expected(:,:) = 1.0 + if (.not. assert_allclose(out_data, expected, test_name="test_torch_tensor_add2")) then + call clean_up() + print *, "Error :: incorrect gradient w.r.t. second input for addition" + stop 999 + end if + + call clean_up() + + contains + + ! Subroutine for freeing memory and nullifying pointers used in the unit test + subroutine clean_up() + nullify(out_data) + call torch_tensor_delete(a) + call torch_tensor_delete(b) + call torch_tensor_delete(Q) + call torch_tensor_delete(dQda) + call torch_tensor_delete(dQdb) + call torch_tensor_delete(external_gradient) + end subroutine clean_up + + end subroutine test_torch_tensor_add + ! TODO: Other operators end module test_tensor_operator_overloads_autograd From 13d28e8beae04d7f6e27067a2175ba6fb5dee05e Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Tue, 21 Jan 2025 16:19:33 +0000 Subject: [PATCH 09/24] Unit test for gradient of subtraction --- ...test_tensor_operator_overloads_autograd.pf | 77 +++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/test/unit/test_tensor_operator_overloads_autograd.pf b/test/unit/test_tensor_operator_overloads_autograd.pf index dbbef3cd..20be2211 100644 --- a/test/unit/test_tensor_operator_overloads_autograd.pf +++ b/test/unit/test_tensor_operator_overloads_autograd.pf @@ -203,6 +203,83 @@ contains end subroutine test_torch_tensor_add + @test(testParameters={get_parameters_short()}) + subroutine test_torch_tensor_subtract(this) + use, intrinsic :: iso_fortran_env, only: sp => real32 + use ftorch, only: operator(-) + + implicit none + + ! Set working precision for reals + integer, parameter :: wp = sp + + class(TestCaseType), intent(inout) :: this + type(torch_tensor) :: Q, a, b, external_gradient, dQda, dQdb + integer, parameter :: ndims = 2 + integer(ftorch_int), parameter :: tensor_layout(ndims) = [1, 2] + integer(c_int64_t), parameter, dimension(ndims) :: tensor_shape = [2, 3] + integer, parameter :: dtype = torch_kFloat32 + real(wp), dimension(2,3), target :: in_data1, in_data2 + real(wp), dimension(:,:), pointer :: out_data + real(wp), dimension(2,3) :: expected + + ! Create an arbitrary input array + in_data1(:,:) = reshape([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [2, 3]) + in_data2(:,:) = reshape([7.0, 8.0, 9.0, 10.0, 11.0, 12.0], [2, 3]) + + ! Create tensor based off input arrays + call torch_tensor_from_array(a, in_data1, tensor_layout, device_type, requires_grad=.true.) + call torch_tensor_from_array(b, in_data2, tensor_layout, device_type, requires_grad=.true.) + + ! Create another empty tensor and assign it to the difference of the first two using the + ! overloaded subtraction operator + call torch_tensor_empty(Q, ndims, tensor_shape, dtype, device_type) + Q = a - b + + ! Apply back-propagation + ! TODO: Automate choice of ones for external_gradient + call torch_tensor_ones(external_gradient, ndims, tensor_shape, dtype, device_type) + call torch_tensor_backward(Q, external_gradient) + dQda = get_gradient(a) + dQdb = get_gradient(b) + + ! Extract Fortran array from the first computed gradient and its data with the expected value: + ! Q(a,b) = a - b => dQ/da = 1 + call torch_tensor_to_array(dQda, out_data, shape(in_data1)) + expected(:,:) = 1.0 + if (.not. assert_allclose(out_data, expected, test_name="test_torch_tensor_subtract1")) then + call clean_up() + print *, "Error :: incorrect gradient w.r.t. first input for subtraction" + stop 999 + end if + + ! Extract Fortran array from the second computed gradient and its data with the expected value: + ! Q(a,b) = a - b => dQ/db = -1 + call torch_tensor_to_array(dQdb, out_data, shape(in_data2)) + expected(:,:) = -1.0 + if (.not. assert_allclose(out_data, expected, test_name="test_torch_tensor_subtract2")) then + call clean_up() + print *, "Error :: incorrect gradient w.r.t. second input for subtraction" + stop 999 + end if + + call clean_up() + + contains + + ! Subroutine for freeing memory and nullifying pointers used in the unit test + subroutine clean_up() + nullify(out_data) + call torch_tensor_delete(a) + call torch_tensor_delete(b) + call torch_tensor_delete(Q) + call torch_tensor_delete(dQda) + call torch_tensor_delete(dQdb) + call torch_tensor_delete(external_gradient) + end subroutine clean_up + + end subroutine test_torch_tensor_subtract + ! TODO: Other operators end module test_tensor_operator_overloads_autograd From 13d8e40bf21c0e550efb6d22e990cb1c88c2bdb6 Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Tue, 21 Jan 2025 16:22:48 +0000 Subject: [PATCH 10/24] Unit test for gradient of negative --- ...test_tensor_operator_overloads_autograd.pf | 61 +++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/test/unit/test_tensor_operator_overloads_autograd.pf b/test/unit/test_tensor_operator_overloads_autograd.pf index 20be2211..b95ca53e 100644 --- a/test/unit/test_tensor_operator_overloads_autograd.pf +++ b/test/unit/test_tensor_operator_overloads_autograd.pf @@ -203,6 +203,67 @@ contains end subroutine test_torch_tensor_add + @test(testParameters={get_parameters_short()}) + subroutine test_torch_tensor_negative(this) + use, intrinsic :: iso_fortran_env, only: sp => real32 + use ftorch, only: operator(-) + + implicit none + + ! Set working precision for reals + integer, parameter :: wp = sp + + class(TestCaseType), intent(inout) :: this + type(torch_tensor) :: Q, a, external_gradient, dQda + integer, parameter :: ndims = 2 + integer(ftorch_int), parameter :: tensor_layout(ndims) = [1, 2] + integer(c_int64_t), parameter, dimension(ndims) :: tensor_shape = [2, 3] + integer, parameter :: dtype = torch_kFloat32 + real(wp), dimension(2,3), target :: in_data + real(wp), dimension(:,:), pointer :: out_data + real(wp), dimension(2,3) :: expected + + ! Create an arbitrary input array + in_data(:,:) = reshape([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [2, 3]) + + ! Create a tensor based off the input array + call torch_tensor_from_array(a, in_data, tensor_layout, device_type, requires_grad=.true.) + + ! Create another empty tensor and assign it to the negative of the first one using the + ! overloaded negation operator + call torch_tensor_empty(Q, ndims, tensor_shape, dtype, device_type) + Q = -a + + ! Apply back-propagation + ! TODO: Automate choice of ones for external_gradient + call torch_tensor_ones(external_gradient, ndims, tensor_shape, dtype, device_type) + call torch_tensor_backward(Q, external_gradient) + dQda = get_gradient(a) + + ! Extract Fortran array from the computed gradient and its data with the expected value: + ! Q(a) = a => dQ/da = -1 + call torch_tensor_to_array(dQda, out_data, shape(in_data)) + expected(:,:) = -1.0 + if (.not. assert_allclose(out_data, expected, test_name="test_torch_tensor_negative")) then + call clean_up() + print *, "Error :: incorrect gradient for negation" + stop 999 + end if + + call clean_up() + + contains + + ! Subroutine for freeing memory and nullifying pointers used in the unit test + subroutine clean_up() + nullify(out_data) + call torch_tensor_delete(a) + call torch_tensor_delete(Q) + call torch_tensor_delete(dQda) + end subroutine clean_up + + end subroutine test_torch_tensor_negative + @test(testParameters={get_parameters_short()}) subroutine test_torch_tensor_subtract(this) use, intrinsic :: iso_fortran_env, only: sp => real32 From d8f40b286702861441eafec2e5483d2eef33ce71 Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Tue, 21 Jan 2025 16:24:42 +0000 Subject: [PATCH 11/24] Unit test for gradient of multiplication --- ...test_tensor_operator_overloads_autograd.pf | 77 +++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/test/unit/test_tensor_operator_overloads_autograd.pf b/test/unit/test_tensor_operator_overloads_autograd.pf index b95ca53e..afa12125 100644 --- a/test/unit/test_tensor_operator_overloads_autograd.pf +++ b/test/unit/test_tensor_operator_overloads_autograd.pf @@ -341,6 +341,83 @@ contains end subroutine test_torch_tensor_subtract + @test(testParameters={get_parameters_short()}) + subroutine test_torch_tensor_multiply(this) + use, intrinsic :: iso_fortran_env, only: sp => real32 + use ftorch, only: operator(*) + + implicit none + + ! Set working precision for reals + integer, parameter :: wp = sp + + class(TestCaseType), intent(inout) :: this + type(torch_tensor) :: Q, a, b, external_gradient, dQda, dQdb + integer, parameter :: ndims = 2 + integer(ftorch_int), parameter :: tensor_layout(ndims) = [1, 2] + integer(c_int64_t), parameter, dimension(ndims) :: tensor_shape = [2, 3] + integer, parameter :: dtype = torch_kFloat32 + real(wp), dimension(2,3), target :: in_data1, in_data2 + real(wp), dimension(:,:), pointer :: out_data + real(wp), dimension(2,3) :: expected + + ! Create an arbitrary input array + in_data1(:,:) = reshape([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [2, 3]) + in_data2(:,:) = reshape([7.0, 8.0, 9.0, 10.0, 11.0, 12.0], [2, 3]) + + ! Create tensor based off input arrays + call torch_tensor_from_array(a, in_data1, tensor_layout, device_type, requires_grad=.true.) + call torch_tensor_from_array(b, in_data2, tensor_layout, device_type, requires_grad=.true.) + + ! Create another empty tensor and assign it to the product of the first two using the + ! overloaded multiplication operator + call torch_tensor_empty(Q, ndims, tensor_shape, dtype, device_type) + Q = a * b + + ! Apply back-propagation + ! TODO: Automate choice of ones for external_gradient + call torch_tensor_ones(external_gradient, ndims, tensor_shape, dtype, device_type) + call torch_tensor_backward(Q, external_gradient) + dQda = get_gradient(a) + dQdb = get_gradient(b) + + ! Extract Fortran array from the first computed gradient and its data with the expected value: + ! Q(a,b) = a * b => dQ/da = b + call torch_tensor_to_array(dQda, out_data, shape(in_data1)) + expected(:,:) = in_data2 + if (.not. assert_allclose(out_data, expected, test_name="test_torch_tensor_multiply1")) then + call clean_up() + print *, "Error :: incorrect gradient w.r.t. first input for multiplication" + stop 999 + end if + + ! Extract Fortran array from the second computed gradient and its data with the expected value: + ! Q(a,b) = a * b => dQ/db = a + call torch_tensor_to_array(dQdb, out_data, shape(in_data2)) + expected(:,:) = in_data1 + if (.not. assert_allclose(out_data, expected, test_name="test_torch_tensor_multiply2")) then + call clean_up() + print *, "Error :: incorrect gradient w.r.t. second input for multiplication" + stop 999 + end if + + call clean_up() + + contains + + ! Subroutine for freeing memory and nullifying pointers used in the unit test + subroutine clean_up() + nullify(out_data) + call torch_tensor_delete(a) + call torch_tensor_delete(b) + call torch_tensor_delete(Q) + call torch_tensor_delete(dQda) + call torch_tensor_delete(dQdb) + call torch_tensor_delete(external_gradient) + end subroutine clean_up + + end subroutine test_torch_tensor_multiply + ! TODO: Other operators end module test_tensor_operator_overloads_autograd From abb303e9283314fabf52da70cdf6d587c30b2212 Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Tue, 21 Jan 2025 16:44:33 +0000 Subject: [PATCH 12/24] Unit test for gradient of division --- ...test_tensor_operator_overloads_autograd.pf | 77 +++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/test/unit/test_tensor_operator_overloads_autograd.pf b/test/unit/test_tensor_operator_overloads_autograd.pf index afa12125..91f0c925 100644 --- a/test/unit/test_tensor_operator_overloads_autograd.pf +++ b/test/unit/test_tensor_operator_overloads_autograd.pf @@ -418,6 +418,83 @@ contains end subroutine test_torch_tensor_multiply + @test(testParameters={get_parameters_short()}) + subroutine test_torch_tensor_divide(this) + use, intrinsic :: iso_fortran_env, only: sp => real32 + use ftorch, only: operator(/) + + implicit none + + ! Set working precision for reals + integer, parameter :: wp = sp + + class(TestCaseType), intent(inout) :: this + type(torch_tensor) :: Q, a, b, external_gradient, dQda, dQdb + integer, parameter :: ndims = 2 + integer(ftorch_int), parameter :: tensor_layout(ndims) = [1, 2] + integer(c_int64_t), parameter, dimension(ndims) :: tensor_shape = [2, 3] + integer, parameter :: dtype = torch_kFloat32 + real(wp), dimension(2,3), target :: in_data1, in_data2 + real(wp), dimension(:,:), pointer :: out_data + real(wp), dimension(2,3) :: expected + + ! Create an arbitrary input array + in_data1(:,:) = reshape([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [2, 3]) + in_data2(:,:) = reshape([7.0, 8.0, 9.0, 10.0, 11.0, 12.0], [2, 3]) + + ! Create tensor based off input arrays + call torch_tensor_from_array(a, in_data1, tensor_layout, device_type, requires_grad=.true.) + call torch_tensor_from_array(b, in_data2, tensor_layout, device_type, requires_grad=.true.) + + ! Create another empty tensor and assign it to the quotient of the first two using the + ! overloaded division operator + call torch_tensor_empty(Q, ndims, tensor_shape, dtype, device_type) + Q = a / b + + ! Apply back-propagation + ! TODO: Automate choice of ones for external_gradient + call torch_tensor_ones(external_gradient, ndims, tensor_shape, dtype, device_type) + call torch_tensor_backward(Q, external_gradient) + dQda = get_gradient(a) + dQdb = get_gradient(b) + + ! Extract Fortran array from the first computed gradient and its data with the expected value: + ! Q(a,b) = a / b => dQ/da = 1 / b + call torch_tensor_to_array(dQda, out_data, shape(in_data1)) + expected(:,:) = 1.0 / in_data2 + if (.not. assert_allclose(out_data, expected, test_name="test_torch_tensor_divide1")) then + call clean_up() + print *, "Error :: incorrect gradient w.r.t. numerator for division" + stop 999 + end if + + ! Extract Fortran array from the second computed gradient and its data with the expected value: + ! Q(a,b) = a / b => dQ/db = -a / b^2 + call torch_tensor_to_array(dQdb, out_data, shape(in_data2)) + expected(:,:) = -in_data1 / in_data2 ** 2 + if (.not. assert_allclose(out_data, expected, test_name="test_torch_tensor_divide2")) then + call clean_up() + print *, "Error :: incorrect gradient w.r.t. denominator for division" + stop 999 + end if + + call clean_up() + + contains + + ! Subroutine for freeing memory and nullifying pointers used in the unit test + subroutine clean_up() + nullify(out_data) + call torch_tensor_delete(a) + call torch_tensor_delete(b) + call torch_tensor_delete(Q) + call torch_tensor_delete(dQda) + call torch_tensor_delete(dQdb) + call torch_tensor_delete(external_gradient) + end subroutine clean_up + + end subroutine test_torch_tensor_divide + ! TODO: Other operators end module test_tensor_operator_overloads_autograd From 0c9a59a66c88e0c614de3cfa53a9224adbb1c66e Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Tue, 21 Jan 2025 16:54:14 +0000 Subject: [PATCH 13/24] Unit test for gradient of square --- ...test_tensor_operator_overloads_autograd.pf | 65 +++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/test/unit/test_tensor_operator_overloads_autograd.pf b/test/unit/test_tensor_operator_overloads_autograd.pf index 91f0c925..b059037b 100644 --- a/test/unit/test_tensor_operator_overloads_autograd.pf +++ b/test/unit/test_tensor_operator_overloads_autograd.pf @@ -497,4 +497,69 @@ contains ! TODO: Other operators + @test(testParameters={get_parameters_full()}) + subroutine test_torch_tensor_square(this) + use, intrinsic :: iso_fortran_env, only: sp => real32 + use ftorch, only: operator(**) + + implicit none + + ! Set working precision for reals + integer, parameter :: wp = sp + + class(TestCaseType), intent(inout) :: this + type(torch_tensor) :: Q, a, external_gradient, dQda + integer, parameter :: ndims = 2 + integer(ftorch_int), parameter :: tensor_layout(ndims) = [1, 2] + integer(c_int64_t), parameter, dimension(ndims) :: tensor_shape = [2, 3] + integer, parameter :: dtype = torch_kFloat32 + real(wp), dimension(2,3), target :: in_data + real(wp), dimension(:,:), pointer :: out_data + real(wp), dimension(2,3) :: expected + + ! Create an arbitrary input array + in_data(:,:) = reshape([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [2, 3]) + + ! Create a tensor based off the input array + call torch_tensor_from_array(a, in_data, tensor_layout, device_type, requires_grad=.true.) + + ! Create another empty tensor and assign it to the square of the first one using the + ! overloaded power operator + call torch_tensor_empty(Q, ndims, tensor_shape, dtype, device_type) + if (this%param%switch) then + Q = a ** 2.0 + else + Q = a ** 2 + end if + + ! Apply back-propagation + ! TODO: Automate choice of ones for external_gradient + call torch_tensor_ones(external_gradient, ndims, tensor_shape, dtype, device_type) + call torch_tensor_backward(Q, external_gradient) + dQda = get_gradient(a) + + ! Extract Fortran array from the computed gradient and its data with the expected value: + ! Q(a) = a^2 => dQ/da = 2 * a + call torch_tensor_to_array(dQda, out_data, shape(in_data)) + expected(:,:) = 2.0 * in_data + if (.not. assert_allclose(out_data, expected, test_name="test_torch_tensor_square")) then + call clean_up() + print *, "Error :: incorrect gradient for square" + stop 999 + end if + + call clean_up() + + contains + + ! Subroutine for freeing memory and nullifying pointers used in the unit test + subroutine clean_up() + nullify(out_data) + call torch_tensor_delete(a) + call torch_tensor_delete(Q) + call torch_tensor_delete(dQda) + end subroutine clean_up + + end subroutine test_torch_tensor_square + end module test_tensor_operator_overloads_autograd From f68d647271941ade5c040d8a792c416c043e0b9d Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Tue, 21 Jan 2025 16:58:36 +0000 Subject: [PATCH 14/24] Unit test for gradient of square root --- ...test_tensor_operator_overloads_autograd.pf | 61 +++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/test/unit/test_tensor_operator_overloads_autograd.pf b/test/unit/test_tensor_operator_overloads_autograd.pf index b059037b..ff97c735 100644 --- a/test/unit/test_tensor_operator_overloads_autograd.pf +++ b/test/unit/test_tensor_operator_overloads_autograd.pf @@ -562,4 +562,65 @@ contains end subroutine test_torch_tensor_square + @test(testParameters={get_parameters_short()}) + subroutine test_torch_tensor_sqrt(this) + use, intrinsic :: iso_fortran_env, only: sp => real32 + use ftorch, only: operator(**) + + implicit none + + ! Set working precision for reals + integer, parameter :: wp = sp + + class(TestCaseType), intent(inout) :: this + type(torch_tensor) :: Q, a, external_gradient, dQda + integer, parameter :: ndims = 2 + integer(ftorch_int), parameter :: tensor_layout(ndims) = [1, 2] + integer(c_int64_t), parameter, dimension(ndims) :: tensor_shape = [2, 3] + integer, parameter :: dtype = torch_kFloat32 + real(wp), dimension(2,3), target :: in_data + real(wp), dimension(:,:), pointer :: out_data + real(wp), dimension(2,3) :: expected + + ! Create an arbitrary input array + in_data(:,:) = reshape([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [2, 3]) + + ! Create a tensor based off the input array + call torch_tensor_from_array(a, in_data, tensor_layout, device_type, requires_grad=.true.) + + ! Create another empty tensor and assign it to the square root of the first one using the + ! overloaded power operator + call torch_tensor_empty(Q, ndims, tensor_shape, dtype, device_type) + Q = a ** 0.5 + + ! Apply back-propagation + ! TODO: Automate choice of ones for external_gradient + call torch_tensor_ones(external_gradient, ndims, tensor_shape, dtype, device_type) + call torch_tensor_backward(Q, external_gradient) + dQda = get_gradient(a) + + ! Extract Fortran array from the computed gradient and its data with the expected value: + ! Q(a) = a^{1/2} => dQ/da = 0.5 * a^{-1/2}) + call torch_tensor_to_array(dQda, out_data, shape(in_data)) + expected(:,:) = 0.5 / in_data ** 0.5 + if (.not. assert_allclose(out_data, expected, test_name="test_torch_tensor_sqrt")) then + call clean_up() + print *, "Error :: incorrect gradient for square root" + stop 999 + end if + + call clean_up() + + contains + + ! Subroutine for freeing memory and nullifying pointers used in the unit test + subroutine clean_up() + nullify(out_data) + call torch_tensor_delete(a) + call torch_tensor_delete(Q) + call torch_tensor_delete(dQda) + end subroutine clean_up + + end subroutine test_torch_tensor_sqrt + end module test_tensor_operator_overloads_autograd From f100c97c6ed16cfd073161832b639770c5c60413 Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Tue, 21 Jan 2025 16:50:03 +0000 Subject: [PATCH 15/24] Unit test for gradient of scalar multiplication - FIXME --- test/unit/test_tensor_operator_overloads.pf | 4 +- ...test_tensor_operator_overloads_autograd.pf | 70 +++++++++++++++++++ 2 files changed, 72 insertions(+), 2 deletions(-) diff --git a/test/unit/test_tensor_operator_overloads.pf b/test/unit/test_tensor_operator_overloads.pf index 9d9968c2..0f32058d 100644 --- a/test/unit/test_tensor_operator_overloads.pf +++ b/test/unit/test_tensor_operator_overloads.pf @@ -433,8 +433,8 @@ contains ! Create a tensor based off the input array call torch_tensor_from_array(tensor1, in_data, tensor_layout, device_type) - ! Create another two empty tensors and assign them to the products of a scalar constant and the - ! first tensor using the overloaded multiplication operator (in each order) + ! Create another empty tensors and assign it to the product of a scalar constant and the first + ! tensor using the overloaded multiplication operator call torch_tensor_empty(tensor2, ndims, tensor_shape, dtype, device_type) if (this%param%switch) then tensor2 = scalar * tensor1 diff --git a/test/unit/test_tensor_operator_overloads_autograd.pf b/test/unit/test_tensor_operator_overloads_autograd.pf index ff97c735..4bf50a23 100644 --- a/test/unit/test_tensor_operator_overloads_autograd.pf +++ b/test/unit/test_tensor_operator_overloads_autograd.pf @@ -418,6 +418,76 @@ contains end subroutine test_torch_tensor_multiply + @test(testParameters={get_parameters_full()}) + subroutine test_torch_tensor_scalar_multiply(this) + use, intrinsic :: iso_fortran_env, only: sp => real32 + use ftorch, only: operator(*), torch_tensor_print + + implicit none + + ! Set working precision for reals + integer, parameter :: wp = sp + + class(TestCaseType), intent(inout) :: this + type(torch_tensor) :: Q, a, external_gradient, dQda + integer, parameter :: ndims = 2 + integer(ftorch_int), parameter :: tensor_layout(ndims) = [1, 2] + integer(c_int64_t), parameter, dimension(ndims) :: tensor_shape = [2, 3] + integer, parameter :: dtype = torch_kFloat32 + real(wp), parameter :: scalar = 3.14 + real(wp), dimension(2,3), target :: in_data + real(wp), dimension(:,:), pointer :: out_data + real(wp), dimension(2,3) :: expected + logical :: test_pass + + ! Create an arbitrary input array + in_data(:,:) = reshape([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [2, 3]) + + ! Create tensor based off input array + call torch_tensor_from_array(a, in_data, tensor_layout, device_type, requires_grad=.true.) + + ! Create another empty tensors and assign it to the product of a scalar constant and the first + ! tensor using the overloaded multiplication operator + call torch_tensor_empty(Q, ndims, tensor_shape, dtype, device_type) + if (this%param%switch) then + Q = scalar * a + else + Q = a * scalar + end if + + ! Apply back-propagation + ! TODO: Automate choice of ones for external_gradient + call torch_tensor_ones(external_gradient, ndims, tensor_shape, dtype, device_type) + call torch_tensor_backward(Q, external_gradient) + dQda = get_gradient(a) + + ! Extract Fortran array from the first computed gradient and its data with the expected value: + ! Q(a,b) = scalar * a => dQ/da = scalar + call torch_tensor_to_array(dQda, out_data, shape(in_data)) + call torch_tensor_print(dQda) ! TODO: Temp + expected(:,:) = scalar + test_pass = assert_allclose(out_data, expected, test_name="test_torch_tensor_scalar_multiply") + if (.not. test_pass) then + call clean_up() + print *, "Error :: incorrect gradient for scalar multiplication" + stop 999 + end if + + call clean_up() + + contains + + ! Subroutine for freeing memory and nullifying pointers used in the unit test + subroutine clean_up() + nullify(out_data) + call torch_tensor_delete(a) + call torch_tensor_delete(Q) + call torch_tensor_delete(dQda) + call torch_tensor_delete(external_gradient) + end subroutine clean_up + + end subroutine test_torch_tensor_scalar_multiply + @test(testParameters={get_parameters_short()}) subroutine test_torch_tensor_divide(this) use, intrinsic :: iso_fortran_env, only: sp => real32 From 55d38496bfe8b027ca3b92e9ab0ebc45cb4b65ee Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Tue, 21 Jan 2025 16:49:37 +0000 Subject: [PATCH 16/24] Unit test for gradient of scalar division - FIXME --- ...test_tensor_operator_overloads_autograd.pf | 66 ++++++++++++++++++- 1 file changed, 65 insertions(+), 1 deletion(-) diff --git a/test/unit/test_tensor_operator_overloads_autograd.pf b/test/unit/test_tensor_operator_overloads_autograd.pf index 4bf50a23..842429ef 100644 --- a/test/unit/test_tensor_operator_overloads_autograd.pf +++ b/test/unit/test_tensor_operator_overloads_autograd.pf @@ -565,7 +565,71 @@ contains end subroutine test_torch_tensor_divide - ! TODO: Other operators + @test(testParameters={get_parameters_short()}) + subroutine test_torch_tensor_scalar_divide(this) + use, intrinsic :: iso_fortran_env, only: sp => real32 + use ftorch, only: operator(/), torch_tensor_print + + implicit none + + ! Set working precision for reals + integer, parameter :: wp = sp + + class(TestCaseType), intent(inout) :: this + type(torch_tensor) :: Q, a, external_gradient, dQda + integer, parameter :: ndims = 2 + integer(ftorch_int), parameter :: tensor_layout(ndims) = [1, 2] + integer(c_int64_t), parameter, dimension(ndims) :: tensor_shape = [2, 3] + integer, parameter :: dtype = torch_kFloat32 + real(wp), parameter :: scalar = 3.14 + real(wp), dimension(2,3), target :: in_data + real(wp), dimension(:,:), pointer :: out_data + real(wp), dimension(2,3) :: expected + logical :: test_pass + + ! Create an arbitrary input array + in_data(:,:) = reshape([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [2, 3]) + + ! Create tensor based off input array + call torch_tensor_from_array(a, in_data, tensor_layout, device_type, requires_grad=.true.) + + ! Create another empty tensors and assign it to the product of a scalar constant and the first + ! tensor using the overloaded multiplication operator + call torch_tensor_empty(Q, ndims, tensor_shape, dtype, device_type) + Q = a / scalar + + ! Apply back-propagation + ! TODO: Automate choice of ones for external_gradient + call torch_tensor_ones(external_gradient, ndims, tensor_shape, dtype, device_type) + call torch_tensor_backward(Q, external_gradient) + dQda = get_gradient(a) + + ! Extract Fortran array from the first computed gradient and its data with the expected value: + ! Q(a,b) = a / scalar => dQ/da = 1 / scalar + call torch_tensor_to_array(dQda, out_data, shape(in_data)) + call torch_tensor_print(dQda) ! TODO: Temp + expected(:,:) = 1.0 / scalar + test_pass = assert_allclose(out_data, expected, test_name="test_torch_tensor_scalar_divide") + if (.not. test_pass) then + call clean_up() + print *, "Error :: incorrect gradient for scalar division" + stop 999 + end if + + call clean_up() + + contains + + ! Subroutine for freeing memory and nullifying pointers used in the unit test + subroutine clean_up() + nullify(out_data) + call torch_tensor_delete(a) + call torch_tensor_delete(Q) + call torch_tensor_delete(dQda) + call torch_tensor_delete(external_gradient) + end subroutine clean_up + + end subroutine test_torch_tensor_scalar_divide @test(testParameters={get_parameters_full()}) subroutine test_torch_tensor_square(this) From 37e687380487e8da5986bc334b62eb63ebab76cf Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Mon, 17 Feb 2025 12:06:57 +0000 Subject: [PATCH 17/24] Rename get_gradient and provide method --- examples/6_Autograd/autograd.f90 | 6 ++-- src/ctorch.cpp | 2 +- src/ctorch.h | 2 +- src/ftorch.F90 | 15 +++++----- src/ftorch.fypp | 15 +++++----- ...test_tensor_operator_overloads_autograd.pf | 30 +++++++++---------- 6 files changed, 36 insertions(+), 34 deletions(-) diff --git a/examples/6_Autograd/autograd.f90 b/examples/6_Autograd/autograd.f90 index cc0b266b..72a93ba1 100644 --- a/examples/6_Autograd/autograd.f90 +++ b/examples/6_Autograd/autograd.f90 @@ -5,7 +5,7 @@ program example ! Import our library for interfacing with PyTorch's Autograd module use ftorch, only: assignment(=), operator(+), operator(-), operator(*), operator(/), & - operator(**), get_gradient, torch_kCPU, torch_tensor, torch_tensor_backward, & + operator(**), torch_kCPU, torch_tensor, torch_tensor_backward, & torch_tensor_delete, torch_tensor_from_array, torch_tensor_to_array ! Import our tools module for testing utils @@ -56,8 +56,8 @@ program example in_data3(:) = [1.0_wp, 1.0_wp] call torch_tensor_from_array(external_gradient, in_data3, tensor_layout, torch_kCPU) call torch_tensor_backward(Q, external_gradient) - dQda = get_gradient(a) - dQdb = get_gradient(b) + dQda = a%grad() + dQdb = b%grad() ! Extract Fortran arrays from the Torch tensors and check the gradients take expected values call torch_tensor_to_array(dQda, out_data2, shape(in_data1)) diff --git a/src/ctorch.cpp b/src/ctorch.cpp index 1267be51..9ffe993d 100644 --- a/src/ctorch.cpp +++ b/src/ctorch.cpp @@ -414,7 +414,7 @@ void torch_tensor_backward(const torch_tensor_t tensor, t->backward(*g); } -EXPORT_C torch_tensor_t get_gradient(const torch_tensor_t tensor) { +EXPORT_C torch_tensor_t torch_tensor_get_gradient(const torch_tensor_t tensor) { auto t = reinterpret_cast(tensor); torch::Tensor *output = nullptr; output = new torch::Tensor; diff --git a/src/ctorch.h b/src/ctorch.h index 6f39b793..e84282e2 100644 --- a/src/ctorch.h +++ b/src/ctorch.h @@ -260,7 +260,7 @@ EXPORT_C void torch_tensor_backward(const torch_tensor_t tensor, * @param Tensor to get the gradient of * @return Tensor for the gradient */ -EXPORT_C torch_tensor_t get_gradient(const torch_tensor_t tensor); +EXPORT_C torch_tensor_t torch_tensor_get_gradient(const torch_tensor_t tensor); // ============================================================================= // --- Torch model API diff --git a/src/ftorch.F90 b/src/ftorch.F90 index 4f77a0ce..6356bab7 100644 --- a/src/ftorch.F90 +++ b/src/ftorch.F90 @@ -34,6 +34,7 @@ module ftorch procedure :: get_dtype => torch_tensor_get_dtype procedure :: get_device_type => torch_tensor_get_device_type procedure :: get_device_index => torch_tensor_get_device_index + procedure :: grad => torch_tensor_get_gradient end type torch_tensor !| Enumerator for Torch data types @@ -2920,22 +2921,22 @@ end subroutine torch_tensor_backward_c end subroutine torch_tensor_backward !> Retreives the gradient of a Torch Tensor. - function get_gradient(tensor) result(gradient) - type(torch_tensor), intent(in) :: tensor + function torch_tensor_get_gradient(tensor) result(gradient) + class(torch_tensor), intent(in) :: tensor type(torch_tensor) :: gradient interface - function get_gradient_c(tensor_c) result(gradient_c) & - bind(c, name = 'get_gradient') + function torch_tensor_get_gradient_c(tensor_c) result(gradient_c) & + bind(c, name = 'torch_tensor_get_gradient') use, intrinsic :: iso_c_binding, only : c_ptr implicit none type(c_ptr), value, intent(in) :: tensor_c type(c_ptr) :: gradient_c - end function get_gradient_c + end function torch_tensor_get_gradient_c end interface - gradient%p = get_gradient_c(tensor%p) - end function get_gradient + gradient%p = torch_tensor_get_gradient_c(tensor%p) + end function torch_tensor_get_gradient ! ============================================================================ ! --- Torch Model API diff --git a/src/ftorch.fypp b/src/ftorch.fypp index 075e96bc..d8789606 100644 --- a/src/ftorch.fypp +++ b/src/ftorch.fypp @@ -53,6 +53,7 @@ module ftorch procedure :: get_dtype => torch_tensor_get_dtype procedure :: get_device_type => torch_tensor_get_device_type procedure :: get_device_index => torch_tensor_get_device_index + procedure :: grad => torch_tensor_get_gradient end type torch_tensor !| Enumerator for Torch data types @@ -855,22 +856,22 @@ contains end subroutine torch_tensor_backward !> Retreives the gradient of a Torch Tensor. - function get_gradient(tensor) result(gradient) - type(torch_tensor), intent(in) :: tensor + function torch_tensor_get_gradient(tensor) result(gradient) + class(torch_tensor), intent(in) :: tensor type(torch_tensor) :: gradient interface - function get_gradient_c(tensor_c) result(gradient_c) & - bind(c, name = 'get_gradient') + function torch_tensor_get_gradient_c(tensor_c) result(gradient_c) & + bind(c, name = 'torch_tensor_get_gradient') use, intrinsic :: iso_c_binding, only : c_ptr implicit none type(c_ptr), value, intent(in) :: tensor_c type(c_ptr) :: gradient_c - end function get_gradient_c + end function torch_tensor_get_gradient_c end interface - gradient%p = get_gradient_c(tensor%p) - end function get_gradient + gradient%p = torch_tensor_get_gradient_c(tensor%p) + end function torch_tensor_get_gradient ! ============================================================================ ! --- Torch Model API diff --git a/test/unit/test_tensor_operator_overloads_autograd.pf b/test/unit/test_tensor_operator_overloads_autograd.pf index 842429ef..4a39fcac 100644 --- a/test/unit/test_tensor_operator_overloads_autograd.pf +++ b/test/unit/test_tensor_operator_overloads_autograd.pf @@ -6,7 +6,7 @@ ! file for details. module test_tensor_operator_overloads_autograd use funit - use ftorch, only: assignment(=), get_gradient, ftorch_int, torch_kCPU, torch_kFloat32, & + use ftorch, only: assignment(=), ftorch_int, torch_kCPU, torch_kFloat32, & torch_tensor, torch_tensor_backward, torch_tensor_delete, torch_tensor_empty, & torch_tensor_from_array, torch_tensor_ones, torch_tensor_to_array use ftorch_test_utils, only: assert_allclose @@ -99,7 +99,7 @@ contains ! TODO: Automate choice of ones for external_gradient call torch_tensor_ones(external_gradient, ndims, tensor_shape, dtype, device_type) call torch_tensor_backward(Q, external_gradient) - dQda = get_gradient(a) + dQda = a%grad() ! Extract Fortran array from the computed gradient and its data with the expected value: ! Q(a) = a => dQ/da = 1 @@ -163,8 +163,8 @@ contains ! TODO: Automate choice of ones for external_gradient call torch_tensor_ones(external_gradient, ndims, tensor_shape, dtype, device_type) call torch_tensor_backward(Q, external_gradient) - dQda = get_gradient(a) - dQdb = get_gradient(b) + dQda = a%grad() + dQdb = b%grad() ! Extract Fortran array from the first computed gradient and its data with the expected value: ! Q(a,b) = a + b => dQ/da = 1 @@ -238,7 +238,7 @@ contains ! TODO: Automate choice of ones for external_gradient call torch_tensor_ones(external_gradient, ndims, tensor_shape, dtype, device_type) call torch_tensor_backward(Q, external_gradient) - dQda = get_gradient(a) + dQda = a%grad() ! Extract Fortran array from the computed gradient and its data with the expected value: ! Q(a) = a => dQ/da = -1 @@ -301,8 +301,8 @@ contains ! TODO: Automate choice of ones for external_gradient call torch_tensor_ones(external_gradient, ndims, tensor_shape, dtype, device_type) call torch_tensor_backward(Q, external_gradient) - dQda = get_gradient(a) - dQdb = get_gradient(b) + dQda = a%grad() + dQdb = b%grad() ! Extract Fortran array from the first computed gradient and its data with the expected value: ! Q(a,b) = a - b => dQ/da = 1 @@ -378,8 +378,8 @@ contains ! TODO: Automate choice of ones for external_gradient call torch_tensor_ones(external_gradient, ndims, tensor_shape, dtype, device_type) call torch_tensor_backward(Q, external_gradient) - dQda = get_gradient(a) - dQdb = get_gradient(b) + dQda = a%grad() + dQdb = b%grad() ! Extract Fortran array from the first computed gradient and its data with the expected value: ! Q(a,b) = a * b => dQ/da = b @@ -459,7 +459,7 @@ contains ! TODO: Automate choice of ones for external_gradient call torch_tensor_ones(external_gradient, ndims, tensor_shape, dtype, device_type) call torch_tensor_backward(Q, external_gradient) - dQda = get_gradient(a) + dQda = a%grad() ! Extract Fortran array from the first computed gradient and its data with the expected value: ! Q(a,b) = scalar * a => dQ/da = scalar @@ -525,8 +525,8 @@ contains ! TODO: Automate choice of ones for external_gradient call torch_tensor_ones(external_gradient, ndims, tensor_shape, dtype, device_type) call torch_tensor_backward(Q, external_gradient) - dQda = get_gradient(a) - dQdb = get_gradient(b) + dQda = a%grad() + dQdb = b%grad() ! Extract Fortran array from the first computed gradient and its data with the expected value: ! Q(a,b) = a / b => dQ/da = 1 / b @@ -602,7 +602,7 @@ contains ! TODO: Automate choice of ones for external_gradient call torch_tensor_ones(external_gradient, ndims, tensor_shape, dtype, device_type) call torch_tensor_backward(Q, external_gradient) - dQda = get_gradient(a) + dQda = a%grad() ! Extract Fortran array from the first computed gradient and its data with the expected value: ! Q(a,b) = a / scalar => dQ/da = 1 / scalar @@ -670,7 +670,7 @@ contains ! TODO: Automate choice of ones for external_gradient call torch_tensor_ones(external_gradient, ndims, tensor_shape, dtype, device_type) call torch_tensor_backward(Q, external_gradient) - dQda = get_gradient(a) + dQda = a%grad() ! Extract Fortran array from the computed gradient and its data with the expected value: ! Q(a) = a^2 => dQ/da = 2 * a @@ -731,7 +731,7 @@ contains ! TODO: Automate choice of ones for external_gradient call torch_tensor_ones(external_gradient, ndims, tensor_shape, dtype, device_type) call torch_tensor_backward(Q, external_gradient) - dQda = get_gradient(a) + dQda = a%grad() ! Extract Fortran array from the computed gradient and its data with the expected value: ! Q(a) = a^{1/2} => dQ/da = 0.5 * a^{-1/2}) From bd28159e2094f1d85b38810bd10e6a3e433163c2 Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Mon, 17 Feb 2025 14:54:32 +0000 Subject: [PATCH 18/24] Drop unnecessary c_loc use --- src/ftorch.F90 | 18 ------------------ src/ftorch.fypp | 3 --- 2 files changed, 21 deletions(-) diff --git a/src/ftorch.F90 b/src/ftorch.F90 index 6356bab7..ea214c5f 100644 --- a/src/ftorch.F90 +++ b/src/ftorch.F90 @@ -2501,7 +2501,6 @@ end function torch_tensor_multiply !> Overloads multiplication operator for a scalar of type int8 and a tensor. function torch_tensor_premultiply_int8(scalar, tensor) result(output) - use, intrinsic :: iso_c_binding, only : c_loc use, intrinsic :: iso_fortran_env, only : int8 integer(int8), target, intent(in) :: scalar type(torch_tensor), intent(in) :: tensor @@ -2515,7 +2514,6 @@ end function torch_tensor_premultiply_int8 !> Overloads multiplication operator for a scalar of type int16 and a tensor. function torch_tensor_premultiply_int16(scalar, tensor) result(output) - use, intrinsic :: iso_c_binding, only : c_loc use, intrinsic :: iso_fortran_env, only : int16 integer(int16), target, intent(in) :: scalar type(torch_tensor), intent(in) :: tensor @@ -2529,7 +2527,6 @@ end function torch_tensor_premultiply_int16 !> Overloads multiplication operator for a scalar of type int32 and a tensor. function torch_tensor_premultiply_int32(scalar, tensor) result(output) - use, intrinsic :: iso_c_binding, only : c_loc use, intrinsic :: iso_fortran_env, only : int32 integer(int32), target, intent(in) :: scalar type(torch_tensor), intent(in) :: tensor @@ -2543,7 +2540,6 @@ end function torch_tensor_premultiply_int32 !> Overloads multiplication operator for a scalar of type int64 and a tensor. function torch_tensor_premultiply_int64(scalar, tensor) result(output) - use, intrinsic :: iso_c_binding, only : c_loc use, intrinsic :: iso_fortran_env, only : int64 integer(int64), target, intent(in) :: scalar type(torch_tensor), intent(in) :: tensor @@ -2557,7 +2553,6 @@ end function torch_tensor_premultiply_int64 !> Overloads multiplication operator for a scalar of type real32 and a tensor. function torch_tensor_premultiply_real32(scalar, tensor) result(output) - use, intrinsic :: iso_c_binding, only : c_loc use, intrinsic :: iso_fortran_env, only : real32 real(real32), target, intent(in) :: scalar type(torch_tensor), intent(in) :: tensor @@ -2571,7 +2566,6 @@ end function torch_tensor_premultiply_real32 !> Overloads multiplication operator for a scalar of type real64 and a tensor. function torch_tensor_premultiply_real64(scalar, tensor) result(output) - use, intrinsic :: iso_c_binding, only : c_loc use, intrinsic :: iso_fortran_env, only : real64 real(real64), target, intent(in) :: scalar type(torch_tensor), intent(in) :: tensor @@ -2586,7 +2580,6 @@ end function torch_tensor_premultiply_real64 !> Overloads multiplication operator for a tensor and a scalar of type int8. function torch_tensor_postmultiply_int8(tensor, scalar) result(output) - use, intrinsic :: iso_c_binding, only : c_loc use, intrinsic :: iso_fortran_env, only : int8 type(torch_tensor), intent(in) :: tensor integer(int8), intent(in) :: scalar @@ -2600,7 +2593,6 @@ end function torch_tensor_postmultiply_int8 !> Overloads multiplication operator for a tensor and a scalar of type int16. function torch_tensor_postmultiply_int16(tensor, scalar) result(output) - use, intrinsic :: iso_c_binding, only : c_loc use, intrinsic :: iso_fortran_env, only : int16 type(torch_tensor), intent(in) :: tensor integer(int16), intent(in) :: scalar @@ -2614,7 +2606,6 @@ end function torch_tensor_postmultiply_int16 !> Overloads multiplication operator for a tensor and a scalar of type int32. function torch_tensor_postmultiply_int32(tensor, scalar) result(output) - use, intrinsic :: iso_c_binding, only : c_loc use, intrinsic :: iso_fortran_env, only : int32 type(torch_tensor), intent(in) :: tensor integer(int32), intent(in) :: scalar @@ -2628,7 +2619,6 @@ end function torch_tensor_postmultiply_int32 !> Overloads multiplication operator for a tensor and a scalar of type int64. function torch_tensor_postmultiply_int64(tensor, scalar) result(output) - use, intrinsic :: iso_c_binding, only : c_loc use, intrinsic :: iso_fortran_env, only : int64 type(torch_tensor), intent(in) :: tensor integer(int64), intent(in) :: scalar @@ -2642,7 +2632,6 @@ end function torch_tensor_postmultiply_int64 !> Overloads multiplication operator for a tensor and a scalar of type real32. function torch_tensor_postmultiply_real32(tensor, scalar) result(output) - use, intrinsic :: iso_c_binding, only : c_loc use, intrinsic :: iso_fortran_env, only : real32 type(torch_tensor), intent(in) :: tensor real(real32), intent(in) :: scalar @@ -2656,7 +2645,6 @@ end function torch_tensor_postmultiply_real32 !> Overloads multiplication operator for a tensor and a scalar of type real64. function torch_tensor_postmultiply_real64(tensor, scalar) result(output) - use, intrinsic :: iso_c_binding, only : c_loc use, intrinsic :: iso_fortran_env, only : real64 type(torch_tensor), intent(in) :: tensor real(real64), intent(in) :: scalar @@ -2679,7 +2667,6 @@ end function torch_tensor_divide !> Overloads division operator for a tensor and a scalar of type int8. function torch_tensor_postdivide_int8(tensor, divisor) result(output) - use, intrinsic :: iso_c_binding, only : c_loc use, intrinsic :: iso_fortran_env, only : int8 type(torch_tensor), intent(in) :: tensor integer(int8), intent(in) :: divisor @@ -2693,7 +2680,6 @@ end function torch_tensor_postdivide_int8 !> Overloads division operator for a tensor and a scalar of type int16. function torch_tensor_postdivide_int16(tensor, divisor) result(output) - use, intrinsic :: iso_c_binding, only : c_loc use, intrinsic :: iso_fortran_env, only : int16 type(torch_tensor), intent(in) :: tensor integer(int16), intent(in) :: divisor @@ -2707,7 +2693,6 @@ end function torch_tensor_postdivide_int16 !> Overloads division operator for a tensor and a scalar of type int32. function torch_tensor_postdivide_int32(tensor, divisor) result(output) - use, intrinsic :: iso_c_binding, only : c_loc use, intrinsic :: iso_fortran_env, only : int32 type(torch_tensor), intent(in) :: tensor integer(int32), intent(in) :: divisor @@ -2721,7 +2706,6 @@ end function torch_tensor_postdivide_int32 !> Overloads division operator for a tensor and a scalar of type int64. function torch_tensor_postdivide_int64(tensor, divisor) result(output) - use, intrinsic :: iso_c_binding, only : c_loc use, intrinsic :: iso_fortran_env, only : int64 type(torch_tensor), intent(in) :: tensor integer(int64), intent(in) :: divisor @@ -2735,7 +2719,6 @@ end function torch_tensor_postdivide_int64 !> Overloads division operator for a tensor and a scalar of type real32. function torch_tensor_postdivide_real32(tensor, divisor) result(output) - use, intrinsic :: iso_c_binding, only : c_loc use, intrinsic :: iso_fortran_env, only : real32 type(torch_tensor), intent(in) :: tensor real(real32), intent(in) :: divisor @@ -2749,7 +2732,6 @@ end function torch_tensor_postdivide_real32 !> Overloads division operator for a tensor and a scalar of type real64. function torch_tensor_postdivide_real64(tensor, divisor) result(output) - use, intrinsic :: iso_c_binding, only : c_loc use, intrinsic :: iso_fortran_env, only : real64 type(torch_tensor), intent(in) :: tensor real(real64), intent(in) :: divisor diff --git a/src/ftorch.fypp b/src/ftorch.fypp index d8789606..b4d96390 100644 --- a/src/ftorch.fypp +++ b/src/ftorch.fypp @@ -725,7 +725,6 @@ contains #:for PREC in PRECISIONS !> Overloads multiplication operator for a scalar of type ${PREC}$ and a tensor. function torch_tensor_premultiply_${PREC}$(scalar, tensor) result(output) - use, intrinsic :: iso_c_binding, only : c_loc use, intrinsic :: iso_fortran_env, only : ${PREC}$ ${f_type(PREC)}$(${PREC}$), target, intent(in) :: scalar type(torch_tensor), intent(in) :: tensor @@ -742,7 +741,6 @@ contains #:for PREC in PRECISIONS !> Overloads multiplication operator for a tensor and a scalar of type ${PREC}$. function torch_tensor_postmultiply_${PREC}$(tensor, scalar) result(output) - use, intrinsic :: iso_c_binding, only : c_loc use, intrinsic :: iso_fortran_env, only : ${PREC}$ type(torch_tensor), intent(in) :: tensor ${f_type(PREC)}$(${PREC}$), intent(in) :: scalar @@ -767,7 +765,6 @@ contains #:for PREC in PRECISIONS !> Overloads division operator for a tensor and a scalar of type ${PREC}$. function torch_tensor_postdivide_${PREC}$(tensor, divisor) result(output) - use, intrinsic :: iso_c_binding, only : c_loc use, intrinsic :: iso_fortran_env, only : ${PREC}$ type(torch_tensor), intent(in) :: tensor ${f_type(PREC)}$(${PREC}$), intent(in) :: divisor From c81530132007097f33afee952165dd80536787c3 Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Mon, 17 Feb 2025 15:27:51 +0000 Subject: [PATCH 19/24] FIXME- backward needs intent(inout) --- src/ctorch.cpp | 5 +++-- src/ctorch.h | 2 +- src/ftorch.F90 | 4 ++-- src/ftorch.fypp | 4 ++-- 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/ctorch.cpp b/src/ctorch.cpp index 9ffe993d..f6ace212 100644 --- a/src/ctorch.cpp +++ b/src/ctorch.cpp @@ -407,10 +407,11 @@ torch_tensor_t torch_tensor_power_float(const torch_tensor_t tensor, // --- Functions related to automatic differentiation functionality for tensors // ============================================================================= -void torch_tensor_backward(const torch_tensor_t tensor, +void torch_tensor_backward(torch_tensor_t tensor, const torch_tensor_t external_gradient) { - auto t = reinterpret_cast(tensor); + auto t = reinterpret_cast(tensor); auto g = reinterpret_cast(external_gradient); + // FIXME: tensor needs to not be const but this crashes t->backward(*g); } diff --git a/src/ctorch.h b/src/ctorch.h index e84282e2..73fc5252 100644 --- a/src/ctorch.h +++ b/src/ctorch.h @@ -252,7 +252,7 @@ EXPORT_C torch_tensor_t torch_tensor_power_float(const torch_tensor_t tensor, * @param Tensor to perform back-propagation on * @param Tensor with an external gradient to supply for the back-propagation */ -EXPORT_C void torch_tensor_backward(const torch_tensor_t tensor, +EXPORT_C void torch_tensor_backward(torch_tensor_t tensor, const torch_tensor_t external_gradient); /** diff --git a/src/ftorch.F90 b/src/ftorch.F90 index ea214c5f..3c26193f 100644 --- a/src/ftorch.F90 +++ b/src/ftorch.F90 @@ -2884,7 +2884,7 @@ end function torch_tensor_power_real64 !> Performs back-propagation on a Torch Tensor, given some external gradient. subroutine torch_tensor_backward(tensor, external_gradient) - type(torch_tensor), intent(in) :: tensor + type(torch_tensor), intent(inout) :: tensor type(torch_tensor), intent(in) :: external_gradient interface @@ -2892,7 +2892,7 @@ subroutine torch_tensor_backward_c(tensor_c, external_gradient_c) & bind(c, name = 'torch_tensor_backward') use, intrinsic :: iso_c_binding, only : c_ptr implicit none - type(c_ptr), value, intent(in) :: tensor_c + type(c_ptr), value, intent(inout) :: tensor_c type(c_ptr), value, intent(in) :: external_gradient_c end subroutine torch_tensor_backward_c end interface diff --git a/src/ftorch.fypp b/src/ftorch.fypp index b4d96390..16029e71 100644 --- a/src/ftorch.fypp +++ b/src/ftorch.fypp @@ -834,7 +834,7 @@ contains !> Performs back-propagation on a Torch Tensor, given some external gradient. subroutine torch_tensor_backward(tensor, external_gradient) - type(torch_tensor), intent(in) :: tensor + type(torch_tensor), intent(inout) :: tensor type(torch_tensor), intent(in) :: external_gradient interface @@ -842,7 +842,7 @@ contains bind(c, name = 'torch_tensor_backward') use, intrinsic :: iso_c_binding, only : c_ptr implicit none - type(c_ptr), value, intent(in) :: tensor_c + type(c_ptr), value, intent(inout) :: tensor_c type(c_ptr), value, intent(in) :: external_gradient_c end subroutine torch_tensor_backward_c end interface From 7e350e0ce663d83e41db0bde62c3c06e0483b9f5 Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Mon, 17 Feb 2025 16:15:39 +0000 Subject: [PATCH 20/24] Drop unused imports --- test/unit/test_tensor_operator_overloads.pf | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/test/unit/test_tensor_operator_overloads.pf b/test/unit/test_tensor_operator_overloads.pf index 0f32058d..477ee62c 100644 --- a/test/unit/test_tensor_operator_overloads.pf +++ b/test/unit/test_tensor_operator_overloads.pf @@ -10,7 +10,7 @@ module test_tensor_operator_overloads torch_tensor_delete, torch_tensor_empty, torch_tensor_from_array, & torch_tensor_to_array use ftorch_test_utils, only: assert_allclose - use, intrinsic :: iso_c_binding, only : c_associated, c_int64_t + use, intrinsic :: iso_c_binding, only : c_int64_t implicit none @@ -275,7 +275,6 @@ contains torch_tensor_from_array, torch_tensor_to_array use ftorch_test_utils, only: assert_allclose use, intrinsic :: iso_fortran_env, only: sp => real32 - use, intrinsic :: iso_c_binding, only : c_associated, c_int64_t implicit none @@ -683,7 +682,6 @@ contains subroutine test_torch_tensor_sqrt(this) use ftorch, only: operator(**) use, intrinsic :: iso_fortran_env, only: sp => real32 - use, intrinsic :: iso_c_binding, only : c_associated, c_int64_t implicit none From 3bc60dcbc74f8034f364502497e87b6f368a46e6 Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Mon, 17 Feb 2025 16:36:48 +0000 Subject: [PATCH 21/24] Set up external gradient of ones for now --- examples/6_Autograd/autograd.f90 | 6 +- src/ctorch.cpp | 6 +- src/ctorch.h | 2 +- src/ftorch.F90 | 18 +++-- src/ftorch.fypp | 18 +++-- ...test_tensor_operator_overloads_autograd.pf | 67 ++++++------------- 6 files changed, 54 insertions(+), 63 deletions(-) diff --git a/examples/6_Autograd/autograd.f90 b/examples/6_Autograd/autograd.f90 index 72a93ba1..d1a5a6a8 100644 --- a/examples/6_Autograd/autograd.f90 +++ b/examples/6_Autograd/autograd.f90 @@ -24,7 +24,7 @@ program example integer :: tensor_layout(1) = [1] ! Set up Torch data structures - type(torch_tensor) :: a, b, Q, external_gradient, dQda, dQdb + type(torch_tensor) :: a, b, Q, dQda, dQdb ! Initialise input arrays as in Python example in_data1(:) = [2.0_wp, 3.0_wp] @@ -54,8 +54,7 @@ program example ! Back-propagation in_data3(:) = [1.0_wp, 1.0_wp] - call torch_tensor_from_array(external_gradient, in_data3, tensor_layout, torch_kCPU) - call torch_tensor_backward(Q, external_gradient) + call torch_tensor_backward(Q) dQda = a%grad() dQdb = b%grad() @@ -90,7 +89,6 @@ subroutine clean_up() call torch_tensor_delete(a) call torch_tensor_delete(b) call torch_tensor_delete(Q) - call torch_tensor_delete(external_gradient) end subroutine clean_up end program example diff --git a/src/ctorch.cpp b/src/ctorch.cpp index f6ace212..3971af2b 100644 --- a/src/ctorch.cpp +++ b/src/ctorch.cpp @@ -407,10 +407,14 @@ torch_tensor_t torch_tensor_power_float(const torch_tensor_t tensor, // --- Functions related to automatic differentiation functionality for tensors // ============================================================================= -void torch_tensor_backward(torch_tensor_t tensor, +void torch_tensor_backward(const torch_tensor_t tensor, const torch_tensor_t external_gradient) { auto t = reinterpret_cast(tensor); auto g = reinterpret_cast(external_gradient); + std::cout << "[DEBUG] tensor"; + torch_tensor_print(tensor); + std::cout << "[DEBUG] external_gradient"; + torch_tensor_print(external_gradient); // FIXME: tensor needs to not be const but this crashes t->backward(*g); } diff --git a/src/ctorch.h b/src/ctorch.h index 73fc5252..e84282e2 100644 --- a/src/ctorch.h +++ b/src/ctorch.h @@ -252,7 +252,7 @@ EXPORT_C torch_tensor_t torch_tensor_power_float(const torch_tensor_t tensor, * @param Tensor to perform back-propagation on * @param Tensor with an external gradient to supply for the back-propagation */ -EXPORT_C void torch_tensor_backward(torch_tensor_t tensor, +EXPORT_C void torch_tensor_backward(const torch_tensor_t tensor, const torch_tensor_t external_gradient); /** diff --git a/src/ftorch.F90 b/src/ftorch.F90 index 3c26193f..5312488f 100644 --- a/src/ftorch.F90 +++ b/src/ftorch.F90 @@ -2883,23 +2883,31 @@ end function torch_tensor_power_real64 ! ============================================================================ !> Performs back-propagation on a Torch Tensor, given some external gradient. - subroutine torch_tensor_backward(tensor, external_gradient) - type(torch_tensor), intent(inout) :: tensor - type(torch_tensor), intent(in) :: external_gradient + subroutine torch_tensor_backward(tensor) + type(torch_tensor), intent(in) :: tensor + type(torch_tensor) :: external_gradient interface subroutine torch_tensor_backward_c(tensor_c, external_gradient_c) & bind(c, name = 'torch_tensor_backward') use, intrinsic :: iso_c_binding, only : c_ptr implicit none - type(c_ptr), value, intent(inout) :: tensor_c + type(c_ptr), value, intent(in) :: tensor_c type(c_ptr), value, intent(in) :: external_gradient_c end subroutine torch_tensor_backward_c end interface - ! TODO: Make external_gradient optional, setting to ones by default + ! External gradient to provide to the back-propagation consisting of a tensor of ones + ! TODO: Accept other external gradients as an optional argument + call torch_tensor_ones(external_gradient, tensor%get_rank(), tensor%get_shape(), & + tensor%get_dtype(), tensor%get_device_type(), & + device_index=tensor%get_device_index()) + ! Call back-propagation with the provided external gradient call torch_tensor_backward_c(tensor%p, external_gradient%p) + + ! Delete the external gradient tensor + call torch_tensor_delete(external_gradient) end subroutine torch_tensor_backward !> Retreives the gradient of a Torch Tensor. diff --git a/src/ftorch.fypp b/src/ftorch.fypp index 16029e71..47f2da93 100644 --- a/src/ftorch.fypp +++ b/src/ftorch.fypp @@ -833,23 +833,31 @@ contains ! ============================================================================ !> Performs back-propagation on a Torch Tensor, given some external gradient. - subroutine torch_tensor_backward(tensor, external_gradient) - type(torch_tensor), intent(inout) :: tensor - type(torch_tensor), intent(in) :: external_gradient + subroutine torch_tensor_backward(tensor) + type(torch_tensor), intent(in) :: tensor + type(torch_tensor) :: external_gradient interface subroutine torch_tensor_backward_c(tensor_c, external_gradient_c) & bind(c, name = 'torch_tensor_backward') use, intrinsic :: iso_c_binding, only : c_ptr implicit none - type(c_ptr), value, intent(inout) :: tensor_c + type(c_ptr), value, intent(in) :: tensor_c type(c_ptr), value, intent(in) :: external_gradient_c end subroutine torch_tensor_backward_c end interface - ! TODO: Make external_gradient optional, setting to ones by default + ! External gradient to provide to the back-propagation consisting of a tensor of ones + ! TODO: Accept other external gradients as an optional argument + call torch_tensor_ones(external_gradient, tensor%get_rank(), tensor%get_shape(), & + tensor%get_dtype(), tensor%get_device_type(), & + device_index=tensor%get_device_index()) + ! Call back-propagation with the provided external gradient call torch_tensor_backward_c(tensor%p, external_gradient%p) + + ! Delete the external gradient tensor + call torch_tensor_delete(external_gradient) end subroutine torch_tensor_backward !> Retreives the gradient of a Torch Tensor. diff --git a/test/unit/test_tensor_operator_overloads_autograd.pf b/test/unit/test_tensor_operator_overloads_autograd.pf index 4a39fcac..18352de4 100644 --- a/test/unit/test_tensor_operator_overloads_autograd.pf +++ b/test/unit/test_tensor_operator_overloads_autograd.pf @@ -75,7 +75,7 @@ contains integer, parameter :: wp = sp class(TestCaseType), intent(inout) :: this - type(torch_tensor) :: Q, a, external_gradient, dQda + type(torch_tensor) :: Q, a, dQda integer, parameter :: ndims = 2 integer(ftorch_int), parameter :: tensor_layout(ndims) = [1, 2] integer(c_int64_t), parameter, dimension(ndims) :: tensor_shape = [2, 3] @@ -96,9 +96,7 @@ contains Q = a ! Apply back-propagation - ! TODO: Automate choice of ones for external_gradient - call torch_tensor_ones(external_gradient, ndims, tensor_shape, dtype, device_type) - call torch_tensor_backward(Q, external_gradient) + call torch_tensor_backward(Q) dQda = a%grad() ! Extract Fortran array from the computed gradient and its data with the expected value: @@ -121,7 +119,6 @@ contains call torch_tensor_delete(a) call torch_tensor_delete(Q) call torch_tensor_delete(dQda) - call torch_tensor_delete(external_gradient) end subroutine clean_up end subroutine test_torch_tensor_assign @@ -137,7 +134,7 @@ contains integer, parameter :: wp = sp class(TestCaseType), intent(inout) :: this - type(torch_tensor) :: Q, a, b, external_gradient, dQda, dQdb + type(torch_tensor) :: Q, a, b, dQda, dQdb integer, parameter :: ndims = 2 integer(ftorch_int), parameter :: tensor_layout(ndims) = [1, 2] integer(c_int64_t), parameter, dimension(ndims) :: tensor_shape = [2, 3] @@ -160,9 +157,7 @@ contains Q = a + b ! Apply back-propagation - ! TODO: Automate choice of ones for external_gradient - call torch_tensor_ones(external_gradient, ndims, tensor_shape, dtype, device_type) - call torch_tensor_backward(Q, external_gradient) + call torch_tensor_backward(Q) dQda = a%grad() dQdb = b%grad() @@ -198,7 +193,6 @@ contains call torch_tensor_delete(Q) call torch_tensor_delete(dQda) call torch_tensor_delete(dQdb) - call torch_tensor_delete(external_gradient) end subroutine clean_up end subroutine test_torch_tensor_add @@ -214,7 +208,7 @@ contains integer, parameter :: wp = sp class(TestCaseType), intent(inout) :: this - type(torch_tensor) :: Q, a, external_gradient, dQda + type(torch_tensor) :: Q, a, dQda integer, parameter :: ndims = 2 integer(ftorch_int), parameter :: tensor_layout(ndims) = [1, 2] integer(c_int64_t), parameter, dimension(ndims) :: tensor_shape = [2, 3] @@ -235,9 +229,7 @@ contains Q = -a ! Apply back-propagation - ! TODO: Automate choice of ones for external_gradient - call torch_tensor_ones(external_gradient, ndims, tensor_shape, dtype, device_type) - call torch_tensor_backward(Q, external_gradient) + call torch_tensor_backward(Q) dQda = a%grad() ! Extract Fortran array from the computed gradient and its data with the expected value: @@ -275,7 +267,7 @@ contains integer, parameter :: wp = sp class(TestCaseType), intent(inout) :: this - type(torch_tensor) :: Q, a, b, external_gradient, dQda, dQdb + type(torch_tensor) :: Q, a, b, dQda, dQdb integer, parameter :: ndims = 2 integer(ftorch_int), parameter :: tensor_layout(ndims) = [1, 2] integer(c_int64_t), parameter, dimension(ndims) :: tensor_shape = [2, 3] @@ -298,9 +290,7 @@ contains Q = a - b ! Apply back-propagation - ! TODO: Automate choice of ones for external_gradient - call torch_tensor_ones(external_gradient, ndims, tensor_shape, dtype, device_type) - call torch_tensor_backward(Q, external_gradient) + call torch_tensor_backward(Q) dQda = a%grad() dQdb = b%grad() @@ -336,7 +326,6 @@ contains call torch_tensor_delete(Q) call torch_tensor_delete(dQda) call torch_tensor_delete(dQdb) - call torch_tensor_delete(external_gradient) end subroutine clean_up end subroutine test_torch_tensor_subtract @@ -352,7 +341,7 @@ contains integer, parameter :: wp = sp class(TestCaseType), intent(inout) :: this - type(torch_tensor) :: Q, a, b, external_gradient, dQda, dQdb + type(torch_tensor) :: Q, a, b, dQda, dQdb integer, parameter :: ndims = 2 integer(ftorch_int), parameter :: tensor_layout(ndims) = [1, 2] integer(c_int64_t), parameter, dimension(ndims) :: tensor_shape = [2, 3] @@ -375,9 +364,7 @@ contains Q = a * b ! Apply back-propagation - ! TODO: Automate choice of ones for external_gradient - call torch_tensor_ones(external_gradient, ndims, tensor_shape, dtype, device_type) - call torch_tensor_backward(Q, external_gradient) + call torch_tensor_backward(Q) dQda = a%grad() dQdb = b%grad() @@ -413,7 +400,6 @@ contains call torch_tensor_delete(Q) call torch_tensor_delete(dQda) call torch_tensor_delete(dQdb) - call torch_tensor_delete(external_gradient) end subroutine clean_up end subroutine test_torch_tensor_multiply @@ -429,7 +415,7 @@ contains integer, parameter :: wp = sp class(TestCaseType), intent(inout) :: this - type(torch_tensor) :: Q, a, external_gradient, dQda + type(torch_tensor) :: Q, a, dQda integer, parameter :: ndims = 2 integer(ftorch_int), parameter :: tensor_layout(ndims) = [1, 2] integer(c_int64_t), parameter, dimension(ndims) :: tensor_shape = [2, 3] @@ -456,9 +442,7 @@ contains end if ! Apply back-propagation - ! TODO: Automate choice of ones for external_gradient - call torch_tensor_ones(external_gradient, ndims, tensor_shape, dtype, device_type) - call torch_tensor_backward(Q, external_gradient) + call torch_tensor_backward(Q) dQda = a%grad() ! Extract Fortran array from the first computed gradient and its data with the expected value: @@ -483,7 +467,6 @@ contains call torch_tensor_delete(a) call torch_tensor_delete(Q) call torch_tensor_delete(dQda) - call torch_tensor_delete(external_gradient) end subroutine clean_up end subroutine test_torch_tensor_scalar_multiply @@ -499,7 +482,7 @@ contains integer, parameter :: wp = sp class(TestCaseType), intent(inout) :: this - type(torch_tensor) :: Q, a, b, external_gradient, dQda, dQdb + type(torch_tensor) :: Q, a, b, dQda, dQdb integer, parameter :: ndims = 2 integer(ftorch_int), parameter :: tensor_layout(ndims) = [1, 2] integer(c_int64_t), parameter, dimension(ndims) :: tensor_shape = [2, 3] @@ -522,9 +505,7 @@ contains Q = a / b ! Apply back-propagation - ! TODO: Automate choice of ones for external_gradient - call torch_tensor_ones(external_gradient, ndims, tensor_shape, dtype, device_type) - call torch_tensor_backward(Q, external_gradient) + call torch_tensor_backward(Q) dQda = a%grad() dQdb = b%grad() @@ -560,7 +541,6 @@ contains call torch_tensor_delete(Q) call torch_tensor_delete(dQda) call torch_tensor_delete(dQdb) - call torch_tensor_delete(external_gradient) end subroutine clean_up end subroutine test_torch_tensor_divide @@ -576,7 +556,7 @@ contains integer, parameter :: wp = sp class(TestCaseType), intent(inout) :: this - type(torch_tensor) :: Q, a, external_gradient, dQda + type(torch_tensor) :: Q, a, dQda integer, parameter :: ndims = 2 integer(ftorch_int), parameter :: tensor_layout(ndims) = [1, 2] integer(c_int64_t), parameter, dimension(ndims) :: tensor_shape = [2, 3] @@ -599,9 +579,7 @@ contains Q = a / scalar ! Apply back-propagation - ! TODO: Automate choice of ones for external_gradient - call torch_tensor_ones(external_gradient, ndims, tensor_shape, dtype, device_type) - call torch_tensor_backward(Q, external_gradient) + call torch_tensor_backward(Q) dQda = a%grad() ! Extract Fortran array from the first computed gradient and its data with the expected value: @@ -626,7 +604,6 @@ contains call torch_tensor_delete(a) call torch_tensor_delete(Q) call torch_tensor_delete(dQda) - call torch_tensor_delete(external_gradient) end subroutine clean_up end subroutine test_torch_tensor_scalar_divide @@ -642,7 +619,7 @@ contains integer, parameter :: wp = sp class(TestCaseType), intent(inout) :: this - type(torch_tensor) :: Q, a, external_gradient, dQda + type(torch_tensor) :: Q, a, dQda integer, parameter :: ndims = 2 integer(ftorch_int), parameter :: tensor_layout(ndims) = [1, 2] integer(c_int64_t), parameter, dimension(ndims) :: tensor_shape = [2, 3] @@ -667,9 +644,7 @@ contains end if ! Apply back-propagation - ! TODO: Automate choice of ones for external_gradient - call torch_tensor_ones(external_gradient, ndims, tensor_shape, dtype, device_type) - call torch_tensor_backward(Q, external_gradient) + call torch_tensor_backward(Q) dQda = a%grad() ! Extract Fortran array from the computed gradient and its data with the expected value: @@ -707,7 +682,7 @@ contains integer, parameter :: wp = sp class(TestCaseType), intent(inout) :: this - type(torch_tensor) :: Q, a, external_gradient, dQda + type(torch_tensor) :: Q, a, dQda integer, parameter :: ndims = 2 integer(ftorch_int), parameter :: tensor_layout(ndims) = [1, 2] integer(c_int64_t), parameter, dimension(ndims) :: tensor_shape = [2, 3] @@ -728,9 +703,7 @@ contains Q = a ** 0.5 ! Apply back-propagation - ! TODO: Automate choice of ones for external_gradient - call torch_tensor_ones(external_gradient, ndims, tensor_shape, dtype, device_type) - call torch_tensor_backward(Q, external_gradient) + call torch_tensor_backward(Q) dQda = a%grad() ! Extract Fortran array from the computed gradient and its data with the expected value: From 7e3fe439d69affd9468ae91f566c548950830cd6 Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Tue, 18 Feb 2025 15:06:39 +0000 Subject: [PATCH 22/24] Scalar multiplication and division using rank-1 tensors --- examples/6_Autograd/autograd.f90 | 52 ++-- src/ctorch.cpp | 13 - src/ftorch.F90 | 254 ------------------ src/ftorch.fypp | 54 ---- test/unit/test_tensor_operator_overloads.pf | 18 +- ...test_tensor_operator_overloads_autograd.pf | 24 +- 6 files changed, 51 insertions(+), 364 deletions(-) diff --git a/examples/6_Autograd/autograd.f90 b/examples/6_Autograd/autograd.f90 index d1a5a6a8..ceb50b0f 100644 --- a/examples/6_Autograd/autograd.f90 +++ b/examples/6_Autograd/autograd.f90 @@ -18,59 +18,51 @@ program example ! Set up Fortran data structures integer, parameter :: n = 2 - real(wp), dimension(n), target :: in_data1, in_data2, in_data3 real(wp), dimension(:), pointer :: out_data1, out_data2, out_data3 - real(wp), dimension(n) :: expected integer :: tensor_layout(1) = [1] ! Set up Torch data structures - type(torch_tensor) :: a, b, Q, dQda, dQdb + type(torch_tensor) :: a, b, Q, dQda, dQdb, multiplier, divisor - ! Initialise input arrays as in Python example - in_data1(:) = [2.0_wp, 3.0_wp] - in_data2(:) = [6.0_wp, 4.0_wp] + ! Initialise input arrays as in the Python example and construct Torch Tensors from them + call torch_tensor_from_array(a, [2.0_wp, 3.0_wp], tensor_layout, torch_kCPU, requires_grad=.true.) + call torch_tensor_from_array(b, [6.0_wp, 4.0_wp], tensor_layout, torch_kCPU, requires_grad=.true.) - ! Construct a Torch Tensor from a Fortran array - call torch_tensor_from_array(a, in_data1, tensor_layout, torch_kCPU, requires_grad=.true.) - call torch_tensor_from_array(b, in_data2, tensor_layout, torch_kCPU, requires_grad=.true.) + ! Scalar multiplication and division are not currently implemented in FTorch. However, you can + ! achieve the same thing by defining a rank-1 tensor with a single entry, as follows: + call torch_tensor_from_array(multiplier, [3.0_wp], tensor_layout, torch_kCPU) + call torch_tensor_from_array(divisor, [3.0_wp], tensor_layout, torch_kCPU) - ! Check arithmetic operations work for torch_tensors - write (*,*) "a = ", in_data1(:) - write (*,*) "b = ", in_data2(:) - Q = 3 * (a**3 - b * b / 3) - ! FIXME: Something seems off with gradients related to scalar multiplication and/or division + ! Compute the same mathematical expression as in the Python example + Q = multiplier * (a**3 - b * b / divisor) - ! Extract a Fortran array from a Torch tensor - call torch_tensor_to_array(Q, out_data1, shape(in_data1)) - write (*,*) "Q = 3 * (a ** 3 - b * b / 2) =", out_data1(:) + ! Extract a Fortran array from the Torch tensor + call torch_tensor_to_array(Q, out_data1, [2]) + write (*,*) "Q = 3 * (a^3 - b*b/3) = 3*a^3 - b^2 = ", out_data1(:) ! Check output tensor matches expected value - expected(:) = [-12.0_wp, 65.0_wp] - if (.not. assert_allclose(out_data1, expected, test_name="autograd_Q")) then + if (.not. assert_allclose(out_data1, [-12.0_wp, 65.0_wp], test_name="autograd_Q")) then call clean_up() print *, "Error :: value of Q does not match expected value" stop 999 end if - ! Back-propagation - in_data3(:) = [1.0_wp, 1.0_wp] + ! Run the back-propagation operator call torch_tensor_backward(Q) dQda = a%grad() dQdb = b%grad() ! Extract Fortran arrays from the Torch tensors and check the gradients take expected values - call torch_tensor_to_array(dQda, out_data2, shape(in_data1)) - print *, "dQda", out_data2 - expected(:) = [36.0_wp, 81.0_wp] - if (.not. assert_allclose(out_data2, expected, test_name="autograd_dQdb")) then + call torch_tensor_to_array(dQda, out_data2, [2]) + print *, "dQda = 9*a^2 = ", out_data2 + if (.not. assert_allclose(out_data2, [36.0_wp, 81.0_wp], test_name="autograd_dQdb")) then call clean_up() print *, "Error :: value of dQdb does not match expected value" stop 999 end if - call torch_tensor_to_array(dQdb, out_data3, shape(in_data1)) - print *, "dQdb", out_data3 - expected(:) = [-12.0_wp, -8.0_wp] - if (.not. assert_allclose(out_data3, expected, test_name="autograd_dQdb")) then + call torch_tensor_to_array(dQdb, out_data3, [2]) + print *, "dQdb = - 2*b = ", out_data3 + if (.not. assert_allclose(out_data3, [-12.0_wp, -8.0_wp], test_name="autograd_dQdb")) then call clean_up() print *, "Error :: value of dQdb does not match expected value" stop 999 @@ -88,6 +80,8 @@ subroutine clean_up() nullify(out_data3) call torch_tensor_delete(a) call torch_tensor_delete(b) + call torch_tensor_delete(multiplier) + call torch_tensor_delete(divisor) call torch_tensor_delete(Q) end subroutine clean_up diff --git a/src/ctorch.cpp b/src/ctorch.cpp index 3971af2b..c0c040c5 100644 --- a/src/ctorch.cpp +++ b/src/ctorch.cpp @@ -212,7 +212,6 @@ torch_tensor_t torch_from_blob(void *data, int ndim, const int64_t *shape, tensor = new torch::Tensor; *tensor = torch::from_blob(data, vshape, vstrides, options); - std::cout << "[DEBUG]: blob " << tensor->requires_grad() << std::endl; // TODO } catch (const torch::Error &e) { std::cerr << "[ERROR]: " << e.msg() << std::endl; delete tensor; @@ -323,7 +322,6 @@ torch_tensor_t torch_tensor_assign(const torch_tensor_t input) { torch::Tensor *output = nullptr; output = new torch::Tensor; *output = *in; - std::cout << "[DEBUG]: assign " << output->requires_grad() << std::endl; // TODO return output; } @@ -334,7 +332,6 @@ torch_tensor_t torch_tensor_add(const torch_tensor_t tensor1, torch::Tensor *output = nullptr; output = new torch::Tensor; *output = *t1 + *t2; - std::cout << "[DEBUG]: add " << output->requires_grad() << std::endl; // TODO return output; } @@ -353,7 +350,6 @@ torch_tensor_t torch_tensor_subtract(const torch_tensor_t tensor1, torch::Tensor *output = nullptr; output = new torch::Tensor; *output = *t1 - *t2; - std::cout << "[DEBUG]: subtract " << output->requires_grad() << std::endl; // TODO return output; } @@ -364,7 +360,6 @@ torch_tensor_t torch_tensor_multiply(const torch_tensor_t tensor1, torch::Tensor *output = nullptr; output = new torch::Tensor; *output = *t1 * *t2; - std::cout << "[DEBUG]: multiply " << output->requires_grad() << std::endl; // TODO return output; } @@ -375,7 +370,6 @@ torch_tensor_t torch_tensor_divide(const torch_tensor_t tensor1, torch::Tensor *output = nullptr; output = new torch::Tensor; *output = *t1 / *t2; - std::cout << "[DEBUG]: divide " << output->requires_grad() << std::endl; // TODO return output; } @@ -387,7 +381,6 @@ torch_tensor_t torch_tensor_power_int(const torch_tensor_t tensor, torch::Tensor *output = nullptr; output = new torch::Tensor; *output = pow(*t, *exp); - std::cout << "[DEBUG]: power_int " << output->requires_grad() << std::endl; // TODO return output; } @@ -399,7 +392,6 @@ torch_tensor_t torch_tensor_power_float(const torch_tensor_t tensor, torch::Tensor *output = nullptr; output = new torch::Tensor; *output = pow(*t, *exp); - std::cout << "[DEBUG]: power_float " << output->requires_grad() << std::endl; // TODO return output; } @@ -411,11 +403,6 @@ void torch_tensor_backward(const torch_tensor_t tensor, const torch_tensor_t external_gradient) { auto t = reinterpret_cast(tensor); auto g = reinterpret_cast(external_gradient); - std::cout << "[DEBUG] tensor"; - torch_tensor_print(tensor); - std::cout << "[DEBUG] external_gradient"; - torch_tensor_print(external_gradient); - // FIXME: tensor needs to not be const but this crashes t->backward(*g); } diff --git a/src/ftorch.F90 b/src/ftorch.F90 index 5312488f..c990e3d7 100644 --- a/src/ftorch.F90 +++ b/src/ftorch.F90 @@ -191,18 +191,6 @@ end function torch_to_blob_c interface operator (*) module procedure torch_tensor_multiply - module procedure torch_tensor_premultiply_int8 - module procedure torch_tensor_postmultiply_int8 - module procedure torch_tensor_premultiply_int16 - module procedure torch_tensor_postmultiply_int16 - module procedure torch_tensor_premultiply_int32 - module procedure torch_tensor_postmultiply_int32 - module procedure torch_tensor_premultiply_int64 - module procedure torch_tensor_postmultiply_int64 - module procedure torch_tensor_premultiply_real32 - module procedure torch_tensor_postmultiply_real32 - module procedure torch_tensor_premultiply_real64 - module procedure torch_tensor_postmultiply_real64 end interface interface @@ -218,12 +206,6 @@ end function torch_tensor_multiply_c interface operator (/) module procedure torch_tensor_divide - module procedure torch_tensor_postdivide_int8 - module procedure torch_tensor_postdivide_int16 - module procedure torch_tensor_postdivide_int32 - module procedure torch_tensor_postdivide_int64 - module procedure torch_tensor_postdivide_real32 - module procedure torch_tensor_postdivide_real64 end interface interface @@ -2499,163 +2481,6 @@ function torch_tensor_multiply(tensor1, tensor2) result(output) output%p = torch_tensor_multiply_c(tensor1%p, tensor2%p) end function torch_tensor_multiply - !> Overloads multiplication operator for a scalar of type int8 and a tensor. - function torch_tensor_premultiply_int8(scalar, tensor) result(output) - use, intrinsic :: iso_fortran_env, only : int8 - integer(int8), target, intent(in) :: scalar - type(torch_tensor), intent(in) :: tensor - type(torch_tensor) :: wrk, output - - ! Create a tensor with a single entry, the scalar pre-multiplier - call torch_tensor_from_array(wrk, [scalar], [1], tensor%get_device_type(), & - tensor%get_device_index()) - output%p = torch_tensor_multiply_c(wrk%p, tensor%p) - end function torch_tensor_premultiply_int8 - - !> Overloads multiplication operator for a scalar of type int16 and a tensor. - function torch_tensor_premultiply_int16(scalar, tensor) result(output) - use, intrinsic :: iso_fortran_env, only : int16 - integer(int16), target, intent(in) :: scalar - type(torch_tensor), intent(in) :: tensor - type(torch_tensor) :: wrk, output - - ! Create a tensor with a single entry, the scalar pre-multiplier - call torch_tensor_from_array(wrk, [scalar], [1], tensor%get_device_type(), & - tensor%get_device_index()) - output%p = torch_tensor_multiply_c(wrk%p, tensor%p) - end function torch_tensor_premultiply_int16 - - !> Overloads multiplication operator for a scalar of type int32 and a tensor. - function torch_tensor_premultiply_int32(scalar, tensor) result(output) - use, intrinsic :: iso_fortran_env, only : int32 - integer(int32), target, intent(in) :: scalar - type(torch_tensor), intent(in) :: tensor - type(torch_tensor) :: wrk, output - - ! Create a tensor with a single entry, the scalar pre-multiplier - call torch_tensor_from_array(wrk, [scalar], [1], tensor%get_device_type(), & - tensor%get_device_index()) - output%p = torch_tensor_multiply_c(wrk%p, tensor%p) - end function torch_tensor_premultiply_int32 - - !> Overloads multiplication operator for a scalar of type int64 and a tensor. - function torch_tensor_premultiply_int64(scalar, tensor) result(output) - use, intrinsic :: iso_fortran_env, only : int64 - integer(int64), target, intent(in) :: scalar - type(torch_tensor), intent(in) :: tensor - type(torch_tensor) :: wrk, output - - ! Create a tensor with a single entry, the scalar pre-multiplier - call torch_tensor_from_array(wrk, [scalar], [1], tensor%get_device_type(), & - tensor%get_device_index()) - output%p = torch_tensor_multiply_c(wrk%p, tensor%p) - end function torch_tensor_premultiply_int64 - - !> Overloads multiplication operator for a scalar of type real32 and a tensor. - function torch_tensor_premultiply_real32(scalar, tensor) result(output) - use, intrinsic :: iso_fortran_env, only : real32 - real(real32), target, intent(in) :: scalar - type(torch_tensor), intent(in) :: tensor - type(torch_tensor) :: wrk, output - - ! Create a tensor with a single entry, the scalar pre-multiplier - call torch_tensor_from_array(wrk, [scalar], [1], tensor%get_device_type(), & - tensor%get_device_index()) - output%p = torch_tensor_multiply_c(wrk%p, tensor%p) - end function torch_tensor_premultiply_real32 - - !> Overloads multiplication operator for a scalar of type real64 and a tensor. - function torch_tensor_premultiply_real64(scalar, tensor) result(output) - use, intrinsic :: iso_fortran_env, only : real64 - real(real64), target, intent(in) :: scalar - type(torch_tensor), intent(in) :: tensor - type(torch_tensor) :: wrk, output - - ! Create a tensor with a single entry, the scalar pre-multiplier - call torch_tensor_from_array(wrk, [scalar], [1], tensor%get_device_type(), & - tensor%get_device_index()) - output%p = torch_tensor_multiply_c(wrk%p, tensor%p) - end function torch_tensor_premultiply_real64 - - - !> Overloads multiplication operator for a tensor and a scalar of type int8. - function torch_tensor_postmultiply_int8(tensor, scalar) result(output) - use, intrinsic :: iso_fortran_env, only : int8 - type(torch_tensor), intent(in) :: tensor - integer(int8), intent(in) :: scalar - type(torch_tensor) :: wrk, output - - ! Create a tensor with a single entry, the scalar post-multiplier - call torch_tensor_from_array(wrk, [scalar], [1], tensor%get_device_type(), & - tensor%get_device_index()) - output%p = torch_tensor_multiply_c(tensor%p, wrk%p) - end function torch_tensor_postmultiply_int8 - - !> Overloads multiplication operator for a tensor and a scalar of type int16. - function torch_tensor_postmultiply_int16(tensor, scalar) result(output) - use, intrinsic :: iso_fortran_env, only : int16 - type(torch_tensor), intent(in) :: tensor - integer(int16), intent(in) :: scalar - type(torch_tensor) :: wrk, output - - ! Create a tensor with a single entry, the scalar post-multiplier - call torch_tensor_from_array(wrk, [scalar], [1], tensor%get_device_type(), & - tensor%get_device_index()) - output%p = torch_tensor_multiply_c(tensor%p, wrk%p) - end function torch_tensor_postmultiply_int16 - - !> Overloads multiplication operator for a tensor and a scalar of type int32. - function torch_tensor_postmultiply_int32(tensor, scalar) result(output) - use, intrinsic :: iso_fortran_env, only : int32 - type(torch_tensor), intent(in) :: tensor - integer(int32), intent(in) :: scalar - type(torch_tensor) :: wrk, output - - ! Create a tensor with a single entry, the scalar post-multiplier - call torch_tensor_from_array(wrk, [scalar], [1], tensor%get_device_type(), & - tensor%get_device_index()) - output%p = torch_tensor_multiply_c(tensor%p, wrk%p) - end function torch_tensor_postmultiply_int32 - - !> Overloads multiplication operator for a tensor and a scalar of type int64. - function torch_tensor_postmultiply_int64(tensor, scalar) result(output) - use, intrinsic :: iso_fortran_env, only : int64 - type(torch_tensor), intent(in) :: tensor - integer(int64), intent(in) :: scalar - type(torch_tensor) :: wrk, output - - ! Create a tensor with a single entry, the scalar post-multiplier - call torch_tensor_from_array(wrk, [scalar], [1], tensor%get_device_type(), & - tensor%get_device_index()) - output%p = torch_tensor_multiply_c(tensor%p, wrk%p) - end function torch_tensor_postmultiply_int64 - - !> Overloads multiplication operator for a tensor and a scalar of type real32. - function torch_tensor_postmultiply_real32(tensor, scalar) result(output) - use, intrinsic :: iso_fortran_env, only : real32 - type(torch_tensor), intent(in) :: tensor - real(real32), intent(in) :: scalar - type(torch_tensor) :: wrk, output - - ! Create a tensor with a single entry, the scalar post-multiplier - call torch_tensor_from_array(wrk, [scalar], [1], tensor%get_device_type(), & - tensor%get_device_index()) - output%p = torch_tensor_multiply_c(tensor%p, wrk%p) - end function torch_tensor_postmultiply_real32 - - !> Overloads multiplication operator for a tensor and a scalar of type real64. - function torch_tensor_postmultiply_real64(tensor, scalar) result(output) - use, intrinsic :: iso_fortran_env, only : real64 - type(torch_tensor), intent(in) :: tensor - real(real64), intent(in) :: scalar - type(torch_tensor) :: wrk, output - - ! Create a tensor with a single entry, the scalar post-multiplier - call torch_tensor_from_array(wrk, [scalar], [1], tensor%get_device_type(), & - tensor%get_device_index()) - output%p = torch_tensor_multiply_c(tensor%p, wrk%p) - end function torch_tensor_postmultiply_real64 - !> Overloads division operator for two tensors. function torch_tensor_divide(tensor1, tensor2) result(output) type(torch_tensor), intent(in) :: tensor1 @@ -2665,85 +2490,6 @@ function torch_tensor_divide(tensor1, tensor2) result(output) output%p = torch_tensor_divide_c(tensor1%p, tensor2%p) end function torch_tensor_divide - !> Overloads division operator for a tensor and a scalar of type int8. - function torch_tensor_postdivide_int8(tensor, divisor) result(output) - use, intrinsic :: iso_fortran_env, only : int8 - type(torch_tensor), intent(in) :: tensor - integer(int8), intent(in) :: divisor - type(torch_tensor) :: wrk, output - - ! Create a tensor with a single entry, the scalar post-divisor - call torch_tensor_from_array(wrk, [divisor], [1], tensor%get_device_type(), & - tensor%get_device_index()) - output%p = torch_tensor_divide_c(tensor%p, wrk%p) - end function torch_tensor_postdivide_int8 - - !> Overloads division operator for a tensor and a scalar of type int16. - function torch_tensor_postdivide_int16(tensor, divisor) result(output) - use, intrinsic :: iso_fortran_env, only : int16 - type(torch_tensor), intent(in) :: tensor - integer(int16), intent(in) :: divisor - type(torch_tensor) :: wrk, output - - ! Create a tensor with a single entry, the scalar post-divisor - call torch_tensor_from_array(wrk, [divisor], [1], tensor%get_device_type(), & - tensor%get_device_index()) - output%p = torch_tensor_divide_c(tensor%p, wrk%p) - end function torch_tensor_postdivide_int16 - - !> Overloads division operator for a tensor and a scalar of type int32. - function torch_tensor_postdivide_int32(tensor, divisor) result(output) - use, intrinsic :: iso_fortran_env, only : int32 - type(torch_tensor), intent(in) :: tensor - integer(int32), intent(in) :: divisor - type(torch_tensor) :: wrk, output - - ! Create a tensor with a single entry, the scalar post-divisor - call torch_tensor_from_array(wrk, [divisor], [1], tensor%get_device_type(), & - tensor%get_device_index()) - output%p = torch_tensor_divide_c(tensor%p, wrk%p) - end function torch_tensor_postdivide_int32 - - !> Overloads division operator for a tensor and a scalar of type int64. - function torch_tensor_postdivide_int64(tensor, divisor) result(output) - use, intrinsic :: iso_fortran_env, only : int64 - type(torch_tensor), intent(in) :: tensor - integer(int64), intent(in) :: divisor - type(torch_tensor) :: wrk, output - - ! Create a tensor with a single entry, the scalar post-divisor - call torch_tensor_from_array(wrk, [divisor], [1], tensor%get_device_type(), & - tensor%get_device_index()) - output%p = torch_tensor_divide_c(tensor%p, wrk%p) - end function torch_tensor_postdivide_int64 - - !> Overloads division operator for a tensor and a scalar of type real32. - function torch_tensor_postdivide_real32(tensor, divisor) result(output) - use, intrinsic :: iso_fortran_env, only : real32 - type(torch_tensor), intent(in) :: tensor - real(real32), intent(in) :: divisor - type(torch_tensor) :: wrk, output - - ! Create a tensor with a single entry, the scalar post-divisor - call torch_tensor_from_array(wrk, [divisor], [1], tensor%get_device_type(), & - tensor%get_device_index()) - output%p = torch_tensor_divide_c(tensor%p, wrk%p) - end function torch_tensor_postdivide_real32 - - !> Overloads division operator for a tensor and a scalar of type real64. - function torch_tensor_postdivide_real64(tensor, divisor) result(output) - use, intrinsic :: iso_fortran_env, only : real64 - type(torch_tensor), intent(in) :: tensor - real(real64), intent(in) :: divisor - type(torch_tensor) :: wrk, output - - ! Create a tensor with a single entry, the scalar post-divisor - call torch_tensor_from_array(wrk, [divisor], [1], tensor%get_device_type(), & - tensor%get_device_index()) - output%p = torch_tensor_divide_c(tensor%p, wrk%p) - end function torch_tensor_postdivide_real64 - - !> Overloads exponentiation operator for a tensor and a scalar of type `int8` function torch_tensor_power_int8(tensor, power) result(output) use, intrinsic :: iso_c_binding, only : c_loc diff --git a/src/ftorch.fypp b/src/ftorch.fypp index 47f2da93..777e5ef1 100644 --- a/src/ftorch.fypp +++ b/src/ftorch.fypp @@ -160,10 +160,6 @@ module ftorch interface operator (*) module procedure torch_tensor_multiply - #:for PREC in PRECISIONS - module procedure torch_tensor_premultiply_${PREC}$ - module procedure torch_tensor_postmultiply_${PREC}$ - #:endfor end interface interface @@ -179,9 +175,6 @@ module ftorch interface operator (/) module procedure torch_tensor_divide - #:for PREC in PRECISIONS - module procedure torch_tensor_postdivide_${PREC}$ - #:endfor end interface interface @@ -722,37 +715,6 @@ contains output%p = torch_tensor_multiply_c(tensor1%p, tensor2%p) end function torch_tensor_multiply - #:for PREC in PRECISIONS - !> Overloads multiplication operator for a scalar of type ${PREC}$ and a tensor. - function torch_tensor_premultiply_${PREC}$(scalar, tensor) result(output) - use, intrinsic :: iso_fortran_env, only : ${PREC}$ - ${f_type(PREC)}$(${PREC}$), target, intent(in) :: scalar - type(torch_tensor), intent(in) :: tensor - type(torch_tensor) :: wrk, output - - ! Create a tensor with a single entry, the scalar pre-multiplier - call torch_tensor_from_array(wrk, [scalar], [1], tensor%get_device_type(), & - tensor%get_device_index()) - output%p = torch_tensor_multiply_c(wrk%p, tensor%p) - end function torch_tensor_premultiply_${PREC}$ - - #:endfor - - #:for PREC in PRECISIONS - !> Overloads multiplication operator for a tensor and a scalar of type ${PREC}$. - function torch_tensor_postmultiply_${PREC}$(tensor, scalar) result(output) - use, intrinsic :: iso_fortran_env, only : ${PREC}$ - type(torch_tensor), intent(in) :: tensor - ${f_type(PREC)}$(${PREC}$), intent(in) :: scalar - type(torch_tensor) :: wrk, output - - ! Create a tensor with a single entry, the scalar post-multiplier - call torch_tensor_from_array(wrk, [scalar], [1], tensor%get_device_type(), & - tensor%get_device_index()) - output%p = torch_tensor_multiply_c(tensor%p, wrk%p) - end function torch_tensor_postmultiply_${PREC}$ - - #:endfor !> Overloads division operator for two tensors. function torch_tensor_divide(tensor1, tensor2) result(output) type(torch_tensor), intent(in) :: tensor1 @@ -762,22 +724,6 @@ contains output%p = torch_tensor_divide_c(tensor1%p, tensor2%p) end function torch_tensor_divide - #:for PREC in PRECISIONS - !> Overloads division operator for a tensor and a scalar of type ${PREC}$. - function torch_tensor_postdivide_${PREC}$(tensor, divisor) result(output) - use, intrinsic :: iso_fortran_env, only : ${PREC}$ - type(torch_tensor), intent(in) :: tensor - ${f_type(PREC)}$(${PREC}$), intent(in) :: divisor - type(torch_tensor) :: wrk, output - - ! Create a tensor with a single entry, the scalar post-divisor - call torch_tensor_from_array(wrk, [divisor], [1], tensor%get_device_type(), & - tensor%get_device_index()) - output%p = torch_tensor_divide_c(tensor%p, wrk%p) - end function torch_tensor_postdivide_${PREC}$ - - #:endfor - #:for PREC in INT_PRECISIONS !> Overloads exponentiation operator for a tensor and a scalar of type `${PREC}$` function torch_tensor_power_${PREC}$(tensor, power) result(output) diff --git a/test/unit/test_tensor_operator_overloads.pf b/test/unit/test_tensor_operator_overloads.pf index 477ee62c..6fb6fb02 100644 --- a/test/unit/test_tensor_operator_overloads.pf +++ b/test/unit/test_tensor_operator_overloads.pf @@ -415,7 +415,7 @@ contains integer, parameter :: wp = sp class(TestCaseType), intent(inout) :: this - type(torch_tensor) :: tensor1, tensor2 + type(torch_tensor) :: tensor1, tensor2, multiplier integer, parameter :: ndims = 2 integer(ftorch_int), parameter :: tensor_layout(ndims) = [1, 2] integer(c_int64_t), parameter, dimension(ndims) :: tensor_shape = [2, 3] @@ -432,13 +432,16 @@ contains ! Create a tensor based off the input array call torch_tensor_from_array(tensor1, in_data, tensor_layout, device_type) + ! Create a single valued rank-1 tensor based off the scalar + call torch_tensor_from_array(multiplier, [scalar], [1], device_type) + ! Create another empty tensors and assign it to the product of a scalar constant and the first ! tensor using the overloaded multiplication operator call torch_tensor_empty(tensor2, ndims, tensor_shape, dtype, device_type) if (this%param%switch) then - tensor2 = scalar * tensor1 + tensor2 = multiplier * tensor1 else - tensor2 = tensor1 * scalar + tensor2 = tensor1 * multiplier end if ! Check input array is unchanged by scalar multiplication @@ -470,6 +473,7 @@ contains nullify(out_data) call torch_tensor_delete(tensor1) call torch_tensor_delete(tensor2) + call torch_tensor_delete(multiplier) end subroutine clean_up end subroutine test_torch_tensor_scalar_multiply @@ -556,7 +560,7 @@ contains integer, parameter :: wp = sp class(TestCaseType), intent(inout) :: this - type(torch_tensor) :: tensor1, tensor2 + type(torch_tensor) :: tensor1, tensor2, divisor integer, parameter :: ndims = 2 integer(ftorch_int), parameter :: tensor_layout(ndims) = [1, 2] integer(c_int64_t), parameter, dimension(ndims) :: tensor_shape = [2, 3] @@ -570,13 +574,16 @@ contains ! Create an arbitrary input array in_data(:,:) = reshape([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [2, 3]) + ! Create a single valued rank-1 tensor based off the scalar + call torch_tensor_from_array(divisor, [scalar], [1], device_type) + ! Create a tensor based off the input array call torch_tensor_from_array(tensor1, in_data, tensor_layout, device_type) ! Create another empty tensor and assign it to the quotient of the first tensor and a scalar ! constant using the overloaded division operator call torch_tensor_empty(tensor2, ndims, tensor_shape, dtype, device_type) - tensor2 = tensor1 / scalar + tensor2 = tensor1 / divisor ! Check input array is unchanged by post-division expected(:,:) = reshape([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [2, 3]) @@ -607,6 +614,7 @@ contains nullify(out_data) call torch_tensor_delete(tensor1) call torch_tensor_delete(tensor2) + call torch_tensor_delete(divisor) end subroutine clean_up end subroutine test_torch_tensor_scalar_divide diff --git a/test/unit/test_tensor_operator_overloads_autograd.pf b/test/unit/test_tensor_operator_overloads_autograd.pf index 18352de4..26fa3012 100644 --- a/test/unit/test_tensor_operator_overloads_autograd.pf +++ b/test/unit/test_tensor_operator_overloads_autograd.pf @@ -407,7 +407,7 @@ contains @test(testParameters={get_parameters_full()}) subroutine test_torch_tensor_scalar_multiply(this) use, intrinsic :: iso_fortran_env, only: sp => real32 - use ftorch, only: operator(*), torch_tensor_print + use ftorch, only: operator(*) implicit none @@ -415,7 +415,7 @@ contains integer, parameter :: wp = sp class(TestCaseType), intent(inout) :: this - type(torch_tensor) :: Q, a, dQda + type(torch_tensor) :: Q, a, dQda, multiplier integer, parameter :: ndims = 2 integer(ftorch_int), parameter :: tensor_layout(ndims) = [1, 2] integer(c_int64_t), parameter, dimension(ndims) :: tensor_shape = [2, 3] @@ -432,13 +432,16 @@ contains ! Create tensor based off input array call torch_tensor_from_array(a, in_data, tensor_layout, device_type, requires_grad=.true.) + ! Create rank-1 tensor based off scalar + call torch_tensor_from_array(multiplier, [scalar], tensor_layout, device_type) + ! Create another empty tensors and assign it to the product of a scalar constant and the first ! tensor using the overloaded multiplication operator call torch_tensor_empty(Q, ndims, tensor_shape, dtype, device_type) if (this%param%switch) then - Q = scalar * a + Q = multiplier * a else - Q = a * scalar + Q = a * multiplier end if ! Apply back-propagation @@ -448,7 +451,6 @@ contains ! Extract Fortran array from the first computed gradient and its data with the expected value: ! Q(a,b) = scalar * a => dQ/da = scalar call torch_tensor_to_array(dQda, out_data, shape(in_data)) - call torch_tensor_print(dQda) ! TODO: Temp expected(:,:) = scalar test_pass = assert_allclose(out_data, expected, test_name="test_torch_tensor_scalar_multiply") if (.not. test_pass) then @@ -467,6 +469,7 @@ contains call torch_tensor_delete(a) call torch_tensor_delete(Q) call torch_tensor_delete(dQda) + call torch_tensor_delete(multiplier) end subroutine clean_up end subroutine test_torch_tensor_scalar_multiply @@ -548,7 +551,7 @@ contains @test(testParameters={get_parameters_short()}) subroutine test_torch_tensor_scalar_divide(this) use, intrinsic :: iso_fortran_env, only: sp => real32 - use ftorch, only: operator(/), torch_tensor_print + use ftorch, only: operator(/) implicit none @@ -556,7 +559,7 @@ contains integer, parameter :: wp = sp class(TestCaseType), intent(inout) :: this - type(torch_tensor) :: Q, a, dQda + type(torch_tensor) :: Q, a, dQda, divisor integer, parameter :: ndims = 2 integer(ftorch_int), parameter :: tensor_layout(ndims) = [1, 2] integer(c_int64_t), parameter, dimension(ndims) :: tensor_shape = [2, 3] @@ -573,10 +576,13 @@ contains ! Create tensor based off input array call torch_tensor_from_array(a, in_data, tensor_layout, device_type, requires_grad=.true.) + ! Create rank-1 tensor based off scalar + call torch_tensor_from_array(divisor, [scalar], tensor_layout, device_type) + ! Create another empty tensors and assign it to the product of a scalar constant and the first ! tensor using the overloaded multiplication operator call torch_tensor_empty(Q, ndims, tensor_shape, dtype, device_type) - Q = a / scalar + Q = a / divisor ! Apply back-propagation call torch_tensor_backward(Q) @@ -585,7 +591,6 @@ contains ! Extract Fortran array from the first computed gradient and its data with the expected value: ! Q(a,b) = a / scalar => dQ/da = 1 / scalar call torch_tensor_to_array(dQda, out_data, shape(in_data)) - call torch_tensor_print(dQda) ! TODO: Temp expected(:,:) = 1.0 / scalar test_pass = assert_allclose(out_data, expected, test_name="test_torch_tensor_scalar_divide") if (.not. test_pass) then @@ -604,6 +609,7 @@ contains call torch_tensor_delete(a) call torch_tensor_delete(Q) call torch_tensor_delete(dQda) + call torch_tensor_delete(divisor) end subroutine clean_up end subroutine test_torch_tensor_scalar_divide From d42f3eda275db012ed8c9efccc66e6fff869e62f Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Tue, 18 Feb 2025 15:54:46 +0000 Subject: [PATCH 23/24] Fix static analysis --- .github/workflows/static_analysis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/static_analysis.yml b/.github/workflows/static_analysis.yml index 8c25dcaf..8d961e5c 100644 --- a/.github/workflows/static_analysis.yml +++ b/.github/workflows/static_analysis.yml @@ -113,7 +113,7 @@ jobs: style: 'file' tidy-checks: '' # Use the compile_commands.json from CMake to locate headers - database: ${{ github.workspace }}/src/build + database: ${{ github.workspace }}/build # only 'update' a single comment in a pull request thread. thread-comments: ${{ github.event_name == 'pull_request' && 'update' }} - name: Fail fast?! From 2c5e8cbd092ee1c331eb974a4943ce232632e2b6 Mon Sep 17 00:00:00 2001 From: Joe Wallwork Date: Tue, 18 Feb 2025 15:57:10 +0000 Subject: [PATCH 24/24] Apply cmake-format --- test/unit/CMakeLists.txt | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt index 410e5b9a..55442c66 100644 --- a/test/unit/CMakeLists.txt +++ b/test/unit/CMakeLists.txt @@ -1,21 +1,26 @@ cmake_minimum_required(VERSION 3.15...3.31) -cmake_policy (SET CMP0076 NEW) +cmake_policy(SET CMP0076 NEW) -project("FTorch unit tests" VERSION 1.0.0 LANGUAGES Fortran) +project( + "FTorch unit tests" + VERSION 1.0.0 + LANGUAGES Fortran) find_package(FTorch) message(STATUS "Building with Fortran PyTorch coupling") find_package(PFUNIT REQUIRED) -add_pfunit_ctest(test_tensor_constructors - TEST_SOURCES test_tensor_constructors.pf LINK_LIBRARIES FTorch::ftorch) -add_pfunit_ctest(test_tensor_interrogation - TEST_SOURCES test_tensor_interrogation.pf LINK_LIBRARIES FTorch::ftorch) -add_pfunit_ctest(test_operator_overloads - TEST_SOURCES test_tensor_operator_overloads.pf LINK_LIBRARIES FTorch::ftorch) -add_pfunit_ctest(test_operator_overloads_autograd - TEST_SOURCES test_tensor_operator_overloads_autograd.pf LINK_LIBRARIES FTorch::ftorch) +add_pfunit_ctest(test_tensor_constructors TEST_SOURCES + test_tensor_constructors.pf LINK_LIBRARIES FTorch::ftorch) +add_pfunit_ctest(test_tensor_interrogation TEST_SOURCES + test_tensor_interrogation.pf LINK_LIBRARIES FTorch::ftorch) +add_pfunit_ctest( + test_operator_overloads TEST_SOURCES test_tensor_operator_overloads.pf + LINK_LIBRARIES FTorch::ftorch) +add_pfunit_ctest( + test_operator_overloads_autograd TEST_SOURCES + test_tensor_operator_overloads_autograd.pf LINK_LIBRARIES FTorch::ftorch) if(ENABLE_CUDA) check_language(CUDA) @@ -24,7 +29,7 @@ if(ENABLE_CUDA) else() message(ERROR "No CUDA support") endif() - add_pfunit_ctest(test_tensor_interrogation_cuda - TEST_SOURCES test_tensor_interrogation_cuda.pf - LINK_LIBRARIES FTorch::ftorch) + add_pfunit_ctest( + test_tensor_interrogation_cuda TEST_SOURCES + test_tensor_interrogation_cuda.pf LINK_LIBRARIES FTorch::ftorch) endif()