From 48e645616195b500d6f09de54ccb17feaeeff639 Mon Sep 17 00:00:00 2001 From: Fabien Collas Date: Wed, 10 Jan 2024 16:45:45 +0100 Subject: [PATCH] Added data-truth correlation. Metrics can now be selected as an option in the triple_collocation_validate function. --- tests/test_triple_collocation.py | 2 +- wavy/triple_collocation.py | 77 +++++++++++++++++++++++--------- 2 files changed, 56 insertions(+), 23 deletions(-) diff --git a/tests/test_triple_collocation.py b/tests/test_triple_collocation.py index 55b75e8e..cad93044 100644 --- a/tests/test_triple_collocation.py +++ b/tests/test_triple_collocation.py @@ -43,7 +43,7 @@ def test_triple_collocation_simulated_data(): assert set(keys_lvl_2) == set(dict_data.keys()) assert ref in keys_lvl_2 - assert len(tc_result['data_sources'][ref].keys()) == 8 + assert len(tc_result['data_sources'][ref].keys()) == 6 # def test_triple_collocation_real_data(): diff --git a/wavy/triple_collocation.py b/wavy/triple_collocation.py index b4ff1269..ffa94568 100644 --- a/wavy/triple_collocation.py +++ b/wavy/triple_collocation.py @@ -218,13 +218,19 @@ def SNR_dB(A, B, C, var_err_A): return 10*np.log10(sensitivity_estimates(A, B, C)/var_err_A) -def triple_collocation_validate(result_dict, ref=None): +def triple_collocation_validate(result_dict, + metric_list=['var_est','rmse','si', + 'rho','mean','std'], + ref=None): ''' Runs the triple collocation given a dictionary containing three measurements, returns results in a dictionary. results_dict: {'name of measurement':list of values} + metric_list: List of the metrics to return, among 'var_est', + 'rmse', 'si', 'rho', 'sens', 'snr', 'snr_db', 'fmse', 'mean', + 'std' ref: Name of one of the measurements, must correspond to one key of results_dict @@ -247,44 +253,71 @@ def triple_collocation_validate(result_dict, ref=None): cov_bc = np.cov(B,C) cov_ac = np.cov(A,C) + # Sensitivity sens = [(cov_ab[0][1]*cov_ac[0][1])/cov_bc[0][1], (cov_ab[0][1]*cov_bc[0][1])/cov_ac[0][1], (cov_bc[0][1]*cov_ac[0][1])/cov_ab[0][1]] + # Estimate of the variance of random error var_est = [cov_ab[0][0] - sens[0], - cov_ab[1][1] - sens[1], - cov_bc[1][1] - sens[2]] + cov_ab[1][1] - sens[1], + cov_bc[1][1] - sens[2]] + # Root Mean Square Error rmse = [np.sqrt(var_est[0]), np.sqrt(var_est[1]), np.sqrt(var_est[2])] - si = [rmse[0]/mean_ref*100, - rmse[1]/mean_ref*100, - rmse[2]/mean_ref*100] + # Scatter Index + if 'si' in metric_list: + si = [rmse[0]/mean_ref*100, + rmse[1]/mean_ref*100, + rmse[2]/mean_ref*100] + # Signal to Noise Ratio snr = [sens[0]/var_est[0], sens[1]/var_est[1], sens[2]/var_est[2]] - fmse = [1/(1+snr[0]), - 1/(1+snr[1]), - 1/(1+snr[2])] - - snr_db = [10*np.log10(snr[0]), - 10*np.log10(snr[1]), - 10*np.log10(snr[2])] + # Fractional Mean Squared Error + if 'fmse' in metric_list: + fmse = [1/(1+snr[0]), + 1/(1+snr[1]), + 1/(1+snr[2])] + + # Signal to Noise Ratio (dB) + if 'snr_db' in metric_list: + snr_db = [10*np.log10(snr[0]), + 10*np.log10(snr[1]), + 10*np.log10(snr[2])] + + # Data truth correlation + if 'rho' in metric_list: + rho = [sens[0]/cov_ab[0][0], + sens[1]/cov_ab[1][1], + sens[2]/cov_bc[1][1]] for i,k in enumerate(measure_names): - - tc_validate['data_sources'][k]['var_est'] = var_est[i] - tc_validate['data_sources'][k]['RMSE'] = rmse[i] - tc_validate['data_sources'][k]['SI'] = si[i] - tc_validate['data_sources'][k]['sensitivity'] = sens[i] - tc_validate['data_sources'][k]['fMSE'] = fmse[i] - tc_validate['data_sources'][k]['SNR_dB'] = snr_db[i] - tc_validate['data_sources'][k]['mean'] = np.mean(measures[i]) - tc_validate['data_sources'][k]['std'] = np.std(measures[i]) + if 'var_est' in metric_list: + tc_validate['data_sources'][k]['var_est'] = var_est[i] + if 'rmse' in metric_list: + tc_validate['data_sources'][k]['RMSE'] = rmse[i] + if 'si' in metric_list: + tc_validate['data_sources'][k]['SI'] = si[i] + if 'sens' in metric_list: + tc_validate['data_sources'][k]['sensitivity'] = sens[i] + if 'rho' in metric_list: + tc_validate['data_sources'][k]['rho'] = rho[i] + if 'snr' in metric_list: + tc_validate['data_sources'][k]['SNR'] = snr[i] + if 'fmse' in metric_list: + tc_validate['data_sources'][k]['fMSE'] = fmse[i] + if 'snr_db' in metric_list: + tc_validate['data_sources'][k]['SNR_dB'] = snr_db[i] + if 'mean' in metric_list: + tc_validate['data_sources'][k]['mean'] = np.mean(measures[i]) + if 'std' in metric_list: + tc_validate['data_sources'][k]['std'] = np.std(measures[i]) return tc_validate