Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Properly support batched/non-batched with vllm/llama.cpp #77

Closed
wants to merge 10 commits into from

Conversation

njhill
Copy link
Contributor

@njhill njhill commented Jul 3, 2024

This is based on @npalaska's PR #58.

With these changes we will auto-detect whether the server supports batched inputs and if not will send them sequentially.

src/instructlab/sdg/llmblock.py Outdated Show resolved Hide resolved
src/instructlab/sdg/llmblock.py Outdated Show resolved Hide resolved
parsed_outputs = self._parse(output)
# pylint: disable=consider-using-generator

max_length = max([len(value) for value in parsed_outputs.values()])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Github lint is suggesting to use max(len(value) for value in parsed_outputs.values())

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic still needs to be cleaned up anyhow I think, it's not doing what it was intended to

src/instructlab/sdg/llmblock.py Outdated Show resolved Hide resolved
src/instructlab/sdg/llmblock.py Outdated Show resolved Hide resolved

def validate(self, prompt_template: str, input_dict: Dict[str, Any]) -> bool:
if isinstance(prompt_template, dict):
prompt_template = prompt_template[input_dict[self.selector_column_name]]
return super()._validate(prompt_template, input_dict)


def server_supports_batched(client, model_id: str) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be nitpick but we can use server_supports_batching instead of server_supports_batched?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think batched is better.. since it's referring to the inputs. Even without batched inputs it might do batching internally.

Comment on lines +111 to +115
for prompt in prompts:
for _ in range(n):
response = self.client.completions.create(
prompt=prompt, **generate_args
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could rewrite this as

responses = [
    self.client.completions.create(prompt=prompt, **generate_args)
    for prompt in prompts
    for _ in range(n)
]

wdyt?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes but then we would require an additional loop anyhow

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah

@njhill
Copy link
Contributor Author

njhill commented Jul 3, 2024

Thanks @npalaska I addressed most of those comments.

Copy link
Contributor

@markmc markmc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, looks like a great direction

At least resolve the #TODO remove sample from samples thing

}

# Whether the LLM server supports a list of input prompts
# and supports the n parameter to generate n outputs per input
self.server_supports_batched = server_supports_batched(client, model_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The FlowParams in #64 would give us a place to do this once rather than for every LLMBlock, but that can be fixed up later

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See PipelineContext in #86 now

@@ -45,8 +46,13 @@ def __init__(
"model": self.model,
"temperature": 0,
"max_tokens": 12000,
#"seed": 12345, TBD
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Delete? Or add an explanation to the comment

)
return [choice.text.strip() for choice in response.choices]

n = gen_kwargs.get("n", 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would have imagined doing this in reverse - including "num_instructions_to_generate" in the block config and adding 'n' to gen_kwargs if batching was supported. No biggie though

@@ -113,21 +132,30 @@ def generate(self, samples, **gen_kwargs) -> Dataset:
# validate each sample
for sample in samples:
if not self._validate(self.prompt_template, sample):
return None
logger.warning("Sample failed validation") #TODO add details
#TODO remove sample from samples
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. Should this be in a separate PR. If in this PR, the TODO should be resolved?

outputs = self._generate(samples, **gen_kwargs)
logger.debug("Generated outputs: %s", outputs)

num_parallel_samples = gen_kwargs.get("n", 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, and here's a reason to make num_parallel_samples part of the block config ... and add 'n' to gen_kwargs based on that

supported = len(response.choices) == 6
except openai.InternalServerError:
supported = False
client.server_supports_batched = supported
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I understand that you want to cache this ... but I don't like setting a new attribute on a class we don't own

I guess this could be removed with a move to FlowParams

@markmc
Copy link
Contributor

markmc commented Jul 4, 2024

Also, please squash those fixup commits as per instructlab/dev-docs#110

@mergify mergify bot added the needs-rebase label Jul 6, 2024
Copy link
Contributor

mergify bot commented Jul 6, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. @njhill please rebase it. https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

markmc added a commit to markmc/dev-docs that referenced this pull request Jul 8, 2024
@aakankshaduggal
Copy link
Member

Closing in favor of #105

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants