diff --git a/wwdata/Class_HydroData.py b/wwdata/Class_HydroData.py index 51013c40e..c6e73a7ae 100644 --- a/wwdata/Class_HydroData.py +++ b/wwdata/Class_HydroData.py @@ -1451,6 +1451,9 @@ def get_correlation(self,data_1,data_2,arange,zero_intercept=False, default to 'False' if a value in one column is filtered, the corresponding value in the second column also gets excluded! + plot : bool + if true, a plot is made, comparing the original data with the calculated + prediction Returns ------- @@ -1544,7 +1547,7 @@ def get_correlation(self,data_1,data_2,arange,zero_intercept=False, return slope,intercept,r_sq - def detect_drift(self, data_name, arange, max_slope, period=None): + def detect_drift(self, data_name, arange, max_slope=None, period=None, plot=False): # data input or using self.data? """ This function calculates the slope of the data in a certain given @@ -1559,8 +1562,10 @@ def detect_drift(self, data_name, arange, max_slope, period=None): the range in which to apply the function max_slope : int the maximum slope a signal is expected to have over a certain period - period : + period : int the period over which a certain slope is allowed + plot : bool + if true, a plot is made, ....... Returns ---------- @@ -1569,20 +1574,44 @@ def detect_drift(self, data_name, arange, max_slope, period=None): from scipy import signal series = self.data[data_name][arange[0]:arange[1]].copy() + #removes NaNs and infs from the dataset + index = 0 + nan_values = [] + for value in series: + try: + signal.detrend([value]) + except ValueError: + nan_values.append(index) + index += 1 + series = series.drop(index=series[nan_values].index) + + if max_slope is None: + print('Please specify a maximum slope') + return KeyError + if period is None or period is arange: + detrended_values = signal.detrend(series) + line_segment = series - detrended_values[:] + slope = (int(line_segment[-1]) - int(line_segment[0])) / (arange[1].day - arange[0].day + 1) - detrended_values = signal.detrend(series[:]) - line_segment = series[:] - detrended_values[:] - slope = (int(line_segment[-1]) - int(line_segment[0])) / len(series) - print(slope) if slope > max_slope: - print('The actual slope is larger than the specified max slope') - - #plt.plot(detrended_values, 'r', series[:], 'g', line_segment, 'y') + print('Based on the specified maximum slope, a drift was' + ' detected with a slope higher than the maximum one. \n' + 'Slope detected: {}, maximum slope: {}'.format(slope, max_slope)) else: + if type(period) is int: + for n in range(len(series)-period): + pass + pass + else: + print('period must be an integer') + return ValueError pass + if plot is True: + print(plt.plot(detrended_values, 'r', line_segment, 'y', series[:], 'g')) + return None