Skip to content

Commit f76e1f5

Browse files
committed
WIP: Modify level construction.
1 parent 198b934 commit f76e1f5

File tree

10 files changed

+399
-282
lines changed

10 files changed

+399
-282
lines changed

.github/workflows/publish.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: Publish
1+
name: Publish
22
on:
33
workflow_dispatch:
44
jobs:

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ coverage.xml
5050
.hypothesis/
5151
.pytest_cache/
5252
cover/
53+
junit/
5354

5455
# Translations
5556
*.mo

poetry.lock

+157-160
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/finch/__init__.py

-13
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,3 @@
1-
from .levels import (
2-
Dense,
3-
Element,
4-
Pattern,
5-
SparseList,
6-
SparseByteMap,
7-
RepeatRLE,
8-
SparseVBL,
9-
SparseCOO,
10-
SparseHash,
11-
Storage,
12-
DenseStorage,
13-
)
141
from .tensor import (
152
Tensor,
163
astype,

src/finch/formats.py

+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import abc
2+
import typing
3+
4+
from .julia import jl
5+
from . import levels
6+
from . import utils
7+
from dataclasses import dataclass
8+
9+
10+
@dataclass
11+
class Format:
12+
levels: tuple[levels.AbstractLevel, ...]
13+
order: tuple[int, ...]
14+
leaf: levels.AbstractLeafLevel
15+
16+
def __init__(
17+
self,
18+
*,
19+
levels: tuple[levels.AbstractLevel, ...],
20+
order: tuple[int, ...] | None,
21+
leaf: levels.AbstractLeafLevel,
22+
) -> None:
23+
if order is None:
24+
order = tuple(range(len(levels)))
25+
26+
utils.check_valid_order(order, ndim=len(levels))
27+
self.order = order
28+
self.levels = levels
29+
self.leaf = leaf
30+
31+
def _construct(self, *, fill_value, dtype: jl.DataType, data=None):
32+
out_level = self.leaf._construct(dtype=dtype, fill_value=fill_value)
33+
for level in reversed(self.levels):
34+
out_level = level._construct(inner_level=out_level)
35+
36+
swizzle_args = map(lambda x: x + 1, reversed(self.order))
37+
if data is None:
38+
return jl.swizzle(jl.Tensor(out_level), *swizzle_args)
39+
40+
return jl.swizzle(jl.Tensor(out_level, data), *swizzle_args)
41+
42+
class FlexibleFormat(abc.ABC):
43+
def _construct(self, *, ndim: int, fill_value, dtype: jl.DataType, data=None):
44+
return self._get_format(ndim)._construct(fill_value=fill_value, dtype=dtype, data=data)
45+
46+
@abc.abstractmethod
47+
def _get_format(self, ndim: int, /) -> Format:
48+
pass
49+
50+
@dataclass
51+
class Dense(FlexibleFormat):
52+
order: typing.Literal["C", "F"] | tuple[int, ...] = "C"
53+
shape: tuple[int | None, ...] | None = None
54+
55+
def __post_init__(self) -> None:
56+
if isinstance(self.order, tuple):
57+
utils.check_valid_order(self.order)
58+
59+
if self.shape is not None and len(self.order) != len(self.shape):
60+
raise ValueError(f"len(self.order) != len(self.shape), {self.order}, {self.shape}")
61+
62+
def _get_format(self, ndim: int) -> Format:
63+
super()._get_format(ndim)
64+
match self.order:
65+
case "C":
66+
order = tuple(range(ndim))
67+
case "F":
68+
order = tuple(reversed(range(ndim)))
69+
case _:
70+
order = self.order
71+
72+
utils.check_valid_order(order, ndim=ndim)
73+
74+
shape = self.shape
75+
if shape is None:
76+
shape = (None,) * ndim
77+
78+
if len(shape) != ndim:
79+
raise ValueError(f"len(self.shape != ndim), {shape=}, {ndim=}")
80+
81+
topological_shape = utils.get_topological_shape(shape, order=order)
82+
lvls = tuple(levels.Dense(dim=dim) for dim in topological_shape)
83+
84+
return Format(levels=lvls, order=order, leaf=levels.Element())

src/finch/levels.py

+43-66
Original file line numberDiff line numberDiff line change
@@ -1,93 +1,70 @@
1-
import numpy as np
1+
import abc
2+
23

34
from .julia import jl
4-
from .typing import OrderType
5+
from . import dtypes
6+
from dataclasses import dataclass
57

68

7-
class _Display:
9+
class _Display(abc.ABC):
810
def __repr__(self):
911
return jl.sprint(jl.show, self._obj)
1012

1113
def __str__(self):
1214
return jl.sprint(jl.show, jl.MIME("text/plain"), self._obj)
1315

1416

15-
# LEVEL
17+
class AbstractLeafLevel(abc.ABC):
18+
@abc.abstractmethod
19+
def _construct(self, *, dtype, fill_value):
20+
...
21+
1622

17-
class AbstractLevel(_Display):
18-
pass
23+
# LEVEL
24+
class AbstractLevel(abc.ABC):
25+
@abc.abstractmethod
26+
def _construct(self, *, inner_level):
27+
...
1928

2029

2130
# core levels
22-
31+
@dataclass
2332
class Dense(AbstractLevel):
24-
def __init__(self, lvl, shape=None):
25-
args = [lvl._obj]
26-
if shape is not None:
27-
args.append(shape)
28-
self._obj = jl.Dense(*args)
29-
30-
31-
class Element(AbstractLevel):
32-
def __init__(self, fill_value, data=None):
33-
args = [fill_value]
34-
if data is not None:
35-
args.append(data)
36-
self._obj = jl.Element(*args)
37-
38-
39-
class Pattern(AbstractLevel):
40-
def __init__(self):
41-
self._obj = jl.Pattern()
42-
43-
44-
# advanced levels
45-
46-
class SparseList(AbstractLevel):
47-
def __init__(self, lvl):
48-
self._obj = jl.SparseList(lvl._obj)
49-
50-
51-
class SparseByteMap(AbstractLevel):
52-
def __init__(self, lvl):
53-
self._obj = jl.SparseByteMap(lvl._obj)
54-
33+
dim: int | None = None
34+
index_type: jl.DataType = dtypes.int64
5535

56-
class RepeatRLE(AbstractLevel):
57-
def __init__(self, lvl):
58-
self._obj = jl.RepeatRLE(lvl._obj)
36+
def _construct(self, *, inner_level) -> jl.Dense:
37+
if self.dim is None:
38+
return jl.Dense[self.index_type](inner_level)
5939

40+
return jl.Dense[self.index_type](inner_level, self.dim)
6041

61-
class SparseVBL(AbstractLevel):
62-
def __init__(self, lvl):
63-
self._obj = jl.SparseVBL(lvl._obj)
6442

43+
@dataclass
44+
class Element(AbstractLeafLevel):
45+
def _construct(self, *, dtype: jl.DataType, fill_value) -> jl.Element:
46+
return jl.Element[fill_value, dtype]()
6547

66-
class SparseCOO(AbstractLevel):
67-
def __init__(self, ndim, lvl):
68-
self._obj = jl.SparseCOO[ndim](lvl._obj)
6948

49+
@dataclass
50+
class Pattern(AbstractLeafLevel):
51+
def _construct(self, *, dtype, fill_value) -> jl.Pattern:
52+
from .dtypes import bool
7053

71-
class SparseHash(AbstractLevel):
72-
def __init__(self, ndim, lvl):
73-
self._obj = jl.SparseHash[ndim](lvl._obj)
54+
if dtype != bool:
55+
raise TypeError("`Pattern` can only have `dtype=bool`.")
56+
if dtype(fill_value) != dtype(False):
57+
raise TypeError("`Pattern` can only have `fill_value=False`.")
7458

59+
return jl.Pattern()
7560

76-
# STORAGE
7761

78-
class Storage:
79-
def __init__(self, levels_descr: AbstractLevel, order: OrderType = None):
80-
self.levels_descr = levels_descr
81-
self.order = order if order is not None else "C"
82-
83-
def __str__(self) -> str:
84-
return f"Storage(lvl={str(self.levels_descr)}, order={self.order})"
85-
86-
87-
class DenseStorage(Storage):
88-
def __init__(self, ndim: int, dtype: np.dtype, order: OrderType = None):
89-
lvl = Element(np.int_(0).astype(dtype))
90-
for _ in range(ndim):
91-
lvl = Dense(lvl)
62+
# advanced levels
63+
@dataclass
64+
class SparseList(AbstractLevel):
65+
index_type: jl.DataType = dtypes.int64
66+
pos_type: jl.DataType = dtypes.uint64
67+
crd_type: jl.DataType = dtypes.uint64
9268

93-
super().__init__(levels_descr=lvl, order=order)
69+
def _construct(self, *, inner_level) -> jl.SparseList:
70+
return jl.SparseList[self.index_type, self.pos_type, self.crd_type](inner_level)

0 commit comments

Comments
 (0)