Skip to content

Commit

Permalink
feat(gfql): chain serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
lmeyerov committed Dec 21, 2023
1 parent 90851d6 commit ece0924
Show file tree
Hide file tree
Showing 20 changed files with 856 additions and 11 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm
## [Development]

### Added

* GFQL query serialization: `graphistry.compute.from_json(graphistry.compute.to_json([...]))`
* GFQL predicate `is_year_end`

### Docs
Expand Down
144 changes: 143 additions & 1 deletion graphistry/compute/ast.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from abc import abstractmethod
import logging
from typing import Dict, Optional, Union, cast
from typing_extensions import Literal
import pandas as pd

from graphistry.Plottable import Plottable
from graphistry.util import setup_logger
from graphistry.util import is_json_serializable, setup_logger
from .predicates.ASTPredicate import ASTPredicate
from .predicates.is_in import (
is_in, IsIn
Expand Down Expand Up @@ -66,12 +68,43 @@ def __init__(self, name: Optional[str] = None):
self._name = name
pass

@abstractmethod
def __call__(self, g: Plottable, prev_node_wavefront: Optional[pd.DataFrame], target_wave_front: Optional[pd.DataFrame]) -> Plottable:
raise RuntimeError('__call__ not implemented')

@abstractmethod
def reverse(self) -> 'ASTObject':
raise RuntimeError('reverse not implemented')

@abstractmethod
def to_json(self, validate=True) -> dict:
raise NotImplementedError()

def validate(self) -> None:
pass


##############################################################################


def assert_record_match(d: Dict) -> None:
assert isinstance(d, dict)
for k, v in d.items():
assert isinstance(k, str)
assert isinstance(v, ASTPredicate) or is_json_serializable(v)

def maybe_filter_dict_from_json(d: Dict, key: str) -> Optional[Dict]:
if key not in d:
return None
if key in d and isinstance(d[key], dict):
return {
k: ASTPredicate.from_json(v) if isinstance(v, dict) else v
for k, v in d[key].items()
}
elif key in d and d[key] is not None:
raise ValueError('filter_dict must be a dict or None')
else:
return None

##############################################################################

Expand All @@ -91,6 +124,36 @@ def __init__(self, filter_dict: Optional[dict] = None, name: Optional[str] = Non

def __repr__(self) -> str:
return f'ASTNode(filter_dict={self._filter_dict}, name={self._name})'

def validate(self) -> None:
if self._filter_dict is not None:
assert_record_match(self._filter_dict)
if self._name is not None:
assert isinstance(self._name, str)
if self._query is not None:
assert isinstance(self._query, str)

def to_json(self, validate=True) -> dict:
return {
'type': 'Node',
'filter_dict': {
k: v.to_json() if isinstance(v, ASTPredicate) else v
for k, v in self._filter_dict.items()
if v is not None
} if self._filter_dict is not None else {},
**({'name': self._name} if self._name is not None else {}),
**({'query': self._query } if self._query is not None else {})
}

@classmethod
def from_json(cls, d: dict) -> 'ASTNode':
out = ASTNode(
filter_dict=maybe_filter_dict_from_json(d, 'filter_dict'),
name=d['name'] if 'name' in d else None,
query=d['query'] if 'query' in d else None
)
out.validate()
return out

def __call__(self, g: Plottable, prev_node_wavefront: Optional[pd.DataFrame], target_wave_front: Optional[pd.DataFrame]) -> Plottable:
out_g = (g
Expand Down Expand Up @@ -170,6 +233,71 @@ def __init__(
def __repr__(self) -> str:
return f'ASTEdge(direction={self._direction}, edge_match={self._edge_match}, hops={self._hops}, to_fixed_point={self._to_fixed_point}, source_node_match={self._source_node_match}, destination_node_match={self._destination_node_match}, name={self._name}, source_node_query={self._source_node_query}, destination_node_query={self._destination_node_query}, edge_query={self._edge_query})'

def validate(self) -> None:
assert self._hops is None or isinstance(self._hops, int)
assert isinstance(self._to_fixed_point, bool)
assert self._direction in ['forward', 'reverse', 'undirected']
if self._source_node_match is not None:
assert_record_match(self._source_node_match)
if self._edge_match is not None:
assert_record_match(self._edge_match)
if self._destination_node_match is not None:
assert_record_match(self._destination_node_match)
if self._name is not None:
assert isinstance(self._name, str)
if self._source_node_query is not None:
assert isinstance(self._source_node_query, str)
if self._destination_node_query is not None:
assert isinstance(self._destination_node_query, str)
if self._edge_query is not None:
assert isinstance(self._edge_query, str)

def to_json(self, validate=True) -> dict:
if validate:
self.validate()
return {
'type': 'Edge',
'hops': self._hops,
'to_fixed_point': self._to_fixed_point,
'direction': self._direction,
**({'source_node_match': {
k: v.to_json() if isinstance(v, ASTPredicate) else v
for k, v in self._source_node_match.items()
if v is not None
}} if self._source_node_match is not None else {}),
**({'edge_match': {
k: v.to_json() if isinstance(v, ASTPredicate) else v
for k, v in self._edge_match.items()
if v is not None
}} if self._edge_match is not None else {}),
**({'destination_node_match': {
k: v.to_json() if isinstance(v, ASTPredicate) else v
for k, v in self._destination_node_match.items()
if v is not None
}} if self._destination_node_match is not None else {}),
**({'name': self._name} if self._name is not None else {}),
**({'source_node_query': self._source_node_query} if self._source_node_query is not None else {}),
**({'destination_node_query': self._destination_node_query} if self._destination_node_query is not None else {}),
**({'edge_query': self._edge_query} if self._edge_query is not None else {})
}

@classmethod
def from_json(cls, d: dict) -> 'ASTEdge':
out = ASTEdge(
direction=d['direction'] if 'direction' in d else None,
edge_match=maybe_filter_dict_from_json(d, 'edge_match'),
hops=d['hops'] if 'hops' in d else None,
to_fixed_point=d['to_fixed_point'] if 'to_fixed_point' in d else None,
source_node_match=maybe_filter_dict_from_json(d, 'source_node_match'),
destination_node_match=maybe_filter_dict_from_json(d, 'destination_node_match'),
source_node_query=d['source_node_query'] if 'source_node_query' in d else None,
destination_node_query=d['destination_node_query'] if 'destination_node_query' in d else None,
edge_query=d['edge_query'] if 'edge_query' in d else None,
name=d['name'] if 'name' in d else None
)
out.validate()
return out

def __call__(self, g: Plottable, prev_node_wavefront: Optional[pd.DataFrame], target_wave_front: Optional[pd.DataFrame]) -> Plottable:

if logger.isEnabledFor(logging.DEBUG):
Expand Down Expand Up @@ -316,3 +444,17 @@ def __init__(self,

e_undirected = ASTEdgeUndirected # noqa: E305
e = ASTEdgeUndirected # noqa: E305

###

def from_json(o: Dict) -> Union[ASTNode, ASTEdge]:
assert isinstance(o, dict)
assert 'type' in o
out : Union[ASTNode, ASTEdge]
if o['type'] == 'Node':
out = ASTNode.from_json(o)
elif o['type'] == 'Edge':
out = ASTEdge.from_json(o)
else:
raise ValueError(f'Unknown type {o["type"]}')
return out
19 changes: 17 additions & 2 deletions graphistry/compute/chain.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import cast, List, Tuple
from typing import Dict, cast, List, Tuple
import pandas as pd

from graphistry.Plottable import Plottable
from graphistry.util import setup_logger
from .ast import ASTObject, ASTNode, ASTEdge
from .ast import ASTObject, ASTNode, ASTEdge, from_json as ASTObject_from_json

logger = setup_logger(__name__)

Expand Down Expand Up @@ -253,3 +253,18 @@ def chain(self: Plottable, ops: List[ASTObject]) -> Plottable:
g_out = g.nodes(final_nodes_df).edges(final_edges_df)

return g_out

###

def from_json(d: Dict) -> List[ASTObject]:
"""
Convert a JSON AST into a list of ASTObjects
"""
assert isinstance(d, list)
return [ASTObject_from_json(op) for op in d]

def to_json(ops: List[ASTObject]) -> List[Dict]:
"""
Convert a list of ASTObjects into a JSON AST
"""
return [op.to_json() for op in ops]
11 changes: 11 additions & 0 deletions graphistry/compute/predicates/ASTPredicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,15 @@ class ASTPredicate():

@abstractmethod
def __call__(self, s: pd.Series) -> pd.Series:
raise NotImplementedError()

@abstractmethod
def to_json(self, validate=True) -> dict:
raise NotImplementedError()

@classmethod
def from_json(cls, d: dict) -> 'ASTPredicate':
raise NotImplementedError()

def validate(self) -> None:
pass
15 changes: 15 additions & 0 deletions graphistry/compute/predicates/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,21 @@ def __init__(self, keep: Literal['first', 'last', False] = 'first') -> None:
def __call__(self, s: pd.Series) -> pd.Series:
return s.duplicated(keep=self.keep)

def validate(self) -> None:
assert self.keep in ['first', 'last', False]

def to_json(self, validate=True) -> dict:
if validate:
self.validate()
return {'type': 'Duplicated', 'keep': self.keep}

@classmethod
def from_json(cls, d: dict) -> 'Duplicated':
assert 'keep' in d
out = Duplicated(keep=d['keep'])
out.validate()
return out

def duplicated(keep: Literal['first', 'last', False] = 'first') -> Duplicated:
"""
Return whether a given value is duplicated
Expand Down
37 changes: 37 additions & 0 deletions graphistry/compute/predicates/from_json.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from typing import Dict, List, Type

from graphistry.compute.predicates.ASTPredicate import ASTPredicate
from graphistry.compute.predicates.categorical import Duplicated
from graphistry.compute.predicates.is_in import IsIn
from graphistry.compute.predicates.numeric import GT, LT, GE, LE, EQ, NE, Between, IsNA, NotNA
from graphistry.compute.predicates.str import (
Contains, Startswith, Endswith, Match, IsNumeric, IsAlpha, IsDecimal, IsDigit, IsLower, IsUpper,
IsSpace, IsAlnum, IsTitle, IsNull, NotNull
)
from graphistry.compute.predicates.temporal import (
IsMonthStart, IsMonthEnd, IsQuarterStart, IsQuarterEnd,
IsYearStart, IsYearEnd, IsLeapYear
)

predicates : List[Type[ASTPredicate]] = [
Duplicated,
IsIn,
GT, LT, GE, LE, EQ, NE, Between, IsNA, NotNA,
Contains, Startswith, Endswith, Match, IsNumeric, IsAlpha, IsDecimal, IsDigit, IsLower, IsUpper,
IsSpace, IsAlnum, IsDecimal, IsTitle, IsNull, NotNull,
IsMonthStart, IsMonthEnd, IsQuarterStart, IsQuarterEnd,
IsYearStart, IsYearEnd, IsLeapYear
]

type_to_predicate: Dict[str, Type[ASTPredicate]] = {
cls.__name__: cls
for cls in predicates
}

def from_json(d: Dict) -> ASTPredicate:
assert isinstance(d, dict)
assert 'type' in d
assert d['type'] in type_to_predicate
out = type_to_predicate[d['type']].from_json(d)
out.validate()
return out
21 changes: 21 additions & 0 deletions graphistry/compute/predicates/is_in.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Any, List
import pandas as pd

from graphistry.util import assert_json_serializable

from .ASTPredicate import ASTPredicate


Expand All @@ -10,6 +12,25 @@ def __init__(self, options: List[Any]) -> None:

def __call__(self, s: pd.Series) -> pd.Series:
return s.isin(self.options)

def validate(self) -> None:
assert isinstance(self.options, list)
assert_json_serializable(self.options)

def to_json(self, validate=True) -> dict:
if validate:
self.validate()
return {
'type': 'IsIn',
'options': self.options
}

@classmethod
def from_json(cls, d: dict) -> 'IsIn':
assert 'options' in d
out = IsIn(options=d['options'])
out.validate()
return out

def is_in(options: List[Any]) -> IsIn:
return IsIn(options)
Loading

0 comments on commit ece0924

Please sign in to comment.