Skip to content

Commit

Permalink
log_image: Support matplotlib.figure.Figure as input. (#658)
Browse files Browse the repository at this point in the history
Closes #224
  • Loading branch information
daavoo authored Aug 14, 2023
1 parent e8d008e commit ee7968b
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 6 deletions.
19 changes: 13 additions & 6 deletions src/dvclive/plots/image.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from pathlib import Path, PurePath

from dvclive.utils import isinstance_without_import

from .base import Data


Expand All @@ -17,20 +19,25 @@ def output_path(self) -> Path:
def could_log(val: object) -> bool:
acceptable = {
("numpy", "ndarray"),
("matplotlib.figure", "Figure"),
("PIL.Image", "Image"),
}
for cls in type(val).mro():
if (cls.__module__, cls.__name__) in acceptable:
if any(isinstance_without_import(val, *cls) for cls in acceptable):
return True
if isinstance(val, (PurePath, str)):
return True
return False

def dump(self, val, **kwargs) -> None: # noqa: ARG002
if val.__class__.__module__ == "numpy":
if isinstance_without_import(val, "numpy", "ndarray"):
from PIL import Image as ImagePIL

pil_image = ImagePIL.fromarray(val)
else:
pil_image = val
pil_image.save(self.output_path)
ImagePIL.fromarray(val).save(self.output_path)
elif isinstance_without_import(val, "matplotlib.figure", "Figure"):
import matplotlib.pyplot as plt

plt.savefig(self.output_path)
plt.close(val)
elif isinstance_without_import(val, "PIL.Image", "Image"):
val.save(self.output_path)
7 changes: 7 additions & 0 deletions src/dvclive/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,10 @@ def clean_and_copy_into(src: StrPath, dst: StrPath) -> str:
shutil.copy2(src, dst_path)

return str(dst_path)


def isinstance_without_import(val, module, name):
for cls in type(val).mro():
if (cls.__module__, cls.__name__) == (module, name):
return True
return False
15 changes: 15 additions & 0 deletions tests/plots/test_image.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import matplotlib.pyplot as plt
import numpy as np
import pytest
from PIL import Image
Expand Down Expand Up @@ -100,3 +101,17 @@ def test_custom_class(tmp_dir):
live.log_image("image.png", extended_img)

assert (tmp_dir / live.plots_dir / LiveImage.subfolder / "image.png").exists()


def test_matplotlib(tmp_dir):
live = Live()
fig, ax = plt.subplots()
ax.plot([1, 2, 3, 4])

assert plt.fignum_exists(fig.number)

live.log_image("image.png", fig)

assert not plt.fignum_exists(fig.number)

assert (tmp_dir / live.plots_dir / LiveImage.subfolder / "image.png").exists()

0 comments on commit ee7968b

Please sign in to comment.