From 6d6bccb6da979afabdb354f359999c0ff711cb6f Mon Sep 17 00:00:00 2001 From: Anthony Mahanna <43019056+aMahanna@users.noreply.github.com> Date: Thu, 30 Dec 2021 14:48:39 -0500 Subject: [PATCH] new: mirror networkx-adapter changes (#3) --- .github/workflows/analyze.yml | 25 +- .github/workflows/build.yml | 56 +-- .github/workflows/release.yml | 328 ++++++++---------- .gitignore | 127 ++++++- MANIFEST.in | 3 + README.md | 11 +- VERSION | 1 - .../{adbdgl_adapter => }/__init__.py | 0 adbdgl_adapter/abc.py | 82 +++++ .../adbdgl_adapter.py => adapter.py} | 304 ++++++++-------- adbdgl_adapter/adbdgl_adapter/abc.py | 69 ---- .../adbdgl_adapter/adbdgl_controller.py | 49 --- adbdgl_adapter/controller.py | 70 ++++ adbdgl_adapter/setup.cfg | 7 - adbdgl_adapter/tests/conftest.py | 119 ------- adbdgl_adapter/typings.py | 12 + examples/ArangoDB_DGL_Adapter.ipynb | 303 ++++++++++------ pyproject.toml | 23 ++ scripts/assert_version.py | 10 - scripts/extract_version.py | 9 - setup.cfg | 31 ++ adbdgl_adapter/setup.py => setup.py | 33 +- tests/__init__.py | 0 .../tests => tests}/assets/arangorestore | Bin tests/conftest.py | 123 +++++++ .../test_adapter.py | 93 +++-- 26 files changed, 1080 insertions(+), 808 deletions(-) create mode 100644 MANIFEST.in delete mode 100644 VERSION rename adbdgl_adapter/{adbdgl_adapter => }/__init__.py (100%) create mode 100644 adbdgl_adapter/abc.py rename adbdgl_adapter/{adbdgl_adapter/adbdgl_adapter.py => adapter.py} (61%) delete mode 100644 adbdgl_adapter/adbdgl_adapter/abc.py delete mode 100644 adbdgl_adapter/adbdgl_adapter/adbdgl_controller.py create mode 100644 adbdgl_adapter/controller.py delete mode 100644 adbdgl_adapter/setup.cfg delete mode 100755 adbdgl_adapter/tests/conftest.py create mode 100644 adbdgl_adapter/typings.py create mode 100644 pyproject.toml delete mode 100644 scripts/assert_version.py delete mode 100644 scripts/extract_version.py create mode 100644 setup.cfg rename adbdgl_adapter/setup.py => setup.py (58%) create mode 100644 tests/__init__.py rename {adbdgl_adapter/tests => tests}/assets/arangorestore (100%) create mode 100755 tests/conftest.py rename adbdgl_adapter/tests/test_adbdgl_adapter.py => tests/test_adapter.py (79%) diff --git a/.github/workflows/analyze.yml b/.github/workflows/analyze.yml index e0001fb..dc84535 100644 --- a/.github/workflows/analyze.yml +++ b/.github/workflows/analyze.yml @@ -11,23 +11,30 @@ # name: analyze on: - workflow_dispatch: + push: + branches: [ master ] + paths: + - 'adbdgl_adapter/**' + - 'tests/**' + - 'setup.py' + - 'setup.cfg' + - 'pyproject.toml' + - '.github/workflows/analyze.yml' pull_request: - # The branches below must be a subset of the branches above - branches: [master] + branches: [ master ] paths: - - "adbdgl_adapter/**" + - 'adbdgl_adapter/**' + - 'tests/**' + - 'setup.py' + - 'setup.cfg' + - 'pyproject.toml' + - '.github/workflows/analyze.yml' schedule: - cron: "00 9 * * 1" -env: - SOURCE_DIR: adbdgl_adapter jobs: analyze: name: Analyze runs-on: ubuntu-latest - defaults: - run: - working-directory: ${{env.SOURCE_DIR}} permissions: actions: read contents: read diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 6f6da1c..afbde43 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1,40 +1,56 @@ name: build on: workflow_dispatch: + push: + branches: [ master ] + paths: + - 'adbdgl_adapter/**' + - 'tests/**' + - 'setup.py' + - 'setup.cfg' + - 'pyproject.toml' + - '.github/workflows/build.yml' pull_request: + branches: [ master ] paths: - - "adbdgl_adapter/adbdgl_adapter/**" - - "adbdgl_adapter/tests/**" + - 'adbdgl_adapter/**' + - 'tests/**' + - 'setup.py' + - 'setup.cfg' + - 'pyproject.toml' + - '.github/workflows/build.yml' env: - SOURCE_DIR: adbdgl_adapter PACKAGE_DIR: adbdgl_adapter + TESTS_DIR: tests jobs: build: runs-on: ubuntu-latest - defaults: - run: - working-directory: ${{env.SOURCE_DIR}} strategy: matrix: python: ["3.6", "3.7", "3.8", "3.9"] name: Python ${{ matrix.python }} - env: - COVERALLS_REPO_TOKEN: ${{secrets.COVERALLS_REPO_TOKEN}} - GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}} steps: - uses: actions/checkout@v2 - - name: Setup python + - name: Setup Python ${{ matrix.python }} uses: actions/setup-python@v2 with: python-version: ${{ matrix.python }} - - name: Lint with Black - uses: psf/black@stable - with: - options: "--check --verbose --diff --color" - src: ${{env.PACKAGE_DIR}} - - name: Install dependencies - run: pip install -e . pytest pytest-cov coveralls + - name: Setup pip + run: python -m pip install --upgrade pip setuptools wheel + - name: Install packages + run: pip install .[dev] + - name: Run black + run: black --check --verbose --diff --color ${{env.PACKAGE_DIR}} ${{env.TESTS_DIR}} + - name: Run flake8 + run: flake8 ${{env.PACKAGE_DIR}} ${{env.TESTS_DIR}} + - name: Run isort + run: isort --check --profile=black ${{env.PACKAGE_DIR}} ${{env.TESTS_DIR}} + - name: Run mypy + run: mypy ${{env.PACKAGE_DIR}} ${{env.TESTS_DIR}} - name: Run pytest - run: | - pytest --cov=${{env.PACKAGE_DIR}} --cov-report term-missing -v --color=yes --no-cov-on-fail --code-highlight=yes - coveralls + run: py.test --cov=${{env.PACKAGE_DIR}} --cov-report xml -v --color=yes --no-cov-on-fail --code-highlight=yes + - name: Publish to coveralls.io + if: matrix.python == '3.8' + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: coveralls --service=github \ No newline at end of file diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 9b2a48e..50cea35 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -1,192 +1,136 @@ -# name: release -# on: -# workflow_dispatch: -# push: -# branches: -# - master -# paths: -# - "adbdgl_adapter/adbdgl_adapter/**" -# env: -# SOURCE_DIR: adbdgl_adapter -# PACKAGE_DIR: adbdgl_adapter -# jobs: -# version: -# runs-on: ubuntu-latest -# name: Verify version increase -# steps: -# - uses: actions/checkout@v2 -# - uses: actions/setup-python@v2 -# with: -# python-version: "3.9" -# - name: Install dependencies -# run: pip install requests packaging -# - name: Set variables -# run: | -# echo "OLD_VERSION=$(python scripts/extract_version.py)" >> $GITHUB_ENV -# echo "NEW_VERSION=$(cat VERSION)" >> $GITHUB_ENV -# - name: Assert version increase -# id: verify -# run: echo "::set-output name=has_increased::$(python scripts/assert_version.py ${{env.OLD_VERSION}} ${{env.NEW_VERSION}})" -# - name: Fail on no version increase -# if: ${{ steps.verify.outputs.has_increased != 'true' }} -# uses: actions/github-script@v3 -# with: -# script: core.setFailed("Cannot build & release - VERSION has not been manually incremented") -# build: -# needs: version -# runs-on: ubuntu-latest -# defaults: -# run: -# working-directory: ${{env.SOURCE_DIR}} -# strategy: -# matrix: -# python: ["3.6", "3.7", "3.8", "3.9"] -# name: Python ${{ matrix.python }} -# env: -# COVERALLS_REPO_TOKEN: ${{secrets.COVERALLS_REPO_TOKEN}} -# GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}} -# steps: -# - uses: actions/checkout@v2 -# - uses: actions/setup-python@v2 -# with: -# python-version: ${{ matrix.python }} -# - name: Lint with Black -# uses: psf/black@stable -# with: -# options: "--check --verbose --diff --color" -# src: ${{env.PACKAGE_DIR}} -# - name: Install dependencies -# run: pip install -e . pytest pytest-cov coveralls -# - name: Run pytest -# run: | -# pytest --cov=${{env.PACKAGE_DIR}} --cov-report term-missing -v --color=yes --no-cov-on-fail --code-highlight=yes -# coveralls - -# release: -# needs: build -# runs-on: ubuntu-latest -# name: Release package -# env: -# TWINE_USERNAME: ${{ secrets.TWINE_USERNAME }} -# steps: -# - uses: actions/checkout@v2 -# with: -# fetch-depth: 0 - -# - name: Setup python -# uses: actions/setup-python@v2 -# with: -# python-version: "3.8" - -# - name: Copy static repo files -# run: cp {CHANGELOG.md,LICENSE,README.md,VERSION} ${{env.SOURCE_DIR}} - -# - name: Install release packages -# run: pip install wheel gitchangelog pystache twine - -# - name: Install dependencies -# run: pip install -e . -# working-directory: ${{env.SOURCE_DIR}} - -# - name: Set variables -# run: | -# echo "OLD_VERSION=$(python scripts/extract_version.py)" >> $GITHUB_ENV -# echo "NEW_VERSION=$(cat VERSION)" >> $GITHUB_ENV - -# - name: Ensure clean dist/ and build/ folders -# run: rm -rf dist build -# working-directory: ${{env.SOURCE_DIR}} - -# - name: Build package -# run: python setup.py sdist bdist_wheel -# working-directory: ${{env.SOURCE_DIR}} - -# - name: Extract wheel artifact name -# run: echo "wheel_name=$(echo ${{env.SOURCE_DIR}}/dist/*.whl)" >> $GITHUB_ENV - -# - name: Extract tar.gz artifact name -# run: echo "tar_name=$(echo ${{env.SOURCE_DIR}}/dist/*.tar.gz)" >> $GITHUB_ENV - -# - name: Pull tags from the repo -# run: git pull --tags - -# - name: Create version_changelog.md -# run: gitchangelog ${{env.OLD_VERSION}}..HEAD | sed "s/## (unreleased)/${{env.NEW_VERSION}} ($(date +"%Y-%m-%d"))/" > version_changelog.md - -# - name: Read version_changelog.md -# run: cat version_changelog.md - -# - name: TestPypi release -# run: twine upload --repository testpypi dist/* -p ${{ secrets.TWINE_PASSWORD_TEST }} #--skip-existing -# working-directory: ${{env.SOURCE_DIR}} - -# - name: Pypi release -# run: twine upload dist/* -p ${{ secrets.TWINE_PASSWORD }} #--skip-existing -# working-directory: ${{env.SOURCE_DIR}} - -# - name: Github release -# env: -# GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} -# run: hub release create -a $wheel_name -a $tar_name -F version_changelog.md ${{env.NEW_VERSION}} - -# changelog: -# needs: release -# runs-on: ubuntu-latest -# name: Update Changelog -# steps: -# - uses: actions/checkout@v2 -# with: -# fetch-depth: 0 - -# - name: Create new branch -# run: git checkout -b actions/changelog - -# - name: Set branch upstream -# run: git push -u origin actions/changelog -# env: -# GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - -# - name: Setup python -# uses: actions/setup-python@v2 -# with: -# python-version: "3.8" - -# - name: Install release packages -# run: pip install wheel gitchangelog pystache - -# - name: Install dependencies -# run: pip install -e . -# working-directory: ${{env.SOURCE_DIR}} - -# - name: Set variables -# run: echo "NEW_VERSION=$(cat VERSION)" >> $GITHUB_ENV - -# - name: Generate newest changelog -# run: gitchangelog ${{env.NEW_VERSION}} > CHANGELOG.md - -# - name: Make commit for auto-generated changelog -# uses: EndBug/add-and-commit@v7 -# env: -# GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} -# with: -# add: "CHANGELOG.md" -# branch: actions/changelog -# message: "!gitchangelog" - -# - name: Create pull request for the auto generated changelog -# run: | -# echo "PR_URL=$(gh pr create \ -# --title "changelog: release ${{env.NEW_VERSION}}" \ -# --body "beep boop, i am a robot" \ -# --label documentation)" >> $GITHUB_ENV -# env: -# GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - -# - name: Set pull request to auto-merge as rebase -# run: | -# gh pr merge $PR_URL \ -# --auto \ -# --delete-branch \ -# --rebase -# env: -# GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} +name: release +on: + workflow_dispatch: + release: + types: [published] +env: + PACKAGE_DIR: adbdgl_adapter + TESTS_DIR: tests +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python: ["3.6", "3.7", "3.8", "3.9"] + name: Python ${{ matrix.python }} + steps: + - uses: actions/checkout@v2 + - name: Setup Python ${{ matrix.python }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python }} + - name: Setup pip + run: python -m pip install --upgrade pip setuptools wheel + - name: Install packages + run: pip install .[dev] + - name: Run black + run: black --check --verbose --diff --color ${{env.PACKAGE_DIR}} ${{env.TESTS_DIR}} + - name: Run flake8 + run: flake8 ${{env.PACKAGE_DIR}} ${{env.TESTS_DIR}} + - name: Run isort + run: isort --check --profile=black ${{env.PACKAGE_DIR}} ${{env.TESTS_DIR}} + - name: Run mypy + run: mypy ${{env.PACKAGE_DIR}} ${{env.TESTS_DIR}} + - name: Run pytest + run: py.test --cov=${{env.PACKAGE_DIR}} --cov-report xml -v --color=yes --no-cov-on-fail --code-highlight=yes + - name: Publish to coveralls.io + if: matrix.python == '3.8' + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: coveralls --service=github + + release: + needs: build + runs-on: ubuntu-latest + name: Release package + steps: + - uses: actions/checkout@v2 + + - name: Fetch complete history for all tags and branches + run: git fetch --prune --unshallow + + - name: Setup python + uses: actions/setup-python@v2 + with: + python-version: "3.8" + + - name: Install release packages + run: pip install setuptools wheel twine setuptools-scm[toml] + + - name: Install dependencies + run: pip install .[dev] + + - name: Build distribution + run: python setup.py sdist bdist_wheel + + - name: Publish to PyPI Test + env: + TWINE_USERNAME: __token__ + TWINE_PASSWORD: ${{ secrets.TWINE_PASSWORD_TEST }} + run: twine upload --repository testpypi dist/* #--skip-existing + - name: Publish to PyPI + env: + TWINE_USERNAME: __token__ + TWINE_PASSWORD: ${{ secrets.TWINE_PASSWORD }} + run: twine upload --repository pypi dist/* #--skip-existing + + changelog: + needs: release + runs-on: ubuntu-latest + name: Update Changelog + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Create new branch + run: git checkout -b actions/changelog + + - name: Set branch upstream + run: git push -u origin actions/changelog + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: Setup python + uses: actions/setup-python@v2 + with: + python-version: "3.8" + + - name: Install release packages + run: pip install wheel gitchangelog pystache + + - name: Install dependencies + run: pip install .[dev] + + - name: Set variables + run: echo "VERSION=$(curl ${GITHUB_API_URL}/repos/${GITHUB_REPOSITORY}/releases/latest | python -c "import sys; import json; print(json.load(sys.stdin)['tag_name'])")" >> $GITHUB_ENV + + - name: Generate newest changelog + run: gitchangelog ${{env.VERSION}} > CHANGELOG.md + + - name: Make commit for auto-generated changelog + uses: EndBug/add-and-commit@v7 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + add: "CHANGELOG.md" + branch: actions/changelog + message: "!gitchangelog" + + - name: Create pull request for the auto generated changelog + run: | + echo "PR_URL=$(gh pr create \ + --title "changelog: release ${{env.VERSION}}" \ + --body "beep boop, i am a robot" \ + --label documentation)" >> $GITHUB_ENV + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: Set pull request to auto-merge as rebase + run: | + gh pr merge $PR_URL \ + --admin \ + --delete-branch \ + --rebase + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.gitignore b/.gitignore index a2571cb..efb3b0c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,13 +1,118 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook .ipynb_checkpoints -.tox + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + +# MacOS .DS_Store -**/*.pyc -# log files -**/*.log -# Setuptools distribution folder. -adbdgl_adapter/dist/ -# Remove the build directory from repo -adbdgl_adapter/build/ -adbdgl_adapter/*.egg-info -.vscode -.venv \ No newline at end of file + +# PyCharm +.idea/ + +# ArangoDB Starter +localdata/ + +# setuptools_scm +adbdgl_adapter/version.py + +.vscode \ No newline at end of file diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..3d73851 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,3 @@ +include README.md LICENSE +prune tests +prune examples \ No newline at end of file diff --git a/README.md b/README.md index 255cada..2c260a7 100644 --- a/README.md +++ b/README.md @@ -29,12 +29,12 @@ The Deep Graph Library (DGL) is an easy-to-use, high performance and scalable Py ## Quickstart -Get Started on Colab: Open In Colab +Get Started on Colab: Open In Colab ```py # Import the ArangoDB-DGL Adapter -from adbdgl_adapter.adbdgl_adapter import ArangoDB_DGL_Adapter +from adbdgl_adapter.adapter import ADBDGL_Adapter # Import a sample graph from DGL from dgl.data import KarateClubDataset @@ -51,7 +51,7 @@ con = { } # This instantiates your ADBDGL Adapter with your connection credentials -adbdgl_adapter = ArangoDB_DGL_Adapter(con) +adbdgl_adapter = ADBDGL_Adapter(con) # ArangoDB to DGL via Graph dgl_fraud_graph = adbdgl_adapter.arangodb_graph_to_dgl("fraud-detection") @@ -89,6 +89,5 @@ Prerequisite: `arangorestore` must be installed 2. `cd dgl-adapter` 3. `python -m venv .venv` 4. `source .venv/bin/activate` (MacOS) or `.venv/scripts/activate` (Windows) -5. `cd adbdgl_adapter` -6. `pip install -e . pytest` -7. `pytest` \ No newline at end of file +5. `pip install -e . pytest` +6. `pytest` \ No newline at end of file diff --git a/VERSION b/VERSION deleted file mode 100644 index bd52db8..0000000 --- a/VERSION +++ /dev/null @@ -1 +0,0 @@ -0.0.0 \ No newline at end of file diff --git a/adbdgl_adapter/adbdgl_adapter/__init__.py b/adbdgl_adapter/__init__.py similarity index 100% rename from adbdgl_adapter/adbdgl_adapter/__init__.py rename to adbdgl_adapter/__init__.py diff --git a/adbdgl_adapter/abc.py b/adbdgl_adapter/abc.py new file mode 100644 index 0000000..3219d71 --- /dev/null +++ b/adbdgl_adapter/abc.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from abc import ABC +from typing import Any, List, Set, Union + +from arango.graph import Graph as ArangoDBGraph +from dgl import DGLGraph +from dgl.heterograph import DGLHeteroGraph +from torch.functional import Tensor + +from .typings import ArangoMetagraph, DGLCanonicalEType, Json + + +class Abstract_ADBDGL_Adapter(ABC): + def __init__(self) -> None: + raise NotImplementedError # pragma: no cover + + def arangodb_to_dgl( + self, name: str, metagraph: ArangoMetagraph, **query_options: Any + ) -> DGLHeteroGraph: + raise NotImplementedError # pragma: no cover + + def arangodb_collections_to_dgl( + self, name: str, v_cols: Set[str], e_cols: Set[str], **query_options: Any + ) -> DGLHeteroGraph: + raise NotImplementedError # pragma: no cover + + def arangodb_graph_to_dgl(self, name: str, **query_options: Any) -> DGLHeteroGraph: + raise NotImplementedError # pragma: no cover + + def dgl_to_arangodb( + self, name: str, dgl_g: Union[DGLGraph, DGLHeteroGraph], batch_size: int + ) -> ArangoDBGraph: + raise NotImplementedError # pragma: no cover + + def etypes_to_edefinitions( + self, canonical_etypes: List[DGLCanonicalEType] + ) -> List[Json]: + raise NotImplementedError # pragma: no cover + + def __prepare_dgl_features(self) -> None: + raise NotImplementedError # pragma: no cover + + def __insert_dgl_features(self) -> None: + raise NotImplementedError # pragma: no cover + + def __prepare_adb_attributes(self) -> None: + raise NotImplementedError # pragma: no cover + + def __insert_adb_docs(self) -> None: + raise NotImplementedError # pragma: no cover + + def __fetch_adb_docs(self) -> None: + raise NotImplementedError # pragma: no cover + + def __validate_attributes(self) -> None: + raise NotImplementedError # pragma: no cover + + @property + def DEFAULT_CANONICAL_ETYPE(self) -> List[DGLCanonicalEType]: + return [("_N", "_E", "_N")] + + @property + def CONNECTION_ATRIBS(self) -> Set[str]: + return {"hostname", "username", "password", "dbName"} + + @property + def METAGRAPH_ATRIBS(self) -> Set[str]: + return {"vertexCollections", "edgeCollections"} + + @property + def EDGE_DEFINITION_ATRIBS(self) -> Set[str]: + return {"edge_collection", "from_vertex_collections", "to_vertex_collections"} + + +class Abstract_ADBDGL_Controller(ABC): + def _adb_attribute_to_dgl_feature(self, key: str, col: str, val: Any) -> Any: + raise NotImplementedError # pragma: no cover + + def _dgl_feature_to_adb_attribute(self, key: str, col: str, val: Tensor) -> Any: + raise NotImplementedError # pragma: no cover diff --git a/adbdgl_adapter/adbdgl_adapter/adbdgl_adapter.py b/adbdgl_adapter/adapter.py similarity index 61% rename from adbdgl_adapter/adbdgl_adapter/adbdgl_adapter.py rename to adbdgl_adapter/adapter.py index 50d13fe..b9d3075 100644 --- a/adbdgl_adapter/adbdgl_adapter/adbdgl_adapter.py +++ b/adbdgl_adapter/adapter.py @@ -1,75 +1,73 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -""" -@author: Anthony Mahanna -""" -from .abc import ADBDGL_Adapter -from .adbdgl_controller import Base_ADBDGL_Controller +from collections import defaultdict +from typing import Any, DefaultDict, Dict, List, Set, Union from arango import ArangoClient +from arango.cursor import Cursor from arango.graph import Graph as ArangoDBGraph - -import dgl -from dgl import DGLGraph +from arango.result import Result +from dgl import DGLGraph, heterograph from dgl.heterograph import DGLHeteroGraph from dgl.view import HeteroEdgeDataView, HeteroNodeDataView - -import torch +from torch import tensor # type: ignore from torch.functional import Tensor -from typing import Union -from collections import defaultdict +from .abc import Abstract_ADBDGL_Adapter +from .controller import ADBDGL_Controller +from .typings import ArangoMetagraph, DGLCanonicalEType, DGLDataDict, Json -class ArangoDB_DGL_Adapter(ADBDGL_Adapter): +class ADBDGL_Adapter(Abstract_ADBDGL_Adapter): """ArangoDB-DGL adapter. :param conn: Connection details to an ArangoDB instance. - :type conn: dict - :param controller_class: The ArangoDB-DGL controller, for controlling how ArangoDB attributes are converted into DGL features, and vice-versa. Optionally re-defined by the user if needed (otherwise defaults to Base_ADBDGL_Controller). - :type controller_class: Base_ADBDGL_Controller + :type conn: adbdgl_adapter.typings.Json + :param controller: The ArangoDB-DGL controller, for controlling how + ArangoDB attributes are converted into DGL features, and vice-versa. + Optionally re-defined by the user if needed (otherwise defaults to + ADBDGL_Controller). + :type controller: adbdgl_adapter.controller.ADBDGL_Controller :raise ValueError: If missing required keys in conn """ def __init__( self, - conn: dict, - controller_class: Base_ADBDGL_Controller = Base_ADBDGL_Controller, + conn: Json, + controller: ADBDGL_Controller = ADBDGL_Controller(), ): self.__validate_attributes("connection", set(conn), self.CONNECTION_ATRIBS) - if issubclass(controller_class, Base_ADBDGL_Controller) is False: - msg = "controller_class must inherit from Base_ADBDGL_Controller" + if issubclass(type(controller), ADBDGL_Controller) is False: + msg = "controller must inherit from ADBDGL_Controller" raise TypeError(msg) - username = conn["username"] - password = conn["password"] - db_name = conn["dbName"] - - protocol = conn.get("protocol", "https") - host = conn["hostname"] + username: str = conn["username"] + password: str = conn["password"] + db_name: str = conn["dbName"] + host: str = conn["hostname"] + protocol: str = conn.get("protocol", "https") port = str(conn.get("port", 8529)) url = protocol + "://" + host + ":" + port print(f"Connecting to {url}") self.__db = ArangoClient(hosts=url).db(db_name, username, password, verify=True) - self.__cntrl: Base_ADBDGL_Controller = controller_class() + self.__cntrl: ADBDGL_Controller = controller def arangodb_to_dgl( - self, - name: str, - metagraph: dict, - **query_options, - ): - """Create a DGL graph from user-defined metagraph. + self, name: str, metagraph: ArangoMetagraph, **query_options: Any + ) -> DGLHeteroGraph: + """Create a DGLHeteroGraph from the user-defined metagraph. :param name: The DGL graph name. :type name: str - :param metagraph: An object defining vertex & edge collections to import to DGL, along with their associated attributes to keep. - :type metagraph: dict - :param query_options: Keyword arguments to specify AQL query options when fetching documents from the ArangoDB instance. - :type query_options: **kwargs + :param metagraph: An object defining vertex & edge collections to import + to DGL, along with their associated attributes to keep. + :type metagraph: adbdgl_adapter.typings.ArangoMetagraph + :param query_options: Keyword arguments to specify AQL query options when + fetching documents from the ArangoDB instance. + :type query_options: Any :return: A DGL Heterograph :rtype: dgl.heterograph.DGLHeteroGraph :raise ValueError: If missing required keys in metagraph @@ -85,55 +83,64 @@ def arangodb_to_dgl( }, "edgeCollections": { "accountHolder": {}, - "transaction": {}, + "transaction": { + "transaction_amt", "receiver_bank_id", "sender_bank_id" + }, }, } """ self.__validate_attributes("graph", set(metagraph), self.METAGRAPH_ATRIBS) - adb_map = dict() # Maps ArangoDB vertex IDs to DGL node IDs + # Maps ArangoDB vertex IDs to DGL node IDs + adb_map: Dict[str, Dict[str, Any]] = dict() # Dictionaries for constructing a heterogeneous graph. - data_dict = dict() - ndata = defaultdict(lambda: defaultdict(list)) - edata = defaultdict(lambda: defaultdict(list)) + data_dict: DGLDataDict = dict() + ndata: DefaultDict[Any, Any] = defaultdict(lambda: defaultdict(list)) + edata: DefaultDict[Any, Any] = defaultdict(lambda: defaultdict(list)) + adb_v: Json for v_col, atribs in metagraph["vertexCollections"].items(): - for i, v in enumerate(self.__fetch_adb_docs(v_col, atribs, query_options)): - adb_map[v["_id"]] = { + for i, adb_v in enumerate( + self.__fetch_adb_docs(v_col, atribs, query_options) + ): + adb_map[adb_v["_id"]] = { "id": i, "col": v_col, } - self.__prepare_dgl_features(ndata, atribs, v, v_col) + self.__prepare_dgl_features(ndata, atribs, adb_v, v_col) - from_col = set() - to_col = set() + adb_e: Json + from_col: Set[str] = set() + to_col: Set[str] = set() for e_col, atribs in metagraph["edgeCollections"].items(): - from_nodes = [] - to_nodes = [] - for e in self.__fetch_adb_docs(e_col, atribs, query_options): - from_node = adb_map[e["_from"]] - to_node = adb_map[e["_to"]] + from_nodes: List[int] = [] + to_nodes: List[int] = [] + for adb_e in self.__fetch_adb_docs(e_col, atribs, query_options): + from_node = adb_map[adb_e["_from"]] + to_node = adb_map[adb_e["_to"]] from_col.add(from_node["col"]) to_col.add(to_node["col"]) if len(from_col | to_col) > 2: raise ValueError( - f"Can't convert to DGL: too many '_from' & '_to' collections in {e_col}" + f"""Can't convert to DGL: + too many '_from' & '_to' collections in {e_col} + """ ) from_nodes.append(from_node["id"]) to_nodes.append(to_node["id"]) - self.__prepare_dgl_features(edata, atribs, e, e_col) + self.__prepare_dgl_features(edata, atribs, adb_e, e_col) data_dict[(from_col.pop(), e_col, to_col.pop())] = ( - torch.tensor(from_nodes), - torch.tensor(to_nodes), + tensor(from_nodes), + tensor(to_nodes), ) - dgl_g: DGLHeteroGraph = dgl.heterograph(data_dict) + dgl_g: DGLHeteroGraph = heterograph(data_dict) has_one_ntype = len(dgl_g.ntypes) == 1 has_one_etype = len(dgl_g.etypes) == 1 @@ -146,37 +153,40 @@ def arangodb_to_dgl( def arangodb_collections_to_dgl( self, name: str, - vertex_collections: set, - edge_collections: set, - **query_options, - ): + v_cols: Set[str], + e_cols: Set[str], + **query_options: Any, + ) -> DGLHeteroGraph: """Create a DGL graph from ArangoDB collections. :param name: The DGL graph name. :type name: str - :param vertex_collections: A set of ArangoDB vertex collections to import to DGL. - :type vertex_collections: set - :param edge_collections: A set of ArangoDB edge collections to import to DGL. - :type edge_collections: set - :param query_options: Keyword arguments to specify AQL query options when fetching documents from the ArangoDB instance. - :type query_options: **kwargs + :param v_cols: A set of ArangoDB vertex collections to + import to DGL. + :type v_cols: Set[str] + :param e_cols: A set of ArangoDB edge collections to import to DGL. + :type e_cols: Set[str] + :param query_options: Keyword arguments to specify AQL query options + when fetching documents from the ArangoDB instance. + :type query_options: Any :return: A DGL Heterograph :rtype: dgl.heterograph.DGLHeteroGraph """ - metagraph = { - "vertexCollections": {col: {} for col in vertex_collections}, - "edgeCollections": {col: {} for col in edge_collections}, + metagraph: ArangoMetagraph = { + "vertexCollections": {col: set() for col in v_cols}, + "edgeCollections": {col: set() for col in e_cols}, } return self.arangodb_to_dgl(name, metagraph, **query_options) - def arangodb_graph_to_dgl(self, name: str, **query_options): + def arangodb_graph_to_dgl(self, name: str, **query_options: Any) -> DGLHeteroGraph: """Create a DGL graph from an ArangoDB graph. :param name: The ArangoDB graph name. :type name: str - :param query_options: Keyword arguments to specify AQL query options when fetching documents from the ArangoDB instance. - :type query_options: **kwargs + :param query_options: Keyword arguments to specify AQL query options + when fetching documents from the ArangoDB instance. + :type query_options: Any :return: A DGL Heterograph :rtype: dgl.heterograph.DGLHeteroGraph """ @@ -188,7 +198,7 @@ def arangodb_graph_to_dgl(self, name: str, **query_options): def dgl_to_arangodb( self, name: str, dgl_g: Union[DGLGraph, DGLHeteroGraph], batch_size: int = 1000 - ): + ) -> ArangoDBGraph: """Create an ArangoDB graph from a DGL graph. :param name: The ArangoDB graph name. @@ -200,9 +210,9 @@ def dgl_to_arangodb( :return: The ArangoDB Graph API wrapper. :rtype: arango.graph.Graph """ - is_default_type = dgl_g.canonical_etypes == self.DEFAULT_CANONICAL_ETYPE - adb_v_cols = [name + dgl_g.ntypes[0]] if is_default_type else dgl_g.ntypes - adb_e_cols = [name + dgl_g.etypes[0]] if is_default_type else dgl_g.etypes + is_default = dgl_g.canonical_etypes == self.DEFAULT_CANONICAL_ETYPE + adb_v_cols: List[str] = [name + dgl_g.ntypes[0]] if is_default else dgl_g.ntypes + adb_e_cols: List[str] = [name + dgl_g.etypes[0]] if is_default else dgl_g.etypes e_definitions = self.etypes_to_edefinitions( [ ( @@ -211,16 +221,16 @@ def dgl_to_arangodb( adb_v_cols[0], ) ] - if is_default_type + if is_default else dgl_g.canonical_etypes ) has_one_ntype = len(dgl_g.ntypes) == 1 has_one_etype = len(dgl_g.etypes) == 1 - adb_documents = defaultdict(list) + adb_documents: DefaultDict[str, List[Json]] = defaultdict(list) for v_col in adb_v_cols: - ntype = None if is_default_type else v_col + ntype = None if is_default else v_col v_col_docs = adb_documents[v_col] if self.__db.has_collection(v_col) is False: @@ -228,7 +238,7 @@ def dgl_to_arangodb( node: Tensor for node in dgl_g.nodes(ntype): - dgl_node_id: int = node.item() + dgl_node_id = node.item() adb_vertex = {"_key": str(dgl_node_id)} self.__prepare_adb_attributes( dgl_g.ndata, @@ -246,13 +256,13 @@ def dgl_to_arangodb( from_nodes: Tensor to_nodes: Tensor for e_col in adb_e_cols: - etype = None if is_default_type else e_col + etype = None if is_default else e_col e_col_docs = adb_documents[e_col] if self.__db.has_collection(e_col) is False: self.__db.create_collection(e_col, edge=True) - if is_default_type: + if is_default: from_col = to_col = adb_v_cols[0] else: from_col, _, to_col = dgl_g.to_canonical_etype(e_col) @@ -286,26 +296,29 @@ def dgl_to_arangodb( print(f"ArangoDB: {name} created") return adb_graph - def etypes_to_edefinitions(self, canonical_etypes: list) -> list: + def etypes_to_edefinitions( + self, canonical_etypes: List[DGLCanonicalEType] + ) -> List[Json]: """Converts a DGL graph's canonical_etypes property to ArangoDB graph edge definitions - :param canonical_etypes: A list of string triplets (str, str, str) for source node type, edge type and destination node type. - :type canonical_etypes: list[tuple] + :param canonical_etypes: A list of string triplets (str, str, str) for + source node type, edge type and destination node type. + :type canonical_etypes: List[adbdgl_adapter.typings.DGLCanonicalEType] :return: ArangoDB Edge Definitions - :rtype: list[dict[str, Union[str, list[str]]]] + :rtype: List[adbdgl_adapter.typings.Json] Here is an example of **edge_definitions**: .. code-block:: python [ { - "edge_collection": "teach", - "from_vertex_collections": ["teachers"], - "to_vertex_collections": ["lectures"] + "edge_collection": "teaches", + "from_vertex_collections": ["Teacher"], + "to_vertex_collections": ["Lecture"] } ] """ - edge_definitions = [] + edge_definitions: List[Json] = [] for dgl_from, dgl_e, dgl_to in canonical_etypes: edge_definitions.append( { @@ -319,91 +332,99 @@ def etypes_to_edefinitions(self, canonical_etypes: list) -> list: def __prepare_dgl_features( self, - features_data: defaultdict, - attributes: set, - doc: dict, + features_data: DefaultDict[Any, Any], + attributes: Set[str], + doc: Json, col: str, - ): + ) -> None: """Convert a set of ArangoDB attributes into valid DGL features :param features_data: A dictionary storing the DGL features formatted as lists. - :type features_data: defaultdict[Any, defaultdict[Any, list]] - :param col: The collection the current document belongs to - :type col: str + :type features_data: Defaultdict[Any, Any] :param attributes: A set of ArangoDB attribute keys to convert into DGL features - :type attributes: set + :type attributes: Set[str] :param doc: The current ArangoDB document - :type doc: dict - + :type doc: adbdgl_adapter.typings.Json + :param col: The collection the current document belongs to + :type col: str """ key: str for key in attributes: - arr: list = features_data[key][col] + arr: List[Any] = features_data[key][col] arr.append( - self.__cntrl._adb_attribute_to_dgl_feature(key, col, doc.get(key, -1)) + self.__cntrl._adb_attribute_to_dgl_feature(key, col, doc.get(key, None)) ) def __insert_dgl_features( self, - features_data: defaultdict, + features_data: DefaultDict[Any, Any], data: Union[HeteroNodeDataView, HeteroEdgeDataView], has_one_type: bool, - ): + ) -> None: """Insert valid DGL features into a DGL graph. :param features_data: A dictionary storing the DGL features formatted as lists. - :type features_data: defaultdict[Any, defaultdict[Any, list]] - :param data: The (empty) ndata or edata instance attribute of a dgl graph, which is about to receive the **features_data**. - :type data: Union[HeteroNodeDataView, HeteroEdgeDataView] - :param has_one_type: Set to True if the DGL graph only has one ntype, or one etype. + :type features_data: Defaultdict[Any, Any] + :param data: The (empty) ndata or edata instance attribute of a dgl graph, + which is about to receive **features_data**. + :type data: Union[dgl.view.HeteroNodeDataView, dgl.view.HeteroEdgeDataView] + :param has_one_type: Set to True if the DGL graph only has one ntype, + or one etype. :type has_one_type: bool """ - col_dict: dict + col_dict: Dict[str, List[Any]] for key, col_dict in features_data.items(): for col, array in col_dict.items(): data[key] = ( - torch.tensor(array) - if has_one_type - else {**data[key], col: torch.tensor(array)} + tensor(array) if has_one_type else {**data[key], col: tensor(array)} ) def __prepare_adb_attributes( self, data: Union[HeteroNodeDataView, HeteroEdgeDataView], - features: set, - id: int, - doc: dict, + features: Set[Any], + id: Union[int, float, bool], + doc: Json, col: str, has_one_type: bool, - ): + ) -> None: """Convert DGL features into a set of ArangoDB attributes for a given document - :param data: The ndata or edata instance attribute of a dgl graph, filled with node or edge feature data. - :type data: Union[HeteroNodeDataView, HeteroEdgeDataView] + :param data: The ndata or edata instance attribute of a dgl graph, filled with + node or edge feature data. + :type data: Union[dgl.view.HeteroNodeDataView, dgl.view.HeteroEdgeDataView] :param features: A set of DGL feature keys to convert into ArangoDB attributes - :type features: set + :type features: Set[Any] :param id: The ID of the current DGL node / edge - :type id: int + :type id: Union[int, float, bool] :param doc: The current ArangoDB document - :type doc: dict + :type doc: adbdgl_adapter.typings.Json :param col: The collection the current document belongs to :type col: str - :param has_one_type: Set to True if the DGL graph only has one ntype, or one etype. + :param has_one_type: Set to True if the DGL graph only has one ntype, + or one etype. :type has_one_type: bool """ for key in features: tensor = data[key] if has_one_type else data[key][col] doc[key] = self.__cntrl._dgl_feature_to_adb_attribute(key, col, tensor[id]) - def __insert_adb_docs(self, col: str, col_docs: list, doc: dict, batch_size: int): - """Insert an ArangoDB document into a list. If the list exceeds batch_size documents, insert into the ArangoDB collection. + def __insert_adb_docs( + self, + col: str, + col_docs: List[Json], + doc: Json, + batch_size: int, + ) -> None: + """Insert an ArangoDB document into a list. If the list exceeds + batch_size documents, insert into the ArangoDB collection. :param col: The collection name :type col: str :param col_docs: The existing documents data belonging to the collection. - :type col_docs: list + :type col_docs: List[adbdgl_adapter.typings.Json] :param doc: The current document to insert. - :type doc: dict + :type doc: adbdgl_adapter.typings.Json :param batch_size: The maximum number of documents to insert at once :type batch_size: int """ @@ -413,38 +434,45 @@ def __insert_adb_docs(self, col: str, col_docs: list, doc: dict, batch_size: int self.__db.collection(col).import_bulk(col_docs, on_duplicate="replace") col_docs.clear() - def __fetch_adb_docs(self, col: str, attributes: set, query_options: dict): + def __fetch_adb_docs( + self, col: str, attributes: Set[str], query_options: Any + ) -> Result[Cursor]: """Fetches ArangoDB documents within a collection. :param col: The ArangoDB collection. :type col: str :param attributes: The set of document attributes. - :type attributes: set - :param query_options: Keyword arguments to specify AQL query options when fetching documents from the ArangoDB instance. - :type query_options: **kwargs + :type attributes: Set[str] + :param query_options: Keyword arguments to specify AQL query options + when fetching documents from the ArangoDB instance. + :type query_options: Any :return: Result cursor. :rtype: arango.cursor.Cursor """ aql = f""" FOR doc IN {col} RETURN MERGE( - KEEP(doc, {list(attributes)}), - {{"_id": doc._id}}, + KEEP(doc, {list(attributes)}), + {{"_id": doc._id}}, doc._from ? {{"_from": doc._from, "_to": doc._to}}: {{}} ) """ return self.__db.aql.execute(aql, **query_options) - def __validate_attributes(self, type: str, attributes: set, valid_attributes: set): - """Validates that a set of attributes includes the required valid attributes. + def __validate_attributes( + self, type: str, attributes: Set[str], valid_attributes: Set[str] + ) -> None: + """Validates that a set of attributes includes the required valid + attributes. - :param type: The context of the attribute validation (e.g connection attributes, graph attributes, etc). + :param type: The context of the attribute validation + (e.g connection attributes, graph attributes, etc). :type type: str :param attributes: The provided attributes, possibly invalid. - :type attributes: set + :type attributes: Set[str] :param valid_attributes: The valid attributes. - :type valid_attributes: set + :type valid_attributes: Set[str] :raise ValueError: If **valid_attributes** is not a subset of **attributes** """ if valid_attributes.issubset(attributes) is False: diff --git a/adbdgl_adapter/adbdgl_adapter/abc.py b/adbdgl_adapter/adbdgl_adapter/abc.py deleted file mode 100644 index 9caab5d..0000000 --- a/adbdgl_adapter/adbdgl_adapter/abc.py +++ /dev/null @@ -1,69 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -@author: Anthony Mahanna -""" - -from abc import ABC - - -class ADBDGL_Adapter(ABC): - def __init__(self): - raise NotImplementedError() # pragma: no cover - - def arangodb_to_dgl(self): - raise NotImplementedError() # pragma: no cover - - def arangodb_collections_to_dgl(self): - raise NotImplementedError() # pragma: no cover - - def arangodb_graph_to_dgl(self): - raise NotImplementedError() # pragma: no cover - - def dgl_to_arangodb(self): - raise NotImplementedError() # pragma: no cover - - def etypes_to_edefinitions(self): - raise NotImplementedError() # pragma: no cover - - def __prepare_dgl_features(self): - raise NotImplementedError() # pragma: no cover - - def __insert_dgl_features(self): - raise NotImplementedError() # pragma: no cover - - def __prepare_adb_attributes(self): - raise NotImplementedError() # pragma: no cover - - def __insert_adb_docs(self): - raise NotImplementedError() # pragma: no cover - - def __fetch_adb_docs(self): - raise NotImplementedError() # pragma: no cover - - def __validate_attributes(self): - raise NotImplementedError() # pragma: no cover - - @property - def DEFAULT_CANONICAL_ETYPE(self): - return [("_N", "_E", "_N")] - - @property - def CONNECTION_ATRIBS(self): - return {"hostname", "username", "password", "dbName"} - - @property - def METAGRAPH_ATRIBS(self): - return {"vertexCollections", "edgeCollections"} - - @property - def EDGE_DEFINITION_ATRIBS(self): - return {"edge_collection", "from_vertex_collections", "to_vertex_collections"} - - -class ADBDGL_Controller(ABC): - def _adb_attribute_to_dgl_feature(self): - raise NotImplementedError() # pragma: no cover - - def _dgl_feature_to_adb_attribute(self): - raise NotImplementedError() # pragma: no cover diff --git a/adbdgl_adapter/adbdgl_adapter/adbdgl_controller.py b/adbdgl_adapter/adbdgl_adapter/adbdgl_controller.py deleted file mode 100644 index ca0497b..0000000 --- a/adbdgl_adapter/adbdgl_adapter/adbdgl_controller.py +++ /dev/null @@ -1,49 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -from .abc import ADBDGL_Controller -from collections import defaultdict -from torch.functional import Tensor - -""" - -@author: Anthony Mahanna -""" - - -class Base_ADBDGL_Controller(ADBDGL_Controller): - """ArangoDB-DGL controller. - - Responsible for controlling how ArangoDB attributes - are converted into DGL features, and vice-versa. - - You can derive your own custom ADBDGL_Controller if you want to maintain - consistency between your ArangoDB attributes & your DGL features. - """ - - def _adb_attribute_to_dgl_feature(self, key: str, col: str, val): - """ - Given an ArangoDB attribute key, its assigned value (for an arbitrary document), - and the collection it belongs to, convert it to a valid - DGL feature: https://docs.dgl.ai/en/0.6.x/guide/graph-feature.html. - - NOTE: You must override this function if you want to transfer non-numerical ArangoDB - attributes to DGL (DGL only accepts 'attributes' (a.k.a features) of numerical types). - Read more about DGL features here: https://docs.dgl.ai/en/0.6.x/new-tutorial/2_dglgraph.html#assigning-node-and-edge-features-to-graph. - """ - try: - return float(val) - except: - return 0 - - def _dgl_feature_to_adb_attribute(self, key: str, col: str, val: Tensor): - """ - Given a DGL feature key, its assigned value (for an arbitrary node or edge), - and the collection it belongs to, convert it to a valid ArangoDB attribute (e.g string, list, number, ...). - - NOTE: No action is needed here if you want to keep the numerical-based values of your DGL features. - """ - try: - return val.item() - except ValueError: - print("HERERERERE") - return val.tolist() diff --git a/adbdgl_adapter/controller.py b/adbdgl_adapter/controller.py new file mode 100644 index 0000000..bd7e8d8 --- /dev/null +++ b/adbdgl_adapter/controller.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from typing import Any + +from torch.functional import Tensor + +from .abc import Abstract_ADBDGL_Controller + + +class ADBDGL_Controller(Abstract_ADBDGL_Controller): + """ArangoDB-DGL controller. + + Responsible for controlling how ArangoDB attributes + are converted into DGL features, and vice-versa. + + You can derive your own custom ADBDGL_Controller if you want to maintain + consistency between your ArangoDB attributes & your DGL features. + """ + + def _adb_attribute_to_dgl_feature(self, key: str, col: str, val: Any) -> Any: + """ + Given an ArangoDB attribute key, its assigned value (for an arbitrary document), + and the collection it belongs to, convert it to a valid + DGL feature: https://docs.dgl.ai/en/0.6.x/guide/graph-feature.html. + + NOTE: You must override this function if you want to transfer non-numerical + ArangoDB attributes to DGL (DGL only accepts 'attributes' (a.k.a features) + of numerical types). Read more about DGL features here: + https://docs.dgl.ai/en/0.6.x/new-tutorial/2_dglgraph.html#assigning-node-and-edge-features-to-graph. + + :param key: The ArangoDB attribute key name + :type key: str + :param col: The ArangoDB collection of the ArangoDB document. + :type col: str + :param val: The assigned attribute value of the ArangoDB document. + :type val: Any + :return: The attribute's representation as a DGL Feature + :rtype: Any + """ + if type(val) in [int, float, bool]: + return val + + try: + return float(val) + except (ValueError, TypeError, SyntaxError): + return 0 + + def _dgl_feature_to_adb_attribute(self, key: str, col: str, val: Tensor) -> Any: + """ + Given a DGL feature key, its assigned value (for an arbitrary node or edge), + and the collection it belongs to, convert it to a valid ArangoDB attribute + (e.g string, list, number, ...). + + NOTE: No action is needed here if you want to keep the numerical-based values + of your DGL features. + + :param key: The DGL attribute key name + :type key: str + :param col: The ArangoDB collection of the (soon-to-be) ArangoDB document. + :type col: str + :param val: The assigned attribute value of the DGL node. + :type val: Tensor + :return: The feature's representation as an ArangoDB Attribute + :rtype: Any + """ + try: + return val.item() + except ValueError: + return val.tolist() diff --git a/adbdgl_adapter/setup.cfg b/adbdgl_adapter/setup.cfg deleted file mode 100644 index 5df7f3d..0000000 --- a/adbdgl_adapter/setup.cfg +++ /dev/null @@ -1,7 +0,0 @@ -[metadata] -description_file = README.md - -[tool:pytest] -markers = - unit: Marks a unit test -testpaths = tests \ No newline at end of file diff --git a/adbdgl_adapter/tests/conftest.py b/adbdgl_adapter/tests/conftest.py deleted file mode 100755 index d6e35b9..0000000 --- a/adbdgl_adapter/tests/conftest.py +++ /dev/null @@ -1,119 +0,0 @@ -import os -import time -import json -import requests -import subprocess -from pathlib import Path - -import torch -from dgl import remove_self_loop -from dgl.data import KarateClubDataset -from dgl.data import MiniGCDataset - -from arango import ArangoClient -from adbdgl_adapter.adbdgl_adapter import ArangoDB_DGL_Adapter - -PROJECT_DIR = Path(__file__).parent.parent.parent - - -def pytest_sessionstart(): - global conn - conn = get_oasis_crendetials() - # conn = { - # "username": "root", - # "password": "openSesame", - # "hostname": "localhost", - # "port": 8529, - # "protocol": "http", - # "dbName": "_system", - # } - print_connection_details(conn) - time.sleep(5) # Enough for the oasis instance to be ready. - - global adbdgl_adapter - adbdgl_adapter = ArangoDB_DGL_Adapter(conn) - - global db - url = ( - conn.get("protocol", "https") - + "://" - + conn["hostname"] - + ":" - + str(conn["port"]) - ) - client = ArangoClient(hosts=url) - db = client.db(conn["dbName"], conn["username"], conn["password"], verify=True) - - arango_restore("examples/data/fraud_dump") - db.create_graph( - "fraud-detection", - edge_definitions=[ - { - "edge_collection": "accountHolder", - "from_vertex_collections": ["customer"], - "to_vertex_collections": ["account"], - }, - { - "edge_collection": "transaction", - "from_vertex_collections": ["account"], - "to_vertex_collections": ["account"], - }, - ], - ) - - -def get_oasis_crendetials() -> dict: - url = "https://tutorials.arangodb.cloud:8529/_db/_system/tutorialDB/tutorialDB" - request = requests.post(url, data=json.dumps("{}")) - if request.status_code != 200: - raise Exception("Error retrieving login data.") - - return json.loads(request.text) - - -def arango_restore(path_to_data): - restore_prefix = "./assets/" if os.getenv("GITHUB_ACTIONS") else "" - - subprocess.check_call( - f'chmod -R 755 ./assets/arangorestore && {restore_prefix}arangorestore -c none --server.endpoint http+ssl://{conn["hostname"]}:{conn["port"]} --server.username {conn["username"]} --server.database {conn["dbName"]} --server.password {conn["password"]} --default-replication-factor 3 --input-directory "{PROJECT_DIR}/{path_to_data}"', - cwd=f"{PROJECT_DIR}/adbdgl_adapter/tests", - shell=True, - ) - - -def print_connection_details(conn): - print("----------------------------------------") - print("https://{}:{}".format(conn["hostname"], conn["port"])) - print("Username: " + conn["username"]) - print("Password: " + conn["password"]) - print("Database: " + conn["dbName"]) - print("----------------------------------------") - - -def get_karate_graph(): - return KarateClubDataset()[0] - - -def get_lollipop_graph(): - dgl_g = remove_self_loop(MiniGCDataset(8, 7, 8)[3][0]) - dgl_g.ndata["random_ndata"] = torch.tensor( - [[i, i, i] for i in range(0, dgl_g.num_nodes())] - ) - dgl_g.edata["random_edata"] = torch.rand(dgl_g.num_edges()) - return dgl_g - - -def get_hypercube_graph(): - dgl_g = remove_self_loop(MiniGCDataset(8, 8, 9)[4][0]) - dgl_g.ndata["random_ndata"] = torch.rand(dgl_g.num_nodes()) - dgl_g.edata["random_edata"] = torch.tensor( - [[[i], [i], [i]] for i in range(0, dgl_g.num_edges())] - ) - return dgl_g - - -def get_clique_graph(): - dgl_g = remove_self_loop(MiniGCDataset(8, 6, 7)[6][0]) - dgl_g.ndata["random_ndata"] = torch.ones(dgl_g.num_nodes()) - dgl_g.edata["random_edata"] = torch.zeros(dgl_g.num_edges()) - return dgl_g diff --git a/adbdgl_adapter/typings.py b/adbdgl_adapter/typings.py new file mode 100644 index 0000000..c3a7015 --- /dev/null +++ b/adbdgl_adapter/typings.py @@ -0,0 +1,12 @@ +__all__ = ["Json", "ArangoMetagraph", "DGLCanonicalEType"] + +from typing import Any, Dict, Set, Tuple + +from torch.functional import Tensor + +Json = Dict[str, Any] +ArangoMetagraph = Dict[str, Dict[str, Set[str]]] + + +DGLCanonicalEType = Tuple[str, str, str] +DGLDataDict = Dict[DGLCanonicalEType, Tuple[Tensor, Tensor]] diff --git a/examples/ArangoDB_DGL_Adapter.ipynb b/examples/ArangoDB_DGL_Adapter.ipynb index 3dd6ab4..57d2ae3 100644 --- a/examples/ArangoDB_DGL_Adapter.ipynb +++ b/examples/ArangoDB_DGL_Adapter.ipynb @@ -36,7 +36,7 @@ "source": [ "Version: 1.0.0\n", "\n", - "Objective: Export Graphs from [ArangoDB](https://www.arangodb.com/), a multi-model Graph Database, into [Deep Graph Library](https://www.dgl.ai/) (DGL), a python package for graph neural networks, and vice-versa." + "Objective: Export Graphs from [ArangoDB](https://www.arangodb.com/), a multi-model Graph Database, to [Deep Graph Library](https://www.dgl.ai/) (DGL), a python package for graph neural networks, and vice-versa." ] }, { @@ -58,10 +58,10 @@ "source": [ "%%capture\n", "!git clone -b oasis_connector --single-branch https://github.com/arangodb/interactive_tutorials.git\n", - "!git clone https://github.com/arangoml/dgl-adapter.git # !git clone -b 1.0.0 --single-branch https://github.com/arangoml/dgl-adapter.git\n", + "!git clone -b 1.0.0 --single-branch https://github.com/arangoml/dgl-adapter.git\n", "!rsync -av dgl-adapter/examples/ ./ --exclude=.git\n", "!rsync -av interactive_tutorials/ ./ --exclude=.git\n", - "!pip3 install \"git+https://github.com/arangoml/dgl-adapter.git#egg=adbdgl_adapter&subdirectory=adbdgl_adapter\" # pip3 install adbdgl_adapter==1.0.0\n", + "!pip3 install adbdgl_adapter==1.0.0\n", "!pip3 install matplotlib\n", "!pip3 install pyArango\n", "!pip3 install networkx ## For drawing purposes " @@ -71,7 +71,11 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "RpqvL4COeG8-" + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "RpqvL4COeG8-", + "outputId": "2df55e4e-03fa-47ed-c2c9-baf9f597e1d8" }, "outputs": [], "source": [ @@ -87,8 +91,9 @@ "from dgl.data import KarateClubDataset\n", "from dgl.data import MiniGCDataset\n", "\n", - "from adbdgl_adapter.adbdgl_adapter import ArangoDB_DGL_Adapter\n", - "from adbdgl_adapter.adbdgl_controller import Base_ADBDGL_Controller" + "from adbdgl_adapter.adapter import ADBDGL_Adapter\n", + "from adbdgl_adapter.controller import ADBDGL_Controller\n", + "from adbdgl_adapter.typings import Json, ArangoMetagraph, DGLCanonicalEType, DGLDataDict" ] }, { @@ -124,7 +129,7 @@ "base_uri": "https://localhost:8080/" }, "id": "vf0350qvj8up", - "outputId": "9c2e9905-7272-44f6-8e59-e5f568a57758" + "outputId": "a65f00d2-cd6e-4583-94d8-2c9884e2e2e2" }, "outputs": [], "source": [ @@ -157,7 +162,7 @@ "base_uri": "https://localhost:8080/" }, "id": "oOS3AVAnkQEV", - "outputId": "9589b7b3-0867-4ff2-9c9f-8d0f38633490" + "outputId": "4609cdef-25ce-4f00-94b5-482c76274f88" }, "outputs": [], "source": [ @@ -193,7 +198,7 @@ "base_uri": "https://localhost:8080/" }, "id": "meLon-KgkU4h", - "outputId": "9f2f8081-393f-4a1b-9ff7-3f6c13289a62" + "outputId": "976680a4-eadd-43f2-da17-e6a574fad8a7" }, "outputs": [], "source": [ @@ -231,7 +236,7 @@ "base_uri": "https://localhost:8080/" }, "id": "zTebQ0LOlsGA", - "outputId": "0f3d26db-1b50-4d65-8385-8d0ad7147ec5" + "outputId": "9c84cb84-f7ce-42b3-9174-01f38295c5dd" }, "outputs": [], "source": [ @@ -274,7 +279,7 @@ "base_uri": "https://localhost:8080/" }, "id": "KsxNujb0mSqZ", - "outputId": "0b10fb67-5193-49ef-8aee-d691f53fe5bf" + "outputId": "3f3fd2b1-e1d3-4b03-c6c4-43566672cbb5" }, "outputs": [], "source": [ @@ -317,7 +322,7 @@ "base_uri": "https://localhost:8080/" }, "id": "2ekGwnJDeG8-", - "outputId": "fb348f69-8321-40b3-9bf5-0085ea218492" + "outputId": "92e9d288-0259-45cc-e73d-a8e9f629063a" }, "outputs": [], "source": [ @@ -377,16 +382,13 @@ "cell_type": "code", "execution_count": null, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "7bgGJ3QkeG8_", - "outputId": "93451001-15f8-463d-8341-5c65fe6bc178" + "id": "7bgGJ3QkeG8_" }, "outputs": [], "source": [ + "%%capture\n", "!chmod -R 755 ./tools\n", - "!./tools/arangorestore -c none --server.endpoint http+ssl://{con[\"hostname\"]}:{con[\"port\"]} --server.username {con[\"username\"]} --server.database {con[\"dbName\"]} --server.password {con[\"password\"]} --default-replication-factor 3 --input-directory \"data/fraud_dump\"" + "!./tools/arangorestore -c none --server.endpoint http+ssl://{con[\"hostname\"]}:{con[\"port\"]} --server.username {con[\"username\"]} --server.database {con[\"dbName\"]} --server.password {con[\"password\"]} --replication-factor 3 --input-directory \"data/fraud_dump\"" ] }, { @@ -424,7 +426,7 @@ "base_uri": "https://localhost:8080/" }, "id": "PybHP7jpeG8_", - "outputId": "724c9f23-c63b-4d34-f2f8-d9b7579cc985" + "outputId": "ba3bfc7c-ef56-47e7-8e98-3763d4f34afe" }, "outputs": [], "source": [ @@ -484,11 +486,11 @@ "base_uri": "https://localhost:8080/" }, "id": "oG496kBeeG9A", - "outputId": "164ea448-1117-4bde-e8e3-220558a1c0e3" + "outputId": "50ecbdf5-c82f-4540-d345-14eb4a488f2c" }, "outputs": [], "source": [ - "adbdgl_adapter = ArangoDB_DGL_Adapter(con)" + "adbdgl_adapter = ADBDGL_Adapter(con)" ] }, { @@ -518,7 +520,7 @@ "base_uri": "https://localhost:8080/" }, "id": "zZ-Hu3lLVHgd", - "outputId": "945d4971-cc05-4f97-eda3-b9e02cc05df8" + "outputId": "39f32c51-0753-45a8-a361-dcf46d4e6148" }, "outputs": [], "source": [ @@ -555,7 +557,7 @@ "base_uri": "https://localhost:8080/" }, "id": "i4XOpdRLUNlJ", - "outputId": "e445ea58-35ef-41ed-d4ac-717ac6f68e9c" + "outputId": "b58e75d1-e935-4abd-9bdb-bc8935d9cdc8" }, "outputs": [], "source": [ @@ -579,7 +581,7 @@ { "cell_type": "markdown", "metadata": { - "id": "umy25EsUU6Lg" + "id": "qEH6OdSB23Ya" }, "source": [ "## Via ArangoDB Metagraph" @@ -592,8 +594,59 @@ "colab": { "base_uri": "https://localhost:8080/" }, - "id": "UWX9-MsKeG9A", - "outputId": "f1ad45d9-d29a-4d7f-a853-b808ac898dfe" + "id": "7Kz8lXXq23Yk", + "outputId": "1458aef6-14e5-48c0-98bf-77f21431bc73" + }, + "outputs": [], + "source": [ + "# Define Metagraph\n", + "fraud_detection_metagraph = {\n", + " \"vertexCollections\": {\n", + " \"account\": {\"rank\", \"Balance\", \"customer_id\"},\n", + " \"Class\": {\"concrete\"},\n", + " \"customer\": {\"rank\"},\n", + " },\n", + " \"edgeCollections\": {\n", + " \"accountHolder\": {},\n", + " \"Relationship\": {},\n", + " \"transaction\": {\"receiver_bank_id\", \"sender_bank_id\", \"transaction_amt\"},\n", + " },\n", + "}\n", + "\n", + "# Create DGL Graph from attributes\n", + "dgl_g = adbdgl_adapter.arangodb_to_dgl('FraudDetection', fraud_detection_metagraph)\n", + "\n", + "# You can also provide valid Python-Arango AQL query options to the command above, like such:\n", + "# dgl_g = adbdgl_adapter.arangodb_to_dgl(graph_name = 'FraudDetection', fraud_detection_metagraph, ttl=1000, stream=True)\n", + "# See more here: https://docs.python-arango.com/en/main/specs.html#arango.aql.AQL.execute\n", + "\n", + "# Show graph data\n", + "print('\\n--------------')\n", + "print(dgl_g)\n", + "print('\\n--------------')\n", + "print(dgl_g.ndata)\n", + "print('--------------\\n')\n", + "print(dgl_g.edata)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DqIKT1lO4ASw" + }, + "source": [ + "## Via ArangoDB Metagraph with a custom controller" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "U4_vSdU_4AS4", + "outputId": "b719b0d4-c0a4-43a0-915f-8ee765e1ec86" }, "outputs": [], "source": [ @@ -611,9 +664,9 @@ " },\n", "}\n", "\n", - "# When converting to DGL via an ArangoDB metagraph, a user-defined Controller class\n", - "# is required, to specify how ArangoDB attributes should be converted into DGL features.\n", - "class FraudDetection_ADBDGL_Controller(Base_ADBDGL_Controller):\n", + "# When converting to DGL via an ArangoDB Metagraph that contains non-numerical values, a user-defined \n", + "# Controller class is required to specify how ArangoDB attributes should be converted to DGL features.\n", + "class FraudDetection_ADBDGL_Controller(ADBDGL_Controller):\n", " \"\"\"ArangoDB-DGL controller.\n", "\n", " Responsible for controlling how ArangoDB attributes\n", @@ -629,49 +682,59 @@ " and the collection it belongs to, convert it to a valid\n", " DGL feature: https://docs.dgl.ai/en/0.6.x/guide/graph-feature.html.\n", "\n", - " NOTE: You must override this function if you want to transfer non-numerical ArangoDB\n", - " attributes to DGL (DGL only accepts 'attributes' (a.k.a features) of numerical types).\n", - " Read more about DGL features here: https://docs.dgl.ai/en/0.6.x/new-tutorial/2_dglgraph.html#assigning-node-and-edge-features-to-graph.\n", + " NOTE: You must override this function if you want to transfer non-numerical\n", + " ArangoDB attributes to DGL (DGL only accepts 'attributes' (a.k.a features)\n", + " of numerical types). Read more about DGL features here:\n", + " https://docs.dgl.ai/en/0.6.x/new-tutorial/2_dglgraph.html#assigning-node-and-edge-features-to-graph.\n", + "\n", + " :param key: The ArangoDB attribute key name\n", + " :type key: str\n", + " :param col: The ArangoDB collection of the ArangoDB document.\n", + " :type col: str\n", + " :param val: The assigned attribute value of the ArangoDB document.\n", + " :type val: Any\n", + " :return: The attribute's representation as a DGL Feature\n", + " :rtype: Any\n", " \"\"\"\n", - " if type(val) in [int, float, bool]:\n", - " return val\n", - "\n", - " if col == \"transaction\":\n", - " if key == \"transaction_date\":\n", - " return int(str(val).replace(\"-\", \"\"))\n", - " \n", - " if key == \"trans_time\":\n", - " return int(str(val).replace(\":\", \"\"))\n", - " \n", - " if col == \"customer\":\n", - " if key == \"Sex\":\n", - " return 0 if val == \"M\" else 1\n", - "\n", - " if key == \"Ssn\":\n", - " return int(str(val).replace(\"-\", \"\"))\n", - "\n", - " if col == \"Class\":\n", - " if key == \"name\":\n", - " if val == \"Bank\":\n", - " return 0\n", - " elif val == \"Branch\":\n", - " return 1\n", - " elif val == \"Account\":\n", - " return 2\n", - " elif val == \"Customer\":\n", - " return 3\n", - " else:\n", - " return -1\n", + " try:\n", + " if col == \"transaction\":\n", + " if key == \"transaction_date\":\n", + " return int(str(val).replace(\"-\", \"\"))\n", + " \n", + " if key == \"trans_time\":\n", + " return int(str(val).replace(\":\", \"\"))\n", + " \n", + " if col == \"customer\":\n", + " if key == \"Sex\":\n", + " return 0 if val == \"M\" else 1\n", + "\n", + " if key == \"Ssn\":\n", + " return int(str(val).replace(\"-\", \"\"))\n", + "\n", + " if col == \"Class\":\n", + " if key == \"name\":\n", + " if val == \"Bank\":\n", + " return 0\n", + " elif val == \"Branch\":\n", + " return 1\n", + " elif val == \"Account\":\n", + " return 2\n", + " elif val == \"Customer\":\n", + " return 3\n", + " else:\n", + " return -1\n", + " except (ValueError, TypeError, SyntaxError):\n", + " return 0\n", "\n", " return super()._adb_attribute_to_dgl_feature(key, col, val)\n", "\n", - "fraud_adbgl_adapter = ArangoDB_DGL_Adapter(con, FraudDetection_ADBDGL_Controller)\n", + "fraud_adbdgl_adapter = ADBDGL_Adapter(con, FraudDetection_ADBDGL_Controller())\n", "\n", "# Create DGL Graph from attributes\n", - "dgl_g = fraud_adbgl_adapter.arangodb_to_dgl('FraudDetection', fraud_detection_metagraph)\n", + "dgl_g = fraud_adbdgl_adapter.arangodb_to_dgl('FraudDetection', fraud_detection_metagraph)\n", "\n", "# You can also provide valid Python-Arango AQL query options to the command above, like such:\n", - "# dgl_g = adbdgl_adapter.arangodb_to_dgl(graph_name = 'FraudDetection', fraud_detection_metagraph, ttl=1000, stream=True)\n", + "# dgl_g = fraud_adbdgl_adapter.arangodb_to_dgl(graph_name = 'FraudDetection', fraud_detection_metagraph, ttl=1000, stream=True)\n", "# See more here: https://docs.python-arango.com/en/main/specs.html#arango.aql.AQL.execute\n", "\n", "# Show graph data\n", @@ -707,10 +770,10 @@ "metadata": { "colab": { "base_uri": "https://localhost:8080/", - "height": 0 + "height": 577 }, "id": "eRVbiBy4ZdE4", - "outputId": "bcbfce84-8bc0-4605-82d7-78fbaab53527" + "outputId": "d44eb9d9-e046-443b-8ded-79654f004e02" }, "outputs": [], "source": [ @@ -723,14 +786,13 @@ "python_arango_db_driver.delete_graph(name, drop_collections=True, ignore_missing=True)\n", "adb_karate_graph = adbdgl_adapter.dgl_to_arangodb(name, dgl_karate_graph)\n", "\n", - "\n", - "print(f\"\\nInspect the graph here: https://tutorials.arangodb.cloud:8529/_db/{con['dbName']}/_admin/aardvark/index.html#graph/{name}\\n\")\n", - "\n", + "print('\\n--------------------')\n", "print(\"https://{}:{}\".format(con[\"hostname\"], con[\"port\"]))\n", "print(\"Username: \" + con[\"username\"])\n", "print(\"Password: \" + con[\"password\"])\n", "print(\"Database: \" + con[\"dbName\"])\n", - "\n", + "print('--------------------\\n')\n", + "print(f\"\\nInspect the graph here: https://tutorials.arangodb.cloud:8529/_db/{con['dbName']}/_admin/aardvark/index.html#graph/{name}\\n\")\n", "print(f\"\\nView the original graph below:\")" ] }, @@ -750,10 +812,10 @@ "metadata": { "colab": { "base_uri": "https://localhost:8080/", - "height": 0 + "height": 1000 }, "id": "dADiexlAioGH", - "outputId": "286375c7-f2c9-4843-fb30-d086994985fc" + "outputId": "273988c8-1749-4fe0-85fe-51b0e1ab2058" }, "outputs": [], "source": [ @@ -783,16 +845,16 @@ "adb_hypercube_graph = adbdgl_adapter.dgl_to_arangodb(hypercube, dgl_hypercube_graph)\n", "adb_clique_graph = adbdgl_adapter.dgl_to_arangodb(clique, dgl_clique_graph)\n", "\n", - "print(\"\\nInspect the graphs here:\\n\")\n", - "print(f\"1) https://tutorials.arangodb.cloud:8529/_db/{con['dbName']}/_admin/aardvark/index.html#graph/{lollipop}\")\n", - "print(f\"2) https://tutorials.arangodb.cloud:8529/_db/{con['dbName']}/_admin/aardvark/index.html#graph/{hypercube}\")\n", - "print(f\"3) https://tutorials.arangodb.cloud:8529/_db/{con['dbName']}/_admin/aardvark/index.html#graph/{clique}\\n\")\n", - "\n", + "print('\\n--------------------')\n", "print(\"https://{}:{}\".format(con[\"hostname\"], con[\"port\"]))\n", "print(\"Username: \" + con[\"username\"])\n", "print(\"Password: \" + con[\"password\"])\n", "print(\"Database: \" + con[\"dbName\"])\n", - "\n", + "print('--------------------\\n')\n", + "print(\"\\nInspect the graphs here:\\n\")\n", + "print(f\"1) https://tutorials.arangodb.cloud:8529/_db/{con['dbName']}/_admin/aardvark/index.html#graph/{lollipop}\")\n", + "print(f\"2) https://tutorials.arangodb.cloud:8529/_db/{con['dbName']}/_admin/aardvark/index.html#graph/{hypercube}\")\n", + "print(f\"3) https://tutorials.arangodb.cloud:8529/_db/{con['dbName']}/_admin/aardvark/index.html#graph/{clique}\\n\")\n", "print(f\"\\nView the original graphs below:\")" ] }, @@ -803,7 +865,7 @@ }, "source": [ "\n", - "## Example 3: DGL MiniGCDataset Graphs (with attribute transfer)" + "## Example 3: DGL MiniGCDataset Graphs with a custom controller" ] }, { @@ -814,30 +876,39 @@ "base_uri": "https://localhost:8080/" }, "id": "jbJsvMMaoJoT", - "outputId": "f426cdaf-a53c-4ade-d3b6-cbb11dd39c67" + "outputId": "2ddca41f-9c8b-4db4-c0aa-c1b2cc124fa5" }, "outputs": [], "source": [ "from torch.functional import Tensor\n", "\n", - "# Load the dgl graphs & populate node data\n", + "# Load the dgl graphs\n", "dgl_lollipop_graph = remove_self_loop(MiniGCDataset(8, 7, 8)[3][0])\n", - "dgl_lollipop_graph.ndata['lollipop_ndata'] = torch.ones(7)\n", - "\n", "dgl_hypercube_graph = remove_self_loop(MiniGCDataset(8, 8, 9)[4][0])\n", - "dgl_hypercube_graph.ndata['hypercube_ndata'] = torch.zeros(8)\n", - "\n", "dgl_clique_graph = remove_self_loop(MiniGCDataset(8, 6, 7)[6][0])\n", + "\n", + " # Add DGL Node & Edge Features to each graph\n", + "dgl_lollipop_graph.ndata[\"random_ndata\"] = torch.tensor(\n", + " [[i, i, i] for i in range(0, dgl_lollipop_graph.num_nodes())]\n", + ")\n", + "dgl_lollipop_graph.edata[\"random_edata\"] = torch.rand(dgl_lollipop_graph.num_edges())\n", + "\n", + "dgl_hypercube_graph.ndata[\"random_ndata\"] = torch.rand(dgl_hypercube_graph.num_nodes())\n", + "dgl_hypercube_graph.edata[\"random_edata\"] = torch.tensor(\n", + " [[[i], [i], [i]] for i in range(0, dgl_hypercube_graph.num_edges())]\n", + ")\n", + "\n", "dgl_clique_graph.ndata['clique_ndata'] = torch.tensor([1,2,3,4,5,6])\n", + "dgl_clique_graph.edata['clique_edata'] = torch.tensor(\n", + " [1 if i % 2 == 0 else 0 for i in range(0, dgl_clique_graph.num_edges())]\n", + ")\n", "\n", "\n", "# When converting to ArangoDB from DGL, a user-defined Controller class\n", "# is required to specify how DGL features (aka attributes) should be converted \n", - "# into ArangoDB attributes.\n", - "\n", - "# NOTE: A custom Controller is NOT needed you want to keep the \n", - "# numerical-based values of your DGL features (which is the case for dgl_lollipop_graph and dgl_hypercube_graph)\n", - "class Clique_ADBDGL_Controller(Base_ADBDGL_Controller):\n", + "# into ArangoDB attributes. NOTE: A custom Controller is NOT needed you want to\n", + "# keep the numerical-based values of your DGL features.\n", + "class Clique_ADBDGL_Controller(ADBDGL_Controller):\n", " \"\"\"ArangoDB-DGL controller.\n", "\n", " Responsible for controlling how ArangoDB attributes\n", @@ -850,26 +921,40 @@ " def _dgl_feature_to_adb_attribute(self, key: str, col: str, val: Tensor):\n", " \"\"\"\n", " Given a DGL feature key, its assigned value (for an arbitrary node or edge),\n", - " and the collection it belongs to, convert it to a valid ArangoDB attribute (e.g string, list, number, ...).\n", - "\n", - " NOTE: No action is needed here if you want to keep the numerical-based values of your DGL features.\n", + " and the collection it belongs to, convert it to a valid ArangoDB attribute\n", + " (e.g string, list, number, ...).\n", + "\n", + " NOTE: No action is needed here if you want to keep the numerical-based values\n", + " of your DGL features.\n", + "\n", + " :param key: The DGL attribute key name\n", + " :type key: str\n", + " :param col: The ArangoDB collection of the (soon-to-be) ArangoDB document.\n", + " :type col: str\n", + " :param val: The assigned attribute value of the DGL node.\n", + " :type val: Tensor\n", + " :return: The feature's representation as an ArangoDB Attribute\n", + " :rtype: Any\n", " \"\"\"\n", " if key == \"clique_ndata\":\n", " if val == 1:\n", " return \"one is fun\"\n", " elif val == 2:\n", - " return \"but two is blue\"\n", + " return \"two is blue\"\n", " elif val == 3:\n", - " return \"yet three is free\"\n", + " return \"three is free\"\n", " elif val == 4:\n", - " return \"and four is more\"\n", - " else:\n", + " return \"four is more\"\n", + " else: # No special string for values 5 & 6\n", " return f\"ERROR! Unrecognized value, got {val}\"\n", "\n", + " if key == \"clique_edata\":\n", + " return bool(val)\n", + "\n", " return super()._dgl_feature_to_adb_attribute(key, col, val)\n", "\n", "# Re-instantiate a new adapter specifically for the Clique Graph Conversion\n", - "clique_adbgl_adapter = ArangoDB_DGL_Adapter(con, Clique_ADBDGL_Controller)\n", + "clique_adbgl_adapter = ADBDGL_Adapter(con, Clique_ADBDGL_Controller())\n", "\n", "# Create the ArangoDB graphs\n", "lollipop = \"Lollipop_With_Attributes\"\n", @@ -884,15 +969,16 @@ "adb_hypercube_graph = adbdgl_adapter.dgl_to_arangodb(hypercube, dgl_hypercube_graph)\n", "adb_clique_graph = clique_adbgl_adapter.dgl_to_arangodb(clique, dgl_clique_graph) # Notice the new adapter here!\n", "\n", - "print(\"\\nInspect the graphs here:\\n\")\n", - "print(f\"1) https://tutorials.arangodb.cloud:8529/_db/{con['dbName']}/_admin/aardvark/index.html#graph/{lollipop}\")\n", - "print(f\"2) https://tutorials.arangodb.cloud:8529/_db/{con['dbName']}/_admin/aardvark/index.html#graph/{hypercube}\")\n", - "print(f\"3) https://tutorials.arangodb.cloud:8529/_db/{con['dbName']}/_admin/aardvark/index.html#graph/{clique}\\n\")\n", - "\n", + "print('\\n--------------------')\n", "print(\"https://{}:{}\".format(con[\"hostname\"], con[\"port\"]))\n", "print(\"Username: \" + con[\"username\"])\n", "print(\"Password: \" + con[\"password\"])\n", - "print(\"Database: \" + con[\"dbName\"])" + "print(\"Database: \" + con[\"dbName\"])\n", + "print('--------------------\\n')\n", + "print(\"\\nInspect the graphs here:\\n\")\n", + "print(f\"1) https://tutorials.arangodb.cloud:8529/_db/{con['dbName']}/_admin/aardvark/index.html#graph/{lollipop}\")\n", + "print(f\"2) https://tutorials.arangodb.cloud:8529/_db/{con['dbName']}/_admin/aardvark/index.html#graph/{hypercube}\")\n", + "print(f\"3) https://tutorials.arangodb.cloud:8529/_db/{con['dbName']}/_admin/aardvark/index.html#graph/{clique}\\n\")" ] } ], @@ -900,16 +986,15 @@ "colab": { "collapsed_sections": [ "ot1oJqn7m78n", - "Oc__NAd1eG8-", "7y81WHO8eG8_", "227hLXnPeG8_", "QfE_tKxneG9A", - "umy25EsUU6Lg", + "ZrEDmtqCVD0W", + "qEH6OdSB23Ya", "UafSB_3JZNwK", - "gshTlSX_ZZsS", - "CNj1xKhwoJoL" + "gshTlSX_ZZsS" ], - "name": "Copy of ArangoDB_DGLAdapter.ipynb", + "name": "ArangoDB_DGL_Adapter_v1.0.0.ipynb", "provenance": [] }, "kernelspec": { diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..b9911d5 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,23 @@ +[build-system] +requires = [ + "setuptools>=42", + "setuptools_scm[toml]>=3.4", + "wheel", +] +build-backend = "setuptools.build_meta" + +[tool.coverage.run] +omit = [ + "adbdgl_adapter/version.py", + "setup.py", +] + +[tool.isort] +profile = "black" + +[tool.pytest.ini_options] +minversion = "6.0" +testpaths = ["tests"] + +[tool.setuptools_scm] +write_to = "adbdgl_adapter/version.py" diff --git a/scripts/assert_version.py b/scripts/assert_version.py deleted file mode 100644 index 6621ea1..0000000 --- a/scripts/assert_version.py +++ /dev/null @@ -1,10 +0,0 @@ -# -*- coding: utf-8 -*- -import sys -from packaging.version import Version - -if __name__ == "__main__": - old = Version(sys.argv[1]) - current = Version(sys.argv[2]) - if current > old: - print("true") - sys.exit(0) diff --git a/scripts/extract_version.py b/scripts/extract_version.py deleted file mode 100644 index 8901dbd..0000000 --- a/scripts/extract_version.py +++ /dev/null @@ -1,9 +0,0 @@ -# -*- coding: utf-8 -*- -import requests - -if __name__ == "__main__": - response = requests.get( - "https://api.github.com/repos/arangoml/dgl-adapter/releases/latest" - ) - response.raise_for_status() - print(response.json().get("tag_name", "0.0.0")) diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..2e8bd1d --- /dev/null +++ b/setup.cfg @@ -0,0 +1,31 @@ +[metadata] +name = adbdgl_adapter +author = Anthony Mahanna +author_email = anthony.mahanna@arangodb.com +description = Convert ArangoDB graphs to DGL & vice-versa. +long_description = file: README.md +long_description_content_type = text/markdown +url = https://github.com/arangoml/dgl-adapter +classifiers = + Intended Audience :: Developers + License :: OSI Approved :: Apache Software License + Operating System :: OS Independent + Programming Language :: Python :: 3 + Programming Language :: Python :: 3.6 + Programming Language :: Python :: 3.7 + Programming Language :: Python :: 3.8 + Programming Language :: Python :: 3.9 + Topic :: Utilities + Typing :: Typed + +[options] +python_requires = >=3.6 + +[flake8] +max-line-length = 88 +extend-ignore = E203, E741, W503 +exclude =.git .idea .*_cache dist venv + +[mypy] +ignore_missing_imports = True +strict = True diff --git a/adbdgl_adapter/setup.py b/setup.py similarity index 58% rename from adbdgl_adapter/setup.py rename to setup.py index 5a62704..882e4f8 100644 --- a/adbdgl_adapter/setup.py +++ b/setup.py @@ -1,30 +1,43 @@ from setuptools import setup -with open("../VERSION") as f: - version = f.read().strip() - -with open("../README.md", "r") as f: - long_description = f.read() +with open("./README.md") as fp: + long_description = fp.read() setup( name="adbdgl_adapter", - author="ArangoDB", - author_email="hackers@arangodb.com", - version=version, + author="Anthony Mahanna", + author_email="anthony.mahanna@arangodb.com", description="Convert ArangoDB graphs to DGL & vice-versa.", long_description=long_description, long_description_content_type="text/markdown", url="https://github.com/arangoml/dgl-adapter", + keywords=["arangodb", "dgl", "adapter"], packages=["adbdgl_adapter"], include_package_data=True, + use_scm_version=True, + setup_requires=["setuptools_scm"], python_requires=">=3.6", license="Apache Software License", install_requires=[ - "python-arango==7.2.0", + "python-arango==7.3.0", "torch==1.10.0", "dgl==0.6.1", + "setuptools>=42", + "setuptools_scm[toml]>=3.4", ], - tests_require=["pytest", "pytest-cov"], + extras_require={ + "dev": [ + "black", + "flake8>=3.8.0", + "isort>=5.0.0", + "mypy>=0.790", + "pytest>=6.0.0", + "pytest-cov>=2.0.0", + "coveralls>=3.3.1", + "types-setuptools", + "types-requests", + ], + }, classifiers=[ "Intended Audience :: Developers", "License :: OSI Approved :: Apache Software License", diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/adbdgl_adapter/tests/assets/arangorestore b/tests/assets/arangorestore similarity index 100% rename from adbdgl_adapter/tests/assets/arangorestore rename to tests/assets/arangorestore diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100755 index 0000000..0cdf2c1 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,123 @@ +import json +import os +import subprocess +import time +from pathlib import Path + +from arango import ArangoClient +from arango.database import StandardDatabase +from dgl import DGLGraph, remove_self_loop +from dgl.data import KarateClubDataset, MiniGCDataset +from requests import post +from torch import ones, rand, tensor, zeros # type: ignore + +from adbdgl_adapter.adapter import ADBDGL_Adapter +from adbdgl_adapter.typings import Json + +PROJECT_DIR = Path(__file__).parent.parent + +con: Json +adbdgl_adapter: ADBDGL_Adapter +db: StandardDatabase + + +def pytest_sessionstart() -> None: + global con + con = get_oasis_crendetials() + # con = { + # "username": "root", + # "password": "openSesame", + # "hostname": "localhost", + # "port": 8529, + # "protocol": "http", + # "dbName": "_system", + # } + print_connection_details(con) + time.sleep(5) # Enough for the oasis instance to be ready. + + global adbdgl_adapter + adbdgl_adapter = ADBDGL_Adapter(con) + + global db + url = "https://" + con["hostname"] + ":" + str(con["port"]) + client = ArangoClient(hosts=url) + db = client.db(con["dbName"], con["username"], con["password"], verify=True) + + arango_restore(con, "examples/data/fraud_dump") + db.create_graph( + "fraud-detection", + edge_definitions=[ + { + "edge_collection": "accountHolder", + "from_vertex_collections": ["customer"], + "to_vertex_collections": ["account"], + }, + { + "edge_collection": "transaction", + "from_vertex_collections": ["account"], + "to_vertex_collections": ["account"], + }, + ], + ) + + +def get_oasis_crendetials() -> Json: + url = "https://tutorials.arangodb.cloud:8529/_db/_system/tutorialDB/tutorialDB" + request = post(url, data=json.dumps("{}")) + if request.status_code != 200: + raise Exception("Error retrieving login data.") + + creds: Json = json.loads(request.text) + return creds + + +def arango_restore(con: Json, path_to_data: str) -> None: + restore_prefix = "./assets/" if os.getenv("GITHUB_ACTIONS") else "" + + subprocess.check_call( + f'chmod -R 755 ./assets/arangorestore && {restore_prefix}arangorestore \ + -c none --server.endpoint http+ssl://{con["hostname"]}:{con["port"]} \ + --server.username {con["username"]} --server.database {con["dbName"]} \ + --server.password {con["password"]} \ + --input-directory "{PROJECT_DIR}/{path_to_data}"', + cwd=f"{PROJECT_DIR}/tests", + shell=True, + ) + + +def print_connection_details(con: Json) -> None: + print("----------------------------------------") + print("https://{}:{}".format(con["hostname"], con["port"])) + print("Username: " + con["username"]) + print("Password: " + con["password"]) + print("Database: " + con["dbName"]) + print("----------------------------------------") + + +def get_karate_graph() -> DGLGraph: + return KarateClubDataset()[0] + + +def get_lollipop_graph() -> DGLGraph: + dgl_g = remove_self_loop(MiniGCDataset(8, 7, 8)[3][0]) + dgl_g.ndata["random_ndata"] = tensor( + [[i, i, i] for i in range(0, dgl_g.num_nodes())] + ) + dgl_g.edata["random_edata"] = rand(dgl_g.num_edges()) + return dgl_g + + +def get_hypercube_graph() -> DGLGraph: + dgl_g = remove_self_loop(MiniGCDataset(8, 8, 9)[4][0]) + dgl_g.ndata["random_ndata"] = rand(dgl_g.num_nodes()) + dgl_g.edata["random_edata"] = tensor( + [[[i], [i], [i]] for i in range(0, dgl_g.num_edges())] + ) + return dgl_g + + +def get_clique_graph() -> DGLGraph: + dgl_g = remove_self_loop(MiniGCDataset(8, 6, 7)[6][0]) + dgl_g.ndata["random_ndata"] = ones(dgl_g.num_nodes()) + dgl_g.edata["random_edata"] = zeros(dgl_g.num_edges()) + return dgl_g diff --git a/adbdgl_adapter/tests/test_adbdgl_adapter.py b/tests/test_adapter.py similarity index 79% rename from adbdgl_adapter/tests/test_adbdgl_adapter.py rename to tests/test_adapter.py index 8f7ae4b..f9fcf1c 100644 --- a/adbdgl_adapter/tests/test_adbdgl_adapter.py +++ b/tests/test_adapter.py @@ -1,27 +1,26 @@ -from typing import Union +from typing import Set, Union import pytest -from conftest import ( - ArangoDB_DGL_Adapter, - get_karate_graph, - get_lollipop_graph, - get_hypercube_graph, - get_clique_graph, - db, - conn, - adbdgl_adapter, -) - +from arango.graph import Graph as ArangoGraph from dgl import DGLGraph from dgl.heterograph import DGLHeteroGraph -from arango.graph import Graph as ArangoGraph - -import torch from torch.functional import Tensor +from adbdgl_adapter.adapter import ADBDGL_Adapter +from adbdgl_adapter.typings import ArangoMetagraph -@pytest.mark.unit -def test_validate_attributes(): +from .conftest import ( + adbdgl_adapter, + con, + db, + get_clique_graph, + get_hypercube_graph, + get_karate_graph, + get_lollipop_graph, +) + + +def test_validate_attributes() -> None: bad_connection = { "dbName": "_system", "hostname": "localhost", @@ -32,19 +31,17 @@ def test_validate_attributes(): } with pytest.raises(ValueError): - ArangoDB_DGL_Adapter(bad_connection) + ADBDGL_Adapter(bad_connection) -@pytest.mark.unit -def test_validate_controller_class(): +def test_validate_controller_class() -> None: class Bad_ADBDGL_Controller: pass with pytest.raises(TypeError): - ArangoDB_DGL_Adapter(conn, Bad_ADBDGL_Controller) + ADBDGL_Adapter(con, Bad_ADBDGL_Controller()) # type: ignore -@pytest.mark.unit @pytest.mark.parametrize( "adapter, name, metagraph", [ @@ -70,13 +67,13 @@ class Bad_ADBDGL_Controller: ), ], ) -def test_adb_to_dgl(adapter: ArangoDB_DGL_Adapter, name: str, metagraph: dict): - assert_adapter_type(adapter) +def test_adb_to_dgl( + adapter: ADBDGL_Adapter, name: str, metagraph: ArangoMetagraph +) -> None: dgl_g = adapter.arangodb_to_dgl(name, metagraph) - assert_dgl_data(dgl_g, metagraph["vertexCollections"], metagraph["edgeCollections"]) + assert_dgl_data(dgl_g, metagraph) -@pytest.mark.unit @pytest.mark.parametrize( "adapter, name, v_cols, e_cols", [ @@ -89,38 +86,41 @@ def test_adb_to_dgl(adapter: ArangoDB_DGL_Adapter, name: str, metagraph: dict): ], ) def test_adb_collections_to_dgl( - adapter: ArangoDB_DGL_Adapter, name: str, v_cols: set, e_cols: set -): - assert_adapter_type(adapter) + adapter: ADBDGL_Adapter, name: str, v_cols: Set[str], e_cols: Set[str] +) -> None: dgl_g = adapter.arangodb_collections_to_dgl( name, v_cols, e_cols, ) assert_dgl_data( - dgl_g, {v_col: {} for v_col in v_cols}, {e_col: {} for e_col in e_cols} + dgl_g, + metagraph={ + "vertexCollections": {col: set() for col in v_cols}, + "edgeCollections": {col: set() for col in e_cols}, + }, ) -@pytest.mark.unit @pytest.mark.parametrize( "adapter, name", [(adbdgl_adapter, "fraud-detection")], ) -def test_adb_graph_to_dgl(adapter: ArangoDB_DGL_Adapter, name: str): - assert_adapter_type(adapter) - +def test_adb_graph_to_dgl(adapter: ADBDGL_Adapter, name: str) -> None: arango_graph = db.graph(name) v_cols = arango_graph.vertex_collections() e_cols = {col["edge_collection"] for col in arango_graph.edge_definitions()} dgl_g: DGLGraph = adapter.arangodb_graph_to_dgl(name) assert_dgl_data( - dgl_g, {v_col: {} for v_col in v_cols}, {e_col: {} for e_col in e_cols} + dgl_g, + metagraph={ + "vertexCollections": {col: set() for col in v_cols}, + "edgeCollections": {col: set() for col in e_cols}, + }, ) -@pytest.mark.unit @pytest.mark.parametrize( "adapter, name, dgl_g, is_default_type, batch_size", [ @@ -131,26 +131,21 @@ def test_adb_graph_to_dgl(adapter: ArangoDB_DGL_Adapter, name: str): ], ) def test_dgl_to_adb( - adapter: ArangoDB_DGL_Adapter, + adapter: ADBDGL_Adapter, name: str, dgl_g: Union[DGLGraph, DGLHeteroGraph], is_default_type: bool, batch_size: int, -): - assert_adapter_type(adapter) +) -> None: adb_g = adapter.dgl_to_arangodb(name, dgl_g, batch_size) assert_arangodb_data(name, dgl_g, adb_g, is_default_type) -def assert_adapter_type(adapter: ArangoDB_DGL_Adapter): - assert type(adapter) is ArangoDB_DGL_Adapter - - -def assert_dgl_data(dgl_g: DGLGraph, v_cols: dict, e_cols: dict): - has_one_ntype = len(v_cols) == 1 - has_one_etype = len(e_cols) == 1 +def assert_dgl_data(dgl_g: DGLGraph, metagraph: ArangoMetagraph) -> None: + has_one_ntype = len(metagraph["vertexCollections"]) == 1 + has_one_etype = len(metagraph["edgeCollections"]) == 1 - for col, atribs in v_cols.items(): + for col, atribs in metagraph["vertexCollections"].items(): num_nodes = dgl_g.num_nodes(col) assert num_nodes == db.collection(col).count() @@ -162,7 +157,7 @@ def assert_dgl_data(dgl_g: DGLGraph, v_cols: dict, e_cols: dict): assert col in dgl_g.ndata[atrib] assert len(dgl_g.ndata[atrib][col]) == num_nodes - for col, atribs in e_cols.items(): + for col, atribs in metagraph["edgeCollections"].items(): num_edges = dgl_g.num_edges(col) assert num_edges == db.collection(col).count() @@ -181,7 +176,7 @@ def assert_arangodb_data( dgl_g: Union[DGLGraph, DGLHeteroGraph], adb_g: ArangoGraph, is_default_type: bool, -): +) -> None: for dgl_v_col in dgl_g.ntypes: adb_v_col = name + dgl_v_col if is_default_type else dgl_v_col attributes = dgl_g.node_attr_schemes(