From 7f3f390925ce0b290629aed76941a84a14666d8d Mon Sep 17 00:00:00 2001 From: Tom White Date: Fri, 6 Oct 2023 17:00:43 +0100 Subject: [PATCH] TensorStore --- .github/workflows/tensorstore-tests.yml | 47 +++++++ cubed/storage/backend.py | 4 + cubed/storage/backends/tensorstore.py | 155 ++++++++++++++++++++++++ setup.cfg | 2 + 4 files changed, 208 insertions(+) create mode 100644 .github/workflows/tensorstore-tests.yml create mode 100644 cubed/storage/backends/tensorstore.py diff --git a/.github/workflows/tensorstore-tests.yml b/.github/workflows/tensorstore-tests.yml new file mode 100644 index 00000000..f6a1d914 --- /dev/null +++ b/.github/workflows/tensorstore-tests.yml @@ -0,0 +1,47 @@ +name: TensorStore tests + +on: + schedule: + # Every weekday at 03:58 UTC, see https://crontab.guru/ + - cron: "58 3 * * 1-5" + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + test: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: ["ubuntu-latest"] + python-version: ["3.11"] + + steps: + - name: Checkout source + uses: actions/checkout@v3 + with: + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + architecture: x64 + + - name: Setup Graphviz + uses: ts-graphviz/setup-graphviz@v2 + + - name: Install + run: | + python -m pip install --upgrade pip + python -m pip install -e '.[test]' 'tensorstore' + + - name: Run tests + run: | + # exclude tests that rely on the nchunks_initialized array attribute + pytest -k "not test_resume" + env: + CUBED_STORAGE_NAME: tensorstore diff --git a/cubed/storage/backend.py b/cubed/storage/backend.py index 2bb19168..a341a398 100644 --- a/cubed/storage/backend.py +++ b/cubed/storage/backend.py @@ -22,6 +22,10 @@ def open_backend_array( from cubed.storage.backends.zarr_python import open_zarr_array open_func = open_zarr_array + elif storage_name == "tensorstore": + from cubed.storage.backends.tensorstore import open_tensorstore_array + + open_func = open_tensorstore_array else: raise ValueError(f"Unrecognized storage name: {storage_name}") diff --git a/cubed/storage/backends/tensorstore.py b/cubed/storage/backends/tensorstore.py new file mode 100644 index 00000000..efd2adad --- /dev/null +++ b/cubed/storage/backends/tensorstore.py @@ -0,0 +1,155 @@ +import dataclasses +import math +from typing import Any, Dict, Optional + +import numpy as np +import tensorstore + +from cubed.types import T_DType, T_RegularChunks, T_Shape, T_Store +from cubed.utils import join_path + + +@dataclasses.dataclass(frozen=True) +class TensorStoreArray: + array: tensorstore.TensorStore + + @property + def shape(self) -> tuple[int, ...]: + return self.array.shape + + @property + def dtype(self) -> np.dtype: + return self.array.dtype.numpy_dtype + + @property + def chunks(self) -> tuple[int, ...]: + return self.array.chunk_layout.read_chunk.shape or () + + @property + def ndim(self) -> int: + return len(self.shape) + + @property + def size(self) -> int: + return math.prod(self.shape) + + @property + def oindex(self): + return self.array.oindex + + def __getitem__(self, key): + # read eagerly + return self.array.__getitem__(key).read().result() + + def __setitem__(self, key, value): + self.array.__setitem__(key, value) + + +class TensorStoreGroup(dict): + def __init__( + self, + shape: Optional[T_Shape] = None, + dtype: Optional[T_DType] = None, + chunks: Optional[T_RegularChunks] = None, + ): + dict.__init__(self) + self.shape = shape + self.dtype = dtype + self.chunks = chunks + + def __getitem__(self, key): + if isinstance(key, str): + return super().__getitem__(key) + return {field: zarray[key] for field, zarray in self.items()} + + def set_basic_selection(self, selection, value, fields=None): + self[fields][selection] = value + + +def encode_dtype(d): + if d.fields is None: + return d.str + else: + return d.descr + + +def get_metadata(dtype, chunks): + metadata = {} + if dtype is not None: + dtype = np.dtype(dtype) + metadata["dtype"] = encode_dtype(dtype) + if chunks is not None: + if isinstance(chunks, int): + chunks = (chunks,) + metadata["chunks"] = chunks + return metadata + + +def open_tensorstore_array( + store: T_Store, + mode: str, + *, + shape: Optional[T_Shape] = None, + dtype: Optional[T_DType] = None, + chunks: Optional[T_RegularChunks] = None, + path: Optional[str] = None, + **kwargs, +): + store = str(store) # TODO: check if Path or str + + spec: Dict[str, Any] + if "://" in store: + spec = {"driver": "zarr", "kvstore": store} + else: + spec = { + "driver": "zarr", + "kvstore": {"driver": "file", "path": store}, + "path": path or "", + } + + if mode == "r": + open_kwargs = dict(read=True, open=True) + elif mode == "r+": + open_kwargs = dict(read=True, write=True, open=True) + elif mode == "a": + open_kwargs = dict(read=True, write=True, open=True, create=True) + elif mode == "w": + open_kwargs = dict(read=True, write=True, create=True, delete_existing=True) + elif mode == "w-": + open_kwargs = dict(read=True, write=True, create=True) + else: + raise ValueError(f"Mode not supported: {mode}") + + if dtype is None or not hasattr(dtype, "fields") or dtype.fields is None: + metadata = get_metadata(dtype, chunks) + if metadata: + spec["metadata"] = metadata + + return TensorStoreArray( + tensorstore.open( + spec, + shape=shape, + dtype=dtype, + **open_kwargs, + ).result() + ) + else: + ret = TensorStoreGroup(shape=shape, dtype=dtype, chunks=chunks) + for field in dtype.fields: + field_path = field if path is None else join_path(path, field) + spec["path"] = field_path + + field_dtype, _ = dtype.fields[field] + metadata = get_metadata(field_dtype, chunks) + if metadata: + spec["metadata"] = metadata + + ret[field] = TensorStoreArray( + tensorstore.open( + spec, + shape=shape, + dtype=field_dtype, + **open_kwargs, + ).result() + ) + return ret diff --git a/setup.cfg b/setup.cfg index 8606c1e8..76b6a455 100644 --- a/setup.cfg +++ b/setup.cfg @@ -64,6 +64,8 @@ ignore_missing_imports = True ignore_missing_imports = True [mypy-tenacity.*] ignore_missing_imports = True +[mypy-tensorstore.*] +ignore_missing_imports = True [mypy-tlz.*] ignore_missing_imports = True [mypy-toolz.*]