Skip to content

Commit

Permalink
feat(sat-etl): Enable delation of raw file via --rm flag
Browse files Browse the repository at this point in the history
  • Loading branch information
devsjc committed Aug 14, 2024
1 parent aea9a31 commit 90dea5d
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 51 deletions.
96 changes: 52 additions & 44 deletions containers/sat/download_process_sat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,18 @@
import datetime as dt
import itertools
import json
import traceback
import logging
import os
import pathlib
import shutil
import sys
import traceback
from multiprocessing import Pool, cpu_count
from typing import Literal

import dask.delayed
import dask.distributed
import dask.diagnostics
import eumdac
import eumdac.cli
import numpy as np
Expand All @@ -26,10 +29,9 @@
import satpy.dataset.dataid
import xarray as xr
import yaml
import dask.delayed
import dask.distributed
import zarr
from ocf_blosc2 import Blosc2

from satpy import Scene

handler = logging.StreamHandler(sys.stdout)
Expand Down Expand Up @@ -236,7 +238,7 @@ def process_scans(
zarr_path: pathlib.Path = folder.parent / start.strftime(sat_config.zarr_fmtstr[dstype])
zarr_times: list[dt.datetime] = []
if zarr_path.exists():
zarr_times = xr.open_zarr(zarr_path).sortby("time").time.values.tolist()
zarr_times = xr.open_zarr(zarr_path, consolidated=True).sortby("time").time.values.tolist()
log.debug(f"Zarr store already exists at {zarr_path} for {zarr_times[0]}-{zarr_times[-1]}")
else:
log.debug(f"Zarr store does not exist at {zarr_path}")
Expand All @@ -260,9 +262,9 @@ def process_scans(
if dataset is not None:
dataset = _preprocess_function(dataset)
datasets.append(dataset)
# Append to zarrs in hourly chunks (12 sets of 5 minute datasets)
# Append to zarrs in hourly chunks
# * This is so zarr doesn't complain about mismatching chunk sizes
if len(datasets) == 12:
if len(datasets) == int(pd.Timedelta("1h") / pd.Timedelta(sat_config.cadence)):
if pathlib.Path(zarr_path).exists():
log.debug(f"Appending to existing zarr store at {zarr_path}")
mode = "a"
Expand All @@ -275,7 +277,10 @@ def process_scans(
zarr_path.as_posix(),
mode,
chunks={
"time": 12,
"time": len(datasets),
"x_geostationary": -1,
"y_geostationary": -1,
"variable": 1,
},
)
datasets = []
Expand Down Expand Up @@ -340,15 +345,10 @@ def _convert_scene_to_dataarray(
# Ignore the "area" and "_satpy_id" scene attributes as they are not serializable
# and their data is already present in other scene attrs anyway.
if attr not in ["area", "_satpy_id"]:
try:
serialized_value = json.dumps(scene[channel].attrs[attr])
data_attrs[new_name] = serialized_value
except Exception as e:
log.warning(f"Could not serialize scene attribute {new_name}: {e}")
data_attrs[new_name] = scene[channel].attrs[attr].__repr__()

dataset: xr.Dataset = scene.to_xarray_dataset()
dataarray = dataset.to_array()
log.debug("Converted to dataarray")

# Lat and Lon are the same for all the channels now
if calculate_osgb:
Expand Down Expand Up @@ -504,15 +504,15 @@ def _write_to_zarr(dataset: xr.Dataset, zarr_name: str, mode: str, chunks: dict)
extra_kwargs = mode_extra_kwargs[mode]
sliced_ds: xr.Dataset = dataset.isel(x_geostationary=slice(0, 5548)).chunk(chunks)
try:
write_job: dask.delayed.Delayed = sliced_ds.to_zarr(
write_job = sliced_ds.to_zarr(
store=zarr_name,
compute=False,
**extra_kwargs,
consolidated=True,
mode=mode,
**extra_kwargs,
)
write_job = write_job.persist()
dask.distributed.progress(write_job, notebook=False)
with dask.diagnostics.ProgressBar():
write_job.compute()
except Exception as e:
log.error(f"Error writing dataset to zarr store {zarr_name} with mode {mode}: {e}")
traceback.print_tb(e.__traceback__)
Expand All @@ -522,7 +522,7 @@ def _write_to_zarr(dataset: xr.Dataset, zarr_name: str, mode: str, chunks: dict)
def _rewrite_zarr_times(output_name: str) -> None:
"""Rewrites the time coordinates in the given zarr store."""
# Combine time coords
ds = xr.open_zarr(output_name, consolidated=False)
ds = xr.open_zarr(output_name, consolidated=True)

# Prevent numcodecs string error
# See https://github.com/pydata/xarray/issues/3476#issuecomment-1205346130
Expand Down Expand Up @@ -555,7 +555,7 @@ def _rewrite_zarr_times(output_name: str) -> None:
data["metadata"]["time/.zarray"] = coord_data["metadata"]["time/.zarray"]
with open(f"{output_name}/.zmetadata", "w") as f:
json.dump(data, f)
# zarr.consolidate_metadata(output_name)
zarr.consolidate_metadata(output_name)


parser = argparse.ArgumentParser(
Expand All @@ -570,34 +570,39 @@ def _rewrite_zarr_times(output_name: str) -> None:
choices=list(CONFIGS.keys()),
)
parser.add_argument(
"--path",
"--path", "-p",
help="Path to store the downloaded data",
default="/mnt/disks/sat",
type=pathlib.Path,
)
parser.add_argument(
"--start_date",
"--start_date", "-s",
help="Date to download from (YYYY-MM-DD)",
type=dt.date.fromisoformat,
required=False,
default=str(dt.datetime.now(tz=dt.UTC).date()),
)
parser.add_argument(
"--end_date",
"--end_date", "-e",
help="Date to download to (YYYY-MM-DD)",
type=dt.date.fromisoformat,
required=False,
default=str(dt.datetime.now(tz=dt.UTC).date()),
)
parser.add_argument(
"--delete_raw", "--rm",
help="Delete raw files after processing",
action="store_true",
default=False,
)

if __name__ == "__main__":
def run(args: argparse.Namespace) -> None:
prog_start = dt.datetime.now(tz=dt.UTC)
log.info(f"{prog_start!s}: Running with args: {args}")

# Parse running args
args = parser.parse_args()
# Get running folder from args
folder: pathlib.Path = args.path / args.sat

log.info(f"{prog_start!s}: Running with args: {args}")
# Get config for desired satellite
sat_config = CONFIGS[args.sat]

Expand All @@ -618,41 +623,33 @@ def _rewrite_zarr_times(output_name: str) -> None:
# Download data
# We only parallelize if we have a number of files larger than the cpu count
token = _gen_token()
results: list[pathlib.Path] = []
raw_paths: list[pathlib.Path] = []
if len(scan_times) > cpu_count():
log.debug(f"Concurrency: {cpu_count()}")
pool = Pool(max(cpu_count(), 10)) # EUMDAC only allows for 10 concurrent requests
results = pool.starmap(
raw_paths = pool.starmap(
download_scans,
[(sat_config, folder, scan_time, token) for scan_time in scan_times],
)
pool.close()
pool.join()
results = list(itertools.chain(results))
raw_paths = list(itertools.chain(raw_paths))
else:
results = []
raw_paths = []
for scan_time in scan_times:
result: list[pathlib.Path] = download_scans(sat_config, folder, scan_time, token)
if len(result) > 0:
results.extend(result)
raw_paths.extend(result)

log.info(f"Downloaded {len(results)} files.")
log.info(f"Downloaded {len(raw_paths)} files.")
log.info("Converting raw data to HRV and non-HRV Zarr Stores.")

# Process the HRV and non-HRV data concurrently if possible
completed_types: list[str] = []
if cpu_count() > 1:
pool = Pool(cpu_count())
completed_types = pool.starmap(
process_scans,
[(sat_config, folder, start, end, t) for t in ["hrv", "nonhrv"]],
)
pool.close()
pool.join()
else:
for t in ["hrv", "nonhrv"]:
completed_type = process_scans(sat_config, folder, start, end, t)
completed_types.append(completed_type)
for t in ["hrv", "nonhrv"]:
log.info("Processing {t} data.")
completed_type = process_scans(sat_config, folder, start, end, t)
completed_types.append(completed_type)
for completed_type in completed_types:
log.info(f"Processed {completed_type} data.")

Expand All @@ -663,3 +660,14 @@ def _rewrite_zarr_times(output_name: str) -> None:
)
log.info(f"Completed archive for args: {args}. ({new_average_secs_per_scan} seconds per scan).")

# Delete raw files, if desired
if args.delete_raw:
log.info(f"Deleting {len(raw_paths)} raw files in {folder.as_posix()}.")
for f in raw_paths:
f.unlink()


if __name__ == "__main__":
# Parse running args
args = parser.parse_args()
run(args)
28 changes: 21 additions & 7 deletions containers/sat/test_download_process_sat.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,14 @@ def setUpClass(cls) -> None:

token = dps._gen_token()

paths = dps.download_scans(
sat_config=dps.CONFIGS["iodc"],
folder=pathlib.Path("/tmp/test_sat_data"),
scan_time=TIMESTAMP,
token=token,
)
cls.paths = paths
for t in [TIMESTAMP + pd.Timedelta(t) for t in ["0m", "15m", "30m", "45m"]]:
paths = dps.download_scans(
sat_config=dps.CONFIGS["iodc"],
folder=pathlib.Path("/tmp/test_sat_data"),
scan_time=t,
token=token,
)
cls.paths = paths

attrs: dict = {
"end_time": TIMESTAMP + pd.Timedelta("15m"),
Expand Down Expand Up @@ -111,3 +112,16 @@ def test_open_and_scale_data(self) -> None:
ds.to_zarr("/tmp/test_sat_data/test.zarr", mode="w", consolidated=True)
ds2 = xr.open_zarr("/tmp/test_sat_data/test.zarr")
self.assertDictEqual(dict(ds.sizes), dict(ds2.sizes))
self.assertNotEqual(dict(ds.attrs), {})

def test_process_scans(self) -> None:

out: str = dps.process_scans(
dps.CONFIGS["iodc"],
pathlib.Path("/tmp/test_sat_data"),
pd.Timestamp("2024-01-01"),
pd.Timestamp("2024-01-02"), "nonhrv",
)

self.assertTrue(False)

0 comments on commit 90dea5d

Please sign in to comment.