Skip to content

Commit

Permalink
feat: models api
Browse files Browse the repository at this point in the history
  • Loading branch information
wpbonelli committed Mar 5, 2025
1 parent e64ce4c commit 99cbd8a
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 32 deletions.
55 changes: 55 additions & 0 deletions autotest/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from pathlib import Path

import pytest
import tomli

import modflow_devtools.models as models

MODELS_TOML_PATH = Path(models.DATA_PATH) / models.MODELMAP_NAME


@pytest.fixture
def models_toml():
with MODELS_TOML_PATH.open("rb") as f:
return tomli.load(f)


@pytest.fixture
def temp_cache_dir(tmpdir, monkeypatch):
temp_dir = tmpdir.mkdir("pooch_cache")
monkeypatch.setenv("MF_DATA_DIR", str(temp_dir))
models.FETCHER.path = temp_dir # Update the fetcher path
return temp_dir


def test_registry_loaded():
assert models.FETCHER.registry is not None, "Registry was not loaded"
assert len(models.FETCHER.registry) > 0, "Registry is empty"


def test_generated_functions_exist(models_toml):
for model_name in models_toml.keys():
assert hasattr(models, model_name), (
f"Function {model_name} not found in models module"
)


def test_generated_functions_return_files(models_toml, temp_cache_dir):
for model_name, files in models_toml.items():
model_function = getattr(models, model_name)
fetched_files = model_function()
cached_files = temp_cache_dir.listdir()
assert isinstance(fetched_files, list), (
f"Function {model_name} did not return a list"
)
assert len(fetched_files) == len(files), (
f"Function {model_name} did not return the correct number of files"
)
for fetched_file in fetched_files:
assert Path(fetched_file).exists(), (
f"Fetched file {fetched_file} does not exist"
)
assert Path(temp_cache_dir) / Path(fetched_file).name in cached_files, (
f"Fetched file {fetched_file} is not in the temp cache directory"
)
break # just the first one so we dont ddos github
10 changes: 1 addition & 9 deletions modflow_devtools/make_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,9 @@
import tomli_w as tomli

from modflow_devtools.misc import get_model_paths
from modflow_devtools.models import BASE_URL, DATA_PATH

REPO_OWNER = "MODFLOW-ORG"
REPO_NAME = "modflow-devtools"
REPO_REF = "develop"
PROJ_ROOT = Path(__file__).parents[1]
DATA_RELPATH = "data"
DATA_PATH = PROJ_ROOT / REPO_NAME / DATA_RELPATH
REGISTRY_PATH = DATA_PATH / "registry.txt"
MODELS_PATH = DATA_PATH / "models.toml"
BASE_URL = f"https://github.com/{REPO_OWNER}/{REPO_NAME}/raw/{REPO_REF}/{DATA_RELPATH}/"


def _sha256(path: Path) -> str:
Expand Down Expand Up @@ -86,5 +79,4 @@ def write_registry(
args = parser.parse_args()
path = Path(args.path) if args.path else DATA_PATH
base_url = args.base_url if args.base_url else BASE_URL

write_registry(path, REGISTRY_PATH, base_url, args.append)
57 changes: 34 additions & 23 deletions modflow_devtools/models.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,60 @@
import importlib.resources as pkg_resources
from io import IOBase
from pathlib import Path

import pooch
import tomli

import modflow_devtools

REPO_OWNER = "MODFLOW-ORG"
REPO_NAME = "modflow-devtools"
REPO_REF = "develop"
PROJ_ROOT = Path(__file__).parents[1]
PROJ_OWNER = "MODFLOW-ORG"
PROJ_NAME = "modflow-devtools"
MODULE_NAME = PROJ_NAME.replace("-", "_")
PROJ_REF = "develop"
DATA_RELPATH = "data"
DATA_PATH = PROJ_ROOT / REPO_NAME / DATA_RELPATH
REGISTRY_PATH = DATA_PATH / "registry.txt"
MODELS_PATH = DATA_PATH / "models.toml"
BASE_URL = f"https://github.com/{REPO_OWNER}/{REPO_NAME}/raw/{REPO_REF}/{DATA_RELPATH}/"
DATA_PATH = PROJ_ROOT / MODULE_NAME / DATA_RELPATH
DATA_ANCHOR = f"{MODULE_NAME}.{DATA_RELPATH}"
REGISTRY_NAME = "registry.txt"
MODELMAP_NAME = "models.toml"
BASE_URL = f"https://github.com/{PROJ_OWNER}/{PROJ_NAME}/raw/{PROJ_REF}/{DATA_RELPATH}/"
VERSION = modflow_devtools.__version__.rpartition(".dev")[0]
CACHE_VAR_NAME = "MF_DATA_DIR"
FETCHER = pooch.create(
path=pooch.os_cache(REPO_NAME),
path=pooch.os_cache(PROJ_NAME),
base_url=BASE_URL,
version=VERSION,
registry=None,
env=CACHE_VAR_NAME,
)

if not REGISTRY_PATH.exists():
raise FileNotFoundError(f"Registry file {REGISTRY_PATH} not found.")
try:
with pkg_resources.open_text(DATA_ANCHOR, REGISTRY_NAME) as f:
FETCHER.load_registry(f)
except: # noqa: E722
print(f"Could not load registry from {DATA_PATH}/{REGISTRY_NAME}.")

if not MODELS_PATH.exists():
raise FileNotFoundError(f"Models file {MODELS_PATH} not found.")

FETCHER.load_registry(REGISTRY_PATH)


def _generate_function(model_name: str, files: list) -> callable:
def model_function() -> list:
def _generate_function(model_name, files) -> callable:
def model_function():
return [FETCHER.fetch(file) for file in files]

model_function.__name__ = model_name
return model_function


def _make_functions(models_path: Path, registry_path: Path):
with models_path.open("rb") as f:
models = tomli.load(f)
for model_name, files in models.items():
globals()[model_name] = _generate_function(model_name, files)
def _make_functions(models):
if isinstance(models, IOBase):
models = tomli.load(models)
else:
with Path(models).open("rb") as f:
models = tomli.load(f)
for name, files in models.items():
globals()[name] = _generate_function(name, files)


_make_functions(MODELS_PATH, REGISTRY_PATH)
try:
with pkg_resources.open_binary(DATA_ANCHOR, MODELMAP_NAME) as f:
_make_functions(f)
except: # noqa: E722
print(f"Could not load model mapping from {DATA_PATH}/{MODELMAP_NAME}.")
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,11 @@ only-include = ["modflow_devtools"]
[tool.hatch.build.targets.wheel]
packages = ["modflow_devtools"]

[tool.hatch.build]
include = [
"modflow_devtools/data/*"
]

[tool.hatch.version]
path = "modflow_devtools/__init__.py"

Expand Down

0 comments on commit 99cbd8a

Please sign in to comment.