From 72156d18864b43bc976ae675e581de7a06db9c28 Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Tue, 6 Jun 2023 16:01:35 +0200 Subject: [PATCH] Use mocking to skip the actual training part --- test/nodes/test_retriever.py | 87 ++++++++++++------------------------ 1 file changed, 28 insertions(+), 59 deletions(-) diff --git a/test/nodes/test_retriever.py b/test/nodes/test_retriever.py index b13852e9b8..dd62bbf6ec 100644 --- a/test/nodes/test_retriever.py +++ b/test/nodes/test_retriever.py @@ -650,66 +650,35 @@ def test_table_text_retriever_training(tmp_path, document_store, samples_path): ) -def test_sentence_transformers_retriever_training(): - retriever = EmbeddingRetriever( - embedding_model="sentence-transformers/all-MiniLM-L6-v2", model_format="sentence_transformers", use_gpu=False - ) - retriever.train( - training_data=[ - { - "question": "What color is the sky?", - "pos_doc": "The color of the sky is blue.", - "neg_doc": "The color of the grass is brown.", - }, - { - "question": "What is the capital of Germany?", - "pos_doc": "The capital of Germany is Berlin.", - "neg_doc": "The capital of France is Paris.", - }, - ], - train_loss="mnrl", - n_epochs=1, - batch_size=1, - num_warmup_steps=0, - ) +def test_sentence_transformers_retriever_training_with_gradient_checkpointing(): + def mock_fit(*args, **kwargs): + return None - -# # TODO Mock _SentenceTransformerEmbeddingEncoder, create a normal EmbeddingRetriever -# # When creating the mock _SentenceTransformerEmbeddingEncoder add a train method that checks the correct -# # parameters were passed? -# def test_sentence_transformers_retriever_training_with_gradient_checkpointing(): -# def mock_train(model_name_or_path, **kwargs): -# return model_name_or_path == model_name -# -# class MockSentenceTransformers(): -# def train(self): -# -# -# with patch("haystack.nodes._embedding_encoder._SentenceTransformersEmbeddingEncoder", mock_encoder): -# retriever = EmbeddingRetriever( -# embedding_model="sentence-transformers/all-MiniLM-L6-v2", -# model_format="sentence_transformers", -# use_gpu=False, -# ) -# retriever.train( -# training_data=[ -# { -# "question": "What color is the sky?", -# "pos_doc": "The color of the sky is blue.", -# "neg_doc": "The color of the grass is brown.", -# }, -# { -# "question": "What is the capital of Germany?", -# "pos_doc": "The capital of Germany is Berlin.", -# "neg_doc": "The capital of France is Paris.", -# }, -# ], -# train_loss="mnrl", -# n_epochs=1, -# batch_size=1, -# num_warmup_steps=0, -# gradient_checkpointing=True, -# ) + with patch("sentence_transformers.SentenceTransformer.fit", mock_fit): + retriever = EmbeddingRetriever( + embedding_model="sentence-transformers/all-MiniLM-L6-v2", + model_format="sentence_transformers", + use_gpu=False, + ) + retriever.train( + training_data=[ + { + "question": "What color is the sky?", + "pos_doc": "The color of the sky is blue.", + "neg_doc": "The color of the grass is brown.", + }, + { + "question": "What is the capital of Germany?", + "pos_doc": "The capital of Germany is Berlin.", + "neg_doc": "The capital of France is Paris.", + }, + ], + train_loss="mnrl", + n_epochs=1, + batch_size=1, + num_warmup_steps=0, + gradient_checkpointing=True, + ) @pytest.mark.elasticsearch