Skip to content

Commit

Permalink
Merge pull request #150 from huggingface/haojun/num_samples
Browse files Browse the repository at this point in the history
num_samples
  • Loading branch information
NouamaneTazi authored Apr 29, 2024
2 parents 1c7f038 + 24cfbbd commit dab2b78
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 11 deletions.
9 changes: 6 additions & 3 deletions run_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def main():
tp=args.tp or config.parallelism.tp,
pp_engine=OneForwardOneBackwardPipelineEngine(),
tp_mode=TensorParallelLinearMode.ALL_REDUCE,
tp_linear_async_communication=True,
tp_linear_async_communication=False,
)

# Initialise all process groups
Expand Down Expand Up @@ -164,9 +164,12 @@ def main():
tokenizer.padding_side = "left"
tokenizer.truncation_side = "left" # TODO @nouamane: do we want this?
dummy_inputs = [
# "Passage: Daniel went back to the garden. Mary travelled to the kitchen. Sandra journeyed to the kitchen. Sandra went to the hallway. John went to the bedroom. Mary went back to the garden. Where is Mary?\nAnswer:",
"The future of AI is",
"Passage: Daniel went back to the garden. Mary travelled to the kitchen. Sandra journeyed to the kitchen. Sandra went to the hallway. John went to the bedroom. Mary went back to the garden. Where is Mary?\nAnswer:",
"def fib(n)",
# "This film was probably inspired by Godzilla",
'Here is an extract from a webpage: "Have you ever experienced heel pain after a heavy physical activity, or even right after a long period of standing? If you regard this as something usual and normal, then think again. Miscalled as heel pain, plantar fasciitis causes these frequent mild pains experienced in the soles of the feet. It is the inflammation and enlargement the plantar fascia tissue that is located in the heels of the feet, stretching to the base of the toes. This tissue is responsible for absorbing shock in the feet and for supporting the arches. It also plays a vital role in foot movements during walking and standing. Many factors such as excessive walking, standing, and running trigger heel pain and plantar fasciitis. A sudden increase in intensity of activities, increase in weight, and abrupt change of footwear also cause the swelling of the ligament. Non-supportive footwear lacking arch cushions and improper and worn out running or training can also lead to the problem. It is also most evident among those". Write an extensive and detailed course unit suitable for a textbook targeted at college students, related to the given extract, within the context of "Medicine". Do not just list concepts, but develop each one in detail before moving to the next, as we prioritize depth of understanding and comprehensive exploration of the subject matter over breadth. Focus on: - Rigor: Ensure in-depth coverage of the concepts/sections. - Engagement: Write with an academic, professional and engaging tone that captivates interest. - Application: Incorporate specific, practical examples, such as proofs in calculus or critical dates and figures in history. Do not include a title or an introduction, simply write the content without headlines and introductory phrases. Do not use images.',
"Advancements in technology will lead to",
"Tomorrow's world is shaped by",
]

outputs = decode_text(
Expand Down
2 changes: 1 addition & 1 deletion run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def get_dataloader_from_data_stage(
)
assert num_tokens_needed_for_training <= total_tokens_dataset, (
f"Dataset is too small for steps ({total_tokens_dataset} < {num_tokens_needed_for_training}), "
f"Try train_steps<={len(dataloader.dataset) // trainer.global_batch_size + trainer.start_iteration_step}"
f"Try train_steps<={len(dataloader.dataset) // trainer.global_batch_size + trainer.iteration_step}"
)
else:
raise ValueError(f"Unhandled case of `self.config.data.dataset`. Got: {data.dataset}")
Expand Down
8 changes: 8 additions & 0 deletions src/nanotron/generation/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,14 @@ def decode_text(

p2p = model.p2p

# replicate input for n_samples times when using TOP_P or TOP_K samplers, in order to get diverse results
if generation_config and generation_config.n_samples:
if sampler_type != SamplerType.TOP_P and sampler_type != SamplerType.TOP_K:
raise ValueError("Only support n_samples for TOP_P and TOP_K sampler")
input_iter = [
GenerationInput(text=input.text) for input in input_iter for _ in range(generation_config.n_samples)
]

# That's annoying but I need this as soon as there's a change communication "cross"
pipeline_state = PipelineEvalBatchState()
with attach_pipeline_state_to_model(model=model, pipeline_state=pipeline_state):
Expand Down
10 changes: 3 additions & 7 deletions src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,13 +656,9 @@ def init_model(self) -> Union[NanotronModel, DistributedDataParallel]:
rank=0,
)
else:
log_rank(
f"Setting max_position_embeddings to {self.config.tokens.sequence_length}. Previous value was {self.model_config.max_position_embeddings}.",
logger=logger,
level=logging.INFO,
rank=0,
)
self.model_config.max_position_embeddings = self.config.tokens.sequence_length
assert (
self.config.tokens.sequence_length == self.model_config.max_position_embeddings
), "The tokenizer's sequence length does not match the model's maximum position embeddings."

log_rank("Config:\n" + pformat(self.config), logger=logger, level=logging.INFO, rank=0)
log_rank("Model Config:\n" + pformat(self.model_config), logger=logger, level=logging.INFO, rank=0)
Expand Down

0 comments on commit dab2b78

Please sign in to comment.