forked from snowflakedb/snowflake-connector-python
-
Notifications
You must be signed in to change notification settings - Fork 0
/
arrow_result.pyx
262 lines (230 loc) · 10.2 KB
/
arrow_result.pyx
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
#
# Copyright (c) 2012-2019 Snowflake Computing Inc. All right reserved.
#
# cython: profile=False
# cython: language_level=3
from base64 import b64decode
import io
from logging import getLogger
from .telemetry import TelemetryField
from .time_util import get_time_millis
from .arrow_iterator import PyArrowIterator
from .arrow_iterator import EmptyPyArrowIterator
from .arrow_iterator import ROW_UNIT, TABLE_UNIT, EMPTY_UNIT
from .arrow_context import ArrowConverterContext
from .options import pandas, installed_pandas
logger = getLogger(__name__)
if installed_pandas:
from pyarrow import concat_tables
else:
logger.info("Failed to import optional packages, pyarrow")
cdef class ArrowResult:
cdef:
object _cursor
object _connection
int total_row_index;
int _chunk_index
int _chunk_count
int _current_chunk_row_count
list _description
list _column_idx_to_name
object _current_chunk_row
object _chunk_downloader
object _arrow_context
str _iter_unit
object _use_dict_result
object _use_numpy
def __init__(self, raw_response, cursor, use_dict_result=False, _chunk_downloader=None):
self._reset()
self._cursor = cursor
self._connection = cursor.connection
self._use_dict_result = use_dict_result
self._use_numpy = self._connection._numpy
self._column_idx_to_name = []
for idx, column in enumerate(raw_response.get(u'rowtype')):
self._column_idx_to_name.append(column[u'name'])
self._chunk_info(raw_response, _chunk_downloader)
def _chunk_info(self, data, _chunk_downloader=None):
self.total_row_index = -1 # last fetched number of rows
self._chunk_index = 0
self._chunk_count = 0
# result as arrow chunk
rowset_b64 = data.get(u'rowsetBase64')
if rowset_b64:
arrow_bytes = b64decode(rowset_b64)
self._arrow_context = ArrowConverterContext(self._connection._session_parameters)
self._current_chunk_row = PyArrowIterator(self._cursor, io.BytesIO(arrow_bytes),
self._arrow_context, self._use_dict_result,
self._use_numpy)
else:
logger.debug("Data from first gs response is empty")
self._current_chunk_row = EmptyPyArrowIterator()
self._iter_unit = EMPTY_UNIT
if u'chunks' in data:
chunks = data[u'chunks']
self._chunk_count = len(chunks)
logger.debug(u'chunk size=%s', self._chunk_count)
# prepare the downloader for further fetch
qrmk = data[u'qrmk'] if u'qrmk' in data else None
chunk_headers = None
if u'chunkHeaders' in data:
chunk_headers = {}
for header_key, header_value in data[
u'chunkHeaders'].items():
chunk_headers[header_key] = header_value
logger.debug(
u'added chunk header: key=%s, value=%s',
header_key,
header_value)
logger.debug(u'qrmk=%s', qrmk)
self._chunk_downloader = _chunk_downloader if _chunk_downloader \
else self._connection._chunk_downloader_class(
chunks, self._connection, self._cursor, qrmk, chunk_headers,
query_result_format='arrow',
prefetch_threads=self._connection.client_prefetch_threads)
def __iter__(self):
return self
def __next__(self):
if self._iter_unit == EMPTY_UNIT:
self._iter_unit = ROW_UNIT
self._current_chunk_row.init(self._iter_unit)
elif self._iter_unit == TABLE_UNIT:
logger.debug(u'The iterator has been built for fetching arrow table')
raise RuntimeError
is_done = False
try:
row = None
self.total_row_index += 1
try:
row = self._current_chunk_row.__next__()
except StopIteration:
if self._chunk_index < self._chunk_count:
logger.debug(
u"chunk index: %s, chunk_count: %s",
self._chunk_index, self._chunk_count)
next_chunk = self._chunk_downloader.next_chunk()
self._current_chunk_row = next_chunk.result_data
self._current_chunk_row.init(self._iter_unit)
self._chunk_index += 1
try:
row = self._current_chunk_row.__next__()
except StopIteration:
is_done = True
raise IndexError
else:
if self._chunk_count > 0 and \
self._chunk_downloader is not None:
self._chunk_downloader.terminate()
self._cursor._log_telemetry_job_data(
TelemetryField.TIME_DOWNLOADING_CHUNKS,
self._chunk_downloader._total_millis_downloading_chunks)
self._cursor._log_telemetry_job_data(
TelemetryField.TIME_PARSING_CHUNKS,
self._chunk_downloader._total_millis_parsing_chunks)
self._chunk_downloader = None
self._chunk_count = 0
self._current_chunk_row = EmptyPyArrowIterator()
is_done = True
if is_done:
raise StopIteration
return row
except IndexError:
# returns None if the iteration is completed so that iter() stops
return None
finally:
if is_done and self._cursor._first_chunk_time:
logger.info("fetching data done")
time_consume_last_result = get_time_millis() - self._cursor._first_chunk_time
self._cursor._log_telemetry_job_data(
TelemetryField.TIME_CONSUME_LAST_RESULT,
time_consume_last_result)
def _reset(self):
self.total_row_index = -1 # last fetched number of rows
self._current_chunk_row_count = 0
self._current_chunk_row = EmptyPyArrowIterator()
self._chunk_index = 0
if hasattr(self, u'_chunk_count') and self._chunk_count > 0 and \
self._chunk_downloader is not None:
self._chunk_downloader.terminate()
self._chunk_count = 0
self._chunk_downloader = None
self._arrow_context = None
self._iter_unit = EMPTY_UNIT
def _fetch_arrow_batches(self):
'''
Fetch Arrow Table in batch, where 'batch' refers to Snowflake Chunk
Thus, the batch size (the number of rows in table) may be different
'''
if self._iter_unit == EMPTY_UNIT:
self._iter_unit = TABLE_UNIT
elif self._iter_unit == ROW_UNIT:
logger.debug(u'The iterator has been built for fetching row')
raise RuntimeError
try:
self._current_chunk_row.init(self._iter_unit)
logger.debug(u'Init table iterator successfully, current chunk index: %s, '
u'chunk count: %s', self._chunk_index, self._chunk_count)
while self._chunk_index <= self._chunk_count:
stop_iteration_except = False
try:
table = self._current_chunk_row.__next__()
except StopIteration:
stop_iteration_except = True
if self._chunk_index < self._chunk_count: # multiple chunks
logger.debug(
u"chunk index: %s, chunk_count: %s",
self._chunk_index, self._chunk_count)
next_chunk = self._chunk_downloader.next_chunk()
self._current_chunk_row = next_chunk.result_data
self._current_chunk_row.init(self._iter_unit)
self._chunk_index += 1
if stop_iteration_except:
continue
else:
yield table
else:
if self._chunk_count > 0 and \
self._chunk_downloader is not None:
self._chunk_downloader.terminate()
self._cursor._log_telemetry_job_data(
TelemetryField.TIME_DOWNLOADING_CHUNKS,
self._chunk_downloader._total_millis_downloading_chunks)
self._cursor._log_telemetry_job_data(
TelemetryField.TIME_PARSING_CHUNKS,
self._chunk_downloader._total_millis_parsing_chunks)
self._chunk_downloader = None
self._chunk_count = 0
self._current_chunk_row = EmptyPyArrowIterator()
finally:
if self._cursor._first_chunk_time:
logger.info("fetching data into pandas dataframe done")
time_consume_last_result = get_time_millis() - self._cursor._first_chunk_time
self._cursor._log_telemetry_job_data(
TelemetryField.TIME_CONSUME_LAST_RESULT,
time_consume_last_result)
def _fetch_arrow_all(self):
"""
Fetch a single Arrow Table
"""
tables = list(self._fetch_arrow_batches())
if tables:
return concat_tables(tables)
else:
return None
def _fetch_pandas_batches(self, **kwargs):
u"""
Fetch Pandas dataframes in batch, where 'batch' refers to Snowflake Chunk
Thus, the batch size (the number of rows in dataframe) is optimized by
Snowflake Python Connector
"""
for table in self._fetch_arrow_batches():
yield table.to_pandas(**kwargs)
def _fetch_pandas_all(self, **kwargs):
"""
Fetch a single Pandas dataframe
"""
table = self._fetch_arrow_all()
if table:
return table.to_pandas(**kwargs)
else:
return pandas.DataFrame(columns=self._column_idx_to_name)