Skip to content
This repository was archived by the owner on Mar 14, 2024. It is now read-only.

Add support for distributed file system (HDFS) #195

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/input_output.rst
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,8 @@ the entities, with the first dimension being the number of entities and the
second being the dimension of the embedding.

Just like for the model parameters file, the optimizer state dict and additional metadata is also included.

HDFS Format
^^^^^^^^^^

Include the prefix ``hdfs://`` in entities, edges and checkpoint paths when running in distributed hdfs cluster.
Binary file added test/resources/edges_0_0.h5
Binary file not shown.
Empty file added test/resources/invalidFile.h5
Empty file.
Empty file added test/resources/text.txt
Empty file.
257 changes: 257 additions & 0 deletions test/test_storage_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
import shutil
import tempfile
from contextlib import AbstractContextManager
from io import TextIOWrapper
from pathlib import Path
from unittest import TestCase, main
import h5py

from torchbiggraph.storage_repository import CUSTOM_PATH
from torchbiggraph.storage_repository import LocalPath, HDFSPath, HDFSFileContextManager, LocalFileContextManager
from torchbiggraph.util import run_external_cmd, url_scheme

HDFS_TEST_PATH = '<valid hdfs path>'


def _touch_file(name: str):
file_path = HDFS_TEST_PATH + "/" + name
run_external_cmd("hadoop fs -touchz " + file_path)
return file_path


class TestLocalFileContextManager(TestCase):
def setUp(self):
self.resource_dir = Path(__file__).parent.absolute() / 'resources'

def test_get_resource_valid_h5(self):
filepath_h5 = self.resource_dir / 'edges_0_0.h5'
file_path = str(filepath_h5)
self.assertIs(type(LocalFileContextManager.get_resource(file_path, 'r')), h5py.File)

def test_get_resource_invalid_h5(self):
filepath_h5 = self.resource_dir / 'invalidFile.h5'
file_path = str(filepath_h5)

with self.assertRaises(ValueError):
LocalFileContextManager.get_resource(str(file_path), 'r')

def test_get_resource_valid_text_file(self):
filepath_txt = self.resource_dir / 'text.txt'
file_path = str(filepath_txt)
self.assertIs(type(LocalFileContextManager.get_resource(str(file_path), 'r')), TextIOWrapper)


class TestHDFSFileContextManager(TestCase):
def setUp(self):
if not HDFSFileContextManager.hdfs_file_exists(HDFS_TEST_PATH):
self.skipTest('skipped test due to skip_tests_flag')

self.resource_dir = Path(__file__).parent.absolute() / 'resources'
run_external_cmd("hadoop fs -mkdir -p " + HDFS_TEST_PATH)

def tearDown(self):
run_external_cmd("hadoop fs -rm -r " + HDFS_TEST_PATH)

def test_prepare_hdfs_path(self):
actual = HDFSFileContextManager.get_hdfs_path(Path.cwd() / '/some/path')
expected = '/some/path'
self.assertEqual(str(expected), actual)

def test_hdfs_file_exists(self):
valid_path = _touch_file('abc')
self.assertTrue(HDFSFileContextManager.hdfs_file_exists(valid_path))

def test_hdfs_file_doesnt_exists(self):
invalid_path = HDFS_TEST_PATH + "/invalid_loc"
self.assertFalse(HDFSFileContextManager.hdfs_file_exists(invalid_path))

def test_get_from_hdfs_valid(self):
valid_hdfs_file = _touch_file('valid.file')
local_file = Path(str(Path.cwd()) + valid_hdfs_file)
file_ctx = HDFSFileContextManager(local_file, 'r')

# valid path
file_ctx.get_from_hdfs(reload=True)
self.assertTrue(Path(file_ctx._path).exists())

def test_get_from_hdfs_valid_dont_reload(self):
valid_hdfs_file = _touch_file('valid.file')
local_file = Path(str(Path.cwd()) + valid_hdfs_file)
file_ctx = HDFSFileContextManager(local_file, 'r')

# valid path
file_ctx.get_from_hdfs(reload=False)
self.assertTrue(Path(file_ctx._path).exists())

def test_get_from_hdfs_invalid(self):
invalid_hdfs_file = Path('./' + HDFS_TEST_PATH + "/invalid_loc").resolve()
file_ctx = HDFSFileContextManager(invalid_hdfs_file, 'r')

# invalid path
with self.assertRaises(FileNotFoundError):
file_ctx.get_from_hdfs(reload=True)

def test_put_to_hdfs(self):
local_file_name = 'test_local.file'
local_file = Path(str(Path.cwd()) + HDFS_TEST_PATH + '/' + local_file_name)
file_ctx = HDFSFileContextManager(local_file, 'w')

# clean up local
if local_file.exists():
local_file.unlink()

# invalid path
with self.assertRaises(FileNotFoundError):
file_ctx.put_to_hdfs()

# create local file
local_file.touch()
file_ctx.put_to_hdfs()
self.assertTrue(HDFSFileContextManager.hdfs_file_exists(HDFS_TEST_PATH + '/' + local_file_name))


class TestLocalPath(TestCase):
def setUp(self):
self.resource_dir = Path(__file__).parent.absolute() / 'resources'

def test_init(self):
path = LocalPath(Path.cwd())
self.assertIs(type(path), LocalPath)

path = LocalPath('some/path')
self.assertIs(type(path), LocalPath)

def test_stem_suffix(self):
path = LocalPath('some/path/name.txt')
self.assertTrue(path.stem == 'name')
self.assertTrue(path.suffix == '.txt')
self.assertIsInstance(path.stem, str)

def test_name(self):
path = LocalPath('some/path/name')
self.assertTrue(path.name == 'name')
self.assertIsInstance(path.name, str)

def test_resolve(self):
path = LocalPath('some/path/name')
actual = path.resolve(strict=False)
expected = Path.cwd() / Path(str(path))
self.assertTrue(str(actual) == str(expected))

def test_exists(self):
invalid_path = LocalPath('some/path/name')
self.assertFalse(invalid_path.exists())

valid_path = LocalPath(Path(__file__))
self.assertTrue(valid_path.exists())

def test_append_path(self):
path = LocalPath('/some/path/name')
actual = path / 'storage_manager.py'
expected = '/some/path/name/storage_manager.py'
self.assertTrue(str(actual) == expected)

def test_open(self):
file_path = Path(__file__)
with file_path.open('r') as fh:
self.assertGreater(len(fh.readlines()), 0)

def test_mkdir(self):
path = LocalPath(self.resource_dir)
path.parent.mkdir(parents=True, exist_ok=True)

def test_with_plugin_empty_scheme(self):
local_path = '/some/path/file.txt'
actual_path = CUSTOM_PATH.get_class(url_scheme(local_path))(local_path)
expected_path = '/some/path/file.txt'
self.assertEqual(str(actual_path), str(expected_path))

def test_with_plugin_file_scheme(self):
local_path = 'file:///some/path/file.txt'
actual_path = CUSTOM_PATH.get_class(url_scheme(local_path))(local_path)
expected_path = '/some/path/file.txt'
self.assertEqual(str(expected_path), str(actual_path))


class TestHDFSDataPath(TestCase):

def setUp(self):
if not HDFSFileContextManager.hdfs_file_exists(HDFS_TEST_PATH):
self.skipTest('skipped test due to skip_tests_flag')

self.resource_dir = Path(__file__).parent.absolute() / 'resources'
run_external_cmd("hadoop fs -mkdir -p " + HDFS_TEST_PATH)

def tearDown(self):
run_external_cmd("hadoop fs -rm -r " + HDFS_TEST_PATH)

def test_delete_valid(self):
valid_path = _touch_file('abc.txt')
local_temp_dir = str(Path.cwd()) + '/' + 'axp'

# create resolved path based on the hdfs path
remote_path = HDFSPath(valid_path).resolve(strict=False)
remote_path.parent.mkdir(parents=True, exist_ok=True)
remote_path.touch()
remote_path.unlink()

# remove local path
shutil.rmtree(local_temp_dir, ignore_errors=True)

def test_delete_invalid(self):
invalid_path = HDFSPath(HDFS_TEST_PATH + '/invalid.file')
with self.assertRaises(FileNotFoundError):
invalid_path.unlink()

def test_open(self):
filepath_h5 = self.resource_dir / 'edges_0_0.h5'
hdfs = HDFSPath(filepath_h5).resolve(strict=False)
with hdfs.open('r') as fh:
self.assertEqual(len(fh.keys()), 3)
self.assertIsInstance(fh, AbstractContextManager)

def test_open_reload_False(self):
filepath_h5 = self.resource_dir / 'edges_0_0.h5'
hdfs = HDFSPath(filepath_h5).resolve(strict=False)
with hdfs.open('r', reload=False) as fh:
self.assertEqual(len(fh.keys()), 3)
self.assertIsInstance(fh, AbstractContextManager)

def test_name(self):
hdfs = HDFSPath('/some/path/file.txt')
self.assertEqual(hdfs.name, 'file.txt')

def test_with_plugin(self):
hdfs_path = 'hdfs:///some/path/file.txt'
actual_path = CUSTOM_PATH.get_class(url_scheme(hdfs_path))(hdfs_path).resolve(strict=False)
expected_path = Path.cwd() / 'some/path/file.txt'
self.assertEqual(str(actual_path), str(expected_path))

def test_append_path(self):
path = HDFSPath('/some/path/name')
actual = path.resolve(strict = False) / 'storage_manager.py'
expected = str(Path.cwd() / 'some/path/name/storage_manager.py')
self.assertEqual(expected, str(actual))

def test_stem_suffix(self):
path = HDFSPath('some/path/name.txt')
self.assertTrue(path.stem == 'name')
self.assertTrue(path.suffix == '.txt')
self.assertIsInstance(path.stem, str)

def test_cleardir(self):
# create empty files

tempdir = tempfile.mkdtemp()
Path(tempdir + 'file1.txt').touch()
Path(tempdir + 'file2.txt').touch()
Path(tempdir + 'file3.txt').touch()

dirpath = HDFSPath(tempdir)
dirpath.cleardir()

self.assertFalse(any(dirpath.iterdir()))


if __name__ == "__main__":
main()
22 changes: 13 additions & 9 deletions torchbiggraph/checkpoint_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@
import logging
import os
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, Generator, List, NamedTuple, Optional, Tuple

import h5py
import numpy as np
import torch
from torchbiggraph.plugin import URLPluginRegistry
from torchbiggraph.types import EntityName, FloatTensorType, ModuleStateDict, Partition
from torchbiggraph.util import CouldNotLoadData, allocate_shared_tensor
from torchbiggraph.util import CouldNotLoadData, allocate_shared_tensor, url_scheme
from torchbiggraph.storage_repository import CUSTOM_PATH, AbstractPath as Path


logger = logging.getLogger("torchbiggraph")
Expand Down Expand Up @@ -208,6 +208,7 @@ def process_dataset(public_name, dataset) -> None:

@CHECKPOINT_STORAGES.register_as("") # No scheme
@CHECKPOINT_STORAGES.register_as("file")
@CHECKPOINT_STORAGES.register_as("hdfs")
class FileCheckpointStorage(AbstractCheckpointStorage):

"""Reads and writes checkpoint data to/from disk.
Expand Down Expand Up @@ -241,9 +242,8 @@ class FileCheckpointStorage(AbstractCheckpointStorage):
"""

def __init__(self, path: str) -> None:
if path.startswith("file://"):
path = path[len("file://") :]
self.path: Path = Path(path).resolve(strict=False)
self.path: Path = CUSTOM_PATH.get_class(url_scheme(path))(path).resolve(strict=False)
self.prepare()

def get_version_file(self, *, path: Optional[Path] = None) -> Path:
if path is None:
Expand Down Expand Up @@ -319,7 +319,7 @@ def save_entity_partition(
) -> None:
path = self.get_entity_partition_file(version, entity_name, partition)
logger.debug(f"Saving to {path}")
with h5py.File(path, "w") as hf:
with path.open("w") as hf:
hf.attrs[FORMAT_VERSION_ATTR] = FORMAT_VERSION
for k, v in metadata.items():
hf.attrs[k] = v
Expand All @@ -338,11 +338,13 @@ def load_entity_partition(
path = self.get_entity_partition_file(version, entity_name, partition)
logger.debug(f"Loading from {path}")
try:
with h5py.File(path, "r") as hf:
with path.open("r") as hf:
if hf.attrs.get(FORMAT_VERSION_ATTR, None) != FORMAT_VERSION:
raise RuntimeError(f"Version mismatch in embeddings file {path}")
embs = load_embeddings(hf, out=out)
optim_state = load_optimizer_state_dict(hf)
except FileNotFoundError as err:
raise CouldNotLoadData() from err
except OSError as err:
# h5py refuses to make it easy to figure out what went wrong. The errno
# attribute is set to None. See https://github.com/h5py/h5py/issues/493.
Expand All @@ -368,7 +370,7 @@ def save_model(
) -> None:
path = self.get_model_file(version)
logger.debug(f"Saving to {path}")
with h5py.File(path, "w") as hf:
with path.open("w") as hf:
hf.attrs[FORMAT_VERSION_ATTR] = FORMAT_VERSION
for k, v in metadata.items():
hf.attrs[k] = v
Expand All @@ -383,11 +385,13 @@ def load_model(
path = self.get_model_file(version)
logger.debug(f"Loading from {path}")
try:
with h5py.File(path, "r") as hf:
with path.open("r") as hf:
if hf.attrs.get(FORMAT_VERSION_ATTR, None) != FORMAT_VERSION:
raise RuntimeError(f"Version mismatch in model file {path}")
state_dict = load_model_state_dict(hf)
optim_state = load_optimizer_state_dict(hf)
except FileNotFoundError as err:
raise CouldNotLoadData() from err
except OSError as err:
# h5py refuses to make it easy to figure out what went wrong. The errno
# attribute is set to None. See https://github.com/h5py/h5py/issues/493.
Expand Down
40 changes: 40 additions & 0 deletions torchbiggraph/examples/configs/distributedCluster_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#!/usr/bin/env python3

def get_torchbiggraph_config():

config = dict( # noqa
# I/O data
entity_path='hdfs://<entity_path>>', # set entity_path
edge_paths=['hdfs://<edge_path>'], # set edge_path
checkpoint_path='hdfs://<checkpoint_path>', # set checkpoint_path
# Graph structure
entities={"all": {"num_partitions": 20}},
relations=[
{
"name": "all_edges",
"lhs": "all",
"rhs": "all",
"operator": "complex_diagonal",
}
],
dynamic_relations=True,
verbose=1,
# Scoring model
dimension=100,
batch_size=1000,
workers=10,
global_emb=False,
# Training
num_epochs=25,
num_machines=10,
num_uniform_negs=100,
num_batch_negs=50,
comparator='cos',
loss_fn='softmax',
distributed_init_method='env://',
lr=0.02,
eval_fraction=0.01 # to reproduce results we need to use all training data
)

return config

Loading