Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

community: AWS Athena Document Loader #15625

Merged
merged 21 commits into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions docs/docs/integrations/document_loaders/athena.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {

Check failure on line 5 in docs/docs/integrations/document_loaders/athena.ipynb

View workflow job for this annotation

GitHub Actions / lint / build (3.8)

Ruff (F541)

docs/docs/integrations/document_loaders/athena.ipynb:1:1: F541 f-string without any placeholders

Check failure on line 5 in docs/docs/integrations/document_loaders/athena.ipynb

View workflow job for this annotation

GitHub Actions / lint / build (3.11)

Ruff (F541)

docs/docs/integrations/document_loaders/athena.ipynb:1:1: F541 f-string without any placeholders
"id": "MwTWzDxYgbrR"
},
"source": [
"# Athena\n",
"\n",
"This notebooks goes over how to load documents from AWS Athena"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "F0zaLR3xgWmO"
},

Check failure on line 19 in docs/docs/integrations/document_loaders/athena.ipynb

View workflow job for this annotation

GitHub Actions / lint / build (3.8)

Ruff (F541)

docs/docs/integrations/document_loaders/athena.ipynb:1:1: F541 f-string without any placeholders

Check failure on line 19 in docs/docs/integrations/document_loaders/athena.ipynb

View workflow job for this annotation

GitHub Actions / lint / build (3.11)

Ruff (F541)

docs/docs/integrations/document_loaders/athena.ipynb:1:1: F541 f-string without any placeholders
"outputs": [],
"source": [
"! pip install boto3"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "076NLjfngoWJ"
},
"outputs": [],
"source": [
"from langchain.document_loaders import AthenaLoader"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "XpMRQwU9gu44"
},
"outputs": [],
"source": [
"database_name = \"my_database\"\n",
"s3_output_path = \"s3://my_bucket/query_results/\"\n",
"query = f\"SELECT * FROM my_table\"\n",
"profile_name = \"my_profile\"\n",
"\n",
"loader = AthenaLoader(\n",
" query=query,\n",
" database=database_name,\n",
" s3_output_uri=s3_output_path,\n",
" profile_name=profile_name\n",
")\n",
"\n",
"documents = loader.load()\n",
"print(documents)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5IBapL3ejoEt"
},
"source": [
"Example with metadata columns"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "wMx6nI1qjryD"
},
"outputs": [],
"source": [
"database_name = \"my_database\"\n",
"s3_output_path = \"s3://my_bucket/query_results/\"\n",
"query = f\"SELECT * FROM my_table\"\n",
"profile_name = \"my_profile\"\n",
"metadata_columns = [\"_row\", \"_created_at\"]\n",
"\n",
"loader = AthenaLoader(\n",
" query=query,\n",
" database=database_name,\n",
" s3_output_uri=s3_output_path,\n",
" profile_name=profile_name,\n",
" metadata_columns=metadata_columns\n",
")\n",
"\n",
"documents = loader.load()\n",
"print(documents)"
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
)
from langchain_community.document_loaders.astradb import AstraDBLoader
from langchain_community.document_loaders.async_html import AsyncHtmlLoader
from langchain_community.document_loaders.athena import AthenaLoader
from langchain_community.document_loaders.azlyrics import AZLyricsLoader
from langchain_community.document_loaders.azure_ai_data import (
AzureAIDataLoader,
Expand Down Expand Up @@ -252,6 +253,7 @@
"AssemblyAIAudioTranscriptLoader",
"AstraDBLoader",
"AsyncHtmlLoader",
"AthenaLoader",
"AzureAIDataLoader",
"AzureAIDocumentIntelligenceLoader",
"AzureBlobStorageContainerLoader",
Expand Down
132 changes: 132 additions & 0 deletions libs/community/langchain_community/document_loaders/athena.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from __future__ import annotations

import io
import json
import time
from typing import Any, Dict, Iterator, List, Optional, Tuple

from langchain_core.documents import Document

from langchain_community.document_loaders.base import BaseLoader


class AthenaLoader(BaseLoader):
"""Load documents from `AWS Athena`.

Each document represents one row of the result.
- By default, all columns are written into the `page_content` of the document
and none into the `metadata` of the document.
- If `metadata_columns` are provided then these columns are written
into the `metadata` of the document while the rest of the columns
are written into the `page_content` of the document.

To authenticate, the AWS client uses this method to automatically load credentials:
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html

If a specific credential profile should be used, you must pass
the name of the profile from the ~/.aws/credentials file that is to be used.

Make sure the credentials / roles used have the required policies to
access the Amazon Textract service.
"""

def __init__(
self,
query: str,
database: str,
s3_output_uri: str,
profile_name: str,
metadata_columns: Optional[List[str]] = None,
):
"""Initialize Athena document loader.

Args:
query: The query to run in Athena.
database: Athena database
s3_output_uri: Athena output path
metadata_columns: Optional. Columns written to Document `metadata`.
"""
self.query = query
self.database = database
self.s3_output_uri = s3_output_uri
self.profile_name = profile_name
self.metadata_columns = metadata_columns if metadata_columns is not None else []

def _execute_query(self) -> List[Dict[str, Any]]:
import boto3

session = (
boto3.Session(profile_name=self.profile_name)
if self.profile_name is not None
else boto3.Session()
)
client = session.client("athena")
abhijeethp marked this conversation as resolved.
Show resolved Hide resolved

response = client.start_query_execution(
QueryString=self.query,
QueryExecutionContext={"Database": self.database},
ResultConfiguration={"OutputLocation": self.s3_output_uri},
)
query_execution_id = response["QueryExecutionId"]
print(f"Query : {self.query}")
while True:
response = client.get_query_execution(QueryExecutionId=query_execution_id)
state = response["QueryExecution"]["Status"]["State"]
if state == "SUCCEEDED":
print(f"State : {state}")
break
elif state == "FAILED":
resp_status = response["QueryExecution"]["Status"]
state_change_reason = resp_status["StateChangeReason"]
err = f"Query Failed: {state_change_reason}"
raise Exception(err)
elif state == "CANCELLED":
raise Exception("Query was cancelled by the user.")
else:
print(f"State : {state}")
time.sleep(1)

result_set = self.get_result_set(session, query_execution_id)
return json.loads(result_set.to_json(orient="records"))

def get_result_set(self, session, query_execution_id):
abhijeethp marked this conversation as resolved.
Show resolved Hide resolved
import pandas as pd

s3c = session.client("s3")
abhijeethp marked this conversation as resolved.
Show resolved Hide resolved

tokens = self.s3_output_uri.removeprefix("s3://").removesuffix("/").split("/")
bucket = tokens[0]
key = "/".join(tokens[1:]) + "/" + query_execution_id + ".csv"

obj = s3c.get_object(Bucket=bucket, Key=key)
df = pd.read_csv(io.BytesIO(obj["Body"].read()), encoding="utf8")
return df

def _get_columns(
self, query_result: List[Dict[str, Any]]
) -> Tuple[List[str], List[str]]:
content_columns = []
metadata_columns = []
all_columns = list(query_result[0].keys())
for key in all_columns:
if key in self.metadata_columns:
metadata_columns.append(key)
else:
content_columns.append(key)

return content_columns, metadata_columns

def lazy_load(self) -> Iterator[Document]:
query_result = self._execute_query()
content_columns, metadata_columns = self._get_columns(query_result)
for row in query_result:
page_content = "\n".join(
f"{k}: {v}" for k, v in row.items() if k in content_columns
)
metadata = {k: v for k, v in row.items() if k in metadata_columns}
doc = Document(page_content=page_content, metadata=metadata)
yield doc

def load(self) -> List[Document]:
"""Load data into document objects."""
return list(self.lazy_load())
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"AssemblyAIAudioTranscriptLoader",
"AstraDBLoader",
"AsyncHtmlLoader",
"AthenaLoader",
"AzureAIDataLoader",
"AzureAIDocumentIntelligenceLoader",
"AzureBlobStorageContainerLoader",
Expand Down
Loading