Skip to content

Commit

Permalink
Changed triple_collocation module, added distance to collocation resu…
Browse files Browse the repository at this point in the history
…lt and optimized tc function
  • Loading branch information
fcollas committed Jan 9, 2024
1 parent 074ae59 commit d42c71f
Showing 1 changed file with 66 additions and 55 deletions.
121 changes: 66 additions & 55 deletions wavy/triple_collocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def collocate_sat_and_insitu(sco, ico, twin=5, dist_max=200):
dist = list(ds_tmp['dist_is_sat'].values)
ds_tmp = ds_tmp.isel(time=dist.index(min_dist_tmp))
list_time_sat.append(ds_tmp['time'].data)
list_dist_min.append(min_dist_tmp)
list_time_insitu.append(ds_tmp['insitu_time'].data)

sco_filter = copy.deepcopy(sco)
Expand All @@ -106,6 +107,14 @@ def collocate_sat_and_insitu(sco, ico, twin=5, dist_max=200):
# lists of times corresponding to minimum distance
# between satellite and in-situ observation
sco_filter.vars = sco_filter.vars.sel(time=list_time_sat)
sco_filter.vars = sco_filter.vars.assign(
{
'colloc_dist': (
'time',
list_dist_min
)
}
)
ico_filter.vars = ico_filter.vars.sel(time=list_time_insitu)

return sco_filter, ico_filter
Expand Down Expand Up @@ -209,7 +218,7 @@ def SNR_dB(A, B, C, var_err_A):
return 10*np.log10(sensitivity_estimates(A, B, C)/var_err_A)


def triple_collocation_validate(results_dict, ref=None):
def triple_collocation_validate(result_dict, ref=None):
'''
Runs the triple collocation given a dictionary
containing three measurements, returns results
Expand All @@ -222,60 +231,62 @@ def triple_collocation_validate(results_dict, ref=None):
returns: dict of dict of the metrics for each measurement
{'name of measurement': {'metric name':metric}}
'''
validate_dict = {}

data_sources = results_dict.keys()

if len(list(data_sources)) != 3:
raise Exception("Exactly three data sources must be provided.")

# Check if ref argument is, given and correct
# if not, takes the first data source as reference
if ref is None:
ref = list(data_sources)[0]
print("Since no reference was given,", ref, "was taken as default.")

if ref not in data_sources:
print("Incorrect argument `ref` value given, \
should be one of the keys of `tc_validate`.")
ref = list(data_sources)[0]
print(ref, "was taken as reference instead.\n")

# Save data sources into a list
ds_list = list(results_dict.values())

# Iterates on key:value from data_sources dict
for i, k in enumerate(data_sources):

# Temporary dictionary to save
# results of current data source
dict_tmp = {}

# Rotates the data sources at each step i,
# so A is the ith element of `results_dict.values()`
ds_list.insert(0, ds_list.pop(i))
A = ds_list[0]
B = ds_list[1]
C = ds_list[2]

# Calculates all metrics for A which is, at step i,
# the ith element of `results_dict.values()`
dict_tmp["var_est"] = variance_estimates(A, B, C)
dict_tmp["RMSE"] = RMSE(A, B, C)
dict_tmp["SI"] = SI(A, B, C, results_dict[ref])
dict_tmp["sensitivity"] = sensitivity_estimates(A, B, C)
dict_tmp["fMSE"] = fMSE(A, B, C, dict_tmp["var_est"])
dict_tmp["SNR_dB"] = SNR_dB(A, B, C, dict_tmp["var_est"])
dict_tmp["mean"] = np.mean(A)
dict_tmp["std"] = np.std(A)

# Save the metric dictionary into validate_dict
validate_dict[k] = dict_tmp

validate_dict = {"data_sources": validate_dict}
validate_dict["reference_dataset"] = ref

return validate_dict
measure_names = list(result_dict.keys())
measures = list(result_dict.values())

tc_validate = {'data_sources':{key:{} for key in measure_names},
'reference_dataset':ref}

mean_ref = np.mean(result_dict[ref])

A = measures[0]
B = measures[1]
C = measures[2]

cov_ab = np.cov(A,B)
cov_bc = np.cov(B,C)
cov_ac = np.cov(A,C)

sens = [(cov_ab[0][1]*cov_ac[0][1])/cov_bc[0][1],
(cov_ab[0][1]*cov_bc[0][1])/cov_ac[0][1],
(cov_bc[0][1]*cov_ac[0][1])/cov_ab[0][1]]

var_est = [cov_ab[0][0] - sens[0],
cov_ab[1][1] - sens[1],
cov_bc[1][1] - sens[2]]

rmse = [np.sqrt(var_est[0]),
np.sqrt(var_est[1]),
np.sqrt(var_est[2])]

si = [rmse[0]/mean_ref*100,
rmse[1]/mean_ref*100,
rmse[2]/mean_ref*100]

snr = [sens[0]/var_est[0],
sens[1]/var_est[1],
sens[2]/var_est[2]]

fmse = [1/(1+snr[0]),
1/(1+snr[1]),
1/(1+snr[2])]

snr_db = [10*np.log10(snr[0]),
10*np.log10(snr[1]),
10*np.log10(snr[2])]

for i,k in enumerate(measure_names):

tc_validate['data_sources'][k]['var_est'] = var_est[i]
tc_validate['data_sources'][k]['RMSE'] = rmse[i]
tc_validate['data_sources'][k]['SI'] = si[i]
tc_validate['data_sources'][k]['sensitivity'] = sens[i]
tc_validate['data_sources'][k]['fMSE'] = fmse[i]
tc_validate['data_sources'][k]['SNR_dB'] = snr_db[i]
tc_validate['data_sources'][k]['mean'] = np.mean(measures[i])
tc_validate['data_sources'][k]['std'] = np.std(measures[i])

return tc_validate


def disp_tc_validation(tc_validate, dec=3):
Expand Down

0 comments on commit d42c71f

Please sign in to comment.