diff --git a/tests/test_temporal.py b/tests/test_temporal.py index e5489b1b..7951d3f3 100644 --- a/tests/test_temporal.py +++ b/tests/test_temporal.py @@ -1,4 +1,5 @@ import logging +import warnings import cftime import numpy as np @@ -209,6 +210,7 @@ def test_averages_for_monthly_time_series(self): ) xr.testing.assert_allclose(result, expected) + assert result.ts.attrs == expected.ts.attrs # Test unweighted averages result = ds.temporal.average("ts", weighted=False) @@ -226,6 +228,7 @@ def test_averages_for_monthly_time_series(self): }, ) xr.testing.assert_allclose(result, expected) + assert result.ts.attrs == expected.ts.attrs def test_averages_for_daily_time_series(self): ds = xr.Dataset( @@ -571,27 +574,28 @@ def test_weighted_annual_averages_with_chunking(self): assert result.ts.attrs == expected.ts.attrs assert result.time.attrs == expected.time.attrs - def test_weighted_seasonal_averages_with_DJF_and_drop_incomplete_seasons(self): + def test_weighted_seasonal_averages_with_DJF_without_dropping_incomplete_seasons( + self, + ): ds = self.ds.copy() result = ds.temporal.group_average( "ts", "season", - season_config={"dec_mode": "DJF", "drop_incomplete_djf": True}, + season_config={"dec_mode": "DJF", "drop_incomplete_seasons": False}, ) expected = ds.copy() - # Drop the incomplete DJF seasons - expected = expected.isel(time=slice(2, -1)) expected = expected.drop_dims("time") expected["ts"] = xr.DataArray( name="ts", - data=np.array([[[1]], [[1]], [[1]], [[2.0]]]), + data=np.array([[[2.0]], [[1.0]], [[1.0]], [[1.0]], [[2.0]]]), coords={ "lat": expected.lat, "lon": expected.lon, "time": xr.DataArray( data=np.array( [ + cftime.DatetimeGregorian(2000, 1, 1), cftime.DatetimeGregorian(2000, 4, 1), cftime.DatetimeGregorian(2000, 7, 1), cftime.DatetimeGregorian(2000, 10, 1), @@ -614,35 +618,82 @@ def test_weighted_seasonal_averages_with_DJF_and_drop_incomplete_seasons(self): "mode": "group_average", "freq": "season", "weighted": "True", + "drop_incomplete_seasons": "False", "dec_mode": "DJF", - "drop_incomplete_djf": "True", }, ) xr.testing.assert_identical(result, expected) - def test_weighted_seasonal_averages_with_DJF_without_dropping_incomplete_seasons( - self, - ): + def test_weighted_seasonal_averages_with_DJF_and_drop_incomplete_seasons(self): + ds = generate_dataset(decode_times=True, cf_compliant=True, has_bounds=True) + + result = ds.temporal.group_average( + "ts", + "season", + season_config={"dec_mode": "DJF", "drop_incomplete_seasons": True}, + ) + + expected = ds.copy() + expected = expected.drop_dims("time") + expected["ts"] = xr.DataArray( + name="ts", + data=np.ones((4, 4, 4)), + coords={ + "lat": expected.lat, + "lon": expected.lon, + "time": xr.DataArray( + data=np.array( + [ + cftime.DatetimeGregorian(2000, 4, 1), + cftime.DatetimeGregorian(2000, 7, 1), + cftime.DatetimeGregorian(2000, 10, 1), + cftime.DatetimeGregorian(2001, 1, 1), + ], + ), + dims=["time"], + attrs={ + "axis": "T", + "long_name": "time", + "standard_name": "time", + "bounds": "time_bnds", + }, + ), + }, + dims=["time", "lat", "lon"], + attrs={ + "operation": "temporal_avg", + "mode": "group_average", + "freq": "season", + "weighted": "True", + "drop_incomplete_seasons": "True", + "dec_mode": "DJF", + }, + ) + + xr.testing.assert_identical(result, expected) + + def test_weighted_seasonal_averages_with_DJF_and_drop_incomplete_djf(self): ds = self.ds.copy() result = ds.temporal.group_average( "ts", "season", - season_config={"dec_mode": "DJF", "drop_incomplete_djf": False}, + season_config={"dec_mode": "DJF", "drop_incomplete_djf": True}, ) expected = ds.copy() + # Drop the incomplete DJF seasons + expected = expected.isel(time=slice(2, -1)) expected = expected.drop_dims("time") expected["ts"] = xr.DataArray( name="ts", - data=np.array([[[2.0]], [[1.0]], [[1.0]], [[1.0]], [[2.0]]]), + data=np.array([[[1]], [[1]], [[1]], [[2.0]]]), coords={ "lat": expected.lat, "lon": expected.lon, "time": xr.DataArray( data=np.array( [ - cftime.DatetimeGregorian(2000, 1, 1), cftime.DatetimeGregorian(2000, 4, 1), cftime.DatetimeGregorian(2000, 7, 1), cftime.DatetimeGregorian(2000, 10, 1), @@ -666,7 +717,7 @@ def test_weighted_seasonal_averages_with_DJF_without_dropping_incomplete_seasons "freq": "season", "weighted": "True", "dec_mode": "DJF", - "drop_incomplete_djf": "False", + "drop_incomplete_djf": "True", }, ) @@ -725,12 +776,57 @@ def test_weighted_seasonal_averages_with_JFD(self): "mode": "group_average", "freq": "season", "weighted": "True", + "drop_incomplete_seasons": "False", "dec_mode": "JFD", }, ) xr.testing.assert_identical(result, expected) + def test_raises_error_with_incorrect_custom_seasons_argument(self): + # Test raises error with non-3 letter strings + with pytest.raises(ValueError): + custom_seasons = [ + ["J", "Feb", "Mar"], + ["Apr", "May", "Jun"], + ["Jul", "Aug", "Sep"], + ["Oct", "Nov", "Dec"], + ] + self.ds.temporal.group_average( + "ts", + "season", + season_config={"custom_seasons": custom_seasons}, + ) + + # Test raises error if duplicate month(s) were found + with pytest.raises(ValueError): + custom_seasons = [ + ["Jan", "Jan", "Mar"], + ["Apr", "May", "Jun"], + ["Jul", "Aug", "Sep"], + ["Oct", "Nov", "Dec"], + ] + self.ds.temporal.group_average( + "ts", + "season", + season_config={"custom_seasons": custom_seasons}, + ) + + def test_raises_error_with_dataset_that_has_no_complete_seasons(self): + ds = self.ds.copy() + ds = ds.isel(time=slice(0, 1)) + custom_seasons = [["Dec", "Jan"]] + + with pytest.raises(RuntimeError): + ds.temporal.group_average( + "ts", + "season", + season_config={ + "custom_seasons": custom_seasons, + "drop_incomplete_seasons": True, + }, + ) + def test_weighted_custom_seasonal_averages(self): ds = self.ds.copy() @@ -777,60 +873,194 @@ def test_weighted_custom_seasonal_averages(self): "operation": "temporal_avg", "mode": "group_average", "freq": "season", + "weighted": "True", + "drop_incomplete_seasons": "False", "custom_seasons": [ "JanFebMar", "AprMayJun", "JulAugSep", "OctNovDec", ], + }, + ) + + xr.testing.assert_identical(result, expected) + + def test_weighted_seasonal_averages_with_custom_seasons_and_all_complete_seasons( + self, + ): + ds = self.ds.copy() + ds["time"].values[:] = np.array( + [ + "2000-01-16T12:00:00.000000000", + "2000-02-15T12:00:00.000000000", + "2000-03-16T12:00:00.000000000", + "2000-06-16T00:00:00.000000000", + "2000-09-16T00:00:00.000000000", + ], + dtype="datetime64[ns]", + ) + + result = ds.temporal.group_average( + "ts", + "season", + season_config={ + "custom_seasons": [["Jan", "Mar", "Jun"], ["Feb", "Sep"]], + "drop_incomplete_seasons": True, + }, + ) + expected = ds.copy() + expected = expected.drop_dims("time") + expected["ts"] = xr.DataArray( + name="ts", + data=np.array([[[1.34065934]], [[1.47457627]]]), + coords={ + "lat": expected.lat, + "lon": expected.lon, + "time": xr.DataArray( + data=np.array( + [ + cftime.DatetimeGregorian(2000, 3, 1), + cftime.DatetimeGregorian(2000, 9, 1), + ], + ), + dims=["time"], + attrs={ + "axis": "T", + "long_name": "time", + "standard_name": "time", + "bounds": "time_bnds", + }, + ), + }, + dims=["time", "lat", "lon"], + attrs={ + "test_attr": "test", + "operation": "temporal_avg", + "mode": "group_average", + "freq": "season", "weighted": "True", + "drop_incomplete_seasons": "True", + "custom_seasons": ["JanMarJun", "FebSep"], + }, + ) + + xr.testing.assert_allclose(result, expected) + assert result.ts.attrs == expected.ts.attrs + + def test_weighted_custom_seasonal_averages_drops_incomplete_seasons(self): + ds = self.ds.copy() + ds["time"].values[:] = np.array( + [ + "2000-11-16T12:00:00.000000000", + "2000-12-16T12:00:00.000000000", + "2001-01-16T00:00:00.000000000", + "2001-02-16T00:00:00.000000000", + "2001-03-16T00:00:00.000000000", + ], + dtype="datetime64[ns]", + ) + + custom_seasons = [["Nov", "Dec"], ["Feb", "Mar", "Apr"]] + + result = ds.temporal.group_average( + "ts", + "season", + season_config={ + "drop_incomplete_seasons": True, + "custom_seasons": custom_seasons, + }, + ) + expected = ds.copy() + expected = expected.drop_dims("time") + expected["ts"] = xr.DataArray( + name="ts", + data=np.array([[[1.5]]]), + coords={ + "lat": expected.lat, + "lon": expected.lon, + "time": xr.DataArray( + data=np.array([cftime.datetime(2000, 12, 1)], dtype=object), + dims=["time"], + attrs={ + "axis": "T", + "long_name": "time", + "standard_name": "time", + "bounds": "time_bnds", + }, + ), + }, + dims=["time", "lat", "lon"], + attrs={ + "test_attr": "test", + "operation": "temporal_avg", + "mode": "group_average", + "freq": "season", + "weighted": "True", + "drop_incomplete_seasons": "True", + "custom_seasons": ["NovDec", "FebMarApr"], }, ) xr.testing.assert_identical(result, expected) - def test_raises_error_with_incorrect_custom_seasons_argument(self): - # Test raises error with non-3 letter strings - with pytest.raises(ValueError): - custom_seasons = [ - ["J", "Feb", "Mar"], - ["Apr", "May", "Jun"], - ["Jul", "Aug", "Sep"], - ["Oct", "Nov", "Dec"], - ] - self.ds.temporal.group_average( - "ts", - "season", - season_config={"custom_seasons": custom_seasons}, - ) + def test_weighted_custom_seasonal_averages_with_seasons_spanning_calendar_years( + self, + ): + ds = self.ds.copy() + ds["time"].values[:] = np.array( + [ + "2000-11-16T12:00:00.000000000", + "2000-12-16T12:00:00.000000000", + "2001-01-16T00:00:00.000000000", + "2001-02-16T00:00:00.000000000", + "2001-03-16T00:00:00.000000000", + ], + dtype="datetime64[ns]", + ) - # Test raises error with missing month(s) - with pytest.raises(ValueError): - custom_seasons = [ - ["Feb", "Mar"], - ["Apr", "May", "Jun"], - ["Jul", "Aug", "Sep"], - ["Oct", "Nov", "Dec"], - ] - self.ds.temporal.group_average( - "ts", - "season", - season_config={"custom_seasons": custom_seasons}, - ) + custom_seasons = [ + ["Nov", "Dec", "Jan", "Feb", "Mar"], + ] - # Test raises error if duplicate month(s) were found - with pytest.raises(ValueError): - custom_seasons = [ - ["Jan", "Jan", "Mar"], - ["Apr", "May", "Jun"], - ["Jul", "Aug", "Sep"], - ["Oct", "Nov", "Dec"], - ] - self.ds.temporal.group_average( - "ts", - "season", - season_config={"custom_seasons": custom_seasons}, - ) + result = ds.temporal.group_average( + "ts", + "season", + season_config={"custom_seasons": custom_seasons}, + ) + expected = ds.copy() + expected = expected.drop_dims("time") + expected["ts"] = xr.DataArray( + name="ts", + data=np.array([[[1.3933333]]]), + coords={ + "lat": expected.lat, + "lon": expected.lon, + "time": xr.DataArray( + data=np.array([cftime.datetime(2001, 1, 1)], dtype=object), + dims=["time"], + attrs={ + "axis": "T", + "long_name": "time", + "standard_name": "time", + "bounds": "time_bnds", + }, + ), + }, + dims=["time", "lat", "lon"], + attrs={ + "test_attr": "test", + "operation": "temporal_avg", + "mode": "group_average", + "freq": "season", + "weighted": "True", + "drop_incomplete_seasons": "False", + "custom_seasons": ["NovDecJanFebMar"], + }, + ) + + xr.testing.assert_allclose(result, expected) + assert result.ts.attrs == expected.ts.attrs def test_weighted_monthly_averages(self): ds = self.ds.copy() @@ -1051,7 +1281,7 @@ def test_raises_error_if_reference_period_arg_is_incorrect(self): ds.temporal.climatology( "ts", "season", - season_config={"dec_mode": "DJF", "drop_incomplete_djf": True}, + season_config={"dec_mode": "DJF", "drop_incomplete_seasons": True}, reference_period=("01-01-2000", "01-01-2000"), ) @@ -1059,7 +1289,7 @@ def test_raises_error_if_reference_period_arg_is_incorrect(self): ds.temporal.climatology( "ts", "season", - season_config={"dec_mode": "DJF", "drop_incomplete_djf": True}, + season_config={"dec_mode": "DJF", "drop_incomplete_seasons": True}, reference_period=("01-01-2000"), ) @@ -1069,7 +1299,7 @@ def test_subsets_climatology_based_on_reference_period(self): result = ds.temporal.climatology( "ts", "season", - season_config={"dec_mode": "DJF", "drop_incomplete_djf": True}, + season_config={"dec_mode": "DJF", "drop_incomplete_seasons": True}, reference_period=("2000-01-01", "2000-06-01"), ) @@ -1101,7 +1331,7 @@ def test_subsets_climatology_based_on_reference_period(self): "freq": "season", "weighted": "True", "dec_mode": "DJF", - "drop_incomplete_djf": "True", + "drop_incomplete_seasons": "True", }, ) @@ -1113,7 +1343,7 @@ def test_weighted_seasonal_climatology_with_DJF(self): result = ds.temporal.climatology( "ts", "season", - season_config={"dec_mode": "DJF", "drop_incomplete_djf": True}, + season_config={"dec_mode": "DJF", "drop_incomplete_seasons": True}, ) expected = ds.copy() @@ -1154,8 +1384,72 @@ def test_weighted_seasonal_climatology_with_DJF(self): "mode": "climatology", "freq": "season", "weighted": "True", + "drop_incomplete_seasons": "True", "dec_mode": "DJF", + }, + ) + + xr.testing.assert_identical(result, expected) + + def test_raises_deprecation_warning_with_drop_incomplete_djf_season_config(self): + # NOTE: This will test will also cover the other public APIs that + # have drop_incomplete_djf as a season_config arg. + ds = self.ds.copy() + + with warnings.catch_warnings(record=True) as w: + result = ds.temporal.climatology( + "ts", + "season", + season_config={"dec_mode": "DJF", "drop_incomplete_djf": True}, + ) + + assert len(w) == 1 + assert issubclass(w[0].category, DeprecationWarning) + assert str(w[0].message) == ( + "The `season_config` argument 'drop_incomplete_djf' is being deprecated. " + "Please use 'drop_incomplete_seasons' instead." + ) + + expected = ds.copy() + expected = expected.drop_dims("time") + expected_time = xr.DataArray( + data=np.array( + [ + cftime.DatetimeGregorian(1, 1, 1), + cftime.DatetimeGregorian(1, 4, 1), + cftime.DatetimeGregorian(1, 7, 1), + cftime.DatetimeGregorian(1, 10, 1), + ], + ), + coords={ + "time": np.array( + [ + cftime.DatetimeGregorian(1, 1, 1), + cftime.DatetimeGregorian(1, 4, 1), + cftime.DatetimeGregorian(1, 7, 1), + cftime.DatetimeGregorian(1, 10, 1), + ], + ), + }, + attrs={ + "axis": "T", + "long_name": "time", + "standard_name": "time", + "bounds": "time_bnds", + }, + ) + expected["ts"] = xr.DataArray( + name="ts", + data=np.ones((4, 4, 4)), + coords={"lat": expected.lat, "lon": expected.lon, "time": expected_time}, + dims=["time", "lat", "lon"], + attrs={ + "operation": "temporal_avg", + "mode": "climatology", + "freq": "season", + "weighted": "True", "drop_incomplete_djf": "True", + "dec_mode": "DJF", }, ) @@ -1168,7 +1462,7 @@ def test_chunked_weighted_seasonal_climatology_with_DJF(self): result = ds.temporal.climatology( "ts", "season", - season_config={"dec_mode": "DJF", "drop_incomplete_djf": True}, + season_config={"dec_mode": "DJF", "drop_incomplete_seasons": True}, ) expected = ds.copy() @@ -1210,7 +1504,7 @@ def test_chunked_weighted_seasonal_climatology_with_DJF(self): "freq": "season", "weighted": "True", "dec_mode": "DJF", - "drop_incomplete_djf": "True", + "drop_incomplete_seasons": "True", }, ) @@ -1261,6 +1555,7 @@ def test_weighted_seasonal_climatology_with_JFD(self): "mode": "climatology", "freq": "season", "weighted": "True", + "drop_incomplete_seasons": "False", "dec_mode": "JFD", }, ) @@ -1319,6 +1614,7 @@ def test_weighted_custom_seasonal_climatology(self): "mode": "climatology", "freq": "season", "weighted": "True", + "drop_incomplete_seasons": "False", "custom_seasons": [ "JanFebMar", "AprMayJun", @@ -1330,6 +1626,57 @@ def test_weighted_custom_seasonal_climatology(self): xr.testing.assert_identical(result, expected) + def test_weighted_custom_seasonal_climatology_with_seasons_spanning_calendar_years( + self, + ): + ds = self.ds.copy() + + custom_seasons = [["Nov", "Dec", "Jan", "Feb", "Mar"]] + result = ds.temporal.climatology( + "ts", + "season", + season_config={ + "drop_incomplete_seasons": False, + "custom_seasons": custom_seasons, + }, + ) + + expected = ds.copy() + expected = expected.drop_dims("time") + expected_time = xr.DataArray( + data=np.array( + [cftime.DatetimeGregorian(1, 1, 1)], + ), + coords={ + "time": np.array( + [cftime.DatetimeGregorian(1, 1, 1)], + ), + }, + attrs={ + "axis": "T", + "long_name": "time", + "standard_name": "time", + "bounds": "time_bnds", + }, + ) + + expected["ts"] = xr.DataArray( + name="ts", + data=np.ones((1, 4, 4)), + coords={"lat": expected.lat, "lon": expected.lon, "time": expected_time}, + dims=["time", "lat", "lon"], + attrs={ + "operation": "temporal_avg", + "mode": "climatology", + "freq": "season", + "weighted": "True", + "drop_incomplete_seasons": "False", + "custom_seasons": ["NovDecJanFebMar"], + }, + ) + + xr.testing.assert_identical(result, expected) + def test_weighted_monthly_climatology(self): result = self.ds.temporal.climatology("ts", "month") @@ -1749,7 +2096,7 @@ def test_raises_error_if_reference_period_arg_is_incorrect(self): ds.temporal.departures( "ts", "season", - season_config={"dec_mode": "DJF", "drop_incomplete_djf": True}, + season_config={"dec_mode": "DJF", "drop_incomplete_seasons": True}, reference_period=("01-01-2000", "01-01-2000"), ) @@ -1757,7 +2104,7 @@ def test_raises_error_if_reference_period_arg_is_incorrect(self): ds.temporal.departures( "ts", "season", - season_config={"dec_mode": "DJF", "drop_incomplete_djf": True}, + season_config={"dec_mode": "DJF", "drop_incomplete_seasons": True}, reference_period=("01-01-2000"), ) @@ -1768,7 +2115,7 @@ def test_seasonal_departures_relative_to_climatology_reference_period(self): "ts", "season", weighted=True, - season_config={"dec_mode": "DJF", "drop_incomplete_djf": True}, + season_config={"dec_mode": "DJF", "drop_incomplete_seasons": False}, reference_period=("2000-01-01", "2000-06-01"), ) @@ -1776,13 +2123,14 @@ def test_seasonal_departures_relative_to_climatology_reference_period(self): expected = expected.drop_dims("time") expected["ts"] = xr.DataArray( name="ts", - data=np.array([[[0.0]], [[np.nan]], [[np.nan]], [[np.nan]]]), + data=np.array([[[0.0]], [[0.0]], [[np.nan]], [[np.nan]], [[0.0]]]), coords={ "lat": expected.lat, "lon": expected.lon, "time": xr.DataArray( data=np.array( [ + cftime.DatetimeGregorian(2000, 1, 1), cftime.DatetimeGregorian(2000, 4, 1), cftime.DatetimeGregorian(2000, 7, 1), cftime.DatetimeGregorian(2000, 10, 1), @@ -1806,7 +2154,7 @@ def test_seasonal_departures_relative_to_climatology_reference_period(self): "freq": "season", "weighted": "True", "dec_mode": "DJF", - "drop_incomplete_djf": "True", + "drop_incomplete_seasons": "False", }, ) @@ -1821,7 +2169,7 @@ def test_monthly_departures_relative_to_climatology_reference_period_with_same_o "ts", "month", weighted=True, - season_config={"dec_mode": "DJF", "drop_incomplete_djf": True}, + season_config={"dec_mode": "DJF", "drop_incomplete_seasons": True}, reference_period=("2000-01-01", "2000-06-01"), ) @@ -1904,20 +2252,21 @@ def test_weighted_seasonal_departures_with_DJF(self): "ts", "season", weighted=True, - season_config={"dec_mode": "DJF", "drop_incomplete_djf": True}, + season_config={"dec_mode": "DJF", "drop_incomplete_seasons": False}, ) expected = ds.copy() expected = expected.drop_dims("time") expected["ts"] = xr.DataArray( name="ts", - data=np.array([[[0.0]], [[0.0]], [[0.0]], [[0.0]]]), + data=np.array([[[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]]]), coords={ "lat": expected.lat, "lon": expected.lon, "time": xr.DataArray( data=np.array( [ + cftime.DatetimeGregorian(2000, 1, 1), cftime.DatetimeGregorian(2000, 4, 1), cftime.DatetimeGregorian(2000, 7, 1), cftime.DatetimeGregorian(2000, 10, 1), @@ -1941,7 +2290,7 @@ def test_weighted_seasonal_departures_with_DJF(self): "freq": "season", "weighted": "True", "dec_mode": "DJF", - "drop_incomplete_djf": "True", + "drop_incomplete_seasons": "False", }, ) @@ -1955,20 +2304,21 @@ def test_weighted_seasonal_departures_with_DJF_and_keep_weights(self): "season", weighted=True, keep_weights=True, - season_config={"dec_mode": "DJF", "drop_incomplete_djf": True}, + season_config={"dec_mode": "DJF", "drop_incomplete_seasons": False}, ) expected = ds.copy() expected = expected.drop_dims("time") expected["ts"] = xr.DataArray( name="ts", - data=np.array([[[0.0]], [[0.0]], [[0.0]], [[0.0]]]), + data=np.array([[[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]]]), coords={ "lat": expected.lat, "lon": expected.lon, "time": xr.DataArray( data=np.array( [ + cftime.DatetimeGregorian(2000, 1, 1), cftime.DatetimeGregorian(2000, 4, 1), cftime.DatetimeGregorian(2000, 7, 1), cftime.DatetimeGregorian(2000, 10, 1), @@ -1992,16 +2342,17 @@ def test_weighted_seasonal_departures_with_DJF_and_keep_weights(self): "freq": "season", "weighted": "True", "dec_mode": "DJF", - "drop_incomplete_djf": "True", + "drop_incomplete_seasons": "False", }, ) expected["time_wts"] = xr.DataArray( name="ts", - data=np.array([1.0, 1.0, 1.0, 1.0]), + data=np.array([0.52542373, 1.0, 1.0, 1.0, 0.47457627]), coords={ "time_original": xr.DataArray( data=np.array( [ + "2000-01-16T12:00:00.000000000", "2000-03-16T12:00:00.000000000", "2000-06-16T00:00:00.000000000", "2000-09-16T00:00:00.000000000", @@ -2021,7 +2372,8 @@ def test_weighted_seasonal_departures_with_DJF_and_keep_weights(self): dims=["time_original"], ) - xr.testing.assert_identical(result, expected) + xr.testing.assert_allclose(result, expected) + assert result.ts.attrs == expected.ts.attrs def test_unweighted_seasonal_departures_with_DJF(self): ds = self.ds.copy() @@ -2030,20 +2382,21 @@ def test_unweighted_seasonal_departures_with_DJF(self): "ts", "season", weighted=False, - season_config={"dec_mode": "DJF", "drop_incomplete_djf": True}, + season_config={"dec_mode": "DJF", "drop_incomplete_seasons": False}, ) expected = ds.copy() expected = expected.drop_dims("time") expected["ts"] = xr.DataArray( name="ts", - data=np.array([[[0.0]], [[0.0]], [[0.0]], [[0.0]]]), + data=np.array([[[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]]]), coords={ "lat": expected.lat, "lon": expected.lon, "time": xr.DataArray( data=np.array( [ + cftime.DatetimeGregorian(2000, 1, 1), cftime.DatetimeGregorian(2000, 4, 1), cftime.DatetimeGregorian(2000, 7, 1), cftime.DatetimeGregorian(2000, 10, 1), @@ -2066,8 +2419,8 @@ def test_unweighted_seasonal_departures_with_DJF(self): "mode": "departures", "freq": "season", "weighted": "False", + "drop_incomplete_seasons": "False", "dec_mode": "DJF", - "drop_incomplete_djf": "True", }, ) @@ -2117,6 +2470,7 @@ def test_unweighted_seasonal_departures_with_JFD(self): "mode": "departures", "freq": "season", "weighted": "False", + "drop_incomplete_seasons": "False", "dec_mode": "JFD", }, ) @@ -3290,7 +3644,7 @@ def test_raises_error_with_incorrect_mode_arg(self): weighted=True, season_config={ "dec_mode": "DJF", - "drop_incomplete_djf": False, + "drop_incomplete_seasons": False, "custom_seasons": None, }, ) @@ -3306,7 +3660,7 @@ def test_raises_error_if_freq_arg_is_not_supported_by_operation(self): weighted=True, season_config={ "dec_mode": "DJF", - "drop_incomplete_djf": False, + "drop_incomplete_seasons": False, "custom_seasons": None, }, ) @@ -3318,7 +3672,7 @@ def test_raises_error_if_freq_arg_is_not_supported_by_operation(self): weighted=True, season_config={ "dec_mode": "DJF", - "drop_incomplete_djf": False, + "drop_incomplete_seasons": False, "custom_seasons": None, }, ) @@ -3330,7 +3684,7 @@ def test_raises_error_if_freq_arg_is_not_supported_by_operation(self): weighted=True, season_config={ "dec_mode": "DJF", - "drop_incomplete_djf": False, + "drop_incomplete_seasons": False, "custom_seasons": None, }, ) @@ -3356,7 +3710,7 @@ def test_raises_error_if_december_mode_is_not_supported(self): weighted=True, season_config={ "dec_mode": "unsupported", - "drop_incomplete_djf": False, + "drop_incomplete_seasons": False, "custom_seasons": None, }, ) diff --git a/xcdat/temporal.py b/xcdat/temporal.py index ce44bbfd..b4d446e4 100644 --- a/xcdat/temporal.py +++ b/xcdat/temporal.py @@ -1,5 +1,6 @@ """Module containing temporal functions.""" +import warnings from datetime import datetime from itertools import chain from typing import Dict, List, Literal, Optional, Tuple, TypedDict, Union, get_args @@ -13,6 +14,7 @@ from xarray.coding.cftime_offsets import get_date_type from xarray.core.common import contains_cftime_datetimes, is_np_datetime_like from xarray.core.groupby import DataArrayGroupBy +from xarray.groupers import SeasonGrouper, UniqueGrouper from xcdat import bounds # noqa: F401 from xcdat._logger import _setup_custom_logger @@ -66,8 +68,10 @@ SeasonConfigInput = TypedDict( "SeasonConfigInput", { - "dec_mode": Literal["DJF", "JFD"], + # TODO: Deprecate incomplete_djf. "drop_incomplete_djf": bool, + "drop_incomplete_seasons": bool, + "dec_mode": Literal["DJF", "JFD"], "custom_seasons": Optional[List[List[str]]], }, total=False, @@ -76,16 +80,20 @@ SeasonConfigAttr = TypedDict( "SeasonConfigAttr", { - "dec_mode": Literal["DJF", "JFD"], + # TODO: Deprecate incomplete_djf. "drop_incomplete_djf": bool, + "drop_incomplete_seasons": bool, + "dec_mode": Literal["DJF", "JFD"], "custom_seasons": Optional[Dict[str, List[str]]], }, total=False, ) DEFAULT_SEASON_CONFIG: SeasonConfigInput = { - "dec_mode": "DJF", + # TODO: Deprecate incomplete_djf. "drop_incomplete_djf": False, + "drop_incomplete_seasons": False, + "dec_mode": "DJF", "custom_seasons": None, } @@ -104,6 +112,7 @@ 11: "Nov", 12: "Dec", } +MONTH_STR_TO_INT = {v: k for k, v in MONTH_INT_TO_STR.items()} # A dictionary mapping pre-defined seasons to their middle month. This # dictionary is used during the creation of datetime objects, which don't @@ -248,6 +257,11 @@ def group_average( Time bounds are used for generating weights to calculate weighted group averages (refer to the ``weighted`` parameter documentation below). + .. deprecated:: v0.8.0 + The ``season_config`` dictionary argument ``"drop_incomplete_djf"`` + is being deprecated. Please use ``"drop_incomplete_seasons"`` + instead. + Parameters ---------- data_var: str @@ -280,35 +294,48 @@ def group_average( keep_weights : bool, optional If calculating averages using weights, keep the weights in the final dataset output, by default False. - season_config: SeasonConfigInput, optional + season_config : SeasonConfigInput, optional A dictionary for "season" frequency configurations. If configs for predefined seasons are passed, configs for custom seasons are ignored and vice versa. - Configs for predefined seasons: + * "drop_incomplete_seasons" (bool, by default False) + Seasons are considered incomplete if they do not have all of + the required months to form the season. This argument supersedes + "drop_incomplete_djf". For example, if we have + the time coordinates ["2000-11-16", "2000-12-16", "2001-01-16", + "2001-02-16"] and we want to group seasons by "ND" ("Nov", + "Dec") and "JFM" ("Jan", "Feb", "Mar"). - * "dec_mode" (Literal["DJF", "JFD"], by default "DJF") - The mode for the season that includes December. - - * "DJF": season includes the previous year December. - * "JFD": season includes the same year December. - Xarray labels the season with December as "DJF", but it is - actually "JFD". + * ["2000-11-16", "2000-12-16"] is considered a complete "ND" + season since both "Nov" and "Dec" are present. + * ["2001-01-16", "2001-02-16"] is considered an incomplete "JFM" + season because it only has "Jan" and "Feb". Therefore, these + time coordinates are dropped. * "drop_incomplete_djf" (bool, by default False) If the "dec_mode" is "DJF", this flag drops (True) or keeps (False) time coordinates that fall under incomplete DJF seasons Incomplete DJF seasons include the start year Jan/Feb and the - end year Dec. + end year Dec. This argument is superceded by + "drop_incomplete_seasons" and will be deprecated in a future + release. - Configs for custom seasons: + * "dec_mode" (Literal["DJF", "JFD"], by default "DJF") + The mode for the season that includes December in the list of + list of pre-defined seasons ("DJF"/"JFD", "MAM", "JJA", "SON"). + This config is ignored if the ``custom_seasons`` config is set. + + * "DJF": season includes the previous year December. + * "JFD": season includes the same year December. + Xarray labels the season with December as "DJF", but it is + actually "JFD". * "custom_seasons" ([List[List[str]]], by default None) List of sublists containing month strings, with each sublist representing a custom season. * Month strings must be in the three letter format (e.g., 'Jan') - * Each month must be included once in a custom season * Order of the months in each custom season does not matter * Custom seasons can vary in length @@ -349,7 +376,7 @@ def group_average( >>> "season", >>> season_config={ >>> "dec_mode": "DJF", - >>> "drop_incomplete_season": True + >>> "drop_incomplete_seasons": True >>> } >>> ) >>> ds_season.ts @@ -385,7 +412,7 @@ def group_average( 'freq': 'season', 'weighted': 'True', 'dec_mode': 'DJF', - 'drop_incomplete_djf': 'False' + 'drop_incomplete_seasons': 'False' } """ self._set_data_var_attrs(data_var) @@ -414,6 +441,11 @@ def climatology( Time bounds are used for generating weights to calculate weighted climatology (refer to the ``weighted`` parameter documentation below). + .. deprecated:: v0.8.0 + The ``season_config`` dictionary argument ``"drop_incomplete_djf"`` + is being deprecated. Please use ``"drop_incomplete_seasons"`` + instead. + Parameters ---------- data_var: str @@ -455,35 +487,48 @@ def climatology( 'yyyy-mm-dd'. For example, ``('1850-01-01', '1899-12-31')``. If no value is provided, the climatological reference period will be the full period covered by the dataset. - season_config: SeasonConfigInput, optional + season_config : SeasonConfigInput, optional A dictionary for "season" frequency configurations. If configs for predefined seasons are passed, configs for custom seasons are ignored and vice versa. - Configs for predefined seasons: + * "drop_incomplete_seasons" (bool, by default False) + Seasons are considered incomplete if they do not have all of + the required months to form the season. This argument supersedes + "drop_incomplete_djf". For example, if we have + the time coordinates ["2000-11-16", "2000-12-16", "2001-01-16", + "2001-02-16"] and we want to group seasons by "ND" ("Nov", + "Dec") and "JFM" ("Jan", "Feb", "Mar"). - * "dec_mode" (Literal["DJF", "JFD"], by default "DJF") - The mode for the season that includes December. - - * "DJF": season includes the previous year December. - * "JFD": season includes the same year December. - Xarray labels the season with December as "DJF", but it is - actually "JFD". + * ["2000-11-16", "2000-12-16"] is considered a complete "ND" + season since both "Nov" and "Dec" are present. + * ["2001-01-16", "2001-02-16"] is considered an incomplete "JFM" + season because it only has "Jan" and "Feb". Therefore, these + time coordinates are dropped. * "drop_incomplete_djf" (bool, by default False) If the "dec_mode" is "DJF", this flag drops (True) or keeps (False) time coordinates that fall under incomplete DJF seasons Incomplete DJF seasons include the start year Jan/Feb and the - end year Dec. + end year Dec. This argument is superceded by + "drop_incomplete_seasons" and will be deprecated in a future + release. - Configs for custom seasons: + * "dec_mode" (Literal["DJF", "JFD"], by default "DJF") + The mode for the season that includes December in the list of + list of pre-defined seasons ("DJF"/"JFD", "MAM", "JJA", "SON"). + This config is ignored if the ``custom_seasons`` config is set. + + * "DJF": season includes the previous year December. + * "JFD": season includes the same year December. + Xarray labels the season with December as "DJF", but it is + actually "JFD". * "custom_seasons" ([List[List[str]]], by default None) List of sublists containing month strings, with each sublist representing a custom season. * Month strings must be in the three letter format (e.g., 'Jan') - * Each month must be included once in a custom season * Order of the months in each custom season does not matter * Custom seasons can vary in length @@ -528,7 +573,7 @@ def climatology( >>> "season", >>> season_config={ >>> "dec_mode": "DJF", - >>> "drop_incomplete_season": True + >>> "drop_incomplete_seasons": True >>> } >>> ) >>> ds_season.ts @@ -564,7 +609,7 @@ def climatology( 'freq': 'season', 'weighted': 'True', 'dec_mode': 'DJF', - 'drop_incomplete_djf': 'False' + 'drop_incomplete_seasons': 'False' } """ self._set_data_var_attrs(data_var) @@ -600,6 +645,11 @@ def departures( Time bounds are used for generating weights to calculate weighted climatology (refer to the ``weighted`` parameter documentation below). + .. deprecated:: v0.8.0 + The ``season_config`` dictionary argument ``"drop_incomplete_djf"`` + is being deprecated. Please use ``"drop_incomplete_seasons"`` + instead. + Parameters ---------- data_var: str @@ -642,11 +692,35 @@ def departures( ``('1850-01-01', '1899-12-31')``. If no value is provided, the climatological reference period will be the full period covered by the dataset. - season_config: SeasonConfigInput, optional + season_config : SeasonConfigInput, optional A dictionary for "season" frequency configurations. If configs for predefined seasons are passed, configs for custom seasons are ignored and vice versa. + General configs: + + * "drop_incomplete_seasons" (bool, by default False) + Seasons are considered incomplete if they do not have all of + the required months to form the season. This argument supersedes + "drop_incomplete_djf". For example, if we have + the time coordinates ["2000-11-16", "2000-12-16", "2001-01-16", + "2001-02-16"] and we want to group seasons by "ND" ("Nov", + "Dec") and "JFM" ("Jan", "Feb", "Mar"). + + * ["2000-11-16", "2000-12-16"] is considered a complete "ND" + season since both "Nov" and "Dec" are present. + * ["2001-01-16", "2001-02-16"] is considered an incomplete "JFM" + season because it only has "Jan" and "Feb". Therefore, these + time coordinates are dropped. + + * "drop_incomplete_djf" (bool, by default False) + If the "dec_mode" is "DJF", this flag drops (True) or keeps + (False) time coordinates that fall under incomplete DJF seasons + Incomplete DJF seasons include the start year Jan/Feb and the + end year Dec. This argument is superceded by + "drop_incomplete_seasons" and will be deprecated in a future + release. + Configs for predefined seasons: * "dec_mode" (Literal["DJF", "JFD"], by default "DJF") @@ -657,12 +731,6 @@ def departures( Xarray labels the season with December as "DJF", but it is actually "JFD". - * "drop_incomplete_djf" (bool, by default False) - If the "dec_mode" is "DJF", this flag drops (True) or keeps - (False) time coordinates that fall under incomplete DJF seasons - Incomplete DJF seasons include the start year Jan/Feb and the - end year Dec. - Configs for custom seasons: * "custom_seasons" ([List[List[str]]], by default None) @@ -670,7 +738,6 @@ def departures( representing a custom season. * Month strings must be in the three letter format (e.g., 'Jan') - * Each month must be included once in a custom season * Order of the months in each custom season does not matter * Custom seasons can vary in length @@ -730,7 +797,7 @@ def departures( 'frequency': 'season', 'weighted': 'True', 'dec_mode': 'DJF', - 'drop_incomplete_djf': 'False' + 'drop_incomplete_seasons': 'False' } """ # 1. Set the attributes for this instance of `TemporalAccessor`. @@ -931,30 +998,46 @@ def _set_arg_attrs( self._is_valid_reference_period(reference_period) self._reference_period = reference_period - # "season" frequency specific configuration attributes. + self._set_season_config_attr(season_config) + + def _set_season_config_attr(self, season_config: SeasonConfigInput): for key in season_config.keys(): - if key not in DEFAULT_SEASON_CONFIG.keys(): + if key not in DEFAULT_SEASON_CONFIG: raise KeyError( f"'{key}' is not a supported season config. Supported " f"configs include: {DEFAULT_SEASON_CONFIG.keys()}." ) - custom_seasons = season_config.get("custom_seasons", None) - dec_mode = season_config.get("dec_mode", "DJF") - drop_incomplete_djf = season_config.get("drop_incomplete_djf", False) self._season_config: SeasonConfigAttr = {} - if custom_seasons is None: + self._season_config["drop_incomplete_seasons"] = season_config.get( + "drop_incomplete_seasons", False + ) + + custom_seasons = season_config.get("custom_seasons", None) + if custom_seasons is not None: + self._season_config["custom_seasons"] = self._form_seasons(custom_seasons) + else: + dec_mode = season_config.get("dec_mode", "DJF") if dec_mode not in ("DJF", "JFD"): raise ValueError( "Incorrect 'dec_mode' key value for `season_config`. " "Supported modes include 'DJF' or 'JFD'." ) + self._season_config["dec_mode"] = dec_mode + # TODO: Deprecate incomplete_djf. + drop_incomplete_djf = season_config.get("drop_incomplete_djf", False) if dec_mode == "DJF": + if drop_incomplete_djf is not False: + warnings.warn( + "The `season_config` argument 'drop_incomplete_djf' is being " + "deprecated. Please use 'drop_incomplete_seasons' instead.", + DeprecationWarning, + stacklevel=2, + ) + self._season_config["drop_incomplete_djf"] = drop_incomplete_djf - else: - self._season_config["custom_seasons"] = self._form_seasons(custom_seasons) def _is_valid_reference_period(self, reference_period: Tuple[str, str]): try: @@ -997,10 +1080,6 @@ def _form_seasons(self, custom_seasons: List[List[str]]) -> Dict[str, List[str]] predefined_months = list(MONTH_INT_TO_STR.values()) input_months = list(chain.from_iterable(custom_seasons)) - if len(input_months) != len(predefined_months): - raise ValueError( - "Exactly 12 months were not passed in the list of custom seasons." - ) if len(input_months) != len(set(input_months)): raise ValueError( "Duplicate month(s) were found in the list of custom seasons." @@ -1013,16 +1092,22 @@ def _form_seasons(self, custom_seasons: List[List[str]]) -> Dict[str, List[str]] f"Supported months include: {predefined_months}." ) - c_seasons = {"".join(months): months for months in custom_seasons} + c_seasons = {} + for season in custom_seasons: + key = "".join([month[0] for month in season]) + c_seasons[key] = season return c_seasons def _preprocess_dataset(self, ds: xr.Dataset) -> xr.Dataset: """Preprocess the dataset based on averaging settings. - Preprocessing operations include: - - Drop incomplete DJF seasons (leading/trailing) - - Drop leap days + Operations include: + 1. Drop leap days for daily climatologies. + 2. Subset the dataset based on the reference period. + 3. Shift years for custom seasons spanning the calendar year. + 4. Shift Decembers for "DJF" mode and drop incomplete "DJF" seasons. + 5. Drop incomplete seasons if specified. Parameters ---------- @@ -1033,13 +1118,6 @@ def _preprocess_dataset(self, ds: xr.Dataset) -> xr.Dataset: ------- xr.Dataset """ - if ( - self._freq == "season" - and self._season_config.get("dec_mode") == "DJF" - and self._season_config.get("drop_incomplete_djf") is True - ): - ds = self._drop_incomplete_djf(ds) - if ( self._freq == "day" and self._mode in ["climatology", "departures"] @@ -1052,8 +1130,187 @@ def _preprocess_dataset(self, ds: xr.Dataset) -> xr.Dataset: {self.dim: slice(self._reference_period[0], self._reference_period[1])} ) + if ( + self._freq == "season" + and self._season_config.get("custom_seasons") is not None + ): + months = self._season_config["custom_seasons"].values() # type: ignore + months = list(chain.from_iterable(months)) + + if len(months) != 12: + ds = self._subset_coords_for_custom_seasons(ds, months) + + # FIXME: This causes a bug when accessing `.groups` with + # SeasonGrouper(). Also shifting custom seasons is done for + # drop_incomplete_seasons and grouping for months that span the + # calendar year. The Xarray PR will handle both of these cases + # and this method will be removed. + # ds = self._shift_custom_season_years(ds) + pass + + if self._freq == "season" and self._season_config.get("dec_mode") == "DJF": + ds = self._shift_djf_decembers(ds) + + # TODO: Deprecate incomplete_djf. + if ( + self._season_config.get("drop_incomplete_djf") is True + and self._season_config.get("drop_incomplete_seasons") is False + ): + ds = self._drop_incomplete_djf(ds) + + # if ( + # self._freq == "season" + # and self._season_config["drop_incomplete_seasons"] is True + # ): + # ds = self._drop_incomplete_seasons(ds) + return ds + def _subset_coords_for_custom_seasons( + self, ds: xr.Dataset, months: List[str] + ) -> xr.Dataset: + """Subsets time coordinates to the months included in custom seasons. + + Parameters + ---------- + ds : xr.Dataset + The dataset. + months : List[str] + A list of months included in custom seasons. + Example: ["Nov", "Dec", "Jan"] + + Returns + ------- + xr.Dataset + The dataset with time coordinate subsetted to months used in + custom seasons. + """ + month_ints = sorted([MONTH_STR_TO_INT[month] for month in months]) + + coords_by_month = ds[self.dim].groupby(f"{self.dim}.month").groups + month_to_time_idx = { + k: coords_by_month[k] for k in month_ints if k in coords_by_month + } + month_to_time_idx = sorted( + list(chain.from_iterable(month_to_time_idx.values())) # type: ignore + ) + ds_new = ds.isel({f"{self.dim}": month_to_time_idx}) + + return ds_new + + def _shift_custom_season_years(self, ds: xr.Dataset) -> xr.Dataset: + """Shifts the year for custom seasons spanning the calendar year. + + A season spans the calendar year if it includes "Jan" and "Jan" is not + the first month. For example, for + ``custom_seasons = ["Nov", "Dec", "Jan", "Feb", "Mar"]``: + - ["Nov", "Dec"] are from the previous year. + - ["Jan", "Feb", "Mar"] are from the current year. + + Therefore, ["Nov", "Dec"] need to be shifted a year forward for correct + grouping. + + Parameters + ---------- + ds : xr.Dataset + The Dataset with time coordinates. + + Returns + ------- + xr.Dataset + The Dataset with shifted time coordinates. + + Examples + -------- + + Before and after shifting months for "NDJFM" seasons: + + >>> # Before shifting months + >>> [(2000, "NDJFM", 11), (2000, "NDJFM", 12), (2001, "NDJFM", 1), + >>> (2001, "NDJFM", 2), (2001, "NDJFM", 3)] + + >>> # After shifting months + >>> [(2001, "NDJFM", 11), (2001, "NDJFM", 12), (2001, "NDJFM", 1), + >>> (2001, "NDJFM", 2), (2001, "NDJFM", 3)] + """ + ds_new = ds.copy() + custom_seasons = self._season_config["custom_seasons"] + + span_months: List[int] = [] + + # Identify the months that span across years in custom seasons. + # This is done by checking if "Jan" is not the first month in the + # custom season and getting all months before "Jan". + for months in custom_seasons.values(): # type: ignore + month_nums = [MONTH_STR_TO_INT[month] for month in months] + if 1 in month_nums: + jan_index = month_nums.index(1) + + if jan_index != 0: + span_months.extend(month_nums[:jan_index]) + break + + if span_months: + time_coords = ds_new[self.dim].copy() + idxs = np.where(time_coords.dt.month.isin(span_months))[0] + + if isinstance(time_coords.values[0], cftime.datetime): + for idx in idxs: + time_coords.values[idx] = time_coords.values[idx].replace( + year=time_coords.values[idx].year + 1 + ) + else: + for idx in idxs: + time_coords.values[idx] = pd.Timestamp( + time_coords.values[idx] + ) + pd.DateOffset(years=1) + + ds_new = ds_new.assign_coords({self.dim: time_coords}) + + return ds_new + + def _shift_djf_decembers(self, ds: xr.Dataset) -> xr.Dataset: + """Shifts Decembers to the next year for "DJF" seasons. + + This ensures correct grouping for "DJF" seasons by shifting Decembers + to the next year. Without this, grouping defaults to "JFD", which + is the native Xarray behavior. + + Parameters + ---------- + ds : xr.Dataset + The Dataset with time coordinates. + + Returns + ------- + xr.Dataset + The Dataset with shifted time coordinates. + + Examples + -------- + + Comparison of "JFD" and "DJF" seasons: + + >>> # "JFD" (native xarray behavior) + >>> [(2000, "DJF", 1), (2000, "DJF", 2), (2000, "DJF", 12), + >>> (2001, "DJF", 1), (2001, "DJF", 2)] + + >>> # "DJF" (shifted Decembers) + >>> [(2000, "DJF", 1), (2000, "DJF", 2), (2001, "DJF", 12), + >>> (2001, "DJF", 1), (2001, "DJF", 2)] + """ + ds_new = ds.copy() + time_coords = ds_new[self.dim].copy() + dec_indexes = time_coords.dt.month == 12 + + time_coords.values[dec_indexes] = [ + time.replace(year=time.year + 1) for time in time_coords.values[dec_indexes] + ] + + ds_new = ds_new.assign_coords({self.dim: time_coords}) + + return ds_new + def _drop_incomplete_djf(self, dataset: xr.Dataset) -> xr.Dataset: """Drops incomplete DJF seasons within a continuous time series. @@ -1067,7 +1324,6 @@ def _drop_incomplete_djf(self, dataset: xr.Dataset) -> xr.Dataset: ---------- dataset : xr.Dataset The dataset with some possibly incomplete DJF seasons. - Returns ------- xr.Dataset @@ -1102,6 +1358,75 @@ def _drop_incomplete_djf(self, dataset: xr.Dataset) -> xr.Dataset: return ds_final + def _drop_incomplete_seasons(self, ds: xr.Dataset) -> xr.Dataset: + """Drops incomplete seasons within a continuous time series. + + Seasons are considered incomplete if they do not have all of the + required months to form the season. For example, if we have the time + coordinates ["2000-11-16", "2000-12-16", "2001-01-16", "2001-02-16"] + and we want to group seasons by "ND" ("Nov", "Dec") and "JFM" ("Jan", + "Feb", "Mar"). + - ["2000-11-16", "2000-12-16"] is considered a complete "ND" season + since both "Nov" and "Dec" are present. + - ["2001-01-16", "2001-02-16"] is considered an incomplete "JFM" + season because it only has "Jan" and "Feb". Therefore, these + time coordinates are dropped. + + Parameters + ---------- + df : pd.DataFrame + A DataFrame of seasonal datetime components with potentially + incomplete seasons. + + Returns + ------- + pd.DataFrame + A DataFrame of seasonal datetime components with only complete + seasons. + + Notes + ----- + TODO: Refactor this method to use pure Xarray/NumPy operations, rather + than Pandas. + """ + # Transform the time coords into a DataFrame of seasonal datetime + # components based on the grouping mode. + time_coords = ds[self.dim].copy() + df = self._get_df_dt_components(time_coords, drop_obsolete_cols=False) + + # Get the expected and actual number of months for each season group. + df["expected_months"] = df["season"].str.split(r"(?<=.)(?=[A-Z])").str.len() + df["actual_months"] = df.groupby(["year", "season"])["year"].transform("count") + + # Get the incomplete seasons and drop the time coordinates that are in + # those incomplete seasons. + indexes_to_drop = df[df["expected_months"] != df["actual_months"]].index + + if len(indexes_to_drop) == len(time_coords): + raise RuntimeError( + "No time coordinates remain with `drop_incomplete_seasons=True`. " + "Check the dataset has at least one complete season and/or " + "specify `drop_incomplete_seasons=False` instead." + ) + elif len(indexes_to_drop) > 0: + # The dataset needs to be split into a dataset with and a dataset + # without the time dimension because the xarray `.where()` method + # adds the time dimension to non-time dimension data vars when + # broadcasting, which is a behavior we do not desire. + # https://github.com/pydata/xarray/issues/1234 + # https://github.com/pydata/xarray/issues/8796#issuecomment-1974878267 + ds_no_time = ds.get([v for v in ds.data_vars if self.dim not in ds[v].dims]) # type: ignore + ds_time = ds.get([v for v in ds.data_vars if self.dim in ds[v].dims]) # type: ignore + + coords_to_drop = time_coords.values[indexes_to_drop] + ds_time = ds_time.where(~time_coords.isin(coords_to_drop), drop=True) + + ds_new = xr.merge([ds_time, ds_no_time]) + + return ds_new + + return ds + def _drop_leap_days(self, ds: xr.Dataset): """Drop leap days from time coordinates. @@ -1120,7 +1445,9 @@ def _drop_leap_days(self, ds: xr.Dataset): ------- xr.Dataset """ - ds = ds.sel(**{self.dim: ~((ds.time.dt.month == 2) & (ds.time.dt.day == 29))}) + ds = ds.sel( + **{self.dim: ~((ds[self.dim].dt.month == 2) & (ds[self.dim].dt.day == 29))} + ) return ds def _average(self, ds: xr.Dataset, data_var: str) -> xr.DataArray: @@ -1172,8 +1499,7 @@ def _group_average(self, ds: xr.Dataset, data_var: str) -> xr.DataArray: # Label the time coordinates for grouping weights and the data variable # values. - self._labeled_time = self._label_time_coords(dv[self.dim]) - dv = dv.assign_coords({self.dim: self._labeled_time}) + dv_grouped = self._label_time_coords_for_grouping(dv) if self._weighted: time_bounds = ds.bounds.get_bounds("T", var_key=data_var) @@ -1192,13 +1518,14 @@ def _group_average(self, ds: xr.Dataset, data_var: str) -> xr.DataArray: # Perform weighted average using the formula # WA = sum(data*weights) / sum(weights). The denominator must be # included to take into account zero weight for missing data. + weights_gb = self._label_time_coords_for_grouping(weights) with xr.set_options(keep_attrs=True): - dv = self._group_data(dv).sum() / self._group_data(weights).sum() + dv = dv_grouped.sum() / weights_gb.sum() # Restore the data variable's name. dv.name = data_var else: - dv = self._group_data(dv).mean() + dv = dv_grouped.mean() # After grouping and aggregating, the grouped time dimension's # attributes are removed. Xarray's `keep_attrs=True` option only keeps @@ -1256,7 +1583,10 @@ def _get_weights(self, time_bounds: xr.DataArray) -> xr.DataArray: time_lengths = time_lengths.astype(np.float64) - grouped_time_lengths = self._group_data(time_lengths) + grouped_time_lengths = self._label_time_coords_for_grouping(time_lengths) + # FIXME: File "/opt/miniconda3/envs/xcdat_dev_416_xr/lib/python3.12/site-packages/xarray/core/groupby.py", line 639, in _raise_if_not_single_group + # raise NotImplementedError( + # NotImplementedError: This method is not supported for grouping by multiple variables yet. weights: xr.DataArray = grouped_time_lengths / grouped_time_lengths.sum() weights.name = f"{self.dim}_wts" @@ -1294,9 +1624,9 @@ def _label_time_coords(self, time_coords: xr.DataArray) -> xr.DataArray: This methods labels time coordinates for grouping by first extracting specific xarray datetime components from time coordinates and storing them in a pandas DataFrame. After processing (if necessary) is performed - on the DataFrame, it is converted to a numpy array of datetime - objects. This numpy serves as the data source for the final - DataArray of labeled time coordinates. + on the DataFrame, it is converted to a numpy array of datetime objects. + This numpy array serves as the data source for the final DataArray of + labeled time coordinates. Parameters ---------- @@ -1332,7 +1662,9 @@ def _label_time_coords(self, time_coords: xr.DataArray) -> xr.DataArray: >>> Coordinates: >>> * time (time) datetime64[ns] 2000-01-01T00:00:00 ... 2000-04-01T00:00:00 """ - df_dt_components: pd.DataFrame = self._get_df_dt_components(time_coords) + df_dt_components: pd.DataFrame = self._get_df_dt_components( + time_coords, drop_obsolete_cols=True + ) dt_objects = self._convert_df_to_dt(df_dt_components) time_grouped = xr.DataArray( @@ -1346,7 +1678,38 @@ def _label_time_coords(self, time_coords: xr.DataArray) -> xr.DataArray: return time_grouped - def _get_df_dt_components(self, time_coords: xr.DataArray) -> pd.DataFrame: + def _label_time_coords_for_grouping(self, dv: xr.DataArray) -> DataArrayGroupBy: + # Use the TIME_GROUPS dictionary to determine which components + # are needed to form the labeled time coordinates. + dt_comps = TIME_GROUPS[self._mode][self._freq] + dt_comps_map: Dict[str, UniqueGrouper | SeasonGrouper] = { + comp: UniqueGrouper() for comp in dt_comps if comp != "season" + } + + dv_new = dv.copy() + for comp in dt_comps_map.keys(): + dv_new.coords[comp] = dv_new[self.dim][f"{self.dim}.{comp}"] + + if self._freq == "season": + custom_seasons = self._season_config.get("custom_seasons") + # NOTE: SeasonGrouper() does not drop incomplete seasons yet. + # TODO: Add `drop_incomplete` arg once available. + + if custom_seasons is not None: + season_keys = list(custom_seasons.keys()) + season_grouper = SeasonGrouper(season_keys) + else: + season_keys = list(SEASON_TO_MONTH.keys()) + season_grouper = SeasonGrouper(season_keys) + + dt_comps_map[self.dim] = season_grouper + dv_gb = dv_new.groupby(**dt_comps_map) # type: ignore + + return dv_gb + + def _get_df_dt_components( + self, time_coords: xr.DataArray, drop_obsolete_cols: bool + ) -> pd.DataFrame: """Returns a DataFrame of xarray datetime components. This method extracts the applicable xarray datetime components from each @@ -1367,6 +1730,12 @@ def _get_df_dt_components(self, time_coords: xr.DataArray) -> pd.DataFrame: ---------- time_coords : xr.DataArray The time coordinates. + drop_obsolete_cols : bool + Drop obsolete columns after processing seasonal DataFrame when + ``self._freq="season"``. Set to False to keep datetime columns + needed for preprocessing the dataset (e.g,. removing incomplete + seasons), and set to True to remove obsolete columns when needing + to group time coordinates. Returns ------- @@ -1397,41 +1766,18 @@ def _get_df_dt_components(self, time_coords: xr.DataArray) -> pd.DataFrame: if self._mode in ["climatology", "departures"]: df["year"] = time_coords[f"{self.dim}.year"].values df["month"] = time_coords[f"{self.dim}.month"].values - - if self._mode == "group_average": + elif self._mode == "group_average": df["month"] = time_coords[f"{self.dim}.month"].values - df = self._process_season_df(df) - - return df - - def _process_season_df(self, df: pd.DataFrame) -> pd.DataFrame: - """ - Processes a DataFrame of datetime components for the season frequency. - - Parameters - ---------- - df : pd.DataFrame - A DataFrame of xarray datetime components. - - Returns - ------- - pd.DataFrame - A DataFrame of processed xarray datetime components. - """ - df_new = df.copy() - custom_seasons = self._season_config.get("custom_seasons") - dec_mode = self._season_config.get("dec_mode") + custom_seasons = self._season_config.get("custom_seasons") + if custom_seasons is not None: + df = self._map_months_to_custom_seasons(df) - if custom_seasons is not None: - df_new = self._map_months_to_custom_seasons(df_new) - else: - if dec_mode == "DJF": - df_new = self._shift_decembers(df_new) + if drop_obsolete_cols: + df = self._drop_obsolete_columns(df) + df = self._map_seasons_to_mid_months(df) - df_new = self._drop_obsolete_columns(df_new) - df_new = self._map_seasons_to_mid_months(df_new) - return df_new + return df def _map_months_to_custom_seasons(self, df: pd.DataFrame) -> pd.DataFrame: """Maps the month column in the DataFrame to a custom season. @@ -1467,45 +1813,6 @@ def _map_months_to_custom_seasons(self, df: pd.DataFrame) -> pd.DataFrame: return df_new - def _shift_decembers(self, df_season: pd.DataFrame) -> pd.DataFrame: - """Shifts Decembers over to the next year for "DJF" seasons in-place. - - For "DJF" seasons, Decembers must be shifted over to the next year in - order for the xarray groupby operation to correctly label and group the - corresponding time coordinates. If the aren't shifted over, grouping is - incorrectly performed with the native xarray "DJF" season (which is - actually "JFD"). - - Parameters - ---------- - df_season : pd.DataFrame - The DataFrame of xarray datetime components produced using the - "season" frequency. - - Returns - ------- - pd.DataFrame - The DataFrame of xarray datetime components with Decembers shifted - over to the next year. - - Examples - -------- - - Comparison of "JFD" and "DJF" seasons: - - >>> # "JFD" (native xarray behavior) - >>> [(2000, "DJF", 1), (2000, "DJF", 2), (2000, "DJF", 12), - >>> (2001, "DJF", 1), (2001, "DJF", 2)] - - >>> # "DJF" (shifted Decembers) - >>> [(2000, "DJF", 1), (2000, "DJF", 2), (2001, "DJF", 12), - >>> (2001, "DJF", 1), (2001, "DJF", 2)] - - """ - df_season.loc[df_season["month"] == 12, "year"] = df_season["year"] + 1 - - return df_season - def _map_seasons_to_mid_months(self, df: pd.DataFrame) -> pd.DataFrame: """Maps the season column values to the integer of its middle month. @@ -1682,17 +1989,22 @@ def _add_operation_attrs(self, data_var: xr.DataArray) -> xr.DataArray: ) if self._freq == "season": - custom_seasons = self._season_config.get("custom_seasons") - - if custom_seasons is None: - dec_mode = self._season_config.get("dec_mode") - drop_incomplete_djf = self._season_config.get("drop_incomplete_djf") + drop_incomplete_seasons = self._season_config["drop_incomplete_seasons"] + drop_incomplete_djf = self._season_config.get("drop_incomplete_djf", False) - data_var.attrs["dec_mode"] = dec_mode - if dec_mode == "DJF": - data_var.attrs["drop_incomplete_djf"] = str(drop_incomplete_djf) + # TODO: Deprecate drop_incomplete_djf. This attr is only set if the + # user does not set drop_incomplete_seasons. + if drop_incomplete_seasons is False and drop_incomplete_djf is not False: + data_var.attrs["drop_incomplete_djf"] = str(drop_incomplete_djf) else: + data_var.attrs["drop_incomplete_seasons"] = str(drop_incomplete_seasons) + + custom_seasons = self._season_config.get("custom_seasons") + if custom_seasons is not None: data_var.attrs["custom_seasons"] = list(custom_seasons.keys()) + else: + dec_mode = self._season_config.get("dec_mode") + data_var.attrs["dec_mode"] = dec_mode return data_var