Skip to content

Commit

Permalink
Port surface radiative flux override capability (#5)
Browse files Browse the repository at this point in the history
This PR ports the ability to override the surface radiative fluxes seen by the land surface model from the wrapper.  This was split across two PRs originally in the case of FV3GFS:

- ai2cm/fv3gfs-fortran#158
- ai2cm/fv3gfs-wrapper#244

This depends on the fortran changes made in:

- NOAA-GFDL/SHiELD_physics#31
- NOAA-GFDL/atmos_drivers#31

which have now been merged, and incorporated into this repo via #11.
  • Loading branch information
spencerkclark authored Nov 2, 2023
1 parent a20d668 commit 18bcafe
Show file tree
Hide file tree
Showing 11 changed files with 247 additions and 5 deletions.
6 changes: 6 additions & 0 deletions wrapper/fill_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@
"real": {"type_c": "c_double", "type_cython": "REAL_t"},
"logical": {"type_c": "c_int", "type_cython": "bint"},
}
OVERRIDES_FOR_SURFACE_RADIATIVE_FLUXES = [
"override_for_time_adjusted_total_sky_downward_longwave_flux_at_surface",
"override_for_time_adjusted_total_sky_downward_shortwave_flux_at_surface",
"override_for_time_adjusted_total_sky_net_shortwave_flux_at_surface",
]


def get_dim_range_string(dim_list):
Expand Down Expand Up @@ -98,6 +103,7 @@ def assign_types_to_flags(flag_data):
physics_3d_properties=physics_3d_properties,
dynamics_properties=dynamics_properties,
flagstruct_properties=flagstruct_properties,
overriding_fluxes=OVERRIDES_FOR_SURFACE_RADIATIVE_FLUXES,
)
with open(out_filename, "w") as f:
f.write(result)
6 changes: 6 additions & 0 deletions wrapper/shield/wrapper/_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
with open(os.path.join(DIR, "flagstruct_properties.json"), "r") as f:
FLAGSTRUCT_PROPERTIES = json.load(f)

OVERRIDES_FOR_SURFACE_RADIATIVE_FLUXES = [
"override_for_time_adjusted_total_sky_downward_longwave_flux_at_surface",
"override_for_time_adjusted_total_sky_downward_shortwave_flux_at_surface",
"override_for_time_adjusted_total_sky_net_shortwave_flux_at_surface",
]

DIM_NAMES = {
properties["name"]: properties["dims"]
for properties in DYNAMICS_PROPERTIES + PHYSICS_PROPERTIES
Expand Down
5 changes: 3 additions & 2 deletions wrapper/shield/wrapper/_restart/io.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from .._wrapper import get_tracer_metadata
from .._properties import (
DYNAMICS_PROPERTIES,
PHYSICS_PROPERTIES
OVERRIDES_FOR_SURFACE_RADIATIVE_FLUXES,
PHYSICS_PROPERTIES,
)

# these variables are found not to be needed for smooth restarts
Expand All @@ -10,7 +11,7 @@
"convective_cloud_fraction",
"convective_cloud_top_pressure",
"convective_cloud_bottom_pressure",
]
] + OVERRIDES_FOR_SURFACE_RADIATIVE_FLUXES


def get_restart_names():
Expand Down
6 changes: 6 additions & 0 deletions wrapper/shield/wrapper/flagstruct_properties.json
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,11 @@
"fortran_name" : "do_adiabatic_init",
"location" : "do_adiabatic_init",
"type_fortran": "logical"
},
{
"name": "override_surface_radiative_fluxes",
"fortran_name" : "override_surface_radiative_fluxes",
"location" : "IPD_Control",
"type_fortran": "logical"
}
]
21 changes: 21 additions & 0 deletions wrapper/shield/wrapper/physics_properties.json
Original file line number Diff line number Diff line change
Expand Up @@ -484,5 +484,26 @@
"description": "orographic metrics",
"container": "Sfcprop",
"dims": ["orographic_variable", "y", "x"]
},
{
"name": "override_for_time_adjusted_total_sky_downward_longwave_flux_at_surface",
"fortran_name": "adjsfcdlw_override",
"units": "W/m^2",
"container": "Overrides",
"dims": ["y", "x"]
},
{
"name": "override_for_time_adjusted_total_sky_downward_shortwave_flux_at_surface",
"fortran_name": "adjsfcdsw_override",
"units": "W/m^2",
"container": "Overrides",
"dims": ["y", "x"]
},
{
"name": "override_for_time_adjusted_total_sky_net_shortwave_flux_at_surface",
"fortran_name": "adjsfcnsw_override",
"units": "W/m^2",
"container": "Overrides",
"dims": ["y", "x"]
}
]
18 changes: 18 additions & 0 deletions wrapper/templates/_wrapper.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,16 @@ cdef int set_2d_quantity(name, REAL_t[:, ::1] array) except -1:
{% endif %}
{% endfor %}
{% for item in physics_2d_properties %}
{% if item.name in overriding_fluxes %}
elif name == '{{ item.name }}':
if flags.override_surface_radiative_fluxes:
set_{{ item.fortran_name }}{% if "fortran_subname" in item %}_{{ item.fortran_subname }}{% endif %}(&array[0, 0])
else:
raise pace.util.InvalidQuantityError('Overriding surface fluxes can only be set if gfs_physics_nml.override_surface_radiative_fluxes is set to .true.')
{% else %}
elif name == '{{ item.name }}':
set_{{ item.fortran_name }}{% if "fortran_subname" in item %}_{{ item.fortran_subname }}{% endif %}(&array[0, 0])
{% endif %}
{% endfor %}
else:
raise ValueError(f'no setter available for {name}')
Expand Down Expand Up @@ -296,10 +304,20 @@ def get_state(names, dict state=None, allocator=None):
state['initialization_time'] = get_time(which='initialization_time')

{% for item in physics_2d_properties %}
{% if item.name in overriding_fluxes %}
if '{{ item.name }}' in input_names_set:
if flags.override_surface_radiative_fluxes:
quantity = _get_quantity(state, "{{ item.name }}", allocator, {{ item.dims | safe }}, "{{ item.units }}", dtype=real_type)
with pace.util.recv_buffer(quantity.np.empty, quantity.view[:]) as array_2d:
get_{{ item.fortran_name }}{% if "fortran_subname" in item %}_{{ item.fortran_subname }}{% endif %}(&array_2d[0, 0])
else:
raise pace.util.InvalidQuantityError('Overriding surface fluxes can only be accessed if gfs_physics_nml.override_surface_radiative_fluxes is set to .true.')
{% else %}
if '{{ item.name }}' in input_names_set:
quantity = _get_quantity(state, "{{ item.name }}", allocator, {{ item.dims | safe }}, "{{ item.units }}", dtype=real_type)
with pace.util.recv_buffer(quantity.np.empty, quantity.view[:]) as array_2d:
get_{{ item.fortran_name }}{% if "fortran_subname" in item %}_{{ item.fortran_subname }}{% endif %}(&array_2d[0, 0])
{% endif %}
{% endfor %}

{% for item in physics_3d_properties %}
Expand Down
8 changes: 7 additions & 1 deletion wrapper/tests/test_all_mpi_requiring.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ def test_getters(self):
run_unittest_script("test_getters.py")

def test_setters_default(self):
run_unittest_script("test_setters.py")
run_unittest_script("test_setters.py", "false")

def test_setters_while_overriding_surface_radiative_fluxes(self):
run_unittest_script("test_setters.py", "true")

def test_tracer_metadata(self):
run_unittest_script("test_tracer_metadata.py")
Expand All @@ -37,6 +40,9 @@ def test_get_initialization_time(self):
def test_flags(self):
run_unittest_script("test_flags.py")

def test_overrides_for_surface_radiative_fluxes_modify_diagnostics(self):
run_unittest_script("test_overrides_for_surface_radiative_fluxes.py")

def test_set_ocean_surface_temperature(self):
run_unittest_script("test_set_ocean_surface_temperature.py")

Expand Down
5 changes: 5 additions & 0 deletions wrapper/tests/test_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ def test_dt_atmos(self):
expected = 900
self.assertEqual(result, expected)

def test_override_surface_radiative_fluxes(self):
"""Test that getting a boolean flag produces its expected result."""
result = shield.wrapper.flags.override_surface_radiative_fluxes
self.assertFalse(result)


if __name__ == "__main__":
config = get_default_config()
Expand Down
16 changes: 15 additions & 1 deletion wrapper/tests/test_getters.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pace.util
from shield.wrapper._properties import (
DYNAMICS_PROPERTIES,
OVERRIDES_FOR_SURFACE_RADIATIVE_FLUXES,
PHYSICS_PROPERTIES,
)
from mpi4py import MPI
Expand All @@ -13,14 +14,18 @@

test_dir = os.path.dirname(os.path.abspath(__file__))
MM_PER_M = 1000
DEFAULT_PHYSICS_PROPERTIES = []
for entry in PHYSICS_PROPERTIES:
if entry["name"] not in OVERRIDES_FOR_SURFACE_RADIATIVE_FLUXES:
DEFAULT_PHYSICS_PROPERTIES.append(entry)


class GetterTests(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(GetterTests, self).__init__(*args, **kwargs)
self.tracer_data = shield.wrapper.get_tracer_metadata()
self.dynamics_data = generate_data_dict(DYNAMICS_PROPERTIES)
self.physics_data = generate_data_dict(PHYSICS_PROPERTIES)
self.physics_data = generate_data_dict(DEFAULT_PHYSICS_PROPERTIES)
self.mpi_comm = MPI.COMM_WORLD

def setUp(self):
Expand Down Expand Up @@ -153,6 +158,15 @@ def _get_names_helper(self, name_list):
self.assertIn(name, state)
self.assertEqual(len(name_list), len(state.keys()))

def _get_unallocated_name_helper(self, name):
with self.assertRaisesRegex(pace.util.InvalidQuantityError, "Overriding"):
shield.wrapper.get_state(names=[name])

def test_unallocated_physics_properties(self):
for name in OVERRIDES_FOR_SURFACE_RADIATIVE_FLUXES:
with self.subTest(name):
self._get_unallocated_name_helper(name)

class TracerMetadataTests(unittest.TestCase):
def test_tracer_index_is_one_based(self):
data = shield.wrapper.get_tracer_metadata()
Expand Down
114 changes: 114 additions & 0 deletions wrapper/tests/test_overrides_for_surface_radiative_fluxes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import unittest
import os
from copy import deepcopy
from shield.wrapper._properties import OVERRIDES_FOR_SURFACE_RADIATIVE_FLUXES
import numpy as np
import shield.wrapper
import pace.util
from mpi4py import MPI
from util import get_default_config, main


test_dir = os.path.dirname(os.path.abspath(__file__))
(
DOWNWARD_LONGWAVE,
DOWNWARD_SHORTWAVE,
NET_SHORTWAVE,
) = OVERRIDES_FOR_SURFACE_RADIATIVE_FLUXES


def override_surface_radiative_fluxes_with_random_values():
old_state = shield.wrapper.get_state(names=OVERRIDES_FOR_SURFACE_RADIATIVE_FLUXES)
replace_state = deepcopy(old_state)
for name, quantity in replace_state.items():
quantity.view[:] = np.random.uniform(size=quantity.extent)
shield.wrapper.set_state(replace_state)
return replace_state


def get_state_single_variable(name):
return shield.wrapper.get_state([name])[name].view[:]


class OverridingSurfaceRadiativeFluxTests(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(OverridingSurfaceRadiativeFluxTests, self).__init__(*args, **kwargs)

def setUp(self):
pass

def tearDown(self):
MPI.COMM_WORLD.barrier()

def test_resetting_to_checkpoint_allows_for_exact_restart(self):
checkpoint_state = shield.wrapper.get_state(shield.wrapper.get_restart_names())
print(checkpoint_state["time"])

# Run the model forward a timestep and save the temperature.
shield.wrapper.step()
expected = get_state_single_variable("air_temperature")

# Restore state to original checkpoint; step the model forward again.
# Check that the temperature is identical as after the first time we
# took a step.
shield.wrapper.set_state(checkpoint_state)
shield.wrapper.step()
result = get_state_single_variable("air_temperature")
np.testing.assert_equal(result, expected)

def test_overriding_fluxes_changes_model_state(self):
checkpoint_state = shield.wrapper.get_state(shield.wrapper.get_restart_names())

shield.wrapper.step()
temperature_with_default_override = get_state_single_variable("air_temperature")

# Restore state to original checkpoint; modify the radiative fluxes;
# step the model again.
shield.wrapper.set_state(checkpoint_state)
override_surface_radiative_fluxes_with_random_values()
shield.wrapper.step()
temperature_with_random_override = get_state_single_variable("air_temperature")

# We expect these states to differ.
assert not np.array_equal(
temperature_with_default_override, temperature_with_random_override
)

def test_overriding_fluxes_are_propagated_to_diagnostics(self):
replace_state = override_surface_radiative_fluxes_with_random_values()

# We need to step the model to fill the diagnostics buckets.
shield.wrapper.step()

timestep = shield.wrapper.flags.dt_atmos
expected_DSWRFI = replace_state[DOWNWARD_SHORTWAVE].view[:]
expected_DLWRFI = replace_state[DOWNWARD_LONGWAVE].view[:]
expected_USWRFI = (
replace_state[DOWNWARD_SHORTWAVE].view[:]
- replace_state[NET_SHORTWAVE].view[:]
)

result_DSWRF = shield.wrapper.get_diagnostic_by_name("DSWRF").view[:]
result_DLWRF = shield.wrapper.get_diagnostic_by_name("DLWRF").view[:]
result_USWRF = shield.wrapper.get_diagnostic_by_name("USWRF").view[:]
result_DSWRFI = shield.wrapper.get_diagnostic_by_name("DSWRFI").view[:]
result_DLWRFI = shield.wrapper.get_diagnostic_by_name("DLWRFI").view[:]
result_USWRFI = shield.wrapper.get_diagnostic_by_name("USWRFI").view[:]

np.testing.assert_allclose(result_DSWRF, timestep * expected_DSWRFI)
np.testing.assert_allclose(result_DLWRF, timestep * expected_DLWRFI)
np.testing.assert_allclose(result_USWRF, timestep * expected_USWRFI)
np.testing.assert_allclose(result_DSWRFI, expected_DSWRFI)
np.testing.assert_allclose(result_DLWRFI, expected_DLWRFI)
np.testing.assert_allclose(result_USWRFI, expected_USWRFI)


if __name__ == "__main__":
config = get_default_config()
config["namelist"]["gfs_physics_nml"]["override_surface_radiative_fluxes"] = True

# Clear diag_table for these tests, since outputting interval-averaged
# physics diagnostics leads to unrelated errors; see
# ai2cm/fv3gfs-fortran#384 for more context.
config["diag_table"] = "no_output"
main(test_dir, config)
Loading

0 comments on commit 18bcafe

Please sign in to comment.