Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve performance with efficient use of FieldMappingCache #389

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 7 additions & 11 deletions eland/field_mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,29 +653,25 @@ def date_field_format(self, es_field_name: str) -> str:
self._mappings_capabilities.es_field_name == es_field_name
].es_date_format.squeeze()

def field_name_pd_dtype(self, es_field_name: str) -> str:
def field_name_pd_dtype(self, es_field_name: str) -> Tuple[bool, Optional[str]]:
"""
Parameters
----------
es_field_name: str

Returns
-------
pd_dtype: str
The pandas data type we map to
Tuple[bool, Optional[str]]
If es_field_name is source field and the pandas data type we map to

Raises
------
KeyError
If es_field_name does not exist in mapping
"""
if es_field_name not in self._mappings_capabilities.es_field_name:
raise KeyError(f"es_field_name {es_field_name} does not exist")
return False, "object"

pd_dtype = self._mappings_capabilities.loc[
df: pd.DataFrame = self._mappings_capabilities.loc[
self._mappings_capabilities.es_field_name == es_field_name
].pd_dtype.squeeze()
return pd_dtype
]
return df.is_source.squeeze(), df.pd_dtype.squeeze()

def add_scripted_field(
self, scripted_field_name: str, display_name: str, pd_dtype: str
Expand Down
10 changes: 5 additions & 5 deletions eland/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ def _metric_aggs(
fields=fields,
es_aggs=es_aggs,
pd_aggs=pd_aggs,
response=response,
response=response, # type: ignore
numeric_only=numeric_only,
is_dataframe_agg=is_dataframe_agg,
percentiles=percentiles,
Expand Down Expand Up @@ -453,7 +453,7 @@ def _terms_aggs(
body.terms_aggs(field, func, field, es_size=es_size)

response = query_compiler._client.search(
index=query_compiler._index_pattern, size=0, body=body.to_search_body()
index=query_compiler._index_pattern, size=0, **body.to_search_body()
)

results = {}
Expand Down Expand Up @@ -499,7 +499,7 @@ def _hist_aggs(
body.hist_aggs(field, field, min_aggs[field], max_aggs[field], num_bins)

response = query_compiler._client.search(
index=query_compiler._index_pattern, size=0, body=body.to_search_body()
index=query_compiler._index_pattern, size=0, **body.to_search_body()
)
# results are like
# "aggregations" : {
Expand Down Expand Up @@ -1040,7 +1040,7 @@ def bucket_generator(
res = query_compiler._client.search(
index=query_compiler._index_pattern,
size=0,
body=body.to_search_body(),
**body.to_search_body(),
)

# Pagination Logic
Expand Down Expand Up @@ -1539,7 +1539,7 @@ def _search_yield_hits(

try:
pit_id = client.open_point_in_time(
index=query_compiler._index_pattern, keep_alive=DEFAULT_PIT_KEEP_ALIVE
index=query_compiler._index_pattern, keep_alive=DEFAULT_PIT_KEEP_ALIVE # type: ignore
)["id"]

# Modify the search with the new point in time ID and keep-alive time.
Expand Down
4 changes: 2 additions & 2 deletions eland/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,11 +322,11 @@ def to_search_body(self) -> Dict[str, Any]:
body["query"] = self._query.build()
return body

def to_count_body(self) -> Optional[Dict[str, Any]]:
def to_count_body(self) -> Dict[str, Any]:
if len(self._aggs) > 0:
warnings.warn(f"Requesting count for agg query {self}")
if self._query.empty():
return None
return {}
else:
return {"query": self._query.build()}

Expand Down
71 changes: 34 additions & 37 deletions eland/query_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,20 @@ def __init__(
Union[str, List[str], Tuple[str, ...], "Elasticsearch"]
] = None,
index_pattern: Optional[str] = None,
display_names=None,
index_field=None,
to_copy=None,
display_names: Optional[List[str]] = None,
index_field: Optional[str] = None,
to_copy: Optional["QueryCompiler"] = None,
) -> None:
# Implement copy as we don't deep copy the client
if to_copy is not None:
self._client = to_copy._client
self._index_pattern = to_copy._index_pattern
self._client: "Elasticsearch" = to_copy._client
self._index_pattern: Optional[str] = to_copy._index_pattern
self._index: "Index" = Index(self, to_copy._index.es_index_field)
self._operations: "Operations" = copy.deepcopy(to_copy._operations)
self._mappings: FieldMappings = copy.deepcopy(to_copy._mappings)
self._field_mapping_cache: Optional["FieldMappingCache"] = copy.deepcopy(
to_copy._field_mapping_cache
)
else:
self._client = ensure_es_client(client)
self._index_pattern = index_pattern
Expand All @@ -104,6 +107,8 @@ def __init__(
)
self._index = Index(self, index_field)
self._operations = Operations()
# This should only be initialized when ETL is done
self._field_mapping_cache = None

@property
def index(self) -> Index:
Expand Down Expand Up @@ -239,7 +244,8 @@ def _es_results_to_pandas(
# This is one of the most performance critical areas of eland, and it repeatedly calls
# self._mappings.field_name_pd_dtype and self._mappings.date_field_format
# therefore create a simple cache for this data
field_mapping_cache = FieldMappingCache(self._mappings)
if self._field_mapping_cache is None:
self._field_mapping_cache = FieldMappingCache(self._mappings)

rows = []
index = []
Expand All @@ -266,7 +272,7 @@ def _es_results_to_pandas(
index.append(index_field)

# flatten row to map correctly to 2D DataFrame
rows.append(self._flatten_dict(row, field_mapping_cache))
rows.append(self._flatten_dict(row))

# Create pandas DataFrame
df = pd.DataFrame(data=rows, index=index)
Expand All @@ -279,7 +285,7 @@ def _es_results_to_pandas(
)

for missing in missing_field_names:
pd_dtype = self._mappings.field_name_pd_dtype(missing)
_, pd_dtype = self._field_mapping_cache.field_name_pd_dtype(missing)
df[missing] = pd.Series(dtype=pd_dtype)

# Rename columns
Expand All @@ -291,7 +297,7 @@ def _es_results_to_pandas(

return df

def _flatten_dict(self, y, field_mapping_cache: "FieldMappingCache"):
def _flatten_dict(self, y):
out = {}

def flatten(x, name=""):
Expand All @@ -301,12 +307,10 @@ def flatten(x, name=""):
is_source_field = False
pd_dtype = "object"
else:
try:
pd_dtype = field_mapping_cache.field_name_pd_dtype(name[:-1])
is_source_field = True
except KeyError:
is_source_field = False
pd_dtype = "object"
(
is_source_field,
pd_dtype,
) = self._field_mapping_cache.field_name_pd_dtype(name[:-1])

if not is_source_field and isinstance(x, dict):
for a in x:
Expand All @@ -321,7 +325,7 @@ def flatten(x, name=""):
# Coerce types - for now just datetime
if pd_dtype == "datetime64[ns]":
x = elasticsearch_date_to_pandas_date(
x, field_mapping_cache.date_field_format(field_name)
x, self._field_mapping_cache.date_field_format(field_name)
)

# Elasticsearch can have multiple values for a field. These are represented as lists, so
Expand Down Expand Up @@ -791,28 +795,21 @@ class FieldMappingCache:

def __init__(self, mappings: "FieldMappings") -> None:
self._mappings = mappings
# This returns all the es_field_names
self._es_field_names: List[str] = mappings.get_field_names()
# Cache these to re-use later
self._field_name_pd_dtype: Dict[str, Tuple[bool, Optional[str]]] = {
i: mappings.field_name_pd_dtype(i) for i in self._es_field_names
}
self._date_field_format: Dict[str, str] = {
i: mappings.date_field_format(i) for i in self._es_field_names
}

self._field_name_pd_dtype: Dict[str, str] = dict()
self._date_field_format: Dict[str, str] = dict()

def field_name_pd_dtype(self, es_field_name: str) -> str:
if es_field_name in self._field_name_pd_dtype:
def field_name_pd_dtype(self, es_field_name: str) -> Tuple[bool, Optional[str]]:
if es_field_name not in self._field_name_pd_dtype:
return False, "object"
else:
return self._field_name_pd_dtype[es_field_name]

pd_dtype = self._mappings.field_name_pd_dtype(es_field_name)

# cache this
self._field_name_pd_dtype[es_field_name] = pd_dtype

return pd_dtype

def date_field_format(self, es_field_name: str) -> str:
if es_field_name in self._date_field_format:
return self._date_field_format[es_field_name]

es_date_field_format = self._mappings.date_field_format(es_field_name)

# cache this
self._date_field_format[es_field_name] = es_date_field_format

return es_date_field_format
return self._date_field_format[es_field_name]
6 changes: 2 additions & 4 deletions tests/field_mappings/test_field_name_pd_dtype_pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# under the License.

# File called _pytest for PyCharm compatability
import pytest
from pandas.testing import assert_series_equal

from eland.field_mappings import FieldMappings
Expand All @@ -35,7 +34,7 @@ def test_all_formats(self):
assert_series_equal(pd_flights.dtypes, ed_field_mappings.dtypes())

for es_field_name in FLIGHTS_MAPPING["mappings"]["properties"].keys():
pd_dtype = ed_field_mappings.field_name_pd_dtype(es_field_name)
_, pd_dtype = ed_field_mappings.field_name_pd_dtype(es_field_name)

assert pd_flights[es_field_name].dtype == pd_dtype

Expand All @@ -44,5 +43,4 @@ def test_non_existant(self):
client=ES_TEST_CLIENT, index_pattern=FLIGHTS_INDEX_NAME
)

with pytest.raises(KeyError):
ed_field_mappings.field_name_pd_dtype("unknown")
assert (False, "object") == ed_field_mappings.field_name_pd_dtype("unknown")