Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removes all fortran from SHOC #3121

Closed
wants to merge 26 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
529 changes: 0 additions & 529 deletions components/eam/src/physics/cam/shoc.F90

Large diffs are not rendered by default.

227 changes: 215 additions & 12 deletions components/eamxx/scripts/gen_boiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,21 +138,21 @@
)),

("cxx_f2c_bind_decl" , (
lambda phys, sub, gb: f"{phys}_functions_f90.hpp",
lambda phys, sub, gb: f"tests/infra/{phys}_test_data.hpp",
lambda phys, sub, gb: expect_exists(phys, sub, gb, "cxx_f2c_bind_decl"),
lambda phys, sub, gb: get_cxx_close_block_regex(comment="end _f function decls"), # reqs special comment
lambda phys, sub, gb: get_cxx_function_begin_regex(sub + "_f"), # cxx_f decl
lambda phys, sub, gb: get_plain_comment_regex(comment="end _host function decls"), # reqs special comment
lambda phys, sub, gb: get_cxx_function_begin_regex(sub + "_host"), # cxx_host decl
lambda phys, sub, gb: re.compile(r".*;\s*$"), # ;
lambda *x : "The f90 to cxx function declaration(<name>_f)"
lambda *x : "The f90 to cxx function declaration(<name>_host)"
)),

("cxx_f2c_bind_impl" , (
lambda phys, sub, gb: f"{phys}_functions_f90.cpp",
lambda phys, sub, gb: f"tests/infra/{phys}_test_data.cpp",
lambda phys, sub, gb: expect_exists(phys, sub, gb, "cxx_f2c_bind_impl"),
lambda phys, sub, gb: get_namespace_close_regex(phys), # insert at end of namespace
lambda phys, sub, gb: get_cxx_function_begin_regex(sub + "_f"), # cxx_f
lambda phys, sub, gb: get_cxx_function_begin_regex(sub + "_host"), # cxx_f
lambda phys, sub, gb: get_cxx_close_block_regex(at_line_start=True), # terminating }
lambda *x : "The f90 to cxx function implementation(<name>_f)"
lambda *x : "The f90 to cxx function implementation(<name>_host)"
)),

("cxx_func_decl", (
Expand Down Expand Up @@ -455,6 +455,12 @@ def get_cxx_struct_begin_regex(struct):
struct_regex_str = fr"^\s*struct\s+{struct}([\W]|$)"
return re.compile(struct_regex_str)

###############################################################################
def get_plain_comment_regex(comment):
###############################################################################
comment_regex_str = fr"^\s*//\s*{comment}"
return re.compile(comment_regex_str)

###############################################################################
def get_data_struct_name(sub):
###############################################################################
Expand Down Expand Up @@ -1169,6 +1175,21 @@ def split_by_type(arg_data):

return reals, ints, logicals

###############################################################################
def split_by_scalar_vs_view(arg_data):
###############################################################################
"""
Take arg data and split into two lists of names based on scalar/not-scalar: [scalars] [non-scalars]
"""
scalars, non_scalars = [], []
for name, _, _, dims in arg_data:
if dims is not None:
non_scalars.append(name)
else:
scalars.append(name)

return scalars, non_scalars

###############################################################################
def gen_cxx_data_args(physics, arg_data):
###############################################################################
Expand Down Expand Up @@ -1441,6 +1462,30 @@ def check_existing_piece(lines, begin_regex, end_regex):

return None if begin_idx is None else (begin_idx, end_idx+1)

###############################################################################
def get_data_by_name(arg_data, arg_name, data_idx):
###############################################################################
for name, a, b, c in arg_data:
if name == arg_name:
return [name, a, b, c][data_idx]

expect(False, f"Name {arg_name} not found")

###############################################################################
def get_rank_map(arg_data, arg_names):
###############################################################################
# Create map of rank -> [args]
rank_map = {}
for arg in arg_names:
dims = get_data_by_name(arg_data, arg, ARG_DIMS)
rank = len(dims)
if rank in rank_map:
rank_map[rank].append(arg)
else:
rank_map[rank] = [arg]

return rank_map

#
# Main classes
#
Expand Down Expand Up @@ -1505,10 +1550,10 @@ def _get_db(self, phys):
db = parse_origin(origin_file.open(encoding="utf-8").read(), self._subs)
self._db[phys].update(db)
if self._verbose:
print("For physics {}, found:")
print(f"For physics {phys}, found:")
for sub in self._subs:
if sub in db:
print(" For subroutine {}, found args:")
print(f" For subroutine {sub}, found args:")
for name, argtype, intent, dims in db[sub]:
print(" name:{} type:{} intent:{} dims:({})".\
format(name, argtype, intent, ",".join(dims) if dims else "scalar"))
Expand Down Expand Up @@ -1729,7 +1774,7 @@ def gen_cxx_f2c_bind_decl(self, phys, sub, force_arg_data=None):
arg_data = force_arg_data if force_arg_data else self._get_arg_data(phys, sub)
arg_decls = gen_arg_cxx_decls(arg_data)

return f"void {sub}_f({', '.join(arg_decls)});"
return f"void {sub}_host({', '.join(arg_decls)});"

###########################################################################
def gen_cxx_f2c_bind_impl(self, phys, sub, force_arg_data=None):
Expand Down Expand Up @@ -1809,8 +1854,166 @@ def gen_cxx_f2c_bind_impl(self, phys, sub, force_arg_data=None):

impl = ""
if has_arrays(arg_data):
# TODO
impl += " // TODO"
#
# Steps:
# 1) Set up typedefs
# 2) Sync to device
# 3) Unpack view array
# 4) Get nk_pack and policy
# 5) Get subviews
# 6) Call fn
# 7) Sync back to host
#
inputs, inouts, outputs = split_by_intent(arg_data)
reals, ints, bools = split_by_type(arg_data)
scalars, views = split_by_scalar_vs_view(arg_data)
all_inputs = inputs + inouts
all_outputs = inouts + outputs

vreals = list(sorted(set(reals) & set(views)))
vints = list(sorted(set(ints) & set(views)))
vbools = list(sorted(set(bools) & set(views)))

sreals = list(sorted(set(reals) & set(scalars)))
sints = list(sorted(set(ints) & set(scalars)))
sbools = list(sorted(set(bools) & set(scalars)))

ivreals = list(sorted(set(vreals) & set(all_inputs)))
ivints = list(sorted(set(vints) & set(all_inputs)))
ivbools = list(sorted(set(vbools) & set(all_inputs)))

ovreals = list(sorted(set(vreals) & set(all_outputs)))
ovints = list(sorted(set(vints) & set(all_outputs)))
ovbools = list(sorted(set(vbools) & set(all_outputs)))

isreals = list(sorted(set(sreals) & set(all_inputs)))
isints = list(sorted(set(sints) & set(all_inputs)))
isbools = list(sorted(set(sbools) & set(all_inputs)))

osreals = list(sorted(set(sreals) & set(all_outputs)))
osints = list(sorted(set(sints) & set(all_outputs)))
osbools = list(sorted(set(sbools) & set(all_outputs)))

#
# 1) Set up typedefs
#

# set up basics
impl += "#if 0\n" # There's no way to guarantee this code compiles
impl += " using SHF = Functions<Real, DefaultDevice>;\n"
impl += " using Scalar = typename SHF::Scalar;\n"
impl += " using Spack = typename SHF::Spack;\n"
impl += " using KT = typename SHF::KT;\n"
impl += " using ExeSpace = typename KT::ExeSpace;\n"
impl += " using MemberType = typename SHF::MemberType;\n\n"

prefix_list = ["", "i", "b"]
type_list = ["Real", "Int", "bool"]
ktype_list = ["Spack", "Int", "bool"]

# make necessary view types. Anything that's an array needs a view type
for view_group, prefix_char, typename in zip([vreals, vints, vbools], prefix_list, type_list):
if view_group:
rank_map = get_rank_map(arg_data, view_group)
for rank in rank_map:
if typename == "Real" and rank > 1:
# Probably this should be packed data
impl += f" using {prefix_char}view_{rank}d = typename SHF::view_{rank}d<Spack>;\n"
else:
impl += f" using {prefix_char}view_{rank}d = typename SHF::view_{rank}d<{typename}>;\n"

impl += "\n"

#
# 2) Sync to device. Do ALL views, not just inputs
#

for input_group, prefix_char, typename in zip([vreals, vints, vbools], prefix_list, type_list):
if input_group:
rank_map = get_rank_map(arg_data, input_group)

for rank, arg_list in rank_map.items():
impl += f" static constexpr Int {prefix_char}num_arrays_{rank} = {len(arg_list)};\n"
impl += f" std::vector<{prefix_char}view_{rank}d> {prefix_char}temp_d_{rank}({prefix_char}num_arrays_{rank});\n"
for rank_itr in range(rank):
dims = [get_data_by_name(arg_data, arg_name, ARG_DIMS)[rank_itr] for arg_name in arg_list]
impl += f" std::vector<int> {prefix_char}dim_{rank}_{rank_itr}_sizes = {{{', '.join(dims)}}};\n"
dim_vectors = [f"{prefix_char}dim_{rank}_{rank_itr}_sizes" for rank_itr in range(rank)]
funcname = "ekat::host_to_device" if (typename == "Real" and rank > 1) else "ScreamDeepCopy::copy_to_device"
impl += f" {funcname}({{{', '.join(arg_list)}}}, {', '.join(dim_vectors)}, {prefix_char}temp_d_{rank});\n\n"

#
# 3) Unpack view array
#

for input_group, prefix_char, typename in zip([vreals, vints, vbools], prefix_list, type_list):
if input_group:
rank_map = get_rank_map(arg_data, input_group)

for rank, arg_list in rank_map.items():
impl += f" {prefix_char}view_{rank}d\n"
for idx, input_item in enumerate(arg_list):
impl += f" {input_item}_d({prefix_char}temp_d_{rank}[{idx}]){';' if idx == len(arg_list) - 1 else ','}\n"
impl += "\n"


#
# 4) Get nk_pack and policy, launch kernel
#
impl += " const Int nk_pack = ekat::npack<Spack>(nlev);\n"
impl += " const auto policy = ekat::ExeSpaceUtils<ExeSpace>::get_default_team_policy(shcol, nk_pack);\n"
impl += " Kokkos::parallel_for(policy, KOKKOS_LAMBDA(const MemberType& team) {\n"
impl += " const Int i = team.league_rank();\n\n"

#
# 5) Get subviews
#
for view_group, prefix_char, typename in zip([vreals, vints, vbools], prefix_list, type_list):
if view_group:
for view_arg in view_group:
dims = get_data_by_name(arg_data, view_arg, ARG_DIMS)
if "shcol" in dims:
if len(dims) == 1:
impl += f" const Scalar {view_arg}_s = {view_arg}_d(i);\n"
else:
impl += f" const auto {view_arg}_s = ekat::subview({view_arg}_d, i);\n"

impl += "\n"

#
# 6) Call fn
#
kernel_arg_names = []
for arg_name in arg_names:
if arg_name in views:
if "shcol" in dims:
kernel_arg_names.append(f"{arg_name}_s")
else:
kernel_arg_names.append(f"{arg_name}_d")
else:
kernel_arg_names.append(arg_name)

impl += f" SHF::{sub}({', '.join(kernel_arg_names)});\n"
impl += " });\n"

#
# 7) Sync back to host
#
for output_group, prefix_char, typename in zip([ovreals, ovints, ovbools], prefix_list, type_list):
if output_group:
rank_map = get_rank_map(arg_data, output_group)

for rank, arg_list in rank_map.items():
impl += f" std::vector<{prefix_char}view_{rank}d> {prefix_char}tempout_d_{rank}({prefix_char}num_arrays_{rank});\n"
for rank_itr in range(rank):
dims = [get_data_by_name(arg_data, arg_name, ARG_DIMS)[rank_itr] for arg_name in arg_list]
impl += f" std::vector<int> {prefix_char}dim_{rank}_{rank_itr}_out_sizes = {{{', '.join(dims)}}};\n"
dim_vectors = [f"{prefix_char}dim_{rank}_{rank_itr}_out_sizes" for rank_itr in range(rank)]
funcname = "ekat::device_to_host" if (typename == "Real" and rank > 1) else "ScreamDeepCopy::copy_to_host"
impl += f" {funcname}({{{', '.join(arg_list)}}}, {', '.join(dim_vectors)}, {prefix_char}tempout_d_{rank});\n\n"

impl += "#endif\n"

else:
inputs, inouts, outputs = split_by_intent(arg_data)
reals, ints, logicals = split_by_type(arg_data)
Expand Down
2 changes: 1 addition & 1 deletion components/eamxx/src/physics/p3/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ if (NOT SCREAM_P3_SMALL_KERNELS AND NOT SCREAM_ONLY_GENERATE_BASELINES)
CreateUnitTest(p3_sk_tests "p3_main_unit_tests.cpp"
LIBS p3_sk p3_test_infra
EXE_ARGS "--args ${BASELINE_FILE_ARG}"
THREADS 1 ${SCREAM_TEST_MAX_THREADS} ${SCREAM_TEST_THREAD_INC}
THREADS ${P3_THREADS}
LABELS "p3_sk;physics;baseline_cmp")
endif()

Expand Down
4 changes: 2 additions & 2 deletions components/eamxx/src/physics/p3/tests/infra/p3_test_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1299,7 +1299,7 @@ Int p3_main_host(
}
}

ekat::host_to_device(ptr_array, dim1_sizes, dim2_sizes, temp_d, true);
ekat::host_to_device(ptr_array, dim1_sizes, dim2_sizes, temp_d);

int counter = 0;
view_2d
Expand Down Expand Up @@ -1452,7 +1452,7 @@ Int p3_main_host(
rho_qi, qv2qi_depos_tend,
liq_ice_exchange, vap_liq_exchange, vap_ice_exchange, precip_liq_flux, precip_ice_flux, precip_liq_surf, precip_ice_surf
},
dim1_sizes_out, dim2_sizes_out, inout_views, true);
dim1_sizes_out, dim2_sizes_out, inout_views);

return elapsed_microsec;
}
Expand Down
2 changes: 0 additions & 2 deletions components/eamxx/src/physics/p3/tests/p3_main_unit_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -445,15 +445,13 @@ void run_bfb_p3_main()

// Get data from cxx
for (auto& d : isds_cxx) {
d.template transpose<ekat::TransposeDirection::c2f>();
p3_main_host(
d.qc, d.nc, d.qr, d.nr, d.th_atm, d.qv, d.dt, d.qi, d.qm, d.ni,
d.bm, d.pres, d.dz, d.nc_nuceat_tend, d.nccn_prescribed, d.ni_activated, d.inv_qc_relvar, d.it, d.precip_liq_surf,
d.precip_ice_surf, d.its, d.ite, d.kts, d.kte, d.diag_eff_radius_qc, d.diag_eff_radius_qi, d.diag_eff_radius_qr,
d.rho_qi, d.do_predict_nc, d.do_prescribed_CCN, d.dpres, d.inv_exner, d.qv2qi_depos_tend,
d.precip_liq_flux, d.precip_ice_flux, d.cld_frac_r, d.cld_frac_l, d.cld_frac_i,
d.liq_ice_exchange, d.vap_liq_exchange, d.vap_ice_exchange, d.qv_prev, d.t_prev);
d.template transpose<ekat::TransposeDirection::f2c>();
}

if (SCREAM_BFB_TESTING && this->m_baseline_action == COMPARE) {
Expand Down
12 changes: 0 additions & 12 deletions components/eamxx/src/physics/shoc/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,19 +1,7 @@
set(SHOC_SRCS
shoc_f90.cpp
shoc_ic_cases.cpp
shoc_iso_c.f90
shoc_iso_f.f90
${SCREAM_BASE_DIR}/../eam/src/physics/cam/shoc.F90
eamxx_shoc_process_interface.cpp
)

if (NOT SCREAM_LIB_ONLY)
list(APPEND SHOC_SRCS
shoc_functions_f90.cpp
shoc_main_wrap.cpp
) # Add f90 bridges needed for testing
endif()

set(SHOC_HEADERS
shoc.hpp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file does not even exist. Let's prune this line.

eamxx_shoc_process_interface.hpp
Expand Down
Loading
Loading