From e8cd8eef53bf6cc95451a755e319505cc8d9ea21 Mon Sep 17 00:00:00 2001 From: nils Date: Tue, 19 Oct 2021 11:17:23 +0200 Subject: [PATCH] fix saturated values when computing db rate --- compare.py | 123 +++++++++++++++++++++++++++++++++++------------------ metrics.py | 58 ++++++++----------------- 2 files changed, 99 insertions(+), 82 deletions(-) diff --git a/compare.py b/compare.py index dbcba32..efeade7 100755 --- a/compare.py +++ b/compare.py @@ -10,26 +10,55 @@ from anchor import VariantData, VariantMetricSet, Metric, iter_variants, AnchorTuple from download import AnchorTupleCtx -from metrics import SDR_METRICS, BD_RATE, Metric, sanitize_rd_data1, sanitize_rd_data2, sort_on_rates +from metrics import Metric import sys, csv -def rounded(v): - return f'{round(v, 2):.2f}' - -def rd_plot(r0, d0, r1, d1, dist='psnr', anchor_label='anchor', test_label='test', title='', show=True): - if title == '': - title = f'{dist} rd curve' - fig, axs = plt.subplots(1, 1, figsize=(20,15)) - axs.plot(r0, d0, 'o-', r1, d1, 'o-') - axs.set_xlabel('bitrate', fontsize=21) - axs.set_ylabel(dist, fontsize=21) - axs.grid(True) - axs.tick_params(axis='both', which='major', labelsize=21) - axs.set_title(title, fontdict={'fontsize': 24, 'fontweight': 'medium'}) - axs.legend([anchor_label, test_label]) - if show: - plt.show(block=True) - return fig + + +def sort_rates_on(rates, metric): + """ + sort [(rate, psnr), ...] samples based on psnr + """ + return rates[np.argsort(metric)], np.sort(metric) + +def sort_on_rates(rates, metric): + """ + sort [(rate, psnr), ...] samples based on rate + """ + return np.sort(rates), metric[np.argsort(rates)] + + +def sanitize_rd_data(rates, dist, step=0.001): + """workaround for saturated dist values. + returns sanitized rates & dist, sorted on dist. + consecutive samples that have the same values are modified with the given step + so that the sequence is increasing instead of stagnating. + if a sequence is decreasing rather than stagnating, it is not modified. + eg. [50., 50., 50., 50., 50.] becomes [50., 50.001, 50.002, 50.003, 50.004] + [98.999, 99.999, 99.999, 99.999, 100.] becomes [98.999, 99.999, 100., 100.001, 100.002] + [98.999, 99.999, 99.999, 97., 100.] becomes [97, 98., 99.999, 100., 100.001] + """ + rate = np.array(rates) + dist = np.array(dist) + sorted = np.lexsort((rate, dist)) + rate = rate[sorted] + dist = dist[sorted] + dist_fix = [] + sanitized = False + for i, _ in enumerate(rate): + d = dist[i] + if i and (d == dist[i-1]): + d = dist_fix[-1] + step + sanitized = True + elif i and (d > dist[i-1]) and (d <= dist_fix[-1]): + d = dist_fix[-1] + step + sanitized = True + dist_fix.append(d) + if sanitized: + print("/!\ data has been sanitized:") + print(f" - replaced: {dist}") + print(f" - with : {dist_fix}") + return rate, np.array(dist_fix, dtype=np.float64) def rd_metrics(variants:List[VariantData], rate="BitrateLog", dist="PSNR") -> Iterable[Any]: return zip(*[(v.metrics[rate], v.metrics[dist]) for v in variants]) @@ -39,8 +68,7 @@ def compare_anchors_metrics( anchor:List[VariantData], test:List[VariantData], r test_metrics = [*rd_metrics(test, rate=rate, dist=dist)] try: print("#", dist, "#"*(32-len(str(dist)))) - # return BD_RATE(*anchor_metrics, *test_metrics, piecewise=piecewise, sanitize=sanitize) - return BD_RATE_PLOT(*anchor_metrics, *test_metrics, sanitize=sanitize, title=title, dist_label=dist) + return bd_rate_plot(*anchor_metrics, *test_metrics, sanitize=sanitize, title=title, dist_label=dist) except BaseException as e: if strict: raise @@ -229,34 +257,42 @@ def csv_dump(data, fp): ##################################################################################################### +def strictly_increasing(samples): + for i, v in enumerate(samples): + if i and v <= samples[i-1]: + return False + return True -def BD_RATE_PLOT(R1, PSNR1, R2, PSNR2, sanitize=False, title="", dist_label="dist"): +def bd_rate_plot(R1, DIST1, R2, DIST2, sanitize=False, title="", dist_label="dist"): if sanitize: - R1, PSNR1 = sanitize_rd_data1(R1, PSNR1) - R2, PSNR2 = sanitize_rd_data1(R2, PSNR2) + R1, DIST1 = sanitize_rd_data(R1, DIST1) + R2, DIST2 = sanitize_rd_data(R2, DIST2) + b = strictly_increasing(DIST1) and strictly_increasing(DIST2) + else: - PSNR1 = np.array(PSNR1) - PSNR2 = np.array(PSNR2) + DIST1 = np.array(DIST1) + DIST2 = np.array(DIST2) lR1 = np.log(R1) lR2 = np.log(R2) # integration interval - min_int = max(min(PSNR1), min(PSNR2)) - max_int = min(max(PSNR1), max(PSNR2)) + min_int = max(min(DIST1), min(DIST2)) + max_int = min(max(DIST1), max(DIST2)) samples, interval = np.linspace(min_int, max_int, num=100, retstep=True) - [y1, x1] = sort_on_rates(lR1, PSNR1) - [y2, x2] = sort_on_rates(lR2, PSNR2) - err = None - fig = None - avg_diff = None + [r1, d1] = sort_on_rates(lR1, DIST1) + assert strictly_increasing(d1) + + [r2, d2] = sort_on_rates(lR2, DIST2) + assert strictly_increasing(d2) + + v1, v2, avg_diff, fig = None, None, 0, None try: - v1 = scipy.interpolate.pchip_interpolate(x1, y1, samples) - v2 = scipy.interpolate.pchip_interpolate(x2, y2, samples) + v1 = scipy.interpolate.pchip_interpolate(d1, r1, samples) + v2 = scipy.interpolate.pchip_interpolate(d2, r2, samples) - # Calculate the integral using the trapezoid method on the samples. int1 = np.trapz(v1, dx=interval) int2 = np.trapz(v2, dx=interval) avg_exp_diff = (int2-int1)/(max_int-min_int) @@ -265,7 +301,7 @@ def BD_RATE_PLOT(R1, PSNR1, R2, PSNR2, sanitize=False, title="", dist_label="dis # plot it fig, axs = plt.subplots(2, 1, figsize=(10, 20)) - axs[0].plot(R1, PSNR1, 'o-', R2, PSNR2, 'o-') + axs[0].plot(R1, DIST1, 'o-', R2, DIST2, 'o-') axs[0].set_xlabel('bitrate', fontsize=21) axs[0].set_ylabel(dist_label, fontsize=21) axs[0].grid(True) @@ -275,10 +311,10 @@ def BD_RATE_PLOT(R1, PSNR1, R2, PSNR2, sanitize=False, title="", dist_label="dis axs[0].axhline(min_int, linestyle='dashed', color='red') axs[0].axhline(max_int, linestyle='dashed', color='red') - _ = axs[1].plot(y1, x1, 'o:', label="anchor (measured)") - _ = axs[1].plot(y2, x2, 'o:', label="test (measured)") - _ = axs[1].plot(v1, samples, '-', label="anchor (interpolated)") - _ = axs[1].plot(v2, samples, '-', label="test (interpolated)") + axs[1].plot(r1, d1, 'o:', label="anchor (measured)") + axs[1].plot(r2, d2, 'o:', label="test (measured)") + axs[1].plot(v1, samples, '-', label="anchor (interpolated)") + axs[1].plot(v2, samples, '-', label="test (interpolated)") axs[1].legend() axs[1].set_xlabel('bitrate (log)', fontsize=21) @@ -294,8 +330,10 @@ def BD_RATE_PLOT(R1, PSNR1, R2, PSNR2, sanitize=False, title="", dist_label="dis except ValueError as ve: print(ve) + print('d1:', d1) + print('d2:', d2) - return fig, avg_diff, R1, PSNR1, R2, PSNR2 + return fig, avg_diff, R1, DIST1, R2, DIST2 ##################################################################################################### @@ -321,7 +359,7 @@ def main(): if ref.is_dir() and (ref.parent / 'streams.csv').exists() \ and test.is_dir() and (test.parent / 'streams.csv').exists(): - # eg. compare.py ./scenario/codec1/a_key1 ./scenario/codec2/a_key2 + # eg. compare.py ./scenario/codec1/a_ker1 ./scenario/codec2/a_ker2 plots = [ m.key for m in ( Metric.PSNR_Y, Metric.PSNR, @@ -332,6 +370,7 @@ def main(): r['reference'] = seqid outp = test.parent / 'Metrics' / f'{ref.name}.{test.name}.csv'.lower() csv_dump([r], outp) + else: # eg. compare.py ./scenario/codec1@cfg_id1 ./scenario/codec2@cfg_id2 _, refkey = _parse_filter(ref) diff --git a/metrics.py b/metrics.py index ca060ad..693fb76 100644 --- a/metrics.py +++ b/metrics.py @@ -409,58 +409,36 @@ def sort_on_rates(rates, metric): return np.sort(rates), metric[np.argsort(rates)] -def sanitize_rd_data2(rates, PSNR, step=0.001): - """- sort samples for increasing rates - - fix saturated values, by adding a step value""" +def sanitize_rd_data(rates, dist, step=0.001): + """fixes saturated dist values by increasing sequence of equal values by the step value""" rate = np.array(rates) - dist = np.array(PSNR) + dist = np.array(dist) sorted = np.lexsort((rate, dist)) rate = rate[sorted] dist = dist[sorted] - dist_max = dist[-1] - dist_fix = np.array([], dtype=np.float64) - step = 0.001 - fix = 0.001 + dist_fix = [] + sanitized = False for i, _ in enumerate(rate): d = dist[i] - if d == dist_max: - d += fix - fix += step + if i and (d == dist[i-1]): + d = dist_fix[-1] + step + sanitized = True + elif i and (d > dist[i-1]) and (d <= dist_fix[-1]): + d = dist_fix[-1] + step + sanitized = True dist_fix.append(d) - return rate, dist_fix - - -def sanitize_rd_data1(rates, PSNR, preserve_last_sample=True): - """- sort samples for increasing rates - - drop samples on non increasing dist values""" - rate = np.array(rates) - dist = np.array(PSNR) - sorted = np.lexsort((rate, dist)) - rate = rate[sorted] - dist = dist[sorted] - _sorted = [] - for i, r in enumerate(rate): - d = dist[i] - if len(_sorted): - p = _sorted[-1] - if (d <= p[1]): - print(f'/!\ rate increased, but quality did not - dropping sample ( r:{r}, d:{d} ) !') - # keep the last sample on saturated values, drop the previous one - # if preserve_last_sample and (len(rate) == (i+1)): - # dist.pop() - # else: - # continue - continue - _sorted.append((r, d)) - return [np.array(arr) for arr in zip(*_sorted)] - + if sanitized: + print("/!\ data has been sanitized:") + print(f" - replaced: {dist}") + print(f" - with : {dist_fix}") + return rate, np.array(dist_fix, dtype=np.float64) def BD_RATE(R1, PSNR1, R2, PSNR2, piecewise=1, sanitize=False) -> float: if sanitize: - R1, PSNR1 = sanitize_rd_data1(R1, PSNR1) - R2, PSNR2 = sanitize_rd_data1(R2, PSNR2) + R1, PSNR1 = sanitize_rd_data(R1, PSNR1) + R2, PSNR2 = sanitize_rd_data(R2, PSNR2) else: PSNR1 = np.array(PSNR1) PSNR2 = np.array(PSNR2)