Skip to content

Commit

Permalink
fix(studio): wait for studio metrics publish to complete on end (#827)
Browse files Browse the repository at this point in the history
  • Loading branch information
shcheklein authored Jun 10, 2024
1 parent 2672989 commit b987e66
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 0 deletions.
7 changes: 7 additions & 0 deletions src/dvclive/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,6 +905,11 @@ def worker():

self._studio_queue.put(self)

def _wait_for_studio_updates_posted(self):
if self._studio_queue:
logger.debug("Waiting for studio updates to be posted")
self._studio_queue.join()

def end(self):
"""
Signals that the current experiment has ended.
Expand Down Expand Up @@ -946,6 +951,8 @@ def end(self):

self.save_dvc_exp()

self._wait_for_studio_updates_posted()

# Mark experiment as done
post_to_studio(self, "done")

Expand Down
25 changes: 25 additions & 0 deletions tests/test_post_to_studio.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pathlib import Path

import pytest
import time
from dvc.env import DVC_EXP_GIT_REMOTE
from dvc_studio_client import DEFAULT_STUDIO_URL
from dvc_studio_client.env import DVC_STUDIO_REPO_URL, DVC_STUDIO_TOKEN
Expand Down Expand Up @@ -186,6 +187,30 @@ def test_post_to_studio_done_only_once(tmp_dir, mocked_dvc_repo, mocked_studio_p
assert expected_done_calls == actual_done_calls


def test_studio_updates_posted_on_end(tmp_path, mocked_dvc_repo, mocked_studio_post):
mocked_post, valid_response = mocked_studio_post
metrics_file = tmp_path / "metrics.json"
metrics_content = "metrics"

def long_post(*args, **kwargs):
# in case of `data` `long_post` should be called from a separate thread,
# meanwhile main thread go forward without slowing down, so if there is no
# some kind of wait in the Live main thread, then it will complete before
# we even can have a chance to write the file below
if kwargs["json"]["type"] == "data":
time.sleep(1)
metrics_file.write_text(metrics_content)

return valid_response

mocked_post.side_effect = long_post

with Live() as live:
live.log_metric("foo", 1)

assert metrics_file.read_text() == metrics_content


@pytest.mark.studio()
def test_post_to_studio_skip_start_and_done_on_env_var(
tmp_dir, mocked_dvc_repo, mocked_studio_post, monkeypatch
Expand Down

0 comments on commit b987e66

Please sign in to comment.