Skip to content

Commit

Permalink
Refactor consignment indexing logic and add tests to cover uncovered …
Browse files Browse the repository at this point in the history
…edge test cases
  • Loading branch information
anthonyhashemi committed Jan 9, 2025
1 parent 9332a12 commit 853a186
Show file tree
Hide file tree
Showing 6 changed files with 706 additions and 102 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
import logging
from typing import Dict, List, Optional, Tuple, Union

import pg8000
import sqlalchemy
from opensearchpy import OpenSearch, RequestsHttpConnection
from requests_aws4auth import AWS4Auth
from sqlalchemy import create_engine, text
from sqlalchemy.orm import declarative_base, sessionmaker
from sqlalchemy.orm import sessionmaker

from ..aws_helpers import (
_build_db_url,
Expand All @@ -21,6 +21,11 @@


class ConsignmentBulkIndexError(Exception):
"""
Custom exception raised when bulk indexing a consignment fails due to errors
in text extraction or OpenSearch bulk indexing.
"""

pass


Expand Down Expand Up @@ -63,40 +68,31 @@ def bulk_index_consignment(
bucket_name: str,
database_url: str,
open_search_host_url: str,
open_search_http_auth: Union[AWS4Auth, Tuple[str, str]],
open_search_bulk_index_timeout: int = 60,
open_search_http_auth: Union[Tuple[str, str], AWS4Auth],
open_search_bulk_index_timeout: int,
open_search_ca_certs: Optional[str] = None,
) -> None:
"""
Fetch files associated with a consignment and index them in OpenSearch.
Args:
consignment_reference (str): The consignment reference identifier.
bucket_name (str): The S3 bucket name containing files.
database_url (str): The connection string for the PostgreSQL database.
open_search_host_url (str): The host URL of the OpenSearch cluster.
open_search_http_auth (Union[AWS4Auth, Tuple[str, str]]): The authentication credentials for OpenSearch.
open_search_ca_certs (Optional[str]): Path to CA certificates for SSL verification.
Returns:
None
consignment_reference (str): The unique reference identifying the consignment to be indexed.
bucket_name (str): Name of the S3 bucket.
database_url (str): Database connection URL.
open_search_host_url (str): OpenSearch endpoint URL.
open_search_http_auth (AWS4Auth or tuple): Authentication details for OpenSearch.
open_search_bulk_index_timeout (int): Timeout for OpenSearch bulk indexing.
open_search_ca_certs (str, optional): Path to a file containing OpenSearch CA certificates for verification.
Raises:
ConsignmentBulkIndexError: If errors occur during text extraction or bulk indexing.
"""
files = _fetch_files_in_consignment(consignment_reference, database_url)
documents_to_index = _construct_documents(files, bucket_name)
files = fetch_files_in_consignment(consignment_reference, database_url)
documents_to_index = construct_documents(files, bucket_name)

document_text_extraction_exceptions_message = ""
for doc in documents_to_index:
if doc["document"]["text_extraction_status"] not in [
TextExtractionStatus.SKIPPED.value,
TextExtractionStatus.SUCCEEDED.value,
]:
if document_text_extraction_exceptions_message == "":
document_text_extraction_exceptions_message += (
"Text extraction failed on the following documents:"
)
document_text_extraction_exceptions_message += f"\n{doc['file_id']}"

bulk_indexing_exception_message = ""
text_extraction_error = validate_text_extraction(documents_to_index)

bulk_index_error = None
try:
bulk_index_files_in_opensearch(
documents_to_index,
Expand All @@ -106,80 +102,93 @@ def bulk_index_consignment(
open_search_ca_certs,
)
except Exception as bulk_indexing_exception:
bulk_indexing_exception_message = bulk_indexing_exception.text
logger.error("Bulk indexing of files resulted in some errors")

# Combine and raise all errors from failed attempts to extract text or index documents
if (
document_text_extraction_exceptions_message
or bulk_indexing_exception_message
):
consignment_bulk_index_error_message = (
"The following errors occurred when attempting to "
f"bulk index consignment reference: {consignment_reference}"
)
if document_text_extraction_exceptions_message:
consignment_bulk_index_error_message += (
f"\n{document_text_extraction_exceptions_message}"
)
if bulk_indexing_exception_message:
consignment_bulk_index_error_message += (
f"\n{bulk_indexing_exception_message}"
bulk_index_error = str(bulk_indexing_exception)

if text_extraction_error or bulk_index_error:
raise ConsignmentBulkIndexError(
format_bulk_indexing_error_message(
consignment_reference, text_extraction_error, bulk_index_error
)
)


def validate_text_extraction(documents: List[Dict]) -> Optional[str]:
"""
Validate document text extraction statuses and return an error message if any documents failed.
Args:
documents (list): A list of dictionaries, each containing metadata and content of a document.
raise ConsignmentBulkIndexError(consignment_bulk_index_error_message)
Returns:
Optional[str]: An error message if any documents failed text extraction, otherwise None.
"""
errors = [
f"\n{doc['file_id']}"
for doc in documents
if doc["document"]["text_extraction_status"]
not in [
TextExtractionStatus.SKIPPED.value,
TextExtractionStatus.SUCCEEDED.value,
]
]
if errors:
return "Text extraction failed on the following documents:" + "".join(
errors
)
return None


def _construct_documents(files: List[Dict], bucket_name: str) -> List[Dict]:
def construct_documents(files: List[Dict], bucket_name: str) -> List[Dict]:
"""
Construct a list of documents to be indexed in OpenSearch from file metadata.
Args:
files (List[Dict]): The list of file metadata dictionaries.
bucket_name (str): The S3 bucket name where the files are stored.
files (list): A list of file metadata dictionaries retrieved from the database.
bucket_name (str): The name of the S3 bucket containing the files.
Returns:
List[Dict]: A list of documents ready for indexing.
list: A list of dictionaries, each representing a document to be indexed in OpenSearch.
Raises:
Exception: If a file cannot be retrieved from S3.
"""
documents_to_index = []
for file in files:
object_key = file["consignment_reference"] + "/" + str(file["file_id"])
object_key = f"{file['consignment_reference']}/{str(file['file_id'])}"

logger.info(f"Processing file: {object_key}")

file_stream = None
document = file

try:
file_stream = get_s3_file(bucket_name, object_key)
except Exception as e:
logger.error(f"Failed to obtain file {object_key}: {e}")
raise e

document = add_text_content(file, file_stream)

documents_to_index.append(
{"file_id": file["file_id"], "document": document}
)

return documents_to_index


def _fetch_files_in_consignment(
def fetch_files_in_consignment(
consignment_reference: str, database_url: str
) -> List[Dict]:
"""
Fetch file metadata associated with the given consignment reference.
Args:
consignment_reference (str): The consignment reference identifier.
database_url (str): The connection string for the PostgreSQL database.
consignment_reference (str): The unique reference identifying the consignment.
database_url (str): The database connection URL.
Returns:
List[Dict]: A list of file metadata dictionaries.
list: A list of dictionaries, each containing metadata for a file in the consignment.
Raises:
pg8000.Error: If the database query fails.
"""
engine = create_engine(database_url)
Base = declarative_base()
Base.metadata.reflect(bind=engine)
Session = sessionmaker(bind=engine)
session = Session()

Expand Down Expand Up @@ -217,16 +226,15 @@ def _fetch_files_in_consignment(
result = session.execute(
text(query), {"consignment_reference": consignment_reference}
).fetchall()
except pg8000.Error as e:
except sqlalchemy.exc.ProgrammingError as e:
logger.error(
f"Failed to retrieve file metadata from database for consignment reference: {consignment_reference}"
f"Failed to retrieve file metadata for consignment reference: {consignment_reference}"
)
session.close()
raise e

session.close()

# Process query results
files_data = {}

for row in result:
Expand Down Expand Up @@ -256,62 +264,85 @@ def _fetch_files_in_consignment(


def bulk_index_files_in_opensearch(
documents: List[Dict[str, Union[str, Dict]]],
open_search_host_url: str,
open_search_http_auth: Union[AWS4Auth, Tuple[str, str]],
open_search_bulk_index_timeout: int = 60,
open_search_ca_certs: Optional[str] = None,
documents_to_index: List[Dict],
host_url: str,
http_auth: Union[Tuple[str, str], AWS4Auth],
timeout: int = 60,
ca_certs: Optional[str] = None,
) -> None:
"""
Perform bulk indexing of documents in OpenSearch.
Perform bulk indexing of documents into OpenSearch using the OpenSearch library.
Args:
documents (List[Dict[str, Union[str, Dict]]]): The documents to index.
open_search_host_url (str): The OpenSearch cluster URL.
open_search_http_auth (Union[AWS4Auth, Tuple[str, str]]): The authentication credentials.
open_search_ca_certs (Optional[str]): Path to CA certificates for SSL verification.
Returns:
None
documents_to_index (List[Dict[str, Union[str, Dict]]]): The documents to index.
host_url (str): (str): The OpenSearch cluster URL.
http_auth (Union[AWS4Auth, Tuple[str, str]]): The authentication credentials.
timeout (int): Timeout in seconds for bulk indexing operations.
ca_certs (Optional[str]): Path to CA certificates for SSL verification.
Raises:
Exception: If the bulk indexing operation fails.
"""
opensearch_index = "documents"
index = "documents"

bulk_payload = _prepare_bulk_index_payload(documents, opensearch_index)

open_search = OpenSearch(
open_search_host_url,
http_auth=open_search_http_auth,
client = OpenSearch(
host_url,
http_auth=http_auth,
use_ssl=True,
verify_certs=True,
ca_certs=open_search_ca_certs,
ca_certs=ca_certs,
connection_class=RequestsHttpConnection,
)

actions = prepare_bulk_index_payload(documents_to_index, index)

try:
response = open_search.bulk(
index=opensearch_index,
body=bulk_payload,
timeout=open_search_bulk_index_timeout,
)
bulk_response = client.bulk(index=index, body=actions, timeout=timeout)
except Exception as e:
logger.error(f"Opensearch bulk indexing call failed: {e}")
raise e

logger.info("Opensearch bulk indexing call completed with response")
logger.info(response)
logger.info(bulk_response)

if response["errors"]:
if bulk_response["errors"]:
logger.info("Opensearch bulk indexing completed with errors")
error_message = "Opensearch bulk indexing errors:"
for item in response["items"]:
for item in bulk_response["items"]:
if "error" in item.get("index", {}):
error_message += f"\nError for document ID {item['index']['_id']}: {item['index']['error']}"
raise Exception(error_message)
else:
logger.info("Opensearch bulk indexing completed successfully")


def _prepare_bulk_index_payload(
def format_bulk_indexing_error_message(
consignment_reference: str,
text_extraction_error: Optional[str],
bulk_index_error: Optional[str],
) -> str:
"""
Construct a detailed error message for bulk indexing failures.
Args:
consignment_reference (str): The unique reference identifying the consignment.
text_extraction_error (str, optional): Error message related to text extraction.
bulk_index_error (str, optional): Error message related to bulk indexing.
Returns:
str: A formatted error message detailing the issues.
"""
error_message = (
f"Bulk indexing failed for consignment {consignment_reference}:"
)
if text_extraction_error:
error_message += f"\nText Extraction Errors:\n{text_extraction_error}"
if bulk_index_error:
error_message += f"\nBulk Index Errors:\n{bulk_index_error}"
return error_message


def prepare_bulk_index_payload(
documents: List[Dict[str, Union[str, Dict]]], opensearch_index: str
) -> str:
bulk_data = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,24 +47,23 @@ class TextExtractionStatus(Enum):

def add_text_content(file: Dict, file_stream: bytes) -> Dict:
file_type = file["file_name"].split(".")[-1].lower()
file_id = file["file_id"]

if file_type not in SUPPORTED_TEXTRACT_FORMATS:
logger.info(
f"Text extraction skipped for unsupported file type: {file_type}"
f"Text extraction skipped for file {file_id} due to unsupported file type: {file_type}"
)
file["content"] = ""
file["text_extraction_status"] = TextExtractionStatus.SKIPPED.value
else:
try:
file["content"] = extract_text(file_stream, file_type)
logger.info(f"Text extraction succeeded for file {file['file_id']}")
logger.info(f"Text extraction succeeded for file {file_id}")
file["text_extraction_status"] = (
TextExtractionStatus.SUCCEEDED.value
)
except Exception as e:
logger.error(
f"Text extraction failed for file {file['file_id']}: {e}"
)
logger.error(f"Text extraction failed for file {file_id}: {e}")
file["content"] = ""
file["text_extraction_status"] = TextExtractionStatus.FAILED.value

Expand Down
2 changes: 1 addition & 1 deletion data_management/opensearch_indexer/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def temp_db():
return engine


@pytest.fixture(scope="session")
@pytest.fixture(scope="function")
def database(request):
# Launch new PostgreSQL server
postgresql = PostgresqlFactory(cache_initialized_db=True)()
Expand Down
Loading

0 comments on commit 853a186

Please sign in to comment.