Skip to content

Commit

Permalink
Add show methods to everything returned by visuals.* (#804)
Browse files Browse the repository at this point in the history
  • Loading branch information
kinnala authored Nov 21, 2021
1 parent ddeac3e commit d9d5847
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 8 deletions.
10 changes: 8 additions & 2 deletions skfem/mesh/mesh.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import importlib

from dataclasses import dataclass, replace
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
Expand Down Expand Up @@ -906,5 +907,10 @@ def element_finder(self, mapping=None):

def draw(self, *args, **kwargs):
"""Convenience wrapper for vedo."""
from skfem.visuals.vedo import draw
return draw(self, *args, **kwargs)
if 'visuals' in kwargs:
visuals = kwargs['visuals']
del kwargs['visuals']
else:
visuals = 'vedo'
mod = importlib.import_module('skfem.visuals.{}'.format(visuals))
return mod.draw(self, *args, **kwargs)
6 changes: 6 additions & 0 deletions skfem/visuals/matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def draw_meshtet(m: MeshTet1, **kwargs) -> Axes:
ax.plot_trisurf(m.p[0], m.p[1], m.p[2],
triangles=indexing, cmap=plt.cm.viridis, edgecolor='k')
ax.set_axis_off()
ax.show = lambda: plt.show()
return ax


Expand Down Expand Up @@ -118,6 +119,7 @@ def draw_mesh2d(m: Mesh2D, **kwargs) -> Axes:
for itr in range(m.t.shape[1]):
ax.text(mx[itr], my[itr], str(itr))

ax.show = lambda: plt.show()
return ax


Expand All @@ -141,6 +143,7 @@ def plot_meshline(m: MeshLine1, z: ndarray, **kwargs):
ix = np.argsort(m.p[0])
ax.plot(m.p[0][ix], z[ix], color)

ax.show = lambda: plt.show()
return ax


Expand Down Expand Up @@ -194,6 +197,8 @@ def plot_meshtri(m: MeshTri1, z: ndarray, **kwargs) -> Axes:

if "colorbar" in kwargs:
plt.colorbar(im)

ax.show = lambda: plt.show()
return ax


Expand Down Expand Up @@ -255,6 +260,7 @@ def plot3_meshtri(m: MeshTri1, z: ndarray, **kwargs) -> Axes:
cmap=plt.cm.viridis,
antialiased=False)

ax.show = lambda: plt.show()
return ax


Expand Down
31 changes: 25 additions & 6 deletions skfem/visuals/svg.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""Drawing meshes using svg."""

import webbrowser
from functools import singledispatch
from dataclasses import dataclass
from http.server import HTTPServer, BaseHTTPRequestHandler

import numpy as np
from numpy import ndarray

from ..assembly import InteriorBasis
from ..assembly import CellBasis
from ..mesh import Mesh2D


Expand All @@ -29,6 +31,15 @@ def points_to_figure(p, kwargs):
return p, width, height, stroke


class Server(BaseHTTPRequestHandler):

def do_GET(self):
self.send_response(200)
self.send_header("Content-type", "text/html")
self.end_headers()
self.wfile.write(self.svg.encode("utf8"))


@dataclass
class SvgPlot:

Expand All @@ -37,6 +48,14 @@ class SvgPlot:
def _repr_svg_(self) -> str:
return self.svg

def show(self, port=8000):
server = Server
server.svg = self.svg
url = "http://localhost:{}".format(port)
print("Serving the plot at " + url)
webbrowser.open_new_tab(url)
HTTPServer(("localhost", port), Server).handle_request()


@singledispatch
def draw(m, **kwargs) -> SvgPlot:
Expand Down Expand Up @@ -81,13 +100,13 @@ def draw_mesh2d(m: Mesh2D, **kwargs) -> SvgPlot:

@draw.register(Mesh2D)
def draw_geometry2d(m: Mesh2D, **kwargs) -> SvgPlot:
nrefs = kwargs["nrefs"] if "nrefs" in kwargs else 1
nrefs = kwargs["nrefs"] if "nrefs" in kwargs else 0
m = m._splitref(nrefs)
return draw_mesh2d(m, **kwargs)


@draw.register(InteriorBasis)
def draw_basis(ib: InteriorBasis, **kwargs) -> SvgPlot:
@draw.register(CellBasis)
def draw_basis(ib: CellBasis, **kwargs) -> SvgPlot:
nrefs = kwargs["nrefs"] if "nrefs" in kwargs else 2
m, _ = ib.refinterp(ib.mesh.p[0], nrefs=nrefs)
return draw(m, boundaries_only=True, **kwargs)
Expand Down Expand Up @@ -138,8 +157,8 @@ def plot_mesh2d(m: Mesh2D, x: ndarray, **kwargs) -> SvgPlot:
elems))


@plot.register(InteriorBasis)
def plot_basis(ib: InteriorBasis, x: ndarray, **kwargs) -> SvgPlot:
@plot.register(CellBasis)
def plot_basis(ib: CellBasis, x: ndarray, **kwargs) -> SvgPlot:
nrefs = kwargs["nrefs"] if "nrefs" in kwargs else 0
m, X = ib.refinterp(x, nrefs=nrefs)
return plot(m, X, **kwargs)

0 comments on commit d9d5847

Please sign in to comment.