diff --git a/doc/specs/stdlib_linalg.md b/doc/specs/stdlib_linalg.md index 6aa76fa88..cab16279c 100644 --- a/doc/specs/stdlib_linalg.md +++ b/doc/specs/stdlib_linalg.md @@ -168,3 +168,41 @@ program demo_trace print *, trace(A) ! 1 + 5 + 9 end program demo_trace ``` + +## `outer_product` - Computes the outer product of two vectors + +### Status + +Experimental + +### Description + +Computes the outer product of two vectors + +### Syntax + +`d = [[stdlib_linalg(module):outer_product(interface)]](u, v)` + +### Arguments + +`u`: Shall be a rank-1 array + +`v`: Shall be a rank-1 array + +### Return value + +Returns a rank-2 array equal to `u v^T` (where `u, v` are considered column vectors). The shape of the returned array is `[size(u), size(v)]`. + +### Example + +```fortran +program demo_outer_product + use stdlib_linalg, only: outer_product + implicit none + real, allocatable :: A(:,:), u(:), v(:) + u = [1., 2., 3. ] + v = [3., 4.] + A = outer_product(u,v) + !A = reshape([3., 6., 9., 4., 8., 12.], [3,2]) +end program demo_outer_product +``` diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index fa7f6e639..ef1d7963c 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -9,6 +9,7 @@ set(fppFiles stdlib_io.fypp stdlib_linalg.fypp stdlib_linalg_diag.fypp + stdlib_linalg_outer_product.fypp stdlib_optval.fypp stdlib_sorting.fypp stdlib_sorting_ord_sort.fypp diff --git a/src/Makefile.manual b/src/Makefile.manual index e55bb80fd..f869cbd7f 100644 --- a/src/Makefile.manual +++ b/src/Makefile.manual @@ -6,6 +6,7 @@ SRCFYPP =\ stdlib_io.fypp \ stdlib_linalg.fypp \ stdlib_linalg_diag.fypp \ + stdlib_linalg_outer_product.fypp \ stdlib_optval.fypp \ stdlib_quadrature.fypp \ stdlib_quadrature_trapz.fypp \ @@ -128,3 +129,4 @@ stdlib_stats_distribution_PRNG.o: \ stdlib_string_type.o: stdlib_ascii.o stdlib_kinds.o stdlib_strings.o: stdlib_ascii.o stdlib_string_type.o stdlib_math.o: stdlib_kinds.o +stdlib_linalg_outer_product.o: stdlib_linalg.o diff --git a/src/stdlib_linalg.fypp b/src/stdlib_linalg.fypp index 51c2cdd54..5e0388c0b 100644 --- a/src/stdlib_linalg.fypp +++ b/src/stdlib_linalg.fypp @@ -11,6 +11,7 @@ module stdlib_linalg public :: diag public :: eye public :: trace + public :: outer_product interface diag !! version: experimental @@ -52,6 +53,7 @@ module stdlib_linalg #:endfor end interface + ! Matrix trace interface trace !! version: experimental @@ -63,6 +65,21 @@ module stdlib_linalg #:endfor end interface + + ! Outer product (of two vectors) + interface outer_product + !! version: experimental + !! + !! Computes the outer product of two vectors, returning a rank-2 array + !! ([Specification](../page/specs/stdlib_linalg.html#description_3)) + #:for k1, t1 in RCI_KINDS_TYPES + pure module function outer_product_${t1[0]}$${k1}$(u, v) result(res) + ${t1}$, intent(in) :: u(:), v(:) + ${t1}$ :: res(size(u),size(v)) + end function outer_product_${t1[0]}$${k1}$ + #:endfor + end interface outer_product + contains function eye(n) result(res) diff --git a/src/stdlib_linalg_outer_product.fypp b/src/stdlib_linalg_outer_product.fypp new file mode 100644 index 000000000..26c726435 --- /dev/null +++ b/src/stdlib_linalg_outer_product.fypp @@ -0,0 +1,20 @@ +#:include "common.fypp" +#:set RCI_KINDS_TYPES = REAL_KINDS_TYPES + CMPLX_KINDS_TYPES + INT_KINDS_TYPES +submodule (stdlib_linalg) stdlib_linalg_outer_product + + implicit none + +contains + + #:for k1, t1 in RCI_KINDS_TYPES + pure module function outer_product_${t1[0]}$${k1}$(u, v) result(res) + ${t1}$, intent(in) :: u(:), v(:) + ${t1}$ :: res(size(u),size(v)) + integer :: col + do col = 1, size(v) + res(:,col) = v(col) * u + end do + end function outer_product_${t1[0]}$${k1}$ + #:endfor + +end submodule diff --git a/src/tests/linalg/test_linalg.f90 b/src/tests/linalg/test_linalg.f90 index 4ad178d5b..cc8d0db68 100644 --- a/src/tests/linalg/test_linalg.f90 +++ b/src/tests/linalg/test_linalg.f90 @@ -2,7 +2,7 @@ program test_linalg use stdlib_error, only: check use stdlib_kinds, only: sp, dp, qp, int8, int16, int32, int64 - use stdlib_linalg, only: diag, eye, trace + use stdlib_linalg, only: diag, eye, trace, outer_product implicit none @@ -56,6 +56,22 @@ program test_linalg call test_trace_int32 call test_trace_int64 + ! + ! outer product + ! + call test_outer_product_rsp + call test_outer_product_rdp + call test_outer_product_rqp + + call test_outer_product_csp + call test_outer_product_cdp + call test_outer_product_cqp + + call test_outer_product_int8 + call test_outer_product_int16 + call test_outer_product_int32 + call test_outer_product_int64 + contains @@ -75,7 +91,7 @@ subroutine test_eye cye = eye(7) call check(abs(trace(cye) - cmplx(7.0_sp,0.0_sp,kind=sp)) < sptol, & msg="abs(trace(cye) - cmplx(7.0_sp,0.0_sp,kind=sp)) < sptol failed.",warn=warn) - end subroutine + end subroutine test_eye subroutine test_diag_rsp integer, parameter :: n = 3 @@ -90,7 +106,7 @@ subroutine test_diag_rsp call check(all(diag(3*a) == 3*v), & msg="all(diag(3*a) == 3*v) failed.",warn=warn) - end subroutine + end subroutine test_diag_rsp subroutine test_diag_rsp_k integer, parameter :: n = 4 @@ -118,7 +134,7 @@ subroutine test_diag_rsp_k end do call check(size(diag(a,n+1)) == 0, & msg="size(diag(a,n+1)) == 0 failed.",warn=warn) - end subroutine + end subroutine test_diag_rsp_k subroutine test_diag_rdp integer, parameter :: n = 3 @@ -133,7 +149,7 @@ subroutine test_diag_rdp call check(all(diag(3*a) == 3*v), & msg="all(diag(3*a) == 3*v) failed.",warn=warn) - end subroutine + end subroutine test_diag_rdp subroutine test_diag_rqp integer, parameter :: n = 3 @@ -148,7 +164,7 @@ subroutine test_diag_rqp call check(all(diag(3*a) == 3*v), & msg="all(diag(3*a) == 3*v) failed.", warn=warn) - end subroutine + end subroutine test_diag_rqp subroutine test_diag_csp integer, parameter :: n = 3 @@ -165,7 +181,7 @@ subroutine test_diag_csp msg="all(abs(real(diag(a)) - [(i,i=1,n)]) < sptol)", warn=warn) call check(all(abs(aimag(diag(a)) - [(1,i=1,n)]) < sptol), & msg="all(abs(aimag(diag(a)) - [(1,i=1,n)]) < sptol)", warn=warn) - end subroutine + end subroutine test_diag_csp subroutine test_diag_cdp integer, parameter :: n = 3 @@ -175,7 +191,7 @@ subroutine test_diag_cdp a = diag([i_],-2) + diag([i_],2) call check(a(3,1) == i_ .and. a(1,3) == i_, & msg="a(3,1) == i_ .and. a(1,3) == i_ failed.",warn=warn) - end subroutine + end subroutine test_diag_cdp subroutine test_diag_cqp integer, parameter :: n = 3 @@ -185,7 +201,7 @@ subroutine test_diag_cqp a = diag([i_,i_],-1) + diag([i_,i_],1) call check(all(diag(a,-1) == i_) .and. all(diag(a,1) == i_), & msg="all(diag(a,-1) == i_) .and. all(diag(a,1) == i_) failed.",warn=warn) - end subroutine + end subroutine test_diag_cqp subroutine test_diag_int8 integer, parameter :: n = 3 @@ -199,7 +215,7 @@ subroutine test_diag_int8 msg="all(diag(a) == pack(a,mask)) failed.", warn=warn) call check(all(diag(diag(a)) == merge(a,0_int8,mask)), & msg="all(diag(diag(a)) == merge(a,0_int8,mask)) failed.", warn=warn) - end subroutine + end subroutine test_diag_int8 subroutine test_diag_int16 integer, parameter :: n = 4 integer(int16), allocatable :: a(:,:) @@ -212,7 +228,7 @@ subroutine test_diag_int16 msg="all(diag(a) == pack(a,mask))", warn=warn) call check(all(diag(diag(a)) == merge(a,0_int16,mask)), & msg="all(diag(diag(a)) == merge(a,0_int16,mask)) failed.", warn=warn) - end subroutine + end subroutine test_diag_int16 subroutine test_diag_int32 integer, parameter :: n = 3 integer(int32) :: a(n,n) @@ -226,7 +242,7 @@ subroutine test_diag_int32 msg="all(diag([1,1],-1) == a) failed.", warn=warn) call check(all(diag([1,1],1) == transpose(a)), & msg="all(diag([1,1],1) == transpose(a)) failed.", warn=warn) - end subroutine + end subroutine test_diag_int32 subroutine test_diag_int64 integer, parameter :: n = 4 integer(int64) :: a(n,n), c(0:2*n-1) @@ -257,7 +273,7 @@ subroutine test_diag_int64 end do call check(all(diag(a,-2) == diag(a,2)), & msg="all(diag(a,-2) == diag(a,2))", warn=warn) - end subroutine + end subroutine test_diag_int64 @@ -270,7 +286,7 @@ subroutine test_trace_rsp a = reshape([(i,i=1,n**2)],[n,n]) call check(abs(trace(a) - sum(diag(a))) < sptol, & msg="abs(trace(a) - sum(diag(a))) < sptol failed.",warn=warn) - end subroutine + end subroutine test_trace_rsp subroutine test_trace_rsp_nonsquare integer, parameter :: n = 4 @@ -287,7 +303,7 @@ subroutine test_trace_rsp_nonsquare call check(abs(trace(a) - ans) < sptol, & msg="abs(trace(a) - ans) < sptol failed.",warn=warn) - end subroutine + end subroutine test_trace_rsp_nonsquare subroutine test_trace_rdp integer, parameter :: n = 4 @@ -297,7 +313,7 @@ subroutine test_trace_rdp a = reshape([(i,i=1,n**2)],[n,n]) call check(abs(trace(a) - sum(diag(a))) < dptol, & msg="abs(trace(a) - sum(diag(a))) < dptol failed.",warn=warn) - end subroutine + end subroutine test_trace_rdp subroutine test_trace_rdp_nonsquare integer, parameter :: n = 4 @@ -314,7 +330,7 @@ subroutine test_trace_rdp_nonsquare call check(abs(trace(a) - ans) < dptol, & msg="abs(trace(a) - ans) < dptol failed.",warn=warn) - end subroutine + end subroutine test_trace_rdp_nonsquare subroutine test_trace_rqp integer, parameter :: n = 3 @@ -324,7 +340,7 @@ subroutine test_trace_rqp a = reshape([(i,i=1,n**2)],[n,n]) call check(abs(trace(a) - sum(diag(a))) < qptol, & msg="abs(trace(a) - sum(diag(a))) < qptol failed.",warn=warn) - end subroutine + end subroutine test_trace_rqp subroutine test_trace_csp @@ -345,7 +361,7 @@ subroutine test_trace_csp ! tr(A + B) = tr(A) + tr(B) call check(abs(trace(a+b) - (trace(a) + trace(b))) < sptol, & msg="abs(trace(a+b) - (trace(a) + trace(b))) < sptol failed.",warn=warn) - end subroutine + end subroutine test_trace_csp subroutine test_trace_cdp integer, parameter :: n = 3 @@ -359,7 +375,7 @@ subroutine test_trace_cdp call check(abs(trace(a) - ans) < dptol, & msg="abs(trace(a) - ans) < dptol failed.",warn=warn) - end subroutine + end subroutine test_trace_cdp subroutine test_trace_cqp integer, parameter :: n = 3 @@ -369,7 +385,7 @@ subroutine test_trace_cqp a = 3*eye(n) + 4*eye(n)*i_ ! pythagorean triple call check(abs(trace(a)) - 3*5.0_qp < qptol, & msg="abs(trace(a)) - 3*5.0_qp < qptol failed.",warn=warn) - end subroutine + end subroutine test_trace_cqp subroutine test_trace_int8 @@ -380,7 +396,7 @@ subroutine test_trace_int8 a = reshape([(i**2,i=1,n**2)],[n,n]) call check(trace(a) == (1 + 25 + 81), & msg="trace(a) == (1 + 25 + 81) failed.",warn=warn) - end subroutine + end subroutine test_trace_int8 subroutine test_trace_int16 integer, parameter :: n = 3 @@ -390,7 +406,7 @@ subroutine test_trace_int16 a = reshape([(i**3,i=1,n**2)],[n,n]) call check(trace(a) == (1 + 125 + 729), & msg="trace(a) == (1 + 125 + 729) failed.",warn=warn) - end subroutine + end subroutine test_trace_int16 subroutine test_trace_int32 integer, parameter :: n = 3 @@ -400,7 +416,7 @@ subroutine test_trace_int32 a = reshape([(i**4,i=1,n**2)],[n,n]) call check(trace(a) == (1 + 625 + 6561), & msg="trace(a) == (1 + 625 + 6561) failed.",warn=warn) - end subroutine + end subroutine test_trace_int32 subroutine test_trace_int64 integer, parameter :: n = 5 @@ -424,7 +440,129 @@ subroutine test_trace_int64 call check(trace(h) == sum(c(0:nd:2)), & msg="trace(h) == sum(c(0:nd:2)) failed.",warn=warn) - end subroutine + end subroutine test_trace_int64 + + + subroutine test_outer_product_rsp + integer, parameter :: n = 2 + real(sp) :: u(n), v(n), expected(n,n), diff(n,n) + write(*,*) "test_outer_product_rsp" + u = [1.,2.] + v = [1.,3.] + expected = reshape([1.,2.,3.,6.],[n,n]) + diff = expected - outer_product(u,v) + call check(all(abs(diff) < sptol), & + msg="all(abs(diff) < sptol) failed.",warn=warn) + end subroutine test_outer_product_rsp + + subroutine test_outer_product_rdp + integer, parameter :: n = 2 + real(dp) :: u(n), v(n), expected(n,n), diff(n,n) + write(*,*) "test_outer_product_rdp" + u = [1.,2.] + v = [1.,3.] + expected = reshape([1.,2.,3.,6.],[n,n]) + diff = expected - outer_product(u,v) + call check(all(abs(diff) < dptol), & + msg="all(abs(diff) < dptol) failed.",warn=warn) + end subroutine test_outer_product_rdp + + subroutine test_outer_product_rqp + integer, parameter :: n = 2 + real(qp) :: u(n), v(n), expected(n,n), diff(n,n) + write(*,*) "test_outer_product_rqp" + u = [1.,2.] + v = [1.,3.] + expected = reshape([1.,2.,3.,6.],[n,n]) + diff = expected - outer_product(u,v) + call check(all(abs(diff) < qptol), & + msg="all(abs(diff) < qptol) failed.",warn=warn) + end subroutine test_outer_product_rqp + + subroutine test_outer_product_csp + integer, parameter :: n = 2 + complex(sp) :: u(n), v(n), expected(n,n), diff(n,n) + write(*,*) "test_outer_product_csp" + u = [cmplx(1.,1.),cmplx(2.,0.)] + v = [cmplx(1.,0.),cmplx(3.,1.)] + expected = reshape([cmplx(1.,1.),cmplx(2.,0.),cmplx(2.,4.),cmplx(6.,2.)],[n,n]) + diff = expected - outer_product(u,v) + call check(all(abs(diff) < sptol), & + msg="all(abs(diff) < sptol) failed.",warn=warn) + end subroutine test_outer_product_csp + + subroutine test_outer_product_cdp + integer, parameter :: n = 2 + complex(dp) :: u(n), v(n), expected(n,n), diff(n,n) + write(*,*) "test_outer_product_cdp" + u = [cmplx(1.,1.),cmplx(2.,0.)] + v = [cmplx(1.,0.),cmplx(3.,1.)] + expected = reshape([cmplx(1.,1.),cmplx(2.,0.),cmplx(2.,4.),cmplx(6.,2.)],[n,n]) + diff = expected - outer_product(u,v) + call check(all(abs(diff) < dptol), & + msg="all(abs(diff) < dptol) failed.",warn=warn) + end subroutine test_outer_product_cdp + + subroutine test_outer_product_cqp + integer, parameter :: n = 2 + complex(qp) :: u(n), v(n), expected(n,n), diff(n,n) + write(*,*) "test_outer_product_cqp" + u = [cmplx(1.,1.),cmplx(2.,0.)] + v = [cmplx(1.,0.),cmplx(3.,1.)] + expected = reshape([cmplx(1.,1.),cmplx(2.,0.),cmplx(2.,4.),cmplx(6.,2.)],[n,n]) + diff = expected - outer_product(u,v) + call check(all(abs(diff) < qptol), & + msg="all(abs(diff) < qptol) failed.",warn=warn) + end subroutine test_outer_product_cqp + + subroutine test_outer_product_int8 + integer, parameter :: n = 2 + integer(int8) :: u(n), v(n), expected(n,n), diff(n,n) + write(*,*) "test_outer_product_int8" + u = [1,2] + v = [1,3] + expected = reshape([1,2,3,6],[n,n]) + diff = expected - outer_product(u,v) + call check(all(abs(diff) == 0), & + msg="all(abs(diff) == 0) failed.",warn=warn) + end subroutine test_outer_product_int8 + + subroutine test_outer_product_int16 + integer, parameter :: n = 2 + integer(int16) :: u(n), v(n), expected(n,n), diff(n,n) + write(*,*) "test_outer_product_int16" + u = [1,2] + v = [1,3] + expected = reshape([1,2,3,6],[n,n]) + diff = expected - outer_product(u,v) + call check(all(abs(diff) == 0), & + msg="all(abs(diff) == 0) failed.",warn=warn) + end subroutine test_outer_product_int16 + + subroutine test_outer_product_int32 + integer, parameter :: n = 2 + integer(int32) :: u(n), v(n), expected(n,n), diff(n,n) + write(*,*) "test_outer_product_int32" + u = [1,2] + v = [1,3] + expected = reshape([1,2,3,6],[n,n]) + diff = expected - outer_product(u,v) + call check(all(abs(diff) == 0), & + msg="all(abs(diff) == 0) failed.",warn=warn) + end subroutine test_outer_product_int32 + + subroutine test_outer_product_int64 + integer, parameter :: n = 2 + integer(int64) :: u(n), v(n), expected(n,n), diff(n,n) + write(*,*) "test_outer_product_int64" + u = [1,2] + v = [1,3] + expected = reshape([1,2,3,6],[n,n]) + diff = expected - outer_product(u,v) + call check(all(abs(diff) == 0), & + msg="all(abs(diff) == 0) failed.",warn=warn) + end subroutine test_outer_product_int64 + pure recursive function catalan_number(n) result(value) integer, intent(in) :: n