From 9f4f6db237287b23045b19d3fffe0e513ed7b838 Mon Sep 17 00:00:00 2001 From: Lasse Bjermeland Date: Fri, 15 Nov 2024 12:27:57 +0100 Subject: [PATCH] Fix missing column names in csv file --- metocean_api/ts/internal/aux_funcs.py | 3 ++ metocean_api/ts/internal/ec/ec_products.py | 33 +++++++++-------- metocean_api/ts/internal/metno/met_product.py | 9 +++-- metocean_api/ts/internal/product.py | 3 +- tests/test_extract_data.py | 37 ++++++++++++++++++- 5 files changed, 62 insertions(+), 23 deletions(-) diff --git a/metocean_api/ts/internal/aux_funcs.py b/metocean_api/ts/internal/aux_funcs.py index bf39485..98e4b56 100644 --- a/metocean_api/ts/internal/aux_funcs.py +++ b/metocean_api/ts/internal/aux_funcs.py @@ -86,6 +86,9 @@ def create_dataframe(product, ds: xr.Dataset, lon_near, lat_near, outfile, start units = varattr.get("units", "-") header_lines.append("#" + name + ";" + standard_name + ";" + long_name + ";" + units) + # Add column names last + header_lines.append("time," + ",".join(df.columns)) + header = "\n".join(header_lines) + "\n" if save_csv: diff --git a/metocean_api/ts/internal/ec/ec_products.py b/metocean_api/ts/internal/ec/ec_products.py index 6d4d8ce..aaedf6f 100644 --- a/metocean_api/ts/internal/ec/ec_products.py +++ b/metocean_api/ts/internal/ec/ec_products.py @@ -24,7 +24,7 @@ def find_product(name: str) -> Product: class ERA5(Product): - + @property def convention(self) -> Convention: return Convention.METEOROLOGICAL @@ -54,9 +54,9 @@ def import_data(self, ts: TimeSeries, save_csv=True, save_nc=False, use_cache=Fa "significant_height_of_wind_waves", ] filenames = self.__download_era5_from_cds(ts.start_time, ts.end_time, ts.lon, ts.lat,ts.variable, folder='cache') + # Combine the data from the multiple files into a single dataframe df_res = None ds_res = None - variable_info = [] for filename in filenames: with xr.open_mfdataset(filename) as ds: @@ -67,17 +67,6 @@ def import_data(self, ts: TimeSeries, save_csv=True, save_nc=False, use_cache=Fa ds = ds.drop_vars(['longitude','latitude'], errors="ignore") df = aux_funcs.create_dataframe(self.name, ds, lon_near, lat_near, ts.datafile, ts.start_time, ts.end_time, save_csv=False) df.drop(columns=['number', 'expver'], inplace=True, errors='ignore') - variable = df.columns[0] - try: - standard_name = ds[variable].standard_name - except AttributeError: - standard_name = '-' - try: - long_name = ds[variable].long_name - except AttributeError: - long_name = '-' - variable_info.append(f'#{variable};{standard_name};{long_name};{ds[variable].units}\n') - if df_res is None: df_res = df ds_res = ds @@ -88,10 +77,22 @@ def import_data(self, ts: TimeSeries, save_csv=True, save_nc=False, use_cache=Fa if save_csv: lon_near = ds.longitude.values[0] lat_near = ds.latitude.values[0] - top_header = f'#{ts.product};LONGITUDE:{lon_near:0.4f};LATITUDE:{lat_near:0.4f}\n' - header = [top_header, '#Variable_name;standard_name;long_name;units\n'] + variable_info + header_lines =[f'#{ts.product};LONGITUDE:{lon_near:0.4f};LATITUDE:{lat_near:0.4f}'] + header_lines.append("#Variable_name;standard_name;long_name;units") + var_names = ["time"] + for name,vardata in ds_res.data_vars.items(): + varattr = vardata.attrs + standard_name =varattr.get("standard_name", "-") + long_name = varattr.get("long_name", "-") + units = varattr.get("units", "-") + header_lines.append("#" + name + ";" + standard_name + ";" + long_name + ";" + units) + var_names.append(name) + # Add column names last + header_lines.append(",".join(var_names)) + header = "\n".join(header_lines) + "\n" + with open(ts.datafile, "w", encoding="utf8", newline="") as f: - f.writelines(header) + f.write(header) df_res.to_csv(f, header=False, encoding=f.encoding, index_label="time") if save_nc: diff --git a/metocean_api/ts/internal/metno/met_product.py b/metocean_api/ts/internal/metno/met_product.py index 88e3bf5..799062a 100644 --- a/metocean_api/ts/internal/metno/met_product.py +++ b/metocean_api/ts/internal/metno/met_product.py @@ -4,6 +4,7 @@ from abc import abstractmethod from tqdm import tqdm import xarray as xr +import pandas as pd from .. import aux_funcs from ..product import Product @@ -20,11 +21,11 @@ class MetProduct(Product): """ @abstractmethod - def get_default_variables(self): + def get_default_variables(self) -> List[str]: raise NotImplementedError(f"Not implemented for {self.name}") @abstractmethod - def _get_url_info(self, date: str): + def _get_url_info(self, date: str) -> str: raise NotImplementedError(f"Not implemented for {self.name}") @abstractmethod @@ -35,7 +36,7 @@ def _get_near_coord(self, url: str, lon: float, lat: float): def get_dates(self, start_date, end_date): raise NotImplementedError(f"Not implemented for {self.name}") - def get_url_for_dates(self, start_date, end_date): + def get_url_for_dates(self, start_date, end_date) -> List[str]: """Returns the necessary files to download to support the given date range""" return [self._get_url_info(date) for date in self.get_dates(start_date, end_date)] @@ -115,7 +116,7 @@ def _combine_temporary_files( return df - def create_dataframe(self, ds: xr.Dataset, lon_near, lat_near, outfile, start_time, end_time, save_csv=True, **flatten_dims): + def create_dataframe(self, ds: xr.Dataset, lon_near, lat_near, outfile, start_time, end_time, save_csv=True, **flatten_dims) -> pd.DataFrame: ds = self._flatten_data_structure(ds, **flatten_dims) return aux_funcs.create_dataframe(self.name, ds, lon_near, lat_near, outfile, start_time, end_time, save_csv) diff --git a/metocean_api/ts/internal/product.py b/metocean_api/ts/internal/product.py index fe91b78..6b886b0 100644 --- a/metocean_api/ts/internal/product.py +++ b/metocean_api/ts/internal/product.py @@ -3,6 +3,7 @@ from __future__ import annotations from typing import TYPE_CHECKING, Tuple, List from abc import ABC, abstractmethod +import pandas as pd from .convention import Convention if TYPE_CHECKING: @@ -20,7 +21,7 @@ def convention(self) -> Convention: return Convention.NONE @abstractmethod - def import_data(self, ts: TimeSeries, save_csv=True, save_nc=False, use_cache=False): + def import_data(self, ts: TimeSeries, save_csv=True, save_nc=False, use_cache=False) -> pd.DataFrame: """Import data specified by the TimeSeries object""" raise NotImplementedError(f"import_data method not implemented for product {self.name}") diff --git a/tests/test_extract_data.py b/tests/test_extract_data.py index 79b4228..f7a3648 100644 --- a/tests/test_extract_data.py +++ b/tests/test_extract_data.py @@ -1,4 +1,5 @@ import xarray as xr +import pandas as pd from metocean_api import ts from metocean_api.ts.internal import products from metocean_api.ts.internal.convention import Convention @@ -18,6 +19,7 @@ def test_extract_nora3_wind(): df_ts.import_data(save_csv=SAVE_CSV,save_nc=SAVE_NC, use_cache=USE_CACHE) assert (df_ts.lat_data, df_ts.lon_data) == (53.32374838481946, 1.3199893172215793) assert df_ts.data.shape == (744,14) + __compare_loaded_data(df_ts) def test_download_of_temporary_files(): # Pick a time region with a start and end time where the temporary files will cover more than the requested time @@ -40,27 +42,52 @@ def test_extract_nora3_wave(): df_ts.import_data(save_csv=SAVE_CSV,save_nc=SAVE_NC, use_cache=USE_CACHE) assert (df_ts.lat_data, df_ts.lon_data) == (53.32494354248047, 1.3358169794082642) assert df_ts.data.shape == (744,14) + __compare_loaded_data(df_ts) def test_nora3_wind_wave_combined(): df_ts = ts.TimeSeries(lon=3.73, lat=64.60,start_time='2020-09-14', end_time='2020-09-15', product='NORA3_wind_wave', height=[10]) - # Import data from thredds.met.no df_ts.import_data(save_csv=SAVE_CSV,save_nc=SAVE_NC, use_cache=USE_CACHE) assert (df_ts.lat_data, df_ts.lon_data) == (64.60475157243123, 3.752025547482376) assert df_ts.data.shape == (48, 16) + __compare_loaded_data(df_ts) -#def test_extract_nora3_stormsurge(): +# def test_extract_nora3_stormsurge(): # df_ts = ts.TimeSeries(lon=1.320, lat=53.324,start_time='2000-01-01', end_time='2000-01-31', product='NORA3_stormsurge') # # Import data from thredds.met.no # df_ts.import_data(save_csv=SAVE_CSV,save_nc=SAVE_NC) # assert df_ts.data.shape == (744,1) +def __inferr_frequency(data: pd.DataFrame): + inferred_freq = pd.infer_freq(data.index) + # Set the inferred frequency if it’s detected + if inferred_freq: + data.index.freq = inferred_freq + else: + print("Could not infer frequency. Intervals may not be consistent.") + +def __compare_loaded_data(df_ts: ts.TimeSeries): + # Load the data back in and check that the data is the same + df_ts2 = ts.TimeSeries( + lon=df_ts.lon, + lat=df_ts.lat, + start_time=df_ts.start_time, + end_time=df_ts.end_time, + product=df_ts.product, + ) + df_ts2.load_data(local_file=df_ts.datafile) + __inferr_frequency(df_ts.data) + __inferr_frequency(df_ts2.data) + pd.testing.assert_frame_equal(df_ts.data, df_ts2.data) + + def test_extract_nora3_atm(): df_ts = ts.TimeSeries(lon=1.320, lat=53.324,start_time='2000-01-01', end_time='2000-01-31', product='NORA3_atm_sub') # Import data from thredds.met.no df_ts.import_data(save_csv=SAVE_CSV,save_nc=SAVE_NC, use_cache=USE_CACHE) assert (df_ts.lat_data, df_ts.lon_data) == (53.32374838481946, 1.3199893172215793) assert df_ts.data.shape == (744,7) + __compare_loaded_data(df_ts) def test_extract_nora3_atm3hr(): df_ts = ts.TimeSeries(lon=1.320, lat=53.324,start_time='2000-01-01', end_time='2000-01-31', product='NORA3_atm3hr_sub') @@ -69,12 +96,14 @@ def test_extract_nora3_atm3hr(): print(f"product: {df_ts.product}: {df_ts.lat_data}, {df_ts.lon_data}") assert (df_ts.lat_data, df_ts.lon_data) == (53.32374838481946, 1.3199893172215793) assert df_ts.data.shape == (248,30) + __compare_loaded_data(df_ts) def test_extract_obs(): df_ts = ts.TimeSeries(lon='', lat='',start_time='2017-01-01', end_time='2017-01-31' , product='E39_B_Sulafjorden_wave', variable=['Hm0', 'tp']) # Import data from thredds.met.no df_ts.import_data(save_csv=SAVE_CSV,save_nc=SAVE_NC, use_cache=USE_CACHE) assert df_ts.data.shape == (4464,2) + __compare_loaded_data(df_ts) def test_norkyst_800(): df_ts = ts.TimeSeries(lon=3.73, lat=64.60,start_time='2020-09-14', end_time='2020-09-15', product='NORKYST800') @@ -82,6 +111,7 @@ def test_norkyst_800(): df_ts.import_data(save_csv=SAVE_CSV,save_nc=SAVE_NC, use_cache=USE_CACHE) assert (df_ts.lat_data, df_ts.lon_data) == (64.59832175874106, 3.728905373023728) assert df_ts.data.shape == (48, 65) + __compare_loaded_data(df_ts) def test_norkyst_da_zdepth(): # We want to collect a subset @@ -91,6 +121,7 @@ def test_norkyst_da_zdepth(): df_ts.import_data(save_csv=SAVE_CSV,save_nc=SAVE_NC, use_cache=USE_CACHE) assert (df_ts.lat_data, df_ts.lon_data) == (64.59537563943964, 3.74450378868417) assert df_ts.data.shape == (24, 16) + __compare_loaded_data(df_ts) def test_norkyst_da_surface(): df_ts = ts.TimeSeries(lon=3.73, lat=64.60,start_time='2017-01-19', end_time='2017-01-20', product='NorkystDA_surface') @@ -98,12 +129,14 @@ def test_norkyst_da_surface(): df_ts.import_data(save_csv=SAVE_CSV,save_nc=SAVE_NC, use_cache=USE_CACHE) assert (df_ts.lat_data, df_ts.lon_data) == (64.59537563943964, 3.74450378868417) assert df_ts.data.shape == (48, 5) + __compare_loaded_data(df_ts) def test_echowave(): df_ts = ts.TimeSeries(lon=3.098, lat=52.48,start_time='2017-01-19', end_time='2017-01-20', product='ECHOWAVE') # Import data from https://data.4tu.nl/datasets/ df_ts.import_data(save_csv=SAVE_CSV,save_nc=SAVE_NC, use_cache=USE_CACHE) assert df_ts.data.shape == (48, 22) + __compare_loaded_data(df_ts) def test_extract_nora3_wave_spectra(): df_ts = ts.TimeSeries(lon=3.73, lat=64.60,start_time='2017-01-29',end_time='2017-02-02',product='NORA3_wave_spec')