Skip to content

Commit

Permalink
fix w-avg
Browse files Browse the repository at this point in the history
  • Loading branch information
cehbrecht committed Oct 30, 2023
1 parent 267b7f8 commit 821068f
Showing 1 changed file with 10 additions and 63 deletions.
73 changes: 10 additions & 63 deletions rook/utils/weighted_average_utils.py
Original file line number Diff line number Diff line change
@@ -1,87 +1,34 @@
import numpy as np
import xarray as xr

import collections
# import numpy as np
# import xarray as xr

from roocs_utils.parameter import collection_parameter
from roocs_utils.parameter import dimension_parameter

from roocs_utils.project_utils import derive_ds_id

from daops.ops.base import Operation
from daops.utils import normalise

from clisops.ops.average import average_over_dims
from clisops.ops.average import average_over_dims as clisops_average_over_dims


class WeightedAverage(Operation):
def _resolve_params(self, collection, **params):
"""
Resolve the input parameters to `self.params` and parameterise
collection parameter and set to `self.collection`.
"""
dims = dimension_parameter.DimensionParameter(["latitude", "longitude"])
collection = collection_parameter.CollectionParameter(collection)

self.collection = collection
self.params = {
"dims": dims,
"ignore_undetected_dims": params.get("ignore_undetected_dims"),
}

def _calculate(self):
config = {
"output_type": self._output_type,
"output_dir": self._output_dir,
"split_method": self._split_method,
"file_namer": self._file_namer,
}

self.params.update(config)

new_collection = collections.OrderedDict()

for dset in self.collection:
ds_id = derive_ds_id(dset)
new_collection[ds_id] = dset.file_paths

# Normalise (i.e. "fix") data inputs based on "character"
norm_collection = normalise.normalise(
new_collection, False # self._apply_fixes
)

rs = normalise.ResultSet(vars())

# apply weights
datasets = []
for ds_id in norm_collection.keys():
ds = norm_collection[ds_id]
# fix time
ds['time'] = ds['time'].astype('int64')
ds['time_bnds'] = ds['time_bnds'].astype('int64')
# calculate weights
weights = np.cos(np.deg2rad(ds.lat))
weights.name = "weights"
weights.fillna(0)
# apply weights
ds_weighted = ds.weighted(weights)
# add to list
datasets.append(ds_weighted)

# concat over time
processed_ds = xr.concat(
datasets,
"time",
)

# average
outputs = average_over_dims(
processed_ds,
dims=["latitude", "longitude"],
output_type="nc",
)
# result
rs.add("output", outputs)

return rs

def get_operation_callable(self):
return clisops_average_over_dims


def run_weighted_average(args):
result = weighted_average(**args)
return result.file_uris
Expand Down

0 comments on commit 821068f

Please sign in to comment.