Skip to content

Commit

Permalink
fix saturated values when computing db rate
Browse files Browse the repository at this point in the history
  • Loading branch information
nlsdvl committed Oct 19, 2021
1 parent 71cfd4f commit e8cd8ee
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 82 deletions.
123 changes: 81 additions & 42 deletions compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,55 @@

from anchor import VariantData, VariantMetricSet, Metric, iter_variants, AnchorTuple
from download import AnchorTupleCtx
from metrics import SDR_METRICS, BD_RATE, Metric, sanitize_rd_data1, sanitize_rd_data2, sort_on_rates
from metrics import Metric
import sys, csv

def rounded(v):
return f'{round(v, 2):.2f}'

def rd_plot(r0, d0, r1, d1, dist='psnr', anchor_label='anchor', test_label='test', title='', show=True):
if title == '':
title = f'{dist} rd curve'
fig, axs = plt.subplots(1, 1, figsize=(20,15))
axs.plot(r0, d0, 'o-', r1, d1, 'o-')
axs.set_xlabel('bitrate', fontsize=21)
axs.set_ylabel(dist, fontsize=21)
axs.grid(True)
axs.tick_params(axis='both', which='major', labelsize=21)
axs.set_title(title, fontdict={'fontsize': 24, 'fontweight': 'medium'})
axs.legend([anchor_label, test_label])
if show:
plt.show(block=True)
return fig


def sort_rates_on(rates, metric):
"""
sort [(rate, psnr), ...] samples based on psnr
"""
return rates[np.argsort(metric)], np.sort(metric)

def sort_on_rates(rates, metric):
"""
sort [(rate, psnr), ...] samples based on rate
"""
return np.sort(rates), metric[np.argsort(rates)]


def sanitize_rd_data(rates, dist, step=0.001):
"""workaround for saturated dist values.
returns sanitized rates & dist, sorted on dist.
consecutive samples that have the same values are modified with the given step
so that the sequence is increasing instead of stagnating.
if a sequence is decreasing rather than stagnating, it is not modified.
eg. [50., 50., 50., 50., 50.] becomes [50., 50.001, 50.002, 50.003, 50.004]
[98.999, 99.999, 99.999, 99.999, 100.] becomes [98.999, 99.999, 100., 100.001, 100.002]
[98.999, 99.999, 99.999, 97., 100.] becomes [97, 98., 99.999, 100., 100.001]
"""
rate = np.array(rates)
dist = np.array(dist)
sorted = np.lexsort((rate, dist))
rate = rate[sorted]
dist = dist[sorted]
dist_fix = []
sanitized = False
for i, _ in enumerate(rate):
d = dist[i]
if i and (d == dist[i-1]):
d = dist_fix[-1] + step
sanitized = True
elif i and (d > dist[i-1]) and (d <= dist_fix[-1]):
d = dist_fix[-1] + step
sanitized = True
dist_fix.append(d)
if sanitized:
print("/!\ data has been sanitized:")
print(f" - replaced: {dist}")
print(f" - with : {dist_fix}")
return rate, np.array(dist_fix, dtype=np.float64)

def rd_metrics(variants:List[VariantData], rate="BitrateLog", dist="PSNR") -> Iterable[Any]:
return zip(*[(v.metrics[rate], v.metrics[dist]) for v in variants])
Expand All @@ -39,8 +68,7 @@ def compare_anchors_metrics( anchor:List[VariantData], test:List[VariantData], r
test_metrics = [*rd_metrics(test, rate=rate, dist=dist)]
try:
print("#", dist, "#"*(32-len(str(dist))))
# return BD_RATE(*anchor_metrics, *test_metrics, piecewise=piecewise, sanitize=sanitize)
return BD_RATE_PLOT(*anchor_metrics, *test_metrics, sanitize=sanitize, title=title, dist_label=dist)
return bd_rate_plot(*anchor_metrics, *test_metrics, sanitize=sanitize, title=title, dist_label=dist)
except BaseException as e:
if strict:
raise
Expand Down Expand Up @@ -229,34 +257,42 @@ def csv_dump(data, fp):

#####################################################################################################

def strictly_increasing(samples):
for i, v in enumerate(samples):
if i and v <= samples[i-1]:
return False
return True

def BD_RATE_PLOT(R1, PSNR1, R2, PSNR2, sanitize=False, title="", dist_label="dist"):
def bd_rate_plot(R1, DIST1, R2, DIST2, sanitize=False, title="", dist_label="dist"):

if sanitize:
R1, PSNR1 = sanitize_rd_data1(R1, PSNR1)
R2, PSNR2 = sanitize_rd_data1(R2, PSNR2)
R1, DIST1 = sanitize_rd_data(R1, DIST1)
R2, DIST2 = sanitize_rd_data(R2, DIST2)
b = strictly_increasing(DIST1) and strictly_increasing(DIST2)

else:
PSNR1 = np.array(PSNR1)
PSNR2 = np.array(PSNR2)
DIST1 = np.array(DIST1)
DIST2 = np.array(DIST2)

lR1 = np.log(R1)
lR2 = np.log(R2)

# integration interval
min_int = max(min(PSNR1), min(PSNR2))
max_int = min(max(PSNR1), max(PSNR2))
min_int = max(min(DIST1), min(DIST2))
max_int = min(max(DIST1), max(DIST2))

samples, interval = np.linspace(min_int, max_int, num=100, retstep=True)
[y1, x1] = sort_on_rates(lR1, PSNR1)
[y2, x2] = sort_on_rates(lR2, PSNR2)
err = None
fig = None
avg_diff = None
[r1, d1] = sort_on_rates(lR1, DIST1)
assert strictly_increasing(d1)

[r2, d2] = sort_on_rates(lR2, DIST2)
assert strictly_increasing(d2)

v1, v2, avg_diff, fig = None, None, 0, None
try:
v1 = scipy.interpolate.pchip_interpolate(x1, y1, samples)
v2 = scipy.interpolate.pchip_interpolate(x2, y2, samples)
v1 = scipy.interpolate.pchip_interpolate(d1, r1, samples)
v2 = scipy.interpolate.pchip_interpolate(d2, r2, samples)

# Calculate the integral using the trapezoid method on the samples.
int1 = np.trapz(v1, dx=interval)
int2 = np.trapz(v2, dx=interval)
avg_exp_diff = (int2-int1)/(max_int-min_int)
Expand All @@ -265,7 +301,7 @@ def BD_RATE_PLOT(R1, PSNR1, R2, PSNR2, sanitize=False, title="", dist_label="dis
# plot it
fig, axs = plt.subplots(2, 1, figsize=(10, 20))

axs[0].plot(R1, PSNR1, 'o-', R2, PSNR2, 'o-')
axs[0].plot(R1, DIST1, 'o-', R2, DIST2, 'o-')
axs[0].set_xlabel('bitrate', fontsize=21)
axs[0].set_ylabel(dist_label, fontsize=21)
axs[0].grid(True)
Expand All @@ -275,10 +311,10 @@ def BD_RATE_PLOT(R1, PSNR1, R2, PSNR2, sanitize=False, title="", dist_label="dis
axs[0].axhline(min_int, linestyle='dashed', color='red')
axs[0].axhline(max_int, linestyle='dashed', color='red')

_ = axs[1].plot(y1, x1, 'o:', label="anchor (measured)")
_ = axs[1].plot(y2, x2, 'o:', label="test (measured)")
_ = axs[1].plot(v1, samples, '-', label="anchor (interpolated)")
_ = axs[1].plot(v2, samples, '-', label="test (interpolated)")
axs[1].plot(r1, d1, 'o:', label="anchor (measured)")
axs[1].plot(r2, d2, 'o:', label="test (measured)")
axs[1].plot(v1, samples, '-', label="anchor (interpolated)")
axs[1].plot(v2, samples, '-', label="test (interpolated)")

axs[1].legend()
axs[1].set_xlabel('bitrate (log)', fontsize=21)
Expand All @@ -294,8 +330,10 @@ def BD_RATE_PLOT(R1, PSNR1, R2, PSNR2, sanitize=False, title="", dist_label="dis

except ValueError as ve:
print(ve)
print('d1:', d1)
print('d2:', d2)

return fig, avg_diff, R1, PSNR1, R2, PSNR2
return fig, avg_diff, R1, DIST1, R2, DIST2


#####################################################################################################
Expand All @@ -321,7 +359,7 @@ def main():

if ref.is_dir() and (ref.parent / 'streams.csv').exists() \
and test.is_dir() and (test.parent / 'streams.csv').exists():
# eg. compare.py ./scenario/codec1/a_key1 ./scenario/codec2/a_key2
# eg. compare.py ./scenario/codec1/a_ker1 ./scenario/codec2/a_ker2
plots = [ m.key for m in (
Metric.PSNR_Y,
Metric.PSNR,
Expand All @@ -332,6 +370,7 @@ def main():
r['reference'] = seqid
outp = test.parent / 'Metrics' / f'{ref.name}.{test.name}.csv'.lower()
csv_dump([r], outp)

else:
# eg. compare.py ./scenario/codec1@cfg_id1 ./scenario/codec2@cfg_id2
_, refkey = _parse_filter(ref)
Expand Down
58 changes: 18 additions & 40 deletions metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,58 +409,36 @@ def sort_on_rates(rates, metric):
return np.sort(rates), metric[np.argsort(rates)]


def sanitize_rd_data2(rates, PSNR, step=0.001):
"""- sort samples for increasing rates
- fix saturated values, by adding a step value"""
def sanitize_rd_data(rates, dist, step=0.001):
"""fixes saturated dist values by increasing sequence of equal values by the step value"""
rate = np.array(rates)
dist = np.array(PSNR)
dist = np.array(dist)
sorted = np.lexsort((rate, dist))
rate = rate[sorted]
dist = dist[sorted]
dist_max = dist[-1]
dist_fix = np.array([], dtype=np.float64)
step = 0.001
fix = 0.001
dist_fix = []
sanitized = False
for i, _ in enumerate(rate):
d = dist[i]
if d == dist_max:
d += fix
fix += step
if i and (d == dist[i-1]):
d = dist_fix[-1] + step
sanitized = True
elif i and (d > dist[i-1]) and (d <= dist_fix[-1]):
d = dist_fix[-1] + step
sanitized = True
dist_fix.append(d)
return rate, dist_fix


def sanitize_rd_data1(rates, PSNR, preserve_last_sample=True):
"""- sort samples for increasing rates
- drop samples on non increasing dist values"""
rate = np.array(rates)
dist = np.array(PSNR)
sorted = np.lexsort((rate, dist))
rate = rate[sorted]
dist = dist[sorted]
_sorted = []
for i, r in enumerate(rate):
d = dist[i]
if len(_sorted):
p = _sorted[-1]
if (d <= p[1]):
print(f'/!\ rate increased, but quality did not - dropping sample ( r:{r}, d:{d} ) !')
# keep the last sample on saturated values, drop the previous one
# if preserve_last_sample and (len(rate) == (i+1)):
# dist.pop()
# else:
# continue
continue
_sorted.append((r, d))
return [np.array(arr) for arr in zip(*_sorted)]

if sanitized:
print("/!\ data has been sanitized:")
print(f" - replaced: {dist}")
print(f" - with : {dist_fix}")
return rate, np.array(dist_fix, dtype=np.float64)


def BD_RATE(R1, PSNR1, R2, PSNR2, piecewise=1, sanitize=False) -> float:

if sanitize:
R1, PSNR1 = sanitize_rd_data1(R1, PSNR1)
R2, PSNR2 = sanitize_rd_data1(R2, PSNR2)
R1, PSNR1 = sanitize_rd_data(R1, PSNR1)
R2, PSNR2 = sanitize_rd_data(R2, PSNR2)
else:
PSNR1 = np.array(PSNR1)
PSNR2 = np.array(PSNR2)
Expand Down

0 comments on commit e8cd8ee

Please sign in to comment.