diff --git a/.gitignore b/.gitignore index 703616f..942704e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ # remove images *.png +!img/* # remove scratch files # C extensions diff --git a/README.md b/README.md index 40d02fe..41c7ada 100644 --- a/README.md +++ b/README.md @@ -1,106 +1,142 @@ -# MOVERO PLOTS -## 0. GENERAL -### 0.1 Task Description / Overview -> **_Task_**: Replace the IDL plot scripts for the verification plots with Python scripts. -> There is a number of different plots, which need to be created. For each type of plot a number of scripts is necessary. - -These plots need to be implemented: -1. Time Series of verification scores -2. Diurnal cycle of verification scores -3. Total scores depending on lead-time ranges -4. Numeric values of total scores - - - -### 0.2 Usage -The entry point for this package is a script called [plot_synop](src/movero/plot_synop.py). By executing `python plot_synop.py --help`, one can see the required and optional command line inputs: +# MOVEROPLOT +Moveroplot is a tool for the production of various verification graphics including line graphs, maps, histograms, and reliability diagrams +based on verification results produced by Movero. +## Getting Started +To setup the project, follow the instructions in [CONTRIBUTING.md](CONTRIBUTING.md). + + +## Usage +The primary command for moveroplot follows the structure +```bash +moveroplot [OPTIONS] MODEL_VERSION(S) +moveroplot [OPTIONS] MODEL_VERSION(S) ``` -Usage: plot_synop.py [OPTIONS] MODEL_VERSION - - Entry Point for the MOVERO Plotting Pipeline. - - The only input argument is the MODEL_VERSION argument. Pass this along with any number - of options. These usually have a default value or are not necessary. - +To see the available options, run: +```bash +moveroplot -h +``` +``` +The following options are listed below: Options: - --debug Add debug comments to command prompt. - --lt_ranges TEXT Specify the lead time range(s) of interest. - Def: 19-24 - --plot_params TEXT Specify parameters to plot. - --plot_scores TEXT Specify scores to plot. - --plot_cat_params TEXT Specify categorical parameters to plot. - --plot_cat_thresh TEXT Specify categorical scores thresholds to - plot. - --plot_cat_scores TEXT Specify categorical scores to plot. - --plot_ens_params TEXT Specify ens parameters to plot. - --plot_ens_thresh TEXT Specify ens scores thresholds to plot. - --plot_ens_scores TEXT Specify ens scores thresholds to plot. - --input_dir PATH Specify input directory. - --output_dir TEXT Specify output directory. Def: plots - --relief Add relief to maps. - --grid Add grid to plots. - --season [2020s4|2021s1|2021s2|2021s3|2021s4] - Specify the season of interest. Def: 2021s4 - --help Show this message and exit. + --plot_type TEXT Specify the type of plot to generate: [total, + time, station, daytime, ensemble]. + --debug Add debug comments to command prompt. + --lt_ranges TEXT Specify the lead time ranges of interest. Def: + 19-24 + --plot_params TEXT Specify parameters to plot. + --plot_scores TEXT Specify scores to plot. + --plot_cat_params TEXT Specify categorical parameters to plot. + --plot_cat_thresh TEXT Specify categorical scores thresholds to plot. + --plot_cat_scores TEXT Specify categorical scores to plot. + --plot_ens_params TEXT Specify parameters to ensemble plots. + --plot_ens_scores TEXT Specify scores to ensemble plots. + --plot_ens_cat_params TEXT Specify categorical parameters to ensemble + plots. + --plot_ens_cat_scores TEXT Specify categorical scores to ensemble plots. + --plot_ens_cat_thresh TEXT Specify categorical scores thresholds to + ensemble plots. + --input_dir PATH Specify input directory. + --output_dir TEXT Specify output directory. Def: plots + --colors TEXT Specify the plot color for each model version + using matploblib's color coding + --relief Add relief to maps. + --grid Add grid to plots. + -V, --version Show the version and exit. + -v, --verbose Increase verbosity; specify multiple times for + more. + -h, --help Show this message and exit. ``` +`moveroplot` efficiently processes user inputs to construct a `plot_setup` dictionary, which is pivotal in organizing the plotting process. +This dictionary is structured with two primary keys: 'model_versions' and 'parameter'. +* 'model_versions': This key maps to a list encompassing the model versions to plot. +* 'parameter': This key connects to a nested dictionary. Within this nested structure, each parameter serves as a key linked to its corresponding scores to plot. -> [time=Wed, Mar 2, 2022 2:21 PM] -**Command so far to create all plots for model v. C-1E-CTR_ch:** - +To offer a clearer understanding, the image below illustrates the potential parameters and their associated scores and their thresholds: +![**Parameters Dictitonary**](https://i.imgur.com/kdQrufu.png) +In the subsequent stages, `plot_setup` is channeled into distinct plotting pipelines. There, the source files are retrieved, parsed and plotted. +Ultimately, all plots are saved in the `/` directory as PNG files. +### Usage Examples +Example Command plotting Station, Time, Total and Daytime Scores: ``` -python plot_synop.py C-1E-CTR_ch ---plot_params TOT_PREC12,TOT_PREC6,TOT_PREC1,CLCT,GLOB,DURSUN12,DURSUN1,T_2M,T_2M_KAL,TD_2M,TD_2M_KAL,RELHUM_2M,FF_10M,FF_10M_KAL,VMAX_10M6,VMAX_10M1,DD_10M,PS,PMSL ---plot_scores ME,MMOD/MOBS,MAE,STDE,RMSE,COR,NOBS ---plot_cat_params TOT_PREC12,TOT_PREC6,TOT_PREC1,CLCT,T_2M,T_2M_KAL,TD_2M,TD_2M_KAL,FF_10M,FF_10M_KAL,VMAX_10M6,VMAX_10M1 +moveroplot C-1E_ch/C-2E_ch --lt_ranges 07-12,19-24,61-72 --input_dir /scratch/osm/movero/wd/2022s4 --plot_type station,time,daytime,total +--plot_cat_params TOT_PREC12,TOT_PREC6,CLCT,T_2M,TD_2M,FF_10M,VMAX_10M6 --plot_cat_thresh 0.1,1,10:0.2,1,5:0.2,0.5,2:2.5,6.5:0,15,25:0,15,25:-5,5,15:-5,5,15:2.5,5,10:2.5,5,10:5,12.5,20:5,12.5,20 --plot_cat_scores FBI,MF/OF,POD,FAR,THS,ETS +--plot_params TOT_PREC12,TOT_PREC6,TOT_PREC1,CLCT,GLOB,DURSUN12,DURSUN1,T_2M,T_2M_KAL,TD_2M,TD_2M_KAL,RELHUM_2M,FF_10M,FF_10M_KAL,VMAX_10M6,VMAX_10M1,DD_10M,PS,PMSL +--plot_scores ME,MMOD/MOBS,MAE,STDE,RMSE,COR,NOBS +``` +Example Command plotting Ensemble Scores: +``` +moveroplot C-1E_ch/C-2E_ch --lt_ranges 07-12,19-24,61-72 --input_dir /scratch/osm/movero/wd/2022s4 --plot_type ensemble +--plot_ens_params TOT_PREC12,TOT_PREC6,CLCT,T_2M,TD_2M,FF_10M,VMAX_10M6 +--plot_ens_scores OUTLIERS,RANK,RPS,RPS_REF +--plot_ens_cat_params TOT_PREC12,TOT_PREC6,CLCT,T_2M,TD_2M,FF_10M,VMAX_10M6 +--plot_ens_cat_thresh 0.1,0.2,2.5,0,0,2.5,5 +--plot_ens_cat_scores REL,RES,BS,BS_REF,BSS,BSSD,REL_DIA ``` -`plot_synop.py` parses these user inputs into a _parameter dictionary_. Each provided parameter is one key in this dictionary. For every key, a list of corresponding scores is assigned. -![**Parameters Dictitonary**](https://i.imgur.com/kdQrufu.png) -Afterwards this `params_dict` is passed to separate plotting pipelines. There, the source files are retrieved, parsed and plotted. Ultimately, all plots are placed in the `//` directory. - -## 1. SPATIAL VERIFICATION - - -> Relevant File: [station_score.py](src/movero/station_scores.py) - -The spatial verification plots feature a map, where all stations have are marked with a coloured dot. The colour of this dot corresponds to a colour-bar on the right side of the map. The smaller the deviation from the centre of the colourbar, the better. One can see directly, if & where the model performed well, or rather less so. - -###### Example: Old Station Score Plot -drawing - -###### Example: New Station Scores Plot -drawing - ---- - +## Plotting Pipeline and Output +### Plotting Multiple Model Versions +`moveroplot` offers the option to visualize multiple results from distinct model versions within a single plot or image, depending on the plot type. +This can be achieved through the use of specific delimiters: a slash (/) signifies combined plotting, while a comma (,) indicates separate plots. -## 2. TIME SERIES OF VERIFICATION SCORES -> Relevant File: [time_scores.py](src/movero/time_scores.py) +Example: +> Input: C-1E_ch/C-2E_ch,C-1E_alps +> +> Interpretation: Display results of C-1E_ch and C-2E_ch into one combined plot. +> Plot the results of C-1E_alps separately. -###### Example: Old vs. New Station Scores Plot -![](https://i.imgur.com/g9t612p.png) -![](https://i.imgur.com/mlwMtTY.png) +### Spatial Verification +> Relevant File: [station_scores.py](src/moveroplot/station_scores.py) +> +> Note: Each station score image is consistent in its LT range. +> The number of plots per image can vary (model versions along columns, scores along rows). +> +> Note: Invalid Atab files are ignored. ---- -## 3. DIURNAL CYCLYE OF VERIFICATION SCORES -> Relevant File: [daytime_scores.py](src/movero/daytime_scores.py) -###### Example: Old vs. New Station Scores Plot -![](https://i.imgur.com/FGSW1My.png) -![](https://i.imgur.com/pSNKEF4.png) +![**Example Station Scores**](img/station_scores_example.png) +### Time Series of Verification Scores +> Relevant File: [time_scores.py](src/moveroplot/time_scores.py) +> +> Remark: The order specified in --plot_scores and --plot_cat_scores is crucial. Two plots are assigned per page. Each threshold and parameter initialize a new page. +> +> Remark: Model versions and scores can be displayed in the same plot using `/` in the input. +> +![**Example Time Scores**](img/time_scores_example.png) -___ -## 4. TOTAL SCORES DEP. ON LEAD-TIME RANGES +### Diurnal Cycle of Verification Scores +> Relevant File: [daytime_scores.py](src/moveroplot/daytime_scores.py) +> +> Remark: The order specified in --plot_scores and --plot_cat_scores is crucial. Two plots are assigned per page. Each threshold and parameter initialize a new page. +> +> Remark: Model versions and scores can be displayed in the same plot using `/` in the input +![**Example Daytime Scores**](img/daytime_scores_example.png) -> Remark: how are scores assigned to subplots? +### Total scores for all lead times +> Relevant File: [total_scores.py](src/moveroplot/total_scores.py) +> +> Remark: The order specified in --plot_scores and --plot_cat_scores is crucial. Four plots are assigned per page. Each threshold and parameter initialize a new page. > -> die Reihenfolge in --plot_scores ist entscheidend. Es kommen immer 4 plots auf eine Seite für die normalen Scores. Die --plot_cat_scores beginnen auf jeden Fall auf einer neuen Seite, und jeder Threshold beginnt wieder auf einer neuen Seite. +> Remark: Model versions and scores can be displayed in the same plot using `/` in the input +> +![**Example Total Scores**](img/total_scores_example.png) -###### Example: Old vs. New Station Scores Plot -![](https://i.imgur.com/RViAUU4.png) -![](https://i.imgur.com/2d69BoT.png) +### Ensemble scores +> Relevant File: [ensemble_scores.py](src/moveroplot/ensemble_scores.py) +> +> Remark: The order specified in --plot_ens_scores and --plot_ens_cat_scores is crucial. Each threshold and parameter initialize a new page. +> +> Remark: Model versions and scores can be displayed in the same plot using `/` in the input. +> +> Remark: RANK, REL_DIA and line plots are saved in separate images. + +#### Regular Line Plots +![**Example Regular Ensemble Scores**](img/ensemble_scores_OUTLIERS_example.png) +#### RANK +![**Example RANK**](img/ensemble_scores_RANK_example.png) +#### Reliability Diagram +![**Example REAL DIA**](img/ensemble_scores_REL_DIA_example.png) diff --git a/img/daytime_scores_example.png b/img/daytime_scores_example.png new file mode 100644 index 0000000..9fda92d Binary files /dev/null and b/img/daytime_scores_example.png differ diff --git a/img/ensemble_scores_OUTLIERS_example.png b/img/ensemble_scores_OUTLIERS_example.png new file mode 100644 index 0000000..7197984 Binary files /dev/null and b/img/ensemble_scores_OUTLIERS_example.png differ diff --git a/img/ensemble_scores_RANK_example.png b/img/ensemble_scores_RANK_example.png new file mode 100644 index 0000000..2baf842 Binary files /dev/null and b/img/ensemble_scores_RANK_example.png differ diff --git a/img/ensemble_scores_REL_DIA_example.png b/img/ensemble_scores_REL_DIA_example.png new file mode 100644 index 0000000..83fbfcc Binary files /dev/null and b/img/ensemble_scores_REL_DIA_example.png differ diff --git a/img/station_scores_example.png b/img/station_scores_example.png new file mode 100644 index 0000000..92ccbd9 Binary files /dev/null and b/img/station_scores_example.png differ diff --git a/img/time_scores_example.png b/img/time_scores_example.png new file mode 100644 index 0000000..bf80175 Binary files /dev/null and b/img/time_scores_example.png differ diff --git a/img/total_scores_example.png b/img/total_scores_example.png new file mode 100644 index 0000000..82cf008 Binary files /dev/null and b/img/total_scores_example.png differ diff --git a/src/moveroplot/cli.py b/src/moveroplot/cli.py index 3df3b4b..3b52878 100644 --- a/src/moveroplot/cli.py +++ b/src/moveroplot/cli.py @@ -22,8 +22,14 @@ ) @click.version_option(__version__, "--version", "-V", message="%(version)s") @click.command(context_settings={"help_option_names": ["-h", "--help"]}) +@click.option( + "--plot_type", + type=str, + help="""Specify the type of plot to generate: + [total, time, station, daytime, ensemble].""", +) @click.argument( - "model_version", type=str + "model_versions", type=str, default="C-1E-CTR_ch,C-1E_ch" ) # help="Specify the correct run. I.e. C-1E-CTR_ch" @click.option( "--debug", type=bool, is_flag=True, help="Add debug comments to command prompt." @@ -31,8 +37,7 @@ @click.option( "--lt_ranges", type=str, - multiple=True, - default=("19-24",), + default="19-24", help="Specify the lead time ranges of interest. Def: 19-24", ) @click.option("--plot_params", type=str, help="Specify parameters to plot.") @@ -51,22 +56,33 @@ # 0.1,1,10:0.2,1,5:0.2,0.5,2:2.5,6.5:0,15,25:0,15,25:-5,5, # 15:-5,5,15:2.5,5,10:2.5,5,10:5,12.5,20:5,12.5,20 @click.option("--plot_cat_scores", type=str, help="Specify categorical scores to plot.") -# FBI,MF,POD,FAR,THS,ETS @click.option( - "--plot_ens_params", type=str, help="Specify ens parameters to plot." -) # TODO: figure out what ens params are + "--plot_ens_params", type=str, help="Specify parameters to ensemble plots." +) +@click.option("--plot_ens_scores", type=str, help="Specify scores to ensemble plots.") @click.option( - "--plot_ens_thresh", type=str, help="Specify ens scores thresholds to plot." -) # TODO: figure out what ens thresh are + "--plot_ens_cat_params", + type=str, + help="Specify categorical parameters to ensemble plots.", +) +@click.option( + "--plot_ens_cat_scores", + type=str, + help="Specify categorical scores to ensemble plots.", +) @click.option( - "--plot_ens_scores", type=str, help="Specify ens scores thresholds to plot." -) # TODO: figure out what ens scores are + "--plot_ens_cat_thresh", + type=str, + help="Specify categorical scores thresholds to ensemble plots.", +) +# FBI,MF,POD,FAR,THS,ETS # C-1E-CTR_ch # 🔰 new options for plot_synop call @click.option( "--input_dir", type=click.Path(exists=True), - default=Path("/scratch/osm/movero/wd"), + default=Path("/scratch/osm/movero/wd/2022s4"), + # default=Path("/scratch/kaufmann/movero/wd/2023s3_icon"), help="Specify input directory.", ) @click.option( @@ -75,23 +91,14 @@ default=Path("plots"), help="Specify output directory. Def: plots", ) -@click.option("--relief", type=bool, is_flag=True, help="Add relief to maps.") -@click.option("--grid", type=bool, is_flag=True, help="Add grid to plots.") @click.option( - "--season", - type=click.Choice( - [ - "2020s4", - "2021s1", - "2021s2", - "2021s3", - "2021s4", - ] - ), - multiple=False, - default="2021s4", - help="Specify the season of interest. Def: 2021s4", + "--colors", + type=str, + help="""Specify the plot color for each model version + using matploblib's color coding""", ) +@click.option("--relief", type=bool, is_flag=True, help="Add relief to maps.") +@click.option("--grid", type=bool, is_flag=True, help="Add grid to plots.") @click.pass_context def cli(ctx: Context, **kwargs) -> None: """Console script for test_cli_project.""" diff --git a/src/moveroplot/config/plot_settings.py b/src/moveroplot/config/plot_settings.py new file mode 100644 index 0000000..6415fb0 --- /dev/null +++ b/src/moveroplot/config/plot_settings.py @@ -0,0 +1,16 @@ +"""Static configurations settings for plots.""" + +modelcolors: list[str] = [ + "black", + "red", + "blue", + "green", + "cyan", + "yellow", + "magenta", + "orange", +] + +line_styles: list[str] = ["-", ":", "--", "-."] + +marker_styles: list[str] = ["D", "^", "o", "v"] diff --git a/src/moveroplot/daytime_scores.py b/src/moveroplot/daytime_scores.py index c3e50e9..3bfcbe4 100644 --- a/src/moveroplot/daytime_scores.py +++ b/src/moveroplot/daytime_scores.py @@ -1,32 +1,29 @@ # pylint: skip-file # Standard library -from pathlib import Path -from pprint import pprint +import re +from datetime import datetime +from datetime import timedelta # Third-party -import matplotlib.dates as md +import matplotlib.dates as mdates import matplotlib.pyplot as plt import numpy as np +from matplotlib.lines import Line2D -# Local -# import datetime -from .utils.atab import Atab -from .utils.check_params import check_params -from .utils.parse_plot_synop_ch import cat_daytime_score_range -from .utils.parse_plot_synop_ch import daytime_score_range +# First-party +import moveroplot.config.plot_settings as plot_settings +from moveroplot.load_files import load_relevant_files +from moveroplot.plotting import get_total_dates_from_headers # enter directory / read station_scores files / call plotting pipeline def _daytime_scores_pipeline( - params_dict, + plot_setup, lt_ranges, file_prefix, file_postfix, input_dir, output_dir, - season, - model_version, - grid, debug, ) -> None: """Read all ```ATAB``` files that are present in: data_dir/season/model_version/<...>. @@ -43,335 +40,210 @@ def _daytime_scores_pipeline( file_postfix (str): postfix of files (i.e. '.dat') input_dir (str): directory to seasons (i.e. /scratch/osm/movero/wd) output_dir (str): output directory (i.e. plots/) - season (str): season of interest (i.e. 2021s4/) model_version (str): model_version of interest (i.e. C-1E_ch) scores (list): list of scores, for which plots should be generated debug (bool): print further comments & debug statements """ # noqa: E501 print("\n--- initialising daytime score pipeline") - for lt_range in lt_ranges: - for parameter in params_dict: - # retrieve list of scores, relevant for current parameter - scores = params_dict[parameter] # this scores is a list of lists - - # define file path to the current parameter (station_score atab file) - file = f"{file_prefix}{lt_range}_{parameter}{file_postfix}" - path = Path(f"{input_dir}/{season}/{model_version}/{file}") - - # check if the file exists - if not path.exists(): - print( - f"""WARNING: No data file for parameter {parameter} could be found. - {path} does not exist.""" - ) - continue # for the current parameter no file could be retrieved - - if debug: - print(f"\nFilepath:\t{path}") - - # extract header & dataframe - header = Atab(file=path, sep=" ").header - df = Atab(file=path, sep=" ").data - - # > remove/replace missing values in dataframe with np.NaN - df = df.replace(float(header["Missing value code"][0]), np.NaN) - - # > if there are columns (= scores), that only contain np.NaN, remove them - # df = df.dropna(axis=1, how="all") - - # > check which relevant scores are available; extract those from df - all_scores = df.columns.tolist() - available_scores = ["hh"] - multiplot_scores = {} - for score in scores: - if len(score) == 1: - if score[0] in all_scores: - available_scores.append(score[0]) - else: # warn that a relevant score was not available in dataframe - print( - f"""WARNING: Score {score[0]} not - available for parameter {parameter}.""" - ) - if ( - len(score) > 1 - ): # # currently only 2-in-1 plots are currently possible - multiplot_scores[score[0]] = score[1] - for sc in score: - if sc in all_scores: - available_scores.append(sc) - else: - print( - f"""WARNING: Score {sc} not available - for parameter {parameter}.""" - ) - - df = df[available_scores] - df = df.set_index("hh") - - if debug: - print("\nFile header:") - pprint(header) - print("\nData:") - pprint(df) - print( - f"""Generating plot for {parameter} for - lt_range: {lt_range}. (File: {file})""" - ) - - # for each score in df, create one map - _generate_daytime_plot( - data=df, - multiplots=multiplot_scores, - lt_range=lt_range, - variable=parameter, - file=file, - file_postfix=file_postfix, - header_dict=header, + if not lt_ranges: + lt_ranges = "19-24" + for model_plots in plot_setup["model_versions"]: + for parameter, scores in plot_setup["parameter"].items(): + model_data = load_relevant_files( + input_dir, + file_prefix, + file_postfix, + debug, + model_plots, + parameter, + lt_ranges, + ltr_first=True, + transform_func=_daytime_score_transformation, + ) + if not model_data: + print(f"No matching files found with given ltr {lt_ranges}") + return + _generate_daytime_plots( + plot_scores=scores, + models_data=model_data, + parameter=parameter, output_dir=output_dir, - grid=grid, debug=debug, ) -# PLOTTING PIPELINE FOR DAYTIME SCORES PLOTS -# generator that gives time between start and end times with delta intervals -# inspired by: https://stackoverflow.com/questions/61733727/how-to-set-minutes-time-as-x-axis-of-a-matplotlib-plot-in-python # noqa: E501 -def deltatime(start, end, delta): - current = start - while current < end: - yield current - current += delta +def _daytime_score_transformation(df, header): + df["hh"] = df["hh"].astype(int) + df = df.replace(float(header["Missing value code"][0]), np.NaN) + return df -def get_xaxis(): - # Standard library - from datetime import datetime - from datetime import timedelta +def _initialize_plots(labels: list): + fig, ((ax0), (ax1)) = plt.subplots( + nrows=2, ncols=1, tight_layout=True, figsize=(10, 10), dpi=200 + ) + custom_lines = [ + Line2D([0], [0], color=plot_settings.modelcolors[i], lw=2) + for i in range(len(labels)) + ] + fig.legend( + custom_lines, + labels, + loc="upper right", + ncol=1, + frameon=False, + ) + plt.tight_layout(w_pad=8, h_pad=5, rect=(0.05, 0.05, 0.90, 0.90)) + return fig, [ax0, ax1] - # two random consecutive dates [date1, date2] - dates = [("01/02/1991", "02/02/1991")] # , '01/03/1991', '01/04/1991'] - # generate the list for each date between 00:00 on date1 to 00:00 on date2 hourly intervals # noqa: E501 - datetimes = [] - for start, end in dates: - startime = datetime.combine( - datetime.strptime(start, "%d/%m/%Y"), - datetime.strptime("0:00:00", "%H:%M:%S").time(), - ) - endtime = datetime.combine( - datetime.strptime(end, "%d/%m/%Y"), - datetime.strptime("01:00:00", "%H:%M:%S").time(), - ) - datetimes.append( - [j for j in deltatime(startime, endtime, timedelta(minutes=60))] - ) +def _clear_empty_axes_if_necessary(subplot_axes, idx): + # remove empty ``axes`` instances + if idx % 2 != 1: + [ax.axis("off") for ax in subplot_axes[(idx + 1) % 2 :]] - # #flatten datetimes list - datetimes = [datetime for day in datetimes for datetime in day] - x = datetimes - return x +def _plot_and_save_scores( + output_dir, + base_filename, + parameter, + plot_scores_setup, + sup_title, + ltr_models_data, + debug=False, +): + for ltr, models_data in ltr_models_data.items(): + fig, subplot_axes = _initialize_plots(ltr_models_data[ltr].keys()) + headers = [data["header"] for data in models_data.values()] + total_start_date, total_end_date = get_total_dates_from_headers(headers) + title_base = f"{parameter.upper()}: " + model_info = ( + f" {list(models_data.keys())[0]}" if len(models_data.keys()) == 1 else "" + ) -def _generate_daytime_plot( - data, - multiplots, - lt_range, - variable, - file, - file_postfix, - header_dict, + x_label_base = f"""{total_start_date.strftime("%Y-%m-%d %H:%M")} - {total_end_date.strftime("%Y-%m-%d %H:%M")}""" # noqa: E501 + filename = base_filename + f"_{ltr}" + pattern = ( + re.search(r"\(.*?\)", next(iter(plot_scores_setup))[0]) + if plot_scores_setup + else None + ) + prev_threshold = None + if pattern is not None: + prev_threshold = pattern.group() + current_threshold = prev_threshold + current_plot_idx = 0 + + for idx, score_setup in enumerate(plot_scores_setup): + prev_threshold = current_threshold + pattern = re.search(r"\(.*?\)", next(iter(score_setup))) + current_threshold = pattern.group() if pattern is not None else None + different_threshold = prev_threshold != current_threshold + if different_threshold: + _clear_empty_axes_if_necessary(subplot_axes, current_plot_idx - 1) + fig.savefig(f"{output_dir}/{filename}.png") + plt.close() + filename = base_filename + f"_{ltr}" + fig, subplot_axes = _initialize_plots(ltr_models_data[ltr].keys()) + current_plot_idx += current_plot_idx % 2 + + title = title_base + ",".join(score_setup) + model_info + ax = subplot_axes[current_plot_idx % 2] + for model_idx, data in enumerate(models_data.values()): + model_plot_color = plot_settings.modelcolors[model_idx] + header = data["header"] + unit = header["Unit"][0] + y_label = ",".join(score_setup) + ax.set_ylabel(f"{y_label.upper()} ({unit})") + ax.set_xlabel(x_label_base) + ax.set_title(title + f", LT: {ltr}") + + for score_idx, score in enumerate(score_setup): + x_int = list(data["df"]["hh"]) + score_values = data["df"][score].to_list() + if 0 not in x_int: + bound_x_values = x_int[:: -len(x_int) + 1] + bound_x_values[0] -= 24 + score_value0 = np.interp( + 0, bound_x_values, score_values[:: len(x_int) - 1] + ) + x_int = [0] + x_int + [24] + score_values = [score_value0] + score_values + [score_value0] + + x_datetimes = [ + datetime.combine(datetime.now().date(), datetime.min.time()) + + timedelta(hours=hour) + for hour in x_int + ] + ax.plot( + x_datetimes, + score_values, + color=model_plot_color, + linestyle=plot_settings.line_styles[score_idx], + fillstyle="none", + label=f"{score.upper()}", + marker="D", + ) + ax.tick_params(axis="both", which="major", labelsize=8) + ax.tick_params(axis="both", which="minor", labelsize=6) + ax.autoscale(axis="y") + ax.set_xlim(x_datetimes[0], x_datetimes[-1]) + ax.xaxis.set_major_locator(mdates.HourLocator(interval=6)) + ax.xaxis.set_major_formatter(mdates.DateFormatter("%H:%M")) + if len(score_setup) > 1: + sub_plot_legend = ax.legend( + score_setup, + loc="upper right", + markerscale=0.9, + bbox_to_anchor=(1.1, 1.05), + ) + for line in sub_plot_legend.get_lines(): + line.set_color("black") + filename += "_" + "_".join(score_setup) + + if current_plot_idx % 2 == 1 or idx == len(plot_scores_setup) - 1: + _clear_empty_axes_if_necessary(subplot_axes, current_plot_idx) + fig.savefig(f"{output_dir}/{filename}.png") + plt.close() + filename = base_filename + f"_{ltr}" + fig, subplot_axes = _initialize_plots(ltr_models_data[ltr].keys()) + current_plot_idx += 1 + + +def _generate_daytime_plots( + plot_scores, + models_data, + parameter, output_dir, - grid, debug, ): - """Generate Daytime Plot.""" - # output_dir = f"{output_dir}/daytime_scores" - if not Path(output_dir).exists(): - Path(output_dir).mkdir(parents=True, exist_ok=True) - print(f"creating plots for file: {file}") - - # extract scores, which are available in the dataframe (data) - scores = data.columns.tolist() - - # Standard library - from datetime import datetime - from datetime import timedelta + model_versions = list(models_data.keys()) - # two random consecutive dates [date1, date2] - start_time = datetime.combine( - datetime.strptime("01/02/1991", "%d/%m/%Y"), - datetime.strptime("00:00:00", "%H:%M:%S").time(), + # initialise filename + base_filename = ( + f"daytime_scores_{model_versions[0]}_{parameter}" + if len(model_versions) == 1 + else f"daytime_scores_{parameter}" ) - end_time = datetime.combine( - datetime.strptime("02/02/1991", "%d/%m/%Y"), - datetime.strptime("00:00:00", "%H:%M:%S").time(), + sup_title = "" + # plot regular scores + _plot_and_save_scores( + output_dir, + base_filename, + parameter, + plot_scores["regular_scores"], + sup_title, + models_data, + debug=False, ) - # define x-axis only once. list of datetimes from date1 00:00 - date2 00:00 - x = get_xaxis() - - # check, which timestamps are actually necessary - available_times = data.index.tolist() - available_x = [] - for available_time in available_times: - available_x.append(x[available_time]) - first_point = available_x[0] - last_point = available_x[-1] - available_x.insert(0, last_point - timedelta(hours=24)) - available_x.append(first_point + timedelta(hours=24)) - - unit = header_dict["Unit"][0] - - # define further plot properties - grid = True - - score_to_skip = None - for score in scores: - if score == score_to_skip: - continue - - param = header_dict["Parameter"][0] - param = check_params(param=param, verbose=debug) - print(f"plotting:\t{param}/{score}") - - multiplt = False - title = f"{variable}: {score}" - footer = f"""Model: {header_dict['Model version'][0]} | - Period: {header_dict['Start time'][0]} - - {header_dict['End time'][0]} ({lt_range}) | © MeteoSwiss""" - - # initialise figure/axes instance - fig, ax = plt.subplots( - 1, 1, figsize=(1660 / 100, 1100 / 100), dpi=150, tight_layout=True - ) - - ax.set_xlim(start_time, end_time) - ax.set_ylabel(f"{score.upper()} ({unit})") - - # TODO: retrieve ymin/ymax from correct tables in plot_synop - # and set ax.set_ylim(ymin,ymax) - - if grid: - ax.grid(which="major", color="#DDDDDD", linewidth=0.8) - ax.grid(which="minor", color="#EEEEEE", linestyle=":", linewidth=0.5) - ax.minorticks_on() - - if debug: - print(f"Extract dataframe for score: {score}") - pprint(data) - - y = data[score].values.tolist() - - if score in multiplots.keys(): - y2 = data[multiplots[score]].values.tolist() - multiplt = True - score_to_skip = multiplots[score] - title = f"{variable}: {score}/{multiplots[score]}" - ax.set_ylabel(f"{score.upper()}/{multiplots[score].upper()} ({unit})") - - # plot dashed line @ 0 - ax.plot(x, [0] * len(x), color="grey", linestyle="--") - - # define limits for yaxis if available - regular_param = (param, "min") in daytime_score_range.columns - regular_score = score in daytime_score_range.index - cat_score = not regular_score - - if regular_param and regular_score: - lower_bound = daytime_score_range[param]["min"].loc[score] - upper_bound = daytime_score_range[param]["max"].loc[score] - if debug: - print( - f"found limits for {param}/{score} --> {lower_bound}/{upper_bound}" - ) - if lower_bound != upper_bound: - ax.set_ylim(lower_bound, upper_bound) - - if cat_score: - # get the index of the current score - index = cat_daytime_score_range[ - cat_daytime_score_range[param]["scores"] == score - ].index.values[0] - # get min/max value - lower_bound = cat_daytime_score_range[param]["min"].iloc[index] - upper_bound = cat_daytime_score_range[param]["max"].iloc[index] - if debug: - print( - f"found limits for {param}/{score} --> {lower_bound}/{upper_bound}" - ) - if lower_bound != upper_bound: - ax.set_ylim(lower_bound, upper_bound) - - label = f"{score.upper()}" - if not multiplt: - # pre-/append first and last values to the scores lists - first_y, last_y = y[0], y[-1] - y.insert(0, last_y) - y.append(first_y) - - ax.plot( - available_x, - y, - color="k", - marker="o", - linestyle="-", - label=label, - ) - if multiplt: - # pre-/append first and last values to the scores lists - first_y, last_y = y[0], y[-1] - y.insert(0, last_y) - y.append(first_y) - first_y2, last_y2 = y2[0], y2[-1] - y2.insert(0, last_y2) - y2.append(first_y2) - - # change title, y-axis label, filename here, for the multiplot case - ax.plot( - available_x, - y, - color="red", - linestyle="-", - marker="o", - label=label, - ) - label = f"{multiplots[score].upper()}" - ax.plot( - available_x, - y2, - color="k", - linestyle="-", - marker="o", - label=label, - ) - - # From the SO:https://stackoverflow.com/questions/42398264/matplotlib-xticks-every-15-minutes-starting-on-the-hour # noqa: E501 - # Set time format and the interval of ticks (every n minutes) - xformatter = md.DateFormatter("%H:%M") - xlocator = md.MinuteLocator(interval=360) - # Set xtick labels to appear every n minutes - ax.xaxis.set_major_locator(xlocator) - # Format xtick labels as HH:MM - plt.gcf().axes[0].xaxis.set_major_formatter(xformatter) - - plt.legend() - - plt.suptitle( - footer, - x=0.03, - y=0.957, - horizontalalignment="left", - verticalalignment="top", - fontdict={ - "size": 6, - "color": "k", - }, - ) - ax.set_title(label=title) - - print(f"saving:\t\t{output_dir}/{file.split(file_postfix)[0]}_{score}.png") - plt.savefig(f"{output_dir}/{file.split(file_postfix)[0]}_{score}.png") - plt.close(fig) - - return + _plot_and_save_scores( + output_dir, + base_filename, + parameter, + plot_scores["cat_scores"], + sup_title, + models_data, + debug=False, + ) diff --git a/src/moveroplot/ensemble_scores.py b/src/moveroplot/ensemble_scores.py new file mode 100644 index 0000000..82f7e89 --- /dev/null +++ b/src/moveroplot/ensemble_scores.py @@ -0,0 +1,360 @@ +"""Calculate ensemble scores from parsed data.""" +# Standard library +import re +from typing import Tuple + +# Third-party +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.lines import Line2D + +# First-party +from moveroplot.config import plot_settings +from moveroplot.load_files import load_relevant_files +from moveroplot.plotting import get_total_dates_from_headers + +# Local +from .station_scores import _calculate_figsize + +# pylint: disable=no-name-in-module + + +def _ensemble_score_transformation(df, header): + df = df.replace(float(header["Missing value code"][0]), np.NaN) + df.set_index(keys="Score", inplace=True) + return df + + +# pylint: disable=too-many-arguments,too-many-locals +def _ensemble_scores_pipeline( + plot_setup, + lt_ranges, + file_prefix, + file_postfix, + input_dir, + output_dir, + debug, +) -> None: + print("\n--- initialising ensemble score pipeline") + if not lt_ranges: + lt_ranges = "07-12,13-18,19-24" + for model_plots in plot_setup["model_versions"]: + for parameter, scores in plot_setup["parameter"].items(): + model_data = load_relevant_files( + input_dir, + file_prefix, + file_postfix, + debug, + model_plots, + parameter, + lt_ranges, + ltr_first=True, + transform_func=_ensemble_score_transformation, + ) + if not model_data: + print(f"No matching files found with given ltr {lt_ranges}") + return + _generate_ensemble_scores_plots( + plot_scores=scores, + models_data=model_data, + parameter=parameter, + output_dir=output_dir, + debug=debug, + ) + + +def _initialize_plots( + num_rows: int, num_cols: int, single_figsize: Tuple[int, int] = (8, 4) +): + figsize = _calculate_figsize( + num_rows, num_cols, single_figsize, (1, 1) + ) # (10, 6.8) + fig, axes = plt.subplots( + nrows=num_rows, + ncols=num_cols, + tight_layout=True, + figsize=figsize, + dpi=100, + squeeze=False, + ) + fig.tight_layout(w_pad=6, h_pad=4, rect=(0.05, 0.05, 0.90, 0.85)) + plt.subplots_adjust(bottom=0.15) + return fig, axes + + +def _add_sample_subplot(fig, ax): + box = ax.get_position() + width = box.width + height = box.height + l, b, h, w = 0.8, 0.025, 0.3, 0.2 + w *= width + h *= height + inax_position = ax.transAxes.transform([l, b]) + transformed_fig = fig.transFigure.inverted() + infig_position = transformed_fig.transform(inax_position) + sub_plot = fig.add_axes([*infig_position, w, h]) + sub_plot.set_xticks([]) + sub_plot.set_title("N") + return sub_plot + + +def _add_boundary_line(ax, points): + ax.plot( + [0, 1], + points, + color="black", + fillstyle="none", + linestyle="--", + alpha=0.2, + ) + + +def _get_bin_values(data: dict, prefix: str, threshold: str): + indices = [ + index for index in data["df"]["Total"].index if f"{prefix}{threshold}" in index + ] + return data["df"]["Total"][indices] + + +def _customize_figure(fig, sup_title, models_color_lines, labels): + fig.suptitle( + sup_title, + horizontalalignment="center", + verticalalignment="top", + fontdict={ + "size": 6, + "color": "k", + }, + bbox={"facecolor": "none", "edgecolor": "grey"}, + ) + + fig.legend( + models_color_lines, + labels, + loc="upper right", + ncol=1, + frameon=False, + ) + + +# pylint: disable=too-many-branches,too-many-statements +def _plot_and_save_scores( + output_dir, + base_filename, + parameter, + plot_scores_setup, + sup_title, + models_data, + models_color_lines, + debug=False, +): + if debug: + print("Plotting ensemble scores.") + for score_setup in plot_scores_setup: + custom_sup_title = sup_title + filename = base_filename + if "RANK" in score_setup: + [score] = score_setup + custom_sup_title = f"RANK: {sup_title}" + for ltr, model_data in models_data.items(): + fig, subplot_axes = _initialize_plots(1, 1) + filename = f"{base_filename}_RANK_{ltr}" + [ax] = subplot_axes.ravel() + ax.set_xlabel("RANK") + ax.set_title(f"{parameter}, LT: {ltr}") + for model_idx, data in enumerate(model_data.values()): + model_plot_color = plot_settings.modelcolors[model_idx] + model_ranks = sorted( + [ + index + for index in data["df"]["Total"].index + if "RANK" in index + ], + key=lambda x: int("".join(filter(str.isdigit, x))), + ) + ranks = data["df"]["Total"][model_ranks].reset_index(drop=True) + ax.bar( + np.arange(len(model_ranks)) + model_idx * 0.25, + ranks, + width=0.25, + color=model_plot_color, + ) + _customize_figure( + fig, + custom_sup_title, + models_color_lines, + list(models_data[next(iter(models_data.keys()))].keys()), + ) + + fig.savefig(f"{output_dir}/{filename}.png") + plt.close() + elif any("REL_DIA" in score for score in score_setup): + [score] = score_setup + threshold = re.search(r"\(.*?\)", score).group() + for ltr, model_data in models_data.items(): + fig, subplot_axes = _initialize_plots(1, 1, (6.7, 6)) + [ax] = subplot_axes.ravel() + filename = f"{base_filename}_{score}_{ltr}" + ax.set_ylabel("Observed Relative Frequency") + ax.set_xlabel("Forecast Probability") + ax.set_xlim(0, 1) + ax.set_ylim(0, 1) + ax.set_aspect("equal") + [unit] = model_data[next(iter(model_data.keys()))]["header"]["Unit"] + ax.set_title(f"{parameter} {threshold[1:-1]} {unit}, LT: {ltr}") + sample_subplot = _add_sample_subplot(fig, ax) + + for model_idx, data in enumerate(model_data.values()): + model_plot_color = plot_settings.modelcolors[model_idx] + fbin_values = _get_bin_values(data, "FBIN", threshold) + obin_values = _get_bin_values(data, "OBIN", threshold) + nbin_values = _get_bin_values(data, "NBIN", threshold) + of_value = _get_bin_values(data, "OF", threshold) + ax.plot( + fbin_values, + obin_values, + color=model_plot_color, + marker="D", + fillstyle="none", + ) + + sample_subplot.bar( + np.arange(len(nbin_values)) + model_idx * 0.25, + nbin_values, + width=0.25, + color=model_plot_color, + ) + + _add_boundary_line(ax, [0, 1]) + _add_boundary_line(ax, [of_value, of_value]) + _add_boundary_line( + ax, + [ + (1 - np.tan(np.pi / 8)) * of_value, + of_value + (1 - of_value) * np.tan(np.pi / 8), + ], + ) + sample_subplot.set_yticks(np.round([max(nbin_values)], -3)) + _customize_figure( + fig, + custom_sup_title, + models_color_lines, + list(models_data[next(iter(models_data.keys()))].keys()), + ) + fig.savefig(f"{output_dir}/{filename}.png") + plt.close() + else: + fig, subplot_axes = _initialize_plots( + 2 if len(score_setup) > 1 else 1, + (len(score_setup) + 1) // 2, + ) + subplot_axes = subplot_axes.ravel() + fig.legend( + models_color_lines, + list(models_data[next(iter(models_data.keys()))].keys()), + loc="upper right", + ncol=1, + frameon=False, + ) + ltr_sorted = sorted( + list(models_data.keys()), key=lambda x: int(x.split("-")[0]) + ) + x_int = list(range(len(ltr_sorted))) + for score_idx, score in enumerate(score_setup): + print( + "SCORE IN ENS ", + score, + models_data.keys(), + [models_data[ltr].keys() for ltr in ltr_sorted], + ) + ax = subplot_axes[score_idx] + filename += f"_{score}" + for model_idx, model_name in enumerate( + models_data[next(iter(ltr_sorted))].keys() + ): + model_plot_color = plot_settings.modelcolors[model_idx] + y_values = [ + models_data[ltr][model_name]["df"]["Total"].loc[score] + if model_name in models_data[ltr].keys() + else None + for ltr in ltr_sorted + ] + filtered_x_int = [ + x for x, y in zip(x_int, y_values) if y is not None + ] + filtered_y_values = [y for y in y_values if y is not None] + ax.plot( + filtered_x_int, + filtered_y_values, + color=model_plot_color, + marker="D", + fillstyle="none", + ) + + ax.set_ylabel(f"{score}") + ax.set_xticks(x_int, ltr_sorted) + ax.set_title(f"{parameter}: {score}") + ax.grid(which="major", color="#DDDDDD", linewidth=0.8) + ax.grid(which="minor", color="#EEEEEE", linestyle=":", linewidth=0.5) + ax.set_xlabel("Lead-Time Range (h)") + + if len(score_setup) > 2 and len(score_setup) % 2 == 1: + subplot_axes[-1].axis("off") + + _customize_figure( + fig, + custom_sup_title, + models_color_lines, + list(models_data[next(iter(models_data.keys()))].keys()), + ) + fig.savefig(f"{output_dir}/{filename}.png") + plt.close() + + +def _generate_ensemble_scores_plots( + plot_scores, + models_data, + parameter, + output_dir, + debug, +): + """Generate Ensemble Score Plots.""" + model_plot_colors = plot_settings.modelcolors + model_versions = list(models_data[next(iter(models_data))].keys()) + custom_lines = [ + Line2D([0], [0], color=model_plot_colors[i], lw=2) + for i in range(len(model_versions)) + ] + + # initialise filename + base_filename = f"ensemble_scores_{parameter}" + + headers = [data["header"] for data in models_data[next(iter(models_data))].values()] + total_start_date, total_end_date = get_total_dates_from_headers(headers) + # pylint: disable=line-too-long + sup_title = f"""{parameter} + Period: {total_start_date.strftime("%Y-%m-%d")} - {total_end_date.strftime("%Y-%m-%d")} | © MeteoSwiss""" # noqa: E501 + # pylint: enable=line-too-long + if debug: + print("Generating ensemble plots.") + _plot_and_save_scores( + output_dir, + base_filename, + parameter, + plot_scores["regular_ens_scores"], + sup_title, + models_data, + custom_lines, + debug=False, + ) + + _plot_and_save_scores( + output_dir, + base_filename, + parameter, + plot_scores["ens_cat_scores"], + sup_title, + models_data, + custom_lines, + debug=False, + ) diff --git a/src/moveroplot/load_files.py b/src/moveroplot/load_files.py new file mode 100644 index 0000000..8b02de8 --- /dev/null +++ b/src/moveroplot/load_files.py @@ -0,0 +1,78 @@ +"""General function to load and collect atab files.""" + +# Standard library +import re +from datetime import datetime +from pathlib import Path + +# Local +# pylint: disable=no-name-in-module +from .utils.atab import Atab + + +# pylint: disable=too-many-arguments,too-many-locals +def is_valid_data(header): + try: + datetime.strptime(" ".join(header["Start time"][0:3:2]), "%Y-%m-%d %H:%M") + datetime.strptime(" ".join(header["End time"][0:2]), "%Y-%m-%d %H:%M") + return True + except ValueError: + return False + + +def load_relevant_files( + input_dir, + file_prefix, + file_postfix, + debug, + model_plots, + parameter, + lt_ranges, + ltr_first=True, + transform_func=None, +): + corresponding_files_dict = {} + files_list = [] + for model in model_plots: + source_path = Path(f"{input_dir}/{model}") + for file_path in source_path.glob(f"{file_prefix}*{parameter}{file_postfix}"): + if file_path.is_file(): + ltr_match = re.search(r"(\d{2})-(\d{2})", file_path.name) + if ltr_match: + lt_range = ltr_match.group() + else: + raise IOError( + f"The filename {file_path.name} does not contain a LT range." + ) + + in_lt_ranges = True + if lt_ranges: + in_lt_ranges = lt_range in lt_ranges + + if in_lt_ranges: + # extract header & dataframe + loaded_atab = Atab(file=file_path, sep=" ") + header = loaded_atab.header + df = loaded_atab.data + if transform_func: + df = transform_func(df, header) + if is_valid_data(header): + # add information to dict + first_key, second_key = ( + (lt_range, model) if ltr_first else (model, lt_range) + ) + corresponding_files_dict.setdefault(first_key, {})[ + second_key + ] = { + "header": header, + "df": df, + } + + # add path of file to list of relevant files + files_list.append(file_path) + + if debug: + print(f"\nFor parameter: {parameter} these files are relevant:\n") + print("Found files: ", files_list) + + return corresponding_files_dict diff --git a/src/moveroplot/main.py b/src/moveroplot/main.py index 4588556..c0db3da 100644 --- a/src/moveroplot/main.py +++ b/src/moveroplot/main.py @@ -46,12 +46,14 @@ """ # noqa: E501 # Standard library from pathlib import Path +from typing import Optional # Third-party from click import Context # Local from .daytime_scores import _daytime_scores_pipeline +from .ensemble_scores import _ensemble_scores_pipeline # local from .parse_inputs import _parse_inputs @@ -67,24 +69,27 @@ def main( ctx: Context, *, # legacy inputs - model_version: str, + model_versions: str, debug: bool, lt_ranges: tuple, # Parameters / Scores / Thresholds - plot_params: str, - plot_scores: str, - plot_cat_params: str, - plot_cat_thresh: str, - plot_cat_scores: str, - plot_ens_params: str, - plot_ens_thresh: str, - plot_ens_scores: str, + plot_params: Optional[str], + plot_scores: Optional[str], + plot_cat_params: Optional[str], + plot_cat_thresh: Optional[str], + plot_cat_scores: Optional[str], + plot_ens_params: Optional[str], + plot_ens_scores: Optional[str], + plot_ens_cat_params: Optional[str], + plot_ens_cat_thresh: Optional[str], + plot_ens_cat_scores: Optional[str], # new inputs input_dir: Path, - output_dir: str, + output_dir: Path, relief: bool, grid: bool, - season: str, + colors: Optional[str], + plot_type: str, ): """Entry Point for the MOVERO Plotting Pipeline. @@ -96,7 +101,7 @@ def main( python plot_synop.py C-1E-CTR_ch - --plot_params TOT_PREC12,TOT_PREC6,TOT_PREC1,CLCT, + --plot_params TOT_PREC12,TOT_PREC6,CLCT, GLOB,DURSUN12,DURSUN1,T_2M,T_2M_KAL, TD_2M,TD_2M_KAL,RELHUM_2M,FF_10M, FF_10M_KAL,VMAX_10M6,VMAX_10M1, @@ -115,82 +120,85 @@ def main( --plot_cat_scores FBI,MF/OF,POD,FAR,THS,ETS """ # noqa: E501 - # -1. DEFINE PLOTS - station_scores = False - time_scores = False - daytime_scores = False - total_scores = True + # -1. Check plot type input + if plot_type is None: + raise ValueError("ERROR: No plot type argument --plot_type.") + + if not Path(output_dir).exists(): + Path(output_dir).mkdir(parents=True, exist_ok=True) + # 0. PARSE USER INPUT - params_dict = _parse_inputs( + plot_setup = _parse_inputs( debug, + input_dir, + model_versions, plot_params, plot_scores, plot_cat_params, plot_cat_thresh, plot_cat_scores, plot_ens_params, - plot_ens_thresh, plot_ens_scores, + plot_ens_cat_params, + plot_ens_cat_thresh, + plot_ens_cat_scores, + colors, + plot_type, ) + print("PLOT SETUP ", plot_setup) # 1. INITIALISE STATION SCORES PLOTTING PIPELINE - if station_scores: + if "station" in plot_type: _station_scores_pipeline( - params_dict=params_dict, + plot_setup=plot_setup, lt_ranges=lt_ranges, file_prefix="station_scores", file_postfix=".dat", input_dir=input_dir, output_dir=output_dir, - season=season, # 2021s4 - model_version=model_version, # C-1E-CTR_ch - relief=relief, debug=debug, ) # 2. INITIALISE TIME SERIES PLOTTING PIPELINE - if time_scores: + if "time" in plot_type: _time_scores_pipeline( - params_dict=params_dict, + plot_setup=plot_setup, lt_ranges=lt_ranges, file_prefix="time_scores", file_postfix=".dat", input_dir=input_dir, output_dir=output_dir, - season=season, - model_version=model_version, - grid=grid, debug=debug, ) # 3. INITIALISE DYURNAL CYCLE PLOTTING PIPELINE - if daytime_scores: + if "daytime" in plot_type: _daytime_scores_pipeline( - params_dict=params_dict, + plot_setup=plot_setup, lt_ranges=lt_ranges, file_prefix="daytime_scores", file_postfix=".dat", input_dir=input_dir, output_dir=output_dir, - season=season, - model_version=model_version, - grid=grid, debug=debug, ) # 4. INITIALIS TOTAL SCORES PLOTTING PIPELINE - if total_scores: + if "total" in plot_type: _total_scores_pipeline( - params_dict=params_dict, - plot_scores=plot_scores, - plot_params=plot_params, - plot_cat_scores=plot_cat_scores, - plot_cat_params=plot_cat_params, - plot_cat_thresh=plot_cat_thresh, - # TODO: add plot_ens params/scores/threshs + plot_setup=plot_setup, + lt_ranges=lt_ranges, + file_prefix="total_scores", + file_postfix=".dat", + input_dir=input_dir, + output_dir=output_dir, + debug=debug, + ) + + if "ensemble" in plot_type: + _ensemble_scores_pipeline( + plot_setup=plot_setup, + lt_ranges=lt_ranges, file_prefix="total_scores", file_postfix=".dat", input_dir=input_dir, output_dir=output_dir, - season=season, - model_version=model_version, - grid=grid, debug=debug, ) print("\n--- Done.") diff --git a/src/moveroplot/mutable_number.py b/src/moveroplot/mutable_number.py deleted file mode 100644 index 2760bfc..0000000 --- a/src/moveroplot/mutable_number.py +++ /dev/null @@ -1,50 +0,0 @@ -"""Mutable number.""" -# Standard library -from typing import List - - -class MutableNumber: - """A mutable number.""" - - def __init__(self, number: float) -> None: - """Create an instance of ``MutableNumber``. - - Args: - number: Initial number. - - """ - self.history: List[float] = [float(number)] - - def get(self, idx: int = -1) -> float: - """Get the current or a past value of the number. - - Args: - idx (optional): Index since the initial value. Defaults to the most - recent (i.e., current) value. - - """ - return self.history[idx] - - def add(self, addend: float) -> "MutableNumber": - """Add ``addend`` to the current number.""" - number = self.get() + float(addend) - self.history.append(number) - return self - - def subtract(self, subtrahend: float) -> "MutableNumber": - """Subtract ``subtrahend`` from the current number.""" - number = self.get() - float(subtrahend) - self.history.append(number) - return self - - def multiply(self, factor: float) -> "MutableNumber": - """Multiply current number by ``factor``.""" - number = self.get() * float(factor) - self.history.append(number) - return self - - def divide(self, divisor: float) -> "MutableNumber": - """Divide current number by ``divisor``.""" - number = self.get() / float(divisor) - self.history.append(number) - return self diff --git a/src/moveroplot/parse_inputs.py b/src/moveroplot/parse_inputs.py index 73582ee..1b5de50 100644 --- a/src/moveroplot/parse_inputs.py +++ b/src/moveroplot/parse_inputs.py @@ -1,24 +1,39 @@ # pylint: skip-file """Parse raw data from ATAB files into data frame.""" # Standard library +import itertools +import re +from pathlib import Path from pprint import pprint +# First-party +import moveroplot.config.plot_settings as plot_settings + +invalid_ensemble_paramter = ["DD_10M", "PS", "PMSL"] + def _parse_inputs( debug, + input_dir, + model_versions, plot_params, plot_scores, plot_cat_params, plot_cat_thresh, plot_cat_scores, plot_ens_params, - plot_ens_thresh, plot_ens_scores, + plot_ens_cat_params, + plot_ens_cat_thresh, + plot_ens_cat_scores, + plotcolors, + plot_type, ): """Parse the user input flags. Args: debug (bool): Add debug statements to command prompt. + model_versions (str): string of models (i.e. "C-1E_ch,C-1E-CTR_ch") plot_params (str): string w/ regular plot parameters. i.e. "TOT_PREC12,TOT_PREC6,TOT_PREC1" plot_scores (str): strint w/ regular plot scores. @@ -33,96 +48,185 @@ def _parse_inputs( plot_ens_params (str): string w/ ens plot parameters. Separated by comma. plot_ens_thresh (str): string w/ ens scores. Separated by comma. plot_ens_scores (str): string w/ ens scores thresholds. Separated by coma. + plot_ens_cat_params (str): string w/ categorical ens plot parameters. Separated by comma. + plot_ens_cat_thresh (str): string w/ categorical ens scores. Separated by comma. + plot_ens_cat_scores (str): string w/ categorical ens scores thresholds. Separated by comma. + plotcolors (str): custom colors for each model version using matplotlib codes, separated by comma + plot_type (str): string which defines the plot types (station, ensemble, time, daytime, total) Returns: dict: Dictionary w/ all relevant parameters as keys. Each key is assigned a list of lists containing the corresponding scores. """ # noqa: E501 - print("--- debugging user inputs") + print("--- parsing user inputs") if debug: print("-------------------------------------------------------------") + + plot_setup = dict() + + # Check if the model versions are in the input dir + all_model_versions = re.split(r"[,/]", model_versions) + input_dir = Path(input_dir) + model_directories = {x.name for x in input_dir.iterdir() if x.is_dir()} + if not set(all_model_versions).issubset(model_directories): + not_in_dir = set(all_model_versions) - model_directories + raise ValueError( + f"""The model version inputs {list(not_in_dir)} + do not exist in the directory {input_dir}.""" + ) + + if plotcolors: + color_list = plotcolors.split(",") + if len(color_list) < len(all_model_versions): + raise ValueError( + f""" + The input length --plotcolor is smaller than the + number of models to plot ({len(color_list)} < {len(all_model_versions)}) + """ + ) + plot_settings.modelcolors = color_list + plot_models_setup = [ + model_combinations.split("/") + for model_combinations in model_versions.split(",") + ] + plot_setup["model_versions"] = plot_models_setup + # initialise empty dictionaries - regular_params_dict = None - cat_params_dict = None - ens_params_dict = None - - # REGULAR PARAMETERS - if plot_params and plot_scores: - params = plot_params.split(",") - # TOT_PREC12,TOT_PREC6,TOT_PREC1,CLCT,GLOB,DURSUN12,DURSUN1,T_2M,T_2M_KAL, - # TD_2M,TD_2M_KAL,RELHUM_2M,FF_10M, - # FF_10M_KAL,VMAX_10M6,VMAX_10M1,DD_10M,PS,PMSL - scores = plot_scores.split(",") # ME,MMOD/MOBS,MAE,STDE,RMSE,COR,NOBS - regular_params_dict = {param: [] for param in params} - for param in params: - for score in scores: - if "/" in score: - regular_params_dict[param].append(score.split("/")) + regular_params_dict = {} + cat_params_dict = {} + regular_ens_params_dict = {} + ens_cat_params_dict = {} + plot_setup["parameter"] = {} + if any(p_type in plot_type for p_type in ["total", "time", "station", "daytime"]): + if not any( + [ + plot_params and plot_scores, + plot_cat_params and plot_cat_scores and plot_cat_thresh, + ] + ): + raise ValueError( + f"Missing params, scores or thresholds for {plot_type} score plots." + ) + # REGULAR PARAMETERS + if plot_params and plot_scores: + params = plot_params.split(",") + scores = plot_scores.split(",") # ME,MMOD/MOBS,MAE,STDE,RMSE,COR,NOBS + regular_params_dict = {param: [] for param in params} + for param in params: + for score in scores: + if "/" in score: + regular_params_dict[param].append(score.split("/")) + else: + regular_params_dict[param].append([score]) + if debug: + print("Regular Parameter Dict: ") + pprint(regular_params_dict) + + # CATEGORICAL PARAMETERS + if plot_cat_params and plot_cat_scores and plot_cat_thresh: + cat_params = plot_cat_params.split(",") + cat_scores = plot_cat_scores.split(",") + cat_threshs = plot_cat_thresh.split(":") + cat_params_dict = {cat_param: [] for cat_param in cat_params} + for param, threshs in zip(cat_params, cat_threshs): + # append all scores with a threshold + thresholds = threshs.split(",") + for threshold in thresholds: + for score in cat_scores: + if "/" in score: + cat_params_dict[param].append( + [x + f"({threshold})" for x in score.split("/")] + ) + else: + cat_params_dict[param].append([f"{score}({threshold})"]) + + if debug: + print("Categorical Parameter Dict: ") + pprint(cat_params_dict) + if "ensemble" in plot_type: + if not any( + [ + plot_ens_params and plot_ens_scores, + plot_ens_cat_params and plot_ens_cat_scores and plot_ens_cat_thresh, + ] + ): + raise ValueError("Missing params, scores or thresholds for ensemble plots.") + if plot_ens_params and plot_ens_scores: + ens_params = plot_ens_params.split(",") + for invalid_param in invalid_ensemble_paramter: + if invalid_param in ens_params: + raise ValueError( + f"{invalid_param} us not a valid parameter for plot_ens_params." + ) + ens_scores = list() + score_setups = [ + score_combinations.split("/") + for score_combinations in plot_ens_scores.split(",") + ] + for score_set in score_setups: + if "RANK" in score_set and len(score_set) > 1: + ens_scores.append( + [score for score in score_set if "RANK" not in score] + ) + ens_scores.append(["RANK"]) else: - regular_params_dict[param].append([score]) - - if debug: - print("Regular Parameter Dict: ") - pprint(regular_params_dict) - - # CATEGORICAL PARAMETERS - if plot_cat_params and plot_cat_scores and plot_cat_thresh: - cat_params = plot_cat_params.split(",") - # categorical parameters: TOT_PREC12,TOT_PREC6,TOT_PREC1,CLCT, - # T_2M,T_2M_KAL,TD_2M,TD_2M_KAL,FF_10M,FF_10M_KAL,VMAX_10M6,VMAX_10M1 - cat_scores = plot_cat_scores.split( - "," - ) # categorical scores: FBI,MF,POD,FAR,THS,ETS - cat_threshs = plot_cat_thresh.split(":") - # categorical thresholds: 0.1,1,10:0.2,1,5:0.2,0.5,2:2.5,6.5:0,15, - # 25:0,15,25:-5,5,15:-5,5,15:2.5,5,10:2.5,5,10:5,12.5,20:5,12.5,20 - cat_params_dict = {cat_param: [] for cat_param in cat_params} - for param, threshs in zip(cat_params, cat_threshs): - # first append all scores w/o thresholds to parameter - for score in plot_scores.split(","): - if "/" in score: - cat_params_dict[param].append(score.split("/")) + ens_scores.append(score_set) + + regular_ens_params_dict = {param: [] for param in ens_params} + for param in ens_params: + regular_ens_params_dict[param].extend(ens_scores) + + if plot_ens_cat_params and plot_ens_cat_scores and plot_ens_cat_thresh: + ens_cat_params = plot_ens_cat_params.split(",") + ens_cat_scores = list() + ens_cat_score_setups = [ + score_comb.split("/") for score_comb in plot_ens_cat_scores.split(",") + ] + for score_set in ens_cat_score_setups: + if "REL_DIA" in score_set and len(score_set) > 1: + ens_cat_scores.append( + [score for score in score_set if "REL_DIA" not in score] + ) + ens_cat_scores.append(["REL_DIA"]) else: - cat_params_dict[param].append([score]) - - # append all scores with a threshold in their name to current to parameter - thresholds = threshs.split(",") - for threshold in thresholds: - for score in cat_scores: - if "/" in score: - cat_params_dict[param].append( - [x + f"({threshold})" for x in score.split("/")] + ens_cat_scores.append(score_set) + ens_cat_params_dict = {} + for param, threshs in zip(ens_cat_params, plot_ens_cat_thresh.split(":")): + param_thresh_combs = [ + thresholds.split("/") for thresholds in threshs.split(",") + ] + for thresh_comb in param_thresh_combs: + for score_comb in ens_cat_scores: + ens_cat_params_dict.setdefault(param, []).append( + [ + f"{score}({thresh})" + for thresh, score in itertools.product( + thresh_comb, score_comb + ) + ] ) - else: - cat_params_dict[param].append([f"{score}({threshold})"]) - - if debug: - print("Categorical Parameter Dict: ") - pprint(cat_params_dict) - - # ENV PARAMETERS (TODO) - if plot_ens_params and plot_ens_scores and plot_ens_thresh: - ens_params_dict = {} - print("extend code here to create a end-dict.") - - # merge the dictionaries if the exist - # regular & categorical parameters together - if regular_params_dict and cat_params_dict: - params_dict = ( - regular_params_dict | cat_params_dict - ) # merges the right dict into the left and is assigned to new dict - # TODO: cover more cases, for the various possible combinations of dictionaries - - # only regular parameters - elif regular_params_dict and not cat_params_dict and not ens_params_dict: - params_dict = regular_params_dict - - elif cat_params_dict and not regular_params_dict and not ens_params_dict: - params_dict = cat_params_dict + all_keys = ( + set(regular_params_dict) + | set(cat_params_dict) + | set(regular_ens_params_dict) + | set(ens_cat_params_dict) + ) + plot_setup["parameter"] = { + key: { + "regular_scores": regular_params_dict.get(key, []), + "cat_scores": cat_params_dict.get(key, []), + "regular_ens_scores": regular_ens_params_dict.get(key, []), + "ens_cat_scores": ens_cat_params_dict.get(key, []), + } + for key in all_keys + } + if not plot_setup["parameter"]: + raise IOError("Invalid Input: parameter and/or scores are missing.") if debug: print("Finally, the following parameter x score pairs will get plotted:") - pprint(params_dict) - return params_dict + pprint(plot_setup) + + return plot_setup diff --git a/src/moveroplot/plotting.py b/src/moveroplot/plotting.py new file mode 100644 index 0000000..ec15dee --- /dev/null +++ b/src/moveroplot/plotting.py @@ -0,0 +1,15 @@ +"""General functions to create plots.""" + +# Standard library +from datetime import datetime + + +def get_total_dates_from_headers(headers: list): + total_start_date = min( + datetime.strptime(header["Start time"][0], "%Y-%m-%d") for header in headers + ) + + total_end_date = max( + datetime.strptime(header["End time"][0], "%Y-%m-%d") for header in headers + ) + return total_start_date, total_end_date diff --git a/src/moveroplot/station_scores.py b/src/moveroplot/station_scores.py index 26e95df..a26db82 100644 --- a/src/moveroplot/station_scores.py +++ b/src/moveroplot/station_scores.py @@ -1,8 +1,7 @@ # pylint: skip-file # relevant imports for parsing pipeline # Standard library -import pprint -import sys +from datetime import datetime from pathlib import Path # Third-party @@ -18,13 +17,13 @@ # > taken from: https://stackoverflow.com/questions/37423997/cartopy-shaded-relief from cartopy.io.img_tiles import GoogleTiles +# First-party +from moveroplot.load_files import load_relevant_files + # Local # local -from .utils.atab import Atab from .utils.check_params import check_params -from .utils.parse_plot_synop_ch import cat_station_score_colortable from .utils.parse_plot_synop_ch import cat_station_score_range -from .utils.parse_plot_synop_ch import station_score_colortable from .utils.parse_plot_synop_ch import station_score_range @@ -40,17 +39,167 @@ def _image_url(self, tile): return url +def _calculate_figsize(num_rows, num_cols, single_plot_size=(8, 6), padding=(2, 2)): + """Calculate the figure size given the number of rows and columns of subplots. + + Args: + - num_rows: Number of rows of subplots. + - num_cols: Number of columns of subplots. + - single_plot_size: A tuple (width, height) of the size of a single subplot. + - padding: A tuple (horizontal_padding, vertical_padding) between plots. + + Returns: + - tuple representing the figure size. + + """ + total_width = num_cols * single_plot_size[0] + (num_cols - 1) * padding[0] + total_height = num_rows * single_plot_size[1] + (num_rows - 1) * padding[1] + return (total_width, total_height) + + +def _initialize_plots(labels: list, scores: list): + num_cols = len(labels) + num_rows = len(scores) + figsize = _calculate_figsize(num_rows, num_cols, (7.3, 5), (0, 2)) + fig, axes = plt.subplots( + subplot_kw=dict(projection=ccrs.PlateCarree()), + nrows=num_rows, + ncols=num_cols, + tight_layout=True, + figsize=figsize, + dpi=100, + squeeze=False, + ) + for ax in axes.ravel(): + ax.set_extent([5.3, 11.2, 45.4, 48.2]) + _add_features(ax) + fig.tight_layout(w_pad=8, h_pad=2, rect=[0.05, 0.05, 0.90, 0.90]) + plt.subplots_adjust(bottom=0.15) + return fig, axes + + +def _add_plot_text(ax, data, score, ltr): + [subplot_title] = data["header"]["Model version"] + ax.set_title(f"{subplot_title}: {score}, LT: {ltr}") + if score not in data["df"].index: + return + min_value = data["df"].loc[score].min() + min_station = data["df"].loc[score].idxmin() + max_value = data["df"].loc[score].max() + max_station = data["df"].loc[score].idxmax() + try: + start_date = datetime.strptime( + " ".join(data["header"]["Start time"][0:3:2]), "%Y-%m-%d %H:%M" + ) + end_date = datetime.strptime( + " ".join(data["header"]["End time"][0:2]), "%Y-%m-%d %H:%M" + ) + except ValueError: + start_date = datetime(9999, 1, 1, hour=0, minute=0) + end_date = datetime(9999, 1, 1, hour=0, minute=0) + print("Found invalid date format.") + + # pylint: disable=line-too-long + plt.text( + 0.5, + -0.1, + f"""{start_date.strftime("%Y-%m-%d %H:%M")} to {end_date.strftime("%Y-%m-%d %H:%M")} -Min: {min_value} mm at station {min_station} +Max: {max_value} mm at station {max_station}""", # noqa: E501 + horizontalalignment="center", + verticalalignment="center", + transform=ax.transAxes, + fontsize=8, + ) + # pylint: enable=line-too-long + + +def _plot_and_save_scores( + output_dir, + base_filename, + parameter, + plot_scores_setup, + sup_title, + ltr_models_data, + debug=False, +): + for ltr, models_data in ltr_models_data.items(): + ltr_info = f"_{ltr}" + model_info = ( + "" if len(models_data.keys()) > 1 else f"_{next(iter(models_data.keys()))}" + ) + for scores in plot_scores_setup: + filename = base_filename + ltr_info + model_info + fig, subplot_axes = _initialize_plots(models_data.keys(), scores) + for idx, score in enumerate(scores): + filename += f"_{score}" + for model_idx, data in enumerate(models_data.values()): + ax = subplot_axes[idx][model_idx] + _add_datapoints2( + fig=fig, + data=data["df"], + score=score, + ax=ax, + min=-10, + max=10, + unit=data["header"]["Unit"][0], + param=data["header"]["Parameter"], + ) + + _add_plot_text(ax, data, score, ltr) + + fig.suptitle( + sup_title, + horizontalalignment="center", + verticalalignment="top", + fontdict={ + "size": 6, + "color": "k", + }, + bbox={"facecolor": "none", "edgecolor": "grey"}, + ) + fig.savefig(f"{output_dir}/{filename}.png") + plt.close() + + +def _generate_station_plots( + plot_scores, + models_data, + parameter, + output_dir, + debug, +): + # initialise filename + base_filename = f"station_scores_{parameter}" + sup_title = f"PARAMETER: {parameter}" + + _plot_and_save_scores( + output_dir, + base_filename, + parameter, + plot_scores["regular_scores"], + sup_title, + models_data, + debug=False, + ) + _plot_and_save_scores( + output_dir, + base_filename, + parameter, + plot_scores["cat_scores"], + sup_title, + models_data, + debug=False, + ) + + # enter directory / read station_scores files / call plotting pipeline +# type: ignore def _station_scores_pipeline( - params_dict, + plot_setup, lt_ranges, file_prefix, file_postfix, input_dir, output_dir, - season, - model_version, - relief, debug, ) -> None: """Read all ```ATAB``` files that are present in: data_dir/season/model_version/<...>. @@ -66,7 +215,6 @@ def _station_scores_pipeline( file_postfix (str): postfix of files (i.e. '.dat') input_dir (str): directory to seasons (i.e. /scratch/osm/movero/wd) output_dir (str): output directory (i.e. plots/) - season (str): season of interest (i.e. 2021s4/) model_version (str): model_version of interest (i.e. C-1E_ch) scores (list): list of scores, for which plots should be generated relief (bool): passed on to plotting pipeline - add relief to map if True @@ -74,111 +222,42 @@ def _station_scores_pipeline( """ # noqa: E501 print("--- initialising station score pipeline") - for lt_range in lt_ranges: - for parameter in params_dict: - # retrieve list of scores, relevant for current parameter - scores = params_dict[parameter] # this scores is a list of lists - - # define path to the file of the current parameter (station_score atab file) - file = f"{file_prefix}{lt_range}_{parameter}{file_postfix}" - path = Path(f"{input_dir}/{season}/{model_version}/{file}") - - # check if the file exists - if not path.exists(): - print( - f"""WARNING: No data file for parameter {parameter} could be found. - {path} does not exist.""" - ) - continue # no file could be retrieved for the current parameter - - if debug: - print(f"\nFilepath:\t{path}") - - # extract header - header = Atab(file=path, sep=" ").header - relevant_header_information = { - "Start time": header["Start time"], - "End time": header[ - "End time" - ], # i.e. ['2021-11-30', '2300', '', '+000'], - "Missing value code": header["Missing value code"][0], - "Model name": header["Model name"][0], - "Model version": header["Model version"][0], - "Parameter": header["Parameter"][0], - "Unit": header["Unit"][0], - } - # pprint.pprint(relevant_header_information) # dbg - - # TODO: longitude gets parsed ugly --> check separator in atab.py - # looks like this: ['7.56100', '', '', '', '', '', '', '8.60800',....] - # should look like this: ['7.56100', '8.60800', ..] - longitudes = list(filter(None, header["Longitude"])) - latitudes = list(filter(None, header["Latitude"])) - - # extract dataframe - df = Atab(file=path, sep=" ").data - - print(path) - pprint(df) - """ - # > rename the first column - # TODO (in ATAB): split the first column based on number of characters and not based - # on separator. get number of characters from header: Width of text label column: 14 - - # alternatively: - # get column names - # get first column name - # remove Score from first column name - # keep rest and rename first column - """ # noqa: E501 - df.rename(columns={"ScoreABO": "ABO"}, inplace=True) - - # > add longitude and latitude to df - df.loc["lon"] = longitudes - df.loc["lat"] = latitudes - - # > check which relevant scores are available; extract those from df - all_scores = df.index.tolist() - available_scores = ["lon", "lat"] # this list, will be kept - for score in scores: # scores = [[score1], [score2],...] - if score[0] in all_scores: - available_scores.append(score[0]) - else: # warn that a relevant score was not available in dataframe - print( - f"""WARNING: Score {score[0]} - not available for parameter {parameter}.""" - ) - - # reduce dataframe, s.t. only relevant scores + lon/lat are kept - df = df.loc[available_scores] - # > remove/replace missing values in dataframe with np.NaN - print("JJJJ ", relevant_header_information["Missing value code"]) - df = df.replace( - float(relevant_header_information["Missing value code"]), np.NaN + if not lt_ranges: + lt_ranges = "19-24" + for model_plots in plot_setup["model_versions"]: + for parameter, scores in plot_setup["parameter"].items(): + model_data = load_relevant_files( + input_dir, + file_prefix, + file_postfix, + debug, + model_plots, + parameter, + lt_ranges, + ltr_first=True, + transform_func=_station_score_transformation, ) - """ - # > if there are rows (= scores), that only contain np.NaN, remove them - # df = df.dropna(how="all") - - # if debug: - # print(f"Generating plot for {parameter} for lt_range: {lt_range}. (File: {file})") - # for each score in df, create one map - """ # noqa: E501 - _generate_map_plot( - data=df, - lt_range=lt_range, - variable=parameter, - file=file, - file_postfix=file_postfix, - header_dict=relevant_header_information, - model_version=model_version, + if not model_data: + print(f"No matching files found with given ltr {lt_ranges}") + return + _generate_station_plots( + plot_scores=scores, + models_data=model_data, + parameter=parameter, output_dir=output_dir, - relief=relief, debug=debug, ) +def _station_score_transformation(df, header): + df = df.replace(float(header["Missing value code"][0]), np.NaN) + df.rename(columns={"ScoreABO": "ABO"}, inplace=True) + df.loc["lon"] = list(filter(None, header["Longitude"])) + df.loc["lat"] = list(filter(None, header["Latitude"])) + return df + + # PLOTTING PIPELINE FOR STATION SCORES PLOTS def _add_features(ax): """Add features to map. @@ -223,110 +302,106 @@ def _add_features(ax): rasterized=True, color="#97b6e1", ) - - return + # ax.add_image(ShadedReliefESRI(), 8) + + +def _add_datapoints2(fig, data, score, ax, min, max, unit, param, debug=False): + # dataframes have two different structures + param = check_params(param[0]) + if param in station_score_range.columns and score in station_score_range.index: + param_score_range = station_score_range[param].loc[score] + elif ( + param in cat_station_score_range.columns + and score in cat_station_score_range[param].set_index("scores").index + ): + param_score_range = ( + cat_station_score_range[param].set_index("scores").loc[score] + ) + else: + param_score_range = {"min": None, "max": None} + lower_bound = param_score_range["min"] + upper_bound = param_score_range["max"] + if score not in data.index: + return + plot_data = data.loc[["lon", "lat", score]].astype(float) + nan_data = plot_data.loc[:, plot_data.isna().any()] + plot_data = plot_data.dropna(axis="columns") + sc = ax.scatter( + x=list(plot_data.loc["lon"]), + y=list(plot_data.loc["lat"]), + marker="o", + c=list(plot_data.loc[score]), + vmin=lower_bound, + vmax=upper_bound, + rasterized=True, + transform=ccrs.PlateCarree(), + ) + if len(plot_data.loc[score]) != 0: + max_idx = plot_data.loc[score].idxmax() + min_idx = plot_data.loc[score].idxmin() + ax.scatter( + x=[plot_data[max_idx].loc["lon"]], + y=[plot_data[max_idx].loc["lat"]], + marker="+", + color="black", + s=80, + transform=ccrs.PlateCarree(), + ) + ax.scatter( + x=[plot_data[min_idx].loc["lon"]], + y=[plot_data[min_idx].loc["lat"]], + marker="_", + color="black", + s=80, + transform=ccrs.PlateCarree(), + ) + cax = fig.add_axes( + [ + ax.get_position().x1 + 0.005, + ax.get_position().y0, + 0.008, + ax.get_position().height, + ] + ) + cbar = plt.colorbar(sc, cax=cax) + cbar.set_label(unit, rotation=270, labelpad=10) + ax.scatter( + x=list(nan_data.loc["lon"]), + y=list(nan_data.loc["lat"]), + rasterized=True, + transform=ccrs.PlateCarree(), + facecolors="none", + edgecolors="black", + linewidth=0.5, + ) def _add_datapoints(data, score, ax, min, max, unit, param, debug=False): - cat_score = False print(f"plotting:\t{param}/{score}") # check param, before trying to assign cmap to it - # i.e. param = TD_2M_KAL - param = check_params(param, debug) - # i.e. param = TD_2M* - print("Station Score Colortable") - pprint(station_score_colortable) + # pprint(station_score_colortable) print("Note: Index = Scores") - print("Cat Station Score Colortable") - pprint(cat_station_score_colortable) - - # RESOLVE CORRECT PARAMETER - try: # try to get the cmap from the regular station_score_colortable - cmap = station_score_colortable[param][score] - - # if a KeyError occurs, the current parameter - # doesn't match the columns of the station_score_colourtable df AND/OR - # because the score is not present in the station_score_colourtable df. - except KeyError: - if score not in station_score_colortable.index.tolist(): - if score in cat_station_score_colortable[param]["scores"].values: - cat_score = True - if debug: - print( - f"""{score} ∉ station score colortable. - {score} ∈ categorical station score colortable.""" - ) - index = cat_station_score_colortable[ - cat_station_score_colortable[param]["scores"] == score - ].index.values[0] - cmap = cat_station_score_colortable[param]["cmap"].iloc[index] - else: - print(f"{score} not known - check again.") - sys.exit(123) - - elif not cat_score: - try: - cmap = station_score_colortable[param][score] - except KeyError: - print("Dont know this parameter and score combination.") - - # define limits for colour bar - if not cat_score: - lower_bound = station_score_range[param]["min"].loc[score] - upper_bound = station_score_range[param]["max"].loc[score] - if cat_score: - # get the index of the current score - index = cat_station_score_range[ - cat_station_score_range[param]["scores"] == score - ].index.values[0] - lower_bound = cat_station_score_range[param]["min"].iloc[index] - upper_bound = cat_station_score_range[param]["max"].iloc[index] - - # if both are equal (i.e. 0), take the min/max values as limits - if lower_bound == upper_bound: - lower_bound = min - upper_bound = max - - # print(param, score, lower_bound, upper_bound) # dbg - tmp = False - for name, info in data.iteritems(): - lon, lat, value = float(info.lon), float(info.lat), float(info[score]) - # add available datapoints - if not np.isnan(value): - tmp = True - sc = ax.scatter( - x=lon, - y=lat, - marker="o", - c=value, - vmin=lower_bound, - vmax=upper_bound, - cmap=cmap, - rasterized=True, - transform=ccrs.PlateCarree(), - ) - - if False: # add the short name of the stations as well - ax.text( - x=lon - 0.025, - y=lat - 0.007, - s=name, - color="k", - fontsize=3, - transform=ccrs.PlateCarree(), - rasterized=True, - ) - - if tmp: - cbar = plt.colorbar(sc, ax=ax) - cbar.set_label(unit, rotation=270, labelpad=15) - return + # print("Cat Station Score Colortable") + # print(data.loc[["lon", "lat", score]]) + lower_bound = station_score_range[param[0], "min"][score] + upper_bound = station_score_range[param[0], "max"][score] + + sc = ax.scatter( + x=list(data.loc["lon"].astype(float)), + y=list(data.loc["lat"].astype(float)), + marker="o", + c=list(data.loc[score].astype(float)), + vmin=lower_bound, + vmax=upper_bound, + rasterized=True, + transform=ccrs.PlateCarree(), + ) - else: - return + cbar = plt.colorbar(sc, ax=ax, orientation="vertical", fraction=0.046, pad=0.04) + cbar.set_label(unit, rotation=270, labelpad=15) def _add_text( @@ -342,10 +417,9 @@ def _add_text( ): """Add footer and title to plot.""" footer = f"""Model: {header_dict['Model version']} | - Period: {header_dict['Start time'][0]} - {header_dict['End time'][0]} - ({lt_range}) | Min: {min_value} {header_dict['Unit']} + Period: {header_dict['Start time'][0]} - {header_dict['End time'][0]} | Min: {min_value} {header_dict['Unit']} @ {min_station} | Max: {max_value} {header_dict['Unit']} @ {max_station} - | © MeteoSwiss""" + | © MeteoSwiss""" # noqa: E501 plt.suptitle( footer, @@ -359,7 +433,7 @@ def _add_text( }, ) - title = f"{variable}: {score}" + title = f"{variable}: {score}, LT: {lt_range}" ax.set_title(title, fontsize=15, fontweight="bold") return ax @@ -403,7 +477,7 @@ def _generate_map_plot( max_station = station # plotting pipeline - fig = plt.figure(figsize=(16, 9), dpi=500) + fig = plt.figure(figsize=(14.7, 10), dpi=100) if relief: ax = plt.axes(projection=ShadedReliefESRI().crs) else: diff --git a/src/moveroplot/time_scores.py b/src/moveroplot/time_scores.py index 18cb1e3..bfa9949 100644 --- a/src/moveroplot/time_scores.py +++ b/src/moveroplot/time_scores.py @@ -1,33 +1,51 @@ # pylint: skip-file # Standard library -import datetime as dt -from pathlib import Path -from pprint import pprint +import re # Third-party +import matplotlib.dates as mdates import matplotlib.pyplot as plt import numpy as np import pandas as pd +from matplotlib.lines import Line2D + +# First-party +import moveroplot.config.plot_settings as plot_settings +from moveroplot.load_files import load_relevant_files +from moveroplot.plotting import get_total_dates_from_headers # Local -# import datetime -from .utils.atab import Atab -from .utils.check_params import check_params -from .utils.parse_plot_synop_ch import cat_time_score_range -from .utils.parse_plot_synop_ch import time_score_range +from .utils.parse_plot_synop_ch import total_score_range + + +def _time_score_transformation(df, header): + df = df.replace(float(header["Missing value code"][0]), np.NaN) + names = { + "YYYY": "year", + "MM": "month", + "DD": "day", + "hh": "hour", + "mm": "minute", + } + df["timestamp"] = pd.to_datetime( + df[["YYYY", "MM", "DD", "hh", "mm"]].rename(columns=names) + ) + df.drop( + ["YYYY", "MM", "DD", "hh", "mm", "lt_hh", "lt_mm"], + axis=1, + inplace=True, + ) + return df # enter directory / read station_scores files / call plotting pipeline def _time_scores_pipeline( - params_dict, + plot_setup, lt_ranges, file_prefix, file_postfix, input_dir, output_dir, - season, - model_version, - grid, debug, ) -> None: """Read all ATAB files that are present in: data_dir/season/model_version/<...>. @@ -49,280 +67,264 @@ def _time_scores_pipeline( """ # noqa: E501 print("---initialising time score pipeline") - for lt_range in lt_ranges: - for parameter in params_dict: - # retrieve list of scores, relevant for current parameter - scores = params_dict[parameter] # this scores is a list of lists - - # define path to the file of current parameter (station_score atab file) - file = f"{file_prefix}{lt_range}_{parameter}{file_postfix}" - path = Path(f"{input_dir}/{season}/{model_version}/{file}") - - # check if the file exists - if not path.exists(): - print( - f"""WARNING: No data file for parameter - {parameter} could be found. {path} does not exist.""" - ) - continue # for the current parameter no file could be retrieved - - if debug: - print(f"\nFilepath:\t{path}") - - # extract header & dataframe - header = Atab(file=path, sep=" ").header - df = Atab(file=path, sep=" ").data - - # cast time columns as str, so they can be combined to one datetime column - data_types_dict = {"YYYY": str, "MM": str, "DD": str, "hh": str, "mm": str} - df = df.astype(data_types_dict) - - # TODO: optimise this - it is inefficient and ugly. - # create datetime column (just called time) & drop unnecessary columns - df["timestamp"] = pd.to_datetime( - df["YYYY"] # noqa: W503 - + "-" # noqa: W503 - + df["MM"] # noqa: W503 - + "-" # noqa: W503 - + df["DD"] # noqa: W503 - + " " # noqa: W503 - + df["hh"] # noqa: W503 - + ":" # noqa: W503 - + df["mm"] # noqa: W503 - ) - """ - # dbg() - # df['timestamp_new'] = [' '.join([x + '-' + y + '-' + z + ' ' + q + ':' + r]) for x, y, z, q, r in zip(df['YYYY'], df['MM'], df['DD'], df['hh'], df['mm'])] - # df['timestamp_new'] = pd.to_datetime(df['timestamp_new']) - # df['timestamp_new_2'] = pd.to_datetime([' '.join([x + '-' + y + '-' + z + ' ' + q + ':' + r]) for x, y, z, q, r in zip(df['YYYY'], df['MM'], df['DD'], df['hh'], df['mm'])]) - # df['timestamp'] = pd.to_datetime([' '.join([x + '-' + y + '-' + z + ' ' + q + ':' + r]) for x, y, z, q, r in zip(df['YYYY'], df['MM'], df['DD'], df['hh'], df['mm'])]) - # dbg() - """ # noqa: E501 - - df.drop( - ["YYYY", "MM", "DD", "hh", "mm", "lt_hh", "lt_mm"], axis=1, inplace=True + if not lt_ranges: + lt_ranges = "19-24" + + for model_plots in plot_setup["model_versions"]: + for parameter, scores in plot_setup["parameter"].items(): + model_data = load_relevant_files( + input_dir, + file_prefix, + file_postfix, + debug, + model_plots, + parameter, + lt_ranges, + ltr_first=True, + transform_func=_time_score_transformation, ) - - # > remove/replace missing values in dataframe with np.NaN - df = df.replace(float(header["Missing value code"][0]), np.NaN) - - # > if there are columns (= scores), that only contain np.NaN, remove them - # df = df.dropna(axis=1, how="all") - - # > check which relevant scores are available; extract those from df - all_scores = df.columns.tolist() - available_scores = [ - "timestamp" - ] # this list is the columns, that should be kept - multiplot_scores = {} - for score in scores: # scores = [[score1], [score2/score3], [score4],...] - if len(score) == 1: - if score[0] in all_scores: - available_scores.append(score[0]) - else: # warn that a relevant score was not available in dataframe - print( - f"""WARNING: Score {score[0]} not available - for parameter {parameter}.""" - ) - if len(score) > 1: # currently only 2-in-1 plots are possible - # MMOD/MOBS --> MMOD:MOBS - multiplot_scores[score[0]] = score[1] - for sc in score: - if sc in all_scores: - available_scores.append(sc) - else: - print( - f"""WARNING: Score {sc} not available - for parameter {parameter}.""" - ) - - df = df[available_scores] - - if False: - print("\nFile header:") - pprint(header) - print("\nData:") - pprint(df) - - if debug: - print( - f"""Generating plot for {parameter} for - lt_range: {lt_range}. (File: {file})""" - ) - # for each score in df, create one map - _generate_timeseries_plot( - data=df, - multiplots=multiplot_scores, # { MMOD : MOBS } - lt_range=lt_range, - variable=parameter, - file=file, - file_postfix=file_postfix, - header_dict=header, + if not model_data: + print(f"No matching files found with given ltr {lt_ranges}") + return + _generate_timeseries_plots( + plot_scores=scores, + models_data=model_data, + parameter=parameter, output_dir=output_dir, - grid=grid, debug=debug, ) -# PLOTTING PIPELINE FOR TIME SCORES PLOTS +def _clear_empty_axes_if_necessary(subplot_axes, idx): + # remove empty ``axes`` instances + if idx % 2 != 1: + [ax.axis("off") for ax in subplot_axes[(idx + 1) % 2 :]] -def _generate_timeseries_plot( - data, - multiplots, - lt_range, - variable, - file, - file_postfix, - header_dict, - output_dir, - grid, - debug, -): - """Generate Timeseries Plot.""" - # output_dir = f"{output_dir}/time_scores" - if not Path(output_dir).exists(): - Path(output_dir).mkdir(parents=True, exist_ok=True) - # print(f"creating plots for file: {file}") - # extract scores, which are available in the dataframe (data) - # for each score - scores = data.columns.tolist() - scores.remove("timestamp") - - # define limits for plot (start, end time specified in header) - start = dt.datetime.strptime( - header_dict["Start time"][0] + " " + header_dict["Start time"][2], - "%Y-%m-%d %H:%M", +def _save_figure(output_dir, filename, title, fig, axes, idx): + fig.suptitle( + title, + horizontalalignment="center", + verticalalignment="top", + fontdict={ + "size": 6, + "color": "k", + }, + bbox={"facecolor": "none", "edgecolor": "grey"}, ) - end = dt.datetime.strptime( - header_dict["End time"][0] + " " + header_dict["End time"][1], "%Y-%m-%d %H:%M" - ) - unit = header_dict["Unit"][0] - - # this variable, remembers if a score has been added to another plot. - # for example in the multiplots dict - # when plotting MMOD, MOBS will also be added to the plot - # and does not need to be plotted again. - score_to_skip = None - for score in scores: - if score == score_to_skip: - continue - - param = header_dict["Parameter"][0] - # param = TD_2M_KAL - param = check_params( - param=param, verbose=debug - ) # TODO: replace param w/ variable - # param = TD_2M* - print(f"plotting:\t{param}/{score}") - - multiplt = False - title = f"{variable}: {score}" # 'variable' is the full parameter name. - footer = f"""Model: {header_dict['Model version'][0]} | - Period: {header_dict['Start time'][0]} - {header_dict['End time'][0]} - ({lt_range}) | © MeteoSwiss""" - # initialise figure/axes instance - fig, ax = plt.subplots( - 1, 1, figsize=(245 / 10, 51 / 10), dpi=150, tight_layout=True - ) + _clear_empty_axes_if_necessary(axes, idx) + fig.savefig(f"{output_dir}/{filename[:-1]}.png") + plt.close() - ax.set_xlim(start, end) - ax.set_ylabel(f"{score.upper()} ({unit})") - if grid: - ax.grid(visible=True) +def _initialize_plots(labels: list): + fig, ((ax0), (ax1)) = plt.subplots( + nrows=2, ncols=1, tight_layout=True, figsize=(10, 10), dpi=200 + ) + custom_lines = [ + Line2D([0], [0], color=plot_settings.modelcolors[i], lw=2) + for i in range(len(labels)) + ] + fig.legend( + custom_lines, + labels, + loc="upper right", + ncol=1, + frameon=False, + ) + plt.tight_layout(w_pad=8, h_pad=5, rect=(0.05, 0.05, 0.90, 0.90)) + return fig, [ax0, ax1] - if debug: - print(f"Extract dataframe for score: {score}") - pprint(data) - x = data["timestamp"].values - y = data[score].values +# PLOTTING PIPELINE FOR TOTAL SCORES PLOTS +def _set_ylim(param, score, ax, debug): # pylint: disable=unused-argument + # define limits for yaxis if available + regular_param = (param, "min") in total_score_range.columns + regular_scores = score in total_score_range.index - if score in multiplots.keys(): - y2 = data[multiplots[score]].values - multiplt = True - score_to_skip = multiplots[score] - title = f"{variable}: {score}/{multiplots[score]}" - ax.set_ylabel(f"{score.upper()}/{multiplots[score].upper()} ({unit})") + if regular_param and regular_scores: + lower_bound = total_score_range[param]["min"].loc[score] + upper_bound = total_score_range[param]["max"].loc[score] + if lower_bound != upper_bound: + ax.set_ylim(lower_bound, upper_bound) - # plot dashed line @ 0 - ax.plot(x, [0] * len(x), color="grey", linestyle="--") - # define limits for yaxis if available - regular_param = (param, "min") in time_score_range.columns - regular_score = score in time_score_range.index - cat_score = not regular_score +def _customise_ax(parameter, scores, x_ticks, grid, ax): + """Apply cosmetics to current ax. - if regular_param and regular_score: - lower_bound = time_score_range[param]["min"].loc[score] - upper_bound = time_score_range[param]["max"].loc[score] - if debug: - print( - f"found limits for {param}/{score} --> {lower_bound}/{upper_bound}" + Args: + parameter (str): current parameter + score (str): current score + x_ticks (list): list of x-ticks labels (lead time ranges, as strings) + grid (bool): add grid to ax + ax (Axes): current ax + + """ + if grid: + ax.grid(which="major", color="#DDDDDD", linewidth=0.8) + ax.grid(which="minor", color="#EEEEEE", linestyle=":", linewidth=0.5) + ax.minorticks_on() + + ax.tick_params(axis="both", which="major", labelsize=8) + ax.tick_params(axis="both", which="minor", labelsize=6) + ax.set_title(f"{parameter}: {','.join(scores)}") + ax.set_xlabel("Lead-Time Range (h)") + # plotting too many data on the x-axis + steps = len(x_ticks) // 7 + skip_indices = slice(None, None, steps) if steps > 0 else slice(None) + ax.set_xticks(range(len(x_ticks))[skip_indices], x_ticks[skip_indices]) + ax.autoscale(axis="y") + + +def _plot_and_save_scores( + output_dir, + base_filename, + parameter, + plot_scores_setup, + sup_title, + ltr_models_data, + debug=False, +): + for ltr, models_data in ltr_models_data.items(): + fig, subplot_axes = _initialize_plots(ltr_models_data[ltr].keys()) + headers = [data["header"] for data in models_data.values()] + total_start_date, total_end_date = get_total_dates_from_headers(headers) + title_base = f"{parameter.upper()}: " + model_info = ( + f" {list(models_data.keys())[0]}" if len(models_data.keys()) == 1 else "" + ) + x_label_base = f"""{total_start_date.strftime("%Y-%m-%d %H:%M")} - {total_end_date.strftime("%Y-%m-%d %H:%M")}""" # noqa: E501 + filename = base_filename + f"_{ltr}" + pattern = ( + re.search(r"\(.*?\)", next(iter(plot_scores_setup))[0]) + if plot_scores_setup + else None + ) + prev_threshold = None + if pattern is not None: + prev_threshold = pattern.group() + current_threshold = prev_threshold + current_plot_idx = 0 + + for idx, score_setup in enumerate(plot_scores_setup): + prev_threshold = current_threshold + pattern = re.search(r"\(.*?\)", next(iter(score_setup))) + current_threshold = pattern.group() if pattern is not None else None + different_threshold = prev_threshold != current_threshold + if different_threshold: + _clear_empty_axes_if_necessary(subplot_axes, current_plot_idx - 1) + fig.suptitle( + sup_title, + horizontalalignment="center", + verticalalignment="top", + fontdict={ + "size": 6, + "color": "k", + }, + bbox={"facecolor": "none", "edgecolor": "grey"}, ) - if lower_bound != upper_bound: - ax.set_ylim(lower_bound, upper_bound) - - if cat_score: - # get the index of the current score - index = cat_time_score_range[ - cat_time_score_range[param]["scores"] == score - ].index.values[0] - # get min/max value - lower_bound = cat_time_score_range[param]["min"].iloc[index] - upper_bound = cat_time_score_range[param]["max"].iloc[index] - if debug: - print( - f"found limits for {param}/{score} --> {lower_bound}/{upper_bound}" + fig.savefig(f"{output_dir}/{filename}.png") + plt.close() + filename = base_filename + f"_{ltr}" + fig, subplot_axes = _initialize_plots(ltr_models_data[ltr].keys()) + current_plot_idx += current_plot_idx % 2 + + title = title_base + ",".join(score_setup) + model_info + f" LT: {ltr}" + ax = subplot_axes[current_plot_idx % 2] + for model_idx, data in enumerate(models_data.values()): + model_plot_color = plot_settings.modelcolors[model_idx] + header = data["header"] + unit = header["Unit"][0] + x_int = data["df"][["timestamp"]] + y_label = ",".join(score_setup) + ax.set_ylabel(f"{y_label.upper()} ({unit})") + ax.set_xlabel(x_label_base) + ax.set_title(title) + for score_idx, score in enumerate(score_setup): + score_values = data["df"][[score]] + ax.plot( + np.asarray(x_int, dtype="datetime64[s]"), + score_values, + color=model_plot_color, + linestyle=plot_settings.line_styles[score_idx], + fillstyle="none", + label=f"{score.upper()}", + ) + ax.tick_params(axis="both", which="major", labelsize=8) + ax.tick_params(axis="both", which="minor", labelsize=6) + ax.autoscale(axis="y") + ax.xaxis.set_major_formatter(mdates.DateFormatter("%b %d\n%H:%M")) + if len(score_setup) > 1: + sub_plot_legend = ax.legend( + score_setup, + loc="upper right", + markerscale=0.9, + bbox_to_anchor=(1.1, 1.05), ) - if lower_bound != upper_bound: - ax.set_ylim(lower_bound, upper_bound) - - label = f"{score.upper()}" - if not multiplt: - ax.plot( - x, - y, - color="k", - linestyle="-", - label=label, - ) - if multiplt: - ax.plot( - x, - y, - color="red", - linestyle="-", - label=label, - ) - label = f"{multiplots[score].upper()}" - ax.plot( - x, - y2, - color="k", - linestyle="-", - label=label, - ) - # change title, y-axis label, filename here, for the multiplot case - - plt.legend() - - plt.suptitle( - footer, - x=0.0215, - y=0.908, - horizontalalignment="left", - verticalalignment="top", - fontdict={ - "size": 6, - "color": "k", - }, - ) - ax.set_title(label=title) + for line in sub_plot_legend.get_lines(): + line.set_color("black") + filename += "_" + "_".join(score_setup) + + if current_plot_idx % 2 == 1 or idx == len(plot_scores_setup) - 1: + _clear_empty_axes_if_necessary(subplot_axes, current_plot_idx) + fig.suptitle( + sup_title, + horizontalalignment="center", + verticalalignment="top", + fontdict={ + "size": 6, + "color": "k", + }, + bbox={"facecolor": "none", "edgecolor": "grey"}, + ) + fig.savefig(f"{output_dir}/{filename}.png") + plt.close() + filename = base_filename + f"_{ltr}" + fig, subplot_axes = _initialize_plots(ltr_models_data[ltr].keys()) + current_plot_idx += 1 - print(f"saving:\t\t{output_dir}/{file.split(file_postfix)[0]}_{score}.png") - plt.savefig(f"{output_dir}/{file.split(file_postfix)[0]}_{score}.png") - plt.close(fig) - return +# PLOTTING PIPELINE FOR TIME SCORES PLOTS +def _generate_timeseries_plots( + plot_scores, + models_data, + parameter, + output_dir, + debug, +): + model_versions = list(models_data.keys()) + + # initialise filename + base_filename = ( + f"time_scores_{model_versions[0]}_{parameter}" + if len(model_versions) == 1 + else f"time_scores_{parameter}" + ) + headers = [ + data["header"] for data in models_data[next(iter(models_data.keys()))].values() + ] + total_start_date, total_end_date = get_total_dates_from_headers(headers) + # pylint: disable=line-too-long + period_info = f"""Period: {total_start_date.strftime("%Y-%m-%d %H:%M")} - {total_end_date.strftime("%Y-%m-%d %H:%M")} | © MeteoSwiss""" # noqa: E501 + # pylint: enable=line-too-long + sup_title = f"{parameter}: " + period_info + # plot regular scores + _plot_and_save_scores( + output_dir, + base_filename, + parameter, + plot_scores["regular_scores"], + sup_title, + models_data, + debug=debug, + ) + + _plot_and_save_scores( + output_dir, + base_filename, + parameter, + plot_scores["cat_scores"], + sup_title, + models_data, + debug=debug, + ) diff --git a/src/moveroplot/total_scores.py b/src/moveroplot/total_scores.py index 0dd3ed9..baf4820 100644 --- a/src/moveroplot/total_scores.py +++ b/src/moveroplot/total_scores.py @@ -1,16 +1,20 @@ -"""Calculate Total score from parsed data.""" +"""Calculate total scores from parsed data.""" # Standard library -from pathlib import Path -from pprint import pprint # noqa: E402¨ +import re # Third-party import matplotlib.pyplot as plt import numpy as np +from matplotlib.lines import Line2D + +# First-party +from moveroplot.config import plot_settings # Local +from .load_files import load_relevant_files +from .plotting import get_total_dates_from_headers + # pylint: disable=no-name-in-module -from .utils.atab import Atab -from .utils.check_params import check_params from .utils.parse_plot_synop_ch import total_score_range # pylint: enable=no-name-in-module @@ -24,81 +28,22 @@ ) -def collect_relevant_files(file_prefix, file_postfix, debug, source_path, parameter): - """Collect all files corresponding to current parameter in 'corresponding_files_dict'. - - Args: - file_prefix (str): prefix of files we're looking for (i.e. total_scores) - file_postfix (str): postfix of files we're looking for (i.e. .dat) - debug (bool): add debug messages command prompt - source_path (Path): path to directory, where source files are - parameter (str): parameter of interest - - Returns: - dict: dictionary conrains all available lead time range dataframes for parameter - # collect the files to this parameter in the corresponding files dict. - # the keys in this dict are the available lead time ranges for the current parameter. - - """ # noqa: E501 - corresponding_files_dict = {} - - # for dbg purposes: - files_list = [] - for file_path in source_path.glob(f"{file_prefix}*{parameter}{file_postfix}"): - if file_path.is_file(): - # check, that the corresponding path belongs to a file - # and not to a sub-directory - # lt_range = key for corresponding_files_dict - # TODO:change here, if ltr is longer than 5 chars - lt_range = file_path.name[ - len(file_prefix) : len(file_prefix) + 5 # noqa: E203 - ] - # extract header & dataframe - header = Atab(file=file_path, sep=" ").header - df = Atab(file=file_path, sep=" ").data - - # clean df - df = df.replace(float(header["Missing value code"][0]), np.NaN) - - df.set_index(keys="Score", inplace=True) - - # add information to dict - corresponding_files_dict[lt_range] = { - # 'path':file_path, - "header": header, - "df": df, - } - - # add path of file to list of relevant files - files_list.append(file_path) - if debug: - print(f"\nFor parameter: {parameter} these files are relevant:\n") - pprint(files_list) - print( - f"""\nFiles have been parsed & combined in the 'corresponding_files_dict'. - Each key (lt-range) has a subdict with two keys: - {corresponding_files_dict['19-24'].keys()}\n""" # noqa: E501 - ) - - return corresponding_files_dict +def _total_score_transformation(df, header): + df = df.replace(float(header["Missing value code"][0]), np.NaN) + df.set_index(keys="Score", inplace=True) + return df +# pylint: disable=too-many-arguments,too-many-locals # enter directory / read total_scores files / call plotting pipeline # pylint: disable=pointless-string-statement,too-many-arguments,too-many-locals def _total_scores_pipeline( - params_dict, - plot_scores, - plot_params, - plot_cat_scores, - plot_cat_params, - plot_cat_thresh, + plot_setup, + lt_ranges, file_prefix, file_postfix, input_dir, output_dir, - season, - model_version, - grid, debug, ) -> None: # pylint: disable=line-too-long @@ -114,8 +59,7 @@ def _total_scores_pipeline( file_postfix (str): postfix of files (i.e. '.dat') input_dir (str): directory to seasons (i.e. /scratch/osm/movero/wd) output_dir (str): output directory (i.e. plots/) - season (str): season of interest (i.e. 2021s4/) - model_version (str): model_version of interest (i.e. C-1E_ch) + model_versions (str): model_versions of interest (i.e. C-1E_ch) scores (list): list of scores, for which plots should be generated debug (bool): print further comments & debug statements @@ -123,52 +67,47 @@ def _total_scores_pipeline( # pylint: enable=line-too-long print("\n--- initialising total scores pipeline") # tmp; define debug = True, to show debug statements for total_scores only - debug = True - - source_path = Path(f"{input_dir}/{season}/{model_version}") - for parameter in params_dict: - corresponding_files_dict = collect_relevant_files( - file_prefix, file_postfix, debug, source_path, parameter - ) - - # pass dict to plotting pipeline - _generate_total_scores_plot( - data=corresponding_files_dict, - parameter=parameter, - plot_scores=plot_scores, - plot_params=plot_params, - plot_cat_scores=plot_cat_scores, - plot_cat_params=plot_cat_params, - plot_cat_thresh=plot_cat_thresh, - # TODO: add ens params/thresh/scores - output_dir=output_dir, - grid=grid, - debug=debug, - ) + # debug = True + for model_plots in plot_setup["model_versions"]: + for parameter, scores in plot_setup["parameter"].items(): + model_data = {} + model_data = load_relevant_files( + input_dir, + file_prefix, + file_postfix, + debug, + model_plots, + parameter, + lt_ranges, + ltr_first=False, + transform_func=_total_score_transformation, + ) + if not model_data: + print(f"No matching files found with given ltr {lt_ranges}") + return + _generate_total_scores_plots( + plot_scores=scores, + models_data=model_data, + parameter=parameter, + output_dir=output_dir, + debug=debug, + ) # PLOTTING PIPELINE FOR TOTAL SCORES PLOTS - - def _set_ylim(param, score, ax, debug): # pylint: disable=unused-argument # define limits for yaxis if available regular_param = (param, "min") in total_score_range.columns - regular_score = score in total_score_range.index + regular_scores = score in total_score_range.index - if regular_param and regular_score: + if regular_param and regular_scores: lower_bound = total_score_range[param]["min"].loc[score] upper_bound = total_score_range[param]["max"].loc[score] - # if debug: - # print( - # f"found limits for {param}/{score} --> {lower_bound}/{upper_bound}" - # ) if lower_bound != upper_bound: ax.set_ylim(lower_bound, upper_bound) - # TODO: add computation of y-lims for cat & ens scores - -def _customise_ax(parameter, score, x_ticks, grid, ax): +def _customise_ax(parameter, scores, x_ticks, grid, ax): """Apply cosmetics to current ax. Args: @@ -186,289 +125,219 @@ def _customise_ax(parameter, score, x_ticks, grid, ax): ax.tick_params(axis="both", which="major", labelsize=8) ax.tick_params(axis="both", which="minor", labelsize=6) - ax.set_title(f"{parameter}: {score}") + ax.set_title(f"{parameter}: {','.join(scores)}") ax.set_xlabel("Lead-Time Range (h)") - ax.legend(fontsize=6) - ax.set_xticks(range(len(x_ticks)), x_ticks) + # plotting too many data on the x-axis + steps = len(x_ticks) // 5 + skip_indices = slice(None, None, steps) if steps > 0 else slice(None) + ax.set_xticks(range(len(x_ticks))[skip_indices], x_ticks[skip_indices]) + ax.autoscale(axis="y") -def _clear_empty_axes(subplot_axes, idx): +def _clear_empty_axes_if_necessary(subplot_axes, idx): # remove empty ``axes`` instances - i = 1 - while (idx % 4 + i) < 4: - ax = subplot_axes[idx % 4 + i] - ax.axis("off") - i += 1 + if (idx + 1) % 4 != 0: + for ax in subplot_axes[(idx + 1) % 4 :]: + ax.axis("off") -def _save_figure(output_dir, filename): - print(f"---\t\tsaving: {output_dir}/{filename[:-1]}.png") - plt.savefig(f"{output_dir}/{filename[:-1]}.png") - plt.clf() +def _initialize_plots(lines: list[Line2D], labels: list): + fig, ((ax0, ax1), (ax2, ax3)) = plt.subplots( + nrows=2, ncols=2, tight_layout=True, figsize=(10, 10), dpi=200 + ) + fig.legend( + lines, + labels, + loc="upper right", + ncol=1, + frameon=False, + ) + plt.tight_layout(w_pad=8, h_pad=3, rect=(0.05, 0.05, 0.90, 0.90)) + return fig, [ax0, ax1, ax2, ax3] + + +def _save_figure(output_dir, filename, title, fig, axes, idx): + fig.suptitle( + title, + horizontalalignment="center", + verticalalignment="top", + fontdict={ + "size": 6, + "color": "k", + }, + bbox={"facecolor": "none", "edgecolor": "grey"}, + ) + _clear_empty_axes_if_necessary(axes, idx) + fig.savefig(f"{output_dir}/{filename[:-1]}.png") + plt.close() -# pylint: disable=too-many-branches,too-many-statements -# pylint: disable=too-many-arguments,too-many-locals -def _generate_total_scores_plot( - data, - parameter, - plot_params, - plot_scores, - plot_cat_scores, - plot_cat_params, - plot_cat_thresh, +def _plot_and_save_scores( output_dir, - grid, # pylint: disable=unused-argument - debug, + base_filename, + parameter, + plot_scores_setup, + sup_title, + models_data, + models_color_lines, + debug=False, ): - """Generate Total Scores Plot.""" - if debug: - print("--- starting plotting pipeline") - print("---\t1) map parameter (i.e. TD_2M_KAL --> TD_2M*)") - - # get correct parameter, i.e. if parameter=T_2M_KAL --> param=T_2M* - param = check_params( - param=parameter, verbose=False - ) # TODO: change False back to debug + filename = base_filename + fig, subplot_axes = _initialize_plots(models_color_lines, models_data.keys()) - if debug: - print("---\t2) check if output_dir exists (& create it if necessary)") - # check (&create) output directory for total scores plots - # output_dir = f"{output_dir}/total_scores" - if not Path(output_dir).exists(): - Path(output_dir).mkdir(parents=True, exist_ok=True) - - if debug: - print("---\t3) initialise figure with a 2x2 subplots grid") - # create 2x2 subplot grid - _, ((ax0, ax1), (ax2, ax3)) = plt.subplots( - nrows=2, ncols=2, tight_layout=True, figsize=(10, 10), dpi=200 + pattern = ( + re.search(r"\(.*?\)", next(iter(plot_scores_setup))[0]) + if plot_scores_setup + else None ) - - subplot_axes = { - 0: ax0, - 1: ax1, - 2: ax2, - 3: ax3, - } # hash map to access correct axes later on - - # ltr_unsorted -> unsorted lead time ranges - # ltr_sorted -> sorted lead time ranges (used for x-tick-labels later on ) - ltr_unsorted, ltr_sorted = list(data.keys()), [] - ltr_start_times_sorted = sorted([int(lt.split("-")[0]) for lt in ltr_unsorted]) - for idx, ltr_start in enumerate(ltr_start_times_sorted): - for ltr in ltr_unsorted: - if ltr.startswith(str(ltr_start).zfill(2)): - ltr_sorted.insert(idx, ltr) - - # re-name & create x_int list, s.t. np.arrays are plottet against each other - x_ticks = ltr_sorted - x_int = list(range(len(ltr_sorted))) - - if debug: - print("---\t4) create x-axis") - print("---\t\tUnsorted ltr list:\t\t{ltr_unsorted}") - print("---\t\tSorted ltr start times list:\t{ltr_start_times_sorted}") - print("---\t\tSorted ltr list (= x-ticks):\t{ltr_sorted}") - print("---\t\tx_int =\t\t\t\t{x_int}") - - # extract header from data & create title - header = data[ltr_sorted[-1]]["header"] - footer = f"""Model: {header['Model version'][0]} | - Period: {header['Start time'][0]} - {header['End time'][0]} | © MeteoSwiss""" - unit = header["Unit"][0] - - # initialise filename - filename = f"total_scores_{parameter}_" - - # scores = params_dict[parameter] # this is a list of lists - - # REGULAR SCORES PLOTTING PIPELINE - if debug: - print( - """---\t5) plot REGULAR parameter/scores - (Because, regular & categorical scores - should not be mixed on the same figure.)""" - ) - - if plot_scores and plot_params: - regular_scores = plot_scores.split(",") - # the idx of a score, maps the score to the corresponding subplot axes instance - for idx, score in enumerate(regular_scores): - # if debug: - # print(f"--- plotting:\t{param}/{score}") - - multiplt = False - - # save filled figure & re-set necessary for next iteration - if idx > 0 and idx % 4 == 0: - # add title to figure - plt.suptitle( - footer, - horizontalalignment="center", - verticalalignment="top", - fontdict={ - "size": 6, - "color": "k", - }, - bbox={"facecolor": "none", "edgecolor": "grey"}, - ) - _save_figure(output_dir=output_dir, filename=filename) - _, ((ax0, ax1), (ax2, ax3)) = plt.subplots( - nrows=2, ncols=2, tight_layout=True, figsize=(10, 10), dpi=200 - ) - subplot_axes = {0: ax0, 1: ax1, 2: ax2, 3: ax3} - # reset filename - filename = f"total_scores_{parameter}_" - + prev_threshold = None + if pattern is not None: + prev_threshold = pattern.group() + current_threshold = prev_threshold + current_plot_idx = 0 + for idx, score_setup in enumerate(plot_scores_setup): + prev_threshold = current_threshold + pattern = re.search(r"\(.*?\)", next(iter(score_setup))) + current_threshold = pattern.group() if pattern is not None else None + different_threshold = prev_threshold != current_threshold + if different_threshold: + _save_figure( + output_dir, filename, sup_title, fig, subplot_axes, current_plot_idx - 1 + ) + fig, subplot_axes = _initialize_plots( + models_color_lines, models_data.keys() + ) + filename = base_filename + current_plot_idx += current_plot_idx % 4 + for model_idx, data in enumerate(models_data.values()): + model_plot_color = plot_settings.modelcolors[model_idx] + # sorted lead time ranges + ltr_sorted = sorted(list(data.keys()), key=lambda x: int(x.split("-")[0])) + x_int = list(range(len(ltr_sorted))) + + # extract header from data & create title + header = data[ltr_sorted[-1]]["header"] + unit = header["Unit"][0] # get ax, to add plot to - ax = subplot_axes[idx % 4] - ax.set_xlim(x_int[0], x_int[-1]) - ax.set_ylabel(f"{score.upper()} ({unit})") - - # plot two scores on one sub-plot - if "/" in score: - multiplt = True - scores = score.split("/") - filename += f"{scores[0]}_{scores[1]}_" - _set_ylim(param=param, score=scores[0], ax=ax, debug=debug) - - # get y0, y1 from dfs - if debug and idx == 0: - print( - f"""---\t6) collect the data corresponding to {score} - from all dataframes in the data dict in y-list""" - ) - print("---\t7) plot y-list against x_int") - y0, y1 = [], [] - for ltr in ltr_sorted: - y0.append(data[ltr]["df"]["Total"].loc[scores[0]]) - y1.append(data[ltr]["df"]["Total"].loc[scores[1]]) - - # plot y0, y1 - ax.plot( - x_int, - y0, - color="red", - linestyle="-", - marker="^", - fillstyle="none", - label=f"{scores[0].upper()}", + ax = subplot_axes[current_plot_idx % 4] + y_label = ",".join(score_setup) + ax.set_ylabel(f"{y_label.upper()} ({unit})") + + if len(score_setup) > 2: + raise ValueError( + f"""Maximum two scores are allowed in one plot. + Got {len(score_setup)}""" ) + for score_idx, score in enumerate(score_setup): + if model_idx == 0: + filename += f"{score}_" + _set_ylim(param=parameter, score=score_setup[0], ax=ax, debug=debug) + y_values = [data[ltr]["df"]["Total"].loc[score] for ltr in ltr_sorted] ax.plot( x_int, - y1, - color="k", - linestyle="-", + y_values, + color=model_plot_color, + linestyle=plot_settings.line_styles[score_idx], marker="D", fillstyle="none", - label=f"{scores[1].upper()}", + label=f"{score_setup[0].upper()}", ) - # plot single score on sub-plot - if not multiplt: - filename += f"{score}_" - _set_ylim(param=param, score=score, ax=ax, debug=debug) - - y = [] - # extract y from different dfs - for ltr in ltr_sorted: - ltr_score = data[ltr]["df"]["Total"].loc[score] - y.append(ltr_score) - - ax.plot( - x_int, - y, - color="k", - linestyle="-", - marker="D", - fillstyle="none", - label=f"{score.upper()}", + # Generate a legend if two scores in one subplot + if len(score_setup) > 1: + sub_plot_legend = ax.legend( + score_setup, + loc="upper right", + markerscale=0.9, + bbox_to_anchor=(1.35, 1.05), ) + # make lines in the legend always black + for line in sub_plot_legend.get_lines(): + line.set_color("black") - # customise grid, title, xticks, legend of current ax - _customise_ax( - parameter=parameter, score=score, x_ticks=x_ticks, grid=True, ax=ax + # customise grid, title, xticks, legend of current ax + _customise_ax( + parameter=parameter, + scores=score_setup, + x_ticks=ltr_sorted, + grid=True, + ax=ax, + ) + + # save filled figure & re-set necessary for next iteration + full_figure = current_plot_idx > 0 and (current_plot_idx + 1) % 4 == 0 + last_plot = idx == len(plot_scores_setup) - 1 + if full_figure or last_plot: + _save_figure( + output_dir, filename, sup_title, fig, subplot_axes, current_plot_idx ) + fig, subplot_axes = _initialize_plots( + models_color_lines, models_data.keys() + ) + filename = base_filename - # save figure, if this is the last score - if idx == len(regular_scores) - 1: - # add title to figure - plt.suptitle( - footer, - horizontalalignment="center", - verticalalignment="top", - fontdict={ - "size": 6, - "color": "k", - }, - bbox={"facecolor": "none", "edgecolor": "grey"}, - ) - # clear empty subplots - _clear_empty_axes(subplot_axes=subplot_axes, idx=idx) - # save & clear figure - _save_figure(output_dir=output_dir, filename=filename) + current_plot_idx += 1 - # CATEGORICAL SCORES PLOTTING PIPELINE - # remark: include thresholds for categorical scores - if debug: - print( - """---\t10) repeat plotting pipeline for categorical - params/scores/thresh combinations""" - ) - print(plot_cat_params) - print(plot_cat_scores) - print(plot_cat_thresh) - print(plot_cat_params and plot_cat_scores and plot_cat_thresh) - - if plot_cat_params and plot_cat_scores and plot_cat_thresh: - print("--- should now create total scores plots for all cat params/scores") - cat_params = plot_cat_params.split( - "," - ) # pylint: disable=pointless-string-statement - """ - categorical parameters: - # TOT_PREC12,TOT_PREC6,TOT_PREC1,CLCT, - # T_2M,T_2M_KAL,TD_2M,TD_2M_KAL,FF_10M, - # FF_10M_KAL,VMAX_10M6,VMAX_10M1 - """ - cat_scores = plot_cat_scores.split( - "," - ) # categorical scores: FBI,MF,POD,FAR,THS,ETS - cat_threshs = plot_cat_thresh.split(":") # categorical thresholds: - # 0.1,1,10:0.2,1,5:0.2,0.5,2:2.5,6.5:0,15,25:0,15, - # 25:-5,5,15:-5,5,15:2.5,5,10:2.5,5,10:5,12.5,20:5,12.5,20 # noqa: E501 - cat_params_dict = {cat_param: [] for cat_param in cat_params} - for param, threshs in zip(cat_params, cat_threshs): - # first append all scores w/o thresholds to parameter - for score in plot_scores.split(","): - if "/" in score: - cat_params_dict[param].append(score.split("/")) - else: - cat_params_dict[param].append([score]) - # append all scores with threshold in their name # noqa: E501 - thresholds = threshs.split(",") - for threshold in thresholds: - for score in cat_scores: - if "/" in score: - cat_params_dict[param].append( - [x + f"({threshold})" for x in score.split("/")] - ) - - else: - cat_params_dict[param].append([f"{score}({threshold})"]) - - if True: # pylint: disable=using-constant-test - print("Categorical Parameter Dict: ") - pprint(cat_params_dict) - - # TODO: implement the total scores pipeline for categorical scores as well - - # ENSEMBLE SCORES PLOTTING PIPELINE - # remark: include thresholds for categorical scores - print( - """---\t11) repeat plotting pipeline for ensemble - params/scores/thresh combinations""" +def _generate_total_scores_plots( + plot_scores, + models_data, + parameter, + output_dir, + debug, +): + """Generate Total Score Plots.""" + model_plot_colors = plot_settings.modelcolors + model_versions = list(models_data.keys()) + custom_lines = [ + Line2D([0], [0], color=model_plot_colors[i], lw=2) + for i in range(len(model_versions)) + ] + + # initialise filename + base_filename = ( + f"total_scores_{model_versions[0]}_{parameter}_" + if len(model_versions) == 1 + else f"total_scores_{parameter}_" + ) + headers = [ + data[sorted(list(data.keys()), key=lambda x: int(x.split("-")[0]))[-1]][ + "header" + ] + for data in models_data.values() + ] + total_start_date, total_end_date = get_total_dates_from_headers(headers) + + model_info = ( + "" + if len(model_versions) > 1 + else f"Model: {headers[0]['Model version'][0]} | \n" + ) + # pylint: disable=line-too-long + period_info = f"""Period: {total_start_date.strftime("%Y-%m-%d")} - {total_end_date.strftime("%Y-%m-%d")} | © MeteoSwiss""" # noqa: E501 + # pylint: enable=line-too-long + sup_title = model_info + period_info + if debug: + print("Try to generate total score plots.") + # plot regular scores + _plot_and_save_scores( + output_dir, + base_filename, + parameter, + plot_scores["regular_scores"], + sup_title, + models_data, + custom_lines, + debug=False, + ) + # plot categorial scores + _plot_and_save_scores( + output_dir, + base_filename, + parameter, + plot_scores["cat_scores"], + sup_title, + models_data, + custom_lines, + debug=False, ) - # TODO: implement the total scores pipeline for categorical scores as well diff --git a/src/moveroplot/utils/parse_plot_synop_ch.py b/src/moveroplot/utils/parse_plot_synop_ch.py index bb916f4..ea99ed6 100644 --- a/src/moveroplot/utils/parse_plot_synop_ch.py +++ b/src/moveroplot/utils/parse_plot_synop_ch.py @@ -22,10 +22,11 @@ import pandas as pd verbose = False -path = Path(Path.cwd() / "src/moveroplot/utils/plot_synop_ch") +path = Path(__file__).with_name("plot_synop_ch") +# pylint: disable=unspecified-encoding # open plot_synop_ch file -with open(path, "r") as f: # pylint: disable=unspecified-encoding +with open(path, "r") as f: lines = [line.strip() for line in f.readlines()] # VERIFICATION SCORES; DATAFRAMES FOR SCORE RANGES AND COLOUR TABLE @@ -264,11 +265,9 @@ # ['param1_scores', 'param1_min', 'param1_max', # 'param2_scores', 'param2_min', 'param2_max',...] cat_station_score_range = cat_station_score_range[cat_columns_tmp] - # now that the columns are in the correct order, # create subcolumns (scores, min, max) for each parameter cat_station_score_range.columns = cat_columns # type: ignore - if verbose: print("\n Categorical Station Score Ranges") pprint(cat_station_score_range) diff --git a/tests/test_moveroplot/setup/references.py b/tests/test_moveroplot/setup/references.py new file mode 100644 index 0000000..4881236 --- /dev/null +++ b/tests/test_moveroplot/setup/references.py @@ -0,0 +1,150 @@ +"""Plot_setup references for tests.""" +# Standard library +from typing import Any +from typing import Dict + +DEFAULT_PLOT_SETUP: Dict[str, Any] = { + "model_versions": [["C-1E_ch", "C-1E-CTR_ch"], ["C-2E_alps"]], + "parameter": { + "VMAX_10M6": { + "regular_scores": [], + "cat_scores": [ + ["FBI(5)"], + ["MF(5)", "POD(5)"], + ["FAR(5)"], + ["THS(5)"], + ["ETS(5)"], + ], + "regular_ens_scores": [["OUTLIERS"], ["RANK"], ["RPS"], ["RPS_REF"]], + "ens_cat_scores": [ + ["REL(5)"], + ["RES(5)"], + ["BS(5)"], + ["BS_REF(5)"], + ["BSS(5)"], + ["BSSD(5)"], + ["REL_DIA(5)"], + ], + }, + "TOT_PREC6": { + "regular_scores": [["ME"], ["MMOD", "MOBS"], ["MAE"]], + "cat_scores": [ + ["FBI(0.2)"], + ["MF(0.2)", "POD(0.2)"], + ["FAR(0.2)"], + ["THS(0.2)"], + ["ETS(0.2)"], + ], + "regular_ens_scores": [["OUTLIERS"], ["RANK"], ["RPS"], ["RPS_REF"]], + "ens_cat_scores": [ + ["REL(0.2)"], + ["RES(0.2)"], + ["BS(0.2)"], + ["BS_REF(0.2)"], + ["BSS(0.2)"], + ["BSSD(0.2)"], + ["REL_DIA(0.2)"], + ], + }, + "T_2M": { + "regular_scores": [["ME"], ["MMOD", "MOBS"], ["MAE"]], + "cat_scores": [ + ["FBI(0)"], + ["MF(0)", "POD(0)"], + ["FAR(0)"], + ["THS(0)"], + ["ETS(0)"], + ], + "regular_ens_scores": [["OUTLIERS"], ["RANK"], ["RPS"], ["RPS_REF"]], + "ens_cat_scores": [ + ["REL(0)"], + ["RES(0)"], + ["BS(0)"], + ["BS_REF(0)"], + ["BSS(0)"], + ["BSSD(0)"], + ["REL_DIA(0)"], + ], + }, + "FF_10M": { + "regular_scores": [], + "cat_scores": [ + ["FBI(2.5)"], + ["MF(2.5)", "POD(2.5)"], + ["FAR(2.5)"], + ["THS(2.5)"], + ["ETS(2.5)"], + ], + "regular_ens_scores": [["OUTLIERS"], ["RANK"], ["RPS"], ["RPS_REF"]], + "ens_cat_scores": [ + ["REL(2.5)"], + ["RES(2.5)"], + ["BS(2.5)"], + ["BS_REF(2.5)"], + ["BSS(2.5)"], + ["BSSD(2.5)"], + ["REL_DIA(2.5)"], + ], + }, + "CLCT": { + "regular_scores": [["ME"], ["MMOD", "MOBS"], ["MAE"]], + "cat_scores": [ + ["FBI(2.5)"], + ["MF(2.5)", "POD(2.5)"], + ["FAR(2.5)"], + ["THS(2.5)"], + ["ETS(2.5)"], + ], + "regular_ens_scores": [["OUTLIERS"], ["RANK"], ["RPS"], ["RPS_REF"]], + "ens_cat_scores": [ + ["REL(2.5)"], + ["RES(2.5)"], + ["BS(2.5)"], + ["BS_REF(2.5)"], + ["BSS(2.5)"], + ["BSSD(2.5)"], + ["REL_DIA(2.5)"], + ], + }, + "TOT_PREC12": { + "regular_scores": [["ME"], ["MMOD", "MOBS"], ["MAE"]], + "cat_scores": [ + ["FBI(0.1)"], + ["MF(0.1)", "POD(0.1)"], + ["FAR(0.1)"], + ["THS(0.1)"], + ["ETS(0.1)"], + ], + "regular_ens_scores": [["OUTLIERS"], ["RANK"], ["RPS"], ["RPS_REF"]], + "ens_cat_scores": [ + ["REL(0.1)"], + ["RES(0.1)"], + ["BS(0.1)"], + ["BS_REF(0.1)"], + ["BSS(0.1)"], + ["BSSD(0.1)"], + ["REL_DIA(0.1)"], + ], + }, + "TD_2M": { + "regular_scores": [["ME"], ["MMOD", "MOBS"], ["MAE"]], + "cat_scores": [ + ["FBI(0)"], + ["MF(0)", "POD(0)"], + ["FAR(0)"], + ["THS(0)"], + ["ETS(0)"], + ], + "regular_ens_scores": [["OUTLIERS"], ["RANK"], ["RPS"], ["RPS_REF"]], + "ens_cat_scores": [ + ["REL(0)"], + ["RES(0)"], + ["BS(0)"], + ["BS_REF(0)"], + ["BSS(0)"], + ["BSSD(0)"], + ["REL_DIA(0)"], + ], + }, + }, +} diff --git a/tests/test_moveroplot/setup/test_parse_input.py b/tests/test_moveroplot/setup/test_parse_input.py new file mode 100644 index 0000000..d3cb459 --- /dev/null +++ b/tests/test_moveroplot/setup/test_parse_input.py @@ -0,0 +1,50 @@ +"""UNIT test to test the parse_input function.""" +# Third-party +import pytest +from references import DEFAULT_PLOT_SETUP + +# First-party +from moveroplot.config import plot_settings +from moveroplot.parse_inputs import _parse_inputs + + +# pylint: disable=redefined-outer-name +@pytest.fixture +def mock_input_dir(tmp_path): + d = tmp_path / "input_dir" + d.mkdir() + (d / "C-1E_ch").mkdir() + (d / "C-1E-CTR_ch").mkdir() + (d / "C-2E_alps").mkdir() + return d + + +@pytest.fixture +def test_input_dict(mock_input_dir): + return { + "debug": False, + "input_dir": mock_input_dir, + "model_versions": "C-1E_ch/C-1E-CTR_ch,C-2E_alps", + "plot_params": "TOT_PREC12,TOT_PREC6,CLCT,T_2M,TD_2M", + "plot_scores": "ME,MMOD/MOBS,MAE", + "plot_cat_params": "TOT_PREC12,TOT_PREC6,CLCT,T_2M,TD_2M,FF_10M,VMAX_10M6", + "plot_cat_thresh": "0.1:0.2:2.5:0:0:2.5:5", + "plot_cat_scores": "FBI,MF/POD,FAR,THS,ETS", + "plot_ens_params": "TOT_PREC12,TOT_PREC6,CLCT,T_2M,TD_2M,FF_10M,VMAX_10M6", + "plot_ens_scores": "OUTLIERS/RANK,RPS,RPS_REF", + "plot_ens_cat_params": "TOT_PREC12,TOT_PREC6,CLCT,T_2M,TD_2M,FF_10M,VMAX_10M6", + "plot_ens_cat_thresh": "0.1:0.2:2.5:0:0:2.5:5", + "plot_ens_cat_scores": "REL,RES,BS,BS_REF,BSS,BSSD/REL_DIA", + "plotcolors": "black,orange,blue", + "plot_type": "total,ensemble,station,time,daytime", + } + + +class TestInput: + def test_parse_input(self, test_input_dict): + # Test with valid inputs + result = _parse_inputs(**test_input_dict) + assert result == DEFAULT_PLOT_SETUP, "plot_setup differs." + assert plot_settings.modelcolors == test_input_dict["plotcolors"].split( + "," + ), "color input differs." diff --git a/tests/test_moveroplot/test_cli.py b/tests/test_moveroplot/test_cli.py deleted file mode 100644 index 0c10473..0000000 --- a/tests/test_moveroplot/test_cli.py +++ /dev/null @@ -1,79 +0,0 @@ -"""Test module ``moveroplot``.""" -# Third-party -from click.testing import CliRunner - -# First-party -from moveroplot import cli - - -class _TestCLI: - """Base class to test the command line interface.""" - - def call(self, args=None, *, exit_=0): - runner = CliRunner() - result = runner.invoke(cli.main, args) - assert result.exit_code == exit_ - return result - - -class TestNoCmd(_TestCLI): - """Test CLI without commands.""" - - def test_default(self): - result = self.call() - assert result.output.startswith("Usage: ") - assert "Show this message and exit." in result.output - - def test_help(self): - result = self.call(["--help"]) - assert result.output.startswith("Usage: ") - assert "Show this message and exit." in result.output - - def test_version(self): - result = self.call(["-V"]) - assert cli.__version__ in result.output - - def test_only_arg(self): - result = self.call(["4"]) - assert result.output.strip() == "4" - - def test_two_args(self): - result = self.call(["4", "5"], exit_=2) - assert "Error: No such command '5'." in result.output.split("\n") - - -class TestOneCmd(_TestCLI): - """Test CLI with single command.""" - - def test_add(self): - result = self.call(["4", "plus", "2"]) - assert result.output.strip() == "6" - - def test_subtract(self): - result = self.call(["4", "minus", "2"]) - assert result.output.strip() == "2" - - def test_multiply(self): - result = self.call(["4", "times", "2"]) - assert result.output.strip() == "8" - - def test_divide(self): - result = self.call(["4", "by", "2"]) - assert result.output.strip() == "2" - - -class TestMultCmds(_TestCLI): - """Test CLI with multiple commands.""" - - args = ["4", "plus", "1", "minus", "2", "times", "3", "by", "9"] - - def test_result_only(self): - result = self.call(self.args) - assert result.output.strip() == "1" - - def test_verbose(self): - result = self.call(["-v"] + self.args) - result = result.output.strip().split("\n") - solution = ["4 + 1 = 5", "5 - 2 = 3", "3 * 3 = 9", "9 / 9 = 1", "1"] - assert len(result) == len(solution) - assert result == solution diff --git a/tests/test_moveroplot/test_mutable_number.py b/tests/test_moveroplot/test_mutable_number.py deleted file mode 100644 index aed8b93..0000000 --- a/tests/test_moveroplot/test_mutable_number.py +++ /dev/null @@ -1,38 +0,0 @@ -"""Tests for `moveroplot` package.""" -# First-party -from moveroplot.mutable_number import MutableNumber - - -def test_get(): - n = MutableNumber(3) - assert n.get() == 3 - - -def test_add(): - n = MutableNumber(4).add(2) - assert n.get() == 6 - assert n.history == [4, 6] - - -def test_subtract(): - n = MutableNumber(4).subtract(2) - assert n.get() == 2 - assert n.history == [4, 2] - - -def test_multiply(): - n = MutableNumber(4).multiply(2) - assert n.get() == 8 - assert n.history == [4, 8] - - -def test_divide(): - n = MutableNumber(4).divide(2) - assert n.get() == 2 - assert n.history == [4, 2] - - -def test_chain(): - n = MutableNumber(4).add(1).subtract(2).multiply(3).divide(9) - assert n.get() == 1 - assert n.history == [4, 5, 3, 9, 1] diff --git a/tests/test_moveroplot/test_utils.py b/tests/test_moveroplot/test_utils.py deleted file mode 100644 index 854dc41..0000000 --- a/tests/test_moveroplot/test_utils.py +++ /dev/null @@ -1,14 +0,0 @@ -# pylint: skip-file -"""Test module ``moveroplot/utils.py``.""" -# Standard library -import logging - -# First-party -from moveroplot.utils import count_to_log_level - - -def test_count_to_log_level(): - assert count_to_log_level(0) == logging.ERROR - assert count_to_log_level(1) == logging.WARNING - assert count_to_log_level(2) == logging.INFO - assert count_to_log_level(3) == logging.DEBUG