Skip to content

Commit

Permalink
update LCDB.debug to extract all tracebacks, error messages, and configs
Browse files Browse the repository at this point in the history
  • Loading branch information
ChengYan7 committed Oct 15, 2024
1 parent 9a6914b commit a654f82
Showing 1 changed file with 94 additions and 2 deletions.
96 changes: 94 additions & 2 deletions publications/2023-neurips/lcdb/db/_database.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import os
import pathlib

import re
import pandas as pd
from lcdb.db._repository import Repository
from lcdb.db._util import get_path_to_lcdb, CountAwareGenerator
Expand Down Expand Up @@ -176,4 +176,96 @@ def generator():
if workflows is not None and len(workflows) == 1:
return dfs_per_workflow[workflows[0]] if workflows[0] in dfs_per_workflow else None
else:
return dfs_per_workflow
return dfs_per_workflow


def debug(
self,
repositories=None,
campaigns=None,
workflows=None,
openmlids=None,
workflow_seeds=None,
test_seeds=None,
validation_seeds=None,
show_progress=False
):
"""
Retrieves only rows that contain a traceback and their associated configs.
"""
if not self.loaded:
self._load()

if repositories is None:
repositories = list(self.repositories.values())
else:
requested_repository_names = set(repositories)
existing_repository_names = set(self.repositories.keys())
if (
len(requested_repository_names.difference(existing_repository_names))
> 0
):
raise Exception(
f"The following repositories were included in the query but do not exist in this LCDB_debug: "
f"{requested_repository_names.difference(existing_repository_names)}"
)
repositories = [self.repositories[k] for k in requested_repository_names]

if workflows is not None and isinstance(workflows, str):
workflows = [workflows]

result_generators = []
for repository in repositories:
if repository.exists():
result_generators.append(
repository.query_results_as_stream(
campaigns=campaigns,
workflows=workflows,
openmlids=openmlids,
workflow_seeds=workflow_seeds,
test_seeds=test_seeds,
validation_seeds=validation_seeds,
)
)

def generator():
for gen in result_generators:
for res in gen:
yield res

gen = CountAwareGenerator(sum([len(g) for g in result_generators]), generator())

tracebacks, configs, errors = [], [], []

for df in tqdm(gen, disable=not show_progress):
# check if "traceback" column exists
if "m:traceback" in df.columns:
traceback_rows = df[df["m:traceback"].notna()]

# extract corresponding configuration parameters
if not traceback_rows.empty:
traceback_indices = traceback_rows.index.tolist()
config_cols = [c for c in df.columns if c.startswith("p:")]
# corresponding_configs = df.loc[traceback_rows.index]
# configs.append(corresponding_configs)
corresponding_configs_reset = df.loc[traceback_indices, config_cols].drop_duplicates().reset_index(drop=True)
configs.append(corresponding_configs_reset)

tracebacks.append(traceback_rows["m:traceback"])

# extract errors from traceback messages str format first
traceback_str = str(traceback_rows["m:traceback"].iloc[0])
try:
error_message = re.search(r'(\w+Error): (.*)', traceback_str).group(0)
except:
error_message = traceback_str
errors.append(error_message)

else:
print("Error: no traceback column in dataframe")

return {
"configs": pd.concat(configs, ignore_index=True) if configs else None,
"tracebacks": pd.concat(tracebacks, ignore_index=True) if tracebacks else None,
"errors": pd.Series(errors) if errors else None
}

0 comments on commit a654f82

Please sign in to comment.