Skip to content

Commit

Permalink
fix: compute checksums through GitHub provider
Browse files Browse the repository at this point in the history
  • Loading branch information
davhofer committed Aug 20, 2024
1 parent eb7618c commit 41f2863
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 35 deletions.
23 changes: 15 additions & 8 deletions src/py2spack/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ class PyProject:
dependency_errors: list[pyproject_parsing.ConfigurationError] = dataclasses.field(
default_factory=list
)
provider: package_providers.PyProjectProvider | None = None

@classmethod
def from_toml(
Expand Down Expand Up @@ -491,7 +492,7 @@ def build_from_pyprojects(
self,
name: str,
pyprojects: list[PyProject],
provider: package_providers.PyProjectProvider,
pypi_provider: package_providers.PyPIProvider,
use_test_prefix: bool = False,
) -> None:
"""Build the spack package from pyprojects."""
Expand All @@ -505,18 +506,21 @@ def build_from_pyprojects(
# s.t. newest version is on top in package.py
for p in pyprojects:
spack_version = conversion_tools.packaging_to_spack_version(p.version)
hashdict = provider.get_sdist_hash(name, p.version)
if isinstance(hashdict, dict) and hashdict:
hash_key, hash_value = next(iter(hashdict.items()))

if hash_key in SPACK_CHECKSUM_HASHES:
self._versions_with_checksum.append((spack_version, hash_key, hash_value))
continue
if p.provider is not None:
hashdict = p.provider.get_sdist_hash(name, p.version)

if isinstance(hashdict, dict) and hashdict:
hash_key, hash_value = next(iter(hashdict.items()))

if hash_key in SPACK_CHECKSUM_HASHES:
self._versions_with_checksum.append((spack_version, hash_key, hash_value))
continue

self._versions_missing_checksum.append(spack_version)

# convert all dependencies (for the selected versions)
self._dependencies_from_pyprojects(pyprojects, provider)
self._dependencies_from_pyprojects(pyprojects, pypi_provider)

def print_pkg(self, outfile: TextIO = sys.stdout) -> None: # noqa: C901, PLR0912, PLR0915
"""Format and write the package to 'outfile'.
Expand Down Expand Up @@ -726,6 +730,9 @@ def _convert_single(
)
continue

# add provider to pyproject for convenience
pyproject.provider = provider

pyprojects.append(pyproject)

if not pyprojects:
Expand Down
84 changes: 57 additions & 27 deletions src/py2spack/package_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import abc
import dataclasses
import functools
import hashlib
import io
import re
from collections.abc import Hashable
from typing import Protocol
Expand All @@ -15,7 +17,7 @@
from py2spack import utils


KNOWN_ARCHIVE_FORMATS = [
TARBALL_ARCHIVE_FORMATS = [
".tar",
".tar.gz",
".tar.bz2",
Expand Down Expand Up @@ -107,6 +109,33 @@ def _get(self, repo_specifier: str) -> dict | PyProjectProviderQueryError:
data: dict = r.json()
return data

@functools.cache # noqa: B019
def _get_pyproject_and_checksum(
self, name: str, version: vn.Version
) -> tuple[dict | PyProjectProviderQueryError, str] | PyProjectProviderQueryError:
"""Downloads and extracts tarball and returns checksum and pyproject data."""
versions_with_urls = self._get_versions_with_urls(name)
if isinstance(versions_with_urls, PyProjectProviderQueryError):
return versions_with_urls

result = [url for v, url in versions_with_urls if v == version and url]
if not result:
return PyProjectProviderQueryError(f"No download url found for {name} v{version}")

url = result[0]

sdist_file_obj = utils.download_bytes(url)

if sdist_file_obj is None:
return PyProjectProviderQueryError(f"Unable to download package {name} from {url}")

pyproject: dict | PyProjectProviderQueryError = _extract_pyproject(
sdist_file_obj, ".tar.gz"
)

checksum: str = hashlib.sha256(sdist_file_obj.getvalue()).hexdigest()
return (pyproject, checksum)

def parse_repo_name(self, name: str) -> str | None:
"""Parse the github repository name.
Expand Down Expand Up @@ -186,6 +215,7 @@ def get_versions(self, name: str) -> list[vn.Version] | PyProjectProviderQueryEr

return result

@functools.cache # noqa: B019
def _get_versions_with_urls(
self, name: str
) -> list[tuple[vn.Version, str]] | PyProjectProviderQueryError:
Expand All @@ -211,25 +241,19 @@ def _get_versions_with_urls(

def get_pyproject(self, name: str, version: vn.Version) -> dict | PyProjectProviderQueryError:
"""Get the contents of the pyproject.toml file for the specified version."""
versions_with_urls = self._get_versions_with_urls(name)
if isinstance(versions_with_urls, PyProjectProviderQueryError):
return versions_with_urls

result = [url for v, url in versions_with_urls if v == version and url]
if not result:
return PyProjectProviderQueryError(f"No download url found for {name} v{version}")

url = result[0]

return _try_load_pyproject(url, name, ".tar.gz")
result = self._get_pyproject_and_checksum(name, version)
if isinstance(result, PyProjectProviderQueryError):
return result
return result[0]

def get_sdist_hash(
self, name: str, version: vn.Version
) -> dict[str, str] | PyProjectProviderQueryError:
"""Get the sdist hash (sha256 if available) for the specified version."""
return PyProjectProviderQueryError(
f"GitHub doesn't provide sdist checksums ({name} v{version})"
)
result = self._get_pyproject_and_checksum(name, version)
if isinstance(result, PyProjectProviderQueryError):
return result
return {"sha256": result[1]}


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -447,7 +471,7 @@ def _normalize_package_name(name: str) -> str:


def _parse_archive_extension(filename: str) -> str | PyProjectProviderQueryError:
extension_list = [ext for ext in KNOWN_ARCHIVE_FORMATS if filename.endswith(ext)]
extension_list = [ext for ext in TARBALL_ARCHIVE_FORMATS if filename.endswith(ext)]

if not extension_list:
# we return an API error here because the filenames are obtained through
Expand All @@ -459,25 +483,19 @@ def _parse_archive_extension(filename: str) -> str | PyProjectProviderQueryError


def _is_archive_format_known(filename: str) -> bool:
return any(filename.endswith(ext) for ext in KNOWN_ARCHIVE_FORMATS)
return any(filename.endswith(ext) for ext in TARBALL_ARCHIVE_FORMATS)


# TODO @davhofer: handle zip archives
# TODO @davhofer: should this function be placed in the PyProjectProvider Protocol? or in utils?
# generic function for loading any file/filetype?
def _try_load_pyproject(
url: str, name: str, archive_ext: str
def _extract_pyproject(
file_bytes: io.BytesIO, archive_ext: str
) -> dict | PyProjectProviderQueryError:
"""Load sdist from url and extract pyproject.toml contents."""
sdist_file_obj = utils.download_bytes(url)

result: dict | None | PyProjectProviderQueryError

if sdist_file_obj is None:
result = PyProjectProviderQueryError(f"Unable to download package {name} from {url}")

elif archive_ext in KNOWN_ARCHIVE_FORMATS:
result = utils.extract_toml_from_tar(sdist_file_obj, "pyproject.toml")
if archive_ext in TARBALL_ARCHIVE_FORMATS:
result = utils.extract_toml_from_tar(file_bytes, "pyproject.toml")
if result is None:
result = PyProjectProviderQueryError("Unable to extract pyproject.toml from archive")

Expand All @@ -487,3 +505,15 @@ def _try_load_pyproject(
)

return result


def _try_load_pyproject(
url: str, name: str, archive_ext: str
) -> dict | PyProjectProviderQueryError:
"""Load sdist from url and extract pyproject.toml contents."""
sdist_file_obj = utils.download_bytes(url)

if sdist_file_obj is None:
return PyProjectProviderQueryError(f"Unable to download package {name} from {url}")

return _extract_pyproject(sdist_file_obj, archive_ext)

0 comments on commit 41f2863

Please sign in to comment.