Skip to content

Commit

Permalink
Use fortran types in API functions (#238)
Browse files Browse the repository at this point in the history
* use fortran types in API functions

* add fortran MICM solve function that takes c pointers

* update python wrapper and address reviewer comments

* fix fortran API test

* address reviewer comments
  • Loading branch information
mattldawson authored Nov 5, 2024
1 parent 5817b5a commit f7a69ef
Show file tree
Hide file tree
Showing 12 changed files with 779 additions and 231 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ option(MUSICA_BUILD_DOCS "Build the documentation" OFF)
option(MUSICA_ENABLE_MICM "Enable MICM" ON)
option(MUSICA_ENABLE_TUVX "Enable TUV-x" ON)

set(MUSICA_SET_MICM_VECTOR_MATRIX_SIZE "1" CACHE STRING "Set MICM vector-ordered matrix dimension")
set(MUSICA_SET_MICM_VECTOR_MATRIX_SIZE "4" CACHE STRING "Set MICM vector-ordered matrix dimension")

cmake_dependent_option(
MUSICA_ENABLE_PYTHON_LIBRARY "Adds pybind11, a lightweight header-only library that exposes C++ types in Python and vice versa" OFF "MUSICA_BUILD_C_CXX_INTERFACE" OFF)
Expand Down
166 changes: 129 additions & 37 deletions fortran/micm.F90
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ module musica_micm
implicit none

public :: micm_t, solver_stats_t, get_micm_version
public :: Rosenbrock, RosenbrockStandardOrder, BackwardEuler, BackwardEulerStandardOrder
public :: UndefinedSolver, Rosenbrock, RosenbrockStandardOrder, BackwardEuler, BackwardEulerStandardOrder
private

!> Wrapper for c solver stats
Expand All @@ -28,6 +28,7 @@ module musica_micm
! We could use Fortran 2023 enum type feature if Fortran 2023 is supported
! https://fortran-lang.discourse.group/t/enumerator-type-in-bind-c-derived-type-best-practice/5947/2
enum, bind(c)
enumerator :: UndefinedSolver = 0
enumerator :: Rosenbrock = 1
enumerator :: RosenbrockStandardOrder = 2
enumerator :: BackwardEuler = 3
Expand Down Expand Up @@ -135,10 +136,14 @@ end function get_user_defined_reaction_rates_ordering_c
type :: micm_t
type(mappings_t), pointer :: species_ordering => null()
type(mappings_t), pointer :: user_defined_reaction_rates => null()
type(c_ptr), private :: ptr = c_null_ptr
type(c_ptr), private :: ptr = c_null_ptr
integer, private :: number_of_grid_cells = 0
integer, private :: solver_type = UndefinedSolver
contains
! Solve the chemical system
procedure :: solve
procedure, private :: solve_arrays
procedure, private :: solve_c_ptrs
generic :: solve => solve_arrays, solve_c_ptrs
! Get species properties
procedure :: get_species_property_string
procedure :: get_species_property_double
Expand Down Expand Up @@ -191,9 +196,11 @@ function constructor(config_path, solver_type, num_grid_cells, error) result( t
use musica_util, only: error_t_c, error_t, copy_mappings
type(micm_t), pointer :: this
character(len=*), intent(in) :: config_path
integer(c_int), intent(in) :: solver_type
integer(c_int), intent(in) :: num_grid_cells
integer, intent(in) :: solver_type
integer, intent(in) :: num_grid_cells
type(error_t), intent(inout) :: error

! local variables
character(len=1, kind=c_char) :: c_config_path(len_trim(config_path)+1)
integer :: n, i
type(error_t_c) :: error_c
Expand All @@ -206,7 +213,10 @@ function constructor(config_path, solver_type, num_grid_cells, error) result( t
end do
c_config_path(n+1) = c_null_char

this%ptr = create_micm_c(c_config_path, solver_type, num_grid_cells, error_c)
this%number_of_grid_cells = num_grid_cells
this%solver_type = solver_type
this%ptr = create_micm_c( c_config_path, int(solver_type, kind=c_int), &
int(num_grid_cells, kind=c_int), error_c )
error = error_t(error_c)
if (.not. error%is_success()) then
deallocate(this)
Expand All @@ -233,41 +243,121 @@ function constructor(config_path, solver_type, num_grid_cells, error) result( t

end function constructor

subroutine solve(this, time_step, temperature, pressure, air_density, concentrations, &
user_defined_reaction_rates, solver_state, solver_stats, error)
!> Solves the chemical system
!!
!! This function accepts fortran arrays and checks their sizes
!! against the number of grid cells and the species/rate parameter ordering.
subroutine solve_arrays(this, time_step, temperature, pressure, air_density, &
concentrations, user_defined_reaction_rates, solver_state, solver_stats, error)
use iso_c_binding, only: c_loc
use iso_fortran_env, only: real64
use musica_util, only: string_t, string_t_c, error_t_c, error_t
class(micm_t), intent(in) :: this
real(real64), intent(in) :: time_step
real(real64), target, intent(in) :: temperature(:)
real(real64), target, intent(in) :: pressure(:)
real(real64), target, intent(in) :: air_density(:)
real(real64), target, intent(inout) :: concentrations(:,:)
real(real64), target, intent(in) :: user_defined_reaction_rates(:,:)
type(string_t), intent(out) :: solver_state
type(solver_stats_t), intent(out) :: solver_stats
type(error_t), intent(out) :: error

type(string_t_c) :: solver_state_c
type(solver_stats_t_c) :: solver_stats_c
type(error_t_c) :: error_c

if (size(temperature) .ne. this%number_of_grid_cells) then
error = error_t(1, "MICM_SOLVE", "Temperature array size does not match number of grid cells")
return
end if
if (size(pressure) .ne. this%number_of_grid_cells) then
error = error_t(1, "MICM_SOLVE", "Pressure array size does not match number of grid cells")
return
end if
if (size(air_density) .ne. this%number_of_grid_cells) then
error = error_t(1, "MICM_SOLVE", "Air density array size does not match number of grid cells")
return
end if
if (this%solver_type .eq. Rosenbrock .or. this%solver_type .eq. BackwardEuler) then
if (size(concentrations, 1) .ne. this%number_of_grid_cells) then
error = error_t(1, "MICM_SOLVE", "Concentrations array dimension 1 does not match number of grid cells")
return
end if
if (size(concentrations, 2) .ne. this%species_ordering%size()) then
error = error_t(1, "MICM_SOLVE", "Concentrations array dimension 2 does not match species ordering")
return
end if
if (size(user_defined_reaction_rates, 1) .ne. this%number_of_grid_cells) then
error = error_t(1, "MICM_SOLVE", "User defined reaction rates array dimension 1 does not match number of grid cells")
return
end if
if (size(user_defined_reaction_rates, 2) .ne. this%user_defined_reaction_rates%size()) then
error = error_t(1, "MICM_SOLVE", "User defined reaction rates array dimension 2 does not match user defined reaction rates ordering")
return
end if
else
if (size(concentrations, 1) .ne. this%species_ordering%size()) then
error = error_t(1, "MICM_SOLVE", "Concentrations array dimension 1 does not match species ordering")
return
end if
if (size(concentrations, 2) .ne. this%number_of_grid_cells) then
error = error_t(1, "MICM_SOLVE", "Concentrations array dimension 2 does not match number of grid cells")
return
end if
if (size(user_defined_reaction_rates, 1) .ne. this%user_defined_reaction_rates%size()) then
error = error_t(1, "MICM_SOLVE", "User defined reaction rates array dimension 1 does not match user defined reaction rates ordering")
return
end if
if (size(user_defined_reaction_rates, 2) .ne. this%number_of_grid_cells) then
error = error_t(1, "MICM_SOLVE", "User defined reaction rates array dimension 2 does not match number of grid cells")
return
end if
end if

call micm_solve_c(this%ptr, real(time_step, kind=c_double), c_loc(temperature), &
c_loc(pressure), c_loc(air_density), c_loc(concentrations), &
c_loc(user_defined_reaction_rates), &
solver_state_c, solver_stats_c, error_c)

solver_state = string_t(solver_state_c)
solver_stats = solver_stats_t(solver_stats_c)
error = error_t(error_c)

end subroutine solve_arrays

!> Solves the chemical system
!!
!! This function accepts c pointers and does not check their sizes.
!! The user is responsible for ensuring the sizes are correct.
subroutine solve_c_ptrs(this, time_step, temperature, pressure, air_density, &
concentrations, user_defined_reaction_rates, solver_state, solver_stats, error)
use iso_fortran_env, only: real64
use musica_util, only: string_t, string_t_c, error_t_c, error_t
class(micm_t) :: this
real(c_double), intent(in) :: time_step
real(c_double), target, intent(in) :: temperature(:)
real(c_double), target, intent(in) :: pressure(:)
real(c_double), target, intent(in) :: air_density(:)
real(c_double), target, intent(inout) :: concentrations(:)
real(c_double), target, intent(in) :: user_defined_reaction_rates(:)
type(string_t), intent(out) :: solver_state
type(solver_stats_t), intent(out) :: solver_stats
type(error_t), intent(out) :: error
class(micm_t), intent(in) :: this
real(real64), intent(in) :: time_step
type(c_ptr), intent(in) :: temperature
type(c_ptr), intent(in) :: pressure
type(c_ptr), intent(in) :: air_density
type(c_ptr), intent(in) :: concentrations
type(c_ptr), intent(in) :: user_defined_reaction_rates
type(string_t), intent(out) :: solver_state
type(solver_stats_t), intent(out) :: solver_stats
type(error_t), intent(out) :: error

type(string_t_c) :: solver_state_c
type(solver_stats_t_c) :: solver_stats_c
type(error_t_c) :: error_c
type(c_ptr) :: temperature_c, pressure_c, air_density_c, concentrations_c, &
user_defined_reaction_rates_c

temperature_c = c_loc(temperature)
pressure_c = c_loc(pressure)
air_density_c = c_loc(air_density)
concentrations_c = c_loc(concentrations)
user_defined_reaction_rates_c = c_loc(user_defined_reaction_rates)
call micm_solve_c(this%ptr, time_step, temperature_c, pressure_c, air_density_c, &
concentrations_c, user_defined_reaction_rates_c, solver_state_c, &
solver_stats_c, error_c)

call micm_solve_c(this%ptr, real(time_step, kind=c_double), temperature, pressure, &
air_density, concentrations, user_defined_reaction_rates, &
solver_state_c, solver_stats_c, error_c)

solver_state = string_t(solver_state_c)
solver_stats = solver_stats_t(solver_stats_c)
error = error_t(error_c)

end subroutine solve
end subroutine solve_c_ptrs

!> Constructor for solver_stats_t object that takes ownership of solver_stats_t_c
function solver_stats_t_constructor( c_solver_stats ) result( new_solver_stats )
Expand Down Expand Up @@ -359,8 +449,9 @@ end function solver_stats_t_solves

!> Get the final time the solver iterated to
function solver_stats_t_final_time( this ) result( final_time )
use iso_fortran_env, only: real64
class(solver_stats_t), intent(in) :: this
real :: final_time
real(real64) :: final_time

final_time = this%final_time_

Expand All @@ -382,15 +473,16 @@ function get_species_property_string(this, species_name, property_name, error) r
end function get_species_property_string

function get_species_property_double(this, species_name, property_name, error) result(value)
use iso_fortran_env, only: real64
use musica_util, only: error_t_c, error_t, to_c_string
class(micm_t) :: this
character(len=*), intent(in) :: species_name, property_name
type(error_t), intent(inout) :: error
real(c_double) :: value
real(real64) :: value

type(error_t_c) :: error_c
value = get_species_property_double_c(this%ptr, &
to_c_string(species_name), to_c_string(property_name), error_c)
value = real( get_species_property_double_c( this%ptr, to_c_string(species_name), &
to_c_string(property_name), error_c ), kind=real64 )
error = error_t(error_c)
end function get_species_property_double

Expand All @@ -399,11 +491,11 @@ function get_species_property_int(this, species_name, property_name, error) resu
class(micm_t) :: this
character(len=*), intent(in) :: species_name, property_name
type(error_t), intent(inout) :: error
integer(c_int) :: value
integer :: value

type(error_t_c) :: error_c
value = get_species_property_int_c(this%ptr, &
to_c_string(species_name), to_c_string(property_name), error_c)
value = int( get_species_property_int_c(this%ptr, &
to_c_string(species_name), to_c_string(property_name), error_c) )
error = error_t(error_c)
end function get_species_property_int

Expand Down
2 changes: 2 additions & 0 deletions fortran/test/fetch_content_integration/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ if (MUSICA_ENABLE_MICM)
LINKER_LANGUAGE Fortran
)

target_compile_definitions(test_micm_fortran_api PUBLIC MICM_VECTOR_MATRIX_SIZE=${MUSICA_SET_MICM_VECTOR_MATRIX_SIZE})

add_test(
NAME test_micm_fortran_api
COMMAND $<TARGET_FILE:test_micm_fortran_api>
Expand Down
Loading

0 comments on commit f7a69ef

Please sign in to comment.