-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathneo4j_arrow_client.py
208 lines (179 loc) · 7.75 KB
/
neo4j_arrow_client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
from typing import Any, Dict, Iterable, Union, Tuple
from enum import Enum
import json, os, sys
import pyarrow as pa
import pyarrow.flight as flight
class ClientState(Enum):
READY = "ready"
FEEDING_NODES = "feeding_nodes"
FEEDING_EDGES = "feeding_edges"
AWAITING_GRAPH = "awaiting_graph"
GRAPH_READY = "done"
class Neo4jArrowClient():
def __init__(self, host: str, *, port: int, user: str,
password: str, tls: bool = False,
concurrency: int = 4, database: str, projection: str = None):
self.host = host
self.port = port
self.user = user
self.password = password
self.tls = tls
self.client: flight.FlightClient = None
self.call_opts = None
self.database = database
self.projection= projection
self.concurrency = concurrency
self.state = ClientState.READY
def __str__(self):
return f"Neo4jArrowClient{{{self.user}@{self.host}:{self.port}/{self.database}}}"
def __getstate__(self):
state = self.__dict__.copy()
# Remove the FlightClient and CallOpts as they're not serializable
if "client" in state:
del state["client"]
if "call_opts" in state:
del state["call_opts"]
return state
def copy(self):
client = Neo4jArrowClient(self.host, port=self.port, user=self.user,
password=self.password,
tls=self.tls, concurrency=self.concurrency,
database=self.database,projection=self.projection)
client.state = self.state
return client
def _client(self):
"""Lazy client construction to help pickle this class."""
if not hasattr(self, "client") or not self.client:
self.call_opts = None
if self.tls:
location = flight.Location.for_grpc_tls(self.host, self.port)
else:
location = flight.Location.for_grpc_tcp(self.host, self.port)
client = flight.FlightClient(location)
if self.user and self.password:
(header, token) = client.authenticate_basic_token(self.user, self.password)
if header:
self.call_opts = flight.FlightCallOptions(timeout=None, headers=[(header, token)])
self.client = client
return self.client
def _send_action(self, action: str, body: Dict[str, Any]) -> dict:
"""
Communicates an Arrow Action message to the GDS Arrow Service.
"""
client = self._client()
try:
payload = json.dumps(body).encode("utf-8")
result = client.do_action(
flight.Action(action, payload),
options=self.call_opts
)
return json.loads(next(result).body.to_pybytes().decode())
except Exception as e:
print(f"send_action error: {e}")
#return None
raise e
def _write_table(self, desc: bytes, table: pa.Table, mappingfn = None) -> Tuple[int, int]:
"""
Write a PyArrow Table to the GDS Flight service.
"""
client = self._client()
fn = mappingfn or self._nop
upload_descriptor = flight.FlightDescriptor.for_command(
json.dumps(desc).encode("utf-8")
)
writer, _ = client.do_put(upload_descriptor, table.schema, options=self.call_opts)
with writer:
try:
writer.write_table(table)
return table.num_rows, table.get_total_buffer_size()
except Exception as e:
print(f"_write_table error: {e}")
return 0, 0
@classmethod
def _nop(*args, **kwargs):
pass
def _write_batches(self, desc: bytes, batches, mappingfn = None) -> Tuple[int, int]:
"""
Write PyArrow RecordBatches to the GDS Flight service.
"""
batches = iter(batches)
fn = mappingfn or self._nop
first = fn(next(batches, None))
if not first:
raise Exception("empty iterable of record batches provided")
client = self._client()
upload_descriptor = flight.FlightDescriptor.for_command(
json.dumps(desc).encode("utf-8")
)
rows, nbytes = 0, 0
writer, reader = client.do_put(upload_descriptor, first.schema, options=self.call_opts)
with writer:
try:
writer.write_batch(first)
rows += first.num_rows
nbytes += first.get_total_buffer_size()
for remaining in batches:
writer.write_batch(fn(remaining))
rows += remaining.num_rows
nbytes += remaining.get_total_buffer_size()
except Exception as e:
print(f"_write_batches error: {e}")
return rows, nbytes
def create_database(self, action: str = "CREATE_DATABASE", config: Dict[str, Any] = {}) -> Dict[str, Any]:
assert self.state == ClientState.READY
if not config:
config = {
"name": self.database,
"concurrency": self.concurrency,
"high_io": True,
"force": True,
"record_format": "aligned",
"id_property": "id",
"id_type": "INTEGER"
}
result = self._send_action(action, config)
if result:
self.state = ClientState.FEEDING_NODES
return result
def create_projection(self, action: str = "CREATE_GRAPH", config: Dict[str, Any] = {}) -> Dict[str, Any]:
assert self.state == ClientState.READY
if not config:
config = {
"name": self.projection,
"database_name": self.database,
"concurrency": self.concurrency,
}
result = self._send_action(action, config)
if result:
self.state = ClientState.FEEDING_NODES
return result
def write_nodes(self, nodes: Union[pa.Table, Iterable[pa.RecordBatch]], mappingfn = None) -> Tuple[int, int]:
assert self.state == ClientState.FEEDING_NODES
desc = { "name": self.database if self.projection == None else self.projection, "entity_type": "node" }
if isinstance(nodes, pa.Table):
return self._write_table(desc, nodes, mappingfn)
return self._write_batches(desc, nodes, mappingfn)
def nodes_done(self) -> Dict[str, Any]:
assert self.state == ClientState.FEEDING_NODES
result = self._send_action("NODE_LOAD_DONE", { "name": self.database if self.projection == None else self.projection })
if result:
self.state = ClientState.FEEDING_EDGES
return result
def write_edges(self, edges: Union[pa.Table, Iterable[pa.RecordBatch]], mappingfn = None) -> Tuple[int, int]:
assert self.state == ClientState.FEEDING_EDGES
desc = { "name": self.database if self.projection == None else self.projection, "entity_type": "relationship" }
if isinstance(edges, pa.Table):
return self._write_table(desc, edges, mappingfn)
return self._write_batches(desc, edges, mappingfn)
def edges_done(self) -> Dict[str, Any]:
assert self.state == ClientState.FEEDING_EDGES
result = self._send_action("RELATIONSHIP_LOAD_DONE",
{ "name": self.database if self.projection == None else self.projection })
if result:
self.state = ClientState.AWAITING_GRAPH
return result
def wait(timeout: int = 0):
"""wait for completion"""
assert self.state == ClientState.AWAITING_GRAPH
self.state = ClientState.AWAITING_GRAPH
pass