From 6d6bccb6da979afabdb354f359999c0ff711cb6f Mon Sep 17 00:00:00 2001
From: Anthony Mahanna <43019056+aMahanna@users.noreply.github.com>
Date: Thu, 30 Dec 2021 14:48:39 -0500
Subject: [PATCH] new: mirror networkx-adapter changes (#3)
---
.github/workflows/analyze.yml | 25 +-
.github/workflows/build.yml | 56 +--
.github/workflows/release.yml | 328 ++++++++----------
.gitignore | 127 ++++++-
MANIFEST.in | 3 +
README.md | 11 +-
VERSION | 1 -
.../{adbdgl_adapter => }/__init__.py | 0
adbdgl_adapter/abc.py | 82 +++++
.../adbdgl_adapter.py => adapter.py} | 304 ++++++++--------
adbdgl_adapter/adbdgl_adapter/abc.py | 69 ----
.../adbdgl_adapter/adbdgl_controller.py | 49 ---
adbdgl_adapter/controller.py | 70 ++++
adbdgl_adapter/setup.cfg | 7 -
adbdgl_adapter/tests/conftest.py | 119 -------
adbdgl_adapter/typings.py | 12 +
examples/ArangoDB_DGL_Adapter.ipynb | 303 ++++++++++------
pyproject.toml | 23 ++
scripts/assert_version.py | 10 -
scripts/extract_version.py | 9 -
setup.cfg | 31 ++
adbdgl_adapter/setup.py => setup.py | 33 +-
tests/__init__.py | 0
.../tests => tests}/assets/arangorestore | Bin
tests/conftest.py | 123 +++++++
.../test_adapter.py | 93 +++--
26 files changed, 1080 insertions(+), 808 deletions(-)
create mode 100644 MANIFEST.in
delete mode 100644 VERSION
rename adbdgl_adapter/{adbdgl_adapter => }/__init__.py (100%)
create mode 100644 adbdgl_adapter/abc.py
rename adbdgl_adapter/{adbdgl_adapter/adbdgl_adapter.py => adapter.py} (61%)
delete mode 100644 adbdgl_adapter/adbdgl_adapter/abc.py
delete mode 100644 adbdgl_adapter/adbdgl_adapter/adbdgl_controller.py
create mode 100644 adbdgl_adapter/controller.py
delete mode 100644 adbdgl_adapter/setup.cfg
delete mode 100755 adbdgl_adapter/tests/conftest.py
create mode 100644 adbdgl_adapter/typings.py
create mode 100644 pyproject.toml
delete mode 100644 scripts/assert_version.py
delete mode 100644 scripts/extract_version.py
create mode 100644 setup.cfg
rename adbdgl_adapter/setup.py => setup.py (58%)
create mode 100644 tests/__init__.py
rename {adbdgl_adapter/tests => tests}/assets/arangorestore (100%)
create mode 100755 tests/conftest.py
rename adbdgl_adapter/tests/test_adbdgl_adapter.py => tests/test_adapter.py (79%)
diff --git a/.github/workflows/analyze.yml b/.github/workflows/analyze.yml
index e0001fb..dc84535 100644
--- a/.github/workflows/analyze.yml
+++ b/.github/workflows/analyze.yml
@@ -11,23 +11,30 @@
#
name: analyze
on:
- workflow_dispatch:
+ push:
+ branches: [ master ]
+ paths:
+ - 'adbdgl_adapter/**'
+ - 'tests/**'
+ - 'setup.py'
+ - 'setup.cfg'
+ - 'pyproject.toml'
+ - '.github/workflows/analyze.yml'
pull_request:
- # The branches below must be a subset of the branches above
- branches: [master]
+ branches: [ master ]
paths:
- - "adbdgl_adapter/**"
+ - 'adbdgl_adapter/**'
+ - 'tests/**'
+ - 'setup.py'
+ - 'setup.cfg'
+ - 'pyproject.toml'
+ - '.github/workflows/analyze.yml'
schedule:
- cron: "00 9 * * 1"
-env:
- SOURCE_DIR: adbdgl_adapter
jobs:
analyze:
name: Analyze
runs-on: ubuntu-latest
- defaults:
- run:
- working-directory: ${{env.SOURCE_DIR}}
permissions:
actions: read
contents: read
diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
index 6f6da1c..afbde43 100644
--- a/.github/workflows/build.yml
+++ b/.github/workflows/build.yml
@@ -1,40 +1,56 @@
name: build
on:
workflow_dispatch:
+ push:
+ branches: [ master ]
+ paths:
+ - 'adbdgl_adapter/**'
+ - 'tests/**'
+ - 'setup.py'
+ - 'setup.cfg'
+ - 'pyproject.toml'
+ - '.github/workflows/build.yml'
pull_request:
+ branches: [ master ]
paths:
- - "adbdgl_adapter/adbdgl_adapter/**"
- - "adbdgl_adapter/tests/**"
+ - 'adbdgl_adapter/**'
+ - 'tests/**'
+ - 'setup.py'
+ - 'setup.cfg'
+ - 'pyproject.toml'
+ - '.github/workflows/build.yml'
env:
- SOURCE_DIR: adbdgl_adapter
PACKAGE_DIR: adbdgl_adapter
+ TESTS_DIR: tests
jobs:
build:
runs-on: ubuntu-latest
- defaults:
- run:
- working-directory: ${{env.SOURCE_DIR}}
strategy:
matrix:
python: ["3.6", "3.7", "3.8", "3.9"]
name: Python ${{ matrix.python }}
- env:
- COVERALLS_REPO_TOKEN: ${{secrets.COVERALLS_REPO_TOKEN}}
- GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}}
steps:
- uses: actions/checkout@v2
- - name: Setup python
+ - name: Setup Python ${{ matrix.python }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python }}
- - name: Lint with Black
- uses: psf/black@stable
- with:
- options: "--check --verbose --diff --color"
- src: ${{env.PACKAGE_DIR}}
- - name: Install dependencies
- run: pip install -e . pytest pytest-cov coveralls
+ - name: Setup pip
+ run: python -m pip install --upgrade pip setuptools wheel
+ - name: Install packages
+ run: pip install .[dev]
+ - name: Run black
+ run: black --check --verbose --diff --color ${{env.PACKAGE_DIR}} ${{env.TESTS_DIR}}
+ - name: Run flake8
+ run: flake8 ${{env.PACKAGE_DIR}} ${{env.TESTS_DIR}}
+ - name: Run isort
+ run: isort --check --profile=black ${{env.PACKAGE_DIR}} ${{env.TESTS_DIR}}
+ - name: Run mypy
+ run: mypy ${{env.PACKAGE_DIR}} ${{env.TESTS_DIR}}
- name: Run pytest
- run: |
- pytest --cov=${{env.PACKAGE_DIR}} --cov-report term-missing -v --color=yes --no-cov-on-fail --code-highlight=yes
- coveralls
+ run: py.test --cov=${{env.PACKAGE_DIR}} --cov-report xml -v --color=yes --no-cov-on-fail --code-highlight=yes
+ - name: Publish to coveralls.io
+ if: matrix.python == '3.8'
+ env:
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ run: coveralls --service=github
\ No newline at end of file
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index 9b2a48e..50cea35 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -1,192 +1,136 @@
-# name: release
-# on:
-# workflow_dispatch:
-# push:
-# branches:
-# - master
-# paths:
-# - "adbdgl_adapter/adbdgl_adapter/**"
-# env:
-# SOURCE_DIR: adbdgl_adapter
-# PACKAGE_DIR: adbdgl_adapter
-# jobs:
-# version:
-# runs-on: ubuntu-latest
-# name: Verify version increase
-# steps:
-# - uses: actions/checkout@v2
-# - uses: actions/setup-python@v2
-# with:
-# python-version: "3.9"
-# - name: Install dependencies
-# run: pip install requests packaging
-# - name: Set variables
-# run: |
-# echo "OLD_VERSION=$(python scripts/extract_version.py)" >> $GITHUB_ENV
-# echo "NEW_VERSION=$(cat VERSION)" >> $GITHUB_ENV
-# - name: Assert version increase
-# id: verify
-# run: echo "::set-output name=has_increased::$(python scripts/assert_version.py ${{env.OLD_VERSION}} ${{env.NEW_VERSION}})"
-# - name: Fail on no version increase
-# if: ${{ steps.verify.outputs.has_increased != 'true' }}
-# uses: actions/github-script@v3
-# with:
-# script: core.setFailed("Cannot build & release - VERSION has not been manually incremented")
-# build:
-# needs: version
-# runs-on: ubuntu-latest
-# defaults:
-# run:
-# working-directory: ${{env.SOURCE_DIR}}
-# strategy:
-# matrix:
-# python: ["3.6", "3.7", "3.8", "3.9"]
-# name: Python ${{ matrix.python }}
-# env:
-# COVERALLS_REPO_TOKEN: ${{secrets.COVERALLS_REPO_TOKEN}}
-# GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}}
-# steps:
-# - uses: actions/checkout@v2
-# - uses: actions/setup-python@v2
-# with:
-# python-version: ${{ matrix.python }}
-# - name: Lint with Black
-# uses: psf/black@stable
-# with:
-# options: "--check --verbose --diff --color"
-# src: ${{env.PACKAGE_DIR}}
-# - name: Install dependencies
-# run: pip install -e . pytest pytest-cov coveralls
-# - name: Run pytest
-# run: |
-# pytest --cov=${{env.PACKAGE_DIR}} --cov-report term-missing -v --color=yes --no-cov-on-fail --code-highlight=yes
-# coveralls
-
-# release:
-# needs: build
-# runs-on: ubuntu-latest
-# name: Release package
-# env:
-# TWINE_USERNAME: ${{ secrets.TWINE_USERNAME }}
-# steps:
-# - uses: actions/checkout@v2
-# with:
-# fetch-depth: 0
-
-# - name: Setup python
-# uses: actions/setup-python@v2
-# with:
-# python-version: "3.8"
-
-# - name: Copy static repo files
-# run: cp {CHANGELOG.md,LICENSE,README.md,VERSION} ${{env.SOURCE_DIR}}
-
-# - name: Install release packages
-# run: pip install wheel gitchangelog pystache twine
-
-# - name: Install dependencies
-# run: pip install -e .
-# working-directory: ${{env.SOURCE_DIR}}
-
-# - name: Set variables
-# run: |
-# echo "OLD_VERSION=$(python scripts/extract_version.py)" >> $GITHUB_ENV
-# echo "NEW_VERSION=$(cat VERSION)" >> $GITHUB_ENV
-
-# - name: Ensure clean dist/ and build/ folders
-# run: rm -rf dist build
-# working-directory: ${{env.SOURCE_DIR}}
-
-# - name: Build package
-# run: python setup.py sdist bdist_wheel
-# working-directory: ${{env.SOURCE_DIR}}
-
-# - name: Extract wheel artifact name
-# run: echo "wheel_name=$(echo ${{env.SOURCE_DIR}}/dist/*.whl)" >> $GITHUB_ENV
-
-# - name: Extract tar.gz artifact name
-# run: echo "tar_name=$(echo ${{env.SOURCE_DIR}}/dist/*.tar.gz)" >> $GITHUB_ENV
-
-# - name: Pull tags from the repo
-# run: git pull --tags
-
-# - name: Create version_changelog.md
-# run: gitchangelog ${{env.OLD_VERSION}}..HEAD | sed "s/## (unreleased)/${{env.NEW_VERSION}} ($(date +"%Y-%m-%d"))/" > version_changelog.md
-
-# - name: Read version_changelog.md
-# run: cat version_changelog.md
-
-# - name: TestPypi release
-# run: twine upload --repository testpypi dist/* -p ${{ secrets.TWINE_PASSWORD_TEST }} #--skip-existing
-# working-directory: ${{env.SOURCE_DIR}}
-
-# - name: Pypi release
-# run: twine upload dist/* -p ${{ secrets.TWINE_PASSWORD }} #--skip-existing
-# working-directory: ${{env.SOURCE_DIR}}
-
-# - name: Github release
-# env:
-# GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
-# run: hub release create -a $wheel_name -a $tar_name -F version_changelog.md ${{env.NEW_VERSION}}
-
-# changelog:
-# needs: release
-# runs-on: ubuntu-latest
-# name: Update Changelog
-# steps:
-# - uses: actions/checkout@v2
-# with:
-# fetch-depth: 0
-
-# - name: Create new branch
-# run: git checkout -b actions/changelog
-
-# - name: Set branch upstream
-# run: git push -u origin actions/changelog
-# env:
-# GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
-
-# - name: Setup python
-# uses: actions/setup-python@v2
-# with:
-# python-version: "3.8"
-
-# - name: Install release packages
-# run: pip install wheel gitchangelog pystache
-
-# - name: Install dependencies
-# run: pip install -e .
-# working-directory: ${{env.SOURCE_DIR}}
-
-# - name: Set variables
-# run: echo "NEW_VERSION=$(cat VERSION)" >> $GITHUB_ENV
-
-# - name: Generate newest changelog
-# run: gitchangelog ${{env.NEW_VERSION}} > CHANGELOG.md
-
-# - name: Make commit for auto-generated changelog
-# uses: EndBug/add-and-commit@v7
-# env:
-# GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
-# with:
-# add: "CHANGELOG.md"
-# branch: actions/changelog
-# message: "!gitchangelog"
-
-# - name: Create pull request for the auto generated changelog
-# run: |
-# echo "PR_URL=$(gh pr create \
-# --title "changelog: release ${{env.NEW_VERSION}}" \
-# --body "beep boop, i am a robot" \
-# --label documentation)" >> $GITHUB_ENV
-# env:
-# GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
-
-# - name: Set pull request to auto-merge as rebase
-# run: |
-# gh pr merge $PR_URL \
-# --auto \
-# --delete-branch \
-# --rebase
-# env:
-# GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+name: release
+on:
+ workflow_dispatch:
+ release:
+ types: [published]
+env:
+ PACKAGE_DIR: adbdgl_adapter
+ TESTS_DIR: tests
+jobs:
+ build:
+ runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ python: ["3.6", "3.7", "3.8", "3.9"]
+ name: Python ${{ matrix.python }}
+ steps:
+ - uses: actions/checkout@v2
+ - name: Setup Python ${{ matrix.python }}
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python }}
+ - name: Setup pip
+ run: python -m pip install --upgrade pip setuptools wheel
+ - name: Install packages
+ run: pip install .[dev]
+ - name: Run black
+ run: black --check --verbose --diff --color ${{env.PACKAGE_DIR}} ${{env.TESTS_DIR}}
+ - name: Run flake8
+ run: flake8 ${{env.PACKAGE_DIR}} ${{env.TESTS_DIR}}
+ - name: Run isort
+ run: isort --check --profile=black ${{env.PACKAGE_DIR}} ${{env.TESTS_DIR}}
+ - name: Run mypy
+ run: mypy ${{env.PACKAGE_DIR}} ${{env.TESTS_DIR}}
+ - name: Run pytest
+ run: py.test --cov=${{env.PACKAGE_DIR}} --cov-report xml -v --color=yes --no-cov-on-fail --code-highlight=yes
+ - name: Publish to coveralls.io
+ if: matrix.python == '3.8'
+ env:
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ run: coveralls --service=github
+
+ release:
+ needs: build
+ runs-on: ubuntu-latest
+ name: Release package
+ steps:
+ - uses: actions/checkout@v2
+
+ - name: Fetch complete history for all tags and branches
+ run: git fetch --prune --unshallow
+
+ - name: Setup python
+ uses: actions/setup-python@v2
+ with:
+ python-version: "3.8"
+
+ - name: Install release packages
+ run: pip install setuptools wheel twine setuptools-scm[toml]
+
+ - name: Install dependencies
+ run: pip install .[dev]
+
+ - name: Build distribution
+ run: python setup.py sdist bdist_wheel
+
+ - name: Publish to PyPI Test
+ env:
+ TWINE_USERNAME: __token__
+ TWINE_PASSWORD: ${{ secrets.TWINE_PASSWORD_TEST }}
+ run: twine upload --repository testpypi dist/* #--skip-existing
+ - name: Publish to PyPI
+ env:
+ TWINE_USERNAME: __token__
+ TWINE_PASSWORD: ${{ secrets.TWINE_PASSWORD }}
+ run: twine upload --repository pypi dist/* #--skip-existing
+
+ changelog:
+ needs: release
+ runs-on: ubuntu-latest
+ name: Update Changelog
+ steps:
+ - uses: actions/checkout@v2
+ with:
+ fetch-depth: 0
+
+ - name: Create new branch
+ run: git checkout -b actions/changelog
+
+ - name: Set branch upstream
+ run: git push -u origin actions/changelog
+ env:
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+
+ - name: Setup python
+ uses: actions/setup-python@v2
+ with:
+ python-version: "3.8"
+
+ - name: Install release packages
+ run: pip install wheel gitchangelog pystache
+
+ - name: Install dependencies
+ run: pip install .[dev]
+
+ - name: Set variables
+ run: echo "VERSION=$(curl ${GITHUB_API_URL}/repos/${GITHUB_REPOSITORY}/releases/latest | python -c "import sys; import json; print(json.load(sys.stdin)['tag_name'])")" >> $GITHUB_ENV
+
+ - name: Generate newest changelog
+ run: gitchangelog ${{env.VERSION}} > CHANGELOG.md
+
+ - name: Make commit for auto-generated changelog
+ uses: EndBug/add-and-commit@v7
+ env:
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ with:
+ add: "CHANGELOG.md"
+ branch: actions/changelog
+ message: "!gitchangelog"
+
+ - name: Create pull request for the auto generated changelog
+ run: |
+ echo "PR_URL=$(gh pr create \
+ --title "changelog: release ${{env.VERSION}}" \
+ --body "beep boop, i am a robot" \
+ --label documentation)" >> $GITHUB_ENV
+ env:
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+
+ - name: Set pull request to auto-merge as rebase
+ run: |
+ gh pr merge $PR_URL \
+ --admin \
+ --delete-branch \
+ --rebase
+ env:
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
diff --git a/.gitignore b/.gitignore
index a2571cb..efb3b0c 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,13 +1,118 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
.ipynb_checkpoints
-.tox
+
+# pyenv
+.python-version
+
+# celery beat schedule file
+celerybeat-schedule
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+
+# MacOS
.DS_Store
-**/*.pyc
-# log files
-**/*.log
-# Setuptools distribution folder.
-adbdgl_adapter/dist/
-# Remove the build directory from repo
-adbdgl_adapter/build/
-adbdgl_adapter/*.egg-info
-.vscode
-.venv
\ No newline at end of file
+
+# PyCharm
+.idea/
+
+# ArangoDB Starter
+localdata/
+
+# setuptools_scm
+adbdgl_adapter/version.py
+
+.vscode
\ No newline at end of file
diff --git a/MANIFEST.in b/MANIFEST.in
new file mode 100644
index 0000000..3d73851
--- /dev/null
+++ b/MANIFEST.in
@@ -0,0 +1,3 @@
+include README.md LICENSE
+prune tests
+prune examples
\ No newline at end of file
diff --git a/README.md b/README.md
index 255cada..2c260a7 100644
--- a/README.md
+++ b/README.md
@@ -29,12 +29,12 @@ The Deep Graph Library (DGL) is an easy-to-use, high performance and scalable Py
## Quickstart
-Get Started on Colab:
+Get Started on Colab:
```py
# Import the ArangoDB-DGL Adapter
-from adbdgl_adapter.adbdgl_adapter import ArangoDB_DGL_Adapter
+from adbdgl_adapter.adapter import ADBDGL_Adapter
# Import a sample graph from DGL
from dgl.data import KarateClubDataset
@@ -51,7 +51,7 @@ con = {
}
# This instantiates your ADBDGL Adapter with your connection credentials
-adbdgl_adapter = ArangoDB_DGL_Adapter(con)
+adbdgl_adapter = ADBDGL_Adapter(con)
# ArangoDB to DGL via Graph
dgl_fraud_graph = adbdgl_adapter.arangodb_graph_to_dgl("fraud-detection")
@@ -89,6 +89,5 @@ Prerequisite: `arangorestore` must be installed
2. `cd dgl-adapter`
3. `python -m venv .venv`
4. `source .venv/bin/activate` (MacOS) or `.venv/scripts/activate` (Windows)
-5. `cd adbdgl_adapter`
-6. `pip install -e . pytest`
-7. `pytest`
\ No newline at end of file
+5. `pip install -e . pytest`
+6. `pytest`
\ No newline at end of file
diff --git a/VERSION b/VERSION
deleted file mode 100644
index bd52db8..0000000
--- a/VERSION
+++ /dev/null
@@ -1 +0,0 @@
-0.0.0
\ No newline at end of file
diff --git a/adbdgl_adapter/adbdgl_adapter/__init__.py b/adbdgl_adapter/__init__.py
similarity index 100%
rename from adbdgl_adapter/adbdgl_adapter/__init__.py
rename to adbdgl_adapter/__init__.py
diff --git a/adbdgl_adapter/abc.py b/adbdgl_adapter/abc.py
new file mode 100644
index 0000000..3219d71
--- /dev/null
+++ b/adbdgl_adapter/abc.py
@@ -0,0 +1,82 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+from abc import ABC
+from typing import Any, List, Set, Union
+
+from arango.graph import Graph as ArangoDBGraph
+from dgl import DGLGraph
+from dgl.heterograph import DGLHeteroGraph
+from torch.functional import Tensor
+
+from .typings import ArangoMetagraph, DGLCanonicalEType, Json
+
+
+class Abstract_ADBDGL_Adapter(ABC):
+ def __init__(self) -> None:
+ raise NotImplementedError # pragma: no cover
+
+ def arangodb_to_dgl(
+ self, name: str, metagraph: ArangoMetagraph, **query_options: Any
+ ) -> DGLHeteroGraph:
+ raise NotImplementedError # pragma: no cover
+
+ def arangodb_collections_to_dgl(
+ self, name: str, v_cols: Set[str], e_cols: Set[str], **query_options: Any
+ ) -> DGLHeteroGraph:
+ raise NotImplementedError # pragma: no cover
+
+ def arangodb_graph_to_dgl(self, name: str, **query_options: Any) -> DGLHeteroGraph:
+ raise NotImplementedError # pragma: no cover
+
+ def dgl_to_arangodb(
+ self, name: str, dgl_g: Union[DGLGraph, DGLHeteroGraph], batch_size: int
+ ) -> ArangoDBGraph:
+ raise NotImplementedError # pragma: no cover
+
+ def etypes_to_edefinitions(
+ self, canonical_etypes: List[DGLCanonicalEType]
+ ) -> List[Json]:
+ raise NotImplementedError # pragma: no cover
+
+ def __prepare_dgl_features(self) -> None:
+ raise NotImplementedError # pragma: no cover
+
+ def __insert_dgl_features(self) -> None:
+ raise NotImplementedError # pragma: no cover
+
+ def __prepare_adb_attributes(self) -> None:
+ raise NotImplementedError # pragma: no cover
+
+ def __insert_adb_docs(self) -> None:
+ raise NotImplementedError # pragma: no cover
+
+ def __fetch_adb_docs(self) -> None:
+ raise NotImplementedError # pragma: no cover
+
+ def __validate_attributes(self) -> None:
+ raise NotImplementedError # pragma: no cover
+
+ @property
+ def DEFAULT_CANONICAL_ETYPE(self) -> List[DGLCanonicalEType]:
+ return [("_N", "_E", "_N")]
+
+ @property
+ def CONNECTION_ATRIBS(self) -> Set[str]:
+ return {"hostname", "username", "password", "dbName"}
+
+ @property
+ def METAGRAPH_ATRIBS(self) -> Set[str]:
+ return {"vertexCollections", "edgeCollections"}
+
+ @property
+ def EDGE_DEFINITION_ATRIBS(self) -> Set[str]:
+ return {"edge_collection", "from_vertex_collections", "to_vertex_collections"}
+
+
+class Abstract_ADBDGL_Controller(ABC):
+ def _adb_attribute_to_dgl_feature(self, key: str, col: str, val: Any) -> Any:
+ raise NotImplementedError # pragma: no cover
+
+ def _dgl_feature_to_adb_attribute(self, key: str, col: str, val: Tensor) -> Any:
+ raise NotImplementedError # pragma: no cover
diff --git a/adbdgl_adapter/adbdgl_adapter/adbdgl_adapter.py b/adbdgl_adapter/adapter.py
similarity index 61%
rename from adbdgl_adapter/adbdgl_adapter/adbdgl_adapter.py
rename to adbdgl_adapter/adapter.py
index 50d13fe..b9d3075 100644
--- a/adbdgl_adapter/adbdgl_adapter/adbdgl_adapter.py
+++ b/adbdgl_adapter/adapter.py
@@ -1,75 +1,73 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
-"""
-@author: Anthony Mahanna
-"""
-from .abc import ADBDGL_Adapter
-from .adbdgl_controller import Base_ADBDGL_Controller
+from collections import defaultdict
+from typing import Any, DefaultDict, Dict, List, Set, Union
from arango import ArangoClient
+from arango.cursor import Cursor
from arango.graph import Graph as ArangoDBGraph
-
-import dgl
-from dgl import DGLGraph
+from arango.result import Result
+from dgl import DGLGraph, heterograph
from dgl.heterograph import DGLHeteroGraph
from dgl.view import HeteroEdgeDataView, HeteroNodeDataView
-
-import torch
+from torch import tensor # type: ignore
from torch.functional import Tensor
-from typing import Union
-from collections import defaultdict
+from .abc import Abstract_ADBDGL_Adapter
+from .controller import ADBDGL_Controller
+from .typings import ArangoMetagraph, DGLCanonicalEType, DGLDataDict, Json
-class ArangoDB_DGL_Adapter(ADBDGL_Adapter):
+class ADBDGL_Adapter(Abstract_ADBDGL_Adapter):
"""ArangoDB-DGL adapter.
:param conn: Connection details to an ArangoDB instance.
- :type conn: dict
- :param controller_class: The ArangoDB-DGL controller, for controlling how ArangoDB attributes are converted into DGL features, and vice-versa. Optionally re-defined by the user if needed (otherwise defaults to Base_ADBDGL_Controller).
- :type controller_class: Base_ADBDGL_Controller
+ :type conn: adbdgl_adapter.typings.Json
+ :param controller: The ArangoDB-DGL controller, for controlling how
+ ArangoDB attributes are converted into DGL features, and vice-versa.
+ Optionally re-defined by the user if needed (otherwise defaults to
+ ADBDGL_Controller).
+ :type controller: adbdgl_adapter.controller.ADBDGL_Controller
:raise ValueError: If missing required keys in conn
"""
def __init__(
self,
- conn: dict,
- controller_class: Base_ADBDGL_Controller = Base_ADBDGL_Controller,
+ conn: Json,
+ controller: ADBDGL_Controller = ADBDGL_Controller(),
):
self.__validate_attributes("connection", set(conn), self.CONNECTION_ATRIBS)
- if issubclass(controller_class, Base_ADBDGL_Controller) is False:
- msg = "controller_class must inherit from Base_ADBDGL_Controller"
+ if issubclass(type(controller), ADBDGL_Controller) is False:
+ msg = "controller must inherit from ADBDGL_Controller"
raise TypeError(msg)
- username = conn["username"]
- password = conn["password"]
- db_name = conn["dbName"]
-
- protocol = conn.get("protocol", "https")
- host = conn["hostname"]
+ username: str = conn["username"]
+ password: str = conn["password"]
+ db_name: str = conn["dbName"]
+ host: str = conn["hostname"]
+ protocol: str = conn.get("protocol", "https")
port = str(conn.get("port", 8529))
url = protocol + "://" + host + ":" + port
print(f"Connecting to {url}")
self.__db = ArangoClient(hosts=url).db(db_name, username, password, verify=True)
- self.__cntrl: Base_ADBDGL_Controller = controller_class()
+ self.__cntrl: ADBDGL_Controller = controller
def arangodb_to_dgl(
- self,
- name: str,
- metagraph: dict,
- **query_options,
- ):
- """Create a DGL graph from user-defined metagraph.
+ self, name: str, metagraph: ArangoMetagraph, **query_options: Any
+ ) -> DGLHeteroGraph:
+ """Create a DGLHeteroGraph from the user-defined metagraph.
:param name: The DGL graph name.
:type name: str
- :param metagraph: An object defining vertex & edge collections to import to DGL, along with their associated attributes to keep.
- :type metagraph: dict
- :param query_options: Keyword arguments to specify AQL query options when fetching documents from the ArangoDB instance.
- :type query_options: **kwargs
+ :param metagraph: An object defining vertex & edge collections to import
+ to DGL, along with their associated attributes to keep.
+ :type metagraph: adbdgl_adapter.typings.ArangoMetagraph
+ :param query_options: Keyword arguments to specify AQL query options when
+ fetching documents from the ArangoDB instance.
+ :type query_options: Any
:return: A DGL Heterograph
:rtype: dgl.heterograph.DGLHeteroGraph
:raise ValueError: If missing required keys in metagraph
@@ -85,55 +83,64 @@ def arangodb_to_dgl(
},
"edgeCollections": {
"accountHolder": {},
- "transaction": {},
+ "transaction": {
+ "transaction_amt", "receiver_bank_id", "sender_bank_id"
+ },
},
}
"""
self.__validate_attributes("graph", set(metagraph), self.METAGRAPH_ATRIBS)
- adb_map = dict() # Maps ArangoDB vertex IDs to DGL node IDs
+ # Maps ArangoDB vertex IDs to DGL node IDs
+ adb_map: Dict[str, Dict[str, Any]] = dict()
# Dictionaries for constructing a heterogeneous graph.
- data_dict = dict()
- ndata = defaultdict(lambda: defaultdict(list))
- edata = defaultdict(lambda: defaultdict(list))
+ data_dict: DGLDataDict = dict()
+ ndata: DefaultDict[Any, Any] = defaultdict(lambda: defaultdict(list))
+ edata: DefaultDict[Any, Any] = defaultdict(lambda: defaultdict(list))
+ adb_v: Json
for v_col, atribs in metagraph["vertexCollections"].items():
- for i, v in enumerate(self.__fetch_adb_docs(v_col, atribs, query_options)):
- adb_map[v["_id"]] = {
+ for i, adb_v in enumerate(
+ self.__fetch_adb_docs(v_col, atribs, query_options)
+ ):
+ adb_map[adb_v["_id"]] = {
"id": i,
"col": v_col,
}
- self.__prepare_dgl_features(ndata, atribs, v, v_col)
+ self.__prepare_dgl_features(ndata, atribs, adb_v, v_col)
- from_col = set()
- to_col = set()
+ adb_e: Json
+ from_col: Set[str] = set()
+ to_col: Set[str] = set()
for e_col, atribs in metagraph["edgeCollections"].items():
- from_nodes = []
- to_nodes = []
- for e in self.__fetch_adb_docs(e_col, atribs, query_options):
- from_node = adb_map[e["_from"]]
- to_node = adb_map[e["_to"]]
+ from_nodes: List[int] = []
+ to_nodes: List[int] = []
+ for adb_e in self.__fetch_adb_docs(e_col, atribs, query_options):
+ from_node = adb_map[adb_e["_from"]]
+ to_node = adb_map[adb_e["_to"]]
from_col.add(from_node["col"])
to_col.add(to_node["col"])
if len(from_col | to_col) > 2:
raise ValueError(
- f"Can't convert to DGL: too many '_from' & '_to' collections in {e_col}"
+ f"""Can't convert to DGL:
+ too many '_from' & '_to' collections in {e_col}
+ """
)
from_nodes.append(from_node["id"])
to_nodes.append(to_node["id"])
- self.__prepare_dgl_features(edata, atribs, e, e_col)
+ self.__prepare_dgl_features(edata, atribs, adb_e, e_col)
data_dict[(from_col.pop(), e_col, to_col.pop())] = (
- torch.tensor(from_nodes),
- torch.tensor(to_nodes),
+ tensor(from_nodes),
+ tensor(to_nodes),
)
- dgl_g: DGLHeteroGraph = dgl.heterograph(data_dict)
+ dgl_g: DGLHeteroGraph = heterograph(data_dict)
has_one_ntype = len(dgl_g.ntypes) == 1
has_one_etype = len(dgl_g.etypes) == 1
@@ -146,37 +153,40 @@ def arangodb_to_dgl(
def arangodb_collections_to_dgl(
self,
name: str,
- vertex_collections: set,
- edge_collections: set,
- **query_options,
- ):
+ v_cols: Set[str],
+ e_cols: Set[str],
+ **query_options: Any,
+ ) -> DGLHeteroGraph:
"""Create a DGL graph from ArangoDB collections.
:param name: The DGL graph name.
:type name: str
- :param vertex_collections: A set of ArangoDB vertex collections to import to DGL.
- :type vertex_collections: set
- :param edge_collections: A set of ArangoDB edge collections to import to DGL.
- :type edge_collections: set
- :param query_options: Keyword arguments to specify AQL query options when fetching documents from the ArangoDB instance.
- :type query_options: **kwargs
+ :param v_cols: A set of ArangoDB vertex collections to
+ import to DGL.
+ :type v_cols: Set[str]
+ :param e_cols: A set of ArangoDB edge collections to import to DGL.
+ :type e_cols: Set[str]
+ :param query_options: Keyword arguments to specify AQL query options
+ when fetching documents from the ArangoDB instance.
+ :type query_options: Any
:return: A DGL Heterograph
:rtype: dgl.heterograph.DGLHeteroGraph
"""
- metagraph = {
- "vertexCollections": {col: {} for col in vertex_collections},
- "edgeCollections": {col: {} for col in edge_collections},
+ metagraph: ArangoMetagraph = {
+ "vertexCollections": {col: set() for col in v_cols},
+ "edgeCollections": {col: set() for col in e_cols},
}
return self.arangodb_to_dgl(name, metagraph, **query_options)
- def arangodb_graph_to_dgl(self, name: str, **query_options):
+ def arangodb_graph_to_dgl(self, name: str, **query_options: Any) -> DGLHeteroGraph:
"""Create a DGL graph from an ArangoDB graph.
:param name: The ArangoDB graph name.
:type name: str
- :param query_options: Keyword arguments to specify AQL query options when fetching documents from the ArangoDB instance.
- :type query_options: **kwargs
+ :param query_options: Keyword arguments to specify AQL query options
+ when fetching documents from the ArangoDB instance.
+ :type query_options: Any
:return: A DGL Heterograph
:rtype: dgl.heterograph.DGLHeteroGraph
"""
@@ -188,7 +198,7 @@ def arangodb_graph_to_dgl(self, name: str, **query_options):
def dgl_to_arangodb(
self, name: str, dgl_g: Union[DGLGraph, DGLHeteroGraph], batch_size: int = 1000
- ):
+ ) -> ArangoDBGraph:
"""Create an ArangoDB graph from a DGL graph.
:param name: The ArangoDB graph name.
@@ -200,9 +210,9 @@ def dgl_to_arangodb(
:return: The ArangoDB Graph API wrapper.
:rtype: arango.graph.Graph
"""
- is_default_type = dgl_g.canonical_etypes == self.DEFAULT_CANONICAL_ETYPE
- adb_v_cols = [name + dgl_g.ntypes[0]] if is_default_type else dgl_g.ntypes
- adb_e_cols = [name + dgl_g.etypes[0]] if is_default_type else dgl_g.etypes
+ is_default = dgl_g.canonical_etypes == self.DEFAULT_CANONICAL_ETYPE
+ adb_v_cols: List[str] = [name + dgl_g.ntypes[0]] if is_default else dgl_g.ntypes
+ adb_e_cols: List[str] = [name + dgl_g.etypes[0]] if is_default else dgl_g.etypes
e_definitions = self.etypes_to_edefinitions(
[
(
@@ -211,16 +221,16 @@ def dgl_to_arangodb(
adb_v_cols[0],
)
]
- if is_default_type
+ if is_default
else dgl_g.canonical_etypes
)
has_one_ntype = len(dgl_g.ntypes) == 1
has_one_etype = len(dgl_g.etypes) == 1
- adb_documents = defaultdict(list)
+ adb_documents: DefaultDict[str, List[Json]] = defaultdict(list)
for v_col in adb_v_cols:
- ntype = None if is_default_type else v_col
+ ntype = None if is_default else v_col
v_col_docs = adb_documents[v_col]
if self.__db.has_collection(v_col) is False:
@@ -228,7 +238,7 @@ def dgl_to_arangodb(
node: Tensor
for node in dgl_g.nodes(ntype):
- dgl_node_id: int = node.item()
+ dgl_node_id = node.item()
adb_vertex = {"_key": str(dgl_node_id)}
self.__prepare_adb_attributes(
dgl_g.ndata,
@@ -246,13 +256,13 @@ def dgl_to_arangodb(
from_nodes: Tensor
to_nodes: Tensor
for e_col in adb_e_cols:
- etype = None if is_default_type else e_col
+ etype = None if is_default else e_col
e_col_docs = adb_documents[e_col]
if self.__db.has_collection(e_col) is False:
self.__db.create_collection(e_col, edge=True)
- if is_default_type:
+ if is_default:
from_col = to_col = adb_v_cols[0]
else:
from_col, _, to_col = dgl_g.to_canonical_etype(e_col)
@@ -286,26 +296,29 @@ def dgl_to_arangodb(
print(f"ArangoDB: {name} created")
return adb_graph
- def etypes_to_edefinitions(self, canonical_etypes: list) -> list:
+ def etypes_to_edefinitions(
+ self, canonical_etypes: List[DGLCanonicalEType]
+ ) -> List[Json]:
"""Converts a DGL graph's canonical_etypes property to ArangoDB graph edge definitions
- :param canonical_etypes: A list of string triplets (str, str, str) for source node type, edge type and destination node type.
- :type canonical_etypes: list[tuple]
+ :param canonical_etypes: A list of string triplets (str, str, str) for
+ source node type, edge type and destination node type.
+ :type canonical_etypes: List[adbdgl_adapter.typings.DGLCanonicalEType]
:return: ArangoDB Edge Definitions
- :rtype: list[dict[str, Union[str, list[str]]]]
+ :rtype: List[adbdgl_adapter.typings.Json]
Here is an example of **edge_definitions**:
.. code-block:: python
[
{
- "edge_collection": "teach",
- "from_vertex_collections": ["teachers"],
- "to_vertex_collections": ["lectures"]
+ "edge_collection": "teaches",
+ "from_vertex_collections": ["Teacher"],
+ "to_vertex_collections": ["Lecture"]
}
]
"""
- edge_definitions = []
+ edge_definitions: List[Json] = []
for dgl_from, dgl_e, dgl_to in canonical_etypes:
edge_definitions.append(
{
@@ -319,91 +332,99 @@ def etypes_to_edefinitions(self, canonical_etypes: list) -> list:
def __prepare_dgl_features(
self,
- features_data: defaultdict,
- attributes: set,
- doc: dict,
+ features_data: DefaultDict[Any, Any],
+ attributes: Set[str],
+ doc: Json,
col: str,
- ):
+ ) -> None:
"""Convert a set of ArangoDB attributes into valid DGL features
:param features_data: A dictionary storing the DGL features formatted as lists.
- :type features_data: defaultdict[Any, defaultdict[Any, list]]
- :param col: The collection the current document belongs to
- :type col: str
+ :type features_data: Defaultdict[Any, Any]
:param attributes: A set of ArangoDB attribute keys to convert into DGL features
- :type attributes: set
+ :type attributes: Set[str]
:param doc: The current ArangoDB document
- :type doc: dict
-
+ :type doc: adbdgl_adapter.typings.Json
+ :param col: The collection the current document belongs to
+ :type col: str
"""
key: str
for key in attributes:
- arr: list = features_data[key][col]
+ arr: List[Any] = features_data[key][col]
arr.append(
- self.__cntrl._adb_attribute_to_dgl_feature(key, col, doc.get(key, -1))
+ self.__cntrl._adb_attribute_to_dgl_feature(key, col, doc.get(key, None))
)
def __insert_dgl_features(
self,
- features_data: defaultdict,
+ features_data: DefaultDict[Any, Any],
data: Union[HeteroNodeDataView, HeteroEdgeDataView],
has_one_type: bool,
- ):
+ ) -> None:
"""Insert valid DGL features into a DGL graph.
:param features_data: A dictionary storing the DGL features formatted as lists.
- :type features_data: defaultdict[Any, defaultdict[Any, list]]
- :param data: The (empty) ndata or edata instance attribute of a dgl graph, which is about to receive the **features_data**.
- :type data: Union[HeteroNodeDataView, HeteroEdgeDataView]
- :param has_one_type: Set to True if the DGL graph only has one ntype, or one etype.
+ :type features_data: Defaultdict[Any, Any]
+ :param data: The (empty) ndata or edata instance attribute of a dgl graph,
+ which is about to receive **features_data**.
+ :type data: Union[dgl.view.HeteroNodeDataView, dgl.view.HeteroEdgeDataView]
+ :param has_one_type: Set to True if the DGL graph only has one ntype,
+ or one etype.
:type has_one_type: bool
"""
- col_dict: dict
+ col_dict: Dict[str, List[Any]]
for key, col_dict in features_data.items():
for col, array in col_dict.items():
data[key] = (
- torch.tensor(array)
- if has_one_type
- else {**data[key], col: torch.tensor(array)}
+ tensor(array) if has_one_type else {**data[key], col: tensor(array)}
)
def __prepare_adb_attributes(
self,
data: Union[HeteroNodeDataView, HeteroEdgeDataView],
- features: set,
- id: int,
- doc: dict,
+ features: Set[Any],
+ id: Union[int, float, bool],
+ doc: Json,
col: str,
has_one_type: bool,
- ):
+ ) -> None:
"""Convert DGL features into a set of ArangoDB attributes for a given document
- :param data: The ndata or edata instance attribute of a dgl graph, filled with node or edge feature data.
- :type data: Union[HeteroNodeDataView, HeteroEdgeDataView]
+ :param data: The ndata or edata instance attribute of a dgl graph, filled with
+ node or edge feature data.
+ :type data: Union[dgl.view.HeteroNodeDataView, dgl.view.HeteroEdgeDataView]
:param features: A set of DGL feature keys to convert into ArangoDB attributes
- :type features: set
+ :type features: Set[Any]
:param id: The ID of the current DGL node / edge
- :type id: int
+ :type id: Union[int, float, bool]
:param doc: The current ArangoDB document
- :type doc: dict
+ :type doc: adbdgl_adapter.typings.Json
:param col: The collection the current document belongs to
:type col: str
- :param has_one_type: Set to True if the DGL graph only has one ntype, or one etype.
+ :param has_one_type: Set to True if the DGL graph only has one ntype,
+ or one etype.
:type has_one_type: bool
"""
for key in features:
tensor = data[key] if has_one_type else data[key][col]
doc[key] = self.__cntrl._dgl_feature_to_adb_attribute(key, col, tensor[id])
- def __insert_adb_docs(self, col: str, col_docs: list, doc: dict, batch_size: int):
- """Insert an ArangoDB document into a list. If the list exceeds batch_size documents, insert into the ArangoDB collection.
+ def __insert_adb_docs(
+ self,
+ col: str,
+ col_docs: List[Json],
+ doc: Json,
+ batch_size: int,
+ ) -> None:
+ """Insert an ArangoDB document into a list. If the list exceeds
+ batch_size documents, insert into the ArangoDB collection.
:param col: The collection name
:type col: str
:param col_docs: The existing documents data belonging to the collection.
- :type col_docs: list
+ :type col_docs: List[adbdgl_adapter.typings.Json]
:param doc: The current document to insert.
- :type doc: dict
+ :type doc: adbdgl_adapter.typings.Json
:param batch_size: The maximum number of documents to insert at once
:type batch_size: int
"""
@@ -413,38 +434,45 @@ def __insert_adb_docs(self, col: str, col_docs: list, doc: dict, batch_size: int
self.__db.collection(col).import_bulk(col_docs, on_duplicate="replace")
col_docs.clear()
- def __fetch_adb_docs(self, col: str, attributes: set, query_options: dict):
+ def __fetch_adb_docs(
+ self, col: str, attributes: Set[str], query_options: Any
+ ) -> Result[Cursor]:
"""Fetches ArangoDB documents within a collection.
:param col: The ArangoDB collection.
:type col: str
:param attributes: The set of document attributes.
- :type attributes: set
- :param query_options: Keyword arguments to specify AQL query options when fetching documents from the ArangoDB instance.
- :type query_options: **kwargs
+ :type attributes: Set[str]
+ :param query_options: Keyword arguments to specify AQL query options
+ when fetching documents from the ArangoDB instance.
+ :type query_options: Any
:return: Result cursor.
:rtype: arango.cursor.Cursor
"""
aql = f"""
FOR doc IN {col}
RETURN MERGE(
- KEEP(doc, {list(attributes)}),
- {{"_id": doc._id}},
+ KEEP(doc, {list(attributes)}),
+ {{"_id": doc._id}},
doc._from ? {{"_from": doc._from, "_to": doc._to}}: {{}}
)
"""
return self.__db.aql.execute(aql, **query_options)
- def __validate_attributes(self, type: str, attributes: set, valid_attributes: set):
- """Validates that a set of attributes includes the required valid attributes.
+ def __validate_attributes(
+ self, type: str, attributes: Set[str], valid_attributes: Set[str]
+ ) -> None:
+ """Validates that a set of attributes includes the required valid
+ attributes.
- :param type: The context of the attribute validation (e.g connection attributes, graph attributes, etc).
+ :param type: The context of the attribute validation
+ (e.g connection attributes, graph attributes, etc).
:type type: str
:param attributes: The provided attributes, possibly invalid.
- :type attributes: set
+ :type attributes: Set[str]
:param valid_attributes: The valid attributes.
- :type valid_attributes: set
+ :type valid_attributes: Set[str]
:raise ValueError: If **valid_attributes** is not a subset of **attributes**
"""
if valid_attributes.issubset(attributes) is False:
diff --git a/adbdgl_adapter/adbdgl_adapter/abc.py b/adbdgl_adapter/adbdgl_adapter/abc.py
deleted file mode 100644
index 9caab5d..0000000
--- a/adbdgl_adapter/adbdgl_adapter/abc.py
+++ /dev/null
@@ -1,69 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
-"""
-@author: Anthony Mahanna
-"""
-
-from abc import ABC
-
-
-class ADBDGL_Adapter(ABC):
- def __init__(self):
- raise NotImplementedError() # pragma: no cover
-
- def arangodb_to_dgl(self):
- raise NotImplementedError() # pragma: no cover
-
- def arangodb_collections_to_dgl(self):
- raise NotImplementedError() # pragma: no cover
-
- def arangodb_graph_to_dgl(self):
- raise NotImplementedError() # pragma: no cover
-
- def dgl_to_arangodb(self):
- raise NotImplementedError() # pragma: no cover
-
- def etypes_to_edefinitions(self):
- raise NotImplementedError() # pragma: no cover
-
- def __prepare_dgl_features(self):
- raise NotImplementedError() # pragma: no cover
-
- def __insert_dgl_features(self):
- raise NotImplementedError() # pragma: no cover
-
- def __prepare_adb_attributes(self):
- raise NotImplementedError() # pragma: no cover
-
- def __insert_adb_docs(self):
- raise NotImplementedError() # pragma: no cover
-
- def __fetch_adb_docs(self):
- raise NotImplementedError() # pragma: no cover
-
- def __validate_attributes(self):
- raise NotImplementedError() # pragma: no cover
-
- @property
- def DEFAULT_CANONICAL_ETYPE(self):
- return [("_N", "_E", "_N")]
-
- @property
- def CONNECTION_ATRIBS(self):
- return {"hostname", "username", "password", "dbName"}
-
- @property
- def METAGRAPH_ATRIBS(self):
- return {"vertexCollections", "edgeCollections"}
-
- @property
- def EDGE_DEFINITION_ATRIBS(self):
- return {"edge_collection", "from_vertex_collections", "to_vertex_collections"}
-
-
-class ADBDGL_Controller(ABC):
- def _adb_attribute_to_dgl_feature(self):
- raise NotImplementedError() # pragma: no cover
-
- def _dgl_feature_to_adb_attribute(self):
- raise NotImplementedError() # pragma: no cover
diff --git a/adbdgl_adapter/adbdgl_adapter/adbdgl_controller.py b/adbdgl_adapter/adbdgl_adapter/adbdgl_controller.py
deleted file mode 100644
index ca0497b..0000000
--- a/adbdgl_adapter/adbdgl_adapter/adbdgl_controller.py
+++ /dev/null
@@ -1,49 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
-from .abc import ADBDGL_Controller
-from collections import defaultdict
-from torch.functional import Tensor
-
-"""
-
-@author: Anthony Mahanna
-"""
-
-
-class Base_ADBDGL_Controller(ADBDGL_Controller):
- """ArangoDB-DGL controller.
-
- Responsible for controlling how ArangoDB attributes
- are converted into DGL features, and vice-versa.
-
- You can derive your own custom ADBDGL_Controller if you want to maintain
- consistency between your ArangoDB attributes & your DGL features.
- """
-
- def _adb_attribute_to_dgl_feature(self, key: str, col: str, val):
- """
- Given an ArangoDB attribute key, its assigned value (for an arbitrary document),
- and the collection it belongs to, convert it to a valid
- DGL feature: https://docs.dgl.ai/en/0.6.x/guide/graph-feature.html.
-
- NOTE: You must override this function if you want to transfer non-numerical ArangoDB
- attributes to DGL (DGL only accepts 'attributes' (a.k.a features) of numerical types).
- Read more about DGL features here: https://docs.dgl.ai/en/0.6.x/new-tutorial/2_dglgraph.html#assigning-node-and-edge-features-to-graph.
- """
- try:
- return float(val)
- except:
- return 0
-
- def _dgl_feature_to_adb_attribute(self, key: str, col: str, val: Tensor):
- """
- Given a DGL feature key, its assigned value (for an arbitrary node or edge),
- and the collection it belongs to, convert it to a valid ArangoDB attribute (e.g string, list, number, ...).
-
- NOTE: No action is needed here if you want to keep the numerical-based values of your DGL features.
- """
- try:
- return val.item()
- except ValueError:
- print("HERERERERE")
- return val.tolist()
diff --git a/adbdgl_adapter/controller.py b/adbdgl_adapter/controller.py
new file mode 100644
index 0000000..bd7e8d8
--- /dev/null
+++ b/adbdgl_adapter/controller.py
@@ -0,0 +1,70 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+from typing import Any
+
+from torch.functional import Tensor
+
+from .abc import Abstract_ADBDGL_Controller
+
+
+class ADBDGL_Controller(Abstract_ADBDGL_Controller):
+ """ArangoDB-DGL controller.
+
+ Responsible for controlling how ArangoDB attributes
+ are converted into DGL features, and vice-versa.
+
+ You can derive your own custom ADBDGL_Controller if you want to maintain
+ consistency between your ArangoDB attributes & your DGL features.
+ """
+
+ def _adb_attribute_to_dgl_feature(self, key: str, col: str, val: Any) -> Any:
+ """
+ Given an ArangoDB attribute key, its assigned value (for an arbitrary document),
+ and the collection it belongs to, convert it to a valid
+ DGL feature: https://docs.dgl.ai/en/0.6.x/guide/graph-feature.html.
+
+ NOTE: You must override this function if you want to transfer non-numerical
+ ArangoDB attributes to DGL (DGL only accepts 'attributes' (a.k.a features)
+ of numerical types). Read more about DGL features here:
+ https://docs.dgl.ai/en/0.6.x/new-tutorial/2_dglgraph.html#assigning-node-and-edge-features-to-graph.
+
+ :param key: The ArangoDB attribute key name
+ :type key: str
+ :param col: The ArangoDB collection of the ArangoDB document.
+ :type col: str
+ :param val: The assigned attribute value of the ArangoDB document.
+ :type val: Any
+ :return: The attribute's representation as a DGL Feature
+ :rtype: Any
+ """
+ if type(val) in [int, float, bool]:
+ return val
+
+ try:
+ return float(val)
+ except (ValueError, TypeError, SyntaxError):
+ return 0
+
+ def _dgl_feature_to_adb_attribute(self, key: str, col: str, val: Tensor) -> Any:
+ """
+ Given a DGL feature key, its assigned value (for an arbitrary node or edge),
+ and the collection it belongs to, convert it to a valid ArangoDB attribute
+ (e.g string, list, number, ...).
+
+ NOTE: No action is needed here if you want to keep the numerical-based values
+ of your DGL features.
+
+ :param key: The DGL attribute key name
+ :type key: str
+ :param col: The ArangoDB collection of the (soon-to-be) ArangoDB document.
+ :type col: str
+ :param val: The assigned attribute value of the DGL node.
+ :type val: Tensor
+ :return: The feature's representation as an ArangoDB Attribute
+ :rtype: Any
+ """
+ try:
+ return val.item()
+ except ValueError:
+ return val.tolist()
diff --git a/adbdgl_adapter/setup.cfg b/adbdgl_adapter/setup.cfg
deleted file mode 100644
index 5df7f3d..0000000
--- a/adbdgl_adapter/setup.cfg
+++ /dev/null
@@ -1,7 +0,0 @@
-[metadata]
-description_file = README.md
-
-[tool:pytest]
-markers =
- unit: Marks a unit test
-testpaths = tests
\ No newline at end of file
diff --git a/adbdgl_adapter/tests/conftest.py b/adbdgl_adapter/tests/conftest.py
deleted file mode 100755
index d6e35b9..0000000
--- a/adbdgl_adapter/tests/conftest.py
+++ /dev/null
@@ -1,119 +0,0 @@
-import os
-import time
-import json
-import requests
-import subprocess
-from pathlib import Path
-
-import torch
-from dgl import remove_self_loop
-from dgl.data import KarateClubDataset
-from dgl.data import MiniGCDataset
-
-from arango import ArangoClient
-from adbdgl_adapter.adbdgl_adapter import ArangoDB_DGL_Adapter
-
-PROJECT_DIR = Path(__file__).parent.parent.parent
-
-
-def pytest_sessionstart():
- global conn
- conn = get_oasis_crendetials()
- # conn = {
- # "username": "root",
- # "password": "openSesame",
- # "hostname": "localhost",
- # "port": 8529,
- # "protocol": "http",
- # "dbName": "_system",
- # }
- print_connection_details(conn)
- time.sleep(5) # Enough for the oasis instance to be ready.
-
- global adbdgl_adapter
- adbdgl_adapter = ArangoDB_DGL_Adapter(conn)
-
- global db
- url = (
- conn.get("protocol", "https")
- + "://"
- + conn["hostname"]
- + ":"
- + str(conn["port"])
- )
- client = ArangoClient(hosts=url)
- db = client.db(conn["dbName"], conn["username"], conn["password"], verify=True)
-
- arango_restore("examples/data/fraud_dump")
- db.create_graph(
- "fraud-detection",
- edge_definitions=[
- {
- "edge_collection": "accountHolder",
- "from_vertex_collections": ["customer"],
- "to_vertex_collections": ["account"],
- },
- {
- "edge_collection": "transaction",
- "from_vertex_collections": ["account"],
- "to_vertex_collections": ["account"],
- },
- ],
- )
-
-
-def get_oasis_crendetials() -> dict:
- url = "https://tutorials.arangodb.cloud:8529/_db/_system/tutorialDB/tutorialDB"
- request = requests.post(url, data=json.dumps("{}"))
- if request.status_code != 200:
- raise Exception("Error retrieving login data.")
-
- return json.loads(request.text)
-
-
-def arango_restore(path_to_data):
- restore_prefix = "./assets/" if os.getenv("GITHUB_ACTIONS") else ""
-
- subprocess.check_call(
- f'chmod -R 755 ./assets/arangorestore && {restore_prefix}arangorestore -c none --server.endpoint http+ssl://{conn["hostname"]}:{conn["port"]} --server.username {conn["username"]} --server.database {conn["dbName"]} --server.password {conn["password"]} --default-replication-factor 3 --input-directory "{PROJECT_DIR}/{path_to_data}"',
- cwd=f"{PROJECT_DIR}/adbdgl_adapter/tests",
- shell=True,
- )
-
-
-def print_connection_details(conn):
- print("----------------------------------------")
- print("https://{}:{}".format(conn["hostname"], conn["port"]))
- print("Username: " + conn["username"])
- print("Password: " + conn["password"])
- print("Database: " + conn["dbName"])
- print("----------------------------------------")
-
-
-def get_karate_graph():
- return KarateClubDataset()[0]
-
-
-def get_lollipop_graph():
- dgl_g = remove_self_loop(MiniGCDataset(8, 7, 8)[3][0])
- dgl_g.ndata["random_ndata"] = torch.tensor(
- [[i, i, i] for i in range(0, dgl_g.num_nodes())]
- )
- dgl_g.edata["random_edata"] = torch.rand(dgl_g.num_edges())
- return dgl_g
-
-
-def get_hypercube_graph():
- dgl_g = remove_self_loop(MiniGCDataset(8, 8, 9)[4][0])
- dgl_g.ndata["random_ndata"] = torch.rand(dgl_g.num_nodes())
- dgl_g.edata["random_edata"] = torch.tensor(
- [[[i], [i], [i]] for i in range(0, dgl_g.num_edges())]
- )
- return dgl_g
-
-
-def get_clique_graph():
- dgl_g = remove_self_loop(MiniGCDataset(8, 6, 7)[6][0])
- dgl_g.ndata["random_ndata"] = torch.ones(dgl_g.num_nodes())
- dgl_g.edata["random_edata"] = torch.zeros(dgl_g.num_edges())
- return dgl_g
diff --git a/adbdgl_adapter/typings.py b/adbdgl_adapter/typings.py
new file mode 100644
index 0000000..c3a7015
--- /dev/null
+++ b/adbdgl_adapter/typings.py
@@ -0,0 +1,12 @@
+__all__ = ["Json", "ArangoMetagraph", "DGLCanonicalEType"]
+
+from typing import Any, Dict, Set, Tuple
+
+from torch.functional import Tensor
+
+Json = Dict[str, Any]
+ArangoMetagraph = Dict[str, Dict[str, Set[str]]]
+
+
+DGLCanonicalEType = Tuple[str, str, str]
+DGLDataDict = Dict[DGLCanonicalEType, Tuple[Tensor, Tensor]]
diff --git a/examples/ArangoDB_DGL_Adapter.ipynb b/examples/ArangoDB_DGL_Adapter.ipynb
index 3dd6ab4..57d2ae3 100644
--- a/examples/ArangoDB_DGL_Adapter.ipynb
+++ b/examples/ArangoDB_DGL_Adapter.ipynb
@@ -36,7 +36,7 @@
"source": [
"Version: 1.0.0\n",
"\n",
- "Objective: Export Graphs from [ArangoDB](https://www.arangodb.com/), a multi-model Graph Database, into [Deep Graph Library](https://www.dgl.ai/) (DGL), a python package for graph neural networks, and vice-versa."
+ "Objective: Export Graphs from [ArangoDB](https://www.arangodb.com/), a multi-model Graph Database, to [Deep Graph Library](https://www.dgl.ai/) (DGL), a python package for graph neural networks, and vice-versa."
]
},
{
@@ -58,10 +58,10 @@
"source": [
"%%capture\n",
"!git clone -b oasis_connector --single-branch https://github.com/arangodb/interactive_tutorials.git\n",
- "!git clone https://github.com/arangoml/dgl-adapter.git # !git clone -b 1.0.0 --single-branch https://github.com/arangoml/dgl-adapter.git\n",
+ "!git clone -b 1.0.0 --single-branch https://github.com/arangoml/dgl-adapter.git\n",
"!rsync -av dgl-adapter/examples/ ./ --exclude=.git\n",
"!rsync -av interactive_tutorials/ ./ --exclude=.git\n",
- "!pip3 install \"git+https://github.com/arangoml/dgl-adapter.git#egg=adbdgl_adapter&subdirectory=adbdgl_adapter\" # pip3 install adbdgl_adapter==1.0.0\n",
+ "!pip3 install adbdgl_adapter==1.0.0\n",
"!pip3 install matplotlib\n",
"!pip3 install pyArango\n",
"!pip3 install networkx ## For drawing purposes "
@@ -71,7 +71,11 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
- "id": "RpqvL4COeG8-"
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "RpqvL4COeG8-",
+ "outputId": "2df55e4e-03fa-47ed-c2c9-baf9f597e1d8"
},
"outputs": [],
"source": [
@@ -87,8 +91,9 @@
"from dgl.data import KarateClubDataset\n",
"from dgl.data import MiniGCDataset\n",
"\n",
- "from adbdgl_adapter.adbdgl_adapter import ArangoDB_DGL_Adapter\n",
- "from adbdgl_adapter.adbdgl_controller import Base_ADBDGL_Controller"
+ "from adbdgl_adapter.adapter import ADBDGL_Adapter\n",
+ "from adbdgl_adapter.controller import ADBDGL_Controller\n",
+ "from adbdgl_adapter.typings import Json, ArangoMetagraph, DGLCanonicalEType, DGLDataDict"
]
},
{
@@ -124,7 +129,7 @@
"base_uri": "https://localhost:8080/"
},
"id": "vf0350qvj8up",
- "outputId": "9c2e9905-7272-44f6-8e59-e5f568a57758"
+ "outputId": "a65f00d2-cd6e-4583-94d8-2c9884e2e2e2"
},
"outputs": [],
"source": [
@@ -157,7 +162,7 @@
"base_uri": "https://localhost:8080/"
},
"id": "oOS3AVAnkQEV",
- "outputId": "9589b7b3-0867-4ff2-9c9f-8d0f38633490"
+ "outputId": "4609cdef-25ce-4f00-94b5-482c76274f88"
},
"outputs": [],
"source": [
@@ -193,7 +198,7 @@
"base_uri": "https://localhost:8080/"
},
"id": "meLon-KgkU4h",
- "outputId": "9f2f8081-393f-4a1b-9ff7-3f6c13289a62"
+ "outputId": "976680a4-eadd-43f2-da17-e6a574fad8a7"
},
"outputs": [],
"source": [
@@ -231,7 +236,7 @@
"base_uri": "https://localhost:8080/"
},
"id": "zTebQ0LOlsGA",
- "outputId": "0f3d26db-1b50-4d65-8385-8d0ad7147ec5"
+ "outputId": "9c84cb84-f7ce-42b3-9174-01f38295c5dd"
},
"outputs": [],
"source": [
@@ -274,7 +279,7 @@
"base_uri": "https://localhost:8080/"
},
"id": "KsxNujb0mSqZ",
- "outputId": "0b10fb67-5193-49ef-8aee-d691f53fe5bf"
+ "outputId": "3f3fd2b1-e1d3-4b03-c6c4-43566672cbb5"
},
"outputs": [],
"source": [
@@ -317,7 +322,7 @@
"base_uri": "https://localhost:8080/"
},
"id": "2ekGwnJDeG8-",
- "outputId": "fb348f69-8321-40b3-9bf5-0085ea218492"
+ "outputId": "92e9d288-0259-45cc-e73d-a8e9f629063a"
},
"outputs": [],
"source": [
@@ -377,16 +382,13 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "7bgGJ3QkeG8_",
- "outputId": "93451001-15f8-463d-8341-5c65fe6bc178"
+ "id": "7bgGJ3QkeG8_"
},
"outputs": [],
"source": [
+ "%%capture\n",
"!chmod -R 755 ./tools\n",
- "!./tools/arangorestore -c none --server.endpoint http+ssl://{con[\"hostname\"]}:{con[\"port\"]} --server.username {con[\"username\"]} --server.database {con[\"dbName\"]} --server.password {con[\"password\"]} --default-replication-factor 3 --input-directory \"data/fraud_dump\""
+ "!./tools/arangorestore -c none --server.endpoint http+ssl://{con[\"hostname\"]}:{con[\"port\"]} --server.username {con[\"username\"]} --server.database {con[\"dbName\"]} --server.password {con[\"password\"]} --replication-factor 3 --input-directory \"data/fraud_dump\""
]
},
{
@@ -424,7 +426,7 @@
"base_uri": "https://localhost:8080/"
},
"id": "PybHP7jpeG8_",
- "outputId": "724c9f23-c63b-4d34-f2f8-d9b7579cc985"
+ "outputId": "ba3bfc7c-ef56-47e7-8e98-3763d4f34afe"
},
"outputs": [],
"source": [
@@ -484,11 +486,11 @@
"base_uri": "https://localhost:8080/"
},
"id": "oG496kBeeG9A",
- "outputId": "164ea448-1117-4bde-e8e3-220558a1c0e3"
+ "outputId": "50ecbdf5-c82f-4540-d345-14eb4a488f2c"
},
"outputs": [],
"source": [
- "adbdgl_adapter = ArangoDB_DGL_Adapter(con)"
+ "adbdgl_adapter = ADBDGL_Adapter(con)"
]
},
{
@@ -518,7 +520,7 @@
"base_uri": "https://localhost:8080/"
},
"id": "zZ-Hu3lLVHgd",
- "outputId": "945d4971-cc05-4f97-eda3-b9e02cc05df8"
+ "outputId": "39f32c51-0753-45a8-a361-dcf46d4e6148"
},
"outputs": [],
"source": [
@@ -555,7 +557,7 @@
"base_uri": "https://localhost:8080/"
},
"id": "i4XOpdRLUNlJ",
- "outputId": "e445ea58-35ef-41ed-d4ac-717ac6f68e9c"
+ "outputId": "b58e75d1-e935-4abd-9bdb-bc8935d9cdc8"
},
"outputs": [],
"source": [
@@ -579,7 +581,7 @@
{
"cell_type": "markdown",
"metadata": {
- "id": "umy25EsUU6Lg"
+ "id": "qEH6OdSB23Ya"
},
"source": [
"## Via ArangoDB Metagraph"
@@ -592,8 +594,59 @@
"colab": {
"base_uri": "https://localhost:8080/"
},
- "id": "UWX9-MsKeG9A",
- "outputId": "f1ad45d9-d29a-4d7f-a853-b808ac898dfe"
+ "id": "7Kz8lXXq23Yk",
+ "outputId": "1458aef6-14e5-48c0-98bf-77f21431bc73"
+ },
+ "outputs": [],
+ "source": [
+ "# Define Metagraph\n",
+ "fraud_detection_metagraph = {\n",
+ " \"vertexCollections\": {\n",
+ " \"account\": {\"rank\", \"Balance\", \"customer_id\"},\n",
+ " \"Class\": {\"concrete\"},\n",
+ " \"customer\": {\"rank\"},\n",
+ " },\n",
+ " \"edgeCollections\": {\n",
+ " \"accountHolder\": {},\n",
+ " \"Relationship\": {},\n",
+ " \"transaction\": {\"receiver_bank_id\", \"sender_bank_id\", \"transaction_amt\"},\n",
+ " },\n",
+ "}\n",
+ "\n",
+ "# Create DGL Graph from attributes\n",
+ "dgl_g = adbdgl_adapter.arangodb_to_dgl('FraudDetection', fraud_detection_metagraph)\n",
+ "\n",
+ "# You can also provide valid Python-Arango AQL query options to the command above, like such:\n",
+ "# dgl_g = adbdgl_adapter.arangodb_to_dgl(graph_name = 'FraudDetection', fraud_detection_metagraph, ttl=1000, stream=True)\n",
+ "# See more here: https://docs.python-arango.com/en/main/specs.html#arango.aql.AQL.execute\n",
+ "\n",
+ "# Show graph data\n",
+ "print('\\n--------------')\n",
+ "print(dgl_g)\n",
+ "print('\\n--------------')\n",
+ "print(dgl_g.ndata)\n",
+ "print('--------------\\n')\n",
+ "print(dgl_g.edata)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "DqIKT1lO4ASw"
+ },
+ "source": [
+ "## Via ArangoDB Metagraph with a custom controller"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "U4_vSdU_4AS4",
+ "outputId": "b719b0d4-c0a4-43a0-915f-8ee765e1ec86"
},
"outputs": [],
"source": [
@@ -611,9 +664,9 @@
" },\n",
"}\n",
"\n",
- "# When converting to DGL via an ArangoDB metagraph, a user-defined Controller class\n",
- "# is required, to specify how ArangoDB attributes should be converted into DGL features.\n",
- "class FraudDetection_ADBDGL_Controller(Base_ADBDGL_Controller):\n",
+ "# When converting to DGL via an ArangoDB Metagraph that contains non-numerical values, a user-defined \n",
+ "# Controller class is required to specify how ArangoDB attributes should be converted to DGL features.\n",
+ "class FraudDetection_ADBDGL_Controller(ADBDGL_Controller):\n",
" \"\"\"ArangoDB-DGL controller.\n",
"\n",
" Responsible for controlling how ArangoDB attributes\n",
@@ -629,49 +682,59 @@
" and the collection it belongs to, convert it to a valid\n",
" DGL feature: https://docs.dgl.ai/en/0.6.x/guide/graph-feature.html.\n",
"\n",
- " NOTE: You must override this function if you want to transfer non-numerical ArangoDB\n",
- " attributes to DGL (DGL only accepts 'attributes' (a.k.a features) of numerical types).\n",
- " Read more about DGL features here: https://docs.dgl.ai/en/0.6.x/new-tutorial/2_dglgraph.html#assigning-node-and-edge-features-to-graph.\n",
+ " NOTE: You must override this function if you want to transfer non-numerical\n",
+ " ArangoDB attributes to DGL (DGL only accepts 'attributes' (a.k.a features)\n",
+ " of numerical types). Read more about DGL features here:\n",
+ " https://docs.dgl.ai/en/0.6.x/new-tutorial/2_dglgraph.html#assigning-node-and-edge-features-to-graph.\n",
+ "\n",
+ " :param key: The ArangoDB attribute key name\n",
+ " :type key: str\n",
+ " :param col: The ArangoDB collection of the ArangoDB document.\n",
+ " :type col: str\n",
+ " :param val: The assigned attribute value of the ArangoDB document.\n",
+ " :type val: Any\n",
+ " :return: The attribute's representation as a DGL Feature\n",
+ " :rtype: Any\n",
" \"\"\"\n",
- " if type(val) in [int, float, bool]:\n",
- " return val\n",
- "\n",
- " if col == \"transaction\":\n",
- " if key == \"transaction_date\":\n",
- " return int(str(val).replace(\"-\", \"\"))\n",
- " \n",
- " if key == \"trans_time\":\n",
- " return int(str(val).replace(\":\", \"\"))\n",
- " \n",
- " if col == \"customer\":\n",
- " if key == \"Sex\":\n",
- " return 0 if val == \"M\" else 1\n",
- "\n",
- " if key == \"Ssn\":\n",
- " return int(str(val).replace(\"-\", \"\"))\n",
- "\n",
- " if col == \"Class\":\n",
- " if key == \"name\":\n",
- " if val == \"Bank\":\n",
- " return 0\n",
- " elif val == \"Branch\":\n",
- " return 1\n",
- " elif val == \"Account\":\n",
- " return 2\n",
- " elif val == \"Customer\":\n",
- " return 3\n",
- " else:\n",
- " return -1\n",
+ " try:\n",
+ " if col == \"transaction\":\n",
+ " if key == \"transaction_date\":\n",
+ " return int(str(val).replace(\"-\", \"\"))\n",
+ " \n",
+ " if key == \"trans_time\":\n",
+ " return int(str(val).replace(\":\", \"\"))\n",
+ " \n",
+ " if col == \"customer\":\n",
+ " if key == \"Sex\":\n",
+ " return 0 if val == \"M\" else 1\n",
+ "\n",
+ " if key == \"Ssn\":\n",
+ " return int(str(val).replace(\"-\", \"\"))\n",
+ "\n",
+ " if col == \"Class\":\n",
+ " if key == \"name\":\n",
+ " if val == \"Bank\":\n",
+ " return 0\n",
+ " elif val == \"Branch\":\n",
+ " return 1\n",
+ " elif val == \"Account\":\n",
+ " return 2\n",
+ " elif val == \"Customer\":\n",
+ " return 3\n",
+ " else:\n",
+ " return -1\n",
+ " except (ValueError, TypeError, SyntaxError):\n",
+ " return 0\n",
"\n",
" return super()._adb_attribute_to_dgl_feature(key, col, val)\n",
"\n",
- "fraud_adbgl_adapter = ArangoDB_DGL_Adapter(con, FraudDetection_ADBDGL_Controller)\n",
+ "fraud_adbdgl_adapter = ADBDGL_Adapter(con, FraudDetection_ADBDGL_Controller())\n",
"\n",
"# Create DGL Graph from attributes\n",
- "dgl_g = fraud_adbgl_adapter.arangodb_to_dgl('FraudDetection', fraud_detection_metagraph)\n",
+ "dgl_g = fraud_adbdgl_adapter.arangodb_to_dgl('FraudDetection', fraud_detection_metagraph)\n",
"\n",
"# You can also provide valid Python-Arango AQL query options to the command above, like such:\n",
- "# dgl_g = adbdgl_adapter.arangodb_to_dgl(graph_name = 'FraudDetection', fraud_detection_metagraph, ttl=1000, stream=True)\n",
+ "# dgl_g = fraud_adbdgl_adapter.arangodb_to_dgl(graph_name = 'FraudDetection', fraud_detection_metagraph, ttl=1000, stream=True)\n",
"# See more here: https://docs.python-arango.com/en/main/specs.html#arango.aql.AQL.execute\n",
"\n",
"# Show graph data\n",
@@ -707,10 +770,10 @@
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
- "height": 0
+ "height": 577
},
"id": "eRVbiBy4ZdE4",
- "outputId": "bcbfce84-8bc0-4605-82d7-78fbaab53527"
+ "outputId": "d44eb9d9-e046-443b-8ded-79654f004e02"
},
"outputs": [],
"source": [
@@ -723,14 +786,13 @@
"python_arango_db_driver.delete_graph(name, drop_collections=True, ignore_missing=True)\n",
"adb_karate_graph = adbdgl_adapter.dgl_to_arangodb(name, dgl_karate_graph)\n",
"\n",
- "\n",
- "print(f\"\\nInspect the graph here: https://tutorials.arangodb.cloud:8529/_db/{con['dbName']}/_admin/aardvark/index.html#graph/{name}\\n\")\n",
- "\n",
+ "print('\\n--------------------')\n",
"print(\"https://{}:{}\".format(con[\"hostname\"], con[\"port\"]))\n",
"print(\"Username: \" + con[\"username\"])\n",
"print(\"Password: \" + con[\"password\"])\n",
"print(\"Database: \" + con[\"dbName\"])\n",
- "\n",
+ "print('--------------------\\n')\n",
+ "print(f\"\\nInspect the graph here: https://tutorials.arangodb.cloud:8529/_db/{con['dbName']}/_admin/aardvark/index.html#graph/{name}\\n\")\n",
"print(f\"\\nView the original graph below:\")"
]
},
@@ -750,10 +812,10 @@
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
- "height": 0
+ "height": 1000
},
"id": "dADiexlAioGH",
- "outputId": "286375c7-f2c9-4843-fb30-d086994985fc"
+ "outputId": "273988c8-1749-4fe0-85fe-51b0e1ab2058"
},
"outputs": [],
"source": [
@@ -783,16 +845,16 @@
"adb_hypercube_graph = adbdgl_adapter.dgl_to_arangodb(hypercube, dgl_hypercube_graph)\n",
"adb_clique_graph = adbdgl_adapter.dgl_to_arangodb(clique, dgl_clique_graph)\n",
"\n",
- "print(\"\\nInspect the graphs here:\\n\")\n",
- "print(f\"1) https://tutorials.arangodb.cloud:8529/_db/{con['dbName']}/_admin/aardvark/index.html#graph/{lollipop}\")\n",
- "print(f\"2) https://tutorials.arangodb.cloud:8529/_db/{con['dbName']}/_admin/aardvark/index.html#graph/{hypercube}\")\n",
- "print(f\"3) https://tutorials.arangodb.cloud:8529/_db/{con['dbName']}/_admin/aardvark/index.html#graph/{clique}\\n\")\n",
- "\n",
+ "print('\\n--------------------')\n",
"print(\"https://{}:{}\".format(con[\"hostname\"], con[\"port\"]))\n",
"print(\"Username: \" + con[\"username\"])\n",
"print(\"Password: \" + con[\"password\"])\n",
"print(\"Database: \" + con[\"dbName\"])\n",
- "\n",
+ "print('--------------------\\n')\n",
+ "print(\"\\nInspect the graphs here:\\n\")\n",
+ "print(f\"1) https://tutorials.arangodb.cloud:8529/_db/{con['dbName']}/_admin/aardvark/index.html#graph/{lollipop}\")\n",
+ "print(f\"2) https://tutorials.arangodb.cloud:8529/_db/{con['dbName']}/_admin/aardvark/index.html#graph/{hypercube}\")\n",
+ "print(f\"3) https://tutorials.arangodb.cloud:8529/_db/{con['dbName']}/_admin/aardvark/index.html#graph/{clique}\\n\")\n",
"print(f\"\\nView the original graphs below:\")"
]
},
@@ -803,7 +865,7 @@
},
"source": [
"\n",
- "## Example 3: DGL MiniGCDataset Graphs (with attribute transfer)"
+ "## Example 3: DGL MiniGCDataset Graphs with a custom controller"
]
},
{
@@ -814,30 +876,39 @@
"base_uri": "https://localhost:8080/"
},
"id": "jbJsvMMaoJoT",
- "outputId": "f426cdaf-a53c-4ade-d3b6-cbb11dd39c67"
+ "outputId": "2ddca41f-9c8b-4db4-c0aa-c1b2cc124fa5"
},
"outputs": [],
"source": [
"from torch.functional import Tensor\n",
"\n",
- "# Load the dgl graphs & populate node data\n",
+ "# Load the dgl graphs\n",
"dgl_lollipop_graph = remove_self_loop(MiniGCDataset(8, 7, 8)[3][0])\n",
- "dgl_lollipop_graph.ndata['lollipop_ndata'] = torch.ones(7)\n",
- "\n",
"dgl_hypercube_graph = remove_self_loop(MiniGCDataset(8, 8, 9)[4][0])\n",
- "dgl_hypercube_graph.ndata['hypercube_ndata'] = torch.zeros(8)\n",
- "\n",
"dgl_clique_graph = remove_self_loop(MiniGCDataset(8, 6, 7)[6][0])\n",
+ "\n",
+ " # Add DGL Node & Edge Features to each graph\n",
+ "dgl_lollipop_graph.ndata[\"random_ndata\"] = torch.tensor(\n",
+ " [[i, i, i] for i in range(0, dgl_lollipop_graph.num_nodes())]\n",
+ ")\n",
+ "dgl_lollipop_graph.edata[\"random_edata\"] = torch.rand(dgl_lollipop_graph.num_edges())\n",
+ "\n",
+ "dgl_hypercube_graph.ndata[\"random_ndata\"] = torch.rand(dgl_hypercube_graph.num_nodes())\n",
+ "dgl_hypercube_graph.edata[\"random_edata\"] = torch.tensor(\n",
+ " [[[i], [i], [i]] for i in range(0, dgl_hypercube_graph.num_edges())]\n",
+ ")\n",
+ "\n",
"dgl_clique_graph.ndata['clique_ndata'] = torch.tensor([1,2,3,4,5,6])\n",
+ "dgl_clique_graph.edata['clique_edata'] = torch.tensor(\n",
+ " [1 if i % 2 == 0 else 0 for i in range(0, dgl_clique_graph.num_edges())]\n",
+ ")\n",
"\n",
"\n",
"# When converting to ArangoDB from DGL, a user-defined Controller class\n",
"# is required to specify how DGL features (aka attributes) should be converted \n",
- "# into ArangoDB attributes.\n",
- "\n",
- "# NOTE: A custom Controller is NOT needed you want to keep the \n",
- "# numerical-based values of your DGL features (which is the case for dgl_lollipop_graph and dgl_hypercube_graph)\n",
- "class Clique_ADBDGL_Controller(Base_ADBDGL_Controller):\n",
+ "# into ArangoDB attributes. NOTE: A custom Controller is NOT needed you want to\n",
+ "# keep the numerical-based values of your DGL features.\n",
+ "class Clique_ADBDGL_Controller(ADBDGL_Controller):\n",
" \"\"\"ArangoDB-DGL controller.\n",
"\n",
" Responsible for controlling how ArangoDB attributes\n",
@@ -850,26 +921,40 @@
" def _dgl_feature_to_adb_attribute(self, key: str, col: str, val: Tensor):\n",
" \"\"\"\n",
" Given a DGL feature key, its assigned value (for an arbitrary node or edge),\n",
- " and the collection it belongs to, convert it to a valid ArangoDB attribute (e.g string, list, number, ...).\n",
- "\n",
- " NOTE: No action is needed here if you want to keep the numerical-based values of your DGL features.\n",
+ " and the collection it belongs to, convert it to a valid ArangoDB attribute\n",
+ " (e.g string, list, number, ...).\n",
+ "\n",
+ " NOTE: No action is needed here if you want to keep the numerical-based values\n",
+ " of your DGL features.\n",
+ "\n",
+ " :param key: The DGL attribute key name\n",
+ " :type key: str\n",
+ " :param col: The ArangoDB collection of the (soon-to-be) ArangoDB document.\n",
+ " :type col: str\n",
+ " :param val: The assigned attribute value of the DGL node.\n",
+ " :type val: Tensor\n",
+ " :return: The feature's representation as an ArangoDB Attribute\n",
+ " :rtype: Any\n",
" \"\"\"\n",
" if key == \"clique_ndata\":\n",
" if val == 1:\n",
" return \"one is fun\"\n",
" elif val == 2:\n",
- " return \"but two is blue\"\n",
+ " return \"two is blue\"\n",
" elif val == 3:\n",
- " return \"yet three is free\"\n",
+ " return \"three is free\"\n",
" elif val == 4:\n",
- " return \"and four is more\"\n",
- " else:\n",
+ " return \"four is more\"\n",
+ " else: # No special string for values 5 & 6\n",
" return f\"ERROR! Unrecognized value, got {val}\"\n",
"\n",
+ " if key == \"clique_edata\":\n",
+ " return bool(val)\n",
+ "\n",
" return super()._dgl_feature_to_adb_attribute(key, col, val)\n",
"\n",
"# Re-instantiate a new adapter specifically for the Clique Graph Conversion\n",
- "clique_adbgl_adapter = ArangoDB_DGL_Adapter(con, Clique_ADBDGL_Controller)\n",
+ "clique_adbgl_adapter = ADBDGL_Adapter(con, Clique_ADBDGL_Controller())\n",
"\n",
"# Create the ArangoDB graphs\n",
"lollipop = \"Lollipop_With_Attributes\"\n",
@@ -884,15 +969,16 @@
"adb_hypercube_graph = adbdgl_adapter.dgl_to_arangodb(hypercube, dgl_hypercube_graph)\n",
"adb_clique_graph = clique_adbgl_adapter.dgl_to_arangodb(clique, dgl_clique_graph) # Notice the new adapter here!\n",
"\n",
- "print(\"\\nInspect the graphs here:\\n\")\n",
- "print(f\"1) https://tutorials.arangodb.cloud:8529/_db/{con['dbName']}/_admin/aardvark/index.html#graph/{lollipop}\")\n",
- "print(f\"2) https://tutorials.arangodb.cloud:8529/_db/{con['dbName']}/_admin/aardvark/index.html#graph/{hypercube}\")\n",
- "print(f\"3) https://tutorials.arangodb.cloud:8529/_db/{con['dbName']}/_admin/aardvark/index.html#graph/{clique}\\n\")\n",
- "\n",
+ "print('\\n--------------------')\n",
"print(\"https://{}:{}\".format(con[\"hostname\"], con[\"port\"]))\n",
"print(\"Username: \" + con[\"username\"])\n",
"print(\"Password: \" + con[\"password\"])\n",
- "print(\"Database: \" + con[\"dbName\"])"
+ "print(\"Database: \" + con[\"dbName\"])\n",
+ "print('--------------------\\n')\n",
+ "print(\"\\nInspect the graphs here:\\n\")\n",
+ "print(f\"1) https://tutorials.arangodb.cloud:8529/_db/{con['dbName']}/_admin/aardvark/index.html#graph/{lollipop}\")\n",
+ "print(f\"2) https://tutorials.arangodb.cloud:8529/_db/{con['dbName']}/_admin/aardvark/index.html#graph/{hypercube}\")\n",
+ "print(f\"3) https://tutorials.arangodb.cloud:8529/_db/{con['dbName']}/_admin/aardvark/index.html#graph/{clique}\\n\")"
]
}
],
@@ -900,16 +986,15 @@
"colab": {
"collapsed_sections": [
"ot1oJqn7m78n",
- "Oc__NAd1eG8-",
"7y81WHO8eG8_",
"227hLXnPeG8_",
"QfE_tKxneG9A",
- "umy25EsUU6Lg",
+ "ZrEDmtqCVD0W",
+ "qEH6OdSB23Ya",
"UafSB_3JZNwK",
- "gshTlSX_ZZsS",
- "CNj1xKhwoJoL"
+ "gshTlSX_ZZsS"
],
- "name": "Copy of ArangoDB_DGLAdapter.ipynb",
+ "name": "ArangoDB_DGL_Adapter_v1.0.0.ipynb",
"provenance": []
},
"kernelspec": {
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..b9911d5
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,23 @@
+[build-system]
+requires = [
+ "setuptools>=42",
+ "setuptools_scm[toml]>=3.4",
+ "wheel",
+]
+build-backend = "setuptools.build_meta"
+
+[tool.coverage.run]
+omit = [
+ "adbdgl_adapter/version.py",
+ "setup.py",
+]
+
+[tool.isort]
+profile = "black"
+
+[tool.pytest.ini_options]
+minversion = "6.0"
+testpaths = ["tests"]
+
+[tool.setuptools_scm]
+write_to = "adbdgl_adapter/version.py"
diff --git a/scripts/assert_version.py b/scripts/assert_version.py
deleted file mode 100644
index 6621ea1..0000000
--- a/scripts/assert_version.py
+++ /dev/null
@@ -1,10 +0,0 @@
-# -*- coding: utf-8 -*-
-import sys
-from packaging.version import Version
-
-if __name__ == "__main__":
- old = Version(sys.argv[1])
- current = Version(sys.argv[2])
- if current > old:
- print("true")
- sys.exit(0)
diff --git a/scripts/extract_version.py b/scripts/extract_version.py
deleted file mode 100644
index 8901dbd..0000000
--- a/scripts/extract_version.py
+++ /dev/null
@@ -1,9 +0,0 @@
-# -*- coding: utf-8 -*-
-import requests
-
-if __name__ == "__main__":
- response = requests.get(
- "https://api.github.com/repos/arangoml/dgl-adapter/releases/latest"
- )
- response.raise_for_status()
- print(response.json().get("tag_name", "0.0.0"))
diff --git a/setup.cfg b/setup.cfg
new file mode 100644
index 0000000..2e8bd1d
--- /dev/null
+++ b/setup.cfg
@@ -0,0 +1,31 @@
+[metadata]
+name = adbdgl_adapter
+author = Anthony Mahanna
+author_email = anthony.mahanna@arangodb.com
+description = Convert ArangoDB graphs to DGL & vice-versa.
+long_description = file: README.md
+long_description_content_type = text/markdown
+url = https://github.com/arangoml/dgl-adapter
+classifiers =
+ Intended Audience :: Developers
+ License :: OSI Approved :: Apache Software License
+ Operating System :: OS Independent
+ Programming Language :: Python :: 3
+ Programming Language :: Python :: 3.6
+ Programming Language :: Python :: 3.7
+ Programming Language :: Python :: 3.8
+ Programming Language :: Python :: 3.9
+ Topic :: Utilities
+ Typing :: Typed
+
+[options]
+python_requires = >=3.6
+
+[flake8]
+max-line-length = 88
+extend-ignore = E203, E741, W503
+exclude =.git .idea .*_cache dist venv
+
+[mypy]
+ignore_missing_imports = True
+strict = True
diff --git a/adbdgl_adapter/setup.py b/setup.py
similarity index 58%
rename from adbdgl_adapter/setup.py
rename to setup.py
index 5a62704..882e4f8 100644
--- a/adbdgl_adapter/setup.py
+++ b/setup.py
@@ -1,30 +1,43 @@
from setuptools import setup
-with open("../VERSION") as f:
- version = f.read().strip()
-
-with open("../README.md", "r") as f:
- long_description = f.read()
+with open("./README.md") as fp:
+ long_description = fp.read()
setup(
name="adbdgl_adapter",
- author="ArangoDB",
- author_email="hackers@arangodb.com",
- version=version,
+ author="Anthony Mahanna",
+ author_email="anthony.mahanna@arangodb.com",
description="Convert ArangoDB graphs to DGL & vice-versa.",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/arangoml/dgl-adapter",
+ keywords=["arangodb", "dgl", "adapter"],
packages=["adbdgl_adapter"],
include_package_data=True,
+ use_scm_version=True,
+ setup_requires=["setuptools_scm"],
python_requires=">=3.6",
license="Apache Software License",
install_requires=[
- "python-arango==7.2.0",
+ "python-arango==7.3.0",
"torch==1.10.0",
"dgl==0.6.1",
+ "setuptools>=42",
+ "setuptools_scm[toml]>=3.4",
],
- tests_require=["pytest", "pytest-cov"],
+ extras_require={
+ "dev": [
+ "black",
+ "flake8>=3.8.0",
+ "isort>=5.0.0",
+ "mypy>=0.790",
+ "pytest>=6.0.0",
+ "pytest-cov>=2.0.0",
+ "coveralls>=3.3.1",
+ "types-setuptools",
+ "types-requests",
+ ],
+ },
classifiers=[
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/adbdgl_adapter/tests/assets/arangorestore b/tests/assets/arangorestore
similarity index 100%
rename from adbdgl_adapter/tests/assets/arangorestore
rename to tests/assets/arangorestore
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100755
index 0000000..0cdf2c1
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,123 @@
+import json
+import os
+import subprocess
+import time
+from pathlib import Path
+
+from arango import ArangoClient
+from arango.database import StandardDatabase
+from dgl import DGLGraph, remove_self_loop
+from dgl.data import KarateClubDataset, MiniGCDataset
+from requests import post
+from torch import ones, rand, tensor, zeros # type: ignore
+
+from adbdgl_adapter.adapter import ADBDGL_Adapter
+from adbdgl_adapter.typings import Json
+
+PROJECT_DIR = Path(__file__).parent.parent
+
+con: Json
+adbdgl_adapter: ADBDGL_Adapter
+db: StandardDatabase
+
+
+def pytest_sessionstart() -> None:
+ global con
+ con = get_oasis_crendetials()
+ # con = {
+ # "username": "root",
+ # "password": "openSesame",
+ # "hostname": "localhost",
+ # "port": 8529,
+ # "protocol": "http",
+ # "dbName": "_system",
+ # }
+ print_connection_details(con)
+ time.sleep(5) # Enough for the oasis instance to be ready.
+
+ global adbdgl_adapter
+ adbdgl_adapter = ADBDGL_Adapter(con)
+
+ global db
+ url = "https://" + con["hostname"] + ":" + str(con["port"])
+ client = ArangoClient(hosts=url)
+ db = client.db(con["dbName"], con["username"], con["password"], verify=True)
+
+ arango_restore(con, "examples/data/fraud_dump")
+ db.create_graph(
+ "fraud-detection",
+ edge_definitions=[
+ {
+ "edge_collection": "accountHolder",
+ "from_vertex_collections": ["customer"],
+ "to_vertex_collections": ["account"],
+ },
+ {
+ "edge_collection": "transaction",
+ "from_vertex_collections": ["account"],
+ "to_vertex_collections": ["account"],
+ },
+ ],
+ )
+
+
+def get_oasis_crendetials() -> Json:
+ url = "https://tutorials.arangodb.cloud:8529/_db/_system/tutorialDB/tutorialDB"
+ request = post(url, data=json.dumps("{}"))
+ if request.status_code != 200:
+ raise Exception("Error retrieving login data.")
+
+ creds: Json = json.loads(request.text)
+ return creds
+
+
+def arango_restore(con: Json, path_to_data: str) -> None:
+ restore_prefix = "./assets/" if os.getenv("GITHUB_ACTIONS") else ""
+
+ subprocess.check_call(
+ f'chmod -R 755 ./assets/arangorestore && {restore_prefix}arangorestore \
+ -c none --server.endpoint http+ssl://{con["hostname"]}:{con["port"]} \
+ --server.username {con["username"]} --server.database {con["dbName"]} \
+ --server.password {con["password"]} \
+ --input-directory "{PROJECT_DIR}/{path_to_data}"',
+ cwd=f"{PROJECT_DIR}/tests",
+ shell=True,
+ )
+
+
+def print_connection_details(con: Json) -> None:
+ print("----------------------------------------")
+ print("https://{}:{}".format(con["hostname"], con["port"]))
+ print("Username: " + con["username"])
+ print("Password: " + con["password"])
+ print("Database: " + con["dbName"])
+ print("----------------------------------------")
+
+
+def get_karate_graph() -> DGLGraph:
+ return KarateClubDataset()[0]
+
+
+def get_lollipop_graph() -> DGLGraph:
+ dgl_g = remove_self_loop(MiniGCDataset(8, 7, 8)[3][0])
+ dgl_g.ndata["random_ndata"] = tensor(
+ [[i, i, i] for i in range(0, dgl_g.num_nodes())]
+ )
+ dgl_g.edata["random_edata"] = rand(dgl_g.num_edges())
+ return dgl_g
+
+
+def get_hypercube_graph() -> DGLGraph:
+ dgl_g = remove_self_loop(MiniGCDataset(8, 8, 9)[4][0])
+ dgl_g.ndata["random_ndata"] = rand(dgl_g.num_nodes())
+ dgl_g.edata["random_edata"] = tensor(
+ [[[i], [i], [i]] for i in range(0, dgl_g.num_edges())]
+ )
+ return dgl_g
+
+
+def get_clique_graph() -> DGLGraph:
+ dgl_g = remove_self_loop(MiniGCDataset(8, 6, 7)[6][0])
+ dgl_g.ndata["random_ndata"] = ones(dgl_g.num_nodes())
+ dgl_g.edata["random_edata"] = zeros(dgl_g.num_edges())
+ return dgl_g
diff --git a/adbdgl_adapter/tests/test_adbdgl_adapter.py b/tests/test_adapter.py
similarity index 79%
rename from adbdgl_adapter/tests/test_adbdgl_adapter.py
rename to tests/test_adapter.py
index 8f7ae4b..f9fcf1c 100644
--- a/adbdgl_adapter/tests/test_adbdgl_adapter.py
+++ b/tests/test_adapter.py
@@ -1,27 +1,26 @@
-from typing import Union
+from typing import Set, Union
import pytest
-from conftest import (
- ArangoDB_DGL_Adapter,
- get_karate_graph,
- get_lollipop_graph,
- get_hypercube_graph,
- get_clique_graph,
- db,
- conn,
- adbdgl_adapter,
-)
-
+from arango.graph import Graph as ArangoGraph
from dgl import DGLGraph
from dgl.heterograph import DGLHeteroGraph
-from arango.graph import Graph as ArangoGraph
-
-import torch
from torch.functional import Tensor
+from adbdgl_adapter.adapter import ADBDGL_Adapter
+from adbdgl_adapter.typings import ArangoMetagraph
-@pytest.mark.unit
-def test_validate_attributes():
+from .conftest import (
+ adbdgl_adapter,
+ con,
+ db,
+ get_clique_graph,
+ get_hypercube_graph,
+ get_karate_graph,
+ get_lollipop_graph,
+)
+
+
+def test_validate_attributes() -> None:
bad_connection = {
"dbName": "_system",
"hostname": "localhost",
@@ -32,19 +31,17 @@ def test_validate_attributes():
}
with pytest.raises(ValueError):
- ArangoDB_DGL_Adapter(bad_connection)
+ ADBDGL_Adapter(bad_connection)
-@pytest.mark.unit
-def test_validate_controller_class():
+def test_validate_controller_class() -> None:
class Bad_ADBDGL_Controller:
pass
with pytest.raises(TypeError):
- ArangoDB_DGL_Adapter(conn, Bad_ADBDGL_Controller)
+ ADBDGL_Adapter(con, Bad_ADBDGL_Controller()) # type: ignore
-@pytest.mark.unit
@pytest.mark.parametrize(
"adapter, name, metagraph",
[
@@ -70,13 +67,13 @@ class Bad_ADBDGL_Controller:
),
],
)
-def test_adb_to_dgl(adapter: ArangoDB_DGL_Adapter, name: str, metagraph: dict):
- assert_adapter_type(adapter)
+def test_adb_to_dgl(
+ adapter: ADBDGL_Adapter, name: str, metagraph: ArangoMetagraph
+) -> None:
dgl_g = adapter.arangodb_to_dgl(name, metagraph)
- assert_dgl_data(dgl_g, metagraph["vertexCollections"], metagraph["edgeCollections"])
+ assert_dgl_data(dgl_g, metagraph)
-@pytest.mark.unit
@pytest.mark.parametrize(
"adapter, name, v_cols, e_cols",
[
@@ -89,38 +86,41 @@ def test_adb_to_dgl(adapter: ArangoDB_DGL_Adapter, name: str, metagraph: dict):
],
)
def test_adb_collections_to_dgl(
- adapter: ArangoDB_DGL_Adapter, name: str, v_cols: set, e_cols: set
-):
- assert_adapter_type(adapter)
+ adapter: ADBDGL_Adapter, name: str, v_cols: Set[str], e_cols: Set[str]
+) -> None:
dgl_g = adapter.arangodb_collections_to_dgl(
name,
v_cols,
e_cols,
)
assert_dgl_data(
- dgl_g, {v_col: {} for v_col in v_cols}, {e_col: {} for e_col in e_cols}
+ dgl_g,
+ metagraph={
+ "vertexCollections": {col: set() for col in v_cols},
+ "edgeCollections": {col: set() for col in e_cols},
+ },
)
-@pytest.mark.unit
@pytest.mark.parametrize(
"adapter, name",
[(adbdgl_adapter, "fraud-detection")],
)
-def test_adb_graph_to_dgl(adapter: ArangoDB_DGL_Adapter, name: str):
- assert_adapter_type(adapter)
-
+def test_adb_graph_to_dgl(adapter: ADBDGL_Adapter, name: str) -> None:
arango_graph = db.graph(name)
v_cols = arango_graph.vertex_collections()
e_cols = {col["edge_collection"] for col in arango_graph.edge_definitions()}
dgl_g: DGLGraph = adapter.arangodb_graph_to_dgl(name)
assert_dgl_data(
- dgl_g, {v_col: {} for v_col in v_cols}, {e_col: {} for e_col in e_cols}
+ dgl_g,
+ metagraph={
+ "vertexCollections": {col: set() for col in v_cols},
+ "edgeCollections": {col: set() for col in e_cols},
+ },
)
-@pytest.mark.unit
@pytest.mark.parametrize(
"adapter, name, dgl_g, is_default_type, batch_size",
[
@@ -131,26 +131,21 @@ def test_adb_graph_to_dgl(adapter: ArangoDB_DGL_Adapter, name: str):
],
)
def test_dgl_to_adb(
- adapter: ArangoDB_DGL_Adapter,
+ adapter: ADBDGL_Adapter,
name: str,
dgl_g: Union[DGLGraph, DGLHeteroGraph],
is_default_type: bool,
batch_size: int,
-):
- assert_adapter_type(adapter)
+) -> None:
adb_g = adapter.dgl_to_arangodb(name, dgl_g, batch_size)
assert_arangodb_data(name, dgl_g, adb_g, is_default_type)
-def assert_adapter_type(adapter: ArangoDB_DGL_Adapter):
- assert type(adapter) is ArangoDB_DGL_Adapter
-
-
-def assert_dgl_data(dgl_g: DGLGraph, v_cols: dict, e_cols: dict):
- has_one_ntype = len(v_cols) == 1
- has_one_etype = len(e_cols) == 1
+def assert_dgl_data(dgl_g: DGLGraph, metagraph: ArangoMetagraph) -> None:
+ has_one_ntype = len(metagraph["vertexCollections"]) == 1
+ has_one_etype = len(metagraph["edgeCollections"]) == 1
- for col, atribs in v_cols.items():
+ for col, atribs in metagraph["vertexCollections"].items():
num_nodes = dgl_g.num_nodes(col)
assert num_nodes == db.collection(col).count()
@@ -162,7 +157,7 @@ def assert_dgl_data(dgl_g: DGLGraph, v_cols: dict, e_cols: dict):
assert col in dgl_g.ndata[atrib]
assert len(dgl_g.ndata[atrib][col]) == num_nodes
- for col, atribs in e_cols.items():
+ for col, atribs in metagraph["edgeCollections"].items():
num_edges = dgl_g.num_edges(col)
assert num_edges == db.collection(col).count()
@@ -181,7 +176,7 @@ def assert_arangodb_data(
dgl_g: Union[DGLGraph, DGLHeteroGraph],
adb_g: ArangoGraph,
is_default_type: bool,
-):
+) -> None:
for dgl_v_col in dgl_g.ntypes:
adb_v_col = name + dgl_v_col if is_default_type else dgl_v_col
attributes = dgl_g.node_attr_schemes(