Skip to content

Commit

Permalink
community[minor]: AWS Athena Document Loader (#15625)
Browse files Browse the repository at this point in the history
- **Description:** Adds the document loader for [AWS
Athena](https://aws.amazon.com/athena/), a serverless and interactive
analytics service.
  - **Dependencies:** Added boto3 as a dependency
  • Loading branch information
abhijeethp authored Feb 12, 2024
1 parent 93da18b commit 584b647
Show file tree
Hide file tree
Showing 4 changed files with 280 additions and 0 deletions.
110 changes: 110 additions & 0 deletions docs/docs/integrations/document_loaders/athena.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "MwTWzDxYgbrR"
},
"source": [
"# Athena\n",
"\n",
"This notebooks goes over how to load documents from AWS Athena"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "F0zaLR3xgWmO"
},
"outputs": [],
"source": [
"! pip install boto3"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "076NLjfngoWJ"
},
"outputs": [],
"source": [
"from langchain_community.document_loaders.athena import AthenaLoader"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "XpMRQwU9gu44"
},
"outputs": [],
"source": [
"database_name = \"my_database\"\n",
"s3_output_path = \"s3://my_bucket/query_results/\"\n",
"query = \"SELECT * FROM my_table\"\n",
"profile_name = \"my_profile\"\n",
"\n",
"loader = AthenaLoader(\n",
" query=query,\n",
" database=database_name,\n",
" s3_output_uri=s3_output_path,\n",
" profile_name=profile_name,\n",
")\n",
"\n",
"documents = loader.load()\n",
"print(documents)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5IBapL3ejoEt"
},
"source": [
"Example with metadata columns"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "wMx6nI1qjryD"
},
"outputs": [],
"source": [
"database_name = \"my_database\"\n",
"s3_output_path = \"s3://my_bucket/query_results/\"\n",
"query = \"SELECT * FROM my_table\"\n",
"profile_name = \"my_profile\"\n",
"metadata_columns = [\"_row\", \"_created_at\"]\n",
"\n",
"loader = AthenaLoader(\n",
" query=query,\n",
" database=database_name,\n",
" s3_output_uri=s3_output_path,\n",
" profile_name=profile_name,\n",
" metadata_columns=metadata_columns,\n",
")\n",
"\n",
"documents = loader.load()\n",
"print(documents)"
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
)
from langchain_community.document_loaders.astradb import AstraDBLoader
from langchain_community.document_loaders.async_html import AsyncHtmlLoader
from langchain_community.document_loaders.athena import AthenaLoader
from langchain_community.document_loaders.azlyrics import AZLyricsLoader
from langchain_community.document_loaders.azure_ai_data import (
AzureAIDataLoader,
Expand Down Expand Up @@ -257,6 +258,7 @@
"AssemblyAIAudioTranscriptLoader",
"AstraDBLoader",
"AsyncHtmlLoader",
"AthenaLoader",
"AzureAIDataLoader",
"AzureAIDocumentIntelligenceLoader",
"AzureBlobStorageContainerLoader",
Expand Down
167 changes: 167 additions & 0 deletions libs/community/langchain_community/document_loaders/athena.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
from __future__ import annotations

import io
import json
import time
from typing import Any, Dict, Iterator, List, Optional, Tuple

from langchain_core.documents import Document

from langchain_community.document_loaders.base import BaseLoader


class AthenaLoader(BaseLoader):
"""Load documents from `AWS Athena`.
Each document represents one row of the result.
- By default, all columns are written into the `page_content` of the document
and none into the `metadata` of the document.
- If `metadata_columns` are provided then these columns are written
into the `metadata` of the document while the rest of the columns
are written into the `page_content` of the document.
To authenticate, the AWS client uses this method to automatically load credentials:
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
If a specific credential profile should be used, you must pass
the name of the profile from the ~/.aws/credentials file that is to be used.
Make sure the credentials / roles used have the required policies to
access the Amazon Textract service.
"""

def __init__(
self,
query: str,
database: str,
s3_output_uri: str,
profile_name: str,
metadata_columns: Optional[List[str]] = None,
):
"""Initialize Athena document loader.
Args:
query: The query to run in Athena.
database: Athena database
s3_output_uri: Athena output path
metadata_columns: Optional. Columns written to Document `metadata`.
"""
self.query = query
self.database = database
self.s3_output_uri = s3_output_uri
self.metadata_columns = metadata_columns if metadata_columns is not None else []

try:
import boto3
except ImportError:
raise ModuleNotFoundError(
"Could not import boto3 python package. "
"Please install it with `pip install boto3`."
)

try:
session = (
boto3.Session(profile_name=profile_name)
if profile_name is not None
else boto3.Session()
)
except Exception as e:
raise ValueError(
"Could not load credentials to authenticate with AWS client. "
"Please check that credentials in the specified "
"profile name are valid."
) from e

self.athena_client = session.client("athena")
self.s3_client = session.client("s3")

def _execute_query(self) -> List[Dict[str, Any]]:
response = self.athena_client.start_query_execution(
QueryString=self.query,
QueryExecutionContext={"Database": self.database},
ResultConfiguration={"OutputLocation": self.s3_output_uri},
)
query_execution_id = response["QueryExecutionId"]
print(f"Query : {self.query}")

Check failure on line 85 in libs/community/langchain_community/document_loaders/athena.py

View workflow job for this annotation

GitHub Actions / cd libs/community / - / make lint #3.8

Ruff (T201)

langchain_community/document_loaders/athena.py:85:9: T201 `print` found

Check failure on line 85 in libs/community/langchain_community/document_loaders/athena.py

View workflow job for this annotation

GitHub Actions / cd libs/community / - / make lint #3.11

Ruff (T201)

langchain_community/document_loaders/athena.py:85:9: T201 `print` found
while True:
response = self.athena_client.get_query_execution(
QueryExecutionId=query_execution_id
)
state = response["QueryExecution"]["Status"]["State"]
if state == "SUCCEEDED":
print(f"State : {state}")

Check failure on line 92 in libs/community/langchain_community/document_loaders/athena.py

View workflow job for this annotation

GitHub Actions / cd libs/community / - / make lint #3.8

Ruff (T201)

langchain_community/document_loaders/athena.py:92:17: T201 `print` found

Check failure on line 92 in libs/community/langchain_community/document_loaders/athena.py

View workflow job for this annotation

GitHub Actions / cd libs/community / - / make lint #3.11

Ruff (T201)

langchain_community/document_loaders/athena.py:92:17: T201 `print` found
break
elif state == "FAILED":
resp_status = response["QueryExecution"]["Status"]
state_change_reason = resp_status["StateChangeReason"]
err = f"Query Failed: {state_change_reason}"
raise Exception(err)
elif state == "CANCELLED":
raise Exception("Query was cancelled by the user.")
else:
print(f"State : {state}")

Check failure on line 102 in libs/community/langchain_community/document_loaders/athena.py

View workflow job for this annotation

GitHub Actions / cd libs/community / - / make lint #3.8

Ruff (T201)

langchain_community/document_loaders/athena.py:102:17: T201 `print` found

Check failure on line 102 in libs/community/langchain_community/document_loaders/athena.py

View workflow job for this annotation

GitHub Actions / cd libs/community / - / make lint #3.11

Ruff (T201)

langchain_community/document_loaders/athena.py:102:17: T201 `print` found
time.sleep(1)

result_set = self._get_result_set(query_execution_id)
return json.loads(result_set.to_json(orient="records"))

def _remove_suffix(self, input_string: str, suffix: str) -> str:
if suffix and input_string.endswith(suffix):
return input_string[: -len(suffix)]
return input_string

def _remove_prefix(self, input_string: str, suffix: str) -> str:
if suffix and input_string.startswith(suffix):
return input_string[len(suffix) :]
return input_string

def _get_result_set(self, query_execution_id: str) -> Any:
try:
import pandas as pd
except ImportError:
raise ModuleNotFoundError(
"Could not import pandas python package. "
"Please install it with `pip install pandas`."
)

output_uri = self.s3_output_uri
tokens = self._remove_prefix(
self._remove_suffix(output_uri, "/"), "s3://"
).split("/")
bucket = tokens[0]
key = "/".join(tokens[1:]) + "/" + query_execution_id + ".csv"

obj = self.s3_client.get_object(Bucket=bucket, Key=key)
df = pd.read_csv(io.BytesIO(obj["Body"].read()), encoding="utf8")
return df

def _get_columns(
self, query_result: List[Dict[str, Any]]
) -> Tuple[List[str], List[str]]:
content_columns = []
metadata_columns = []
all_columns = list(query_result[0].keys())
for key in all_columns:
if key in self.metadata_columns:
metadata_columns.append(key)
else:
content_columns.append(key)

return content_columns, metadata_columns

def lazy_load(self) -> Iterator[Document]:
query_result = self._execute_query()
content_columns, metadata_columns = self._get_columns(query_result)
for row in query_result:
page_content = "\n".join(
f"{k}: {v}" for k, v in row.items() if k in content_columns
)
metadata = {
k: v for k, v in row.items() if k in metadata_columns and v is not None
}
doc = Document(page_content=page_content, metadata=metadata)
yield doc

def load(self) -> List[Document]:
"""Load data into document objects."""
return list(self.lazy_load())
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"AssemblyAIAudioTranscriptLoader",
"AstraDBLoader",
"AsyncHtmlLoader",
"AthenaLoader",
"AzureAIDataLoader",
"AzureAIDocumentIntelligenceLoader",
"AzureBlobStorageContainerLoader",
Expand Down

0 comments on commit 584b647

Please sign in to comment.