diff --git a/torchx/runner/test/api_test.py b/torchx/runner/test/api_test.py index a53bd314b..3a877cbcb 100644 --- a/torchx/runner/test/api_test.py +++ b/torchx/runner/test/api_test.py @@ -105,9 +105,9 @@ def test_validate_invalid_replicas(self, _) -> None: with self.assertRaises(ValueError): runner.run(app, scheduler="local_dir") - @patch("torchx.util.session.uuid") + @patch("torchx.util.session.uuid.uuid4") def test_session_id(self, uuid_mock: MagicMock, record_mock: MagicMock) -> None: - uuid_mock.uuid4.return_value = "test_session_id" + uuid_mock.return_value = "test_session_id" test_file = self.tmpdir / "test_file" with self.get_runner() as runner: @@ -128,15 +128,15 @@ def test_session_id(self, uuid_mock: MagicMock, record_mock: MagicMock) -> None: none_throws(runner.wait(app_handle_2, wait_interval=0.1)) self.assertEqual(get_session_id(), "test_session_id") - uuid_mock.uuid4.assert_called_once() + uuid_mock.assert_called_once() record_mock.assert_called() for i in range(record_mock.call_count): event = record_mock.call_args_list[i].args[0] self.assertEqual(event.session, "test_session_id") - @patch("torchx.util.session.uuid") + @patch("torchx.util.session.uuid.uuid4") def test_run(self, uuid_mock: MagicMock, _) -> None: - uuid_mock.uuid4.return_value = "test_session_id" + uuid_mock.return_value = "test_session_id" test_file = self.tmpdir / "test_file" with self.get_runner() as runner: