Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More elaborate private tests, saner public tests. #679

Merged
merged 2 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions src/modelgauge/private_ensemble_annotator_set.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import os
from typing import Any, Dict, List

from modelgauge.annotator_registry import ANNOTATORS
from modelgauge.annotator_set import AnnotatorSet
from modelgauge.auth.huggingface_inference_token import HuggingFaceInferenceToken
from modelgauge.auth.together_key import TogetherApiKey
from modelgauge.auth.vllm_key import VllmApiKey
from modelgauge.config import load_secrets_from_config
from modelgauge.dependency_injection import _replace_with_injected
from modelgauge.secret_values import InjectSecret
from modelgauge.single_turn_prompt_response import TestItemAnnotations
Expand Down Expand Up @@ -49,22 +47,13 @@ class EnsembleAnnotatorSet(AnnotatorSet):

def __init__(self, secrets):
self.secrets = secrets
# TODO: Pass in the strategy as a parameter for easy swapping.
self.strategy = MajorityVoteEnsembleStrategy()
self.__configure_vllm_annotators()
self.__configure_huggingface_annotators()
self.__configure_together_annotators()
self._register_annotators()

def _register_annotators(self):
# TODO: Register annotators in secret repo.
# ANNOTATORS.register(
# PromptEngineeredAnnotator(uid=MISTRAL_8x22B_CONFIG.name, config=MISTRAL_8x22B_CONFIG),
# )
# ANNOTATORS.register(PromptEngineeredAnnotator(uid=LLAMA_3_70B_CONFIG.name, config=LLAMA_3_70B_CONFIG))
# ANNOTATORS.register(LlamaGuard2LoRAAnnotator(uid=LG2_LORA_CONFIG.name, config=LG2_LORA_CONFIG))
# ANNOTATORS.register(WildguardAnnotator(uid=WILDGUARD_ANNOTATOR_CONFIG.name, config=WILDGUARD_ANNOTATOR_CONFIG))

self.annotators = [
MISTRAL_8x22B_CONFIG.name,
LLAMA_3_70B_CONFIG.name,
Expand Down
41 changes: 13 additions & 28 deletions tests/modelgauge_tests/test_private_ensemble_annotator_set.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,27 @@
import os
from unittest.mock import Mock, patch

import pytest

from modelgauge.suts.together_client import TogetherApiKey
from unittest import mock


@mock.patch.dict(os.environ, {"VLLM_ENDPOINT_URL": "https://example.org/"}, clear=True)
def test_can_load():
"""This just makes sure things are properly connected. Fuller testing is in the private code."""
try:
# EnsembleAnnotator is required by the private annotators
# If we can import it, then the EnsembleAnnotatorSet can be instantiated
from modelgauge.auth.huggingface_inference_token import HuggingFaceInferenceToken
from modelgauge.auth.vllm_key import VllmApiKey
from modelgauge.private_ensemble_annotator_set import EnsembleAnnotatorSet
from modelgauge.suts.together_client import TogetherApiKey
from modelgauge.annotators.wildguard_annotator import WildguardAnnotator

assert True
in_private = True
except:
# The EnsembleAnnotator can't be implemented, so the EnsembleAnnotatorSet can't either
with pytest.raises(NotImplementedError):
from modelgauge.private_ensemble_annotator_set import EnsembleAnnotatorSet
in_private = False


def test_annotators():
try:
from modelgauge.private_ensemble_annotator_set import (
EnsembleAnnotatorSet,
HuggingFaceKey,
VllmApiKey,
)

os.environ["VLLM_ENDPOINT_URL"] = "fake"
annotators = EnsembleAnnotatorSet(
if in_private:
annotator_set = EnsembleAnnotatorSet(
secrets={
"together_api_key": TogetherApiKey("fake"),
"huggingface_key": HuggingFaceKey("fake"),
"huggingface_inference_token": HuggingFaceInferenceToken("fake"),
"vllm_api_key": VllmApiKey("fake"),
}
)
assert len(annotators.annotators) == 4
except:
# The EnsembleAnnotator can't be implemented, so the EnsembleAnnotatorSet can't either
with pytest.raises(NotImplementedError):
from modelgauge.private_ensemble_annotator_set import EnsembleAnnotatorSet
assert len(annotator_set.annotators) == 4