-
Notifications
You must be signed in to change notification settings - Fork 37
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Properly support batched/non-batched with vllm/llama.cpp #77
Conversation
and other streamlining
src/instructlab/sdg/llmblock.py
Outdated
parsed_outputs = self._parse(output) | ||
# pylint: disable=consider-using-generator | ||
|
||
max_length = max([len(value) for value in parsed_outputs.values()]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Github lint is suggesting to use max(len(value) for value in parsed_outputs.values())
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic still needs to be cleaned up anyhow I think, it's not doing what it was intended to
|
||
def validate(self, prompt_template: str, input_dict: Dict[str, Any]) -> bool: | ||
if isinstance(prompt_template, dict): | ||
prompt_template = prompt_template[input_dict[self.selector_column_name]] | ||
return super()._validate(prompt_template, input_dict) | ||
|
||
|
||
def server_supports_batched(client, model_id: str) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might be nitpick but we can use server_supports_batching
instead of server_supports_batched
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think batched is better.. since it's referring to the inputs. Even without batched inputs it might do batching internally.
for prompt in prompts: | ||
for _ in range(n): | ||
response = self.client.completions.create( | ||
prompt=prompt, **generate_args | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could rewrite this as
responses = [
self.client.completions.create(prompt=prompt, **generate_args)
for prompt in prompts
for _ in range(n)
]
wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes but then we would require an additional loop anyhow
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah
Thanks @npalaska I addressed most of those comments. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool, looks like a great direction
At least resolve the #TODO remove sample from samples
thing
} | ||
|
||
# Whether the LLM server supports a list of input prompts | ||
# and supports the n parameter to generate n outputs per input | ||
self.server_supports_batched = server_supports_batched(client, model_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The FlowParams
in #64 would give us a place to do this once rather than for every LLMBlock
, but that can be fixed up later
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See PipelineContext
in #86 now
@@ -45,8 +46,13 @@ def __init__( | |||
"model": self.model, | |||
"temperature": 0, | |||
"max_tokens": 12000, | |||
#"seed": 12345, TBD |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Delete? Or add an explanation to the comment
) | ||
return [choice.text.strip() for choice in response.choices] | ||
|
||
n = gen_kwargs.get("n", 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would have imagined doing this in reverse - including "num_instructions_to_generate" in the block config and adding 'n' to gen_kwargs if batching was supported. No biggie though
@@ -113,21 +132,30 @@ def generate(self, samples, **gen_kwargs) -> Dataset: | |||
# validate each sample | |||
for sample in samples: | |||
if not self._validate(self.prompt_template, sample): | |||
return None | |||
logger.warning("Sample failed validation") #TODO add details | |||
#TODO remove sample from samples |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm. Should this be in a separate PR. If in this PR, the TODO should be resolved?
outputs = self._generate(samples, **gen_kwargs) | ||
logger.debug("Generated outputs: %s", outputs) | ||
|
||
num_parallel_samples = gen_kwargs.get("n", 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, and here's a reason to make num_parallel_samples part of the block config ... and add 'n' to gen_kwargs based on that
supported = len(response.choices) == 6 | ||
except openai.InternalServerError: | ||
supported = False | ||
client.server_supports_batched = supported |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I understand that you want to cache this ... but I don't like setting a new attribute on a class we don't own
I guess this could be removed with a move to FlowParams
Also, please squash those fixup commits as per instructlab/dev-docs#110 |
This pull request has merge conflicts that must be resolved before it can be |
As per instructlab/sdg#77 Signed-off-by: Mark McLoughlin <[email protected]>
Closing in favor of #105 |
This is based on @npalaska's PR #58.
With these changes we will auto-detect whether the server supports batched inputs and if not will send them sequentially.