Skip to content

Commit

Permalink
working on w-avg operator
Browse files Browse the repository at this point in the history
  • Loading branch information
cehbrecht committed Nov 7, 2023
1 parent e18cdeb commit 1452348
Showing 1 changed file with 71 additions and 7 deletions.
78 changes: 71 additions & 7 deletions rook/utils/weighted_average_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,30 @@
# import numpy as np
import numpy as np

# import xarray as xr

import collections

from roocs_utils.parameter import collection_parameter
from roocs_utils.parameter import dimension_parameter

from roocs_utils.project_utils import derive_ds_id

from daops.ops.base import Operation
from daops.utils import normalise

from clisops.ops.average import average_over_dims as average


from clisops.ops.average import average_over_dims as clisops_average_over_dims
def apply_weights(ds):
ds["time"] = ds["time"].astype("int64")
ds["time_bnds"] = ds["time_bnds"].astype("int64")
# weights
weights = np.cos(np.deg2rad(ds.lat))
weights.name = "weights"
weights.fillna(0)
# apply weights
ds_weighted = ds.weighted(weights)
return ds_weighted


class WeightedAverage(Operation):
Expand All @@ -25,10 +42,58 @@ def _resolve_params(self, collection, **params):
"ignore_undetected_dims": params.get("ignore_undetected_dims"),
}

def get_operation_callable(self):
return clisops_average_over_dims
def _calculate(self):
config = {
"output_type": self._output_type,
"output_dir": self._output_dir,
"split_method": self._split_method,
"file_namer": self._file_namer,
}

self.params.update(config)

new_collection = collections.OrderedDict()

for dset in self.collection:
ds_id = derive_ds_id(dset)
new_collection[ds_id] = dset.file_paths

# Normalise (i.e. "fix") data inputs based on "character"
norm_collection = normalise.normalise(
new_collection, False # self._apply_fixes
)

rs = normalise.ResultSet(vars())

# apply weights
datasets = []
for ds_id in norm_collection.keys():
ds = norm_collection[ds_id]
ds_mod = apply_weights(ds)
datasets.append(ds_mod)

dims = dimension_parameter.DimensionParameter(
self.params.get("dims", None)
).value

# processed_ds = xr.concat(
# datasets,
# "time",
# )

# average over dimensions
outputs = average(
# processed_ds,
datasets,
dims=dims,
output_type="nc",
)
# result
rs.add("output", outputs)

return rs



def run_weighted_average(args):
result = weighted_average(**args)
return result.file_uris
Expand All @@ -42,7 +107,6 @@ def weighted_average(
split_method="time:auto",
file_namer="standard",
apply_fixes=False,
apply_average=False,
):
result_set = WeightedAverage(**locals()).calculate()
return result_set
return result_set

0 comments on commit 1452348

Please sign in to comment.