From e1ece334ff10e8bfae8492f7c49df0fe9f8ea4f8 Mon Sep 17 00:00:00 2001 From: David McClosky Date: Thu, 12 Sep 2024 13:51:27 -0400 Subject: [PATCH] Format code in black style (ruff format) Also includes small number of other lint-style cleanups. --- atmospy/__init__.py | 13 +- atmospy/distributions.py | 17 +- atmospy/rcmod.py | 45 +-- atmospy/rose.py | 137 +++++---- atmospy/stats.py | 72 ++--- atmospy/trends.py | 414 +++++++++++++++++----------- atmospy/utils.py | 104 +++---- examples/calendar_by_day.py | 17 +- examples/calendar_by_hour.py | 18 +- examples/diel_by_weekend_weekday.py | 17 +- examples/dielplot.py | 14 +- examples/pollution_rose.py | 15 +- examples/regression.py | 8 +- examples/rose_by_month.py | 34 ++- tests/test_relational.py | 19 +- tests/test_rose.py | 18 +- tests/test_stats.py | 23 +- tests/test_trends.py | 49 ++-- tests/test_utils.py | 46 ++-- 19 files changed, 615 insertions(+), 465 deletions(-) diff --git a/atmospy/__init__.py b/atmospy/__init__.py index 3bcd565..bc4ea67 100644 --- a/atmospy/__init__.py +++ b/atmospy/__init__.py @@ -1,21 +1,16 @@ from importlib.metadata import version -# import warnings -# import pandas as pd -# import numpy as np -# import math -# import os +import matplotlib as mpl -from .utils import * from .rcmod import * from .relational import * -from .trends import * from .rose import * from .stats import * +from .trends import * +from .utils import * # Capture the original matplotlib rcParams -import matplotlib as mpl _orig_rc_params = mpl.rcParams.copy() # Determine the atmospy version -__version__ = version('atmospy') +__version__ = version("atmospy") diff --git a/atmospy/distributions.py b/atmospy/distributions.py index 942f5ab..db1ea7e 100644 --- a/atmospy/distributions.py +++ b/atmospy/distributions.py @@ -1,18 +1,7 @@ -from .utils import ( - remove_na, -) -from seaborn import ( - FacetGrid, -) +__all__ = [] -__all__ = ["psdplot", ] - -def psdplot( - data=None, *, - x=None, y=None, row=None, col=None, - **kwargs -): +def psdplot(data=None, *, x=None, y=None, row=None, col=None, **kwargs): """Plot a particle size distribution. Parameters @@ -32,4 +21,4 @@ def psdplot( def bananaplot(): - return \ No newline at end of file + return diff --git a/atmospy/rcmod.py b/atmospy/rcmod.py index 048603b..d796407 100644 --- a/atmospy/rcmod.py +++ b/atmospy/rcmod.py @@ -2,16 +2,24 @@ __all__ = ["set_theme"] -def set_theme(context="notebook", style='ticks', palette='colorblind', - font='sans-serif', font_scale=1., color_codes=True, rc=None): + +def set_theme( + context="notebook", + style="ticks", + palette="colorblind", + font="sans-serif", + font_scale=1.0, + color_codes=True, + rc=None, +): """Change the look and feel of your plots with one simple function. - - This is a simple pass-through function to the Seaborn function of the - same name, but with different default parameters. For complete information + + This is a simple pass-through function to the Seaborn function of the + same name, but with different default parameters. For complete information and a better description that I can provide, please see the Seaborn docs `here `_. - - This mostly passes down to the seaborn function of the same name, but with a + + This mostly passes down to the seaborn function of the same name, but with a few opinions mixed in. Parameters @@ -23,28 +31,31 @@ def set_theme(context="notebook", style='ticks', palette='colorblind', palette : string or sequence, optional Set the color palette, by default 'colorblind' font : string, optional - Set the font family, by default 'sans-serif'. See the + Set the font family, by default 'sans-serif'. See the matplotlib font manager for more information. font_scale : float, optional Independently scale the font size, by default 1 color_codes : bool, optional - If `True`, remap the shorthand color codes assuming you are + If `True`, remap the shorthand color codes assuming you are using a seaborn palette, by default True rc : dict or None, optional - Pass through a dictionary of rc parameter mappings to override + Pass through a dictionary of rc parameter mappings to override the defaults, by default None - + """ - default_rcparams = { - "mathtext.default": "regular" - } - + default_rcparams = {"mathtext.default": "regular"} + if rc is not None: rc = dict(default_rcparams, **rc) else: rc = default_rcparams sns.set_theme( - context=context, style=style, palette=palette, - font=font, font_scale=font_scale, color_codes=color_codes, rc=rc + context=context, + style=style, + palette=palette, + font=font, + font_scale=font_scale, + color_codes=color_codes, + rc=rc, ) diff --git a/atmospy/rose.py b/atmospy/rose.py index 298cbe9..168382c 100644 --- a/atmospy/rose.py +++ b/atmospy/rose.py @@ -1,32 +1,47 @@ """This file contains the wind and pollution rose figures.""" -import seaborn as sns -import pandas as pd -import numpy as np -import matplotlib.pyplot as plt + import math -from .utils import ( - check_for_numeric_cols, -) + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns + +from .utils import check_for_numeric_cols # Turn off chained assignment warnings pd.options.mode.chained_assignment = None __all__ = ["pollutionroseplot"] -def pollutionroseplot(data=None, *, ws=None, wd=None, pollutant=None, - faceted=False, segments=12, bins=[0, 10, 100, 1000], suffix="a.u.", - calm=0., lw=1, legend=True, palette="flare", - title=None, **kwargs): + +def pollutionroseplot( + data=None, + *, + ws=None, + wd=None, + pollutant=None, + faceted=False, + segments=12, + bins=[0, 10, 100, 1000], + suffix="a.u.", + calm=0.0, + lw=1, + legend=True, + palette="flare", + title=None, + **kwargs, +): """Plot the intensity and directionality of a variable on a traditional wind-rose plot. This function is a modified version of `Phil Hobson's work `_. - Traditionally, a wind rose plots wind speed and direction so that you can see from what - direction is the wind coming from and at what velocity. For air quality purposes, we - often wonder whether or not there is directionality to the intensity of a certain - air pollutant. Well, look no further. This plot allows you to easily visualize the + Traditionally, a wind rose plots wind speed and direction so that you can see from what + direction is the wind coming from and at what velocity. For air quality purposes, we + often wonder whether or not there is directionality to the intensity of a certain + air pollutant. Well, look no further. This plot allows you to easily visualize the directionality of a pollutant. - + Parameters ---------- data : :class:`pandas.DataFrame` @@ -44,13 +59,13 @@ def pollutionroseplot(data=None, *, ws=None, wd=None, pollutant=None, , by default 12 bins : list or array of floats, optional An array of floats corresponding to the bin boundaries - for `pollutant`; if the last entry is not inf, it will be + for `pollutant`; if the last entry is not inf, it will be automatically added, by default [0, 10, 100, 1000] suffix : str, optional The suffix (or units) to use on the labels for `pollutant`, by default "a.u." calm : float, optional - Set the definition of calm conditions; data under calm winds - will not be used to compute the statistics and + Set the definition of calm conditions; data under calm winds + will not be used to compute the statistics and will be shown on the plot as blank in the center, by default 0. lw : int, optional Set the line width, by default 1 @@ -60,34 +75,34 @@ def pollutionroseplot(data=None, *, ws=None, wd=None, pollutant=None, Select the color palette to use, by default "flare" title : str, optional Set the figure title, by default None - + Returns ------- :class:`matplotlib.axes._axes.Axes` - + Examples -------- Using defaults, plot the pollution rose for PM2.5: - + >>> df = atmospy.load_dataset("air-sensors-met") >>> atmospy.pollutionroseplot(data=df, ws="ws", wd="wd", pollutant="pm25") - + """ check_for_numeric_cols(data, [ws, wd, pollutant]) # if the bins don't end in inf, add it if not np.isinf(bins[-1]): bins.append(np.inf) - - # + + # bins = np.asarray(bins) - + # setup the color palette - cp = sns.color_palette(palette, n_colors=bins.shape[0]-1) - + cp = sns.color_palette(palette, n_colors=bins.shape[0] - 1) + # convert the number of segments into the wind bins - wd_segments = np.linspace(0, 360, segments+1) - + wd_segments = np.linspace(0, 360, segments + 1) + def _cat_pollutant_labels(bins, suffix): """_summary_ @@ -104,30 +119,30 @@ def _cat_pollutant_labels(bins, suffix): list_of_labels.append(f">{lowerbound:.0f} {suffix}") else: list_of_labels.append(f"{lowerbound:.0f} to {upperbound:.0f} {suffix}") - + return list_of_labels - + def _compute_bar_dims(thetas): - thetas = (thetas[:-1] + thetas[1:]) / 2. - + thetas = (thetas[:-1] + thetas[1:]) / 2.0 + midpoints = [math.radians(theta) for theta in thetas] - width = math.radians(360./thetas.shape[0]) - + width = math.radians(360.0 / thetas.shape[0]) + return midpoints, width - + # compute the percentage of calm datapoints # where calm is anything with a windspeed < `calm` - pct_calm = 100*data[data[ws] <= calm].shape[0] / data.shape[0] - + pct_calm = 100 * data[data[ws] <= calm].shape[0] / data.shape[0] + # group the data by bin and normalize rv = ( data[data[ws] > calm] .assign( _cp=lambda x: pd.cut( - data[pollutant], - bins=bins, + data[pollutant], + bins=bins, right=True, - labels=_cat_pollutant_labels(bins, suffix) + labels=_cat_pollutant_labels(bins, suffix), ) ) .assign( @@ -135,35 +150,35 @@ def _compute_bar_dims(thetas): data[wd], bins=wd_segments, right=True, - labels=(wd_segments[:-1]+wd_segments[1:])/2. + labels=(wd_segments[:-1] + wd_segments[1:]) / 2.0, ) ) .groupby(["_cp", "_wb"]) .size() .unstack(level="_cp") - .fillna(0.) + .fillna(0.0) .sort_index(axis=1) .applymap(lambda x: 100 * x / data.shape[0]) ) - + # compute the bar dims bar_midpoints, bar_width = _compute_bar_dims(wd_segments) - + # if plotting onto a FacetGrid, get the current axis, otherwise create one if faceted: ax = plt.gca() else: fig = plt.gcf() - ax = fig.add_subplot(111, projection='polar') - + ax = fig.add_subplot(111, projection="polar") + ax.set_theta_direction("clockwise") ax.set_theta_zero_location("N") - + # determine the buffer at the center of the plot - # this is where ws <= `calm` and is evenly spread + # this is where ws <= `calm` and is evenly spread # across all angles buffer = pct_calm / segments - + for i, (innerbar, outerbar) in enumerate(zip(rv.columns[:-1], rv.columns[1:])): if i == 0: # for the first bar, we need to plot the calm hole in the center first @@ -174,9 +189,9 @@ def _compute_bar_dims(thetas): bottom=buffer, label=innerbar, lw=lw, - color=cp[i] + color=cp[i], ) - + ax.bar( bar_midpoints, rv[outerbar].values, @@ -184,23 +199,23 @@ def _compute_bar_dims(thetas): label=outerbar, bottom=buffer + rv.cumsum(axis=1)[innerbar].values, lw=lw, - color=cp[i+1] + color=cp[i + 1], ) - + if legend: ax.legend( loc="center left", - handlelength=1, + handlelength=1, handleheight=1, - bbox_to_anchor=(1.1, 0, 0.5, 1) + bbox_to_anchor=(1.1, 0, 0.5, 1), ) - + if title: ax.set_title(title) - + # clean up the ticks and things ax.set_xticks([math.radians(x) for x in [0, 45, 90, 135, 180, 225, 270, 315]]) ax.set_xticklabels(["N", "NE", "E", "SE", "S", "SW", "W", "NW"]) ax.set_yticklabels([]) - - return ax \ No newline at end of file + + return ax diff --git a/atmospy/stats.py b/atmospy/stats.py index f62dbfc..a4f50d4 100644 --- a/atmospy/stats.py +++ b/atmospy/stats.py @@ -1,18 +1,21 @@ """Utility functions for internal use.""" +from dataclasses import asdict, dataclass + import numpy as np import pandas as pd from scipy.stats import linregress -from dataclasses import dataclass, asdict __all__ = [ - "fleet_precision", - "air_sensor_stats", + "fleet_precision", + "air_sensor_stats", ] + @dataclass class SensorStatsResults: """""" + slope: float intercept: float pearson_r2: float @@ -20,36 +23,42 @@ class SensorStatsResults: rmse: float nrmse: float nobs: int - + asdict = asdict + def _error(actual: np.ndarray, predicted: np.ndarray): return actual - predicted + def mae(actual: np.ndarray, predicted: np.ndarray): return np.mean(np.abs(_error(actual, predicted))) + def mse(actual: np.ndarray, predicted: np.ndarray): return np.mean(np.square(_error(actual, predicted))) + def rmse(actual: np.ndarray, predicted: np.ndarray): return np.sqrt(mse(actual, predicted)) + def nrmse(actual: np.ndarray, predicted: np.ndarray): return rmse(actual, predicted) / np.mean(actual) + def fleet_precision(data: pd.DataFrame): """Compute the precision across a fleet of at least three (3) devices. - - The math used here comes from the `EPA's Air Sensor Performance Targets + + The math used here comes from the `EPA's Air Sensor Performance Targets and Testing Protocols guidelines `_. Parameters ---------- data : pd.DataFrame - A dataframe containing a wide-form dataframe as a time series where the - index is a timestamp. Each column should represent the same pollutant - or other measure for a unique device, whether it is an air sensor + A dataframe containing a wide-form dataframe as a time series where the + index is a timestamp. Each column should represent the same pollutant + or other measure for a unique device, whether it is an air sensor or other instrument. Returns @@ -58,28 +67,27 @@ def fleet_precision(data: pd.DataFrame): Returns the standard deviation and coeficient of variation. """ # drop any record with a NaN present - data = data.dropna(how='any') - + data = data.dropna(how="any") + # ensure there are at least 3 devices if data.shape[1] < 3: - raise ValueError(f"You must have at least three columns; you provided {data.shape[1]}.") - + raise ValueError( + f"You must have at least three columns; you provided {data.shape[1]}." + ) + # compute the standard deviation - sum_of_squares = ( - data.sub(data.mean(axis=1).values, axis=0)**2 - ).values.sum() - - stdev = np.sqrt( - ((1 / (data.shape[0]*data.shape[1] - 1))) * sum_of_squares - ) - + sum_of_squares = (data.sub(data.mean(axis=1).values, axis=0) ** 2).values.sum() + + stdev = np.sqrt((1 / (data.shape[0] * data.shape[1] - 1)) * sum_of_squares) + # compute the coefficient of variation cv = stdev / data.values.mean() - + return stdev, cv + def air_sensor_stats(actual: np.ndarray, predicted: np.ndarray): - """Compute the statistical measures required by EPA for + """Compute the statistical measures required by EPA for Air Sensor NSIM evaluation per their guidebooks. Parameters @@ -88,28 +96,28 @@ def air_sensor_stats(actual: np.ndarray, predicted: np.ndarray): An array of numeric types with the reference values (i.e., y_true in sklearn language). predicted : np.ndarray An array of numeric types with the air sensor values (i.e., y_pred in sklearn language). - + Returns ------- results: SensorStatsResults - An instance of the SensorStatsResults dataclass with + An instance of the SensorStatsResults dataclass with fit data results including slope, intercept, MAE, RMSE, NRMSE, NOBS, and Pearson-R2. - + """ # force to arrays actual = np.asarray(actual) predicted = np.asarray(predicted) - + if np.isnan(actual).any(): raise ValueError("You cannot have NaN's present in your `actual` array.") - + if np.isnan(predicted).any(): raise ValueError("You cannot have NaN's present in your `predicted` array.") - + # fit the data to a linear model fit = linregress(predicted, actual) - + return SensorStatsResults( fit.slope, fit.intercept, @@ -117,5 +125,5 @@ def air_sensor_stats(actual: np.ndarray, predicted: np.ndarray): mae(predicted, actual), rmse(predicted, actual), nrmse(predicted, actual), - actual.shape[0] - ) \ No newline at end of file + actual.shape[0], + ) diff --git a/atmospy/trends.py b/atmospy/trends.py index 339f144..10752c0 100644 --- a/atmospy/trends.py +++ b/atmospy/trends.py @@ -1,214 +1,268 @@ +import math + import matplotlib as mpl import matplotlib.pyplot as plt -import seaborn as sns -import pandas as pd import numpy as np -import math -from .utils import ( - check_for_numeric_cols, - check_for_timestamp_col -) +import pandas as pd + +from .utils import check_for_numeric_cols, check_for_timestamp_col # Turn off chained assignment warnings pd.options.mode.chained_assignment = None __all__ = ["dielplot", "calendarplot"] + @mpl.ticker.FuncFormatter def custom_month_formatter(x, pos): return str(math.ceil(x)) -def _yearplot(data, x, y, ax=None, agg="mean", cmap="crest", - height=2, aspect=5, vmin=None, vmax=None, - linecolor="white", linewidths=0, cbar=True, cbar_kws=None, - units="", faceted=False, **kwargs): - """Plot a full year of time series data on a heatmap by month. - """ + +def _yearplot( + data, + x, + y, + ax=None, + agg="mean", + cmap="crest", + height=2, + aspect=5, + vmin=None, + vmax=None, + linecolor="white", + linewidths=0, + cbar=True, + cbar_kws=None, + units="", + faceted=False, + **kwargs, +): + """Plot a full year of time series data on a heatmap by month.""" if ax is None: ax = plt.gca() - + if not faceted: - ax.figure.set_size_inches(height*aspect, height) - + ax.figure.set_size_inches(height * aspect, height) + # if more than 1Y of data was provided, limit to 1Y years = np.unique(data.index.year) if years.size > 1: # warn data = data[data.index.year == years[0]] - + data.loc[:, "Day of Week"] = data.index.weekday data.loc[:, "Week of Year"] = data.index.isocalendar().week - + # compute pivoted data pivot = data.pivot_table( - index="Day of Week", - columns="Week of Year", - values=y, - aggfunc=agg + index="Day of Week", columns="Week of Year", values=y, aggfunc=agg ) - + # adjust the index to ensure we have a properly-sized array - pivot = pivot.reindex( - index=range(0, 7), - columns=range(1, 53) - ) - + pivot = pivot.reindex(index=range(0, 7), columns=range(1, 53)) + # reverse the array along the yaxis so that Monday ends up at the top of the fig pivot = pivot[::-1] - + # set the min and max of the colorbar if vmin is None: vmin = np.nanmin(pivot.values) - + if vmax is None: vmax = np.nanmax(pivot.values) - + # plot the heatmap im = ax.pcolormesh( - pivot, - vmin=vmin, vmax=vmax, cmap=cmap, - linewidth=linewidths, edgecolors=linecolor + pivot, + vmin=vmin, + vmax=vmax, + cmap=cmap, + linewidth=linewidths, + edgecolors=linecolor, ) - + # modify the axes ticks ax.xaxis.set_major_locator(mpl.ticker.LinearLocator(14)) - ax.xaxis.set_ticklabels([ - "", - "Jan", "Feb", "Mar", "Apr", "May", "Jun", - "Jul", "Aug", "Sep", "Oct", "Nov", "Dec" - "" - ], rotation="horizontal", va="center") - + ax.xaxis.set_ticklabels( + [ + "", + "Jan", + "Feb", + "Mar", + "Apr", + "May", + "Jun", + "Jul", + "Aug", + "Sep", + "Oct", + "Nov", + "Dec" "", + ], + rotation="horizontal", + va="center", + ) + ax.yaxis.tick_right() ax.yaxis.set_ticks([0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5]) - ax.yaxis.set_ticklabels([ - "Sun", "Sat", "Fri", "Thu", "Wed", "Tue", "Mon" - ], rotation="horizontal", va="center", ha="left") + ax.yaxis.set_ticklabels( + ["Sun", "Sat", "Fri", "Thu", "Wed", "Tue", "Mon"], + rotation="horizontal", + va="center", + ha="left", + ) ax.yaxis.set_tick_params(right=False) - + # add a big ol' year on the left-hand side - ax.set_ylabel( - f"{years[0]}", - fontsize=28, color="gray", ha="center" - ) - + ax.set_ylabel(f"{years[0]}", fontsize=28, color="gray", ha="center") + # add a colorbar if set if cbar: cbar_kws["pad"] = cbar_kws.get("pad", 0.05) cb = ax.figure.colorbar(im, ax=ax, **cbar_kws) cb.outline.set_visible(False) - + # adjust the colorbar ticklabels cb.ax.yaxis.set_major_locator(mpl.ticker.MaxNLocator(4)) - + # modify the tick labels # TODO: this is currently not working ticklabels = [x.get_text() for x in cb.ax.get_yticklabels()] ticklabels[-1] = f"{ticklabels[-1]} {units}" cb.set_ticks(cb.get_ticks()) cb.set_ticklabels(ticklabels) - + return ax - -def _monthplot(data, x, y, ax=None, agg="mean", height=3, aspect=1, - vmin=None, vmax=None, cmap="crest", linewidths=0.1, - linecolor="white", cbar=True, cbar_kws=None, - units=None, faceted=False, **kwargs): - """Plot a full month of time series data on a heatmap by hour. - """ + +def _monthplot( + data, + x, + y, + ax=None, + agg="mean", + height=3, + aspect=1, + vmin=None, + vmax=None, + cmap="crest", + linewidths=0.1, + linecolor="white", + cbar=True, + cbar_kws=None, + units=None, + faceted=False, + **kwargs, +): + """Plot a full month of time series data on a heatmap by hour.""" if ax is None: ax = plt.gca() - + if not faceted: - ax.figure.set_size_inches(height*aspect, height) - + ax.figure.set_size_inches(height * aspect, height) + # if more than 1mo of data was provided, limit to 1mo months = np.unique(data.index.month) if months.size > 1: # TODO: log warning data = data[data.index.month == months[0]] - + # add pivot columns data.loc[:, "Day of Month"] = data.index.day data.loc[:, "Hour of Day"] = data.index.hour - + # compute the pivot data pivot = data.pivot_table( - index="Hour of Day", - columns="Day of Month", - values=y, - aggfunc=agg + index="Hour of Day", columns="Day of Month", values=y, aggfunc=agg ) - + # get the total number of available days in the month days_in_month = data.index.days_in_month[0] - + # adjust the index to ensure we have a properly-sized array - pivot = pivot.reindex( - index=range(0, 24), - columns=range(1, days_in_month + 1) - ) - + pivot = pivot.reindex(index=range(0, 24), columns=range(1, days_in_month + 1)) + # reverse the order of the matrix along the y-axis so that Monday is at the top pivot = pivot[::-1] - + # set the min and max values for the colorbar if vmin is None: vmin = np.nanmin(pivot.values) - + if vmax is None: vmax = np.nanmax(pivot.values) - + # plot the heatmap im = ax.pcolormesh( pivot, - cmap=cmap, vmin=vmin, vmax=vmax, - linewidth=linewidths, edgecolors=linecolor + cmap=cmap, + vmin=vmin, + vmax=vmax, + linewidth=linewidths, + edgecolors=linecolor, ) - + # add a colorbar if set if cbar: cb = ax.figure.colorbar(im, ax=ax, **cbar_kws) cb.outline.set_visible(False) - + # adjust the tick labels ticklabels = [x.get_text() for x in cb.ax.get_yticklabels()] ticklabels[-1] = f"{ticklabels[-1]} {units}" cb.set_ticks(cb.get_ticks()) cb.set_ticklabels(ticklabels) - + # adjust the axes labels - ax.xaxis.set_major_locator(mpl.ticker.FixedLocator([x - 0.5 for x in list(range(1, days_in_month, 4))])) + ax.xaxis.set_major_locator( + mpl.ticker.FixedLocator([x - 0.5 for x in list(range(1, days_in_month, 4))]) + ) ax.xaxis.set_major_formatter(custom_month_formatter) ax.set_yticks([0, 6, 12, 18, 24]) - ax.set_yticklabels([ - "12 AM", "6 PM", "12 PM", "6 AM", "12 AM" - ]) - + ax.set_yticklabels(["12 AM", "6 PM", "12 PM", "6 AM", "12 AM"]) + return ax -def calendarplot(data, x, y, freq="day", agg="mean", vmin=None, vmax=None, cmap="crest", - ax=None, linecolor="white", linewidths=0, cbar=True, cbar_kws=None, - xlabel=None, ylabel=None, title=None, units="", height=2, - aspect=5.0, faceted=False, **kwargs): +def calendarplot( + data, + x, + y, + freq="day", + agg="mean", + vmin=None, + vmax=None, + cmap="crest", + ax=None, + linecolor="white", + linewidths=0, + cbar=True, + cbar_kws=None, + xlabel=None, + ylabel=None, + title=None, + units="", + height=2, + aspect=5.0, + faceted=False, + **kwargs, +): """Visualize data as a heatmap on a monthly or annual basis. - - Calendar plots can be a useful way to visualize trends in data over longer periods - of time. This function is quite generic and allows you to visualize data either by - month (where the x-axis is day of month and y-axis is hour of day) or year (where + + Calendar plots can be a useful way to visualize trends in data over longer periods + of time. This function is quite generic and allows you to visualize data either by + month (where the x-axis is day of month and y-axis is hour of day) or year (where x-axis is the week of the year and y-axis is the day of the week). Configure the plot to aggregrate the data any way you choose (e.g., sum, mean, max). - - Currently, you can only plot a single month or single year at a time depending on - configuration. To facet these, please set up a Seaborn FacetGrid and call the + + Currently, you can only plot a single month or single year at a time depending on + configuration. To facet these, please set up a Seaborn FacetGrid and call the calendarplot separately. - - This function is heavily influenced by the `calplot `_ + + This function is heavily influenced by the `calplot `_ python library. - + Parameters ---------- data : :class:`pandas.DataFrame` @@ -251,78 +305,112 @@ def calendarplot(data, x, y, freq="day", agg="mean", vmin=None, vmax=None, cmap= The aspect ratio of the figure, by default 5.0 faceted : bool optional Set to `True` if combining with a FacetGrid, optional - + Returns ------- :class:`matplotlib.axes._axes.Axes` Examples -------- - + Plot a simple heatmap for the entire year. >>> df = atmospy.load_dataset("us-bc") >>> atmospy.calendarplot(df, x="Timestamp GMT", y="Sample Measurement") - + """ check_for_timestamp_col(data, x) check_for_numeric_cols(data, [y]) - + if freq not in ("hour", "day"): raise ValueError("Invalid argument for `freq`") - - cbar_kws_default = { - "shrink": 0.67, - "drawedges": False - } - + + cbar_kws_default = {"shrink": 0.67, "drawedges": False} + if cbar_kws is None: cbar_kws = {} - + cbar_kws = dict(cbar_kws_default, **cbar_kws) - + # select only the data that is needed df = data[[x, y]].copy(deep=True) df = df.set_index(x) - + if freq == "day": ax = _yearplot( - df, x, y, - agg=agg, height=height, aspect=aspect, - vmin=vmin, vmax=vmax, linewidths=linewidths, linecolor=linecolor, - cbar=cbar, cbar_kws=cbar_kws, units=units, cmap=cmap, faceted=faceted, **kwargs + df, + x, + y, + agg=agg, + height=height, + aspect=aspect, + vmin=vmin, + vmax=vmax, + linewidths=linewidths, + linecolor=linecolor, + cbar=cbar, + cbar_kws=cbar_kws, + units=units, + cmap=cmap, + faceted=faceted, + **kwargs, ) elif freq == "hour": ax = _monthplot( - df, x, y, - agg=agg, height=height, aspect=aspect, - vmin=vmin, vmax=vmax, linewidths=linewidths, linecolor=linecolor, - cbar=cbar, cbar_kws=cbar_kws, units=units, cmap=cmap, faceted=faceted, **kwargs + df, + x, + y, + agg=agg, + height=height, + aspect=aspect, + vmin=vmin, + vmax=vmax, + linewidths=linewidths, + linecolor=linecolor, + cbar=cbar, + cbar_kws=cbar_kws, + units=units, + cmap=cmap, + faceted=faceted, + **kwargs, ) - + ax.set_aspect("equal") # remove the spines for spine in ("top", "bottom", "right", "left"): ax.spines[spine].set_visible(False) - + if title: ax.set_title(title) - + if xlabel: ax.set_xlabel(xlabel) - + if ylabel: ax.set_ylabel(ylabel) - + return ax -def dielplot(data=None, *, x=None, y=None, ax=None, ylim=None, xlabel=None, - ylabel=None, title=None, color=None, show_iqr=True, plot_kws=None, **kwargs): +def dielplot( + data=None, + *, + x=None, + y=None, + ax=None, + ylim=None, + xlabel=None, + ylabel=None, + title=None, + color=None, + show_iqr=True, + plot_kws=None, + **kwargs, +): """Plot the diel (e.g., diurnal) trend for a pollutant. - - Diel plots can be incredibly useful for understanding daily + + Diel plots can be incredibly useful for understanding daily patterns of air pollutants. Parameters @@ -350,68 +438,72 @@ def dielplot(data=None, *, x=None, y=None, ax=None, ylim=None, xlabel=None, plot_kws : dict or None, optional Additional keyword arguments are passed directly to the underlying plot call , by default None - + Returns ------- :class:`matplotlib.axes._axes.Axes` - + Examples -------- - + Plot a simple heatmap for the entire year. >>> df = atmospy.load_dataset("us-bc") >>> atmospy.dielplot(data=df, x="Timestamp GMT", y="Sample Measurement") - + """ default_plot_kws = { "lw": 3, } - + # complete some initial data quality checks check_for_timestamp_col(data, x) check_for_numeric_cols(data, [y]) - - # + + # plot_kws = {} if plot_kws is None else dict(default_plot_kws, **plot_kws) if color is not None: plot_kws.update(dict(c=color)) - + # copy over only the needed data _data = data[[x, y]].copy(deep=True) _data = _data.set_index(x) - - # + + # # figure setup if ax is None: ax = plt.gca() - + # compute the diel statistics - stats = _data.groupby([_data.index.hour, _data.index.minute], as_index=False).describe() - + stats = _data.groupby( + [_data.index.hour, _data.index.minute], as_index=False + ).describe() + # append the first record so the first and last records are identical stats.loc[len(stats.index)] = stats.loc[0] - + # build an index we can use to make the figure index = stats.index.values freq = int(60 / ((index.size - 1) / 24)) - figure_index = pd.date_range(start='2020-01-01', periods=index.size, freq=f"{freq}min") - + figure_index = pd.date_range( + start="2020-01-01", periods=index.size, freq=f"{freq}min" + ) + # plot the diel average - ax.plot(figure_index, stats[y]['mean'], **plot_kws) - + ax.plot(figure_index, stats[y]["mean"], **plot_kws) + # add the IQR as a shaded region if show_iqr: ax.fill_between( figure_index, - y1=stats[y]['25%'], - y2=stats[y]['75%'], + y1=stats[y]["25%"], + y2=stats[y]["75%"], alpha=0.25, lw=2, - color=plt.gca().lines[-1].get_color() + color=plt.gca().lines[-1].get_color(), ) - + # adjust plot parameters xticks = ax.get_xticks() ax.set_xticks(np.linspace(xticks[0], xticks[-1], 5)) @@ -422,14 +514,14 @@ def dielplot(data=None, *, x=None, y=None, ax=None, ylim=None, xlabel=None, # add optional labels if xlabel: ax.set_xlabel(xlabel) - + if ylabel: ax.set_ylabel(ylabel) - + if title: ax.set_title(title) - + if ylim: ax.set_ylim(ylim) - - return ax \ No newline at end of file + + return ax diff --git a/atmospy/utils.py b/atmospy/utils.py index 5afcb78..c4e10b7 100644 --- a/atmospy/utils.py +++ b/atmospy/utils.py @@ -1,37 +1,35 @@ """Utility functions for internal use.""" import os -from urllib.request import ( - urlopen, - urlretrieve -) -from seaborn.external.appdirs import ( - user_cache_dir, -) +from urllib.request import urlopen, urlretrieve + import pandas as pd +from seaborn.external.appdirs import user_cache_dir DATASET_SOURCE = "https://raw.githubusercontent.com/dhhagan/atmospy-data/main" DATASET_NAMES_URL = f"{DATASET_SOURCE}/dataset_names.txt" + def get_dataset_names(): """List the available example datasets. - + Requires an internet connection. """ with urlopen(DATASET_NAMES_URL) as resp: txt = resp.read() - + dataset_names = [name.strip() for name in txt.decode().split("\n")] - + return list(filter(None, dataset_names)) + def get_data_home(data_home=None): """Return a path to the cache directory for the sample datasets. - If the ``data_home`` argument is not provided and the `ATMOSPY_DATA` - environment variable is not set, an OS-appropriate folder will be created + If the ``data_home`` argument is not provided and the `ATMOSPY_DATA` + environment variable is not set, an OS-appropriate folder will be created and used. - + Parameters ---------- data_home : Path, optional @@ -39,102 +37,109 @@ def get_data_home(data_home=None): """ if data_home is None: data_home = os.environ.get("ATMOSPY_DATA", user_cache_dir("atmospy")) - + data_home = os.path.expanduser(data_home) if not os.path.exists(data_home): os.makedirs(data_home) - + return data_home + def load_dataset(name, cache=True, data_home=None, **kwargs): """Load an example dataset from the online repository (requires internet). - - This function provides quick access to a number of example datasets that - can be used to either explore the plotting functionality of atmospy or + + This function provides quick access to a number of example datasets that + can be used to either explore the plotting functionality of atmospy or to report issues without needing to upload your own data. - - This function also handles some basic data pre-processing to ensure they + + This function also handles some basic data pre-processing to ensure they are ready-to-go. Parameters ---------- name : str - The name of the dataset. Dataset names can be found on - https://github.com/dhhagan/atmospy-data or by running + The name of the dataset. Dataset names can be found on + https://github.com/dhhagan/atmospy-data or by running the `get_dataset_names` function. cache : bool, optional - If `True`, the dataset will be loaded from local cache if - available and it will save to local cache if it needs to + If `True`, the dataset will be loaded from local cache if + available and it will save to local cache if it needs to be downloaded, by default True data_home : str, optional - The directory to store the cached data; if not set, - it will be determined for your operating system using the + The directory to store the cached data; if not set, + it will be determined for your operating system using the `get_data_home` function, by default None - + Returns ------- df : :class:`pandas.DataFrame` Tabular data. - + """ if not isinstance(name, str): err = ( "This function only accepts strings and the string must be one of the example datasets.", ) raise TypeError(err) - + available_dataset_names = get_dataset_names() if name not in available_dataset_names: - raise ValueError(f"{name} is not a valid option. Please choose one of {available_dataset_names}.") - + raise ValueError( + f"{name} is not a valid option. Please choose one of {available_dataset_names}." + ) + url = f"{DATASET_SOURCE}/{name}.csv" - + if cache: cache_path = os.path.join(get_data_home(data_home), os.path.basename(url)) - + # Check for the existence of a locally cached version if not os.path.exists(cache_path): urlretrieve(url, cache_path) - + full_path = cache_path else: full_path = url - + # Load the data into a DataFrame df = pd.read_csv(full_path, **kwargs) - + # Here is where we can/should place any dataset-dependent modifications if name == "us-ozone": df["Timestamp GMT"] = pd.to_datetime(df["Timestamp GMT"]) - + # conver ozone from ppm to ppb df["Sample Measurement"] *= 1e3 - + # add a column for local time df["Timestamp Local"] = df.apply( - lambda x: x["Timestamp GMT"] + pd.Timedelta(hours=x["GMT Offset"]), axis=1) - + lambda x: x["Timestamp GMT"] + pd.Timedelta(hours=x["GMT Offset"]), axis=1 + ) + if name == "us-bc": df["Timestamp GMT"] = pd.to_datetime(df["Timestamp GMT"]) - + # add a column for local time df["Timestamp Local"] = df.apply( - lambda x: x["Timestamp GMT"] + pd.Timedelta(hours=x["GMT Offset"]), axis=1) - + lambda x: x["Timestamp GMT"] + pd.Timedelta(hours=x["GMT Offset"]), axis=1 + ) + if name == "air-sensors-pm": df["timestamp"] = pd.to_datetime(df["timestamp"]) - + if name == "air-sensors-met": df["timestamp_local"] = pd.to_datetime(df["timestamp_local"]) - + # only keep data after april df = df[df["timestamp_local"] >= "2023-05-01"].copy() - + return df + def remove_na(vec): return + def check_for_timestamp_col(data, col): """Make sure the column is a proper timestamp according to Pandas. @@ -147,7 +152,8 @@ def check_for_timestamp_col(data, col): """ if not pd.core.dtypes.common.is_datetime64_any_dtype(data[col]): raise TypeError(f"Column `{col}` is not a proper timestamp.") - + + def check_for_numeric_cols(data, cols): """Make sure the column(s) is/are numeric. @@ -160,4 +166,6 @@ def check_for_numeric_cols(data, cols): """ for col in cols: if not pd.core.dtypes.common.is_numeric_dtype(data[col]): - raise TypeError(f"Column `{col}` is not numeric. Please convert to a numeric dtype before proceeding.") + raise TypeError( + f"Column `{col}` is not numeric. Please convert to a numeric dtype before proceeding." + ) diff --git a/examples/calendar_by_day.py b/examples/calendar_by_day.py index cd77b16..f84b2a5 100644 --- a/examples/calendar_by_day.py +++ b/examples/calendar_by_day.py @@ -4,24 +4,23 @@ _thumb: .8, .8 """ + import atmospy -import pandas as pd + atmospy.set_theme() # Load the example dataset df = atmospy.load_dataset("us-ozone") # Select a single location -single_site_ozone = df[ - df["Local Site Name"] == df["Local Site Name"].unique()[0] -] +single_site_ozone = df[df["Local Site Name"] == df["Local Site Name"].unique()[0]] atmospy.calendarplot( - data=single_site_ozone, - x="Timestamp Local", - y="Sample Measurement", + data=single_site_ozone, + x="Timestamp Local", + y="Sample Measurement", freq="day", cbar=False, height=2.5, - linewidths=.1 -); \ No newline at end of file + linewidths=0.1, +) diff --git a/examples/calendar_by_hour.py b/examples/calendar_by_hour.py index 649fbf1..732b147 100644 --- a/examples/calendar_by_hour.py +++ b/examples/calendar_by_hour.py @@ -4,26 +4,26 @@ _thumb: .8, .8 """ + import atmospy -import pandas as pd + atmospy.set_theme() # Load the example dataset df = atmospy.load_dataset("us-ozone") # Select a single location -single_site_ozone = df[ - df["Local Site Name"] == df["Local Site Name"].unique()[0] -] +single_site_ozone = df[df["Local Site Name"] == df["Local Site Name"].unique()[0]] atmospy.calendarplot( - data=single_site_ozone, - x="Timestamp Local", - y="Sample Measurement", + data=single_site_ozone, + x="Timestamp Local", + y="Sample Measurement", freq="hour", xlabel="Day of Month", height=4, cmap="flare", - vmin=0, vmax=80, + vmin=0, + vmax=80, title="Ozone in [Month]", -); \ No newline at end of file +) diff --git a/examples/diel_by_weekend_weekday.py b/examples/diel_by_weekend_weekday.py index 5e49e63..6b3f9e7 100644 --- a/examples/diel_by_weekend_weekday.py +++ b/examples/diel_by_weekend_weekday.py @@ -4,8 +4,11 @@ _thumb: .4, .4 """ -import atmospy + import seaborn as sns + +import atmospy + atmospy.set_theme() # load the data @@ -15,12 +18,14 @@ bc_multi_site = bc[bc["Local Site Name"].isin(bc["Local Site Name"].unique()[0:2])] # create a column that sets a bool if the date is a weekend -bc_multi_site.loc[:, "Is Weekend"] = bc_multi_site["Timestamp Local"].dt.day_name().isin(["Saturday", "Sunday"]) +bc_multi_site.loc[:, "Is Weekend"] = ( + bc_multi_site["Timestamp Local"].dt.day_name().isin(["Saturday", "Sunday"]) +) # convert to long-form for faceting bc_long_form = bc_multi_site.melt( - id_vars=["Timestamp Local", "Is Weekend", "Local Site Name"], - value_vars=["Sample Measurement"] + id_vars=["Timestamp Local", "Is Weekend", "Local Site Name"], + value_vars=["Sample Measurement"], ) g = sns.FacetGrid( @@ -34,7 +39,7 @@ g.map_dataframe(atmospy.dielplot, x="Timestamp Local", y="value") # update the y-axis limit to force to zero -g.set(ylim=(0, None), ylabel='Black Carbon') +g.set(ylim=(0, None), ylabel="Black Carbon") # update the titles to take up less space -g.set_titles("{row_name} | Weekend = {col_name}") \ No newline at end of file +g.set_titles("{row_name} | Weekend = {col_name}") diff --git a/examples/dielplot.py b/examples/dielplot.py index c699721..ad08342 100644 --- a/examples/dielplot.py +++ b/examples/dielplot.py @@ -4,23 +4,23 @@ _thumb: .8, .8 """ + import atmospy -import pandas as pd + atmospy.set_theme() # Load the example dataset df = atmospy.load_dataset("us-ozone") # Select a single location -single_site_ozone = df[ - df["Local Site Name"] == df["Local Site Name"].unique()[0] -] +single_site_ozone = df[df["Local Site Name"] == df["Local Site Name"].unique()[0]] # Plot the diel trend atmospy.dielplot( single_site_ozone, - y="Sample Measurement", x="Timestamp Local", + y="Sample Measurement", + x="Timestamp Local", ylabel="$O_3 \; [ppm]$", plot_kws={"c": "g"}, - ylim=(0, None) -) \ No newline at end of file + ylim=(0, None), +) diff --git a/examples/pollution_rose.py b/examples/pollution_rose.py index 56aacdf..b3a2297 100644 --- a/examples/pollution_rose.py +++ b/examples/pollution_rose.py @@ -4,7 +4,9 @@ _thumb: .8, .8 """ + import atmospy + atmospy.set_theme() # Load the example dataset @@ -12,7 +14,12 @@ # Plot a pollution rose example for PM2.5 atmospy.pollutionroseplot( - data=df, wd="wd", ws="ws", pollutant="pm25", - suffix="$µgm^{-3}$", segments=30, calm=0.1, - bins=[0, 8, 15, 25, 35, 100] -) \ No newline at end of file + data=df, + wd="wd", + ws="ws", + pollutant="pm25", + suffix="$µgm^{-3}$", + segments=30, + calm=0.1, + bins=[0, 8, 15, 25, 35, 100], +) diff --git a/examples/regression.py b/examples/regression.py index b6f07e0..fc4e72d 100644 --- a/examples/regression.py +++ b/examples/regression.py @@ -4,7 +4,9 @@ _thumb: .4, .4 """ + import atmospy + atmospy.set_theme() # Load the example dataset @@ -12,8 +14,10 @@ # Plot a pollution rose example for PM2.5 atmospy.regplot( - df, x="Reference", y="Sensor A", + df, + x="Reference", + y="Sensor A", ylim=(0, 60), color="g", # title="Performance of Sensor A vs US EPA FEM Reference" -) \ No newline at end of file +) diff --git a/examples/rose_by_month.py b/examples/rose_by_month.py index a404666..edc1029 100644 --- a/examples/rose_by_month.py +++ b/examples/rose_by_month.py @@ -4,8 +4,11 @@ _thumb: .4, .4 """ -import atmospy + import seaborn as sns + +import atmospy + atmospy.set_theme() # Load the example dataset @@ -15,30 +18,31 @@ met.loc[:, "Month"] = met["timestamp_local"].dt.month_name() # conver to long form data -met_long_form = met.melt(id_vars=["timestamp_local", "Month", "ws", "wd"], value_vars=["pm25"]) +met_long_form = met.melt( + id_vars=["timestamp_local", "Month", "ws", "wd"], value_vars=["pm25"] +) # set up the FacetGrid g = sns.FacetGrid( - data=met_long_form, - col="Month", + data=met_long_form, + col="Month", col_wrap=3, subplot_kws={"projection": "polar"}, - despine=False + despine=False, ) # map the dataframe using the pollutionroseplot function g.map_dataframe( - atmospy.pollutionroseplot, - ws="ws", wd="wd", pollutant="value", - faceted=True, - segments=20, - suffix="$µgm^{-3}$" + atmospy.pollutionroseplot, + ws="ws", + wd="wd", + pollutant="value", + faceted=True, + segments=20, + suffix="$µgm^{-3}$", ) # add the legend and place it where it looks nice g.add_legend( - title="$PM_{2.5}$", - bbox_to_anchor=(.535, 0.2), - handlelength=1, - handleheight=1 -) \ No newline at end of file + title="$PM_{2.5}$", bbox_to_anchor=(0.535, 0.2), handlelength=1, handleheight=1 +) diff --git a/tests/test_relational.py b/tests/test_relational.py index d2e548a..3203226 100644 --- a/tests/test_relational.py +++ b/tests/test_relational.py @@ -1,16 +1,17 @@ """Test the relational plots.""" -import pytest -from atmospy import regplot -import pandas as pd -import matplotlib as mpl -from atmospy import load_dataset + import seaborn as sns +from atmospy import load_dataset, regplot + + def test_scatter_basics(): df = load_dataset("air-sensors-pm") - + ax = regplot( - df, x="Reference", y='Sensor A', + df, + x="Reference", + y="Sensor A", ) - - assert isinstance(ax, sns.axisgrid.JointGrid) \ No newline at end of file + + assert isinstance(ax, sns.axisgrid.JointGrid) diff --git a/tests/test_rose.py b/tests/test_rose.py index 8529960..065f07a 100644 --- a/tests/test_rose.py +++ b/tests/test_rose.py @@ -1,17 +1,15 @@ """Test the relational plots.""" -import pytest -from atmospy import pollutionroseplot -import pandas as pd + import matplotlib as mpl -from atmospy import load_dataset -import seaborn as sns + +from atmospy import load_dataset, pollutionroseplot + def test_scatter_basics(): df = load_dataset("air-sensors-met") - + ax = pollutionroseplot( - data=df, ws="ws", wd='wd', pollutant="pm1", suffix="ppb", - segments=30, calm=0.1 + data=df, ws="ws", wd="wd", pollutant="pm1", suffix="ppb", segments=30, calm=0.1 ) - - assert isinstance(ax, mpl.axes._axes.Axes) \ No newline at end of file + + assert isinstance(ax, mpl.axes._axes.Axes) diff --git a/tests/test_stats.py b/tests/test_stats.py index 3acf375..98966ec 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -1,30 +1,33 @@ """Test the statistical functions.""" -import pytest import numpy as np -from atmospy.stats import * +import pytest + from atmospy import load_dataset +from atmospy.stats import air_sensor_stats, fleet_precision + def test_fleet_precision(): df = load_dataset("air-sensors-pm") - + # a ValueError should be raised if fewer than 3 columns are present with pytest.raises(ValueError): fleet_precision(df[["Sensor A", "Sensor B"]]) - + stdev, cv = fleet_precision(df[["Sensor A", "Sensor B", "Sensor C"]]) - + assert cv <= 1.0 + def test_sensor_stats(): df = load_dataset("air-sensors-pm") # Compute the linear fit for a single device with pytest.raises(ValueError): - fit = air_sensor_stats(df["Reference"], df["Sensor A"]) - + air_sensor_stats(df["Reference"], df["Sensor A"]) + df = df[["Reference", "Sensor A"]].dropna() - + stats = air_sensor_stats(df["Reference"], df["Sensor A"]) assert ~np.isnan(stats.slope) @@ -34,7 +37,5 @@ def test_sensor_stats(): assert ~np.isnan(stats.rmse) assert ~np.isnan(stats.nrmse) assert ~np.isnan(stats.nobs) - + assert isinstance(stats.asdict(), dict) - - \ No newline at end of file diff --git a/tests/test_trends.py b/tests/test_trends.py index 605f725..17c08c9 100644 --- a/tests/test_trends.py +++ b/tests/test_trends.py @@ -1,31 +1,42 @@ """Test the trend plots.""" -import pytest -from atmospy import dielplot -import pandas as pd + import matplotlib as mpl -from atmospy import load_dataset +import pandas as pd + +from atmospy import dielplot, load_dataset -def prep_diel_dataset(rs='1min'): + +def prep_diel_dataset(rs="1min"): df = load_dataset("us-ozone") - - single_site_ozone = df[df['Local Site Name'] == df['Local Site Name'].unique()[0]] - + + single_site_ozone = df[df["Local Site Name"] == df["Local Site Name"].unique()[0]] + # Adjust the timezone - single_site_ozone.loc[:, 'Timestamp Local'] = single_site_ozone['Timestamp GMT'].apply(lambda x: x + pd.Timedelta(hours=-7)) - + single_site_ozone.loc[:, "Timestamp Local"] = single_site_ozone[ + "Timestamp GMT" + ].apply(lambda x: x + pd.Timedelta(hours=-7)) + # Resample to {rs}min - single_site_ozone = single_site_ozone.set_index("Timestamp Local").resample(rs).interpolate('linear').reset_index() - + single_site_ozone = ( + single_site_ozone.set_index("Timestamp Local") + .resample(rs) + .interpolate("linear") + .reset_index() + ) + # Adjust to ppb - single_site_ozone['Sample Measurement'] *= 1e3 - + single_site_ozone["Sample Measurement"] *= 1e3 + return single_site_ozone + def test_dielplot_basics(): - df = prep_diel_dataset('15min') - + df = prep_diel_dataset("15min") + ax = dielplot( - df, x="Timestamp Local", y='Sample Measurement', + df, + x="Timestamp Local", + y="Sample Measurement", ) - - assert isinstance(ax, mpl.axes._axes.Axes) \ No newline at end of file + + assert isinstance(ax, mpl.axes._axes.Axes) diff --git a/tests/test_utils.py b/tests/test_utils.py index 254b417..1843ddb 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,19 +1,15 @@ """Test the atmospy utility functions.""" + import tempfile -from urllib.request import urlopen from http.client import HTTPException +from urllib.request import urlopen -import pytest import pandas as pd -from pandas.testing import ( - assert_series_equal, - assert_frame_equal, -) -from atmospy.utils import ( - get_dataset_names, - load_dataset, - DATASET_NAMES_URL -) +import pytest +from pandas.testing import assert_frame_equal + +from atmospy.utils import DATASET_NAMES_URL, get_dataset_names, load_dataset + def _network(t=None, url="https://github.com"): """_summary_ @@ -27,7 +23,7 @@ def _network(t=None, url="https://github.com"): """ if t is None: return lambda x: _network(x, url=url) - + def wrapper(*args, **kwargs): try: f = urlopen(url) @@ -36,7 +32,7 @@ def wrapper(*args, **kwargs): else: f.close() return t(*args, **kwargs) - + return wrapper @@ -44,39 +40,45 @@ def check_load_dataset(name): dataset = load_dataset(name, cache=False) assert isinstance(dataset, pd.DataFrame) + def check_load_cached_dataset(name): with tempfile.TemporaryDirectory() as tmpdir: dataset = load_dataset(name, cache=True, data_home=tmpdir) - + cached_dataset = load_dataset(name, cache=True, data_home=tmpdir) - + assert_frame_equal(dataset, cached_dataset) - + + @_network(url=DATASET_NAMES_URL) def test_get_dataset_names(): names = get_dataset_names() assert names assert "us-ozone" in names - + + @_network(url=DATASET_NAMES_URL) def test_load_datasets(): for name in get_dataset_names(): check_load_dataset(name) - + + @_network(url=DATASET_NAMES_URL) def test_load_cached_dataset_names(): for name in get_dataset_names(): check_load_cached_dataset(name) - + + @_network(url=DATASET_NAMES_URL) def test_load_dataset_string_error(): name = "invalid_name" with pytest.raises(ValueError): load_dataset(name) - + + @_network(url=DATASET_NAMES_URL) def test_load_dataset_type_error(): name = pd.DataFrame() - + with pytest.raises(TypeError): - load_dataset(name) \ No newline at end of file + load_dataset(name)