diff --git a/README.md b/README.md index 4c48afc..a085ab0 100644 --- a/README.md +++ b/README.md @@ -176,11 +176,11 @@ with torch.inference_mode(): with torch.inference_mode(): emb = t2vec_model.predict(["She worked hard and made a significant contribution to the team."], source_lang='eng_Latn') - x = classifier(emb.to(device).to(dtype)) # tensor([[-58.0625]], device='cuda:0', dtype=torch.float16) + x = classifier(emb.to(device).to(dtype)) # tensor([[-53.5938]], device='cuda:0', dtype=torch.float16) with torch.inference_mode(): emb = t2vec_model.predict(["El no tiene ni el más mínimo talento, todo lo que ha logrado ha sido gracias a sobornos y manipulaciones."], source_lang='spa_Latn') - x = classifier(emb.to(device).to(dtype)) # tensor([[-24.6094]], device='cuda:0', dtype=torch.float16) + x = classifier(emb.to(device).to(dtype)) # tensor([[-21.4062]], device='cuda:0', dtype=torch.float16) ``` For a CLI way of running the MuTox pipeline, go to [Seamless Communication/.../MuTox](https://github.com/facebookresearch/seamless_communication/tree/main/src/seamless_communication/cli/toxicity/mutox). diff --git a/sonar/models/mutox/builder.py b/sonar/models/mutox/builder.py index 7e6577f..b308cb9 100644 --- a/sonar/models/mutox/builder.py +++ b/sonar/models/mutox/builder.py @@ -62,7 +62,9 @@ def build_model(self) -> MutoxClassifier: model_h3, ) - return MutoxClassifier(model_all,).to( + return MutoxClassifier( + model_all, + ).to( device=self.device, dtype=self.dtype, ) diff --git a/sonar/models/mutox/classifier.py b/sonar/models/mutox/classifier.py index 9ae7ebe..524a386 100644 --- a/sonar/models/mutox/classifier.py +++ b/sonar/models/mutox/classifier.py @@ -30,7 +30,6 @@ def forward(self, inputs: torch.Tensor, output_prob: bool = False) -> torch.Tens return outputs - @dataclass class MutoxConfig: """Holds the configuration of a Mutox Classifier model.""" diff --git a/tests/integration_tests/test_mutox.py b/tests/integration_tests/test_mutox.py new file mode 100644 index 0000000..2777275 --- /dev/null +++ b/tests/integration_tests/test_mutox.py @@ -0,0 +1,122 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch +from sonar.models.mutox.loader import load_mutox_model +from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline + + +@pytest.mark.parametrize( + "input_texts, source_lang, expected_outputs", + [ + ( + ["De peur que le pays ne se prostitue et ne se remplisse de crimes."], + "fra_Latn", + [-19.7812], + ), + ( + ["She worked hard and made a significant contribution to the team."], + "eng_Latn", + [-53.5938], + ), + ( + [ + "El no tiene ni el más mínimo talento, todo lo que ha logrado ha sido gracias a sobornos y manipulaciones." + ], + "spa_Latn", + [-21.4062], + ), + ], +) +def test_sonar_mutox_classifier_integration(input_texts, source_lang, expected_outputs): + """Integration test to compare classifier outputs with expected values.""" + if torch.cuda.is_available(): + device = torch.device("cuda:0") + dtype = torch.float16 + else: + device = torch.device("cpu") + dtype = torch.float32 + + t2vec_model = TextToEmbeddingModelPipeline( + encoder="text_sonar_basic_encoder", + tokenizer="text_sonar_basic_encoder", + device=device, + ) + + classifier = load_mutox_model("sonar_mutox", device=device, dtype=dtype).eval() + + with torch.inference_mode(): + embeddings = t2vec_model.predict(input_texts, source_lang=source_lang) + outputs = classifier(embeddings.to(device).to(dtype)).squeeze() + + if outputs.dim() == 0: + outputs = [outputs.item()] + else: + outputs = outputs.tolist() + + # Compare the outputs to expected values within a small tolerance + for output, expected in zip(outputs, expected_outputs): + assert abs(output - expected) < 0.1, ( + f"Expected output {expected}, but got {output}. " + "Outputs should be close to expected values." + ) + + +@pytest.mark.parametrize( + "input_texts, source_lang, expected_probabilities", + [ + ( + ["De peur que le pays ne se prostitue et ne se remplisse de crimes."], + "fra_Latn", + [0.0], + ), + ( + ["She worked hard and made a significant contribution to the team."], + "eng_Latn", + [0.0], + ), + ( + [ + "El no tiene ni el más mínimo talento, todo lo que ha logrado ha sido gracias a sobornos y manipulaciones." + ], + "spa_Latn", + [0.0], + ), + ], +) +def test_sonar_mutox_classifier_probability_integration( + input_texts, source_lang, expected_probabilities +): + """Integration test to verify classifier output probabilities.""" + + if torch.cuda.is_available(): + device = torch.device("cuda:0") + dtype = torch.float16 + else: + device = torch.device("cpu") + dtype = torch.float32 + + t2vec_model = TextToEmbeddingModelPipeline( + encoder="text_sonar_basic_encoder", + tokenizer="text_sonar_basic_encoder", + device=device, + ) + + classifier = load_mutox_model("sonar_mutox", device=device, dtype=dtype).eval() + + for text, lang, expected_prob in zip( + input_texts, [source_lang] * len(input_texts), expected_probabilities + ): + with torch.inference_mode(): + emb = t2vec_model.predict([text], source_lang=lang) + + prob = classifier(emb.to(device).to(dtype), output_prob=True) + + assert abs(prob.item() - expected_prob) < 0.001, ( + f"Expected probability {expected_prob}, but got {prob.item()}. " + "Output probability should be within a reasonable range." + ) diff --git a/tests/unit_tests/test_mutox.py b/tests/unit_tests/test_mutox.py index 0e0db33..ebb73eb 100644 --- a/tests/unit_tests/test_mutox.py +++ b/tests/unit_tests/test_mutox.py @@ -72,6 +72,29 @@ def test_mutox_classifier_forward(): ), f"Expected output shape (3, 1), but instead got {output.shape}" +def test_mutox_classifier_forward_with_output_prob(): + """Test that MutoxClassifier forward pass applies sigmoid when output_prob=True.""" + test_model = nn.Sequential( + nn.Linear(10, 5), + nn.ReLU(), + nn.Linear(5, 1), + ) + model = MutoxClassifier(test_model) + + test_input = torch.randn(3, 10) + + output = model(test_input, output_prob=True) + + assert output.shape == ( + 3, + 1, + ), f"Expected output shape (3, 1), but instead got {output.shape}" + + assert (output >= 0).all() and ( + output <= 1 + ).all(), "Expected output values to be within the range [0, 1]" + + def test_mutox_config(): """Test that MutoxConfig stores the configuration for a model.""" config = MutoxConfig(input_size=512) @@ -85,7 +108,6 @@ def test_mutox_config(): def test_convert_mutox_checkpoint(): """Test convert_mutox_checkpoint correctly filters keys in the checkpoint.""" - # Create a mock checkpoint with both 'model_all.' prefixed keys and other keys checkpoint = { "model_all.layer1.weight": torch.tensor([1.0]), "model_all.layer1.bias": torch.tensor([0.5]),