From 8d12c57ee0c2f8e8f13bb4f313dbe399ea1ccc08 Mon Sep 17 00:00:00 2001 From: Abhijeeth Padarthi Date: Sat, 6 Jan 2024 07:52:22 +0530 Subject: [PATCH 01/12] Add Athena document loader --- .../document_loaders/athena.ipynb | 110 ++++++++++++++++ .../document_loaders/__init__.py | 2 + .../document_loaders/athena.py | 123 ++++++++++++++++++ .../document_loaders/test_imports.py | 1 + .../langchain/document_loaders/__init__.py | 2 + .../langchain/document_loaders/athena.py | 3 + .../document_loaders/test_imports.py | 1 + 7 files changed, 242 insertions(+) create mode 100644 docs/docs/integrations/document_loaders/athena.ipynb create mode 100644 libs/community/langchain_community/document_loaders/athena.py create mode 100644 libs/langchain/langchain/document_loaders/athena.py diff --git a/docs/docs/integrations/document_loaders/athena.ipynb b/docs/docs/integrations/document_loaders/athena.ipynb new file mode 100644 index 0000000000000..2636fa201944e --- /dev/null +++ b/docs/docs/integrations/document_loaders/athena.ipynb @@ -0,0 +1,110 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "MwTWzDxYgbrR" + }, + "source": [ + "# Athena\n", + "\n", + "This notebooks goes over how to load documents from AWS Athena" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "F0zaLR3xgWmO" + }, + "outputs": [], + "source": [ + "! pip install boto3" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "076NLjfngoWJ" + }, + "outputs": [], + "source": [ + "from langchain.document_loaders import AthenaLoader" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "XpMRQwU9gu44" + }, + "outputs": [], + "source": [ + "database_name = \"my_database\"\n", + "s3_output_path = \"s3://my_bucket/query_results/\"\n", + "query = f\"SELECT * FROM my_table\"\n", + "profile_name = \"my_profile\"\n", + "\n", + "loader = AthenaLoader(\n", + " query=query,\n", + " database=database_name,\n", + " s3_output_uri=s3_output_path,\n", + " profile_name=profile_name\n", + ")\n", + "\n", + "documents = loader.load()\n", + "print(documents)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "5IBapL3ejoEt" + }, + "source": [ + "Example with metadata columns" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "wMx6nI1qjryD" + }, + "outputs": [], + "source": [ + "database_name = \"my_database\"\n", + "s3_output_path = \"s3://my_bucket/query_results/\"\n", + "query = f\"SELECT * FROM my_table\"\n", + "profile_name = \"my_profile\"\n", + "metadata_columns = [\"_row\", \"_created_at\"]\n", + "\n", + "loader = AthenaLoader(\n", + " query=query,\n", + " database=database_name,\n", + " s3_output_uri=s3_output_path,\n", + " profile_name=profile_name,\n", + " metadata_columns=metadata_columns\n", + ")\n", + "\n", + "documents = loader.load()\n", + "print(documents)" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/libs/community/langchain_community/document_loaders/__init__.py b/libs/community/langchain_community/document_loaders/__init__.py index ca295e538ebe4..ea6ea46703b3d 100644 --- a/libs/community/langchain_community/document_loaders/__init__.py +++ b/libs/community/langchain_community/document_loaders/__init__.py @@ -35,6 +35,7 @@ AssemblyAIAudioTranscriptLoader, ) from langchain_community.document_loaders.async_html import AsyncHtmlLoader +from langchain_community.document_loaders.athena import AthenaLoader from langchain_community.document_loaders.azlyrics import AZLyricsLoader from langchain_community.document_loaders.azure_ai_data import ( AzureAIDataLoader, @@ -249,6 +250,7 @@ "ArxivLoader", "AssemblyAIAudioTranscriptLoader", "AsyncHtmlLoader", + "AthenaLoader", "AzureAIDataLoader", "AzureAIDocumentIntelligenceLoader", "AzureBlobStorageContainerLoader", diff --git a/libs/community/langchain_community/document_loaders/athena.py b/libs/community/langchain_community/document_loaders/athena.py new file mode 100644 index 0000000000000..1b2307be68f42 --- /dev/null +++ b/libs/community/langchain_community/document_loaders/athena.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +import time +from typing import Any, Dict, Iterator, List, Optional, Tuple + +from langchain_core.documents import Document + +from langchain_community.document_loaders.base import BaseLoader + + +class AthenaLoader(BaseLoader): + """Load documents from `AWS Athena`. + + Each document represents one row of the result. + By default, all columns are written into the `page_content` and none into the `metadata`. + If `metadata_columns` are provided then these columns are written into the `metadata` of the + document while the rest of the columns are written into the `page_content` of the document. + + To authenticate, the AWS client uses the following methods to automatically load credentials: + https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html + + If a specific credential profile should be used, you must pass + the name of the profile from the ~/.aws/credentials file that is to be used. + + Make sure the credentials / roles used have the required policies to + access the Amazon Textract service. + """ + + def __init__( + self, + query: str, + database: str, + s3_output_uri: str, + profile_name: str, + metadata_columns: Optional[List[str]] = None, + ): + """Initialize Athena document loader. + + Args: + query: The query to run in Athena. + database: Athena database + s3_output_uri: Athena output path + metadata_columns: Optional. Columns written to Document `metadata`. + """ + self.query = query + self.database = database + self.s3_output_uri = s3_output_uri + self.profile_name = profile_name + self.metadata_columns = metadata_columns if metadata_columns is not None else [] + + def _execute_query(self) -> List[Dict[str, Any]]: + import boto3 + + session = ( + boto3.Session(profile_name=self.profile_name) + if self.profile_name is not None + else boto3.Session() + ) + client = session.client("athena") + + response = client.start_query_execution( + QueryString=self.query, + QueryExecutionContext={"Database": self.database}, + ResultConfiguration={"OutputLocation": self.s3_output_uri}, + ) + query_execution_id = response["QueryExecutionId"] + + while True: + response = client.get_query_execution(QueryExecutionId=query_execution_id) + state = response["QueryExecution"]["Status"]["State"] + if state == "SUCCEEDED": + break + elif state == "FAILED": + raise Exception( + f"Query Failed: {response['QueryExecution']['Status']['StateChangeReason']}" + ) + elif state == "CANCELLED": + raise Exception("Query was cancelled by the user.") + else: + print(state) + time.sleep(1) + + results = [] + result_set = client.get_query_results(QueryExecutionId=query_execution_id)[ + "ResultSet" + ]["Rows"] + columns = [x["VarCharValue"] for x in result_set[0]["Data"]] + for i in range(1, len(result_set)): + row = result_set[i]["Data"] + row_dict = {} + for col_num in range(len(row)): + row_dict[columns[col_num]] = row[col_num]["VarCharValue"] + results.append(row_dict) + return results + + def _get_columns( + self, query_result: List[Dict[str, Any]] + ) -> Tuple[List[str], List[str]]: + content_columns = [] + metadata_columns = [] + all_columns = list(query_result[0].keys()) + for key in all_columns: + if key in self.metadata_columns: + metadata_columns.append(key) + else: + content_columns.append(key) + + return content_columns, metadata_columns + + def lazy_load(self) -> Iterator[Document]: + query_result = self._execute_query() + content_columns, metadata_columns = self._get_columns(query_result) + for row in query_result: + page_content = "\n".join( + f"{k}: {v}" for k, v in row.items() if k in content_columns + ) + metadata = {k: v for k, v in row.items() if k in metadata_columns} + doc = Document(page_content=page_content, metadata=metadata) + yield doc + + def load(self) -> List[Document]: + """Load data into document objects.""" + return list(self.lazy_load()) diff --git a/libs/community/tests/unit_tests/document_loaders/test_imports.py b/libs/community/tests/unit_tests/document_loaders/test_imports.py index a2101c8830d39..b9d7def3871fe 100644 --- a/libs/community/tests/unit_tests/document_loaders/test_imports.py +++ b/libs/community/tests/unit_tests/document_loaders/test_imports.py @@ -22,6 +22,7 @@ "ArxivLoader", "AssemblyAIAudioTranscriptLoader", "AsyncHtmlLoader", + "AthenaLoader", "AzureAIDataLoader", "AzureAIDocumentIntelligenceLoader", "AzureBlobStorageContainerLoader", diff --git a/libs/langchain/langchain/document_loaders/__init__.py b/libs/langchain/langchain/document_loaders/__init__.py index ba3867ffcb43c..83883eda27079 100644 --- a/libs/langchain/langchain/document_loaders/__init__.py +++ b/libs/langchain/langchain/document_loaders/__init__.py @@ -33,6 +33,7 @@ from langchain.document_loaders.arxiv import ArxivLoader from langchain.document_loaders.assemblyai import AssemblyAIAudioTranscriptLoader from langchain.document_loaders.async_html import AsyncHtmlLoader +from langchain.document_loaders.athena import AthenaLoader from langchain.document_loaders.azlyrics import AZLyricsLoader from langchain.document_loaders.azure_ai_data import ( AzureAIDataLoader, @@ -229,6 +230,7 @@ "ArxivLoader", "AssemblyAIAudioTranscriptLoader", "AsyncHtmlLoader", + "AthenaLoader", "AzureAIDataLoader", "AzureBlobStorageContainerLoader", "AzureBlobStorageFileLoader", diff --git a/libs/langchain/langchain/document_loaders/athena.py b/libs/langchain/langchain/document_loaders/athena.py new file mode 100644 index 0000000000000..361e0a52165d8 --- /dev/null +++ b/libs/langchain/langchain/document_loaders/athena.py @@ -0,0 +1,3 @@ +from langchain_community.document_loaders.athena import AthenaLoader + +__all__ = ["AthenaLoader"] \ No newline at end of file diff --git a/libs/langchain/tests/unit_tests/document_loaders/test_imports.py b/libs/langchain/tests/unit_tests/document_loaders/test_imports.py index 18f6f22a5f00d..5bb4d05cbe5c1 100644 --- a/libs/langchain/tests/unit_tests/document_loaders/test_imports.py +++ b/libs/langchain/tests/unit_tests/document_loaders/test_imports.py @@ -22,6 +22,7 @@ "ArxivLoader", "AssemblyAIAudioTranscriptLoader", "AsyncHtmlLoader", + "AthenaLoader", "AzureAIDataLoader", "AzureBlobStorageContainerLoader", "AzureBlobStorageFileLoader", From cbf7de34ef4d44d01f5eb2c0bcad322d5da42145 Mon Sep 17 00:00:00 2001 From: Abhijeeth Padarthi Date: Sat, 6 Jan 2024 07:58:00 +0530 Subject: [PATCH 02/12] fixed lint and formatting issues --- .../document_loaders/athena.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/libs/community/langchain_community/document_loaders/athena.py b/libs/community/langchain_community/document_loaders/athena.py index 1b2307be68f42..7940263b5be12 100644 --- a/libs/community/langchain_community/document_loaders/athena.py +++ b/libs/community/langchain_community/document_loaders/athena.py @@ -12,11 +12,13 @@ class AthenaLoader(BaseLoader): """Load documents from `AWS Athena`. Each document represents one row of the result. - By default, all columns are written into the `page_content` and none into the `metadata`. - If `metadata_columns` are provided then these columns are written into the `metadata` of the - document while the rest of the columns are written into the `page_content` of the document. + - By default, all columns are written into the `page_content` of the document + and none into the `metadata` of the document. + - If `metadata_columns` are provided then these columns are written + into the `metadata` of the document while the rest of the columns + are written into the `page_content` of the document. - To authenticate, the AWS client uses the following methods to automatically load credentials: + To authenticate, the AWS client uses this method to automatically load credentials: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html If a specific credential profile should be used, you must pass @@ -71,9 +73,10 @@ def _execute_query(self) -> List[Dict[str, Any]]: if state == "SUCCEEDED": break elif state == "FAILED": - raise Exception( - f"Query Failed: {response['QueryExecution']['Status']['StateChangeReason']}" - ) + resp_status = response["QueryExecution"]["Status"] + state_change_reason = resp_status["StateChangeReason"] + err = f"Query Failed: {state_change_reason}" + raise Exception(err) elif state == "CANCELLED": raise Exception("Query was cancelled by the user.") else: From 8afe195d9e675138e2876bcf74bfe34ed47eb8d4 Mon Sep 17 00:00:00 2001 From: Abhijeeth Padarthi Date: Sun, 7 Jan 2024 22:27:04 +0530 Subject: [PATCH 03/12] remove imports --- .../langchain/document_loaders/__init__.py | 186 ------------------ 1 file changed, 186 deletions(-) diff --git a/libs/langchain/langchain/document_loaders/__init__.py b/libs/langchain/langchain/document_loaders/__init__.py index e00108281f44f..de4b7a87af7e4 100644 --- a/libs/langchain/langchain/document_loaders/__init__.py +++ b/libs/langchain/langchain/document_loaders/__init__.py @@ -17,192 +17,6 @@ import warnings from typing import Any -from langchain.document_loaders.acreom import AcreomLoader -from langchain.document_loaders.airbyte import ( - AirbyteCDKLoader, - AirbyteGongLoader, - AirbyteHubspotLoader, - AirbyteSalesforceLoader, - AirbyteShopifyLoader, - AirbyteStripeLoader, - AirbyteTypeformLoader, - AirbyteZendeskSupportLoader, -) -from langchain.document_loaders.airbyte_json import AirbyteJSONLoader -from langchain.document_loaders.airtable import AirtableLoader -from langchain.document_loaders.apify_dataset import ApifyDatasetLoader -from langchain.document_loaders.arcgis_loader import ArcGISLoader -from langchain.document_loaders.arxiv import ArxivLoader -from langchain.document_loaders.assemblyai import AssemblyAIAudioTranscriptLoader -from langchain.document_loaders.async_html import AsyncHtmlLoader -from langchain.document_loaders.athena import AthenaLoader -from langchain.document_loaders.azlyrics import AZLyricsLoader -from langchain.document_loaders.azure_ai_data import ( - AzureAIDataLoader, -) -from langchain.document_loaders.azure_blob_storage_container import ( - AzureBlobStorageContainerLoader, -) -from langchain.document_loaders.azure_blob_storage_file import ( - AzureBlobStorageFileLoader, -) -from langchain.document_loaders.bibtex import BibtexLoader -from langchain.document_loaders.bigquery import BigQueryLoader -from langchain.document_loaders.bilibili import BiliBiliLoader -from langchain.document_loaders.blackboard import BlackboardLoader -from langchain.document_loaders.blob_loaders import ( - Blob, - BlobLoader, - FileSystemBlobLoader, - YoutubeAudioLoader, -) -from langchain.document_loaders.blockchain import BlockchainDocumentLoader -from langchain.document_loaders.brave_search import BraveSearchLoader -from langchain.document_loaders.browserless import BrowserlessLoader -from langchain.document_loaders.chatgpt import ChatGPTLoader -from langchain.document_loaders.chromium import AsyncChromiumLoader -from langchain.document_loaders.college_confidential import CollegeConfidentialLoader -from langchain.document_loaders.concurrent import ConcurrentLoader -from langchain.document_loaders.confluence import ConfluenceLoader -from langchain.document_loaders.conllu import CoNLLULoader -from langchain.document_loaders.couchbase import CouchbaseLoader -from langchain.document_loaders.csv_loader import CSVLoader, UnstructuredCSVLoader -from langchain.document_loaders.cube_semantic import CubeSemanticLoader -from langchain.document_loaders.datadog_logs import DatadogLogsLoader -from langchain.document_loaders.dataframe import DataFrameLoader -from langchain.document_loaders.diffbot import DiffbotLoader -from langchain.document_loaders.directory import DirectoryLoader -from langchain.document_loaders.discord import DiscordChatLoader -from langchain.document_loaders.docugami import DocugamiLoader -from langchain.document_loaders.docusaurus import DocusaurusLoader -from langchain.document_loaders.dropbox import DropboxLoader -from langchain.document_loaders.duckdb_loader import DuckDBLoader -from langchain.document_loaders.email import ( - OutlookMessageLoader, - UnstructuredEmailLoader, -) -from langchain.document_loaders.epub import UnstructuredEPubLoader -from langchain.document_loaders.etherscan import EtherscanLoader -from langchain.document_loaders.evernote import EverNoteLoader -from langchain.document_loaders.excel import UnstructuredExcelLoader -from langchain.document_loaders.facebook_chat import FacebookChatLoader -from langchain.document_loaders.fauna import FaunaLoader -from langchain.document_loaders.figma import FigmaFileLoader -from langchain.document_loaders.gcs_directory import GCSDirectoryLoader -from langchain.document_loaders.gcs_file import GCSFileLoader -from langchain.document_loaders.geodataframe import GeoDataFrameLoader -from langchain.document_loaders.git import GitLoader -from langchain.document_loaders.gitbook import GitbookLoader -from langchain.document_loaders.github import GitHubIssuesLoader -from langchain.document_loaders.google_speech_to_text import GoogleSpeechToTextLoader -from langchain.document_loaders.googledrive import GoogleDriveLoader -from langchain.document_loaders.gutenberg import GutenbergLoader -from langchain.document_loaders.hn import HNLoader -from langchain.document_loaders.html import UnstructuredHTMLLoader -from langchain.document_loaders.html_bs import BSHTMLLoader -from langchain.document_loaders.hugging_face_dataset import HuggingFaceDatasetLoader -from langchain.document_loaders.ifixit import IFixitLoader -from langchain.document_loaders.image import UnstructuredImageLoader -from langchain.document_loaders.image_captions import ImageCaptionLoader -from langchain.document_loaders.imsdb import IMSDbLoader -from langchain.document_loaders.iugu import IuguLoader -from langchain.document_loaders.joplin import JoplinLoader -from langchain.document_loaders.json_loader import JSONLoader -from langchain.document_loaders.lakefs import LakeFSLoader -from langchain.document_loaders.larksuite import LarkSuiteDocLoader -from langchain.document_loaders.markdown import UnstructuredMarkdownLoader -from langchain.document_loaders.mastodon import MastodonTootsLoader -from langchain.document_loaders.max_compute import MaxComputeLoader -from langchain.document_loaders.mediawikidump import MWDumpLoader -from langchain.document_loaders.merge import MergedDataLoader -from langchain.document_loaders.mhtml import MHTMLLoader -from langchain.document_loaders.modern_treasury import ModernTreasuryLoader -from langchain.document_loaders.mongodb import MongodbLoader -from langchain.document_loaders.news import NewsURLLoader -from langchain.document_loaders.notebook import NotebookLoader -from langchain.document_loaders.notion import NotionDirectoryLoader -from langchain.document_loaders.notiondb import NotionDBLoader -from langchain.document_loaders.obs_directory import OBSDirectoryLoader -from langchain.document_loaders.obs_file import OBSFileLoader -from langchain.document_loaders.obsidian import ObsidianLoader -from langchain.document_loaders.odt import UnstructuredODTLoader -from langchain.document_loaders.onedrive import OneDriveLoader -from langchain.document_loaders.onedrive_file import OneDriveFileLoader -from langchain.document_loaders.open_city_data import OpenCityDataLoader -from langchain.document_loaders.org_mode import UnstructuredOrgModeLoader -from langchain.document_loaders.pdf import ( - AmazonTextractPDFLoader, - MathpixPDFLoader, - OnlinePDFLoader, - PDFMinerLoader, - PDFMinerPDFasHTMLLoader, - PDFPlumberLoader, - PyMuPDFLoader, - PyPDFDirectoryLoader, - PyPDFium2Loader, - PyPDFLoader, - UnstructuredPDFLoader, -) -from langchain.document_loaders.polars_dataframe import PolarsDataFrameLoader -from langchain.document_loaders.powerpoint import UnstructuredPowerPointLoader -from langchain.document_loaders.psychic import PsychicLoader -from langchain.document_loaders.pubmed import PubMedLoader -from langchain.document_loaders.pyspark_dataframe import PySparkDataFrameLoader -from langchain.document_loaders.python import PythonLoader -from langchain.document_loaders.readthedocs import ReadTheDocsLoader -from langchain.document_loaders.recursive_url_loader import RecursiveUrlLoader -from langchain.document_loaders.reddit import RedditPostsLoader -from langchain.document_loaders.roam import RoamLoader -from langchain.document_loaders.rocksetdb import RocksetLoader -from langchain.document_loaders.rss import RSSFeedLoader -from langchain.document_loaders.rst import UnstructuredRSTLoader -from langchain.document_loaders.rtf import UnstructuredRTFLoader -from langchain.document_loaders.s3_directory import S3DirectoryLoader -from langchain.document_loaders.s3_file import S3FileLoader -from langchain.document_loaders.sharepoint import SharePointLoader -from langchain.document_loaders.sitemap import SitemapLoader -from langchain.document_loaders.slack_directory import SlackDirectoryLoader -from langchain.document_loaders.snowflake_loader import SnowflakeLoader -from langchain.document_loaders.spreedly import SpreedlyLoader -from langchain.document_loaders.srt import SRTLoader -from langchain.document_loaders.stripe import StripeLoader -from langchain.document_loaders.telegram import ( - TelegramChatApiLoader, - TelegramChatFileLoader, -) -from langchain.document_loaders.tencent_cos_directory import TencentCOSDirectoryLoader -from langchain.document_loaders.tencent_cos_file import TencentCOSFileLoader -from langchain.document_loaders.tensorflow_datasets import TensorflowDatasetLoader -from langchain.document_loaders.text import TextLoader -from langchain.document_loaders.tomarkdown import ToMarkdownLoader -from langchain.document_loaders.toml import TomlLoader -from langchain.document_loaders.trello import TrelloLoader -from langchain.document_loaders.tsv import UnstructuredTSVLoader -from langchain.document_loaders.twitter import TwitterTweetLoader -from langchain.document_loaders.unstructured import ( - UnstructuredAPIFileIOLoader, - UnstructuredAPIFileLoader, - UnstructuredFileIOLoader, - UnstructuredFileLoader, -) -from langchain.document_loaders.url import UnstructuredURLLoader -from langchain.document_loaders.url_playwright import PlaywrightURLLoader -from langchain.document_loaders.url_selenium import SeleniumURLLoader -from langchain.document_loaders.weather import WeatherDataLoader -from langchain.document_loaders.web_base import WebBaseLoader -from langchain.document_loaders.whatsapp_chat import WhatsAppChatLoader -from langchain.document_loaders.wikipedia import WikipediaLoader -from langchain.document_loaders.word_document import ( - Docx2txtLoader, - UnstructuredWordDocumentLoader, -) -from langchain.document_loaders.xml import UnstructuredXMLLoader -from langchain.document_loaders.xorbits import XorbitsLoader -from langchain.document_loaders.youtube import ( - GoogleApiClient, - GoogleApiYoutubeLoader, - YoutubeLoader, -) from langchain_core._api import LangChainDeprecationWarning from langchain.utils.interactive_env import is_interactive_env From 0116383e2774ac5ecceac37c39c4dfb3d9604882 Mon Sep 17 00:00:00 2001 From: Abhijeeth Padarthi Date: Thu, 11 Jan 2024 15:23:45 +0530 Subject: [PATCH 04/12] fix athena pagination issue Fixes pagination issue by downloading complete results to pandas dataframe --- .../document_loaders/athena.py | 35 ++++++++++--------- .../langchain/document_loaders/__init__.py | 1 - .../langchain/document_loaders/athena.py | 3 -- .../document_loaders/test_imports.py | 1 - 4 files changed, 19 insertions(+), 21 deletions(-) delete mode 100644 libs/langchain/langchain/document_loaders/athena.py diff --git a/libs/community/langchain_community/document_loaders/athena.py b/libs/community/langchain_community/document_loaders/athena.py index 7940263b5be12..db6a04163610a 100644 --- a/libs/community/langchain_community/document_loaders/athena.py +++ b/libs/community/langchain_community/document_loaders/athena.py @@ -1,10 +1,11 @@ from __future__ import annotations import time +import io +import pandas as pd +import json from typing import Any, Dict, Iterator, List, Optional, Tuple - from langchain_core.documents import Document - from langchain_community.document_loaders.base import BaseLoader @@ -66,11 +67,12 @@ def _execute_query(self) -> List[Dict[str, Any]]: ResultConfiguration={"OutputLocation": self.s3_output_uri}, ) query_execution_id = response["QueryExecutionId"] - + print(f"Query : {self.query}") while True: response = client.get_query_execution(QueryExecutionId=query_execution_id) state = response["QueryExecution"]["Status"]["State"] if state == "SUCCEEDED": + print(f"State : {state}") break elif state == "FAILED": resp_status = response["QueryExecution"]["Status"] @@ -80,21 +82,22 @@ def _execute_query(self) -> List[Dict[str, Any]]: elif state == "CANCELLED": raise Exception("Query was cancelled by the user.") else: - print(state) + print(f"State : {state}") time.sleep(1) - results = [] - result_set = client.get_query_results(QueryExecutionId=query_execution_id)[ - "ResultSet" - ]["Rows"] - columns = [x["VarCharValue"] for x in result_set[0]["Data"]] - for i in range(1, len(result_set)): - row = result_set[i]["Data"] - row_dict = {} - for col_num in range(len(row)): - row_dict[columns[col_num]] = row[col_num]["VarCharValue"] - results.append(row_dict) - return results + result_set = self.get_result_set(session, query_execution_id) + return json.loads(result_set.to_json(orient='records')) + + def get_result_set(self, session, query_execution_id): + s3c = session.client('s3') + + tokens = self.s3_output_uri.removeprefix("s3://").removesuffix("/").split("/") + bucket = tokens[0] + key = "/".join(tokens[1:]) +"/"+ query_execution_id + '.csv' + + obj = s3c.get_object(Bucket=bucket, Key=key) + df = pd.read_csv(io.BytesIO(obj['Body'].read()), encoding='utf8') + return df def _get_columns( self, query_result: List[Dict[str, Any]] diff --git a/libs/langchain/langchain/document_loaders/__init__.py b/libs/langchain/langchain/document_loaders/__init__.py index de4b7a87af7e4..5a5aec095e277 100644 --- a/libs/langchain/langchain/document_loaders/__init__.py +++ b/libs/langchain/langchain/document_loaders/__init__.py @@ -73,7 +73,6 @@ def __getattr__(name: str) -> Any: "ArxivLoader", "AssemblyAIAudioTranscriptLoader", "AsyncHtmlLoader", - "AthenaLoader", "AzureAIDataLoader", "AzureBlobStorageContainerLoader", "AzureBlobStorageFileLoader", diff --git a/libs/langchain/langchain/document_loaders/athena.py b/libs/langchain/langchain/document_loaders/athena.py deleted file mode 100644 index 361e0a52165d8..0000000000000 --- a/libs/langchain/langchain/document_loaders/athena.py +++ /dev/null @@ -1,3 +0,0 @@ -from langchain_community.document_loaders.athena import AthenaLoader - -__all__ = ["AthenaLoader"] \ No newline at end of file diff --git a/libs/langchain/tests/unit_tests/document_loaders/test_imports.py b/libs/langchain/tests/unit_tests/document_loaders/test_imports.py index 066aa20b6ae55..ad1b5a7ea34bf 100644 --- a/libs/langchain/tests/unit_tests/document_loaders/test_imports.py +++ b/libs/langchain/tests/unit_tests/document_loaders/test_imports.py @@ -23,7 +23,6 @@ "ArxivLoader", "AssemblyAIAudioTranscriptLoader", "AsyncHtmlLoader", - "AthenaLoader", "AzureAIDataLoader", "AzureBlobStorageContainerLoader", "AzureBlobStorageFileLoader", From b4c08c5c3497810be0f7ab46d266b328d7b748bc Mon Sep 17 00:00:00 2001 From: Abhijeeth Padarthi Date: Thu, 11 Jan 2024 15:25:44 +0530 Subject: [PATCH 05/12] fix linting issues --- .../langchain_community/document_loaders/athena.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/libs/community/langchain_community/document_loaders/athena.py b/libs/community/langchain_community/document_loaders/athena.py index db6a04163610a..4226919199062 100644 --- a/libs/community/langchain_community/document_loaders/athena.py +++ b/libs/community/langchain_community/document_loaders/athena.py @@ -1,11 +1,13 @@ from __future__ import annotations -import time import io -import pandas as pd import json +import time from typing import Any, Dict, Iterator, List, Optional, Tuple + +import pandas as pd from langchain_core.documents import Document + from langchain_community.document_loaders.base import BaseLoader @@ -86,17 +88,17 @@ def _execute_query(self) -> List[Dict[str, Any]]: time.sleep(1) result_set = self.get_result_set(session, query_execution_id) - return json.loads(result_set.to_json(orient='records')) + return json.loads(result_set.to_json(orient="records")) def get_result_set(self, session, query_execution_id): - s3c = session.client('s3') + s3c = session.client("s3") tokens = self.s3_output_uri.removeprefix("s3://").removesuffix("/").split("/") bucket = tokens[0] - key = "/".join(tokens[1:]) +"/"+ query_execution_id + '.csv' + key = "/".join(tokens[1:]) + "/" + query_execution_id + ".csv" obj = s3c.get_object(Bucket=bucket, Key=key) - df = pd.read_csv(io.BytesIO(obj['Body'].read()), encoding='utf8') + df = pd.read_csv(io.BytesIO(obj["Body"].read()), encoding="utf8") return df def _get_columns( From 614a3a12e92d276fbaeac0175dcff42cd0161a7e Mon Sep 17 00:00:00 2001 From: Abhijeeth Padarthi Date: Tue, 16 Jan 2024 01:12:43 +0530 Subject: [PATCH 06/12] change pandas import --- libs/community/langchain_community/document_loaders/athena.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/libs/community/langchain_community/document_loaders/athena.py b/libs/community/langchain_community/document_loaders/athena.py index 4226919199062..fea95c23dd8d2 100644 --- a/libs/community/langchain_community/document_loaders/athena.py +++ b/libs/community/langchain_community/document_loaders/athena.py @@ -5,7 +5,6 @@ import time from typing import Any, Dict, Iterator, List, Optional, Tuple -import pandas as pd from langchain_core.documents import Document from langchain_community.document_loaders.base import BaseLoader @@ -91,6 +90,8 @@ def _execute_query(self) -> List[Dict[str, Any]]: return json.loads(result_set.to_json(orient="records")) def get_result_set(self, session, query_execution_id): + import pandas as pd + s3c = session.client("s3") tokens = self.s3_output_uri.removeprefix("s3://").removesuffix("/").split("/") From 11e56d769990f29fc9daf39f83e4c59cddb6c374 Mon Sep 17 00:00:00 2001 From: Abhijeeth Padarthi Date: Tue, 23 Jan 2024 19:17:46 -0500 Subject: [PATCH 07/12] move client and session creation to init --- .../document_loaders/athena.py | 52 +++++++++++++------ 1 file changed, 35 insertions(+), 17 deletions(-) diff --git a/libs/community/langchain_community/document_loaders/athena.py b/libs/community/langchain_community/document_loaders/athena.py index fea95c23dd8d2..b30ec84b47f28 100644 --- a/libs/community/langchain_community/document_loaders/athena.py +++ b/libs/community/langchain_community/document_loaders/athena.py @@ -49,20 +49,34 @@ def __init__( self.query = query self.database = database self.s3_output_uri = s3_output_uri - self.profile_name = profile_name self.metadata_columns = metadata_columns if metadata_columns is not None else [] - def _execute_query(self) -> List[Dict[str, Any]]: - import boto3 + try: + import boto3 + except ImportError: + raise ModuleNotFoundError( + "Could not import boto3 python package. " + "Please install it with `pip install boto3`." + ) - session = ( - boto3.Session(profile_name=self.profile_name) - if self.profile_name is not None - else boto3.Session() - ) - client = session.client("athena") + try: + session = ( + boto3.Session(profile_name=profile_name) + if profile_name is not None + else boto3.Session() + ) + except Exception as e: + raise ValueError( + "Could not load credentials to authenticate with AWS client. " + "Please check that credentials in the specified " + "profile name are valid." + ) from e - response = client.start_query_execution( + self.athena_client = session.client("athena") + self.s3_client = session.client("s3") + + def _execute_query(self) -> List[Dict[str, Any]]: + response = self.athena_client.start_query_execution( QueryString=self.query, QueryExecutionContext={"Database": self.database}, ResultConfiguration={"OutputLocation": self.s3_output_uri}, @@ -70,7 +84,7 @@ def _execute_query(self) -> List[Dict[str, Any]]: query_execution_id = response["QueryExecutionId"] print(f"Query : {self.query}") while True: - response = client.get_query_execution(QueryExecutionId=query_execution_id) + response = self.athena_client.get_query_execution(QueryExecutionId=query_execution_id) state = response["QueryExecution"]["Status"]["State"] if state == "SUCCEEDED": print(f"State : {state}") @@ -86,19 +100,23 @@ def _execute_query(self) -> List[Dict[str, Any]]: print(f"State : {state}") time.sleep(1) - result_set = self.get_result_set(session, query_execution_id) + result_set = self._get_result_set(query_execution_id) return json.loads(result_set.to_json(orient="records")) - def get_result_set(self, session, query_execution_id): - import pandas as pd - - s3c = session.client("s3") + def _get_result_set(self, query_execution_id: str): + try: + import pandas as pd + except ImportError: + raise ModuleNotFoundError( + "Could not import pandas python package. " + "Please install it with `pip install pandas`." + ) tokens = self.s3_output_uri.removeprefix("s3://").removesuffix("/").split("/") bucket = tokens[0] key = "/".join(tokens[1:]) + "/" + query_execution_id + ".csv" - obj = s3c.get_object(Bucket=bucket, Key=key) + obj = self.s3_client.get_object(Bucket=bucket, Key=key) df = pd.read_csv(io.BytesIO(obj["Body"].read()), encoding="utf8") return df From de252d27a72f5e6aba59deec5d421ac1a2007bed Mon Sep 17 00:00:00 2001 From: Abhijeeth Padarthi Date: Tue, 23 Jan 2024 21:47:40 -0500 Subject: [PATCH 08/12] fix null metadata --- libs/community/langchain_community/document_loaders/athena.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libs/community/langchain_community/document_loaders/athena.py b/libs/community/langchain_community/document_loaders/athena.py index b30ec84b47f28..a3f85ba1a0157 100644 --- a/libs/community/langchain_community/document_loaders/athena.py +++ b/libs/community/langchain_community/document_loaders/athena.py @@ -141,10 +141,10 @@ def lazy_load(self) -> Iterator[Document]: page_content = "\n".join( f"{k}: {v}" for k, v in row.items() if k in content_columns ) - metadata = {k: v for k, v in row.items() if k in metadata_columns} + metadata = {k: v for k, v in row.items() if k in metadata_columns and v is not None} doc = Document(page_content=page_content, metadata=metadata) yield doc def load(self) -> List[Document]: """Load data into document objects.""" - return list(self.lazy_load()) + return list(self.lazy_load()) \ No newline at end of file From 6e809b97fbce9d3c56f89db8cb15ea394f212542 Mon Sep 17 00:00:00 2001 From: Abhijeeth Padarthi Date: Tue, 23 Jan 2024 21:49:45 -0500 Subject: [PATCH 09/12] fix format and lint --- .../langchain_community/document_loaders/athena.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/libs/community/langchain_community/document_loaders/athena.py b/libs/community/langchain_community/document_loaders/athena.py index a3f85ba1a0157..5acc0db48ea5f 100644 --- a/libs/community/langchain_community/document_loaders/athena.py +++ b/libs/community/langchain_community/document_loaders/athena.py @@ -84,7 +84,9 @@ def _execute_query(self) -> List[Dict[str, Any]]: query_execution_id = response["QueryExecutionId"] print(f"Query : {self.query}") while True: - response = self.athena_client.get_query_execution(QueryExecutionId=query_execution_id) + response = self.athena_client.get_query_execution( + QueryExecutionId=query_execution_id + ) state = response["QueryExecution"]["Status"]["State"] if state == "SUCCEEDED": print(f"State : {state}") @@ -141,10 +143,12 @@ def lazy_load(self) -> Iterator[Document]: page_content = "\n".join( f"{k}: {v}" for k, v in row.items() if k in content_columns ) - metadata = {k: v for k, v in row.items() if k in metadata_columns and v is not None} + metadata = { + k: v for k, v in row.items() if k in metadata_columns and v is not None + } doc = Document(page_content=page_content, metadata=metadata) yield doc def load(self) -> List[Document]: """Load data into document objects.""" - return list(self.lazy_load()) \ No newline at end of file + return list(self.lazy_load()) From a76246333b2df8f7a537fefe80bcde882d91b057 Mon Sep 17 00:00:00 2001 From: Abhijeeth Padarthi Date: Mon, 29 Jan 2024 15:29:10 -0500 Subject: [PATCH 10/12] Update athena.ipynb --- .../document_loaders/athena.ipynb | 214 +++++++++--------- 1 file changed, 107 insertions(+), 107 deletions(-) diff --git a/docs/docs/integrations/document_loaders/athena.ipynb b/docs/docs/integrations/document_loaders/athena.ipynb index 2636fa201944e..3fd655c0f0a04 100644 --- a/docs/docs/integrations/document_loaders/athena.ipynb +++ b/docs/docs/integrations/document_loaders/athena.ipynb @@ -1,110 +1,110 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "MwTWzDxYgbrR" - }, - "source": [ - "# Athena\n", - "\n", - "This notebooks goes over how to load documents from AWS Athena" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "F0zaLR3xgWmO" - }, - "outputs": [], - "source": [ - "! pip install boto3" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "076NLjfngoWJ" - }, - "outputs": [], - "source": [ - "from langchain.document_loaders import AthenaLoader" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "XpMRQwU9gu44" - }, - "outputs": [], - "source": [ - "database_name = \"my_database\"\n", - "s3_output_path = \"s3://my_bucket/query_results/\"\n", - "query = f\"SELECT * FROM my_table\"\n", - "profile_name = \"my_profile\"\n", - "\n", - "loader = AthenaLoader(\n", - " query=query,\n", - " database=database_name,\n", - " s3_output_uri=s3_output_path,\n", - " profile_name=profile_name\n", - ")\n", - "\n", - "documents = loader.load()\n", - "print(documents)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "5IBapL3ejoEt" - }, - "source": [ - "Example with metadata columns" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "wMx6nI1qjryD" - }, - "outputs": [], - "source": [ - "database_name = \"my_database\"\n", - "s3_output_path = \"s3://my_bucket/query_results/\"\n", - "query = f\"SELECT * FROM my_table\"\n", - "profile_name = \"my_profile\"\n", - "metadata_columns = [\"_row\", \"_created_at\"]\n", - "\n", - "loader = AthenaLoader(\n", - " query=query,\n", - " database=database_name,\n", - " s3_output_uri=s3_output_path,\n", - " profile_name=profile_name,\n", - " metadata_columns=metadata_columns\n", - ")\n", - "\n", - "documents = loader.load()\n", - "print(documents)" - ] - } - ], - "metadata": { - "colab": { - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python" - } + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "MwTWzDxYgbrR" + }, + "source": [ + "# Athena\n", + "\n", + "This notebooks goes over how to load documents from AWS Athena" + ] }, - "nbformat": 4, - "nbformat_minor": 0 + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "F0zaLR3xgWmO" + }, + "outputs": [], + "source": [ + "! pip install boto3" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "076NLjfngoWJ" + }, + "outputs": [], + "source": [ + "from langchain_community.document_loaders.athena import AthenaLoader" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "XpMRQwU9gu44" + }, + "outputs": [], + "source": [ + "database_name = \"my_database\"\n", + "s3_output_path = \"s3://my_bucket/query_results/\"\n", + "query = \"SELECT * FROM my_table\"\n", + "profile_name = \"my_profile\"\n", + "\n", + "loader = AthenaLoader(\n", + " query=query,\n", + " database=database_name,\n", + " s3_output_uri=s3_output_path,\n", + " profile_name=profile_name,\n", + ")\n", + "\n", + "documents = loader.load()\n", + "print(documents)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "5IBapL3ejoEt" + }, + "source": [ + "Example with metadata columns" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "wMx6nI1qjryD" + }, + "outputs": [], + "source": [ + "database_name = \"my_database\"\n", + "s3_output_path = \"s3://my_bucket/query_results/\"\n", + "query = \"SELECT * FROM my_table\"\n", + "profile_name = \"my_profile\"\n", + "metadata_columns = [\"_row\", \"_created_at\"]\n", + "\n", + "loader = AthenaLoader(\n", + " query=query,\n", + " database=database_name,\n", + " s3_output_uri=s3_output_path,\n", + " profile_name=profile_name,\n", + " metadata_columns=metadata_columns,\n", + ")\n", + "\n", + "documents = loader.load()\n", + "print(documents)" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 } From 1428649aaa330840aa15697a5211f445159f831e Mon Sep 17 00:00:00 2001 From: Abhijeeth Padarthi Date: Thu, 8 Feb 2024 18:59:50 -0500 Subject: [PATCH 11/12] fix for python 3.8 --- .../document_loaders/athena.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/libs/community/langchain_community/document_loaders/athena.py b/libs/community/langchain_community/document_loaders/athena.py index 5acc0db48ea5f..572bc6bba938a 100644 --- a/libs/community/langchain_community/document_loaders/athena.py +++ b/libs/community/langchain_community/document_loaders/athena.py @@ -105,6 +105,16 @@ def _execute_query(self) -> List[Dict[str, Any]]: result_set = self._get_result_set(query_execution_id) return json.loads(result_set.to_json(orient="records")) + def _remove_suffix(self, input_string, suffix): + if suffix and input_string.endswith(suffix): + return input_string[: -len(suffix)] + return input_string + + def _remove_prefix(self, input_string, suffix): + if suffix and input_string.startswith(suffix): + return input_string[len(suffix) :] + return input_string + def _get_result_set(self, query_execution_id: str): try: import pandas as pd @@ -114,7 +124,10 @@ def _get_result_set(self, query_execution_id: str): "Please install it with `pip install pandas`." ) - tokens = self.s3_output_uri.removeprefix("s3://").removesuffix("/").split("/") + output_uri = self.s3_output_uri + tokens = self._remove_prefix( + self._remove_suffix(output_uri, "/"), "s3://" + ).split("/") bucket = tokens[0] key = "/".join(tokens[1:]) + "/" + query_execution_id + ".csv" From e1a1de38c846636159298c92bd395e5219722099 Mon Sep 17 00:00:00 2001 From: Abhijeeth Padarthi Date: Thu, 8 Feb 2024 22:56:30 -0500 Subject: [PATCH 12/12] add types for methods --- .../langchain_community/document_loaders/athena.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/libs/community/langchain_community/document_loaders/athena.py b/libs/community/langchain_community/document_loaders/athena.py index 572bc6bba938a..1e33062c194ee 100644 --- a/libs/community/langchain_community/document_loaders/athena.py +++ b/libs/community/langchain_community/document_loaders/athena.py @@ -105,17 +105,17 @@ def _execute_query(self) -> List[Dict[str, Any]]: result_set = self._get_result_set(query_execution_id) return json.loads(result_set.to_json(orient="records")) - def _remove_suffix(self, input_string, suffix): + def _remove_suffix(self, input_string: str, suffix: str) -> str: if suffix and input_string.endswith(suffix): return input_string[: -len(suffix)] return input_string - def _remove_prefix(self, input_string, suffix): + def _remove_prefix(self, input_string: str, suffix: str) -> str: if suffix and input_string.startswith(suffix): return input_string[len(suffix) :] return input_string - def _get_result_set(self, query_execution_id: str): + def _get_result_set(self, query_execution_id: str) -> Any: try: import pandas as pd except ImportError: