Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Support cedar-lean #152

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions src/lean_dojo/data_extraction/ExtractData.lean
Original file line number Diff line number Diff line change
Expand Up @@ -373,10 +373,7 @@ private def visitInfo (ctx : ContextInfo) (i : Info) (parent : InfoTree) (env :
private partial def traverseTree (ctx: ContextInfo) (tree : InfoTree)
(parent : InfoTree) (env : Environment) : TraceM Unit := do
match tree with
| .context ctx' t =>
match ctx'.mergeIntoOuter? ctx with
| some ctx' => traverseTree ctx' t tree env
| none => panic! "fail to synthesis contextInfo when traversing infoTree"
| .context ctx' t => traverseTree ctx' t tree env
| .node i children =>
visitInfo ctx i parent env
for x in children do
Expand All @@ -386,10 +383,7 @@ private partial def traverseTree (ctx: ContextInfo) (tree : InfoTree)

private def traverseTopLevelTree (tree : InfoTree) (env : Environment) : TraceM Unit := do
match tree with
| .context ctx t =>
match ctx.mergeIntoOuter? none with
| some ctx => traverseTree ctx t tree env
| none => panic! "fail to synthesis contextInfo for top-level infoTree"
| .context ctx t => traverseTree ctx t tree env
| _ => pure ()


Expand Down
3 changes: 2 additions & 1 deletion src/lean_dojo/data_extraction/build_lean4_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def main() -> None:

num_procs = int(os.environ["NUM_PROCS"])
repo_name = args.repo_name
os.chdir(repo_name)
os.chdir(repo_name+"/cedar-lean")

# Build the repo using lake.
logger.info(f"Building {repo_name}")
Expand All @@ -186,6 +186,7 @@ def main() -> None:
if not args.no_deps:
dirs_to_monitor.append(packages_path)
logger.info(f"Tracing {repo_name}")
run_cmd("mv ../ExtractData.lean ExtractData.lean", capture_output=True)
with launch_progressbar(dirs_to_monitor):
cmd = f"lake env lean --threads {num_procs} --run ExtractData.lean"
if args.no_deps:
Expand Down
20 changes: 19 additions & 1 deletion src/lean_dojo/data_extraction/lean.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,10 @@ class LeanGitRepo:
You can also use tags such as ``v3.5.0``. They will be converted to commit hashes.
"""

inner_lake_path: str = None
"""If the lake project isn't at the repository root, the path to the lake project root. None otherwise.
"""

repo: Repository = field(init=False, repr=False)
"""A :class:`github.Repository` object.
"""
Expand All @@ -425,6 +429,9 @@ def __post_init__(self) -> None:
raise ValueError(f"{self.url} is not a valid URL")
object.__setattr__(self, "url", normalize_url(self.url))
object.__setattr__(self, "repo", url_to_repo(self.url))
if self.inner_lake_path is not None and (self.inner_lake_path.startswith("/") or self.inner_lake_path.endswith("/")):
raise ValueError(f"{self.inner_lake_path} should not start or end with '/'")
object.__setattr__(self, "inner_lake_path", self.inner_lake_path)

# Convert tags or branches to commit hashes
if not (len(self.commit) == 40 and _COMMIT_REGEX.fullmatch(self.commit)):
Expand Down Expand Up @@ -458,6 +465,13 @@ def from_path(cls, path: Path) -> "LeanGitRepo":
url, commit = get_repo_info(path)
return cls(url, commit)

@property
def path_to_lake_proj(self) -> str:
if self.inner_lake_path is None:
return self.repo.name
else:
return self.repo.name+"/"+self.inner_lake_path

@property
def name(self) -> str:
return self.repo.name
Expand Down Expand Up @@ -576,11 +590,15 @@ def get_license(self) -> Optional[str]:
def _get_config_url(self, filename: str) -> str:
assert "github.com" in self.url, f"Unsupported URL: {self.url}"
url = self.url.replace("github.com", "raw.githubusercontent.com")
return f"{url}/{self.commit}/{filename}"
if self.inner_lake_path is None:
return f"{url}/{self.commit}/{filename}"
else:
return f"{url}/{self.commit}/{self.inner_lake_path}/{filename}"

def get_config(self, filename: str, num_retries: int = 2) -> Dict[str, Any]:
"""Return the repo's files."""
config_url = self._get_config_url(filename)
print(f"Get config with URL: {config_url}")
content = read_url(config_url, num_retries)
if filename.endswith(".toml"):
return toml.loads(content)
Expand Down
3 changes: 2 additions & 1 deletion src/lean_dojo/data_extraction/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def _trace(repo: LeanGitRepo, build_deps: bool) -> None:
logger.debug(f"Tracing {repo}")
container = get_container()
mts = {
# TODO: should we mount a different directory here?
Path.cwd() / repo.name: f"/workspace/{repo.name}",
LEAN4_BUILD_SCRIPT_PATH: f"/workspace/{LEAN4_BUILD_SCRIPT_PATH.name}",
LEAN4_DATA_EXTRACTOR_PATH: f"/workspace/{repo.name}/{LEAN4_DATA_EXTRACTOR_PATH.name}",
Expand Down Expand Up @@ -80,7 +81,7 @@ def get_traced_repo_path(repo: LeanGitRepo, build_deps: bool = True) -> Path:
with working_directory() as tmp_dir:
logger.debug(f"Working in the temporary directory {tmp_dir}")
_trace(repo, build_deps)
traced_repo = TracedRepo.from_traced_files(tmp_dir / repo.name, build_deps)
traced_repo = TracedRepo.from_traced_files(tmp_dir / repo.path_to_lake_proj, build_deps)
traced_repo.save_to_disk()
path = cache.store(tmp_dir / repo.name)
else:
Expand Down
2 changes: 1 addition & 1 deletion src/lean_dojo/data_extraction/traced_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1151,7 +1151,7 @@ def load_from_disk(
cls, root_dir: Union[str, Path], build_deps: bool = True
) -> "TracedRepo":
"""Load a traced repo from :file:`*.trace.xml` files."""
root_dir = Path(root_dir).resolve()
root_dir = Path(root_dir / "cedar-lean").resolve()
if not is_git_repo(root_dir):
raise RuntimeError(f"{root_dir} is not a Git repo.")
repo = LeanGitRepo.from_path(root_dir)
Expand Down