Skip to content

Commit

Permalink
fix run_generate
Browse files Browse the repository at this point in the history
  • Loading branch information
3outeille committed Apr 24, 2024
1 parent cb17862 commit f933736
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions examples/mamba/run_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt-path", type=Path, required=True, help="Checkpoint path")
parser.add_argument("--dp", type=int, default=0)
parser.add_argument("--pp", type=int, default=0)
parser.add_argument("--tp", type=int, default=0)
parser.add_argument("--dp", type=int, default=1)
parser.add_argument("--pp", type=int, default=1)
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--max-new-tokens", type=int, default=128, help="Maximum number of new tokens to generate")
return parser.parse_args()

Expand All @@ -77,9 +77,9 @@ def main():
tokenizer_path = config.tokenizer.tokenizer_name_or_path

parallel_config = ParallelismArgs(
dp=args.dp or config.parallelism.dp,
pp=args.pp or config.parallelism.pp,
tp=args.tp or config.parallelism.tp,
dp=args.dp,
pp=args.pp,
tp=args.tp,
pp_engine=OneForwardOneBackwardPipelineEngine(),
tp_mode=TensorParallelLinearMode.ALL_REDUCE,
tp_linear_async_communication=False,
Expand Down

0 comments on commit f933736

Please sign in to comment.