Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

[Issue #16] Connect the API to use the search index #63

Merged
merged 20 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 210 additions & 38 deletions api/openapi.generated.yml

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion api/src/adapters/search/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from src.adapters.search.opensearch_client import SearchClient
from src.adapters.search.opensearch_config import get_opensearch_config
from src.adapters.search.opensearch_query_builder import SearchQueryBuilder

__all__ = ["SearchClient", "get_opensearch_config"]
__all__ = ["SearchClient", "get_opensearch_config", "SearchQueryBuilder"]
47 changes: 47 additions & 0 deletions api/src/adapters/search/flask_opensearch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from functools import wraps
from typing import Callable, Concatenate, ParamSpec, TypeVar

from flask import Flask, current_app

from src.adapters.search import SearchClient

_SEARCH_CLIENT_KEY = "search-client"


def register_search_client(search_client: SearchClient, app: Flask) -> None:
app.extensions[_SEARCH_CLIENT_KEY] = search_client


def get_search_client(app: Flask) -> SearchClient:
return app.extensions[_SEARCH_CLIENT_KEY]


P = ParamSpec("P")
T = TypeVar("T")


def with_search_client() -> Callable[[Callable[Concatenate[SearchClient, P], T]], Callable[P, T]]:
"""
Decorator for functions that need a search client.

This decorator will return the shared search client object which
has an internal connection pool that is shared.

Usage:
@with_search_client()
def foo(search_client: search.SearchClient):
...

@with_search_client()
def bar(search_client: search.SearchClient, x: int, y: int):
...
"""

def decorator(f: Callable[Concatenate[SearchClient, P], T]) -> Callable[P, T]:
@wraps(f)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
return f(get_search_client(current_app), *args, **kwargs)

return wrapper

return decorator
53 changes: 46 additions & 7 deletions api/src/adapters/search/opensearch_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,31 @@
import opensearchpy

from src.adapters.search.opensearch_config import OpensearchConfig, get_opensearch_config
from src.adapters.search.opensearch_response import SearchResponse

logger = logging.getLogger(__name__)

# By default, we'll override the default analyzer+tokenization
# for a search index. You can provide your own when calling create_index
DEFAULT_INDEX_ANALYSIS = {
"analyzer": {
"default": {
"type": "custom",
"filter": ["lowercase", "custom_stemmer"],
"tokenizer": "standard",
Copy link
Member

@acouch acouch Jun 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to match the stemmer chosen in the utils?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the utils, I found some issues with what I had configured in the prior PR when setting it up with our actual data and fixed it here.

}
},
# Change the default stemming to use snowball which handles plural
# queries better than the default
# TODO - there are a lot of stemmers, we should take some time to figure out
# which one works best with our particular dataset. Snowball is really
# basic and naive (literally just adjusting suffixes on words in common patterns)
# which might be fine generally, but we work with a lot of acronyms
# and should verify that doesn't cause any issues.
# see: https://opensearch.org/docs/latest/analyzers/token-filters/index/
"filter": {"custom_stemmer": {"type": "snowball", "name": "english"}},
}


class SearchClient:
def __init__(self, opensearch_config: OpensearchConfig | None = None) -> None:
Expand All @@ -17,15 +39,27 @@ def __init__(self, opensearch_config: OpensearchConfig | None = None) -> None:
self._client = opensearchpy.OpenSearch(**_get_connection_parameters(opensearch_config))

def create_index(
self, index_name: str, *, shard_count: int = 1, replica_count: int = 1
self,
index_name: str,
*,
shard_count: int = 1,
replica_count: int = 1,
analysis: dict | None = None
) -> None:
"""
Create an empty search index
"""

# Allow the user to adjust how the index analyzer + tokenization works
# but also have a general default.
if analysis is None:
analysis = DEFAULT_INDEX_ANALYSIS

body = {
"settings": {
"index": {"number_of_shards": shard_count, "number_of_replicas": replica_count}
}
"index": {"number_of_shards": shard_count, "number_of_replicas": replica_count},
"analysis": analysis,
},
}

logger.info("Creating search index %s", index_name, extra={"index_name": index_name})
Expand Down Expand Up @@ -104,12 +138,17 @@ def swap_alias_index(
for index in existing_indexes:
self.delete_index(index)

def search(self, index_name: str, search_query: dict) -> dict:
# TODO - add more when we build out the request/response parsing logic
# we use something like Pydantic to help reorganize the response
# object into something easier to parse.
def search_raw(self, index_name: str, search_query: dict) -> dict:
# Simple wrapper around search if you don't want the request or response
# object handled in any special way.
return self._client.search(index=index_name, body=search_query)

def search(
self, index_name: str, search_query: dict, include_scores: bool = True
) -> SearchResponse:
response = self._client.search(index=index_name, body=search_query)
return SearchResponse.from_opensearch_response(response, include_scores)


def _get_connection_parameters(opensearch_config: OpensearchConfig) -> dict[str, Any]:
# TODO - we'll want to add the AWS connection params here when we set that up
Expand Down
217 changes: 217 additions & 0 deletions api/src/adapters/search/opensearch_query_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
import typing

from src.pagination.pagination_models import SortDirection


class SearchQueryBuilder:
"""
Utility to help build queries to OpenSearch

This helps with making sure everything we want in a search query goes
to the right spot in the large JSON object we're building. Note that
it still requires some understanding of OpenSearch (eg. when to add ".keyword" to a field name)

For example, if you wanted to build a query against a search index containing
books with the following:
* Page size of 5, page number 1
* Sorted by relevancy score descending
* Scored on titles containing "king"
* Where the author is one of Brandon Sanderson or J R.R. Tolkien
* Returning aggregate counts of books by those authors in the full results

This query could either be built manually and look like:

{
"size": 5,
"from": 0,
"track_scores": true,
"sort": [
{
"_score": {
"order": "desc"
}
}
],
"query": {
"bool": {
"must": [
{
"simple_query_string": {
"query": "king",
"fields": [
"title.keyword"
],
"default_operator": "AND"
}
}
],
"filter": [
{
"terms": {
"author.keyword": [
"Brandon Sanderson",
"J R.R. Tolkien"
]
}
}
]
}
},
"aggs": {
"author": {
"terms": {
"field": "author.keyword",
"size": 25,
"min_doc_count": 0
}
}
}
}

Or you could use the builder and produce the same result:

search_query = SearchQueryBuilder()
.pagination(page_size=5, page_number=1)
.sort_by([("relevancy", SortDirection.DESCENDING)])
.simple_query("king", fields=["title.keyword"])
.filter_terms("author.keyword", terms=["Brandon Sanderson", "J R.R. Tolkien"])
.aggregation_terms(aggregation_name="author", field_name="author.keyword", minimum_count=0)
.build()
"""

def __init__(self) -> None:
self.page_size = 25
self.page_number = 1

self.sort_values: list[dict[str, dict[str, str]]] = []

self.must: list[dict] = []
self.filters: list[dict] = []

self.aggregations: dict[str, dict] = {}

def pagination(self, page_size: int, page_number: int) -> typing.Self:
"""
Set the pagination for the search request.

Note that page number should be the human-readable page number
and start counting from 1.
"""
self.page_size = page_size
self.page_number = page_number
return self

def sort_by(self, sort_values: list[typing.Tuple[str, SortDirection]]) -> typing.Self:
"""
List of tuples of field name + sort direction to sort by. If you wish to sort by the relevancy
score provide a field name of "relevancy".

The order of the tuples matters, and the earlier values will take precedence - or put another way
the first tuple is the "primary sort", the second is the "secondary sort", and so on. If
all of the primary sort values are unique, then the secondary sorts won't be relevant.

If this method is not called, no sort info will be added to the request, and OpenSearch
will internally default to sorting by relevancy score. If there is no scores calculated,
then the order is likely the IDs of the documents in the index.

Note that multiple calls to this method will erase any info provided in a prior call.
"""
for field, sort_direction in sort_values:
if field == "relevancy":
field = "_score"

self.sort_values.append({field: {"order": sort_direction.short_form()}})

return self

def simple_query(self, query: str, fields: list[str]) -> typing.Self:
"""
Adds a simple_query_string which queries against the provided fields.

The fields must include the full path to the object, and can include optional suffixes
to adjust the weighting. For example "opportunity_title^4" would increase any scores
derived from that field by 4x.

See: https://opensearch.org/docs/latest/query-dsl/full-text/simple-query-string/
"""
self.must.append(
{"simple_query_string": {"query": query, "fields": fields, "default_operator": "AND"}}
)

return self

def filter_terms(self, field: str, terms: list) -> typing.Self:
"""
For a given field, filter to a set of values.

These filters do not affect the relevancy score, they are purely
a binary filter on the overall results.
"""
self.filters.append({"terms": {field: terms}})
return self

def aggregation_terms(
self, aggregation_name: str, field_name: str, size: int = 25, minimum_count: int = 1
) -> typing.Self:
"""
Add a term aggregation to the request. Aggregations are the counts of particular fields in the
full response and are often displayed next to filters in a search UI.

Size determines how many different values can be returned.
Minimum count determines how many occurrences need to occur to include in the response.
If you pass in 0 for this, then values that don't occur at all in the full result set will be returned.

see: https://opensearch.org/docs/latest/aggregations/bucket/terms/
"""
self.aggregations[aggregation_name] = {
"terms": {"field": field_name, "size": size, "min_doc_count": minimum_count}
}
return self

def build(self) -> dict:
"""
Build the search request
"""

# Base request
page_offset = self.page_size * (self.page_number - 1)
request: dict[str, typing.Any] = {
"size": self.page_size,
"from": page_offset,
# Always include the scores in the response objects
# even if we're sorting by non-relevancy
"track_scores": True,
}

# Add sorting if any was provided
if len(self.sort_values) > 0:
request["sort"] = self.sort_values

# Add a bool query
#
# The "must" block contains anything relevant to scoring
# The "filter" block contains filters that don't affect scoring and act
# as just binary filters
#
# See: https://opensearch.org/docs/latest/query-dsl/compound/bool/
bool_query = {}
if len(self.must) > 0:
bool_query["must"] = self.must

if len(self.filters) > 0:
bool_query["filter"] = self.filters

# Add the query object which wraps the bool query
query_obj = {}
if len(bool_query) > 0:
query_obj["bool"] = bool_query

if len(query_obj) > 0:
request["query"] = query_obj

# Add any aggregations
# see: https://opensearch.org/docs/latest/aggregations/
if len(self.aggregations) > 0:
request["aggs"] = self.aggregations

return request
Loading
Loading