diff --git a/skfem/mesh/mesh.py b/skfem/mesh/mesh.py index 4cf22bd24..93861f86a 100644 --- a/skfem/mesh/mesh.py +++ b/skfem/mesh/mesh.py @@ -1,4 +1,5 @@ import logging +import importlib from dataclasses import dataclass, replace from typing import Callable, Dict, List, Optional, Tuple, Type, Union @@ -906,5 +907,10 @@ def element_finder(self, mapping=None): def draw(self, *args, **kwargs): """Convenience wrapper for vedo.""" - from skfem.visuals.vedo import draw - return draw(self, *args, **kwargs) + if 'visuals' in kwargs: + visuals = kwargs['visuals'] + del kwargs['visuals'] + else: + visuals = 'vedo' + mod = importlib.import_module('skfem.visuals.{}'.format(visuals)) + return mod.draw(self, *args, **kwargs) diff --git a/skfem/visuals/matplotlib.py b/skfem/visuals/matplotlib.py index ccffa70ae..5eb3c8357 100644 --- a/skfem/visuals/matplotlib.py +++ b/skfem/visuals/matplotlib.py @@ -39,6 +39,7 @@ def draw_meshtet(m: MeshTet1, **kwargs) -> Axes: ax.plot_trisurf(m.p[0], m.p[1], m.p[2], triangles=indexing, cmap=plt.cm.viridis, edgecolor='k') ax.set_axis_off() + ax.show = lambda: plt.show() return ax @@ -118,6 +119,7 @@ def draw_mesh2d(m: Mesh2D, **kwargs) -> Axes: for itr in range(m.t.shape[1]): ax.text(mx[itr], my[itr], str(itr)) + ax.show = lambda: plt.show() return ax @@ -141,6 +143,7 @@ def plot_meshline(m: MeshLine1, z: ndarray, **kwargs): ix = np.argsort(m.p[0]) ax.plot(m.p[0][ix], z[ix], color) + ax.show = lambda: plt.show() return ax @@ -194,6 +197,8 @@ def plot_meshtri(m: MeshTri1, z: ndarray, **kwargs) -> Axes: if "colorbar" in kwargs: plt.colorbar(im) + + ax.show = lambda: plt.show() return ax @@ -255,6 +260,7 @@ def plot3_meshtri(m: MeshTri1, z: ndarray, **kwargs) -> Axes: cmap=plt.cm.viridis, antialiased=False) + ax.show = lambda: plt.show() return ax diff --git a/skfem/visuals/svg.py b/skfem/visuals/svg.py index e0c748437..c910d692f 100644 --- a/skfem/visuals/svg.py +++ b/skfem/visuals/svg.py @@ -1,12 +1,14 @@ """Drawing meshes using svg.""" +import webbrowser from functools import singledispatch from dataclasses import dataclass +from http.server import HTTPServer, BaseHTTPRequestHandler import numpy as np from numpy import ndarray -from ..assembly import InteriorBasis +from ..assembly import CellBasis from ..mesh import Mesh2D @@ -29,6 +31,15 @@ def points_to_figure(p, kwargs): return p, width, height, stroke +class Server(BaseHTTPRequestHandler): + + def do_GET(self): + self.send_response(200) + self.send_header("Content-type", "text/html") + self.end_headers() + self.wfile.write(self.svg.encode("utf8")) + + @dataclass class SvgPlot: @@ -37,6 +48,14 @@ class SvgPlot: def _repr_svg_(self) -> str: return self.svg + def show(self, port=8000): + server = Server + server.svg = self.svg + url = "http://localhost:{}".format(port) + print("Serving the plot at " + url) + webbrowser.open_new_tab(url) + HTTPServer(("localhost", port), Server).handle_request() + @singledispatch def draw(m, **kwargs) -> SvgPlot: @@ -81,13 +100,13 @@ def draw_mesh2d(m: Mesh2D, **kwargs) -> SvgPlot: @draw.register(Mesh2D) def draw_geometry2d(m: Mesh2D, **kwargs) -> SvgPlot: - nrefs = kwargs["nrefs"] if "nrefs" in kwargs else 1 + nrefs = kwargs["nrefs"] if "nrefs" in kwargs else 0 m = m._splitref(nrefs) return draw_mesh2d(m, **kwargs) -@draw.register(InteriorBasis) -def draw_basis(ib: InteriorBasis, **kwargs) -> SvgPlot: +@draw.register(CellBasis) +def draw_basis(ib: CellBasis, **kwargs) -> SvgPlot: nrefs = kwargs["nrefs"] if "nrefs" in kwargs else 2 m, _ = ib.refinterp(ib.mesh.p[0], nrefs=nrefs) return draw(m, boundaries_only=True, **kwargs) @@ -138,8 +157,8 @@ def plot_mesh2d(m: Mesh2D, x: ndarray, **kwargs) -> SvgPlot: elems)) -@plot.register(InteriorBasis) -def plot_basis(ib: InteriorBasis, x: ndarray, **kwargs) -> SvgPlot: +@plot.register(CellBasis) +def plot_basis(ib: CellBasis, x: ndarray, **kwargs) -> SvgPlot: nrefs = kwargs["nrefs"] if "nrefs" in kwargs else 0 m, X = ib.refinterp(x, nrefs=nrefs) return plot(m, X, **kwargs)