Skip to content
This repository has been archived by the owner on Dec 13, 2024. It is now read-only.

Commit

Permalink
Updated functions.
Browse files Browse the repository at this point in the history
  • Loading branch information
Exitare committed Aug 8, 2022
1 parent b6242ba commit 34f03dc
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 26 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ run_handler: RunHandler = RunHandler()
# Does only delete the first occurence of the given run name. If multiple runs do have the same name,
# this command needs to be executed multiple times
run_handler.delete_runs_and_child_runs(experiment_id=exp_id, run_name="My Run")
run_handler.delete_run(experiment_id=exp_id, run_name="My Run")
```
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "mlflow_wrapper"
version = "0.0.1.5"
version = "0.0.1.6"
authors = [
{ name="Exitare", email="[email protected]" },
]
Expand Down
48 changes: 33 additions & 15 deletions src/mlflow_wrapper/run_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,14 @@ def get_run_by_name(self, experiment_id: str, run_name: str, parent_run_id: str
# Run not found
return None

def get_run_and_child_runs(self, experiment_id: str, run_name: str) -> List:
def get_run(self, experiment_id: str, run_name: str, include_children: bool = True) -> List:
"""
Get all runs for a specific experiment id and run name.
Key is the parent run, values are the children runs
@param experiment_id:
@param run_name:
@return: Returns a dictionary with the parent run as key and the children runs as values
@param include_children:
@return: Returns a list of runs. The first run is the parent run
Values returns are run objects
"""

Expand All @@ -141,7 +142,12 @@ def get_run_and_child_runs(self, experiment_id: str, run_name: str) -> List:
return runs

parent_run: Run = self._client.get_run(parent_run_id)
runs.append(parent_run)

if parent_run.info.lifecycle_stage == 'active':
runs.append(parent_run)

if not include_children:
return runs

# Run not cached
all_run_infos: [] = reversed(self._client.list_run_infos(experiment_id))
Expand All @@ -155,18 +161,30 @@ def get_run_and_child_runs(self, experiment_id: str, run_name: str) -> List:

return runs

def delete_runs_and_child_runs(self, experiment_id: str, run_name: str):

runs: List = self.get_run_and_child_runs(experiment_id=experiment_id, run_name=run_name)

if len(runs) != 0:
for run in runs:
run: Run
# Remove run from local cache
self.__runs.pop(run.info.run_id, None)
# Delete run from mlflow
if run.info.lifecycle_stage == 'active':
self._client.delete_run(run.info.run_id)
def delete_run(self, experiment_id: str, run_name: str, delete_children: bool = True):
"""
Deletes the run with the given name. If multiple runs share the same name, only the first one is being deleted
:param experiment_id: The experiment id where the run located
:param run_name: The run name to be deleted
:param delete_children: Should the children runs also be deleted?
:return:
"""
try:
if delete_children:
runs: List = self.get_run(experiment_id=experiment_id, run_name=run_name, include_children=True)
else:
runs: List = [self.get_run_by_name(experiment_id=experiment_id, run_name=run_name)]

if len(runs) != 0:
for run in runs:
run: Run
# Remove run from local cache
self.__runs.pop(run.info.run_id, None)
# Delete run from mlflow
if run.info.lifecycle_stage == 'active':
self._client.delete_run(run.info.run_id)
except:
raise

def download_artifacts(self, save_path: Union[Path, str], run: Run = None, runs: [] = None,
mlflow_folder: str = None) -> dict:
Expand Down
16 changes: 9 additions & 7 deletions tests/test_run_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ def test_get_run_by_name(self):
experiment_id: str = experiment_handler.get_experiment_id_by_name(experiment_name="Library Test Experiment")
run_name: str = "Test run"

run_handler.delete_runs_and_child_runs(experiment_id=experiment_id, run_name=run_name)
run_handler.delete_run(experiment_id=experiment_id, run_name=run_name)
with mlflow.start_run(experiment_id=experiment_id, run_name=run_name) as run:
mlflow.log_param("TestRun", 1)

run: Run = run_handler.get_run_by_name(experiment_id=experiment_id, run_name=run_name)
self.assertIsNotNone(run)

run_handler.delete_runs_and_child_runs(experiment_id=experiment_id, run_name=run_name)
run_handler.delete_run(experiment_id=experiment_id, run_name=run_name)

def test_get_run_by_id(self):
run_handler: RunHandler = RunHandler()
Expand All @@ -43,7 +43,7 @@ def test_get_run_by_id(self):
run: Run = run_handler.get_run_by_id(experiment_id=experiment_id, run_id=run.info.run_id)
self.assertIsNotNone(run)

run_handler.delete_runs_and_child_runs(experiment_id=experiment_id, run_name=run_name)
run_handler.delete_run(experiment_id=experiment_id, run_name=run_name)

def test_get_run_id_by_name(self):
run_handler: RunHandler = RunHandler()
Expand All @@ -58,7 +58,7 @@ def test_get_run_id_by_name(self):
run_id: str = run_handler.get_run_id_by_name(experiment_id=experiment_id, run_name=run_name)
self.assertIsNotNone(run_id)

run_handler.delete_runs_and_child_runs(experiment_id=experiment_id, run_name=run_name)
run_handler.delete_run(experiment_id=experiment_id, run_name=run_name)

def test_get_run_and_child_runs(self):
run_handler: RunHandler = RunHandler()
Expand All @@ -73,10 +73,12 @@ def test_get_run_and_child_runs(self):
with mlflow.start_run(experiment_id=experiment_id, run_name="child_run", nested=True) as child_run:
mlflow.log_param("Child Run", 1)

runs: List = run_handler.get_run_and_child_runs(experiment_id=experiment_id, run_name=run_name)
runs: List = run_handler.get_run(experiment_id=experiment_id, run_name=run_name, include_children=True)
self.assertEqual(2, len(runs))

run_handler.delete_runs_and_child_runs(experiment_id=experiment_id, run_name=run_name)
run_handler.delete_run(experiment_id=experiment_id, run_name=run_name)
runs: List = run_handler.get_run(experiment_id=experiment_id, run_name=run_name, include_children=True)
self.assertEqual(0, len(runs))

def test_download_artifacts(self):
run_handler: RunHandler = RunHandler()
Expand Down Expand Up @@ -112,7 +114,7 @@ def test_download_artifacts(self):
shutil.rmtree(store_folder)
shutil.rmtree(save_path)

run_handler.delete_runs_and_child_runs(experiment_id=experiment_id, run_name=run_name)
run_handler.delete_run(experiment_id=experiment_id, run_name=run_name)


if __name__ == '__main__':
Expand Down
4 changes: 2 additions & 2 deletions tests/test_upload_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_upload_dataframe(self):
upload_handler.upload_dataframe(data=pd.DataFrame(columns=['A', 'B']), file_name="Test Upload.csv")

time.sleep(1)
run_handler.delete_runs_and_child_runs(experiment_id=experiment_id, run_name="Upload Test")
run_handler.delete_run(experiment_id=experiment_id, run_name="Upload Test")

shutil.rmtree(save_path)

Expand All @@ -47,7 +47,7 @@ def test_upload_file(self):
upload_handler.upload_file(file_name="test.csv")

time.sleep(1)
run_handler.delete_runs_and_child_runs(experiment_id=experiment_id, run_name=run_name)
run_handler.delete_run(experiment_id=experiment_id, run_name=run_name)

shutil.rmtree(save_path)

Expand Down

0 comments on commit 34f03dc

Please sign in to comment.