Skip to content

Commit

Permalink
Merge pull request #1737 from RissyRan:profileUploading
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 483414670
  • Loading branch information
copybara-github committed Oct 24, 2022
2 parents 254b37e + 27465bf commit bdaca1f
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 17 deletions.
31 changes: 22 additions & 9 deletions google/cloud/aiplatform/tensorboard/uploader.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,17 @@ def create_experiment(self):
additional_senders=self._additional_senders,
)

def _should_profile(self) -> bool:
"""Indicate if profile plugin should be enabled."""
if "profile" in self._allowed_plugins:
if not self._one_shot:
raise ValueError(
"Profile plugin currently only supported for one shot."
)
logger.info("Profile plugin is enalbed.")
return True
return False

def _create_additional_senders(self) -> Dict[str, uploader_utils.RequestSender]:
"""Create any additional senders for non traditional event files.
Expand All @@ -338,11 +349,7 @@ def _create_additional_senders(self) -> Dict[str, uploader_utils.RequestSender]:
plugin. If there are any items that cannot be searched for via the
`_BatchedRequestSender`, add them here.
"""
if "profile" in self._allowed_plugins:
if not self._one_shot:
raise ValueError(
"Profile plugin currently only supported for one shot."
)
if self._should_profile():
source_bucket = uploader_utils.get_source_bucket(self._logdir)

self._additional_senders["profile"] = functools.partial(
Expand Down Expand Up @@ -458,6 +465,11 @@ def _upload_once(self):
run_to_events = {
self._run_name_prefix + k: v for k, v in run_to_events.items()
}

# Add a profile event to trigger send_request in _additional_senders
if self._should_profile():
run_to_events[self._run_name_prefix] = None

with self._tracker.send_tracker():
self._dispatcher.dispatch_requests(run_to_events)

Expand Down Expand Up @@ -714,10 +726,11 @@ def dispatch_requests(
"""
for (run_name, events) in run_to_events.items():
self._dispatch_additional_senders(run_name)
for event in events:
_filter_graph_defs(event)
for value in event.summary.value:
self._request_sender.send_request(run_name, event, value)
if events is not None:
for event in events:
_filter_graph_defs(event)
for value in event.summary.value:
self._request_sender.send_request(run_name, event, value)
self._request_sender.flush()


Expand Down
33 changes: 25 additions & 8 deletions tests/unit/aiplatform/test_uploader.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def _create_uploader(
verbosity=0, # Use 0 to minimize littering the test output.
one_shot=None,
allowed_plugins=_SCALARS_HISTOGRAMS_AND_GRAPHS,
run_name_prefix=None,
):
if writer_client is _USE_DEFAULT:
writer_client = _create_mock_client()
Expand Down Expand Up @@ -242,6 +243,7 @@ def _create_uploader(
description=description,
verbosity=verbosity,
one_shot=one_shot,
run_name_prefix=run_name_prefix,
)


Expand Down Expand Up @@ -1053,14 +1055,29 @@ def create_time_series(tensorboard_time_series, parent=None):
self.assertLen(actual_blobs, 2)

def test_add_profile_plugin(self):
uploader = _create_uploader(
_create_mock_client(),
_TEST_LOG_DIR_NAME,
one_shot=True,
allowed_plugins=frozenset(("profile",)),
)
uploader.create_experiment()
self.assertIn("profile", uploader._dispatcher._additional_senders)
run_name = "profile_test_run"
with tempfile.TemporaryDirectory() as logdir:
prof_path = os.path.join(
logdir, run_name, profile_uploader.ProfileRequestSender.PROFILE_PATH
)
os.makedirs(prof_path)

uploader = _create_uploader(
_create_mock_client(),
logdir,
one_shot=True,
allowed_plugins=frozenset(("profile",)),
run_name_prefix=run_name,
)

uploader.create_experiment()
uploader._upload_once()
senders = uploader._dispatcher._additional_senders
self.assertIn("profile", senders.keys())

profile_sender = senders["profile"]
self.assertIn(run_name, profile_sender._run_to_profile_loaders)
self.assertIn(run_name, profile_sender._run_to_file_request_sender)


class BatchedRequestSenderTest(tf.test.TestCase):
Expand Down

0 comments on commit bdaca1f

Please sign in to comment.