diff --git a/data_management/opensearch_indexer/opensearch_indexer/index_consignment/bulk_index_consignment.py b/data_management/opensearch_indexer/opensearch_indexer/index_consignment/bulk_index_consignment.py index 82d08b9c..a87c24c9 100644 --- a/data_management/opensearch_indexer/opensearch_indexer/index_consignment/bulk_index_consignment.py +++ b/data_management/opensearch_indexer/opensearch_indexer/index_consignment/bulk_index_consignment.py @@ -2,11 +2,11 @@ import logging from typing import Dict, List, Optional, Tuple, Union -import pg8000 +import sqlalchemy from opensearchpy import OpenSearch, RequestsHttpConnection from requests_aws4auth import AWS4Auth from sqlalchemy import create_engine, text -from sqlalchemy.orm import declarative_base, sessionmaker +from sqlalchemy.orm import sessionmaker from ..aws_helpers import ( _build_db_url, @@ -21,6 +21,11 @@ class ConsignmentBulkIndexError(Exception): + """ + Custom exception raised when bulk indexing a consignment fails due to errors + in text extraction or OpenSearch bulk indexing. + """ + pass @@ -63,40 +68,31 @@ def bulk_index_consignment( bucket_name: str, database_url: str, open_search_host_url: str, - open_search_http_auth: Union[AWS4Auth, Tuple[str, str]], - open_search_bulk_index_timeout: int = 60, + open_search_http_auth: Union[Tuple[str, str], AWS4Auth], + open_search_bulk_index_timeout: int, open_search_ca_certs: Optional[str] = None, ) -> None: """ Fetch files associated with a consignment and index them in OpenSearch. Args: - consignment_reference (str): The consignment reference identifier. - bucket_name (str): The S3 bucket name containing files. - database_url (str): The connection string for the PostgreSQL database. - open_search_host_url (str): The host URL of the OpenSearch cluster. - open_search_http_auth (Union[AWS4Auth, Tuple[str, str]]): The authentication credentials for OpenSearch. - open_search_ca_certs (Optional[str]): Path to CA certificates for SSL verification. - - Returns: - None + consignment_reference (str): The unique reference identifying the consignment to be indexed. + bucket_name (str): Name of the S3 bucket. + database_url (str): Database connection URL. + open_search_host_url (str): OpenSearch endpoint URL. + open_search_http_auth (AWS4Auth or tuple): Authentication details for OpenSearch. + open_search_bulk_index_timeout (int): Timeout for OpenSearch bulk indexing. + open_search_ca_certs (str, optional): Path to a file containing OpenSearch CA certificates for verification. + + Raises: + ConsignmentBulkIndexError: If errors occur during text extraction or bulk indexing. """ - files = _fetch_files_in_consignment(consignment_reference, database_url) - documents_to_index = _construct_documents(files, bucket_name) + files = fetch_files_in_consignment(consignment_reference, database_url) + documents_to_index = construct_documents(files, bucket_name) - document_text_extraction_exceptions_message = "" - for doc in documents_to_index: - if doc["document"]["text_extraction_status"] not in [ - TextExtractionStatus.SKIPPED.value, - TextExtractionStatus.SUCCEEDED.value, - ]: - if document_text_extraction_exceptions_message == "": - document_text_extraction_exceptions_message += ( - "Text extraction failed on the following documents:" - ) - document_text_extraction_exceptions_message += f"\n{doc['file_id']}" - - bulk_indexing_exception_message = "" + text_extraction_error = validate_text_extraction(documents_to_index) + + bulk_index_error = None try: bulk_index_files_in_opensearch( documents_to_index, @@ -106,50 +102,62 @@ def bulk_index_consignment( open_search_ca_certs, ) except Exception as bulk_indexing_exception: - bulk_indexing_exception_message = bulk_indexing_exception.text - logger.error("Bulk indexing of files resulted in some errors") - - # Combine and raise all errors from failed attempts to extract text or index documents - if ( - document_text_extraction_exceptions_message - or bulk_indexing_exception_message - ): - consignment_bulk_index_error_message = ( - "The following errors occurred when attempting to " - f"bulk index consignment reference: {consignment_reference}" - ) - if document_text_extraction_exceptions_message: - consignment_bulk_index_error_message += ( - f"\n{document_text_extraction_exceptions_message}" - ) - if bulk_indexing_exception_message: - consignment_bulk_index_error_message += ( - f"\n{bulk_indexing_exception_message}" + bulk_index_error = str(bulk_indexing_exception) + + if text_extraction_error or bulk_index_error: + raise ConsignmentBulkIndexError( + format_bulk_indexing_error_message( + consignment_reference, text_extraction_error, bulk_index_error ) + ) + + +def validate_text_extraction(documents: List[Dict]) -> Optional[str]: + """ + Validate document text extraction statuses and return an error message if any documents failed. + + Args: + documents (list): A list of dictionaries, each containing metadata and content of a document. - raise ConsignmentBulkIndexError(consignment_bulk_index_error_message) + Returns: + Optional[str]: An error message if any documents failed text extraction, otherwise None. + """ + errors = [ + f"\n{doc['file_id']}" + for doc in documents + if doc["document"]["text_extraction_status"] + not in [ + TextExtractionStatus.SKIPPED.value, + TextExtractionStatus.SUCCEEDED.value, + ] + ] + if errors: + return "Text extraction failed on the following documents:" + "".join( + errors + ) + return None -def _construct_documents(files: List[Dict], bucket_name: str) -> List[Dict]: +def construct_documents(files: List[Dict], bucket_name: str) -> List[Dict]: """ Construct a list of documents to be indexed in OpenSearch from file metadata. Args: - files (List[Dict]): The list of file metadata dictionaries. - bucket_name (str): The S3 bucket name where the files are stored. + files (list): A list of file metadata dictionaries retrieved from the database. + bucket_name (str): The name of the S3 bucket containing the files. Returns: - List[Dict]: A list of documents ready for indexing. + list: A list of dictionaries, each representing a document to be indexed in OpenSearch. + + Raises: + Exception: If a file cannot be retrieved from S3. """ documents_to_index = [] for file in files: - object_key = file["consignment_reference"] + "/" + str(file["file_id"]) + object_key = f"{file['consignment_reference']}/{str(file['file_id'])}" logger.info(f"Processing file: {object_key}") - file_stream = None - document = file - try: file_stream = get_s3_file(bucket_name, object_key) except Exception as e: @@ -157,29 +165,30 @@ def _construct_documents(files: List[Dict], bucket_name: str) -> List[Dict]: raise e document = add_text_content(file, file_stream) - documents_to_index.append( {"file_id": file["file_id"], "document": document} ) + return documents_to_index -def _fetch_files_in_consignment( +def fetch_files_in_consignment( consignment_reference: str, database_url: str ) -> List[Dict]: """ Fetch file metadata associated with the given consignment reference. Args: - consignment_reference (str): The consignment reference identifier. - database_url (str): The connection string for the PostgreSQL database. + consignment_reference (str): The unique reference identifying the consignment. + database_url (str): The database connection URL. Returns: - List[Dict]: A list of file metadata dictionaries. + list: A list of dictionaries, each containing metadata for a file in the consignment. + + Raises: + pg8000.Error: If the database query fails. """ engine = create_engine(database_url) - Base = declarative_base() - Base.metadata.reflect(bind=engine) Session = sessionmaker(bind=engine) session = Session() @@ -217,16 +226,15 @@ def _fetch_files_in_consignment( result = session.execute( text(query), {"consignment_reference": consignment_reference} ).fetchall() - except pg8000.Error as e: + except sqlalchemy.exc.ProgrammingError as e: logger.error( - f"Failed to retrieve file metadata from database for consignment reference: {consignment_reference}" + f"Failed to retrieve file metadata for consignment reference: {consignment_reference}" ) session.close() raise e session.close() - # Process query results files_data = {} for row in result: @@ -256,54 +264,51 @@ def _fetch_files_in_consignment( def bulk_index_files_in_opensearch( - documents: List[Dict[str, Union[str, Dict]]], - open_search_host_url: str, - open_search_http_auth: Union[AWS4Auth, Tuple[str, str]], - open_search_bulk_index_timeout: int = 60, - open_search_ca_certs: Optional[str] = None, + documents_to_index: List[Dict], + host_url: str, + http_auth: Union[Tuple[str, str], AWS4Auth], + timeout: int = 60, + ca_certs: Optional[str] = None, ) -> None: """ - Perform bulk indexing of documents in OpenSearch. + Perform bulk indexing of documents into OpenSearch using the OpenSearch library. Args: - documents (List[Dict[str, Union[str, Dict]]]): The documents to index. - open_search_host_url (str): The OpenSearch cluster URL. - open_search_http_auth (Union[AWS4Auth, Tuple[str, str]]): The authentication credentials. - open_search_ca_certs (Optional[str]): Path to CA certificates for SSL verification. - - Returns: - None + documents_to_index (List[Dict[str, Union[str, Dict]]]): The documents to index. + host_url (str): (str): The OpenSearch cluster URL. + http_auth (Union[AWS4Auth, Tuple[str, str]]): The authentication credentials. + timeout (int): Timeout in seconds for bulk indexing operations. + ca_certs (Optional[str]): Path to CA certificates for SSL verification. + + Raises: + Exception: If the bulk indexing operation fails. """ - opensearch_index = "documents" + index = "documents" - bulk_payload = _prepare_bulk_index_payload(documents, opensearch_index) - - open_search = OpenSearch( - open_search_host_url, - http_auth=open_search_http_auth, + client = OpenSearch( + host_url, + http_auth=http_auth, use_ssl=True, verify_certs=True, - ca_certs=open_search_ca_certs, + ca_certs=ca_certs, connection_class=RequestsHttpConnection, ) + actions = prepare_bulk_index_payload(documents_to_index, index) + try: - response = open_search.bulk( - index=opensearch_index, - body=bulk_payload, - timeout=open_search_bulk_index_timeout, - ) + bulk_response = client.bulk(index=index, body=actions, timeout=timeout) except Exception as e: logger.error(f"Opensearch bulk indexing call failed: {e}") raise e logger.info("Opensearch bulk indexing call completed with response") - logger.info(response) + logger.info(bulk_response) - if response["errors"]: + if bulk_response["errors"]: logger.info("Opensearch bulk indexing completed with errors") error_message = "Opensearch bulk indexing errors:" - for item in response["items"]: + for item in bulk_response["items"]: if "error" in item.get("index", {}): error_message += f"\nError for document ID {item['index']['_id']}: {item['index']['error']}" raise Exception(error_message) @@ -311,7 +316,33 @@ def bulk_index_files_in_opensearch( logger.info("Opensearch bulk indexing completed successfully") -def _prepare_bulk_index_payload( +def format_bulk_indexing_error_message( + consignment_reference: str, + text_extraction_error: Optional[str], + bulk_index_error: Optional[str], +) -> str: + """ + Construct a detailed error message for bulk indexing failures. + + Args: + consignment_reference (str): The unique reference identifying the consignment. + text_extraction_error (str, optional): Error message related to text extraction. + bulk_index_error (str, optional): Error message related to bulk indexing. + + Returns: + str: A formatted error message detailing the issues. + """ + error_message = ( + f"Bulk indexing failed for consignment {consignment_reference}:" + ) + if text_extraction_error: + error_message += f"\nText Extraction Errors:\n{text_extraction_error}" + if bulk_index_error: + error_message += f"\nBulk Index Errors:\n{bulk_index_error}" + return error_message + + +def prepare_bulk_index_payload( documents: List[Dict[str, Union[str, Dict]]], opensearch_index: str ) -> str: bulk_data = [] diff --git a/data_management/opensearch_indexer/opensearch_indexer/text_extraction.py b/data_management/opensearch_indexer/opensearch_indexer/text_extraction.py index adfd22df..214d1874 100644 --- a/data_management/opensearch_indexer/opensearch_indexer/text_extraction.py +++ b/data_management/opensearch_indexer/opensearch_indexer/text_extraction.py @@ -47,24 +47,23 @@ class TextExtractionStatus(Enum): def add_text_content(file: Dict, file_stream: bytes) -> Dict: file_type = file["file_name"].split(".")[-1].lower() + file_id = file["file_id"] if file_type not in SUPPORTED_TEXTRACT_FORMATS: logger.info( - f"Text extraction skipped for unsupported file type: {file_type}" + f"Text extraction skipped for file {file_id} due to unsupported file type: {file_type}" ) file["content"] = "" file["text_extraction_status"] = TextExtractionStatus.SKIPPED.value else: try: file["content"] = extract_text(file_stream, file_type) - logger.info(f"Text extraction succeeded for file {file['file_id']}") + logger.info(f"Text extraction succeeded for file {file_id}") file["text_extraction_status"] = ( TextExtractionStatus.SUCCEEDED.value ) except Exception as e: - logger.error( - f"Text extraction failed for file {file['file_id']}: {e}" - ) + logger.error(f"Text extraction failed for file {file_id}: {e}") file["content"] = "" file["text_extraction_status"] = TextExtractionStatus.FAILED.value diff --git a/data_management/opensearch_indexer/tests/conftest.py b/data_management/opensearch_indexer/tests/conftest.py index e482a2de..465dba1c 100644 --- a/data_management/opensearch_indexer/tests/conftest.py +++ b/data_management/opensearch_indexer/tests/conftest.py @@ -86,7 +86,7 @@ def temp_db(): return engine -@pytest.fixture(scope="session") +@pytest.fixture(scope="function") def database(request): # Launch new PostgreSQL server postgresql = PostgresqlFactory(cache_initialized_db=True)() diff --git a/data_management/opensearch_indexer/tests/test_bulk_index_files_in_opensearch.py b/data_management/opensearch_indexer/tests/test_bulk_index_files_in_opensearch.py index ef8efa9b..f1be8fee 100644 --- a/data_management/opensearch_indexer/tests/test_bulk_index_files_in_opensearch.py +++ b/data_management/opensearch_indexer/tests/test_bulk_index_files_in_opensearch.py @@ -1,11 +1,26 @@ import re from unittest import mock +from uuid import uuid4 +import boto3 import pytest +import sqlalchemy +import sqlalchemy.exc +from moto import mock_aws from opensearch_indexer.index_consignment.bulk_index_consignment import ( + ConsignmentBulkIndexError, + bulk_index_consignment, bulk_index_files_in_opensearch, + construct_documents, + fetch_files_in_consignment, + format_bulk_indexing_error_message, + validate_text_extraction, ) from opensearchpy import RequestsHttpConnection +from sqlalchemy import create_engine, text +from sqlalchemy.orm import sessionmaker + +from .conftest import Body, Consignment, File, Series @mock.patch( @@ -335,3 +350,314 @@ def test_index_file_content_and_metadata_in_opensearch_with_bulk_api_exception( assert [rec.message for rec in caplog.records] == [ "Opensearch bulk indexing call failed: Simulated OpenSearch bulk API failure" ] + + +def test_validate_text_extraction_with_all_success_or_skipped_status(): + """ + Given a list of documents where all have either 'SUCCEEDED' or 'SKIPPED' text extraction status, + When validate_text_extraction is called, + Then it should return None, indicating no errors. + """ + documents = [ + { + "file_id": "1", + "document": { + "file_id": "1", + "content": "Test file content", + "text_extraction_status": "SUCCEEDED", + }, + }, + { + "file_id": "2", + "document": { + "file_id": "2", + "text_extraction_status": "SKIPPED", + }, + }, + ] + assert validate_text_extraction(documents) is None + + +def test_validate_text_extraction_with_failed_status_returns_error_message(): + """ + Given a list of documents where some have 'FAILED' text extraction status, + When validate_text_extraction is called, + Then it should return an error message listing the file IDs with failed status. + """ + documents = [ + { + "file_id": "1", + "document": { + "file_id": "1", + "content": "Test file content", + "text_extraction_status": "SUCCEEDED", + }, + }, + { + "file_id": "2", + "document": { + "file_id": "2", + "text_extraction_status": "FAILED", + }, + }, + { + "file_id": "3", + "document": { + "file_id": "3", + "text_extraction_status": "FAILED", + }, + }, + ] + expected_error_message = ( + "Text extraction failed on the following documents:\n2\n3" + ) + assert validate_text_extraction(documents) == expected_error_message + + +def test_format_bulk_indexing_error_message_with_both_errors(): + """ + Given a consignment reference, a text extraction error message, and a bulk index error message, + When format_bulk_indexing_error_message is called, + Then it should return a detailed error message including both errors. + """ + consignment_reference = "test-consignment" + text_extraction_error = "Failed to extract text for documents: 1, 2" + bulk_index_error = "Bulk index timeout occurred for documents: 3, 4" + + result = format_bulk_indexing_error_message( + consignment_reference, text_extraction_error, bulk_index_error + ) + + expected_message = ( + "Bulk indexing failed for consignment test-consignment:" + "\nText Extraction Errors:\nFailed to extract text for documents: 1, 2" + "\nBulk Index Errors:\nBulk index timeout occurred for documents: 3, 4" + ) + assert result == expected_message + + +def test_format_bulk_indexing_error_message_with_only_text_extraction_error(): + """ + Given a consignment reference and a text extraction error message, + When format_bulk_indexing_error_message is called with no bulk index error, + Then it should return a detailed error message including only the text extraction error. + """ + consignment_reference = "test-consignment" + text_extraction_error = "Failed to extract text for documents: 1, 2" + bulk_index_error = None + + result = format_bulk_indexing_error_message( + consignment_reference, text_extraction_error, bulk_index_error + ) + + expected_message = ( + "Bulk indexing failed for consignment test-consignment:" + "\nText Extraction Errors:\nFailed to extract text for documents: 1, 2" + ) + assert result == expected_message + + +def test_format_bulk_indexing_error_message_with_only_bulk_index_error(): + """ + Given a consignment reference and a bulk index error message, + When format_bulk_indexing_error_message is called with no text extraction error, + Then it should return a detailed error message including only the bulk index error. + """ + consignment_reference = "test-consignment" + text_extraction_error = None + bulk_index_error = "Bulk index timeout occurred for documents: 3, 4" + + result = format_bulk_indexing_error_message( + consignment_reference, text_extraction_error, bulk_index_error + ) + + expected_message = ( + "Bulk indexing failed for consignment test-consignment:" + "\nBulk Index Errors:\nBulk index timeout occurred for documents: 3, 4" + ) + assert result == expected_message + + +def test_format_bulk_indexing_error_message_with_no_errors(): + """ + Given a consignment reference with no text extraction error or bulk index error, + When format_bulk_indexing_error_message is called, + Then it should return an error message with no specific error details. + """ + consignment_reference = "test-consignment" + text_extraction_error = None + bulk_index_error = None + + result = format_bulk_indexing_error_message( + consignment_reference, text_extraction_error, bulk_index_error + ) + + expected_message = "Bulk indexing failed for consignment test-consignment:" + assert result == expected_message + + +def test_fetch_files_in_consignment_invalid_column_with_logging( + database, caplog +): + """ + Given a databse without correctly setup tables + When fetch_files_in_consignment is called with a consignment reference and the database url, + Then it should log an error indicating the failure to retrieve file metadata. + """ + database_url = database.url() + # Create the engine and session + engine = create_engine(database_url) + + # Minimal setup, omitting necessary schema parts + with engine.connect() as connection: + connection.execute( + text( + """ + CREATE TABLE Consignment ( + ConsignmentId SERIAL PRIMARY KEY, + ConsignmentReference TEXT + ); + """ + ) + ) + + consignment_reference = "reference-123" + with pytest.raises(sqlalchemy.exc.ProgrammingError): + fetch_files_in_consignment(consignment_reference, database_url) + + assert ( + "Failed to retrieve file metadata for consignment reference: reference-123" + in caplog.text + ) + + +@mock_aws +def test_construct_documents_s3_error_logging(caplog): + """ + Given a list of file metadata and a bucket name, + When the function attempts to retrieve files from S3 and encounters an error (e.g., file not found), + Then it should log the error and raise an exception. + """ + files = [ + { + "file_id": 1, + "consignment_reference": "consignment-123", + }, + ] + bucket_name = "test-bucket" + + # Create the mock bucket + s3 = boto3.client("s3", region_name="us-east-1") + s3.create_bucket(Bucket=bucket_name) + + # No files are uploaded to the bucket, so retrieving any file should fail. + # The file we are trying to access does not exist in the mocked bucket. + + with pytest.raises(Exception): + with caplog.at_level("ERROR"): + construct_documents(files, bucket_name) + + # Assert that the correct error message is logged + assert ( + "Failed to obtain file consignment-123/1: An error occurred (NoSuchKey) when calling " + "the GetObject operation: The specified key does not exist." + ) in caplog.text + + +# Test for bulk_index_consignment error handling when bulk index operation fails +@mock.patch( + "opensearch_indexer.index_consignment.bulk_index_consignment.bulk_index_files_in_opensearch" +) +@mock_aws +def test_bulk_index_consignment_error_handling( + mock_bulk_index_files_in_opensearch, caplog, database +): + """ + Given a consignment reference and files, + When the bulk indexing fails, + Then it should log the error and raise a ConsignmentBulkIndexError. + """ + + # Prepare input data + database_url = database.url() + consignment_reference = "consignment-123" + bucket_name = "test-bucket" + open_search_host_url = "https://opensearch.example.com" + open_search_http_auth = ("username", "password") + open_search_bulk_index_timeout = 30 + + # Insert mock data into PostgreSQL + engine = create_engine(database_url) + from data_management.opensearch_indexer.tests.conftest import Base + + Base.metadata.create_all(engine) # Create tables for the test + Session = sessionmaker(bind=engine) + session = Session() + + body_id = uuid4() + series_id = uuid4() + consignment_id = uuid4() + + consignment_reference = "TDR-2024-ABCD" + + file_id = uuid4() + + session.add_all( + [ + Consignment( + ConsignmentId=consignment_id, + ConsignmentType="foo", + ConsignmentReference=consignment_reference, + SeriesId=series_id, + ), + Series(SeriesId=series_id, Name="series-name", BodyId=body_id), + Body( + BodyId=body_id, + Name="body-name", + Description="transferring body description", + ), + File( + FileId=file_id, + FileType="File", + FileName="test-document.txt", + FileReference="file-123", + FilePath="/path/to/file", + CiteableReference="cite-ref-123", + ConsignmentId=consignment_id, + ), + ] + ) + # Insert mock consignment and file data + session.commit() + + # Upload a file to S3 (mocked by moto) + s3 = boto3.client("s3", region_name="us-east-1") + s3.create_bucket(Bucket=bucket_name) + s3.put_object( + Bucket=bucket_name, + Key=f"{consignment_reference}/{file_id}", + Body=b"Test file content", + ) + + # Mock OpenSearch bulk indexing to raise an error + mock_bulk_index_files_in_opensearch.side_effect = Exception( + "Some opensearch bulk indexing error string." + ) + + # Run the function and capture logs + with pytest.raises( + ConsignmentBulkIndexError, + match=( + "Bulk indexing failed for consignment TDR-2024-ABCD:\n" + "Bulk Index Errors:\nSome opensearch bulk indexing error string." + ), + ): + with caplog.at_level("ERROR"): + bulk_index_consignment( + consignment_reference, + bucket_name, + database_url, + open_search_host_url, + open_search_http_auth, + open_search_bulk_index_timeout, + ) diff --git a/data_management/opensearch_indexer/tests/test_consignment_lambda_handler.py b/data_management/opensearch_indexer/tests/test_consignment_lambda_handler.py index ad6f1c39..23fa9e7d 100644 --- a/data_management/opensearch_indexer/tests/test_consignment_lambda_handler.py +++ b/data_management/opensearch_indexer/tests/test_consignment_lambda_handler.py @@ -4,6 +4,7 @@ import boto3 import botocore +import pytest from moto import mock_aws from opensearch_indexer.index_consignment.lambda_function import lambda_handler from opensearch_indexer.text_extraction import TextExtractionStatus @@ -316,3 +317,105 @@ def test_lambda_handler_invokes_bulk_index_with_correct_file_data( assert args[3] == 600 assert args[4] is None + + +@mock_aws +def test_lambda_handler_raises_exception_when_no_consignment_reference_in_sns_message(): + """ + Test case for the lambda_handler function to ensure correct integration with the OpenSearch indexer. + + Given: + - An S3 bucket containing files. + - A secret stored in AWS Secrets Manager containing configuration details such as database connection, + OpenSearch host URL, and an IAM role for OpenSearch access. + + When: + - The lambda_handler function is invoked via an S3 event notification. + + Then: + - The bulk_index_files_in_opensearch function is called with the correct parameters for each file: + - Correct file metadata and content, including the extracted text, metadata properties, + and associated consignment details. + - The OpenSearch host URL. + - An AWS4Auth object with credentials derived from the assumed IAM role. + - The timeout for the OpenSearch bulk indexing operation. + """ + consignment_reference = None + sns_message = { + "properties": { + "messageType": "uk.gov.nationalarchives.da.messages.ayrmetadata.loaded", + "function": "ddt-ayrmetadataload-process", + }, + "parameters": { + "reference": consignment_reference, + "originator": "DDT", + }, + } + + event = { + "Records": [ + { + "Sns": { + "Message": json.dumps(sns_message), + }, + } + ] + } + + with pytest.raises( + Exception, + match="Missing reference in SNS Message required for indexing", + ): + lambda_handler(event, None) + + +@mock_aws +def test_lambda_handler_raises_exception_when_no_secret_id_env_var_set( + monkeypatch, +): + """ + Test case for the lambda_handler function to ensure correct integration with the OpenSearch indexer. + + Given: + - An S3 bucket containing files. + - A secret stored in AWS Secrets Manager containing configuration details such as database connection, + OpenSearch host URL, and an IAM role for OpenSearch access. + + When: + - The lambda_handler function is invoked via an S3 event notification. + + Then: + - The bulk_index_files_in_opensearch function is called with the correct parameters for each file: + - Correct file metadata and content, including the extracted text, metadata properties, + and associated consignment details. + - The OpenSearch host URL. + - An AWS4Auth object with credentials derived from the assumed IAM role. + - The timeout for the OpenSearch bulk indexing operation. + """ + consignment_reference = "TDR-2024-ABCD" + sns_message = { + "properties": { + "messageType": "uk.gov.nationalarchives.da.messages.ayrmetadata.loaded", + "function": "ddt-ayrmetadataload-process", + }, + "parameters": { + "reference": consignment_reference, + "originator": "DDT", + }, + } + + event = { + "Records": [ + { + "Sns": { + "Message": json.dumps(sns_message), + }, + } + ] + } + + with pytest.raises( + Exception, + match="Missing SECRET_ID environment variable required for indexing", + ): + lambda_handler(event, None) diff --git a/data_management/opensearch_indexer/tests/test_text_extraction.py b/data_management/opensearch_indexer/tests/test_text_extraction.py index abec5574..cb050961 100644 --- a/data_management/opensearch_indexer/tests/test_text_extraction.py +++ b/data_management/opensearch_indexer/tests/test_text_extraction.py @@ -1,7 +1,12 @@ from pathlib import Path +from unittest.mock import patch import pytest -from opensearch_indexer.text_extraction import extract_text +from opensearch_indexer.text_extraction import ( + TextExtractionStatus, + add_text_content, + extract_text, +) class TestExtractText: @@ -79,3 +84,143 @@ def test_extract_text(self, file_name, file_type, expected_output): file_stream = file.read() assert extract_text(file_stream, file_type) == expected_output + + +# Sample supported and unsupported file formats +SUPPORTED_TEXTRACT_FORMATS = [ + "pdf", + "txt", + "docx", +] # Example of supported formats + + +# Mock the extract_text function to simulate text extraction behavior +@pytest.fixture +def mock_extract_text(): + with patch("opensearch_indexer.text_extraction.extract_text") as mock: + yield mock + + +# Test for successfully extracting text from a supported file type +def test_add_text_content_success(mock_extract_text, caplog): + """ + Given a supported file type and a valid file stream, + When text extraction succeeds, + Then the content is updated and the status is set to SUCCEEDED. + """ + + # Given + file = { + "file_id": 1, + "file_name": "example.pdf", + "content": "", + "text_extraction_status": "", + } + file_stream = b"Some file content" + mock_extract_text.return_value = "Extracted text" + + # When + result = add_text_content(file, file_stream) + + # Then + assert result["content"] == "Extracted text" + assert ( + result["text_extraction_status"] == TextExtractionStatus.SUCCEEDED.value + ) + mock_extract_text.assert_called_once_with(file_stream, "pdf") + + assert "Text extraction succeeded for file 1" in caplog.text + + +# Test for unsupported file type +def test_add_text_content_unsupported_format(caplog): + """ + Given an unsupported file type, + When text extraction is skipped, + Then the content is set to an empty string and the status is set to SKIPPED. + """ + + # Given + file = { + "file_id": 2, + "file_name": "example.exe", # Unsupported file type + "content": "", + "text_extraction_status": "", + } + file_stream = b"Some content that won't be extracted" + + # When + result = add_text_content(file, file_stream) + + # Then + assert result["content"] == "" + assert ( + result["text_extraction_status"] == TextExtractionStatus.SKIPPED.value + ) + + assert ( + "Text extraction skipped for file 2 due to unsupported file type: exe" + in caplog.text + ) + + +# Test for text extraction failure +def test_add_text_content_failure(mock_extract_text, caplog): + """ + Given a supported file type and a failing text extraction, + When text extraction fails due to an error, + Then the content is set to an empty string and the status is set to FAILED. + """ + + # Given + file = { + "file_id": 3, + "file_name": "example.txt", # Supported file type + "content": "", + "text_extraction_status": "", + } + file_stream = b"Some content" + + # Simulate a failure in text extraction + mock_extract_text.side_effect = Exception("Text extraction failed") + + # When + with caplog.at_level("ERROR"): + result = add_text_content(file, file_stream) + + # Then + assert result["content"] == "" + assert result["text_extraction_status"] == TextExtractionStatus.FAILED.value + mock_extract_text.assert_called_once_with(file_stream, "txt") + + assert ( + "Text extraction failed for file 3: Text extraction failed" + in caplog.text + ) + + +# Test for file type without extension +def test_add_text_content_no_extension(): + """ + Given a file without an extension, + When trying to extract text, + Then the file is skipped and the status is set to SKIPPED. + """ + + # Given + file = { + "file_id": 4, + "file_name": "example", # No file extension + "content": "", + "text_extraction_status": "", + } + file_stream = b"Some content" + + # When + result = add_text_content(file, file_stream) + + # Then + assert result["content"] == "" + assert ( + result["text_extraction_status"] == TextExtractionStatus.SKIPPED.value + )