diff --git a/Makefile b/Makefile deleted file mode 100644 index a99809c..0000000 --- a/Makefile +++ /dev/null @@ -1,22 +0,0 @@ -.PHONY: style check_code_quality - -export PYTHONPATH = . -check_dirs := src - -style: - black $(check_dirs) - isort --profile black $(check_dirs) - -check_code_quality: - black --check $(check_dirs) - isort --check-only --profile black $(check_dirs) - # stop the build if there are Python syntax errors or undefined names - flake8 $(check_dirs) --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. E203 for black, E501 for docstring, W503 for line breaks before logical operators - flake8 $(check_dirs) --count --max-line-length=88 --exit-zero --ignore=D --extend-ignore=E203,E501,W503 --statistics - -publish: - python setup.py sdist bdist_wheel - twine upload -r testpypi dist/* -u ${PYPI_USERNAME} -p ${PYPI_TEST_PASSWORD} --verbose - twine check dist/* - twine upload dist/* -u ${PYPI_USERNAME} -p ${PYPI_PASSWORD} --verbose \ No newline at end of file diff --git a/swarms_memory/__init__.py b/swarms_memory/__init__.py index 467e68b..4fe0d69 100644 --- a/swarms_memory/__init__.py +++ b/swarms_memory/__init__.py @@ -1,4 +1,4 @@ -from swarms_memory.chroma_db_wrapper import ChromaMemory +from swarms_memory.chroma_db_wrapper import ChromaDBMemory from swarms_memory.pinecone_wrapper import PineconeMemory -__all__ = ["ChromaMemory", "PineconeMemory"] +__all__ = ["ChromaDBMemory", "PineconeMemory"] diff --git a/tests/test_chromadb.py b/tests/test_chromadb.py index e69de29..e0fcea4 100644 --- a/tests/test_chromadb.py +++ b/tests/test_chromadb.py @@ -0,0 +1,65 @@ +from unittest.mock import patch, MagicMock +from swarms_memory.chroma_db_wrapper import ChromaDBMemory + + +@patch("chromadb.PersistentClient") +@patch("chromadb.Client") +def test_init(mock_client, mock_persistent_client): + chroma_db = ChromaDBMemory( + metric="cosine", + output_dir="swarms", + limit_tokens=1000, + n_results=1, + docs_folder=None, + verbose=False, + ) + assert chroma_db.metric == "cosine" + assert chroma_db.output_dir == "swarms" + assert chroma_db.limit_tokens == 1000 + assert chroma_db.n_results == 1 + assert chroma_db.docs_folder == None + assert chroma_db.verbose == False + mock_persistent_client.assert_called_once() + mock_client.assert_called_once() + + +@patch("chromadb.PersistentClient") +@patch("chromadb.Client") +def test_add(mock_client, mock_persistent_client): + chroma_db = ChromaDBMemory() + mock_collection = MagicMock() + chroma_db.collection = mock_collection + doc_id = chroma_db.add("test document") + mock_collection.add.assert_called_once_with( + ids=[doc_id], documents=["test document"] + ) + assert isinstance(doc_id, str) + + +@patch("chromadb.PersistentClient") +@patch("chromadb.Client") +def test_query(mock_client, mock_persistent_client): + chroma_db = ChromaDBMemory() + mock_collection = MagicMock() + chroma_db.collection = mock_collection + mock_collection.query.return_value = { + "documents": ["test document"] + } + result = chroma_db.query("test query") + mock_collection.query.assert_called_once_with( + query_texts=["test query"], n_results=1 + ) + assert result == "test document\n" + + +@patch("chromadb.PersistentClient") +@patch("chromadb.Client") +@patch("os.walk") +@patch("swarms_memory.chroma_db_wrapper.ChromaDBMemory.add") +def test_traverse_directory( + mock_add, mock_walk, mock_client, mock_persistent_client +): + chroma_db = ChromaDBMemory(docs_folder="test_folder") + mock_walk.return_value = [("root", "dirs", ["file1", "file2"])] + chroma_db.traverse_directory() + assert mock_add.call_count == 2 diff --git a/tests/test_pinecone.py b/tests/test_pinecone.py index e69de29..dd80da8 100644 --- a/tests/test_pinecone.py +++ b/tests/test_pinecone.py @@ -0,0 +1,61 @@ +from unittest.mock import patch +from swarms_memory.pinecone_wrapper import PineconeMemory + + +@patch("pinecone.init") +@patch("pinecone.list_indexes") +@patch("pinecone.create_index") +@patch("pinecone.Index") +@patch("sentence_transformers.SentenceTransformer") +def test_init( + mock_st, + mock_index, + mock_create_index, + mock_list_indexes, + mock_init, +): + mock_list_indexes.return_value = [] + PineconeMemory("api_key", "environment", "index_name") + mock_init.assert_called_once_with( + api_key="api_key", environment="environment" + ) + mock_create_index.assert_called_once() + mock_index.assert_called_once_with("index_name") + mock_st.assert_called_once_with("all-MiniLM-L6-v2") + + +@patch("loguru.logger.configure") +def test_setup_logger(mock_configure): + PineconeMemory._setup_logger(None) + mock_configure.assert_called_once() + + +@patch("sentence_transformers.SentenceTransformer.encode") +def test_default_embedding_function(mock_encode): + pm = PineconeMemory("api_key", "environment", "index_name") + pm._default_embedding_function("text") + mock_encode.assert_called_once_with("text") + + +def test_default_preprocess_function(): + pm = PineconeMemory("api_key", "environment", "index_name") + assert pm._default_preprocess_function(" text ") == "text" + + +def test_default_postprocess_function(): + pm = PineconeMemory("api_key", "environment", "index_name") + assert pm._default_postprocess_function("results") == "results" + + +@patch("pinecone.Index.upsert") +def test_add(mock_upsert): + pm = PineconeMemory("api_key", "environment", "index_name") + pm.add("doc") + mock_upsert.assert_called_once() + + +@patch("pinecone.Index.query") +def test_query(mock_query): + pm = PineconeMemory("api_key", "environment", "index_name") + pm.query("query") + mock_query.assert_called_once()