From d42c71fb79ff2e6c69730f8f8929871052061060 Mon Sep 17 00:00:00 2001 From: Fabien Collas Date: Tue, 9 Jan 2024 15:50:59 +0100 Subject: [PATCH] Changed triple_collocation module, added distance to collocation result and optimized tc function --- wavy/triple_collocation.py | 121 ++++++++++++++++++++----------------- 1 file changed, 66 insertions(+), 55 deletions(-) diff --git a/wavy/triple_collocation.py b/wavy/triple_collocation.py index faf38b46..b4ff1269 100644 --- a/wavy/triple_collocation.py +++ b/wavy/triple_collocation.py @@ -97,6 +97,7 @@ def collocate_sat_and_insitu(sco, ico, twin=5, dist_max=200): dist = list(ds_tmp['dist_is_sat'].values) ds_tmp = ds_tmp.isel(time=dist.index(min_dist_tmp)) list_time_sat.append(ds_tmp['time'].data) + list_dist_min.append(min_dist_tmp) list_time_insitu.append(ds_tmp['insitu_time'].data) sco_filter = copy.deepcopy(sco) @@ -106,6 +107,14 @@ def collocate_sat_and_insitu(sco, ico, twin=5, dist_max=200): # lists of times corresponding to minimum distance # between satellite and in-situ observation sco_filter.vars = sco_filter.vars.sel(time=list_time_sat) + sco_filter.vars = sco_filter.vars.assign( + { + 'colloc_dist': ( + 'time', + list_dist_min + ) + } + ) ico_filter.vars = ico_filter.vars.sel(time=list_time_insitu) return sco_filter, ico_filter @@ -209,7 +218,7 @@ def SNR_dB(A, B, C, var_err_A): return 10*np.log10(sensitivity_estimates(A, B, C)/var_err_A) -def triple_collocation_validate(results_dict, ref=None): +def triple_collocation_validate(result_dict, ref=None): ''' Runs the triple collocation given a dictionary containing three measurements, returns results @@ -222,60 +231,62 @@ def triple_collocation_validate(results_dict, ref=None): returns: dict of dict of the metrics for each measurement {'name of measurement': {'metric name':metric}} ''' - validate_dict = {} - - data_sources = results_dict.keys() - - if len(list(data_sources)) != 3: - raise Exception("Exactly three data sources must be provided.") - - # Check if ref argument is, given and correct - # if not, takes the first data source as reference - if ref is None: - ref = list(data_sources)[0] - print("Since no reference was given,", ref, "was taken as default.") - - if ref not in data_sources: - print("Incorrect argument `ref` value given, \ - should be one of the keys of `tc_validate`.") - ref = list(data_sources)[0] - print(ref, "was taken as reference instead.\n") - - # Save data sources into a list - ds_list = list(results_dict.values()) - - # Iterates on key:value from data_sources dict - for i, k in enumerate(data_sources): - - # Temporary dictionary to save - # results of current data source - dict_tmp = {} - - # Rotates the data sources at each step i, - # so A is the ith element of `results_dict.values()` - ds_list.insert(0, ds_list.pop(i)) - A = ds_list[0] - B = ds_list[1] - C = ds_list[2] - - # Calculates all metrics for A which is, at step i, - # the ith element of `results_dict.values()` - dict_tmp["var_est"] = variance_estimates(A, B, C) - dict_tmp["RMSE"] = RMSE(A, B, C) - dict_tmp["SI"] = SI(A, B, C, results_dict[ref]) - dict_tmp["sensitivity"] = sensitivity_estimates(A, B, C) - dict_tmp["fMSE"] = fMSE(A, B, C, dict_tmp["var_est"]) - dict_tmp["SNR_dB"] = SNR_dB(A, B, C, dict_tmp["var_est"]) - dict_tmp["mean"] = np.mean(A) - dict_tmp["std"] = np.std(A) - - # Save the metric dictionary into validate_dict - validate_dict[k] = dict_tmp - - validate_dict = {"data_sources": validate_dict} - validate_dict["reference_dataset"] = ref - - return validate_dict + measure_names = list(result_dict.keys()) + measures = list(result_dict.values()) + + tc_validate = {'data_sources':{key:{} for key in measure_names}, + 'reference_dataset':ref} + + mean_ref = np.mean(result_dict[ref]) + + A = measures[0] + B = measures[1] + C = measures[2] + + cov_ab = np.cov(A,B) + cov_bc = np.cov(B,C) + cov_ac = np.cov(A,C) + + sens = [(cov_ab[0][1]*cov_ac[0][1])/cov_bc[0][1], + (cov_ab[0][1]*cov_bc[0][1])/cov_ac[0][1], + (cov_bc[0][1]*cov_ac[0][1])/cov_ab[0][1]] + + var_est = [cov_ab[0][0] - sens[0], + cov_ab[1][1] - sens[1], + cov_bc[1][1] - sens[2]] + + rmse = [np.sqrt(var_est[0]), + np.sqrt(var_est[1]), + np.sqrt(var_est[2])] + + si = [rmse[0]/mean_ref*100, + rmse[1]/mean_ref*100, + rmse[2]/mean_ref*100] + + snr = [sens[0]/var_est[0], + sens[1]/var_est[1], + sens[2]/var_est[2]] + + fmse = [1/(1+snr[0]), + 1/(1+snr[1]), + 1/(1+snr[2])] + + snr_db = [10*np.log10(snr[0]), + 10*np.log10(snr[1]), + 10*np.log10(snr[2])] + + for i,k in enumerate(measure_names): + + tc_validate['data_sources'][k]['var_est'] = var_est[i] + tc_validate['data_sources'][k]['RMSE'] = rmse[i] + tc_validate['data_sources'][k]['SI'] = si[i] + tc_validate['data_sources'][k]['sensitivity'] = sens[i] + tc_validate['data_sources'][k]['fMSE'] = fmse[i] + tc_validate['data_sources'][k]['SNR_dB'] = snr_db[i] + tc_validate['data_sources'][k]['mean'] = np.mean(measures[i]) + tc_validate['data_sources'][k]['std'] = np.std(measures[i]) + + return tc_validate def disp_tc_validation(tc_validate, dec=3):