Skip to content

Commit

Permalink
Fix backend typing, simplify dask (#32)
Browse files Browse the repository at this point in the history
* Fix generics

* Remove npartitions
  • Loading branch information
dobraczka authored Mar 5, 2024
1 parent 6c5ad86 commit f30a957
Show file tree
Hide file tree
Showing 8 changed files with 266 additions and 258 deletions.
339 changes: 181 additions & 158 deletions sylloge/base.py

Large diffs are not rendered by default.

34 changes: 26 additions & 8 deletions sylloge/med_bbk_loader.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import pathlib
from typing import Any, Dict, Optional
from typing import Any, Dict, Literal, Optional, overload

from .base import BACKEND_LITERAL, BASE_DATASET_MODULE, ZipEADataset
import dask.dataframe as dd
import pandas as pd

from .base import BACKEND_LITERAL, BASE_DATASET_MODULE, DataFrameType, ZipEADataset

MED_BBK_MODULE = BASE_DATASET_MODULE.module("med_bbk")


class MED_BBK(ZipEADataset):
class MED_BBK(ZipEADataset[DataFrameType]):
"""Class containing the MED-BBK dataset.
Published in `Zhang, Z. et. al. (2020) An Industry Evaluation of Embedding-based Entity Alignment <A Benchmarking Study of Embedding-based Entity Alignment for Knowledge Graphs>`_,
Expand All @@ -21,17 +24,33 @@ class MED_BBK(ZipEADataset):
#: The hex digest for the zip file
_SHA512: str = "da1ee2b025070fd6890fb7e77b07214af3767b5ae85bcdc1bb36958b4b8dd935bc636e3466b94169158940a960541f96284e3217d32976bfeefa56e29d4a9e0d"

@overload
def __init__(
self: "MED_BBK[pd.DataFrame]",
backend: Literal["pandas"] = "pandas",
use_cache: bool = True,
cache_path: Optional[pathlib.Path] = None,
):
...

@overload
def __init__(
self: "MED_BBK[dd.DataFrame]",
backend: Literal["dask"] = "dask",
use_cache: bool = True,
cache_path: Optional[pathlib.Path] = None,
):
...

def __init__(
self,
backend: BACKEND_LITERAL = "pandas",
npartitions: int = 1,
use_cache: bool = True,
cache_path: Optional[pathlib.Path] = None,
):
"""Initialize an MED-BBK dataset.
:param backend: Whether to use "pandas" or "dask"
:param npartitions: how many partitions to use for each frame, when using dask
:param use_cache: whether to use cache or not
:param cache_path: Path where cache will be stored/loaded
"""
Expand All @@ -45,13 +64,12 @@ def __init__(
actual_cache_path = self.create_cache_path(
MED_BBK_MODULE, inner_path, cache_path
)
super().__init__(
super().__init__( # type: ignore[misc]
cache_path=actual_cache_path,
use_cache=use_cache,
zip_path=zip_path,
inner_path=pathlib.PurePosixPath(inner_path),
backend=backend,
npartitions=npartitions,
backend=backend, # type: ignore[arg-type]
dataset_names=("MED", "BBK"),
)

Expand Down
9 changes: 3 additions & 6 deletions sylloge/moviegraph_benchmark_loader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pathlib
from typing import Literal, Optional, Tuple

import pandas as pd
from moviegraphbenchmark import load_data

from .base import (
Expand All @@ -21,7 +22,7 @@
GRAPH_PAIRS: Tuple[GraphPair, ...] = (IMDB_TMDB, IMDB_TVDB, TMDB_TVDB)


class MovieGraphBenchmark(CacheableEADataset):
class MovieGraphBenchmark(CacheableEADataset[pd.DataFrame]):
"""Class containing the movie graph benchmark.
Published in `Obraczka, D. et. al. (2021) Embedding-Assisted Entity Resolution for Knowledge Graphs <http://ceur-ws.org/Vol-2873/paper8.pdf>`_,
Expand All @@ -31,16 +32,13 @@ class MovieGraphBenchmark(CacheableEADataset):
def __init__(
self,
graph_pair: GraphPair = "imdb-tmdb",
backend: BACKEND_LITERAL = "pandas",
npartitions: int = 1,
use_cache: bool = True,
cache_path: Optional[pathlib.Path] = None,
):
"""Initialize a MovieGraphBenchmark dataset.
:param graph_pair: which graph pair to use of "imdb-tdmb","imdb-tvdb" or "tmdb-tvdb"
:param backend: Whether to use "pandas" or "dask"
:param npartitions: how many partitions to use for each frame, when using dask
:param use_cache: whether to use cache or not
:param cache_path: Path where cache will be stored/loaded
:raises ValueError: if unknown graph pair
Expand All @@ -58,8 +56,7 @@ def __init__(
super().__init__(
cache_path=actual_cache_path,
use_cache=use_cache,
backend=backend,
npartitions=npartitions,
backend="pandas",
dataset_names=(left_name, right_name),
)

Expand Down
21 changes: 10 additions & 11 deletions sylloge/oaei_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from .base import (
BASE_DATASET_MODULE,
CacheableEADataset,
DataFrameType,
DatasetStatistics,
)
from .dask import read_dask_bag_from_archive_text
Expand Down Expand Up @@ -40,8 +39,12 @@
REFLINE_TYPE_LITERAL = Literal["dismiss", "entity", "property", "class", "unknown"]
REL_ATTR_LITERAL = Literal["rel", "attr"]

reflinetype = pd.CategoricalDtype(categories=typing.get_args(REFLINE_TYPE_LITERAL))
relattrlinetype = pd.CategoricalDtype(categories=typing.get_args(REL_ATTR_LITERAL))
reflinetype = pd.CategoricalDtype(
categories=list(typing.get_args(REFLINE_TYPE_LITERAL))
)
relattrlinetype = pd.CategoricalDtype(
categories=list(typing.get_args(REL_ATTR_LITERAL))
)


class URL_SHA512_HASH(NamedTuple):
Expand Down Expand Up @@ -98,7 +101,7 @@ def fault_tolerant_parse_nt(
return subj, pred, obj, triple_type


class OAEI(CacheableEADataset[DataFrameType]):
class OAEI(CacheableEADataset[dd.DataFrame]):
"""The OAEI (Ontology Alignment Evaluation Initiative) Knowledge Graph Track tasks contain graphs created from fandom wikis.
Five integration tasks are available:
Expand All @@ -111,8 +114,8 @@ class OAEI(CacheableEADataset[DataFrameType]):
More information can be found at the `website <http://oaei.ontologymatching.org/2019/knowledgegraph/index.html>`_.
"""

class_links: pd.DataFrame
property_links: pd.DataFrame
class_links: dd.DataFrame
property_links: dd.DataFrame

_CLASS_LINKS_PATH: str = "class_links_parquet"
_PROPERTY_LINKS_PATH: str = "property_links_parquet"
Expand Down Expand Up @@ -246,16 +249,13 @@ class OAEI(CacheableEADataset[DataFrameType]):
def __init__(
self,
task: OAEI_TASK_NAME = "starwars-swg",
backend: BACKEND_LITERAL = "dask",
npartitions: int = 1,
use_cache: bool = True,
cache_path: Optional[pathlib.Path] = None,
):
"""Initialize a OAEI Knowledge Graph Track task.
:param task: Name of the task. Has to be one of {starwars-swg,starwars-swtor,marvelcinematicuniverse-marvel,memoryalpha-memorybeta, memoryalpha-stexpanded}
:param backend: Whether to use "pandas" or "dask"
:param npartitions: how many partitions to use for each frame, when using dask
:param use_cache: whether to use cache or not
:param cache_path: Path where cache will be stored/loaded
:raises ValueError: if unknown task value is provided
Expand All @@ -274,8 +274,7 @@ def __init__(
use_cache=use_cache,
cache_path=actual_cache_path,
dataset_names=(left_name, right_name),
backend=backend,
npartitions=npartitions,
backend="dask",
)

def initial_read(self, backend: BACKEND_LITERAL):
Expand Down
45 changes: 37 additions & 8 deletions sylloge/open_ea_loader.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
# largely adapted from pykeen.datasets.ea.openea
import pathlib
from types import MappingProxyType
from typing import Literal, Optional, Tuple
from typing import Literal, Optional, Tuple, overload

from .base import BACKEND_LITERAL, BASE_DATASET_MODULE, ZipEADatasetWithPreSplitFolds
import dask.dataframe as dd
import pandas as pd

from .base import (
BACKEND_LITERAL,
BASE_DATASET_MODULE,
DataFrameType,
ZipEADatasetWithPreSplitFolds,
)

OPEN_EA_MODULE = BASE_DATASET_MODULE.module("open_ea")

Expand All @@ -28,7 +36,7 @@
GRAPH_VERSIONS = (V1, V2)


class OpenEA(ZipEADatasetWithPreSplitFolds):
class OpenEA(ZipEADatasetWithPreSplitFolds[DataFrameType]):
"""Class containing the OpenEA dataset family.
Published in `Sun, Z. et. al. (2020) A Benchmarking Study of Embedding-based Entity Alignment for Knowledge Graphs <http://www.vldb.org/pvldb/vol13/p2326-sun.pdf>`_,
Expand All @@ -53,13 +61,36 @@ class OpenEA(ZipEADatasetWithPreSplitFolds):
}
)

@overload
def __init__(
self: "OpenEA[pd.DataFrame]",
graph_pair: GraphPair = "D_W",
size: GraphSize = "15K",
version: GraphVersion = "V1",
backend: Literal["pandas"] = "pandas",
use_cache: bool = True,
cache_path: Optional[pathlib.Path] = None,
):
...

@overload
def __init__(
self: "OpenEA[dd.DataFrame]",
graph_pair: GraphPair = "D_W",
size: GraphSize = "15K",
version: GraphVersion = "V1",
backend: Literal["dask"] = "dask",
use_cache: bool = True,
cache_path: Optional[pathlib.Path] = None,
):
...

def __init__(
self,
graph_pair: GraphPair = "D_W",
size: GraphSize = "15K",
version: GraphVersion = "V1",
backend: BACKEND_LITERAL = "pandas",
npartitions: int = 1,
use_cache: bool = True,
cache_path: Optional[pathlib.Path] = None,
):
Expand All @@ -69,7 +100,6 @@ def __init__(
:param size: what size ("15K" or "100K")
:param version: which version to use ("V1" or "V2")
:param backend: Whether to use "pandas" or "dask"
:param npartitions: how many partitions to use for each frame, when using dask
:param use_cache: whether to use cache or not
:param cache_path: Path where cache will be stored/loaded
:raises ValueError: if unknown graph_pair,size or version values are provided
Expand Down Expand Up @@ -98,13 +128,12 @@ def __init__(
actual_cache_path = self.create_cache_path(
OPEN_EA_MODULE, inner_cache_path, cache_path
)
super().__init__(
super().__init__( # type: ignore[misc]
cache_path=actual_cache_path,
use_cache=use_cache,
zip_path=zip_path,
inner_path=inner_path,
backend=backend,
npartitions=npartitions,
backend=backend, # type: ignore[arg-type]
dataset_names=OpenEA._GRAPH_PAIR_TO_DS_NAMES[graph_pair],
)

Expand Down
7 changes: 6 additions & 1 deletion sylloge/typing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import Literal, Tuple
from typing import Literal, Tuple, TypeVar

import dask.dataframe as dd
import pandas as pd

# borrowed from pykeen.typing
Target = Literal["head", "relation", "tail"]
Expand All @@ -10,4 +13,6 @@
EA_SIDE_RIGHT: EASide = "right"
EA_SIDES: Tuple[EASide, EASide] = (EA_SIDE_LEFT, EA_SIDE_RIGHT)
COLUMNS = [LABEL_HEAD, LABEL_RELATION, LABEL_TAIL]

BACKEND_LITERAL = Literal["pandas", "dask"]
DataFrameType = TypeVar("DataFrameType", pd.DataFrame, dd.DataFrame)
16 changes: 3 additions & 13 deletions tests/test_oaei.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import dask.dataframe as dd
import pandas as pd
import pytest
from mocks import ResourceMocker

Expand All @@ -16,8 +15,7 @@
"memoryalpha-stexpanded",
],
)
@pytest.mark.parametrize("backend", ["pandas", "dask"])
def test_oaei_mock(task, backend, mocker, tmp_path):
def test_oaei_mock(task, mocker, tmp_path):
rm = ResourceMocker()
mocker.patch(
"sylloge.oaei_loader.read_dask_bag_from_archive_text",
Expand All @@ -30,7 +28,7 @@ def test_oaei_mock(task, backend, mocker, tmp_path):
"sylloge.oaei_loader.read_dask_bag_from_archive_text",
rm.assert_not_called,
)
ds = OAEI(backend=backend, task=task, use_cache=use_cache, cache_path=tmp_path)
ds = OAEI(task=task, use_cache=use_cache, cache_path=tmp_path)
assert ds.__repr__() is not None
assert ds.canonical_name
assert ds.rel_triples_left is not None
Expand All @@ -39,12 +37,4 @@ def test_oaei_mock(task, backend, mocker, tmp_path):
assert ds.attr_triples_right is not None
assert ds.ent_links is not None
assert ds.dataset_names == tuple(task.split("-"))

if backend == "pandas":
assert isinstance(ds.rel_triples_left, pd.DataFrame)
ds.backend = "dask"
assert isinstance(ds.rel_triples_left, dd.DataFrame)
else:
assert isinstance(ds.rel_triples_left, dd.DataFrame)
ds.backend = "pandas"
assert isinstance(ds.rel_triples_left, pd.DataFrame)
assert isinstance(ds.rel_triples_left, dd.DataFrame)
53 changes: 0 additions & 53 deletions tests/test_open_ea.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,56 +232,3 @@ def test_open_ea_mock(
assert isinstance(ds.rel_triples_left, pd.DataFrame)
else:
assert isinstance(ds.rel_triples_left, dd.DataFrame)
assert ds.rel_triples_left.npartitions == ds.npartitions


@pytest.mark.parametrize("use_cache", [False, True])
def test_backend_handling(use_cache, mocker, tmp_path):
rm = ResourceMocker()
mocker.patch("sylloge.base.read_zipfile_csv", rm.mock_read_zipfile_csv)
mocker.patch(
"sylloge.base.read_dask_df_archive_csv", rm.mock_read_dask_df_archive_csv
)
# test repartitioning
new_npartitions = 10
# run twice to check if caching works here
rerun = 2 if use_cache else 1
for _ in range(rerun):
ds = OpenEA(
backend="dask",
npartitions=new_npartitions,
use_cache=use_cache,
cache_path=tmp_path,
)
assert isinstance(ds.rel_triples_left, dd.DataFrame)
assert isinstance(ds.rel_triples_right, dd.DataFrame)
assert isinstance(ds.attr_triples_left, dd.DataFrame)
assert isinstance(ds.attr_triples_right, dd.DataFrame)
assert isinstance(ds.ent_links, dd.DataFrame)
for fold in ds.folds:
assert isinstance(fold.train, dd.DataFrame)
assert isinstance(fold.test, dd.DataFrame)
assert isinstance(fold.val, dd.DataFrame)

assert ds.rel_triples_left.npartitions == new_npartitions
assert ds.rel_triples_right.npartitions == new_npartitions
assert ds.attr_triples_right.npartitions == new_npartitions
assert ds.attr_triples_right.npartitions == new_npartitions
assert ds.ent_links.npartitions == new_npartitions
assert ds.folds
for fold in ds.folds:
assert fold.train.npartitions == new_npartitions
assert fold.test.npartitions == new_npartitions
assert fold.val.npartitions == new_npartitions

# test backend changing
ds.backend = "pandas"
assert isinstance(ds.rel_triples_left, pd.DataFrame)
assert isinstance(ds.rel_triples_right, pd.DataFrame)
assert isinstance(ds.attr_triples_left, pd.DataFrame)
assert isinstance(ds.attr_triples_right, pd.DataFrame)
assert isinstance(ds.ent_links, pd.DataFrame)
for fold in ds.folds:
assert isinstance(fold.train, pd.DataFrame)
assert isinstance(fold.test, pd.DataFrame)
assert isinstance(fold.val, pd.DataFrame)

0 comments on commit f30a957

Please sign in to comment.