Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Properly support batched/non-batched with vllm/llama.cpp #77

Closed
wants to merge 10 commits into from
138 changes: 35 additions & 103 deletions src/instructlab/sdg/default_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

# Local
from .filterblock import FilterByValueBlock
from .iterblock import IterBlock
from .llmblock import LLMBlock
from .utilblocks import CombineColumnsBlock

Expand All @@ -30,12 +29,11 @@ def _get_model_prompt(model_family):


class Flow(ABC):
def __init__(self, client, model_family, model_id, num_iters, batched=True) -> None:
def __init__(self, client, model_family, model_id, num_instructions_to_generate) -> None:
self.client = client
self.model_family = model_family
self.model_id = model_id
self.num_iters = num_iters
self.batched = batched
self.num_instructions_to_generate = num_instructions_to_generate
self.sdg_base = resources.files(__package__)

@abstractmethod
Expand All @@ -47,62 +45,51 @@ class _SimpleFlow(Flow):
def get_flow(self) -> list:
return [
{
"block_type": IterBlock,
"block_type": LLMBlock,
"block_config": {
"block_name": "", # must be set by subclass
"num_iters": self.num_iters,
"block_type": LLMBlock,
"block_kwargs": {
"block_name": "", # must be set by subclass
"config_path": "", # must be set by subclass
"client": self.client,
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_family),
"output_cols": ["output"],
"batch_kwargs": {
"num_procs": 8,
"batched": self.batched,
},
},
"gen_kwargs": {
"max_tokens": 2048,
"temperature": 0.7,
},
"drop_duplicates": ["output"],
"config_path": "", # must be set by subclass
"client": self.client,
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_family),
"output_cols": ["output"],
},
"gen_kwargs": {
"max_tokens": 2048,
"temperature": 0.7,
"n": self.num_instructions_to_generate
},
"drop_duplicates": ["output"],
}
]


class SimpleKnowledgeFlow(_SimpleFlow):
def get_flow(self) -> list:
flow = super().get_flow()
flow[0]["block_config"]["block_kwargs"]["config_path"] = os.path.join(
flow[0]["block_config"]["config_path"] = os.path.join(
self.sdg_base, "configs/knowledge/simple_generate_qa.yaml"
)
flow[0]["block_config"]["block_kwargs"]["block_name"] = "gen_knowledge"
flow[0]["block_config"]["block_name"] = "gen_knowledge"
return flow


class SimpleFreeformSkillFlow(_SimpleFlow):
def get_flow(self) -> list:
flow = super().get_flow()
flow[0]["block_config"]["block_kwargs"]["config_path"] = os.path.join(
flow[0]["block_config"]["config_path"] = os.path.join(
self.sdg_base, "configs/skills/simple_generate_qa_freeform.yaml"
)
flow[0]["block_config"]["block_kwargs"]["block_name"] = "gen_skill_freeform"
flow[0]["block_config"]["block_name"] = "gen_skill_freeform"
return flow


class SimpleGroundedSkillFlow(_SimpleFlow):
def get_flow(self) -> list:
flow = super().get_flow()
flow[0]["block_config"]["block_kwargs"]["config_path"] = os.path.join(
flow[0]["block_config"]["config_path"] = os.path.join(
self.sdg_base, "configs/skills/simple_generate_qa_grounded.yaml"
)
flow[0]["block_config"]["block_kwargs"]["block_name"] = "gen_skill_grounded"
flow[0]["block_config"]["block_name"] = "gen_skill_grounded"
return flow

Expand All @@ -122,10 +109,6 @@ def get_flow(self) -> list:
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_family),
"output_cols": ["mmlubench_question", "mmlubench_answer"],
"batch_kwargs": {
"num_procs": 8,
"batched": self.batched,
},
},
"gen_kwargs": {
"temperature": 0,
Expand All @@ -151,10 +134,6 @@ def get_flow(self) -> list:
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_family),
"output_cols": ["question", "response"],
"batch_kwargs": {
"num_procs": 8,
"batched": self.batched,
},
"parser_kwargs": {
"parser_name": "custom",
"parsing_pattern": r"\[(?:Question|QUESTION)\]\s*(.*?)\s*\[(?:Answer|ANSWER)\]\s*(.*?)\s*(?=\[(?:Question|QUESTION)\]|$)",
Expand All @@ -177,10 +156,6 @@ def get_flow(self) -> list:
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_family),
"output_cols": ["explanation", "judgment"],
"batch_kwargs": {
"num_procs": 8,
"batched": self.batched,
},
},
"gen_kwargs": {
"max_tokens": 2048,
Expand Down Expand Up @@ -210,10 +185,6 @@ def get_flow(self) -> list:
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_family),
"output_cols": ["feedback", "score"],
"batch_kwargs": {
"num_procs": 8,
"batched": self.batched,
},
},
"gen_kwargs": {
"max_tokens": 2048,
Expand Down Expand Up @@ -244,10 +215,6 @@ def get_flow(self) -> list:
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_family),
"output_cols": ["explanation", "rating"],
"batch_kwargs": {
"num_procs": 8,
"batched": self.batched,
},
},
"gen_kwargs": {
"max_tokens": 2048,
Expand Down Expand Up @@ -286,9 +253,7 @@ def get_flow(self) -> list:
"model_prompt": _get_model_prompt(self.model_family),
"output_cols": ["question"],
"batch_kwargs": {
"num_procs": 8,
"num_samples": 30,
"batched": self.batched,
"num_samples": self.num_instructions_to_generate,
},
},
"drop_duplicates": ["question"],
Expand All @@ -305,10 +270,6 @@ def get_flow(self) -> list:
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_family),
"output_cols": ["evaluation", "score"],
"batch_kwargs": {
"num_procs": 8,
"batched": self.batched,
},
},
},
{
Expand Down Expand Up @@ -337,10 +298,6 @@ def get_flow(self) -> list:
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_family),
"output_cols": ["response"],
"batch_kwargs": {
"num_procs": 8,
"batched": self.batched,
},
},
},
{
Expand All @@ -355,10 +312,6 @@ def get_flow(self) -> list:
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_family),
"output_cols": ["evaluation", "score"],
"batch_kwargs": {
"num_procs": 8,
"batched": self.batched,
},
},
},
{
Expand All @@ -382,31 +335,24 @@ class SynthGroundedSkillsFlow(Flow):
def get_flow(self) -> list:
return [
{
"block_type": IterBlock,
"block_type": LLMBlock,
"block_config": {
"block_name": "context_iter",
"num_iters": 10,
"block_type": LLMBlock,
"block_kwargs": {
"block_name": "gen_contexts",
"config_path": os.path.join(
self.sdg_base,
"configs/skills/contexts.yaml",
),
"client": self.client,
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_family),
"output_cols": ["context"],
"batch_kwargs": {
"num_procs": 8,
"batched": self.batched,
},
},
"gen_kwargs": {
"temperature": 0.7,
"max_tokens": 2048,
},
"block_name": "gen_contexts",
"config_path": os.path.join(
self.sdg_base,
"configs/skills/contexts.yaml",
),
"client": self.client,
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_family),
"output_cols": ["context"],
},
"gen_kwargs": {
"temperature": 0.7,
"max_tokens": 2048,
"n": self.num_instructions_to_generate
},
"drop_duplicates": ["context"],
},
{
"block_type": LLMBlock,
Expand All @@ -421,8 +367,7 @@ def get_flow(self) -> list:
"model_prompt": _get_model_prompt(self.model_family),
"output_cols": ["question"],
"batch_kwargs": {
"num_procs": 8,
"batched": self.batched,
"num_samples": 3,
},
},
"drop_duplicates": ["question"],
Expand All @@ -439,11 +384,6 @@ def get_flow(self) -> list:
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_family),
"output_cols": ["evaluation", "score"],
"batch_kwargs": {
"num_procs": 8,
"batched": self.batched,
"num_samples": 10,
},
},
},
{
Expand Down Expand Up @@ -472,10 +412,6 @@ def get_flow(self) -> list:
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_family),
"output_cols": ["response"],
"batch_kwargs": {
"num_procs": 8,
"batched": self.batched,
},
},
},
{
Expand All @@ -490,10 +426,6 @@ def get_flow(self) -> list:
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_family),
"output_cols": ["evaluation", "score"],
"batch_kwargs": {
"num_procs": 8,
"batched": self.batched,
},
},
},
{
Expand Down
14 changes: 4 additions & 10 deletions src/instructlab/sdg/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def _gen_test_data(
outfile.write("\n")


def _sdg_init(pipeline, client, model_family, model_name, num_iters, batched):
def _sdg_init(pipeline, client, model_family, model_name, num_instructions_to_generate):
knowledge_flow_types = []
freeform_skill_flow_types = []
grounded_skill_flow_types = []
Expand All @@ -144,7 +144,7 @@ def _sdg_init(pipeline, client, model_family, model_name, num_iters, batched):
[
Pipeline(
flow_type(
client, model_family, model_name, num_iters, batched
client, model_family, model_name, num_instructions_to_generate
).get_flow()
)
for flow_type in knowledge_flow_types
Expand All @@ -154,7 +154,7 @@ def _sdg_init(pipeline, client, model_family, model_name, num_iters, batched):
[
Pipeline(
flow_type(
client, model_family, model_name, num_iters, batched
client, model_family, model_name, num_instructions_to_generate
).get_flow()
)
for flow_type in freeform_skill_flow_types
Expand All @@ -164,7 +164,7 @@ def _sdg_init(pipeline, client, model_family, model_name, num_iters, batched):
[
Pipeline(
flow_type(
client, model_family, model_name, num_iters, batched
client, model_family, model_name, num_instructions_to_generate
).get_flow()
)
for flow_type in grounded_skill_flow_types
Expand Down Expand Up @@ -242,17 +242,12 @@ def generate_data(
else:
model_family = MODEL_FAMILY_MERLINITE

# TODO -- llama-cpp doesn't support batching, we need to get a hint from the CLI
# about whether we can turn this on (whether vllm is used or not)
batched = False

sdg_knowledge, sdg_freeform_skill, sdg_grounded_skill = _sdg_init(
pipeline,
client,
model_family,
model_name,
num_instructions_to_generate,
batched,
)

if console_output:
Expand All @@ -267,7 +262,6 @@ def generate_data(
if not samples:
raise utils.GenerateException("Error: No samples found in leaf node.")

sdg = None
if samples[0].get("document"):
sdg = sdg_knowledge
elif samples[0].get("context"):
Expand Down
29 changes: 0 additions & 29 deletions src/instructlab/sdg/iterblock.py

This file was deleted.

Loading