Skip to content

Commit

Permalink
Remove some unecessary logic, add usage output, fix docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
silvanocerza committed Mar 7, 2024
1 parent bf685ca commit 70fa20d
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,6 @@
from ._schema import GenerationRequest, GenerationResponse, Message
from .models import NvidiaGeneratorModel

SUPPORTED_MODELS: List[NvidiaGeneratorModel] = [
NvidiaGeneratorModel.NV_LLAMA2_RLHF_70B,
NvidiaGeneratorModel.STEERLM_LLAMA_70B,
NvidiaGeneratorModel.NEMOTRON_STEERLM_8B,
NvidiaGeneratorModel.NEMOTRON_QA_8B,
]


@component
class NvidiaGenerator:
Expand Down Expand Up @@ -85,11 +78,6 @@ def __init__(
},
)

if self._model not in SUPPORTED_MODELS:
models = ", ".join(e.value for e in NvidiaGeneratorModel)
msg = f"Model {self._model} is not supported, available models are: {models}"
raise ValueError(msg)

def warm_up(self):
"""
Initializes the component.
Expand Down Expand Up @@ -123,7 +111,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "NvidiaGenerator":
deserialize_secrets_inplace(init_params, ["api_key"])
return default_from_dict(cls, data)

@component.output_types(replies=List[str], meta=List[Dict[str, Any]])
@component.output_types(replies=List[str], meta=List[Dict[str, Any]], usage=Dict[str, int])
def run(self, prompt: str):
"""
Queries the model with the provided prompt.
Expand All @@ -132,8 +120,9 @@ def run(self, prompt: str):
Text to be sent to the generative model.
:returns:
A dictionary with the following keys:
- "replies": replies generated by the model.
- "meta": metadata for each reply.
- `replies` - Replies generated by the model.
- `meta` - Metadata for each reply.
- `usage` - Usage statistics for the model.
"""
if self._model_id is None:
msg = "The generation model has not been loaded. Call warm_up() before running."
Expand All @@ -152,11 +141,13 @@ def run(self, prompt: str):
{
"role": choice.message.role,
"finish_reason": choice.finish_reason,
# The usage field is not part of each choice, so we use reuse it each time
"completion_tokens": data.usage.completion_tokens,
"prompt_tokens": data.usage.prompt_tokens,
"total_tokens": data.usage.total_tokens,
}
)

return {"replies": replies, "meta": meta}
usage = {
"completion_tokens": data.usage.completion_tokens,
"prompt_tokens": data.usage.prompt_tokens,
"total_tokens": data.usage.total_tokens,
}

return {"replies": replies, "meta": meta, "usage": usage}
9 changes: 6 additions & 3 deletions integrations/nvidia/tests/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,13 @@ def test_run(self, mock_client):
{
"finish_reason": "stop",
"role": "assistant",
"total_tokens": 21,
"prompt_tokens": 19,
"completion_tokens": 2,
},
],
"usage": {
"total_tokens": 21,
"prompt_tokens": 19,
"completion_tokens": 2,
},
}

@pytest.mark.skipif(
Expand All @@ -171,3 +173,4 @@ def test_run_integration(self):

assert result["replies"]
assert result["meta"]
assert result["usage"]

0 comments on commit 70fa20d

Please sign in to comment.