-
Notifications
You must be signed in to change notification settings - Fork 16.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
community[minor]: AWS Athena Document Loader (#15625)
- **Description:** Adds the document loader for [AWS Athena](https://aws.amazon.com/athena/), a serverless and interactive analytics service. - **Dependencies:** Added boto3 as a dependency
- Loading branch information
1 parent
93da18b
commit 584b647
Showing
4 changed files
with
280 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "MwTWzDxYgbrR" | ||
}, | ||
"source": [ | ||
"# Athena\n", | ||
"\n", | ||
"This notebooks goes over how to load documents from AWS Athena" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "F0zaLR3xgWmO" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"! pip install boto3" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "076NLjfngoWJ" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"from langchain_community.document_loaders.athena import AthenaLoader" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "XpMRQwU9gu44" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"database_name = \"my_database\"\n", | ||
"s3_output_path = \"s3://my_bucket/query_results/\"\n", | ||
"query = \"SELECT * FROM my_table\"\n", | ||
"profile_name = \"my_profile\"\n", | ||
"\n", | ||
"loader = AthenaLoader(\n", | ||
" query=query,\n", | ||
" database=database_name,\n", | ||
" s3_output_uri=s3_output_path,\n", | ||
" profile_name=profile_name,\n", | ||
")\n", | ||
"\n", | ||
"documents = loader.load()\n", | ||
"print(documents)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "5IBapL3ejoEt" | ||
}, | ||
"source": [ | ||
"Example with metadata columns" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "wMx6nI1qjryD" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"database_name = \"my_database\"\n", | ||
"s3_output_path = \"s3://my_bucket/query_results/\"\n", | ||
"query = \"SELECT * FROM my_table\"\n", | ||
"profile_name = \"my_profile\"\n", | ||
"metadata_columns = [\"_row\", \"_created_at\"]\n", | ||
"\n", | ||
"loader = AthenaLoader(\n", | ||
" query=query,\n", | ||
" database=database_name,\n", | ||
" s3_output_uri=s3_output_path,\n", | ||
" profile_name=profile_name,\n", | ||
" metadata_columns=metadata_columns,\n", | ||
")\n", | ||
"\n", | ||
"documents = loader.load()\n", | ||
"print(documents)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"colab": { | ||
"provenance": [] | ||
}, | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"name": "python" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 0 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
167 changes: 167 additions & 0 deletions
167
libs/community/langchain_community/document_loaders/athena.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
from __future__ import annotations | ||
|
||
import io | ||
import json | ||
import time | ||
from typing import Any, Dict, Iterator, List, Optional, Tuple | ||
|
||
from langchain_core.documents import Document | ||
|
||
from langchain_community.document_loaders.base import BaseLoader | ||
|
||
|
||
class AthenaLoader(BaseLoader): | ||
"""Load documents from `AWS Athena`. | ||
Each document represents one row of the result. | ||
- By default, all columns are written into the `page_content` of the document | ||
and none into the `metadata` of the document. | ||
- If `metadata_columns` are provided then these columns are written | ||
into the `metadata` of the document while the rest of the columns | ||
are written into the `page_content` of the document. | ||
To authenticate, the AWS client uses this method to automatically load credentials: | ||
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html | ||
If a specific credential profile should be used, you must pass | ||
the name of the profile from the ~/.aws/credentials file that is to be used. | ||
Make sure the credentials / roles used have the required policies to | ||
access the Amazon Textract service. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
query: str, | ||
database: str, | ||
s3_output_uri: str, | ||
profile_name: str, | ||
metadata_columns: Optional[List[str]] = None, | ||
): | ||
"""Initialize Athena document loader. | ||
Args: | ||
query: The query to run in Athena. | ||
database: Athena database | ||
s3_output_uri: Athena output path | ||
metadata_columns: Optional. Columns written to Document `metadata`. | ||
""" | ||
self.query = query | ||
self.database = database | ||
self.s3_output_uri = s3_output_uri | ||
self.metadata_columns = metadata_columns if metadata_columns is not None else [] | ||
|
||
try: | ||
import boto3 | ||
except ImportError: | ||
raise ModuleNotFoundError( | ||
"Could not import boto3 python package. " | ||
"Please install it with `pip install boto3`." | ||
) | ||
|
||
try: | ||
session = ( | ||
boto3.Session(profile_name=profile_name) | ||
if profile_name is not None | ||
else boto3.Session() | ||
) | ||
except Exception as e: | ||
raise ValueError( | ||
"Could not load credentials to authenticate with AWS client. " | ||
"Please check that credentials in the specified " | ||
"profile name are valid." | ||
) from e | ||
|
||
self.athena_client = session.client("athena") | ||
self.s3_client = session.client("s3") | ||
|
||
def _execute_query(self) -> List[Dict[str, Any]]: | ||
response = self.athena_client.start_query_execution( | ||
QueryString=self.query, | ||
QueryExecutionContext={"Database": self.database}, | ||
ResultConfiguration={"OutputLocation": self.s3_output_uri}, | ||
) | ||
query_execution_id = response["QueryExecutionId"] | ||
print(f"Query : {self.query}") | ||
Check failure on line 85 in libs/community/langchain_community/document_loaders/athena.py GitHub Actions / cd libs/community / - / make lint #3.8Ruff (T201)
|
||
while True: | ||
response = self.athena_client.get_query_execution( | ||
QueryExecutionId=query_execution_id | ||
) | ||
state = response["QueryExecution"]["Status"]["State"] | ||
if state == "SUCCEEDED": | ||
print(f"State : {state}") | ||
Check failure on line 92 in libs/community/langchain_community/document_loaders/athena.py GitHub Actions / cd libs/community / - / make lint #3.8Ruff (T201)
|
||
break | ||
elif state == "FAILED": | ||
resp_status = response["QueryExecution"]["Status"] | ||
state_change_reason = resp_status["StateChangeReason"] | ||
err = f"Query Failed: {state_change_reason}" | ||
raise Exception(err) | ||
elif state == "CANCELLED": | ||
raise Exception("Query was cancelled by the user.") | ||
else: | ||
print(f"State : {state}") | ||
Check failure on line 102 in libs/community/langchain_community/document_loaders/athena.py GitHub Actions / cd libs/community / - / make lint #3.8Ruff (T201)
|
||
time.sleep(1) | ||
|
||
result_set = self._get_result_set(query_execution_id) | ||
return json.loads(result_set.to_json(orient="records")) | ||
|
||
def _remove_suffix(self, input_string: str, suffix: str) -> str: | ||
if suffix and input_string.endswith(suffix): | ||
return input_string[: -len(suffix)] | ||
return input_string | ||
|
||
def _remove_prefix(self, input_string: str, suffix: str) -> str: | ||
if suffix and input_string.startswith(suffix): | ||
return input_string[len(suffix) :] | ||
return input_string | ||
|
||
def _get_result_set(self, query_execution_id: str) -> Any: | ||
try: | ||
import pandas as pd | ||
except ImportError: | ||
raise ModuleNotFoundError( | ||
"Could not import pandas python package. " | ||
"Please install it with `pip install pandas`." | ||
) | ||
|
||
output_uri = self.s3_output_uri | ||
tokens = self._remove_prefix( | ||
self._remove_suffix(output_uri, "/"), "s3://" | ||
).split("/") | ||
bucket = tokens[0] | ||
key = "/".join(tokens[1:]) + "/" + query_execution_id + ".csv" | ||
|
||
obj = self.s3_client.get_object(Bucket=bucket, Key=key) | ||
df = pd.read_csv(io.BytesIO(obj["Body"].read()), encoding="utf8") | ||
return df | ||
|
||
def _get_columns( | ||
self, query_result: List[Dict[str, Any]] | ||
) -> Tuple[List[str], List[str]]: | ||
content_columns = [] | ||
metadata_columns = [] | ||
all_columns = list(query_result[0].keys()) | ||
for key in all_columns: | ||
if key in self.metadata_columns: | ||
metadata_columns.append(key) | ||
else: | ||
content_columns.append(key) | ||
|
||
return content_columns, metadata_columns | ||
|
||
def lazy_load(self) -> Iterator[Document]: | ||
query_result = self._execute_query() | ||
content_columns, metadata_columns = self._get_columns(query_result) | ||
for row in query_result: | ||
page_content = "\n".join( | ||
f"{k}: {v}" for k, v in row.items() if k in content_columns | ||
) | ||
metadata = { | ||
k: v for k, v in row.items() if k in metadata_columns and v is not None | ||
} | ||
doc = Document(page_content=page_content, metadata=metadata) | ||
yield doc | ||
|
||
def load(self) -> List[Document]: | ||
"""Load data into document objects.""" | ||
return list(self.lazy_load()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters