From f30a9579351e68947549f2dd01631a0c6341a2e7 Mon Sep 17 00:00:00 2001 From: Daniel Obraczka Date: Tue, 5 Mar 2024 12:16:04 +0100 Subject: [PATCH] Fix backend typing, simplify dask (#32) * Fix generics * Remove npartitions --- sylloge/base.py | 339 +++++++++++++------------ sylloge/med_bbk_loader.py | 34 ++- sylloge/moviegraph_benchmark_loader.py | 9 +- sylloge/oaei_loader.py | 21 +- sylloge/open_ea_loader.py | 45 +++- sylloge/typing.py | 7 +- tests/test_oaei.py | 16 +- tests/test_open_ea.py | 53 ---- 8 files changed, 266 insertions(+), 258 deletions(-) diff --git a/sylloge/base.py b/sylloge/base.py index 285b83e..b18dde1 100644 --- a/sylloge/base.py +++ b/sylloge/base.py @@ -4,8 +4,8 @@ from abc import abstractmethod from dataclasses import dataclass from typing import ( - TYPE_CHECKING, Any, + Callable, Dict, Generic, Iterable, @@ -14,7 +14,6 @@ Optional, Sequence, Tuple, - TypeVar, Union, cast, overload, @@ -34,6 +33,7 @@ LABEL_HEAD, LABEL_RELATION, LABEL_TAIL, + DataFrameType, ) from .utils import fix_dataclass_init_docs @@ -41,11 +41,6 @@ BASE_DATASET_MODULE = pystow.module(BASE_DATASET_KEY) -DataFrameType = TypeVar("DataFrameType", pd.DataFrame, dd.DataFrame) - - -if TYPE_CHECKING: - import dask.dataframe as dd logger = logging.getLogger(__name__) @@ -78,7 +73,8 @@ class TrainTestValSplit(Generic[DataFrameType]): @fix_dataclass_init_docs -class EADataset(Generic[DataFrameType]): +@dataclass +class BaseEADataset(Generic[DataFrameType]): """Dataset class holding information of the alignment class.""" rel_triples_left: DataFrameType @@ -89,6 +85,10 @@ class EADataset(Generic[DataFrameType]): dataset_names: Tuple[str, str] folds: Optional[Sequence[TrainTestValSplit[DataFrameType]]] = None + +class EADataset(BaseEADataset[DataFrameType]): + """Dataset class holding information of the alignment class.""" + _REL_TRIPLES_LEFT_PATH: str = "rel_triples_left_parquet" _REL_TRIPLES_RIGHT_PATH: str = "rel_triples_right_parquet" _ATTR_TRIPLES_LEFT_PATH: str = "attr_triples_left_parquet" @@ -100,6 +100,36 @@ class EADataset(Generic[DataFrameType]): _VAL_LINKS_PATH: str = "val_parquet" _DATASET_NAMES_PATH: str = "dataset_names.txt" + @overload + def __init__( + self: "EADataset[pd.DataFrame]", + *, + rel_triples_left: DataFrameType, + rel_triples_right: DataFrameType, + attr_triples_left: DataFrameType, + attr_triples_right: DataFrameType, + ent_links: DataFrameType, + dataset_names: Tuple[str, str], + folds: Optional[Sequence[TrainTestValSplit[DataFrameType]]] = None, + backend: Literal["pandas"] = "pandas", + ): + ... + + @overload + def __init__( + self: "EADataset[dd.DataFrame]", + *, + rel_triples_left: DataFrameType, + rel_triples_right: DataFrameType, + attr_triples_left: DataFrameType, + attr_triples_right: DataFrameType, + ent_links: DataFrameType, + dataset_names: Tuple[str, str], + folds: Optional[Sequence[TrainTestValSplit[DataFrameType]]] = None, + backend: Literal["dask"] = "dask", + ): + ... + def __init__( self, *, @@ -111,7 +141,6 @@ def __init__( dataset_names: Tuple[str, str], folds: Optional[Sequence[TrainTestValSplit[DataFrameType]]] = None, backend: BACKEND_LITERAL = "pandas", - npartitions: int = 1, ) -> None: """Create an entity aligment dataclass. @@ -123,18 +152,16 @@ def __init__( :param ent_links: gold standard entity links of alignment :param folds: optional pre-split folds of the gold standard :param backend: which backend is used of either 'pandas' or 'dask' - :param npartitions: how many partitions to use for each frame, when using dask """ - self.rel_triples_left = rel_triples_left - self.rel_triples_right = rel_triples_right - self.attr_triples_left = attr_triples_left - self.attr_triples_right = attr_triples_right - self.ent_links = ent_links - self.dataset_names = dataset_names - self.folds = folds - self.npartitions: int = npartitions - self._backend: BACKEND_LITERAL = backend - # trigger possible transformation + super().__init__( + rel_triples_left=rel_triples_left, # type: ignore[arg-type] + rel_triples_right=rel_triples_right, # type: ignore[arg-type] + attr_triples_left=attr_triples_left, # type: ignore[arg-type] + attr_triples_right=attr_triples_right, # type: ignore[arg-type] + ent_links=ent_links, # type: ignore[arg-type] + dataset_names=dataset_names, + folds=folds, # type: ignore[arg-type] + ) self.backend = backend @property @@ -201,95 +228,6 @@ def __repr__(self) -> str: left_ds_stats, right_ds_stats, num_ent_links = self.statistics() return f"{self.__class__.__name__}(backend={self.backend}, {self._param_repr}rel_triples_left={left_ds_stats.rel_triples}, rel_triples_right={right_ds_stats.rel_triples}, attr_triples_left={left_ds_stats.attr_triples}, attr_triples_right={right_ds_stats.attr_triples}, ent_links={num_ent_links}, folds={len(self.folds) if self.folds else None})" - def _additional_backend_handling(self, backend: BACKEND_LITERAL): - pass - - @property - def backend(self) -> BACKEND_LITERAL: - return self._backend - - @backend.setter - def backend(self, backend: BACKEND_LITERAL): - """Set backend and transform data if needed.""" - if backend == "pandas": - self._backend = "pandas" - if isinstance(self.rel_triples_left, pd.DataFrame): - return - self.rel_triples_left = self.rel_triples_left.compute() - self.rel_triples_right = self.rel_triples_right.compute() - self.attr_triples_left = self.attr_triples_left.compute() - self.attr_triples_right = self.attr_triples_right.compute() - self.ent_links = self.ent_links.compute() - if self.folds: - for fold in self.folds: - fold.train = fold.train.compute() - fold.test = fold.test.compute() - fold.val = fold.val.compute() - - elif backend == "dask": - self._backend = "dask" - if isinstance(self.rel_triples_left, dd.DataFrame): - if self.rel_triples_left.npartitions != self.npartitions: - self.rel_triples_left = self.rel_triples_left.repartition( - npartitions=self.npartitions - ) - self.rel_triples_right = self.rel_triples_right.repartition( - npartitions=self.npartitions - ) - self.attr_triples_left = self.attr_triples_left.repartition( - npartitions=self.npartitions - ) - self.attr_triples_right = self.attr_triples_right.repartition( - npartitions=self.npartitions - ) - self.ent_links = self.ent_links.repartition( - npartitions=self.npartitions - ) - if self.folds: - for fold in self.folds: - fold.train = fold.train.repartition( - npartitions=self.npartitions - ) - fold.test = fold.test.repartition( - npartitions=self.npartitions - ) - fold.val = fold.val.repartition( - npartitions=self.npartitions - ) - else: - return - - else: - self.rel_triples_left = dd.from_pandas( - self.rel_triples_left, npartitions=self.npartitions - ) - self.rel_triples_right = dd.from_pandas( - self.rel_triples_right, npartitions=self.npartitions - ) - self.attr_triples_left = dd.from_pandas( - self.attr_triples_left, npartitions=self.npartitions - ) - self.attr_triples_right = dd.from_pandas( - self.attr_triples_right, npartitions=self.npartitions - ) - self.ent_links = dd.from_pandas( - self.ent_links, npartitions=self.npartitions - ) - if self.folds: - for fold in self.folds: - fold.train = dd.from_pandas( - fold.train, npartitions=self.npartitions - ) - fold.test = dd.from_pandas( - fold.test, npartitions=self.npartitions - ) - fold.val = dd.from_pandas( - fold.val, npartitions=self.npartitions - ) - else: - raise ValueError(f"Unknown backend {backend}") - self._additional_backend_handling(backend) - def to_parquet(self, path: Union[str, pathlib.Path], **kwargs): """Write dataset to path as several parquet files. @@ -353,7 +291,10 @@ def _read_parquet_values( if not isinstance(path, pathlib.Path): path = pathlib.Path(path) - read_parquet_fn = pd.read_parquet if backend == "pandas" else dd.read_parquet + read_parquet_fn = cast( # did not find another way to get the correct type + Callable[[Any], DataFrameType], + pd.read_parquet if backend == "pandas" else dd.read_parquet, + ) # read dataset names with open(path.joinpath(cls._DATASET_NAMES_PATH)) as fh: @@ -388,7 +329,7 @@ def _read_parquet_values( folds = [] for tmp_fold_dir in sorted(sub_dir for sub_dir in os.listdir(fold_path)): fold_dir = fold_path.joinpath(tmp_fold_dir) - train_test_val = {} + train_test_val: Dict[str, DataFrameType] = {} for links, link_path in zip( ["train", "test", "val"], [ @@ -401,15 +342,11 @@ def _read_parquet_values( fold_dir.joinpath(link_path), **kwargs ) folds.append(TrainTestValSplit(**train_test_val)) - npartitions = 1 - if backend == "dask": - npartitions = tables["rel_triples_left"].npartitions return ( dict( dataset_names=dataset_names, folds=folds, backend=backend, - npartitions=npartitions, **tables, ), {}, @@ -444,6 +381,32 @@ def read_parquet( class CacheableEADataset(EADataset[DataFrameType]): + @overload + def __init__( + self: "CacheableEADataset[pd.DataFrame]", + *, + cache_path: pathlib.Path, + use_cache: bool = True, + parquet_load_options: Optional[Mapping] = None, + parquet_store_options: Optional[Mapping] = None, + backend: Literal["pandas"], + **init_kwargs, + ): + ... + + @overload + def __init__( + self: "CacheableEADataset[dd.DataFrame]", + *, + cache_path: pathlib.Path, + use_cache: bool = True, + parquet_load_options: Optional[Mapping] = None, + parquet_store_options: Optional[Mapping] = None, + backend: Literal["dask"], + **init_kwargs, + ): + ... + def __init__( self, *, @@ -451,6 +414,7 @@ def __init__( use_cache: bool = True, parquet_load_options: Optional[Mapping] = None, parquet_store_options: Optional[Mapping] = None, + backend: BACKEND_LITERAL = "pandas", **init_kwargs, ): """EADataset that uses caching after initial read. @@ -459,13 +423,12 @@ def __init__( :param use_cache: whether to use cache :param parquet_load_options: handed through to parquet loading function :param parquet_store_options: handed through to parquet writing function + :param backend: Whether to use pandas or dask for reading/writing :param init_kwargs: other arguments for creating the EADataset instance """ self.cache_path = cache_path self.parquet_load_options = parquet_load_options or {} self.parquet_store_options = parquet_store_options or {} - backend = init_kwargs["backend"] - specific_npartitions = init_kwargs["npartitions"] update_cache = False additional_kwargs: Dict[str, Any] = {} if use_cache: @@ -481,10 +444,10 @@ def __init__( update_cache = True else: init_kwargs.update(self.initial_read(backend=backend)) - if specific_npartitions != 1: - init_kwargs["npartitions"] = specific_npartitions self.__dict__.update(additional_kwargs) - super().__init__(**init_kwargs) + if "backend" in init_kwargs: + backend = init_kwargs.pop("backend") + super().__init__(backend=backend, **init_kwargs) # type: ignore[arg-type] if update_cache: logger.info(f"Caching dataset at {self.cache_path}") self.store_cache() @@ -521,9 +484,45 @@ def store_cache(self): self.to_parquet(self.cache_path, **self.parquet_store_options) -class ZipEADataset(CacheableEADataset): +class ZipEADataset(CacheableEADataset[DataFrameType]): """Dataset created from zip file which is downloaded.""" + @overload + def __init__( + self: "ZipEADataset[pd.DataFrame]", + *, + cache_path: pathlib.Path, + zip_path: str, + inner_path: pathlib.PurePosixPath, + dataset_names: Tuple[str, str], + file_name_rel_triples_left: str = "rel_triples_1", + file_name_rel_triples_right: str = "rel_triples_2", + file_name_attr_triples_left: str = "attr_triples_1", + file_name_attr_triples_right: str = "attr_triples_2", + file_name_ent_links: str = "ent_links", + backend: Literal["pandas"], + use_cache: bool = True, + ): + ... + + @overload + def __init__( + self: "ZipEADataset[dd.DataFrame]", + *, + cache_path: pathlib.Path, + zip_path: str, + inner_path: pathlib.PurePosixPath, + dataset_names: Tuple[str, str], + file_name_rel_triples_left: str = "rel_triples_1", + file_name_rel_triples_right: str = "rel_triples_2", + file_name_attr_triples_left: str = "attr_triples_1", + file_name_attr_triples_right: str = "attr_triples_2", + file_name_ent_links: str = "ent_links", + backend: Literal["dask"], + use_cache: bool = True, + ): + ... + def __init__( self, *, @@ -537,7 +536,6 @@ def __init__( file_name_attr_triples_right: str = "attr_triples_2", file_name_ent_links: str = "ent_links", backend: BACKEND_LITERAL = "pandas", - npartitions: int = 1, use_cache: bool = True, ): """Initialize ZipEADataset. @@ -552,7 +550,6 @@ def __init__( :param file_name_attr_triples_right: file name of right attribute triples :param file_name_ent_links: file name gold standard containing all entity links :param backend: Whether to use "pandas" or "dask" - :param npartitions: how many partitions to use for each frame, when using dask :param use_cache: whether to use cache or not """ self.zip_path = zip_path @@ -563,11 +560,10 @@ def __init__( self.file_name_attr_triples_left = file_name_attr_triples_left self.file_name_attr_triples_right = file_name_attr_triples_right - super().__init__( + super().__init__( # type: ignore[misc] dataset_names=dataset_names, cache_path=cache_path, - backend=backend, - npartitions=npartitions, + backend=backend, # type: ignore[arg-type] use_cache=use_cache, ) @@ -590,30 +586,12 @@ def initial_read(self, backend: BACKEND_LITERAL) -> Dict[str, Any]: ), } - @overload - def _read_triples( - self, - file_name: Union[str, pathlib.Path], - backend: Literal["pandas"], - is_links: bool = False, - ) -> pd.DataFrame: - ... - - @overload - def _read_triples( - self, - file_name: Union[str, pathlib.Path], - backend: Literal["dask"], - is_links: bool = False, - ) -> "dd.DataFrame": - ... - def _read_triples( self, file_name: Union[str, pathlib.Path], backend: BACKEND_LITERAL, is_links: bool = False, - ) -> Union[pd.DataFrame, "dd.DataFrame"]: + ) -> DataFrameType: columns = list(EA_SIDES) if is_links else COLUMNS read_csv_kwargs = dict( # noqa: C408 header=None, @@ -623,22 +601,70 @@ def _read_triples( dtype=str, ) if backend == "pandas": - return read_zipfile_csv( + trip = read_zipfile_csv( path=self.zip_path, inner_path=str(self.inner_path.joinpath(file_name)), **read_csv_kwargs, ) - return read_dask_df_archive_csv( - path=self.zip_path, - inner_path=str(self.inner_path.joinpath(file_name)), - protocol="zip", - **read_csv_kwargs, - ) + else: + trip = read_dask_df_archive_csv( + path=self.zip_path, + inner_path=str(self.inner_path.joinpath(file_name)), + protocol="zip", + **read_csv_kwargs, + ) + return cast(DataFrameType, trip) -class ZipEADatasetWithPreSplitFolds(ZipEADataset): +class ZipEADatasetWithPreSplitFolds(ZipEADataset[DataFrameType]): """Dataset with pre-split folds created from zip file which is downloaded.""" + @overload + def __init__( + self: "ZipEADatasetWithPreSplitFolds[pd.DataFrame]", + *, + cache_path: pathlib.Path, + zip_path: str, + inner_path: pathlib.PurePosixPath, + dataset_names: Tuple[str, str], + file_name_rel_triples_left: str = "rel_triples_1", + file_name_rel_triples_right: str = "rel_triples_2", + file_name_ent_links: str = "ent_links", + file_name_attr_triples_left: str = "attr_triples_1", + file_name_attr_triples_right: str = "attr_triples_2", + backend: Literal["pandas"], + directory_name_folds: str = "721_5fold", + directory_names_individual_folds: Sequence[str] = ("1", "2", "3", "4", "5"), + file_name_test_links: str = "test_links", + file_name_train_links: str = "train_links", + file_name_valid_links: str = "valid_links", + use_cache: bool = True, + ): + ... + + @overload + def __init__( + self: "ZipEADatasetWithPreSplitFolds[dd.DataFrame]", + *, + cache_path: pathlib.Path, + zip_path: str, + inner_path: pathlib.PurePosixPath, + dataset_names: Tuple[str, str], + file_name_rel_triples_left: str = "rel_triples_1", + file_name_rel_triples_right: str = "rel_triples_2", + file_name_ent_links: str = "ent_links", + file_name_attr_triples_left: str = "attr_triples_1", + file_name_attr_triples_right: str = "attr_triples_2", + backend: Literal["dask"], + directory_name_folds: str = "721_5fold", + directory_names_individual_folds: Sequence[str] = ("1", "2", "3", "4", "5"), + file_name_test_links: str = "test_links", + file_name_train_links: str = "train_links", + file_name_valid_links: str = "valid_links", + use_cache: bool = True, + ): + ... + def __init__( self, *, @@ -652,7 +678,6 @@ def __init__( file_name_attr_triples_left: str = "attr_triples_1", file_name_attr_triples_right: str = "attr_triples_2", backend: BACKEND_LITERAL = "pandas", - npartitions: int = 1, directory_name_folds: str = "721_5fold", directory_names_individual_folds: Sequence[str] = ("1", "2", "3", "4", "5"), file_name_test_links: str = "test_links", @@ -672,7 +697,6 @@ def __init__( :param file_name_attr_triples_right: file name of right attribute triples :param file_name_ent_links: file name gold standard containing all entity links :param backend: Whether to use "pandas" or "dask" - :param npartitions: how many partitions to use for each frame, when using dask :param directory_name_folds: name of the folds directory :param directory_names_individual_folds: name of individual folds :param file_name_test_links: name of test link file @@ -688,13 +712,12 @@ def __init__( self.file_name_test_links = file_name_test_links self.file_name_valid_links = file_name_valid_links - super().__init__( + super().__init__( # type: ignore[misc] dataset_names=dataset_names, zip_path=zip_path, inner_path=inner_path, cache_path=cache_path, - backend=backend, - npartitions=npartitions, + backend=backend, # type: ignore[arg-type] use_cache=use_cache, file_name_rel_triples_left=file_name_rel_triples_left, file_name_rel_triples_right=file_name_rel_triples_right, diff --git a/sylloge/med_bbk_loader.py b/sylloge/med_bbk_loader.py index b1a825d..60ede7f 100644 --- a/sylloge/med_bbk_loader.py +++ b/sylloge/med_bbk_loader.py @@ -1,12 +1,15 @@ import pathlib -from typing import Any, Dict, Optional +from typing import Any, Dict, Literal, Optional, overload -from .base import BACKEND_LITERAL, BASE_DATASET_MODULE, ZipEADataset +import dask.dataframe as dd +import pandas as pd + +from .base import BACKEND_LITERAL, BASE_DATASET_MODULE, DataFrameType, ZipEADataset MED_BBK_MODULE = BASE_DATASET_MODULE.module("med_bbk") -class MED_BBK(ZipEADataset): +class MED_BBK(ZipEADataset[DataFrameType]): """Class containing the MED-BBK dataset. Published in `Zhang, Z. et. al. (2020) An Industry Evaluation of Embedding-based Entity Alignment `_, @@ -21,17 +24,33 @@ class MED_BBK(ZipEADataset): #: The hex digest for the zip file _SHA512: str = "da1ee2b025070fd6890fb7e77b07214af3767b5ae85bcdc1bb36958b4b8dd935bc636e3466b94169158940a960541f96284e3217d32976bfeefa56e29d4a9e0d" + @overload + def __init__( + self: "MED_BBK[pd.DataFrame]", + backend: Literal["pandas"] = "pandas", + use_cache: bool = True, + cache_path: Optional[pathlib.Path] = None, + ): + ... + + @overload + def __init__( + self: "MED_BBK[dd.DataFrame]", + backend: Literal["dask"] = "dask", + use_cache: bool = True, + cache_path: Optional[pathlib.Path] = None, + ): + ... + def __init__( self, backend: BACKEND_LITERAL = "pandas", - npartitions: int = 1, use_cache: bool = True, cache_path: Optional[pathlib.Path] = None, ): """Initialize an MED-BBK dataset. :param backend: Whether to use "pandas" or "dask" - :param npartitions: how many partitions to use for each frame, when using dask :param use_cache: whether to use cache or not :param cache_path: Path where cache will be stored/loaded """ @@ -45,13 +64,12 @@ def __init__( actual_cache_path = self.create_cache_path( MED_BBK_MODULE, inner_path, cache_path ) - super().__init__( + super().__init__( # type: ignore[misc] cache_path=actual_cache_path, use_cache=use_cache, zip_path=zip_path, inner_path=pathlib.PurePosixPath(inner_path), - backend=backend, - npartitions=npartitions, + backend=backend, # type: ignore[arg-type] dataset_names=("MED", "BBK"), ) diff --git a/sylloge/moviegraph_benchmark_loader.py b/sylloge/moviegraph_benchmark_loader.py index 1accaa8..8fbeb7a 100644 --- a/sylloge/moviegraph_benchmark_loader.py +++ b/sylloge/moviegraph_benchmark_loader.py @@ -1,6 +1,7 @@ import pathlib from typing import Literal, Optional, Tuple +import pandas as pd from moviegraphbenchmark import load_data from .base import ( @@ -21,7 +22,7 @@ GRAPH_PAIRS: Tuple[GraphPair, ...] = (IMDB_TMDB, IMDB_TVDB, TMDB_TVDB) -class MovieGraphBenchmark(CacheableEADataset): +class MovieGraphBenchmark(CacheableEADataset[pd.DataFrame]): """Class containing the movie graph benchmark. Published in `Obraczka, D. et. al. (2021) Embedding-Assisted Entity Resolution for Knowledge Graphs `_, @@ -31,8 +32,6 @@ class MovieGraphBenchmark(CacheableEADataset): def __init__( self, graph_pair: GraphPair = "imdb-tmdb", - backend: BACKEND_LITERAL = "pandas", - npartitions: int = 1, use_cache: bool = True, cache_path: Optional[pathlib.Path] = None, ): @@ -40,7 +39,6 @@ def __init__( :param graph_pair: which graph pair to use of "imdb-tdmb","imdb-tvdb" or "tmdb-tvdb" :param backend: Whether to use "pandas" or "dask" - :param npartitions: how many partitions to use for each frame, when using dask :param use_cache: whether to use cache or not :param cache_path: Path where cache will be stored/loaded :raises ValueError: if unknown graph pair @@ -58,8 +56,7 @@ def __init__( super().__init__( cache_path=actual_cache_path, use_cache=use_cache, - backend=backend, - npartitions=npartitions, + backend="pandas", dataset_names=(left_name, right_name), ) diff --git a/sylloge/oaei_loader.py b/sylloge/oaei_loader.py index cfa1cf6..6994e49 100644 --- a/sylloge/oaei_loader.py +++ b/sylloge/oaei_loader.py @@ -11,7 +11,6 @@ from .base import ( BASE_DATASET_MODULE, CacheableEADataset, - DataFrameType, DatasetStatistics, ) from .dask import read_dask_bag_from_archive_text @@ -40,8 +39,12 @@ REFLINE_TYPE_LITERAL = Literal["dismiss", "entity", "property", "class", "unknown"] REL_ATTR_LITERAL = Literal["rel", "attr"] -reflinetype = pd.CategoricalDtype(categories=typing.get_args(REFLINE_TYPE_LITERAL)) -relattrlinetype = pd.CategoricalDtype(categories=typing.get_args(REL_ATTR_LITERAL)) +reflinetype = pd.CategoricalDtype( + categories=list(typing.get_args(REFLINE_TYPE_LITERAL)) +) +relattrlinetype = pd.CategoricalDtype( + categories=list(typing.get_args(REL_ATTR_LITERAL)) +) class URL_SHA512_HASH(NamedTuple): @@ -98,7 +101,7 @@ def fault_tolerant_parse_nt( return subj, pred, obj, triple_type -class OAEI(CacheableEADataset[DataFrameType]): +class OAEI(CacheableEADataset[dd.DataFrame]): """The OAEI (Ontology Alignment Evaluation Initiative) Knowledge Graph Track tasks contain graphs created from fandom wikis. Five integration tasks are available: @@ -111,8 +114,8 @@ class OAEI(CacheableEADataset[DataFrameType]): More information can be found at the `website `_. """ - class_links: pd.DataFrame - property_links: pd.DataFrame + class_links: dd.DataFrame + property_links: dd.DataFrame _CLASS_LINKS_PATH: str = "class_links_parquet" _PROPERTY_LINKS_PATH: str = "property_links_parquet" @@ -246,8 +249,6 @@ class OAEI(CacheableEADataset[DataFrameType]): def __init__( self, task: OAEI_TASK_NAME = "starwars-swg", - backend: BACKEND_LITERAL = "dask", - npartitions: int = 1, use_cache: bool = True, cache_path: Optional[pathlib.Path] = None, ): @@ -255,7 +256,6 @@ def __init__( :param task: Name of the task. Has to be one of {starwars-swg,starwars-swtor,marvelcinematicuniverse-marvel,memoryalpha-memorybeta, memoryalpha-stexpanded} :param backend: Whether to use "pandas" or "dask" - :param npartitions: how many partitions to use for each frame, when using dask :param use_cache: whether to use cache or not :param cache_path: Path where cache will be stored/loaded :raises ValueError: if unknown task value is provided @@ -274,8 +274,7 @@ def __init__( use_cache=use_cache, cache_path=actual_cache_path, dataset_names=(left_name, right_name), - backend=backend, - npartitions=npartitions, + backend="dask", ) def initial_read(self, backend: BACKEND_LITERAL): diff --git a/sylloge/open_ea_loader.py b/sylloge/open_ea_loader.py index 67fc7bb..7f03199 100644 --- a/sylloge/open_ea_loader.py +++ b/sylloge/open_ea_loader.py @@ -1,9 +1,17 @@ # largely adapted from pykeen.datasets.ea.openea import pathlib from types import MappingProxyType -from typing import Literal, Optional, Tuple +from typing import Literal, Optional, Tuple, overload -from .base import BACKEND_LITERAL, BASE_DATASET_MODULE, ZipEADatasetWithPreSplitFolds +import dask.dataframe as dd +import pandas as pd + +from .base import ( + BACKEND_LITERAL, + BASE_DATASET_MODULE, + DataFrameType, + ZipEADatasetWithPreSplitFolds, +) OPEN_EA_MODULE = BASE_DATASET_MODULE.module("open_ea") @@ -28,7 +36,7 @@ GRAPH_VERSIONS = (V1, V2) -class OpenEA(ZipEADatasetWithPreSplitFolds): +class OpenEA(ZipEADatasetWithPreSplitFolds[DataFrameType]): """Class containing the OpenEA dataset family. Published in `Sun, Z. et. al. (2020) A Benchmarking Study of Embedding-based Entity Alignment for Knowledge Graphs `_, @@ -53,13 +61,36 @@ class OpenEA(ZipEADatasetWithPreSplitFolds): } ) + @overload + def __init__( + self: "OpenEA[pd.DataFrame]", + graph_pair: GraphPair = "D_W", + size: GraphSize = "15K", + version: GraphVersion = "V1", + backend: Literal["pandas"] = "pandas", + use_cache: bool = True, + cache_path: Optional[pathlib.Path] = None, + ): + ... + + @overload + def __init__( + self: "OpenEA[dd.DataFrame]", + graph_pair: GraphPair = "D_W", + size: GraphSize = "15K", + version: GraphVersion = "V1", + backend: Literal["dask"] = "dask", + use_cache: bool = True, + cache_path: Optional[pathlib.Path] = None, + ): + ... + def __init__( self, graph_pair: GraphPair = "D_W", size: GraphSize = "15K", version: GraphVersion = "V1", backend: BACKEND_LITERAL = "pandas", - npartitions: int = 1, use_cache: bool = True, cache_path: Optional[pathlib.Path] = None, ): @@ -69,7 +100,6 @@ def __init__( :param size: what size ("15K" or "100K") :param version: which version to use ("V1" or "V2") :param backend: Whether to use "pandas" or "dask" - :param npartitions: how many partitions to use for each frame, when using dask :param use_cache: whether to use cache or not :param cache_path: Path where cache will be stored/loaded :raises ValueError: if unknown graph_pair,size or version values are provided @@ -98,13 +128,12 @@ def __init__( actual_cache_path = self.create_cache_path( OPEN_EA_MODULE, inner_cache_path, cache_path ) - super().__init__( + super().__init__( # type: ignore[misc] cache_path=actual_cache_path, use_cache=use_cache, zip_path=zip_path, inner_path=inner_path, - backend=backend, - npartitions=npartitions, + backend=backend, # type: ignore[arg-type] dataset_names=OpenEA._GRAPH_PAIR_TO_DS_NAMES[graph_pair], ) diff --git a/sylloge/typing.py b/sylloge/typing.py index bd917bc..6814c60 100644 --- a/sylloge/typing.py +++ b/sylloge/typing.py @@ -1,4 +1,7 @@ -from typing import Literal, Tuple +from typing import Literal, Tuple, TypeVar + +import dask.dataframe as dd +import pandas as pd # borrowed from pykeen.typing Target = Literal["head", "relation", "tail"] @@ -10,4 +13,6 @@ EA_SIDE_RIGHT: EASide = "right" EA_SIDES: Tuple[EASide, EASide] = (EA_SIDE_LEFT, EA_SIDE_RIGHT) COLUMNS = [LABEL_HEAD, LABEL_RELATION, LABEL_TAIL] + BACKEND_LITERAL = Literal["pandas", "dask"] +DataFrameType = TypeVar("DataFrameType", pd.DataFrame, dd.DataFrame) diff --git a/tests/test_oaei.py b/tests/test_oaei.py index 49d029e..cd03c4b 100644 --- a/tests/test_oaei.py +++ b/tests/test_oaei.py @@ -1,5 +1,4 @@ import dask.dataframe as dd -import pandas as pd import pytest from mocks import ResourceMocker @@ -16,8 +15,7 @@ "memoryalpha-stexpanded", ], ) -@pytest.mark.parametrize("backend", ["pandas", "dask"]) -def test_oaei_mock(task, backend, mocker, tmp_path): +def test_oaei_mock(task, mocker, tmp_path): rm = ResourceMocker() mocker.patch( "sylloge.oaei_loader.read_dask_bag_from_archive_text", @@ -30,7 +28,7 @@ def test_oaei_mock(task, backend, mocker, tmp_path): "sylloge.oaei_loader.read_dask_bag_from_archive_text", rm.assert_not_called, ) - ds = OAEI(backend=backend, task=task, use_cache=use_cache, cache_path=tmp_path) + ds = OAEI(task=task, use_cache=use_cache, cache_path=tmp_path) assert ds.__repr__() is not None assert ds.canonical_name assert ds.rel_triples_left is not None @@ -39,12 +37,4 @@ def test_oaei_mock(task, backend, mocker, tmp_path): assert ds.attr_triples_right is not None assert ds.ent_links is not None assert ds.dataset_names == tuple(task.split("-")) - - if backend == "pandas": - assert isinstance(ds.rel_triples_left, pd.DataFrame) - ds.backend = "dask" - assert isinstance(ds.rel_triples_left, dd.DataFrame) - else: - assert isinstance(ds.rel_triples_left, dd.DataFrame) - ds.backend = "pandas" - assert isinstance(ds.rel_triples_left, pd.DataFrame) + assert isinstance(ds.rel_triples_left, dd.DataFrame) diff --git a/tests/test_open_ea.py b/tests/test_open_ea.py index 5a31eeb..5ffca62 100644 --- a/tests/test_open_ea.py +++ b/tests/test_open_ea.py @@ -232,56 +232,3 @@ def test_open_ea_mock( assert isinstance(ds.rel_triples_left, pd.DataFrame) else: assert isinstance(ds.rel_triples_left, dd.DataFrame) - assert ds.rel_triples_left.npartitions == ds.npartitions - - -@pytest.mark.parametrize("use_cache", [False, True]) -def test_backend_handling(use_cache, mocker, tmp_path): - rm = ResourceMocker() - mocker.patch("sylloge.base.read_zipfile_csv", rm.mock_read_zipfile_csv) - mocker.patch( - "sylloge.base.read_dask_df_archive_csv", rm.mock_read_dask_df_archive_csv - ) - # test repartitioning - new_npartitions = 10 - # run twice to check if caching works here - rerun = 2 if use_cache else 1 - for _ in range(rerun): - ds = OpenEA( - backend="dask", - npartitions=new_npartitions, - use_cache=use_cache, - cache_path=tmp_path, - ) - assert isinstance(ds.rel_triples_left, dd.DataFrame) - assert isinstance(ds.rel_triples_right, dd.DataFrame) - assert isinstance(ds.attr_triples_left, dd.DataFrame) - assert isinstance(ds.attr_triples_right, dd.DataFrame) - assert isinstance(ds.ent_links, dd.DataFrame) - for fold in ds.folds: - assert isinstance(fold.train, dd.DataFrame) - assert isinstance(fold.test, dd.DataFrame) - assert isinstance(fold.val, dd.DataFrame) - - assert ds.rel_triples_left.npartitions == new_npartitions - assert ds.rel_triples_right.npartitions == new_npartitions - assert ds.attr_triples_right.npartitions == new_npartitions - assert ds.attr_triples_right.npartitions == new_npartitions - assert ds.ent_links.npartitions == new_npartitions - assert ds.folds - for fold in ds.folds: - assert fold.train.npartitions == new_npartitions - assert fold.test.npartitions == new_npartitions - assert fold.val.npartitions == new_npartitions - - # test backend changing - ds.backend = "pandas" - assert isinstance(ds.rel_triples_left, pd.DataFrame) - assert isinstance(ds.rel_triples_right, pd.DataFrame) - assert isinstance(ds.attr_triples_left, pd.DataFrame) - assert isinstance(ds.attr_triples_right, pd.DataFrame) - assert isinstance(ds.ent_links, pd.DataFrame) - for fold in ds.folds: - assert isinstance(fold.train, pd.DataFrame) - assert isinstance(fold.test, pd.DataFrame) - assert isinstance(fold.val, pd.DataFrame)