Skip to content

Commit

Permalink
Added data-truth correlation. Metrics can now be selected as an optio…
Browse files Browse the repository at this point in the history
…n in the triple_collocation_validate function.
  • Loading branch information
fcollas committed Jan 10, 2024
1 parent d42c71f commit 48e6456
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 23 deletions.
2 changes: 1 addition & 1 deletion tests/test_triple_collocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_triple_collocation_simulated_data():
assert set(keys_lvl_2) == set(dict_data.keys())

assert ref in keys_lvl_2
assert len(tc_result['data_sources'][ref].keys()) == 8
assert len(tc_result['data_sources'][ref].keys()) == 6


# def test_triple_collocation_real_data():
Expand Down
77 changes: 55 additions & 22 deletions wavy/triple_collocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,13 +218,19 @@ def SNR_dB(A, B, C, var_err_A):
return 10*np.log10(sensitivity_estimates(A, B, C)/var_err_A)


def triple_collocation_validate(result_dict, ref=None):
def triple_collocation_validate(result_dict,
metric_list=['var_est','rmse','si',
'rho','mean','std'],
ref=None):
'''
Runs the triple collocation given a dictionary
containing three measurements, returns results
in a dictionary.
results_dict: {'name of measurement':list of values}
metric_list: List of the metrics to return, among 'var_est',
'rmse', 'si', 'rho', 'sens', 'snr', 'snr_db', 'fmse', 'mean',
'std'
ref: Name of one of the measurements, must correspond
to one key of results_dict
Expand All @@ -247,44 +253,71 @@ def triple_collocation_validate(result_dict, ref=None):
cov_bc = np.cov(B,C)
cov_ac = np.cov(A,C)

# Sensitivity
sens = [(cov_ab[0][1]*cov_ac[0][1])/cov_bc[0][1],
(cov_ab[0][1]*cov_bc[0][1])/cov_ac[0][1],
(cov_bc[0][1]*cov_ac[0][1])/cov_ab[0][1]]

# Estimate of the variance of random error
var_est = [cov_ab[0][0] - sens[0],
cov_ab[1][1] - sens[1],
cov_bc[1][1] - sens[2]]
cov_ab[1][1] - sens[1],
cov_bc[1][1] - sens[2]]

# Root Mean Square Error
rmse = [np.sqrt(var_est[0]),
np.sqrt(var_est[1]),
np.sqrt(var_est[2])]

si = [rmse[0]/mean_ref*100,
rmse[1]/mean_ref*100,
rmse[2]/mean_ref*100]
# Scatter Index
if 'si' in metric_list:
si = [rmse[0]/mean_ref*100,
rmse[1]/mean_ref*100,
rmse[2]/mean_ref*100]

# Signal to Noise Ratio
snr = [sens[0]/var_est[0],
sens[1]/var_est[1],
sens[2]/var_est[2]]

fmse = [1/(1+snr[0]),
1/(1+snr[1]),
1/(1+snr[2])]

snr_db = [10*np.log10(snr[0]),
10*np.log10(snr[1]),
10*np.log10(snr[2])]
# Fractional Mean Squared Error
if 'fmse' in metric_list:
fmse = [1/(1+snr[0]),
1/(1+snr[1]),
1/(1+snr[2])]

# Signal to Noise Ratio (dB)
if 'snr_db' in metric_list:
snr_db = [10*np.log10(snr[0]),
10*np.log10(snr[1]),
10*np.log10(snr[2])]

# Data truth correlation
if 'rho' in metric_list:
rho = [sens[0]/cov_ab[0][0],
sens[1]/cov_ab[1][1],
sens[2]/cov_bc[1][1]]

for i,k in enumerate(measure_names):

tc_validate['data_sources'][k]['var_est'] = var_est[i]
tc_validate['data_sources'][k]['RMSE'] = rmse[i]
tc_validate['data_sources'][k]['SI'] = si[i]
tc_validate['data_sources'][k]['sensitivity'] = sens[i]
tc_validate['data_sources'][k]['fMSE'] = fmse[i]
tc_validate['data_sources'][k]['SNR_dB'] = snr_db[i]
tc_validate['data_sources'][k]['mean'] = np.mean(measures[i])
tc_validate['data_sources'][k]['std'] = np.std(measures[i])
if 'var_est' in metric_list:
tc_validate['data_sources'][k]['var_est'] = var_est[i]
if 'rmse' in metric_list:
tc_validate['data_sources'][k]['RMSE'] = rmse[i]
if 'si' in metric_list:
tc_validate['data_sources'][k]['SI'] = si[i]
if 'sens' in metric_list:
tc_validate['data_sources'][k]['sensitivity'] = sens[i]
if 'rho' in metric_list:
tc_validate['data_sources'][k]['rho'] = rho[i]
if 'snr' in metric_list:
tc_validate['data_sources'][k]['SNR'] = snr[i]
if 'fmse' in metric_list:
tc_validate['data_sources'][k]['fMSE'] = fmse[i]
if 'snr_db' in metric_list:
tc_validate['data_sources'][k]['SNR_dB'] = snr_db[i]
if 'mean' in metric_list:
tc_validate['data_sources'][k]['mean'] = np.mean(measures[i])
if 'std' in metric_list:
tc_validate['data_sources'][k]['std'] = np.std(measures[i])

return tc_validate

Expand Down

0 comments on commit 48e6456

Please sign in to comment.