Skip to content

Commit

Permalink
More elaborate private tests, saner public tests. (#679)
Browse files Browse the repository at this point in the history
* More elaborate private tests, saner public tests.

* Fixing failure in public-land.
  • Loading branch information
wpietri authored Nov 13, 2024
1 parent 21b43b8 commit 5050006
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 39 deletions.
11 changes: 0 additions & 11 deletions src/modelgauge/private_ensemble_annotator_set.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import os
from typing import Any, Dict, List

from modelgauge.annotator_registry import ANNOTATORS
from modelgauge.annotator_set import AnnotatorSet
from modelgauge.auth.huggingface_inference_token import HuggingFaceInferenceToken
from modelgauge.auth.together_key import TogetherApiKey
from modelgauge.auth.vllm_key import VllmApiKey
from modelgauge.config import load_secrets_from_config
from modelgauge.dependency_injection import _replace_with_injected
from modelgauge.secret_values import InjectSecret
from modelgauge.single_turn_prompt_response import TestItemAnnotations
Expand Down Expand Up @@ -49,22 +47,13 @@ class EnsembleAnnotatorSet(AnnotatorSet):

def __init__(self, secrets):
self.secrets = secrets
# TODO: Pass in the strategy as a parameter for easy swapping.
self.strategy = MajorityVoteEnsembleStrategy()
self.__configure_vllm_annotators()
self.__configure_huggingface_annotators()
self.__configure_together_annotators()
self._register_annotators()

def _register_annotators(self):
# TODO: Register annotators in secret repo.
# ANNOTATORS.register(
# PromptEngineeredAnnotator(uid=MISTRAL_8x22B_CONFIG.name, config=MISTRAL_8x22B_CONFIG),
# )
# ANNOTATORS.register(PromptEngineeredAnnotator(uid=LLAMA_3_70B_CONFIG.name, config=LLAMA_3_70B_CONFIG))
# ANNOTATORS.register(LlamaGuard2LoRAAnnotator(uid=LG2_LORA_CONFIG.name, config=LG2_LORA_CONFIG))
# ANNOTATORS.register(WildguardAnnotator(uid=WILDGUARD_ANNOTATOR_CONFIG.name, config=WILDGUARD_ANNOTATOR_CONFIG))

self.annotators = [
MISTRAL_8x22B_CONFIG.name,
LLAMA_3_70B_CONFIG.name,
Expand Down
41 changes: 13 additions & 28 deletions tests/modelgauge_tests/test_private_ensemble_annotator_set.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,27 @@
import os
from unittest.mock import Mock, patch

import pytest

from modelgauge.suts.together_client import TogetherApiKey
from unittest import mock


@mock.patch.dict(os.environ, {"VLLM_ENDPOINT_URL": "https://example.org/"}, clear=True)
def test_can_load():
"""This just makes sure things are properly connected. Fuller testing is in the private code."""
try:
# EnsembleAnnotator is required by the private annotators
# If we can import it, then the EnsembleAnnotatorSet can be instantiated
from modelgauge.auth.huggingface_inference_token import HuggingFaceInferenceToken
from modelgauge.auth.vllm_key import VllmApiKey
from modelgauge.private_ensemble_annotator_set import EnsembleAnnotatorSet
from modelgauge.suts.together_client import TogetherApiKey
from modelgauge.annotators.wildguard_annotator import WildguardAnnotator

assert True
in_private = True
except:
# The EnsembleAnnotator can't be implemented, so the EnsembleAnnotatorSet can't either
with pytest.raises(NotImplementedError):
from modelgauge.private_ensemble_annotator_set import EnsembleAnnotatorSet
in_private = False


def test_annotators():
try:
from modelgauge.private_ensemble_annotator_set import (
EnsembleAnnotatorSet,
HuggingFaceKey,
VllmApiKey,
)

os.environ["VLLM_ENDPOINT_URL"] = "fake"
annotators = EnsembleAnnotatorSet(
if in_private:
annotator_set = EnsembleAnnotatorSet(
secrets={
"together_api_key": TogetherApiKey("fake"),
"huggingface_key": HuggingFaceKey("fake"),
"huggingface_inference_token": HuggingFaceInferenceToken("fake"),
"vllm_api_key": VllmApiKey("fake"),
}
)
assert len(annotators.annotators) == 4
except:
# The EnsembleAnnotator can't be implemented, so the EnsembleAnnotatorSet can't either
with pytest.raises(NotImplementedError):
from modelgauge.private_ensemble_annotator_set import EnsembleAnnotatorSet
assert len(annotator_set.annotators) == 4

0 comments on commit 5050006

Please sign in to comment.