Skip to content

Commit

Permalink
Use mocking to skip the actual training part
Browse files Browse the repository at this point in the history
  • Loading branch information
sjrl committed Jun 6, 2023
1 parent 3e07190 commit 72156d1
Showing 1 changed file with 28 additions and 59 deletions.
87 changes: 28 additions & 59 deletions test/nodes/test_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,66 +650,35 @@ def test_table_text_retriever_training(tmp_path, document_store, samples_path):
)


def test_sentence_transformers_retriever_training():
retriever = EmbeddingRetriever(
embedding_model="sentence-transformers/all-MiniLM-L6-v2", model_format="sentence_transformers", use_gpu=False
)
retriever.train(
training_data=[
{
"question": "What color is the sky?",
"pos_doc": "The color of the sky is blue.",
"neg_doc": "The color of the grass is brown.",
},
{
"question": "What is the capital of Germany?",
"pos_doc": "The capital of Germany is Berlin.",
"neg_doc": "The capital of France is Paris.",
},
],
train_loss="mnrl",
n_epochs=1,
batch_size=1,
num_warmup_steps=0,
)
def test_sentence_transformers_retriever_training_with_gradient_checkpointing():
def mock_fit(*args, **kwargs):
return None


# # TODO Mock _SentenceTransformerEmbeddingEncoder, create a normal EmbeddingRetriever
# # When creating the mock _SentenceTransformerEmbeddingEncoder add a train method that checks the correct
# # parameters were passed?
# def test_sentence_transformers_retriever_training_with_gradient_checkpointing():
# def mock_train(model_name_or_path, **kwargs):
# return model_name_or_path == model_name
#
# class MockSentenceTransformers():
# def train(self):
#
#
# with patch("haystack.nodes._embedding_encoder._SentenceTransformersEmbeddingEncoder", mock_encoder):
# retriever = EmbeddingRetriever(
# embedding_model="sentence-transformers/all-MiniLM-L6-v2",
# model_format="sentence_transformers",
# use_gpu=False,
# )
# retriever.train(
# training_data=[
# {
# "question": "What color is the sky?",
# "pos_doc": "The color of the sky is blue.",
# "neg_doc": "The color of the grass is brown.",
# },
# {
# "question": "What is the capital of Germany?",
# "pos_doc": "The capital of Germany is Berlin.",
# "neg_doc": "The capital of France is Paris.",
# },
# ],
# train_loss="mnrl",
# n_epochs=1,
# batch_size=1,
# num_warmup_steps=0,
# gradient_checkpointing=True,
# )
with patch("sentence_transformers.SentenceTransformer.fit", mock_fit):
retriever = EmbeddingRetriever(
embedding_model="sentence-transformers/all-MiniLM-L6-v2",
model_format="sentence_transformers",
use_gpu=False,
)
retriever.train(
training_data=[
{
"question": "What color is the sky?",
"pos_doc": "The color of the sky is blue.",
"neg_doc": "The color of the grass is brown.",
},
{
"question": "What is the capital of Germany?",
"pos_doc": "The capital of Germany is Berlin.",
"neg_doc": "The capital of France is Paris.",
},
],
train_loss="mnrl",
n_epochs=1,
batch_size=1,
num_warmup_steps=0,
gradient_checkpointing=True,
)


@pytest.mark.elasticsearch
Expand Down

0 comments on commit 72156d1

Please sign in to comment.