diff --git a/src/lean_dojo/data_extraction/ExtractData.lean b/src/lean_dojo/data_extraction/ExtractData.lean index 4be15d0..d8beff3 100644 --- a/src/lean_dojo/data_extraction/ExtractData.lean +++ b/src/lean_dojo/data_extraction/ExtractData.lean @@ -373,10 +373,7 @@ private def visitInfo (ctx : ContextInfo) (i : Info) (parent : InfoTree) (env : private partial def traverseTree (ctx: ContextInfo) (tree : InfoTree) (parent : InfoTree) (env : Environment) : TraceM Unit := do match tree with - | .context ctx' t => - match ctx'.mergeIntoOuter? ctx with - | some ctx' => traverseTree ctx' t tree env - | none => panic! "fail to synthesis contextInfo when traversing infoTree" + | .context ctx' t => traverseTree ctx' t tree env | .node i children => visitInfo ctx i parent env for x in children do @@ -386,10 +383,7 @@ private partial def traverseTree (ctx: ContextInfo) (tree : InfoTree) private def traverseTopLevelTree (tree : InfoTree) (env : Environment) : TraceM Unit := do match tree with - | .context ctx t => - match ctx.mergeIntoOuter? none with - | some ctx => traverseTree ctx t tree env - | none => panic! "fail to synthesis contextInfo for top-level infoTree" + | .context ctx t => traverseTree ctx t tree env | _ => pure () diff --git a/src/lean_dojo/data_extraction/build_lean4_repo.py b/src/lean_dojo/data_extraction/build_lean4_repo.py index a129a97..0c02032 100644 --- a/src/lean_dojo/data_extraction/build_lean4_repo.py +++ b/src/lean_dojo/data_extraction/build_lean4_repo.py @@ -159,7 +159,7 @@ def main() -> None: num_procs = int(os.environ["NUM_PROCS"]) repo_name = args.repo_name - os.chdir(repo_name) + os.chdir(repo_name+"/cedar-lean") # Build the repo using lake. logger.info(f"Building {repo_name}") @@ -186,6 +186,7 @@ def main() -> None: if not args.no_deps: dirs_to_monitor.append(packages_path) logger.info(f"Tracing {repo_name}") + run_cmd("mv ../ExtractData.lean ExtractData.lean", capture_output=True) with launch_progressbar(dirs_to_monitor): cmd = f"lake env lean --threads {num_procs} --run ExtractData.lean" if args.no_deps: diff --git a/src/lean_dojo/data_extraction/lean.py b/src/lean_dojo/data_extraction/lean.py index 1679d59..3c088e8 100644 --- a/src/lean_dojo/data_extraction/lean.py +++ b/src/lean_dojo/data_extraction/lean.py @@ -410,6 +410,10 @@ class LeanGitRepo: You can also use tags such as ``v3.5.0``. They will be converted to commit hashes. """ + inner_lake_path: str = None + """If the lake project isn't at the repository root, the path to the lake project root. None otherwise. + """ + repo: Repository = field(init=False, repr=False) """A :class:`github.Repository` object. """ @@ -425,6 +429,9 @@ def __post_init__(self) -> None: raise ValueError(f"{self.url} is not a valid URL") object.__setattr__(self, "url", normalize_url(self.url)) object.__setattr__(self, "repo", url_to_repo(self.url)) + if self.inner_lake_path is not None and (self.inner_lake_path.startswith("/") or self.inner_lake_path.endswith("/")): + raise ValueError(f"{self.inner_lake_path} should not start or end with '/'") + object.__setattr__(self, "inner_lake_path", self.inner_lake_path) # Convert tags or branches to commit hashes if not (len(self.commit) == 40 and _COMMIT_REGEX.fullmatch(self.commit)): @@ -458,6 +465,13 @@ def from_path(cls, path: Path) -> "LeanGitRepo": url, commit = get_repo_info(path) return cls(url, commit) + @property + def path_to_lake_proj(self) -> str: + if self.inner_lake_path is None: + return self.repo.name + else: + return self.repo.name+"/"+self.inner_lake_path + @property def name(self) -> str: return self.repo.name @@ -576,11 +590,15 @@ def get_license(self) -> Optional[str]: def _get_config_url(self, filename: str) -> str: assert "github.com" in self.url, f"Unsupported URL: {self.url}" url = self.url.replace("github.com", "raw.githubusercontent.com") - return f"{url}/{self.commit}/{filename}" + if self.inner_lake_path is None: + return f"{url}/{self.commit}/{filename}" + else: + return f"{url}/{self.commit}/{self.inner_lake_path}/{filename}" def get_config(self, filename: str, num_retries: int = 2) -> Dict[str, Any]: """Return the repo's files.""" config_url = self._get_config_url(filename) + print(f"Get config with URL: {config_url}") content = read_url(config_url, num_retries) if filename.endswith(".toml"): return toml.loads(content) diff --git a/src/lean_dojo/data_extraction/trace.py b/src/lean_dojo/data_extraction/trace.py index d406c9e..a1da1ce 100644 --- a/src/lean_dojo/data_extraction/trace.py +++ b/src/lean_dojo/data_extraction/trace.py @@ -33,6 +33,7 @@ def _trace(repo: LeanGitRepo, build_deps: bool) -> None: logger.debug(f"Tracing {repo}") container = get_container() mts = { + # TODO: should we mount a different directory here? Path.cwd() / repo.name: f"/workspace/{repo.name}", LEAN4_BUILD_SCRIPT_PATH: f"/workspace/{LEAN4_BUILD_SCRIPT_PATH.name}", LEAN4_DATA_EXTRACTOR_PATH: f"/workspace/{repo.name}/{LEAN4_DATA_EXTRACTOR_PATH.name}", @@ -80,7 +81,7 @@ def get_traced_repo_path(repo: LeanGitRepo, build_deps: bool = True) -> Path: with working_directory() as tmp_dir: logger.debug(f"Working in the temporary directory {tmp_dir}") _trace(repo, build_deps) - traced_repo = TracedRepo.from_traced_files(tmp_dir / repo.name, build_deps) + traced_repo = TracedRepo.from_traced_files(tmp_dir / repo.path_to_lake_proj, build_deps) traced_repo.save_to_disk() path = cache.store(tmp_dir / repo.name) else: diff --git a/src/lean_dojo/data_extraction/traced_data.py b/src/lean_dojo/data_extraction/traced_data.py index 63eb3cd..7657c0c 100644 --- a/src/lean_dojo/data_extraction/traced_data.py +++ b/src/lean_dojo/data_extraction/traced_data.py @@ -1151,7 +1151,7 @@ def load_from_disk( cls, root_dir: Union[str, Path], build_deps: bool = True ) -> "TracedRepo": """Load a traced repo from :file:`*.trace.xml` files.""" - root_dir = Path(root_dir).resolve() + root_dir = Path(root_dir / "cedar-lean").resolve() if not is_git_repo(root_dir): raise RuntimeError(f"{root_dir} is not a Git repo.") repo = LeanGitRepo.from_path(root_dir)