diff --git a/rook/utils/weighted_average_utils.py b/rook/utils/weighted_average_utils.py index 3fb5286..25b0097 100644 --- a/rook/utils/weighted_average_utils.py +++ b/rook/utils/weighted_average_utils.py @@ -1,16 +1,14 @@ -import numpy as np -import xarray as xr - -import collections +# import numpy as np +# import xarray as xr from roocs_utils.parameter import collection_parameter +from roocs_utils.parameter import dimension_parameter -from roocs_utils.project_utils import derive_ds_id from daops.ops.base import Operation -from daops.utils import normalise -from clisops.ops.average import average_over_dims +from clisops.ops.average import average_over_dims as clisops_average_over_dims + class WeightedAverage(Operation): def _resolve_params(self, collection, **params): @@ -18,70 +16,19 @@ def _resolve_params(self, collection, **params): Resolve the input parameters to `self.params` and parameterise collection parameter and set to `self.collection`. """ + dims = dimension_parameter.DimensionParameter(["latitude", "longitude"]) collection = collection_parameter.CollectionParameter(collection) self.collection = collection self.params = { + "dims": dims, "ignore_undetected_dims": params.get("ignore_undetected_dims"), } - def _calculate(self): - config = { - "output_type": self._output_type, - "output_dir": self._output_dir, - "split_method": self._split_method, - "file_namer": self._file_namer, - } - - self.params.update(config) - - new_collection = collections.OrderedDict() - - for dset in self.collection: - ds_id = derive_ds_id(dset) - new_collection[ds_id] = dset.file_paths - - # Normalise (i.e. "fix") data inputs based on "character" - norm_collection = normalise.normalise( - new_collection, False # self._apply_fixes - ) - - rs = normalise.ResultSet(vars()) - - # apply weights - datasets = [] - for ds_id in norm_collection.keys(): - ds = norm_collection[ds_id] - # fix time - ds['time'] = ds['time'].astype('int64') - ds['time_bnds'] = ds['time_bnds'].astype('int64') - # calculate weights - weights = np.cos(np.deg2rad(ds.lat)) - weights.name = "weights" - weights.fillna(0) - # apply weights - ds_weighted = ds.weighted(weights) - # add to list - datasets.append(ds_weighted) - - # concat over time - processed_ds = xr.concat( - datasets, - "time", - ) - - # average - outputs = average_over_dims( - processed_ds, - dims=["latitude", "longitude"], - output_type="nc", - ) - # result - rs.add("output", outputs) - - return rs - + def get_operation_callable(self): + return clisops_average_over_dims + def run_weighted_average(args): result = weighted_average(**args) return result.file_uris