Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

[Issue #179] Incrementally load search data #180

Merged
merged 3 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 66 additions & 5 deletions api/src/adapters/search/opensearch_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Any, Sequence
from typing import Any, Generator, Iterable

import opensearchpy

Expand Down Expand Up @@ -75,7 +75,7 @@ def delete_index(self, index_name: str) -> None:
def bulk_upsert(
self,
index_name: str,
records: Sequence[dict[str, Any]],
records: Iterable[dict[str, Any]],
primary_key_field: str,
*,
refresh: bool = True
Expand Down Expand Up @@ -103,7 +103,29 @@ def bulk_upsert(
logger.info(
"Upserting records to %s",
index_name,
extra={"index_name": index_name, "record_count": int(len(bulk_operations) / 2)},
extra={
"index_name": index_name,
"record_count": int(len(bulk_operations) / 2),
"operation": "update",
},
)
self._client.bulk(index=index_name, body=bulk_operations, refresh=refresh)

def bulk_delete(self, index_name: str, ids: Iterable[Any], *, refresh: bool = True) -> None:
bulk_operations = []

for _id in ids:
# { "delete": { "_id": "tt2229499" } }
bulk_operations.append({"delete": {"_id": _id}})

logger.info(
"Deleting records from %s",
index_name,
extra={
"index_name": index_name,
"record_count": len(bulk_operations),
"operation": "delete",
},
)
self._client.bulk(index=index_name, body=bulk_operations, refresh=refresh)

Expand Down Expand Up @@ -144,11 +166,50 @@ def search_raw(self, index_name: str, search_query: dict) -> dict:
return self._client.search(index=index_name, body=search_query)

def search(
self, index_name: str, search_query: dict, include_scores: bool = True
self,
index_name: str,
search_query: dict,
include_scores: bool = True,
params: dict | None = None,
) -> SearchResponse:
response = self._client.search(index=index_name, body=search_query)
if params is None:
params = {}

response = self._client.search(index=index_name, body=search_query, params=params)
return SearchResponse.from_opensearch_response(response, include_scores)

def scroll(
self,
index_name: str,
search_query: dict,
include_scores: bool = True,
duration: str = "10m",
) -> Generator[SearchResponse, None, None]:
# start scroll
response = self.search(
index_name=index_name,
search_query=search_query,
include_scores=include_scores,
params={"scroll": duration},
)
scroll_id = response.scroll_id

yield response

# iterate
while True:
raw_response = self._client.scroll({"scroll_id": scroll_id, "scroll": duration})
response = SearchResponse.from_opensearch_response(raw_response, include_scores)
scroll_id = response.scroll_id

if len(response.records) == 0:
break

yield response

# close scroll
self._client.clear_scroll(scroll_id=scroll_id)


def _get_connection_parameters(opensearch_config: OpensearchConfig) -> dict[str, Any]:
# TODO - we'll want to add the AWS connection params here when we set that up
Expand Down
6 changes: 5 additions & 1 deletion api/src/adapters/search/opensearch_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ class SearchResponse:

aggregations: dict[str, dict[str, int]]

scroll_id: str | None

@classmethod
def from_opensearch_response(
cls, raw_json: dict[str, typing.Any], include_scores: bool = True
Expand Down Expand Up @@ -40,6 +42,8 @@ def from_opensearch_response(
]
}
"""
scroll_id = raw_json.get("_scroll_id", None)

hits = raw_json.get("hits", {})
hits_total = hits.get("total", {})
total_records = hits_total.get("value", 0)
Expand All @@ -59,7 +63,7 @@ def from_opensearch_response(
raw_aggs: dict[str, dict[str, typing.Any]] = raw_json.get("aggregations", {})
aggregations = _parse_aggregations(raw_aggs)

return cls(total_records, records, aggregations)
return cls(total_records, records, aggregations, scroll_id)


def _parse_aggregations(
Expand Down
56 changes: 53 additions & 3 deletions api/src/search/backend/load_opportunities_to_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,50 @@
self,
db_session: db.Session,
search_client: search.SearchClient,
is_full_refresh: bool = True,
config: LoadOpportunitiesToIndexConfig | None = None,
) -> None:
super().__init__(db_session)

self.search_client = search_client
self.is_full_refresh = is_full_refresh

if config is None:
config = LoadOpportunitiesToIndexConfig()
self.config = config

current_timestamp = get_now_us_eastern_datetime().strftime("%Y-%m-%d_%H-%M-%S")
self.index_name = f"{self.config.index_prefix}-{current_timestamp}"
# TODO - determine if this is for a full refresh and set the index name based on that
if is_full_refresh:
current_timestamp = get_now_us_eastern_datetime().strftime("%Y-%m-%d_%H-%M-%S")
self.index_name = f"{self.config.index_prefix}-{current_timestamp}"
else:
self.index_name = self.config.alias_name
self.set_metrics({"index_name": self.index_name})

def run_task(self) -> None:
if self.is_full_refresh:
self.full_refresh()
else:
self.incremental_updates_and_deletes()

def incremental_updates_and_deletes(self) -> None:
existing_opportunity_ids = self.fetch_existing_opportunity_ids_in_index()

# load the records
# TODO - we should probably not load everything if what is in the search index
# is identical - otherwise this isn't much different from the full refresh
# BUT - need some sort of mechanism for determining that (timestamp?)
loaded_opportunity_ids = set()
for opp_batch in self.fetch_opportunities():
loaded_opportunity_ids.update(self.load_records(opp_batch))

# Delete
opportunity_ids_to_delete = existing_opportunity_ids - loaded_opportunity_ids

if len(opportunity_ids_to_delete) > 0:
self.search_client.bulk_delete(self.index_name, opportunity_ids_to_delete)

def full_refresh(self) -> None:
# create the index
self.search_client.create_index(
self.index_name,
Expand Down Expand Up @@ -93,11 +122,28 @@
.partitions()
)

def load_records(self, records: Sequence[Opportunity]) -> None:
def fetch_existing_opportunity_ids_in_index(self) -> set[int]:
# TODO - check if the index exists already

opportunity_ids: set[int] = set()

for response in self.search_client.scroll(
self.config.alias_name,
{"size": 10000, "_source": ["opportunity_id"]},
include_scores=False,
):
for record in response.records:
opportunity_ids.add(record.get("opportunity_id"))

Check warning on line 136 in api/src/search/backend/load_opportunities_to_index.py

View workflow job for this annotation

GitHub Actions / API Lint, Format & Tests

Argument 1 to "add" of "set" has incompatible type "Any | None"; expected "int" [arg-type]

return opportunity_ids

def load_records(self, records: Sequence[Opportunity]) -> set[int]:
logger.info("Loading batch of opportunities...")
schema = OpportunityV1Schema()
json_records = []

loaded_opportunity_ids = set()

for record in records:
logger.info(
"Preparing opportunity for upload to search index",
Expand All @@ -109,4 +155,8 @@
json_records.append(schema.dump(record))
self.increment(self.Metrics.RECORDS_LOADED)

loaded_opportunity_ids.add(record.opportunity_id)

self.search_client.bulk_upsert(self.index_name, json_records, "opportunity_id")

return loaded_opportunity_ids
45 changes: 45 additions & 0 deletions api/tests/src/adapters/search/test_opensearch_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,25 @@ def test_bulk_upsert(search_client, generic_index):
assert search_client._client.get(generic_index, record["id"])["_source"] == record


def test_bulk_delete(search_client, generic_index):
records = [
{"id": 1, "title": "Green Eggs & Ham", "notes": "why are the eggs green?"},
{"id": 2, "title": "The Cat in the Hat", "notes": "silly cat wears a hat"},
{"id": 3, "title": "One Fish, Two Fish, Red Fish, Blue Fish", "notes": "fish"},
]

search_client.bulk_upsert(generic_index, records, primary_key_field="id")

search_client.bulk_delete(generic_index, [1])

resp = search_client.search(generic_index, {}, include_scores=False)
assert resp.records == records[1:]

search_client.bulk_delete(generic_index, [2, 3])
resp = search_client.search(generic_index, {}, include_scores=False)
assert resp.records == []


def test_swap_alias_index(search_client, generic_index):
alias_name = f"tmp-alias-{uuid.uuid4().int}"

Expand Down Expand Up @@ -101,3 +120,29 @@ def test_swap_alias_index(search_client, generic_index):

# Verify the tmp one was deleted
assert search_client._client.indices.exists(tmp_index) is False


def test_scroll(search_client, generic_index):
records = [
{"id": 1, "title": "Green Eggs & Ham", "notes": "why are the eggs green?"},
{"id": 2, "title": "The Cat in the Hat", "notes": "silly cat wears a hat"},
{"id": 3, "title": "One Fish, Two Fish, Red Fish, Blue Fish", "notes": "fish"},
{"id": 4, "title": "Fox in Socks", "notes": "why he wearing socks?"},
{"id": 5, "title": "The Lorax", "notes": "trees"},
{"id": 6, "title": "Oh, the Places You'll Go", "notes": "graduation gift"},
{"id": 7, "title": "Hop on Pop", "notes": "Let him sleep"},
{"id": 8, "title": "How the Grinch Stole Christmas", "notes": "who"},
]

search_client.bulk_upsert(generic_index, records, primary_key_field="id")

results = []

for response in search_client.scroll(generic_index, {"size": 3}):
assert response.total_records == 8
results.append(response)

assert len(results) == 3
assert len(results[0].records) == 3
assert len(results[1].records) == 3
assert len(results[2].records) == 2
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,18 @@
LoadOpportunitiesToIndex,
LoadOpportunitiesToIndexConfig,
)
from src.util.datetime_util import get_now_us_eastern_datetime
from tests.conftest import BaseTestClass
from tests.src.db.models.factories import OpportunityFactory


class TestLoadOpportunitiesToIndex(BaseTestClass):
class TestLoadOpportunitiesToIndexFullRefresh(BaseTestClass):
@pytest.fixture(scope="class")
def load_opportunities_to_index(self, db_session, search_client, opportunity_index_alias):
config = LoadOpportunitiesToIndexConfig(
alias_name=opportunity_index_alias, index_prefix="test-load-opps"
)
return LoadOpportunitiesToIndex(db_session, search_client, config)
return LoadOpportunitiesToIndex(db_session, search_client, True, config)

def test_load_opportunities_to_index(
self,
Expand Down Expand Up @@ -83,3 +84,60 @@ def test_load_opportunities_to_index(
assert set([opp.opportunity_id for opp in opportunities]) == set(
[record["opportunity_id"] for record in resp.records]
)


class TestLoadOpportunitiesToIndexPartialRefresh(BaseTestClass):
@pytest.fixture(scope="class")
def load_opportunities_to_index(self, db_session, search_client, opportunity_index_alias):
config = LoadOpportunitiesToIndexConfig(
alias_name=opportunity_index_alias, index_prefix="test-load-opps"
)
return LoadOpportunitiesToIndex(db_session, search_client, False, config)

def test_load_opportunities_to_index(
self,
truncate_opportunities,
enable_factory_create,
db_session,
search_client,
opportunity_index_alias,
load_opportunities_to_index,
):
# TODO - need to test/modify logic to be better about handling not already having an index
index_name = "partial-refresh-index-" + get_now_us_eastern_datetime().strftime(
"%Y-%m-%d_%H-%M-%S"
)
search_client.create_index(index_name)
search_client.swap_alias_index(
index_name, load_opportunities_to_index.config.alias_name, delete_prior_indexes=True
)

# Load a bunch of records into the DB
opportunities = []
opportunities.extend(OpportunityFactory.create_batch(size=6, is_posted_summary=True))
opportunities.extend(OpportunityFactory.create_batch(size=3, is_forecasted_summary=True))
opportunities.extend(OpportunityFactory.create_batch(size=2, is_closed_summary=True))
opportunities.extend(
OpportunityFactory.create_batch(size=8, is_archived_non_forecast_summary=True)
)
opportunities.extend(
OpportunityFactory.create_batch(size=6, is_archived_forecast_summary=True)
)

load_opportunities_to_index.run()

resp = search_client.search(opportunity_index_alias, {"size": 100})
assert resp.total_records == len(opportunities)

# Add a few more opportunities that will be created
opportunities.extend(OpportunityFactory.create_batch(size=3, is_posted_summary=True))

# Delete some opportunities
opportunities_to_delete = [opportunities.pop(), opportunities.pop(), opportunities.pop()]
for opportunity in opportunities_to_delete:
db_session.delete(opportunity)

load_opportunities_to_index.run()

resp = search_client.search(opportunity_index_alias, {"size": 100})
assert resp.total_records == len(opportunities)
Loading