Skip to content

Commit

Permalink
Merge pull request #83 from ClimateImpactLab/dscim-v0.4.0_fixes
Browse files Browse the repository at this point in the history
Fix chunking issues in sum_AMEL and reduce_damages
  • Loading branch information
kemccusker authored Jul 6, 2023
2 parents 152ae4f + 43b7843 commit d9bdae3
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 15 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [0.4.0] - Unreleased
### Added
- Functions to concatenate input damages across batches. ([PR #83](https://github.com/ClimateImpactLab/dscim/pull/83), [@davidrzhdu](https://github.com/davidrzhdu))
- New unit tests for [dscim/utils/input_damages.py](https://github.com/ClimateImpactLab/dscim/blob/main/src/dscim/preprocessing/input_damages.py). ([PR #68](https://github.com/ClimateImpactLab/dscim/pull/68), [@davidrzhdu](https://github.com/davidrzhdu))
- New unit tests for [dscim/utils/rff.py](https://github.com/ClimateImpactLab/dscim/blob/main/src/dscim/utils/rff.py). ([PR #73](https://github.com/ClimateImpactLab/dscim/pull/73), [@JMGilbert](https://github.com/JMGilbert))
- New unit tests for [dscim/dscim/preprocessing.py](https://github.com/ClimateImpactLab/dscim/blob/main/src/dscim/preprocessing/preprocessing.py). ([PR #67](https://github.com/ClimateImpactLab/dscim/pull/67), [@JMGilbert](https://github.com/JMGilbert))
Expand All @@ -23,6 +24,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Remove old/unnecessary files. ([PR #57](https://github.com/ClimateImpactLab/dscim/pull/57), [@JMGilbert](https://github.com/JMGilbert))
- Remove unused “save_path” and “ec_cls” from `read_energy_files_parallel()`. ([PR #56](https://github.com/ClimateImpactLab/dscim/pull/56), [@davidrzhdu](https://github.com/davidrzhdu))
### Fixed
- Make all input damages output files with correct chunksizes. ([PR #83](https://github.com/ClimateImpactLab/dscim/pull/83), [@JMGilbert](https://github.com/JMGilbert))
- Add `.load()` to every loading of population data from EconVars. ([PR #82](https://github.com/ClimateImpactLab/dscim/pull/82), [@davidrzhdu](https://github.com/davidrzhdu))
- Make `compute_ag_damages` function correctly save outputs in float32. ([PR #72](https://github.com/ClimateImpactLab/dscim/pull/72) and [PR #82](https://github.com/ClimateImpactLab/dscim/pull/82), [@davidrzhdu](https://github.com/davidrzhdu))
- Make rff damage functions read in and save out in the proper filepath structure. ([PR #79](https://github.com/ClimateImpactLab/dscim/pull/79), [@JMGilbert](https://github.com/JMGilbert))
Expand Down
67 changes: 64 additions & 3 deletions src/dscim/preprocessing/input_damages.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
"""

import os
import glob
import re
import logging
import warnings
Expand Down Expand Up @@ -95,6 +94,50 @@ def _parse_projection_filesys(input_path, query="exists==True"):
return df.query(query)


def concatenate_damage_output(damage_dir, basename, save_path):
"""Concatenate labor/energy damage output across batches.
Parameters
----------
damage_dir str
Directory containing separate labor/energy damage output files by batches.
basename str
Prefix of the damage output filenames (ex. {basename}_batch0.zarr)
save_path str
Path to save concatenated file in .zarr format
"""
paths = [
f"{damage_dir}/{basename}_{b}.zarr"
for b in ["batch" + str(i) for i in range(0, 15)]
]
data = xr.open_mfdataset(paths=paths, engine="zarr")

for v in data:
del data[v].encoding["chunks"]

chunkies = {
"batch": 15,
"rcp": 1,
"gcm": 1,
"model": 1,
"ssp": 1,
"region": -1,
"year": 10,
}

data = data.chunk(chunkies)

for v in list(data.coords.keys()):
if data.coords[v].dtype == object:
data.coords[v] = data.coords[v].astype("unicode")
data.coords["batch"] = data.coords["batch"].astype("unicode")
for v in list(data.variables.keys()):
if data[v].dtype == object:
data[v] = data[v].astype("unicode")

data.to_zarr(save_path, mode="w")


def calculate_labor_impacts(input_path, file_prefix, variable, val_type):
"""Calculate impacts for labor results.
Expand Down Expand Up @@ -371,7 +414,7 @@ def process_batch(g):
batches = [ds for ds in batches if ds is not None]
chunkies = {
"rcp": 1,
"region": 24378,
"region": -1,
"gcm": 1,
"year": 10,
"model": 1,
Expand Down Expand Up @@ -738,12 +781,21 @@ def prep(
).expand_dims({"gcm": [gcm]})

damages = damages.chunk(
{"batch": 15, "ssp": 1, "model": 1, "rcp": 1, "gcm": 1, "year": 10}
{
"batch": 15,
"ssp": 1,
"model": 1,
"rcp": 1,
"gcm": 1,
"year": 10,
"region": -1,
}
)
damages.coords.update({"batch": [f"batch{i}" for i in damages.batch.values]})

# convert to EPA VSL
damages = damages * 0.90681089
damages = damages.astype(np.float32)

for v in list(damages.coords.keys()):
if damages.coords[v].dtype == object:
Expand Down Expand Up @@ -790,6 +842,15 @@ def coastal_inputs(
)
else:
d = d.sel(adapt_type=adapt_type, vsl_valuation=vsl_valuation, drop=True)
chunkies = {
"batch": 15,
"ssp": 1,
"model": 1,
"slr": 1,
"year": 10,
"region": -1,
}
d = d.chunk(chunkies)
d.to_zarr(
f"{path}/coastal_damages_{version}-{adapt_type}-{vsl_valuation}.zarr",
consolidated=True,
Expand Down
41 changes: 36 additions & 5 deletions src/dscim/preprocessing/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,24 @@ def reduce_damages(
xr.open_zarr(damages).chunks["batch"][0] == 15
), "'batch' dim on damages does not have chunksize of 15. Please rechunk."

if "coastal" not in sector:
chunkies = {
"rcp": 1,
"region": -1,
"gcm": 1,
"year": 10,
"model": 1,
"ssp": 1,
}
else:
chunkies = {
"region": -1,
"slr": 1,
"year": 10,
"model": 1,
"ssp": 1,
}

ce_batch_dims = [i for i in gdppc.dims] + [
i for i in ds.dims if i not in gdppc.dims and i != "batch"
]
Expand All @@ -110,15 +128,14 @@ def reduce_damages(
i for i in gdppc.region.values if i in ce_batch_coords["region"]
]
ce_shapes = [len(ce_batch_coords[c]) for c in ce_batch_dims]
ce_chunks = [xr.open_zarr(damages).chunks[c][0] for c in ce_batch_dims]

template = xr.DataArray(
da.empty(ce_shapes, chunks=ce_chunks),
da.empty(ce_shapes),
dims=ce_batch_dims,
coords=ce_batch_coords,
)
).chunk(chunkies)

other = xr.open_zarr(damages)
other = xr.open_zarr(damages).chunk(chunkies)

out = other.map_blocks(
ce_from_chunk,
Expand Down Expand Up @@ -205,7 +222,21 @@ def sum_AMEL(
for sector in sectors:
print(f"Opening {sector},{params[sector]['sector_path']}")
ds = xr.open_zarr(params[sector]["sector_path"], consolidated=True)
ds = ds[params[sector][var]].rename(var)
ds = (
ds[params[sector][var]]
.rename(var)
.chunk(
{
"batch": 15,
"ssp": 1,
"model": 1,
"rcp": 1,
"gcm": 1,
"year": 10,
"region": -1,
}
)
)
ds = xr.where(np.isinf(ds), np.nan, ds)
datasets.append(ds)

Expand Down
89 changes: 82 additions & 7 deletions tests/test_input_damages.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from dscim.menu.simple_storage import EconVars
from dscim.preprocessing.input_damages import (
_parse_projection_filesys,
concatenate_damage_output,
calculate_labor_impacts,
concatenate_labor_damages,
calculate_labor_batch_damages,
Expand All @@ -31,7 +32,7 @@ def test_parse_projection_filesys(tmp_path):
"""
Test that parse_projection_filesys correctly retrieves projection system output structure
"""
rcp = ["rcp85", "rcp45"]
rcp = ["rcp45", "rcp85"]
gcm = ["ACCESS1-0", "GFDL-CM3"]
model = ["high", "low"]
ssp = [f"SSP{n}" for n in range(2, 4)]
Expand All @@ -45,14 +46,14 @@ def test_parse_projection_filesys(tmp_path):
os.makedirs(os.path.join(tmp_path, b, r, g, m, s))

out_expected = {
"batch": list(chain(repeat("batch9", 16), repeat("batch6", 16))),
"rcp": list(chain(repeat("rcp85", 8), repeat("rcp45", 8))) * 2,
"batch": list(chain(repeat("batch6", 16), repeat("batch9", 16))),
"rcp": list(chain(repeat("rcp45", 8), repeat("rcp85", 8))) * 2,
"gcm": list(chain(repeat("ACCESS1-0", 4), repeat("GFDL-CM3", 4))) * 4,
"model": list(chain(repeat("high", 2), repeat("low", 2))) * 8,
"ssp": ["SSP2", "SSP3"] * 16,
"path": [
os.path.join(tmp_path, b, r, g, m, s)
for b in ["batch9", "batch6"]
for b in ["batch6", "batch9"]
for r in rcp
for g in gcm
for m in model
Expand All @@ -64,11 +65,83 @@ def test_parse_projection_filesys(tmp_path):
df_out_expected = pd.DataFrame(out_expected)

df_out_actual = _parse_projection_filesys(input_path=tmp_path)
df_out_actual = df_out_actual.sort_values(
by=["batch", "rcp", "gcm", "model", "ssp"]
)
df_out_actual.reset_index(drop=True, inplace=True)

pd.testing.assert_frame_equal(df_out_expected, df_out_actual)


def test_concatenate_damage_output(tmp_path):
"""
Test that concatenate_damage_output correctly concatenates damages across batches and saves to a single zarr file
"""
d = os.path.join(tmp_path, "concatenate_in")
if not os.path.exists(d):
os.makedirs(d)

for b in ["batch" + str(i) for i in range(0, 15)]:
ds_in = xr.Dataset(
{
"delta_rebased": (
["ssp", "rcp", "model", "gcm", "batch", "year", "region"],
np.full((2, 2, 2, 2, 1, 2, 2), 1).astype(object),
),
"histclim_rebased": (
["ssp", "rcp", "model", "gcm", "batch", "year", "region"],
np.full((2, 2, 2, 2, 1, 2, 2), 2),
),
},
coords={
"batch": (["batch"], [b]),
"gcm": (["gcm"], np.array(["ACCESS1-0", "BNU-ESM"], dtype=object)),
"model": (["model"], ["IIASA GDP", "OECD Env-Growth"]),
"rcp": (["rcp"], ["rcp45", "rcp85"]),
"region": (["region"], ["ZWE.test_region", "USA.test_region"]),
"ssp": (["ssp"], ["SSP2", "SSP3"]),
"year": (["year"], [2020, 2099]),
},
)

infile = os.path.join(d, f"test_insuffix_{b}.zarr")

ds_in.to_zarr(infile)

ds_out_expected = xr.Dataset(
{
"delta_rebased": (
["ssp", "rcp", "model", "gcm", "batch", "year", "region"],
np.full((2, 2, 2, 2, 15, 2, 2), 1),
),
"histclim_rebased": (
["ssp", "rcp", "model", "gcm", "batch", "year", "region"],
np.full((2, 2, 2, 2, 15, 2, 2), 2),
),
},
coords={
"batch": (["batch"], ["batch" + str(i) for i in range(0, 15)]),
"gcm": (["gcm"], ["ACCESS1-0", "BNU-ESM"]),
"model": (["model"], ["IIASA GDP", "OECD Env-Growth"]),
"rcp": (["rcp"], ["rcp45", "rcp85"]),
"region": (["region"], ["ZWE.test_region", "USA.test_region"]),
"ssp": (["ssp"], ["SSP2", "SSP3"]),
"year": (["year"], [2020, 2099]),
},
)

concatenate_damage_output(
damage_dir=d,
basename="test_insuffix",
save_path=os.path.join(d, "concatenate.zarr"),
)
ds_out_actual = xr.open_zarr(os.path.join(d, "concatenate.zarr")).sel(
batch=["batch" + str(i) for i in range(0, 15)]
)

xr.testing.assert_equal(ds_out_expected, ds_out_actual)


@pytest.fixture
def labor_in_val_fixture(tmp_path):
"""
Expand Down Expand Up @@ -697,7 +770,9 @@ def energy_in_netcdf_fixture(tmp_path):
"region",
"year",
],
np.full((1, 1, 1, 1, 1, 2, 2), 2),
np.full((1, 1, 1, 1, 1, 2, 2), 2).astype(
object
),
),
},
coords={
Expand Down Expand Up @@ -1030,11 +1105,11 @@ def test_prep_mortality_damages(
{
"delta": (
["gcm", "batch", "ssp", "rcp", "model", "year", "region"],
np.full((2, 2, 2, 2, 2, 2, 2), -0.90681089),
np.float32(np.full((2, 2, 2, 2, 2, 2, 2), -0.90681089)),
),
"histclim": (
["gcm", "batch", "ssp", "rcp", "model", "year", "region"],
np.full((2, 2, 2, 2, 2, 2, 2), 2 * 0.90681089),
np.float32(np.full((2, 2, 2, 2, 2, 2, 2), 2 * 0.90681089)),
),
},
coords={
Expand Down

0 comments on commit d9bdae3

Please sign in to comment.