Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
[Issue #179] Incrementally load search data (#180)
Browse files Browse the repository at this point in the history
## Summary
Fixes #179

### Time to review: __10 mins__

## Changes proposed
Updated the load search data task to partially support incrementally
loading + deleting records in the search index rather than just fully
remaking it.

Various changes to the search utilities to support this work

## Context for reviewers
Technically this doesn't fully support a true incremental load as it
updates every record rather than just the ones with changes. I think the
logic necessary to detect changes both deserves its own ticket, and may
evolve when we later support indexing files to OpenSearch, so I think it
makes sense to hold off on that for now.
  • Loading branch information
chouinar authored Sep 13, 2024
1 parent 3a19603 commit 2854d43
Show file tree
Hide file tree
Showing 6 changed files with 341 additions and 13 deletions.
111 changes: 106 additions & 5 deletions api/src/adapters/search/opensearch_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Any, Sequence
from typing import Any, Generator, Iterable

import opensearchpy

Expand Down Expand Up @@ -75,7 +75,7 @@ def delete_index(self, index_name: str) -> None:
def bulk_upsert(
self,
index_name: str,
records: Sequence[dict[str, Any]],
records: Iterable[dict[str, Any]],
primary_key_field: str,
*,
refresh: bool = True
Expand Down Expand Up @@ -103,10 +103,51 @@ def bulk_upsert(
logger.info(
"Upserting records to %s",
index_name,
extra={"index_name": index_name, "record_count": int(len(bulk_operations) / 2)},
extra={
"index_name": index_name,
"record_count": int(len(bulk_operations) / 2),
"operation": "update",
},
)
self._client.bulk(index=index_name, body=bulk_operations, refresh=refresh)

def bulk_delete(self, index_name: str, ids: Iterable[Any], *, refresh: bool = True) -> None:
"""
Bulk delete records from an index
See: https://opensearch.org/docs/latest/api-reference/document-apis/bulk/ for details.
In this method, we delete records based on the IDs passed in.
"""
bulk_operations = []

for _id in ids:
# { "delete": { "_id": "tt2229499" } }
bulk_operations.append({"delete": {"_id": _id}})

logger.info(
"Deleting records from %s",
index_name,
extra={
"index_name": index_name,
"record_count": len(bulk_operations),
"operation": "delete",
},
)
self._client.bulk(index=index_name, body=bulk_operations, refresh=refresh)

def index_exists(self, index_name: str) -> bool:
"""
Check if an index OR alias exists by a given name
"""
return self._client.indices.exists(index_name)

def alias_exists(self, alias_name: str) -> bool:
"""
Check if an alias exists
"""
existing_index_mapping = self._client.cat.aliases(alias_name, format="json")
return len(existing_index_mapping) > 0

def swap_alias_index(
self, index_name: str, alias_name: str, *, delete_prior_indexes: bool = False
) -> None:
Expand Down Expand Up @@ -144,11 +185,71 @@ def search_raw(self, index_name: str, search_query: dict) -> dict:
return self._client.search(index=index_name, body=search_query)

def search(
self, index_name: str, search_query: dict, include_scores: bool = True
self,
index_name: str,
search_query: dict,
include_scores: bool = True,
params: dict | None = None,
) -> SearchResponse:
response = self._client.search(index=index_name, body=search_query)
if params is None:
params = {}

response = self._client.search(index=index_name, body=search_query, params=params)
return SearchResponse.from_opensearch_response(response, include_scores)

def scroll(
self,
index_name: str,
search_query: dict,
include_scores: bool = True,
duration: str = "10m",
) -> Generator[SearchResponse, None, None]:
"""
Scroll (iterate) over a large result set a given search query.
This query uses additional resources to keep the response open, but
keeps a consistent set of results and is useful for backend processes
that need to fetch a large amount of search data. After processing the results,
the scroll lock is closed for you.
This method is setup as a generator method and the results can be iterated over::
for response in search_client.scroll("my_index", {"size": 10000}):
for record in response.records:
process_record(record)
See: https://opensearch.org/docs/latest/api-reference/scroll/
"""

# start scroll
response = self.search(
index_name=index_name,
search_query=search_query,
include_scores=include_scores,
params={"scroll": duration},
)
scroll_id = response.scroll_id

yield response

# iterate
while True:
raw_response = self._client.scroll({"scroll_id": scroll_id, "scroll": duration})
response = SearchResponse.from_opensearch_response(raw_response, include_scores)

# The scroll ID can change between queries according to the docs, so we
# keep updating the value while iterating in case they change.
scroll_id = response.scroll_id

if len(response.records) == 0:
break

yield response

# close scroll
self._client.clear_scroll(scroll_id=scroll_id)


def _get_connection_parameters(opensearch_config: OpensearchConfig) -> dict[str, Any]:
# TODO - we'll want to add the AWS connection params here when we set that up
Expand Down
6 changes: 5 additions & 1 deletion api/src/adapters/search/opensearch_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ class SearchResponse:

aggregations: dict[str, dict[str, int]]

scroll_id: str | None

@classmethod
def from_opensearch_response(
cls, raw_json: dict[str, typing.Any], include_scores: bool = True
Expand Down Expand Up @@ -40,6 +42,8 @@ def from_opensearch_response(
]
}
"""
scroll_id = raw_json.get("_scroll_id", None)

hits = raw_json.get("hits", {})
hits_total = hits.get("total", {})
total_records = hits_total.get("value", 0)
Expand All @@ -59,7 +63,7 @@ def from_opensearch_response(
raw_aggs: dict[str, dict[str, typing.Any]] = raw_json.get("aggregations", {})
aggregations = _parse_aggregations(raw_aggs)

return cls(total_records, records, aggregations)
return cls(total_records, records, aggregations, scroll_id)


def _parse_aggregations(
Expand Down
62 changes: 59 additions & 3 deletions api/src/search/backend/load_opportunities_to_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,52 @@ def __init__(
self,
db_session: db.Session,
search_client: search.SearchClient,
is_full_refresh: bool = True,
config: LoadOpportunitiesToIndexConfig | None = None,
) -> None:
super().__init__(db_session)

self.search_client = search_client
self.is_full_refresh = is_full_refresh

if config is None:
config = LoadOpportunitiesToIndexConfig()
self.config = config

current_timestamp = get_now_us_eastern_datetime().strftime("%Y-%m-%d_%H-%M-%S")
self.index_name = f"{self.config.index_prefix}-{current_timestamp}"
if is_full_refresh:
current_timestamp = get_now_us_eastern_datetime().strftime("%Y-%m-%d_%H-%M-%S")
self.index_name = f"{self.config.index_prefix}-{current_timestamp}"
else:
self.index_name = self.config.alias_name
self.set_metrics({"index_name": self.index_name})

def run_task(self) -> None:
if self.is_full_refresh:
logger.info("Running full refresh")
self.full_refresh()
else:
logger.info("Running incremental load")
self.incremental_updates_and_deletes()

def incremental_updates_and_deletes(self) -> None:
existing_opportunity_ids = self.fetch_existing_opportunity_ids_in_index()

# load the records incrementally
# TODO - The point of this incremental load is to support upcoming work
# to load only opportunities that have changes as we'll eventually be indexing
# files which will take longer. However - the structure of the data isn't yet
# known so I want to hold on actually setting up any change-detection logic
loaded_opportunity_ids = set()
for opp_batch in self.fetch_opportunities():
loaded_opportunity_ids.update(self.load_records(opp_batch))

# Delete
opportunity_ids_to_delete = existing_opportunity_ids - loaded_opportunity_ids

if len(opportunity_ids_to_delete) > 0:
self.search_client.bulk_delete(self.index_name, opportunity_ids_to_delete)

def full_refresh(self) -> None:
# create the index
self.search_client.create_index(
self.index_name,
Expand Down Expand Up @@ -93,11 +124,32 @@ def fetch_opportunities(self) -> Iterator[Sequence[Opportunity]]:
.partitions()
)

def load_records(self, records: Sequence[Opportunity]) -> None:
def fetch_existing_opportunity_ids_in_index(self) -> set[int]:
if not self.search_client.alias_exists(self.index_name):
raise RuntimeError(
"Alias %s does not exist, please run the full refresh job before the incremental job"
% self.index_name
)

opportunity_ids: set[int] = set()

for response in self.search_client.scroll(
self.config.alias_name,
{"size": 10000, "_source": ["opportunity_id"]},
include_scores=False,
):
for record in response.records:
opportunity_ids.add(record["opportunity_id"])

return opportunity_ids

def load_records(self, records: Sequence[Opportunity]) -> set[int]:
logger.info("Loading batch of opportunities...")
schema = OpportunityV1Schema()
json_records = []

loaded_opportunity_ids = set()

for record in records:
logger.info(
"Preparing opportunity for upload to search index",
Expand All @@ -109,4 +161,8 @@ def load_records(self, records: Sequence[Opportunity]) -> None:
json_records.append(schema.dump(record))
self.increment(self.Metrics.RECORDS_LOADED)

loaded_opportunity_ids.add(record.opportunity_id)

self.search_client.bulk_upsert(self.index_name, json_records, "opportunity_id")

return loaded_opportunity_ids
11 changes: 9 additions & 2 deletions api/src/search/backend/load_search_data.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import click

import src.adapters.db as db
import src.adapters.search as search
from src.adapters.db import flask_db
Expand All @@ -8,8 +10,13 @@
@load_search_data_blueprint.cli.command(
"load-opportunity-data", help="Load opportunity data from our database to the search index"
)
@click.option(
"--full-refresh/--incremental",
default=True,
help="Whether to run a full refresh, or only incrementally update oppportunities",
)
@flask_db.with_db_session()
def load_opportunity_data(db_session: db.Session) -> None:
def load_opportunity_data(db_session: db.Session, full_refresh: bool) -> None:
search_client = search.SearchClient()

LoadOpportunitiesToIndex(db_session, search_client).run()
LoadOpportunitiesToIndex(db_session, search_client, full_refresh).run()
92 changes: 92 additions & 0 deletions api/tests/src/adapters/search/test_opensearch_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,25 @@ def test_bulk_upsert(search_client, generic_index):
assert search_client._client.get(generic_index, record["id"])["_source"] == record


def test_bulk_delete(search_client, generic_index):
records = [
{"id": 1, "title": "Green Eggs & Ham", "notes": "why are the eggs green?"},
{"id": 2, "title": "The Cat in the Hat", "notes": "silly cat wears a hat"},
{"id": 3, "title": "One Fish, Two Fish, Red Fish, Blue Fish", "notes": "fish"},
]

search_client.bulk_upsert(generic_index, records, primary_key_field="id")

search_client.bulk_delete(generic_index, [1])

resp = search_client.search(generic_index, {}, include_scores=False)
assert resp.records == records[1:]

search_client.bulk_delete(generic_index, [2, 3])
resp = search_client.search(generic_index, {}, include_scores=False)
assert resp.records == []


def test_swap_alias_index(search_client, generic_index):
alias_name = f"tmp-alias-{uuid.uuid4().int}"

Expand Down Expand Up @@ -101,3 +120,76 @@ def test_swap_alias_index(search_client, generic_index):

# Verify the tmp one was deleted
assert search_client._client.indices.exists(tmp_index) is False


def test_index_or_alias_exists(search_client, generic_index):
# Create a few aliased indexes
index_a = f"test-index-a-{uuid.uuid4().int}"
index_b = f"test-index-b-{uuid.uuid4().int}"
index_c = f"test-index-c-{uuid.uuid4().int}"

search_client.create_index(index_a)
search_client.create_index(index_b)
search_client.create_index(index_c)

alias_index_a = f"test-alias-a-{uuid.uuid4().int}"
alias_index_b = f"test-alias-b-{uuid.uuid4().int}"
alias_index_c = f"test-alias-c-{uuid.uuid4().int}"

search_client.swap_alias_index(index_a, alias_index_a)
search_client.swap_alias_index(index_b, alias_index_b)
search_client.swap_alias_index(index_c, alias_index_c)

# Checking the indexes directly - we expect the index method to return true
# and the alias method to not
assert search_client.index_exists(index_a) is True
assert search_client.index_exists(index_b) is True
assert search_client.index_exists(index_c) is True

assert search_client.alias_exists(index_a) is False
assert search_client.alias_exists(index_b) is False
assert search_client.alias_exists(index_c) is False

# We just created these aliases, they should exist
assert search_client.index_exists(alias_index_a) is True
assert search_client.index_exists(alias_index_b) is True
assert search_client.index_exists(alias_index_c) is True

assert search_client.alias_exists(alias_index_a) is True
assert search_client.alias_exists(alias_index_b) is True
assert search_client.alias_exists(alias_index_c) is True

# Other random things won't be found for either case
assert search_client.index_exists("test-index-a") is False
assert search_client.index_exists("asdasdasd") is False
assert search_client.index_exists(alias_index_a + "-other") is False

assert search_client.alias_exists("test-index-a") is False
assert search_client.alias_exists("asdasdasd") is False
assert search_client.alias_exists(alias_index_a + "-other") is False


def test_scroll(search_client, generic_index):
records = [
{"id": 1, "title": "Green Eggs & Ham", "notes": "why are the eggs green?"},
{"id": 2, "title": "The Cat in the Hat", "notes": "silly cat wears a hat"},
{"id": 3, "title": "One Fish, Two Fish, Red Fish, Blue Fish", "notes": "fish"},
{"id": 4, "title": "Fox in Socks", "notes": "why he wearing socks?"},
{"id": 5, "title": "The Lorax", "notes": "trees"},
{"id": 6, "title": "Oh, the Places You'll Go", "notes": "graduation gift"},
{"id": 7, "title": "Hop on Pop", "notes": "Let him sleep"},
{"id": 8, "title": "How the Grinch Stole Christmas", "notes": "who"},
]

search_client.bulk_upsert(generic_index, records, primary_key_field="id")

results = []

for response in search_client.scroll(generic_index, {"size": 3}):
assert response.total_records == 8
results.append(response)

assert len(results) == 3
assert len(results[0].records) == 3
assert len(results[1].records) == 3
assert len(results[2].records) == 2
Loading

0 comments on commit 2854d43

Please sign in to comment.