From 584b647b964825418713b78d41dcb3ef3a536e4f Mon Sep 17 00:00:00 2001 From: Abhijeeth Padarthi Date: Mon, 12 Feb 2024 14:53:40 -0600 Subject: [PATCH] community[minor]: AWS Athena Document Loader (#15625) - **Description:** Adds the document loader for [AWS Athena](https://aws.amazon.com/athena/), a serverless and interactive analytics service. - **Dependencies:** Added boto3 as a dependency --- .../document_loaders/athena.ipynb | 110 ++++++++++++ .../document_loaders/__init__.py | 2 + .../document_loaders/athena.py | 167 ++++++++++++++++++ .../document_loaders/test_imports.py | 1 + 4 files changed, 280 insertions(+) create mode 100644 docs/docs/integrations/document_loaders/athena.ipynb create mode 100644 libs/community/langchain_community/document_loaders/athena.py diff --git a/docs/docs/integrations/document_loaders/athena.ipynb b/docs/docs/integrations/document_loaders/athena.ipynb new file mode 100644 index 0000000000000..3fd655c0f0a04 --- /dev/null +++ b/docs/docs/integrations/document_loaders/athena.ipynb @@ -0,0 +1,110 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "MwTWzDxYgbrR" + }, + "source": [ + "# Athena\n", + "\n", + "This notebooks goes over how to load documents from AWS Athena" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "F0zaLR3xgWmO" + }, + "outputs": [], + "source": [ + "! pip install boto3" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "076NLjfngoWJ" + }, + "outputs": [], + "source": [ + "from langchain_community.document_loaders.athena import AthenaLoader" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "XpMRQwU9gu44" + }, + "outputs": [], + "source": [ + "database_name = \"my_database\"\n", + "s3_output_path = \"s3://my_bucket/query_results/\"\n", + "query = \"SELECT * FROM my_table\"\n", + "profile_name = \"my_profile\"\n", + "\n", + "loader = AthenaLoader(\n", + " query=query,\n", + " database=database_name,\n", + " s3_output_uri=s3_output_path,\n", + " profile_name=profile_name,\n", + ")\n", + "\n", + "documents = loader.load()\n", + "print(documents)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "5IBapL3ejoEt" + }, + "source": [ + "Example with metadata columns" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "wMx6nI1qjryD" + }, + "outputs": [], + "source": [ + "database_name = \"my_database\"\n", + "s3_output_path = \"s3://my_bucket/query_results/\"\n", + "query = \"SELECT * FROM my_table\"\n", + "profile_name = \"my_profile\"\n", + "metadata_columns = [\"_row\", \"_created_at\"]\n", + "\n", + "loader = AthenaLoader(\n", + " query=query,\n", + " database=database_name,\n", + " s3_output_uri=s3_output_path,\n", + " profile_name=profile_name,\n", + " metadata_columns=metadata_columns,\n", + ")\n", + "\n", + "documents = loader.load()\n", + "print(documents)" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/libs/community/langchain_community/document_loaders/__init__.py b/libs/community/langchain_community/document_loaders/__init__.py index 869b0dca0329b..ac1a197b3d259 100644 --- a/libs/community/langchain_community/document_loaders/__init__.py +++ b/libs/community/langchain_community/document_loaders/__init__.py @@ -36,6 +36,7 @@ ) from langchain_community.document_loaders.astradb import AstraDBLoader from langchain_community.document_loaders.async_html import AsyncHtmlLoader +from langchain_community.document_loaders.athena import AthenaLoader from langchain_community.document_loaders.azlyrics import AZLyricsLoader from langchain_community.document_loaders.azure_ai_data import ( AzureAIDataLoader, @@ -257,6 +258,7 @@ "AssemblyAIAudioTranscriptLoader", "AstraDBLoader", "AsyncHtmlLoader", + "AthenaLoader", "AzureAIDataLoader", "AzureAIDocumentIntelligenceLoader", "AzureBlobStorageContainerLoader", diff --git a/libs/community/langchain_community/document_loaders/athena.py b/libs/community/langchain_community/document_loaders/athena.py new file mode 100644 index 0000000000000..1e33062c194ee --- /dev/null +++ b/libs/community/langchain_community/document_loaders/athena.py @@ -0,0 +1,167 @@ +from __future__ import annotations + +import io +import json +import time +from typing import Any, Dict, Iterator, List, Optional, Tuple + +from langchain_core.documents import Document + +from langchain_community.document_loaders.base import BaseLoader + + +class AthenaLoader(BaseLoader): + """Load documents from `AWS Athena`. + + Each document represents one row of the result. + - By default, all columns are written into the `page_content` of the document + and none into the `metadata` of the document. + - If `metadata_columns` are provided then these columns are written + into the `metadata` of the document while the rest of the columns + are written into the `page_content` of the document. + + To authenticate, the AWS client uses this method to automatically load credentials: + https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html + + If a specific credential profile should be used, you must pass + the name of the profile from the ~/.aws/credentials file that is to be used. + + Make sure the credentials / roles used have the required policies to + access the Amazon Textract service. + """ + + def __init__( + self, + query: str, + database: str, + s3_output_uri: str, + profile_name: str, + metadata_columns: Optional[List[str]] = None, + ): + """Initialize Athena document loader. + + Args: + query: The query to run in Athena. + database: Athena database + s3_output_uri: Athena output path + metadata_columns: Optional. Columns written to Document `metadata`. + """ + self.query = query + self.database = database + self.s3_output_uri = s3_output_uri + self.metadata_columns = metadata_columns if metadata_columns is not None else [] + + try: + import boto3 + except ImportError: + raise ModuleNotFoundError( + "Could not import boto3 python package. " + "Please install it with `pip install boto3`." + ) + + try: + session = ( + boto3.Session(profile_name=profile_name) + if profile_name is not None + else boto3.Session() + ) + except Exception as e: + raise ValueError( + "Could not load credentials to authenticate with AWS client. " + "Please check that credentials in the specified " + "profile name are valid." + ) from e + + self.athena_client = session.client("athena") + self.s3_client = session.client("s3") + + def _execute_query(self) -> List[Dict[str, Any]]: + response = self.athena_client.start_query_execution( + QueryString=self.query, + QueryExecutionContext={"Database": self.database}, + ResultConfiguration={"OutputLocation": self.s3_output_uri}, + ) + query_execution_id = response["QueryExecutionId"] + print(f"Query : {self.query}") + while True: + response = self.athena_client.get_query_execution( + QueryExecutionId=query_execution_id + ) + state = response["QueryExecution"]["Status"]["State"] + if state == "SUCCEEDED": + print(f"State : {state}") + break + elif state == "FAILED": + resp_status = response["QueryExecution"]["Status"] + state_change_reason = resp_status["StateChangeReason"] + err = f"Query Failed: {state_change_reason}" + raise Exception(err) + elif state == "CANCELLED": + raise Exception("Query was cancelled by the user.") + else: + print(f"State : {state}") + time.sleep(1) + + result_set = self._get_result_set(query_execution_id) + return json.loads(result_set.to_json(orient="records")) + + def _remove_suffix(self, input_string: str, suffix: str) -> str: + if suffix and input_string.endswith(suffix): + return input_string[: -len(suffix)] + return input_string + + def _remove_prefix(self, input_string: str, suffix: str) -> str: + if suffix and input_string.startswith(suffix): + return input_string[len(suffix) :] + return input_string + + def _get_result_set(self, query_execution_id: str) -> Any: + try: + import pandas as pd + except ImportError: + raise ModuleNotFoundError( + "Could not import pandas python package. " + "Please install it with `pip install pandas`." + ) + + output_uri = self.s3_output_uri + tokens = self._remove_prefix( + self._remove_suffix(output_uri, "/"), "s3://" + ).split("/") + bucket = tokens[0] + key = "/".join(tokens[1:]) + "/" + query_execution_id + ".csv" + + obj = self.s3_client.get_object(Bucket=bucket, Key=key) + df = pd.read_csv(io.BytesIO(obj["Body"].read()), encoding="utf8") + return df + + def _get_columns( + self, query_result: List[Dict[str, Any]] + ) -> Tuple[List[str], List[str]]: + content_columns = [] + metadata_columns = [] + all_columns = list(query_result[0].keys()) + for key in all_columns: + if key in self.metadata_columns: + metadata_columns.append(key) + else: + content_columns.append(key) + + return content_columns, metadata_columns + + def lazy_load(self) -> Iterator[Document]: + query_result = self._execute_query() + content_columns, metadata_columns = self._get_columns(query_result) + for row in query_result: + page_content = "\n".join( + f"{k}: {v}" for k, v in row.items() if k in content_columns + ) + metadata = { + k: v for k, v in row.items() if k in metadata_columns and v is not None + } + doc = Document(page_content=page_content, metadata=metadata) + yield doc + + def load(self) -> List[Document]: + """Load data into document objects.""" + return list(self.lazy_load()) diff --git a/libs/community/tests/unit_tests/document_loaders/test_imports.py b/libs/community/tests/unit_tests/document_loaders/test_imports.py index d22f81aa19ba3..98865797e3b78 100644 --- a/libs/community/tests/unit_tests/document_loaders/test_imports.py +++ b/libs/community/tests/unit_tests/document_loaders/test_imports.py @@ -23,6 +23,7 @@ "AssemblyAIAudioTranscriptLoader", "AstraDBLoader", "AsyncHtmlLoader", + "AthenaLoader", "AzureAIDataLoader", "AzureAIDocumentIntelligenceLoader", "AzureBlobStorageContainerLoader",