Skip to content

Commit

Permalink
feat: Retrieval module support string as input (#817)
Browse files Browse the repository at this point in the history
  • Loading branch information
Wendong-Fan authored Aug 12, 2024
1 parent 6472678 commit 03612e3
Show file tree
Hide file tree
Showing 17 changed files with 75 additions and 81 deletions.
2 changes: 1 addition & 1 deletion .github/ISSUE_TEMPLATE/bug_report.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ body:
attributes:
label: What version of camel are you using?
description: Run command `python3 -c 'print(__import__("camel").__version__)'` in your shell and paste the output here.
placeholder: E.g., 0.1.6.2
placeholder: E.g., 0.1.6.3
validations:
required: true

Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ conda create --name camel python=3.9
conda activate camel
# Clone github repo
git clone -b v0.1.6.2 https://github.com/camel-ai/camel.git
git clone -b v0.1.6.3 https://github.com/camel-ai/camel.git
# Change directory into project directory
cd camel
Expand Down
2 changes: 1 addition & 1 deletion camel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========

__version__ = '0.1.6.2'
__version__ = '0.1.6.3'

__all__ = [
'__version__',
Expand Down
60 changes: 25 additions & 35 deletions camel/retrievers/auto_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,36 +97,36 @@ def _initialize_vector_storage(
f"Unsupported vector storage type: {self.storage_type}"
)

def _collection_name_generator(self, content_input_path: str) -> str:
def _collection_name_generator(self, content: str) -> str:
r"""Generates a valid collection name from a given file path or URL.
Args:
content_input_path: str. The input URL or file path from which to
generate the collection name.
contents (str): Local file path, remote URL or string content.
Returns:
str: A sanitized, valid collection name suitable for use.
"""
# Check path type
parsed_url = urlparse(content_input_path)
self.is_url = all([parsed_url.scheme, parsed_url.netloc])
# Check if the content is URL
parsed_url = urlparse(content)
is_url = all([parsed_url.scheme, parsed_url.netloc])

# Convert given path into a collection name, ensuring it only
# contains numbers, letters, and underscores
if self.is_url:
if is_url:
# For URLs, remove https://, replace /, and any characters not
# allowed by Milvus with _
collection_name = re.sub(
r'[^0-9a-zA-Z]+',
'_',
content_input_path.replace("https://", ""),
content.replace("https://", ""),
)
else:
elif os.path.exists(content):
# For file paths, get the stem and replace spaces with _, also
# ensuring only allowed characters are present
collection_name = re.sub(
r'[^0-9a-zA-Z]+', '_', Path(content_input_path).stem
)
collection_name = re.sub(r'[^0-9a-zA-Z]+', '_', Path(content).stem)
else:
# the content is string input
collection_name = content[:10]

# Ensure the collection name does not start or end with underscore
collection_name = collection_name.strip("_")
Expand Down Expand Up @@ -193,7 +193,7 @@ def _get_file_modified_date_from_storage(
def run_vector_retriever(
self,
query: str,
content_input_paths: Union[str, List[str]],
contents: Union[str, List[str]],
top_k: int = DEFAULT_TOP_K_RESULTS,
similarity_threshold: float = DEFAULT_SIMILARITY_THRESHOLD,
return_detailed_info: bool = False,
Expand All @@ -203,8 +203,8 @@ def run_vector_retriever(
Args:
query (str): Query string for information retriever.
content_input_paths (Union[str, List[str]]): Paths to local
files or remote URLs.
contents (Union[str, List[str]]): Local file paths, remote URLs or
string contents.
top_k (int, optional): The number of top results to return during
retrieve. Must be a positive integer. Defaults to
`DEFAULT_TOP_K_RESULTS`.
Expand All @@ -223,24 +223,18 @@ def run_vector_retriever(
Raises:
ValueError: If there's an vector storage existing with content
name in the vector path but the payload is None. If
`content_input_paths` is empty.
`contents` is empty.
RuntimeError: If any errors occur during the retrieve process.
"""
if not content_input_paths:
raise ValueError("content_input_paths cannot be empty.")
if not contents:
raise ValueError("content cannot be empty.")

content_input_paths = (
[content_input_paths]
if isinstance(content_input_paths, str)
else content_input_paths
)
contents = [contents] if isinstance(contents, str) else contents

all_retrieved_info = []
for content_input_path in content_input_paths:
for content in contents:
# Generate a valid collection name
collection_name = self._collection_name_generator(
content_input_path
)
collection_name = self._collection_name_generator(content)
try:
vector_storage_instance = self._initialize_vector_storage(
collection_name
Expand All @@ -251,13 +245,11 @@ def run_vector_retriever(
file_is_modified = False # initialize with a default value
if (
vector_storage_instance.status().vector_count != 0
and not self.is_url
and os.path.exists(content)
):
# Get original modified date from file
modified_date_from_file = (
self._get_file_modified_date_from_file(
content_input_path
)
self._get_file_modified_date_from_file(content)
)
# Get modified date from vector storage
modified_date_from_storage = (
Expand All @@ -280,18 +272,16 @@ def run_vector_retriever(
# Process and store the content to the vector storage
vr = VectorRetriever(
storage=vector_storage_instance,
similarity_threshold=similarity_threshold,
embedding_model=self.embedding_model,
)
vr.process(content_input_path)
vr.process(content)
else:
vr = VectorRetriever(
storage=vector_storage_instance,
similarity_threshold=similarity_threshold,
embedding_model=self.embedding_model,
)
# Retrieve info by given query from the vector storage
retrieved_info = vr.query(query, top_k)
retrieved_info = vr.query(query, top_k, similarity_threshold)
all_retrieved_info.extend(retrieved_info)
except Exception as e:
raise RuntimeError(
Expand Down
38 changes: 20 additions & 18 deletions camel/retrievers/vector_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
import os
from typing import Any, Dict, List, Optional
from urllib.parse import urlparse

from camel.embeddings import BaseEmbedding, OpenAIEmbedding
from camel.loaders import UnstructuredIO
Expand All @@ -38,24 +40,18 @@ class VectorRetriever(BaseRetriever):
embedding_model (BaseEmbedding): Embedding model used to generate
vector embeddings.
storage (BaseVectorStorage): Vector storage to query.
similarity_threshold (float, optional): The similarity threshold
for filtering results. Defaults to `DEFAULT_SIMILARITY_THRESHOLD`.
unstructured_modules (UnstructuredIO): A module for parsing files and
URLs and chunking content based on specified parameters.
"""

def __init__(
self,
similarity_threshold: float = DEFAULT_SIMILARITY_THRESHOLD,
embedding_model: Optional[BaseEmbedding] = None,
storage: Optional[BaseVectorStorage] = None,
) -> None:
r"""Initializes the retriever class with an optional embedding model.
Args:
similarity_threshold (float, optional): The similarity threshold
for filtering results. Defaults to
`DEFAULT_SIMILARITY_THRESHOLD`.
embedding_model (Optional[BaseEmbedding]): The embedding model
instance. Defaults to `OpenAIEmbedding` if not provided.
storage (BaseVectorStorage): Vector storage to query.
Expand All @@ -68,12 +64,11 @@ def __init__(
vector_dim=self.embedding_model.get_output_dim()
)
)
self.similarity_threshold = similarity_threshold
self.unstructured_modules: UnstructuredIO = UnstructuredIO()
self.uio: UnstructuredIO = UnstructuredIO()

def process(
self,
content_input_path: str,
content: str,
chunk_type: str = "chunk_by_title",
**kwargs: Any,
) -> None:
Expand All @@ -82,16 +77,19 @@ def process(
vector storage.
Args:
content_input_path (str): File path or URL of the content to be
processed.
contents (str): Local file path, remote URL or string content.
chunk_type (str): Type of chunking going to apply. Defaults to
"chunk_by_title".
**kwargs (Any): Additional keyword arguments for content parsing.
"""
elements = self.unstructured_modules.parse_file_or_url(
content_input_path, **kwargs
)
chunks = self.unstructured_modules.chunk_elements(
# Check if the content is URL
parsed_url = urlparse(content)
is_url = all([parsed_url.scheme, parsed_url.netloc])
if is_url or os.path.exists(content):
elements = self.uio.parse_file_or_url(content, **kwargs)
else:
elements = [self.uio.create_element_from_text(text=content)]
chunks = self.uio.chunk_elements(
chunk_type=chunk_type, elements=elements
)
# Iterate to process and store embeddings, set batch of 50
Expand All @@ -105,7 +103,7 @@ def process(
# Prepare the payload for each vector record, includes the content
# path, chunk metadata, and chunk text
for vector, chunk in zip(batch_vectors, batch_chunks):
content_path_info = {"content path": content_input_path}
content_path_info = {"content path": content}
chunk_metadata = {"metadata": chunk.metadata.to_dict()}
chunk_text = {"text": str(chunk)}
combined_dict = {
Expand All @@ -124,12 +122,16 @@ def query(
self,
query: str,
top_k: int = DEFAULT_TOP_K_RESULTS,
similarity_threshold: float = DEFAULT_SIMILARITY_THRESHOLD,
) -> List[Dict[str, Any]]:
r"""Executes a query in vector storage and compiles the retrieved
results into a dictionary.
Args:
query (str): Query string for information retriever.
similarity_threshold (float, optional): The similarity threshold
for filtering results. Defaults to
`DEFAULT_SIMILARITY_THRESHOLD`.
top_k (int, optional): The number of top results to return during
retriever. Must be a positive integer. Defaults to 1.
Expand Down Expand Up @@ -161,7 +163,7 @@ def query(
formatted_results = []
for result in query_results:
if (
result.similarity >= self.similarity_threshold
result.similarity >= similarity_threshold
and result.record.payload is not None
):
result_dict = {
Expand All @@ -182,7 +184,7 @@ def query(
'text': (
f"No suitable information retrieved "
f"from {content_path} with similarity_threshold"
f" = {self.similarity_threshold}."
f" = {similarity_threshold}."
)
}
]
Expand Down
10 changes: 5 additions & 5 deletions camel/toolkits/retrieval_toolkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class RetrievalToolkit(BaseToolkit):
"""

def information_retrieval(
self, query: str, content_input_paths: Union[str, List[str]]
self, query: str, contents: Union[str, List[str]]
) -> str:
r"""Retrieves information from a local vector storage based on the
specified query. This function connects to a local vector storage
Expand All @@ -37,8 +37,8 @@ def information_retrieval(
Args:
query (str): The question or query for which an answer is required.
content_input_paths (Union[str, List[str]]): Paths to local
files or remote URLs.
contents (Union[str, List[str]]): Local file paths, remote URLs or
string contents.
Returns:
str: The information retrieved in response to the query, aggregated
Expand All @@ -47,15 +47,15 @@ def information_retrieval(
Example:
# Retrieve information about CAMEL AI.
information_retrieval(query = "what is CAMEL AI?",
content_input_paths="https://www.camel-ai.org/")
contents="https://www.camel-ai.org/")
"""
auto_retriever = AutoRetriever(
vector_storage_local_path="camel/temp_storage",
storage_type=StorageType.QDRANT,
)

retrieved_info = auto_retriever.run_vector_retriever(
query=query, content_input_paths=content_input_paths, top_k=3
query=query, contents=contents, top_k=3
)
return retrieved_info

Expand Down
8 changes: 4 additions & 4 deletions camel/toolkits/search_toolkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def search_wiki(self, entity: str) -> str:
return result

def search_duckduckgo(
self, query: str, source: str = "text", max_results: int = 10
self, query: str, source: str = "text", max_results: int = 5
) -> List[Dict[str, Any]]:
r"""Use DuckDuckGo search engine to search information for
the given query.
Expand All @@ -78,7 +78,7 @@ def search_duckduckgo(
query (str): The query to be searched.
source (str): The type of information to query (e.g., "text",
"images", "videos"). Defaults to "text".
max_results (int): Max number of results, defaults to `10`.
max_results (int): Max number of results, defaults to `5`.
Returns:
List[Dict[str, Any]]: A list of dictionaries where each dictionary
Expand Down Expand Up @@ -152,7 +152,7 @@ def search_duckduckgo(
return responses

def search_google(
self, query: str, num_result_pages: int = 10
self, query: str, num_result_pages: int = 5
) -> List[Dict[str, Any]]:
r"""Use Google search engine to search information for the given query.
Expand Down Expand Up @@ -196,7 +196,7 @@ def search_google(
# Different language may get different result
search_language = "en"
# How many pages to return
num_result_pages = 10
num_result_pages = num_result_pages
# Constructing the URL
# Doc: https://developers.google.com/custom-search/v1/using_rest
url = (
Expand Down
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
project = 'CAMEL'
copyright = '2023, CAMEL-AI.org'
author = 'CAMEL-AI.org'
release = '0.1.6.2'
release = '0.1.6.3'

html_favicon = (
'https://raw.githubusercontent.com/camel-ai/camel/master/misc/favicon.png'
Expand Down
2 changes: 1 addition & 1 deletion docs/get_started/setup.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ conda create --name camel python=3.10
conda activate camel
# Clone github repo
git clone -b v0.1.6.2 https://github.com/camel-ai/camel.git
git clone -b v0.1.6.3 https://github.com/camel-ai/camel.git
# Change directory into project directory
cd camel
Expand Down
2 changes: 1 addition & 1 deletion docs/key_modules/retrievers.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ ar = AutoRetriever(vector_storage_local_path="camel/retrievers",storage_type=Sto

# Run the auto vector retriever
retrieved_info = ar.run_vector_retriever(
content_input_paths=[
contents=[
"https://www.camel-ai.org/", # Example remote url
],
query="What is CAMEL-AI",
Expand Down
Loading

0 comments on commit 03612e3

Please sign in to comment.