From bac261e1beffae67b3a34e6811f063c368ce3b0a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Davide=20Sandon=C3=A0?= <sandona.davide@gmail.com>
Date: Sat, 27 Apr 2024 13:00:30 +0200
Subject: [PATCH] added contour support to BokehBackend

---
 doc/source/changelog.rst                |  2 +
 spb/backends/bokeh/renderers/contour.py | 56 +++++++++++--------------
 tests/backends/test_bokeh.py            |  6 +--
 3 files changed, 29 insertions(+), 35 deletions(-)

diff --git a/doc/source/changelog.rst b/doc/source/changelog.rst
index 29b1192..2164067 100644
--- a/doc/source/changelog.rst
+++ b/doc/source/changelog.rst
@@ -7,6 +7,8 @@ v3.4.0
 
 * Implemented animations.
 
+* ``BokehBackend`` is now able to create contour plots.
+
 
 v3.3.0
 ======
diff --git a/spb/backends/bokeh/renderers/contour.py b/spb/backends/bokeh/renderers/contour.py
index 057f5a6..5b6c739 100644
--- a/spb/backends/bokeh/renderers/contour.py
+++ b/spb/backends/bokeh/renderers/contour.py
@@ -9,37 +9,30 @@ def _draw_contour_helper(renderer, data):
     if s.is_polar:
         raise NotImplementedError()
     x, y, z = data
-    x, y, zz = [t.flatten() for t in [x, y, z]]
-    minx, miny, minz = min(x), min(y), min(zz)
-    maxx, maxy, maxz = max(x), max(y), max(zz)
+
+    # NOTE: at the time of writing this, Bokeh doesn't support
+    # levels=int number.
+    if "levels" in s.rendering_kw.keys():
+        levels = s.rendering_kw["levels"]
+    else:
+        levels = p.np.linspace(z.min(), z.max(), 10)
+
+    if (not s.is_filled) and s.show_clabels:
+        warnings.warn("BokehBackend doesn't currently support contour labels.")
 
     cm = next(p._cm)
-    ckw = dict(palette=cm)
+    ckw = dict(fill_color=cm, line_color=cm, levels=levels)
     kw = p.merge({}, ckw, s.rendering_kw)
 
-    if not s.is_filled:
-        warnings.warn("Bokeh does not support line contours.")
-
-    h = p._fig.image(
-        image=[z],
-        x=minx,
-        y=miny,
-        dw=abs(maxx - minx),
-        dh=abs(maxy - miny),
-        **kw
-    )
-    handle.append(h)
+    h = p._fig.contour(x, y, z, **kw)
+    handle.extend([h, levels])
     p._fig.add_tools(p.bokeh.models.HoverTool(
-        tooltips=[("x", "$x"), ("y", "$y"), ("z", "@image")],
+        tooltips=[("x", "@x"), ("y", "@y"), ("z", "@z")],
         renderers=[handle[0]]
     ))
 
     if s.colorbar:
-        colormapper = p.bokeh.models.LinearColorMapper(
-            palette=cm, low=minz, high=maxz)
-        cbkw = dict(width=8, title=s.get_label(p._use_latex))
-        colorbar = p.bokeh.models.ColorBar(
-            color_mapper=colormapper, **cbkw)
+        colorbar = h.construct_color_bar(title=s.get_label(p._use_latex))
         p._fig.add_layout(colorbar, "right")
         handle.append(colorbar)
 
@@ -47,18 +40,17 @@ def _draw_contour_helper(renderer, data):
 
 
 def _update_contour_helper(renderer, data, handle):
-    s = renderer.series
+    p, s = renderer.plot, renderer.series
     x, y, z = data
-    minx, miny, minz = x.min(), y.min(), z.min()
-    maxx, maxy, maxz = x.max(), y.max(), z.max()
-    handle[0].data_source.data.update({"image": [z]})
-    handle[0].glyph.x = minx
-    handle[0].glyph.y = miny
-    handle[0].glyph.dw = abs(maxx - minx)
-    handle[0].glyph.dh = abs(maxy - miny)
+    countour_handle, levels, cb_handle = handle
+    levels = p.np.linspace(z.min(), z.max(), len(levels))
+
+    contour_data = p.bokeh.plotting.contour.contour_data(x, y, z, levels)
+    handle[0].set_data(contour_data)
     if s.colorbar:
-        cb = handle[1]
-        cb.color_mapper.update(low=minz, high=maxz)
+        # NOTE: as of Bokeh 3.4.1, there is a bug that prevents ticks to
+        # be updated.
+        handle[2].update(levels=list(levels))
 
 
 class ContourRenderer(Renderer):
diff --git a/tests/backends/test_bokeh.py b/tests/backends/test_bokeh.py
index d4db6df..540e1b4 100644
--- a/tests/backends/test_bokeh.py
+++ b/tests/backends/test_bokeh.py
@@ -239,7 +239,7 @@ def test_plot_contour(use_latex, xl, yl, label_func):
     assert len(p.backend.series) == 1
     f = p.fig
     assert len(f.renderers) == 1
-    assert isinstance(f.renderers[0].glyph, bokeh.models.glyphs.Image)
+    assert isinstance(f.renderers[0], bokeh.models.ContourRenderer)
     # 1 colorbar
     assert len(f.right) == 1
     assert f.right[0].title == label_func(use_latex, cos(a*x**2 + y**2))
@@ -267,7 +267,7 @@ def test_plot_vector_2d_quivers(pivot, success):
         assert len(p.backend.series) == 2
         f = p.fig
         assert len(f.renderers) == 2
-        assert isinstance(f.renderers[0].glyph, bokeh.models.glyphs.Image)
+        assert isinstance(f.renderers[0], bokeh.models.ContourRenderer)
         assert isinstance(f.renderers[1].glyph, bokeh.models.glyphs.Segment)
         # 1 colorbar
         assert len(f.right) == 1
@@ -301,7 +301,7 @@ def test_plot_vector_2d_streamlines_custom_scalar_field(
     assert len(p.backend.series) == 2
     f = p.fig
     assert len(f.renderers) == 2
-    assert isinstance(f.renderers[0].glyph, bokeh.models.glyphs.Image)
+    assert isinstance(f.renderers[0], bokeh.models.ContourRenderer)
     assert isinstance(f.renderers[1].glyph, bokeh.models.glyphs.MultiLine)
     # 1 colorbar
     assert len(f.right) == 1