Skip to content

Commit

Permalink
[CLEANUP]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye Gomez authored and Kye Gomez committed Jul 8, 2024
1 parent a197d6f commit 8033cbd
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 24 deletions.
22 changes: 0 additions & 22 deletions Makefile

This file was deleted.

4 changes: 2 additions & 2 deletions swarms_memory/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from swarms_memory.chroma_db_wrapper import ChromaMemory
from swarms_memory.chroma_db_wrapper import ChromaDBMemory
from swarms_memory.pinecone_wrapper import PineconeMemory

__all__ = ["ChromaMemory", "PineconeMemory"]
__all__ = ["ChromaDBMemory", "PineconeMemory"]
65 changes: 65 additions & 0 deletions tests/test_chromadb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from unittest.mock import patch, MagicMock
from swarms_memory.chroma_db_wrapper import ChromaDBMemory


@patch("chromadb.PersistentClient")
@patch("chromadb.Client")
def test_init(mock_client, mock_persistent_client):
chroma_db = ChromaDBMemory(
metric="cosine",
output_dir="swarms",
limit_tokens=1000,
n_results=1,
docs_folder=None,
verbose=False,
)
assert chroma_db.metric == "cosine"
assert chroma_db.output_dir == "swarms"
assert chroma_db.limit_tokens == 1000
assert chroma_db.n_results == 1
assert chroma_db.docs_folder == None
assert chroma_db.verbose == False
mock_persistent_client.assert_called_once()
mock_client.assert_called_once()


@patch("chromadb.PersistentClient")
@patch("chromadb.Client")
def test_add(mock_client, mock_persistent_client):
chroma_db = ChromaDBMemory()
mock_collection = MagicMock()
chroma_db.collection = mock_collection
doc_id = chroma_db.add("test document")
mock_collection.add.assert_called_once_with(
ids=[doc_id], documents=["test document"]
)
assert isinstance(doc_id, str)


@patch("chromadb.PersistentClient")
@patch("chromadb.Client")
def test_query(mock_client, mock_persistent_client):
chroma_db = ChromaDBMemory()
mock_collection = MagicMock()
chroma_db.collection = mock_collection
mock_collection.query.return_value = {
"documents": ["test document"]
}
result = chroma_db.query("test query")
mock_collection.query.assert_called_once_with(
query_texts=["test query"], n_results=1
)
assert result == "test document\n"


@patch("chromadb.PersistentClient")
@patch("chromadb.Client")
@patch("os.walk")
@patch("swarms_memory.chroma_db_wrapper.ChromaDBMemory.add")
def test_traverse_directory(
mock_add, mock_walk, mock_client, mock_persistent_client
):
chroma_db = ChromaDBMemory(docs_folder="test_folder")
mock_walk.return_value = [("root", "dirs", ["file1", "file2"])]
chroma_db.traverse_directory()
assert mock_add.call_count == 2
61 changes: 61 additions & 0 deletions tests/test_pinecone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from unittest.mock import patch
from swarms_memory.pinecone_wrapper import PineconeMemory


@patch("pinecone.init")
@patch("pinecone.list_indexes")
@patch("pinecone.create_index")
@patch("pinecone.Index")
@patch("sentence_transformers.SentenceTransformer")
def test_init(
mock_st,
mock_index,
mock_create_index,
mock_list_indexes,
mock_init,
):
mock_list_indexes.return_value = []
PineconeMemory("api_key", "environment", "index_name")
mock_init.assert_called_once_with(
api_key="api_key", environment="environment"
)
mock_create_index.assert_called_once()
mock_index.assert_called_once_with("index_name")
mock_st.assert_called_once_with("all-MiniLM-L6-v2")


@patch("loguru.logger.configure")
def test_setup_logger(mock_configure):
PineconeMemory._setup_logger(None)
mock_configure.assert_called_once()


@patch("sentence_transformers.SentenceTransformer.encode")
def test_default_embedding_function(mock_encode):
pm = PineconeMemory("api_key", "environment", "index_name")
pm._default_embedding_function("text")
mock_encode.assert_called_once_with("text")


def test_default_preprocess_function():
pm = PineconeMemory("api_key", "environment", "index_name")
assert pm._default_preprocess_function(" text ") == "text"


def test_default_postprocess_function():
pm = PineconeMemory("api_key", "environment", "index_name")
assert pm._default_postprocess_function("results") == "results"


@patch("pinecone.Index.upsert")
def test_add(mock_upsert):
pm = PineconeMemory("api_key", "environment", "index_name")
pm.add("doc")
mock_upsert.assert_called_once()


@patch("pinecone.Index.query")
def test_query(mock_query):
pm = PineconeMemory("api_key", "environment", "index_name")
pm.query("query")
mock_query.assert_called_once()

0 comments on commit 8033cbd

Please sign in to comment.