From eaef99e3c99fb5ab4f8ec56deccb9fc28514c946 Mon Sep 17 00:00:00 2001 From: KimmiShi Date: Tue, 11 Jun 2024 13:23:58 +0800 Subject: [PATCH] feat(inference): support generation using trainer (#230) --- internlm/apis/__init__.py | 6 + internlm/apis/inference.py | 603 +++++++++++------- internlm/apis/inference_utils.py | 69 ++ internlm/core/engine.py | 5 + internlm/core/scheduler/base_scheduler.py | 25 +- .../core/scheduler/no_pipeline_scheduler.py | 1 - internlm/core/scheduler/pipeline_scheduler.py | 2 +- internlm/utils/common.py | 4 +- tests/test_infer/test_generate.py | 133 ++++ tests/test_infer/test_trainer_generate.py | 201 ++++++ 10 files changed, 818 insertions(+), 231 deletions(-) create mode 100644 internlm/apis/inference_utils.py create mode 100644 tests/test_infer/test_generate.py create mode 100644 tests/test_infer/test_trainer_generate.py diff --git a/internlm/apis/__init__.py b/internlm/apis/__init__.py index e69de29bb..ba807b5e1 100644 --- a/internlm/apis/__init__.py +++ b/internlm/apis/__init__.py @@ -0,0 +1,6 @@ +from .inference_utils import InferenceParams, process_parallel_output + +__all__ = [ + "InferenceParams", + "process_parallel_output", +] diff --git a/internlm/apis/inference.py b/internlm/apis/inference.py index 7a51e34d3..d3b5de87f 100644 --- a/internlm/apis/inference.py +++ b/internlm/apis/inference.py @@ -1,48 +1,18 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from typing import List, Tuple +from typing import Dict, List, Tuple, Union import torch import torch.nn.functional as F from torch import nn -__all__ = ["SequenceGenerator"] - - -class InferenceParams: - """ - Intermediate cache objects for inference - """ +from internlm.apis import InferenceParams, process_parallel_output +from internlm.core.context import ParallelMode # noqa: E402 +from internlm.core.context import global_context as gpc # noqa: E402 +from internlm.core.trainer import Trainer - def __init__( - self, - max_sequence_len, - max_batch_size, - sequence_len_offset=0, - batch_size_offset=0, - key_value_memory_dict: dict = None, - lengths_per_sample=None, - attention_mask=None, - ) -> None: - - self.max_sequence_len: int = max_sequence_len - self.max_batch_size: int = max_batch_size - self.sequence_len_offset: int = sequence_len_offset - self.batch_size_offset: int = batch_size_offset - if key_value_memory_dict is None: - key_value_memory_dict = {} - self.key_value_memory_dict: dict = key_value_memory_dict - self.fused_ft_kernel: bool = False - self.lengths_per_sample = lengths_per_sample - self.attention_mask = attention_mask - - def reorder_state(self, indices): - if self.lengths_per_sample is not None: - self.lengths_per_sample = self.lengths_per_sample.index_select(index=indices, dim=0) - for key, value in list(self.key_value_memory_dict.items()): - value = value.index_select(index=indices, dim=0) - self.key_value_memory_dict[key] = value +__all__ = ["SequenceGenerator"] def _get_model_device(model): @@ -357,17 +327,8 @@ def _streaming_no_beam_search_generate( eos_token_id = torch.LongTensor(eos_token_id).to(tokens.device) has_bos = torch.all(tokens[:, 0].eq(bos_token_id)) - if has_bos: - bos_pos = torch.where(tokens.eq(bos_token_id), 1, 0) - bos_sum = bos_pos.cumsum(dim=-1) - bos_pos = torch.where(bos_sum.eq(bos_sum[:, -1:]), 0, 1) - to_atten_x = bos_pos[:, :, None] - to_atten_y = bos_pos[:, None, :] - else: - bos_pos = torch.where(tokens.eq(bos_token_id), 1, 0) - to_atten_x = bos_pos[:, :, None] - to_atten_y = bos_pos[:, None, :] - attention_mask = torch.logical_or(to_atten_x, to_atten_y).eq(1) + attention_mask = get_attention_mask(tokens, has_bos, bos_token_id=bos_token_id) + if inference_params is None: inference_params = InferenceParams( max_sequence_len=max_length, @@ -379,7 +340,16 @@ def _streaming_no_beam_search_generate( attention_mask=attention_mask, ) - scores = decoder(**{"input_ids": tokens, "inference_params": inference_params}) + if isinstance(decoder, torch.nn.Module): + scores = decoder(**{"input_ids": tokens, "inference_params": inference_params}) + elif isinstance(decoder, Trainer): + data = {"input_ids": tokens, "inference_params": inference_params} + model_output, _, _ = decoder.execute_schedule( + (data, None), forward_only=True, return_loss=False, return_output_label=True + ) + scores = torch.cat(model_output, dim=0) + else: + raise NotImplementedError(f"Unsupported decoder type: {type(decoder)}") if isinstance(scores, (list, tuple)): scores = scores[0] @@ -401,19 +371,20 @@ def _streaming_no_beam_search_generate( while cur_len < real_max_length: # batch_size x vocab_size - if has_bos: - bos_pos = torch.where(token_ids.eq(bos_token_id), 1, 0) - bos_sum = bos_pos.cumsum(dim=-1) - bos_pos = torch.where(bos_sum.eq(bos_sum[:, -1:]), 0, 1) - to_atten_x = bos_pos[:, :, None] - to_atten_y = bos_pos[:, None, :] + attention_mask = get_attention_mask(token_ids, has_bos, bos_token_id=bos_token_id) + + if isinstance(decoder, torch.nn.Module): + inference_params.attention_mask = attention_mask + scores = decoder(**{"input_ids": token_ids[:, -1:], "inference_params": inference_params}) + elif isinstance(decoder, Trainer): + inference_params.set_attention_mask(attention_mask) + data = {"input_ids": token_ids[:, -1:], "inference_params": inference_params} + model_output, _, _ = decoder.execute_schedule( + (data, None), forward_only=True, return_loss=False, return_output_label=True + ) + scores = torch.cat(model_output, dim=0) else: - bos_pos = torch.where(token_ids.eq(bos_token_id), 1, 0) - to_atten_x = bos_pos[:, :, None] - to_atten_y = bos_pos[:, None, :] - attention_mask = torch.logical_or(to_atten_x, to_atten_y).eq(1) - inference_params.attention_mask = attention_mask - scores = decoder(**{"input_ids": token_ids[:, -1:], "inference_params": inference_params}) + raise NotImplementedError(f"Unsupported decoder type: {type(decoder)}") if isinstance(scores, (list, tuple)): scores = scores[0] @@ -502,17 +473,9 @@ def _no_beam_search_generate( eos_token_id = torch.LongTensor(eos_token_id).to(tokens.device) has_bos = torch.all(tokens[:, 0].eq(bos_token_id)) - if has_bos: - bos_pos = torch.where(tokens.eq(bos_token_id), 1, 0) - bos_sum = bos_pos.cumsum(dim=-1) - bos_pos = torch.where(bos_sum.eq(bos_sum[:, -1:]), 0, 1) - to_atten_x = bos_pos[:, :, None] - to_atten_y = bos_pos[:, None, :] - else: - bos_pos = torch.where(tokens.eq(bos_token_id), 1, 0) - to_atten_x = bos_pos[:, :, None] - to_atten_y = bos_pos[:, None, :] - attention_mask = torch.logical_or(to_atten_x, to_atten_y).eq(1) + + attention_mask = get_attention_mask(tokens, has_bos, bos_token_id) + if inference_params is None: inference_params = InferenceParams( max_sequence_len=max_length, @@ -524,75 +487,104 @@ def _no_beam_search_generate( attention_mask=attention_mask, ) - scores = decoder(**{"input_ids": tokens, "inference_params": inference_params}) + if isinstance(decoder, torch.nn.Module): + scores = decoder(**{"input_ids": tokens, "inference_params": inference_params}) + elif isinstance(decoder, Trainer): + data = {"input_ids": tokens, "inference_params": inference_params} + model_output, _, _ = decoder.execute_schedule( + (data, None), forward_only=True, return_loss=False, return_output_label=True + ) + scores = process_parallel_output(model_output) + else: + raise NotImplementedError(f"Unsupported decoder type: {type(decoder)}") - if isinstance(scores, (list, tuple)): - scores = scores[0] - scores = scores[:, -1].float() - inference_params.sequence_len_offset += tokens.size(1) - if eos_token_id is not None: - scores[:, eos_token_id] = -1e12 + if gpc.is_last_rank(ParallelMode.PIPELINE): + if isinstance(scores, (list, tuple)): + scores = scores[0] + scores = scores[:, -1].float() + if eos_token_id is not None: + scores[:, eos_token_id] = -1e12 - # The first token generated. - next_tokens = scores.argmax(dim=-1, keepdim=True) + # The first token generated. + next_tokens = scores.argmax(dim=-1, keepdim=True) + else: + next_tokens = tokens.new_zeros([batch_size, 1]) + if gpc.is_initialized(ParallelMode.PIPELINE): + # broadcast to other rank in PP group + torch.distributed.broadcast( + next_tokens, + src=gpc.get_ranks_in_group(ParallelMode.PIPELINE)[-1], + group=gpc.get_group(ParallelMode.PIPELINE), + ) token_ids = torch.cat([tokens, next_tokens], dim=1) cur_len = token_ids.size(1) dones = token_ids.new_zeros(batch_size).eq(1) + inference_params.sequence_len_offset += tokens.size(1) + real_max_length = max_length max_lengths = tokens.new_full((tokens.size(0),), fill_value=max_length, dtype=torch.long) while cur_len < real_max_length: # batch_size x vocab_size - if has_bos: - bos_pos = torch.where(token_ids.eq(bos_token_id), 1, 0) - bos_sum = bos_pos.cumsum(dim=-1) - bos_pos = torch.where(bos_sum.eq(bos_sum[:, -1:]), 0, 1) - to_atten_x = bos_pos[:, :, None] - to_atten_y = bos_pos[:, None, :] - # attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1) + attention_mask = get_attention_mask(token_ids, has_bos, bos_token_id=bos_token_id) + + if isinstance(decoder, torch.nn.Module): + inference_params.attention_mask = attention_mask + scores = decoder(**{"input_ids": token_ids[:, -1:], "inference_params": inference_params}) + elif isinstance(decoder, Trainer): + inference_params.set_attention_mask(attention_mask) + data = {"input_ids": token_ids[:, -1:], "inference_params": inference_params} + model_output, _, _ = decoder.execute_schedule( + (data, None), forward_only=True, return_loss=False, return_output_label=True + ) + scores = process_parallel_output(model_output) else: - bos_pos = torch.where(token_ids.eq(bos_token_id), 1, 0) - to_atten_x = bos_pos[:, :, None] - to_atten_y = bos_pos[:, None, :] - # attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1) - attention_mask = torch.logical_or(to_atten_x, to_atten_y).eq(1) - inference_params.attention_mask = attention_mask - scores = decoder(**{"input_ids": token_ids[:, -1:], "inference_params": inference_params}) + raise NotImplementedError(f"Unsupported decoder type: {type(decoder)}") - if isinstance(scores, (list, tuple)): - scores = scores[0] - scores = scores[:, -1].float() inference_params.sequence_len_offset += 1 - - if repetition_penalty != 1.0: - token_scores = scores.gather(dim=1, index=token_ids) - lt_zero_mask = token_scores.lt(0).float() - ge_zero_mask = lt_zero_mask.eq(0).float() - token_scores = ( - lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores - ) - scores.scatter_(dim=1, index=token_ids, src=token_scores) - # scores: [bsz, vocab_size] - if eos_token_id is not None and length_penalty != 1.0: - # batch_size x vocab_size - eos_token_scores = scores[:, eos_token_id].clone() - scores = scores / cur_len**length_penalty - scores[:, eos_token_id] = eos_token_scores - del eos_token_scores - - if do_sample: - if temperature > 0 and temperature != 1: - scores = scores / temperature - - scores = top_k_top_p_filtering(scores, top_k, top_p, min_tokens_to_keep=2) - # add 1e-12 to avoid https://github.com/pytorch/pytorch/pull/27523 - probs = F.softmax(scores, dim=-1) + 1e-12 - - next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) # batch_size + if gpc.is_last_rank(ParallelMode.PIPELINE): + if isinstance(scores, (list, tuple)): + scores = scores[0] + scores = scores[:, -1].float() + + if repetition_penalty != 1.0: + token_scores = scores.gather(dim=1, index=token_ids) + lt_zero_mask = token_scores.lt(0).float() + ge_zero_mask = lt_zero_mask.eq(0).float() + token_scores = ( + lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores + ) + scores.scatter_(dim=1, index=token_ids, src=token_scores) + # scores: [bsz, vocab_size] + if eos_token_id is not None and length_penalty != 1.0: + # batch_size x vocab_size + eos_token_scores = scores[:, eos_token_id].clone() + scores = scores / cur_len**length_penalty + scores[:, eos_token_id] = eos_token_scores + del eos_token_scores + + if do_sample: + if temperature > 0 and temperature != 1: + scores = scores / temperature + + scores = top_k_top_p_filtering(scores, top_k, top_p, min_tokens_to_keep=2) + # add 1e-12 to avoid https://github.com/pytorch/pytorch/pull/27523 + probs = F.softmax(scores, dim=-1) + 1e-12 + + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) # batch_size + else: + next_tokens = torch.argmax(scores, dim=-1) # batch_size else: - next_tokens = torch.argmax(scores, dim=-1) # batch_size - + next_tokens = tokens.new_zeros(batch_size) + + if gpc.is_initialized(ParallelMode.PIPELINE): + # broadcast to other rank in PP group + torch.distributed.broadcast( + next_tokens, + src=gpc.get_ranks_in_group(ParallelMode.PIPELINE)[-1], + group=gpc.get_group(ParallelMode.PIPELINE), + ) if eos_token_id is not None: # When the generated result exceeds the length, its eos_token_id is set to the most basic terminator. next_tokens = next_tokens.masked_fill(max_lengths.eq(cur_len + 1), eos_token_id[0]) @@ -640,7 +632,7 @@ def _beam_search_generate( bos_token_id=1, ) -> torch.LongTensor: - device = _get_model_device(decoder) + device = tokens.device batch_size = tokens.size(0) if eos_token_id is not None: @@ -654,19 +646,7 @@ def _beam_search_generate( has_bos = torch.all(tokens[:, 0].eq(bos_token_id)) - if has_bos: - bos_pos = torch.where(tokens.eq(bos_token_id), 1, 0) - bos_sum = bos_pos.cumsum(dim=-1) - bos_pos = torch.where(bos_sum.eq(bos_sum[:, -1:]), 0, 1) - to_atten_x = bos_pos[:, :, None] - to_atten_y = bos_pos[:, None, :] - # attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1) - else: - bos_pos = torch.where(tokens.eq(bos_token_id), 1, 0) - to_atten_x = bos_pos[:, :, None] - to_atten_y = bos_pos[:, None, :] - # attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1) - attention_mask = torch.logical_or(to_atten_x, to_atten_y).eq(1) + attention_mask = get_attention_mask(tokens, has_bos, bos_token_id=bos_token_id) if inference_params is None: inference_params = InferenceParams( @@ -679,29 +659,56 @@ def _beam_search_generate( attention_mask=attention_mask, ) - scores = decoder(**{"input_ids": tokens, "inference_params": inference_params}) + if isinstance(decoder, torch.nn.Module): + scores = decoder(**{"input_ids": tokens, "inference_params": inference_params}) + elif isinstance(decoder, Trainer): + data = {"input_ids": tokens, "inference_params": inference_params} + model_output, _, _ = decoder.execute_schedule( + (data, None), forward_only=True, return_loss=False, return_output_label=True + ) + scores = process_parallel_output(model_output) + else: + raise NotImplementedError(f"Unsupported decoder type: {type(decoder)}") - if isinstance(scores, (list, tuple)): - scores = scores[0] - scores = scores[:, -1].float() inference_params.sequence_len_offset += tokens.size(1) - if eos_token_id is not None: - scores[:, eos_token_id] = -1e12 - vocab_size = scores.size(1) - assert vocab_size >= num_beams, "num_beams should be smaller than " "the number of vocabulary size." - # The first token generated. - if do_sample: - probs = F.softmax(scores, dim=-1) + 1e-12 - # (batch_size, num_beams) - next_tokens = torch.multinomial(probs, num_samples=num_beams) - logits = probs.log() - # (batch_size, num_beams) - next_scores = logits.gather(dim=1, index=next_tokens) + if gpc.is_last_rank(ParallelMode.PIPELINE): + if isinstance(scores, (list, tuple)): + scores = scores[0] + scores = scores[:, -1].float() + if eos_token_id is not None: + scores[:, eos_token_id] = -1e12 + vocab_size = scores.size(1) + assert vocab_size >= num_beams, "num_beams should be smaller than " "the number of vocabulary size." + + # The first token generated. + if do_sample: + probs = F.softmax(scores, dim=-1) + 1e-12 + # (batch_size, num_beams) + next_tokens = torch.multinomial(probs, num_samples=num_beams) + logits = probs.log() + # (batch_size, num_beams) + next_scores = logits.gather(dim=1, index=next_tokens) + else: + scores = F.log_softmax(scores, dim=-1) # (batch_size, vocab_size) + # obtain (batch_size, num_beams), (batch_size, num_beams) + next_scores, next_tokens = torch.topk(scores, num_beams, dim=1, largest=True, sorted=True) else: - scores = F.log_softmax(scores, dim=-1) # (batch_size, vocab_size) - # obtain (batch_size, num_beams), (batch_size, num_beams) - next_scores, next_tokens = torch.topk(scores, num_beams, dim=1, largest=True, sorted=True) + next_tokens = tokens.new_zeros([batch_size, num_beams]) + next_scores = torch.zeros([batch_size, num_beams], dtype=torch.float32, device=next_tokens.device) + + if gpc.is_initialized(ParallelMode.PIPELINE): + # broadcast to other rank in PP group + torch.distributed.broadcast( + next_tokens, + src=gpc.get_ranks_in_group(ParallelMode.PIPELINE)[-1], + group=gpc.get_group(ParallelMode.PIPELINE), + ) + torch.distributed.broadcast( + next_scores, + src=gpc.get_ranks_in_group(ParallelMode.PIPELINE)[-1], + group=gpc.get_group(ParallelMode.PIPELINE), + ) indices = torch.arange(batch_size, dtype=torch.long).to(device) indices = indices.repeat_interleave(num_beams) @@ -726,79 +733,102 @@ def _beam_search_generate( batch_inds_with_numbeams_interval = (torch.arange(batch_size) * num_beams).view(-1, 1).to(token_ids) while cur_len < real_max_length: - if has_bos: - bos_pos = torch.where(token_ids.eq(bos_token_id), 1, 0) - bos_sum = bos_pos.cumsum(dim=-1) - bos_pos = torch.where(bos_sum.eq(bos_sum[:, -1:]), 0, 1) - to_atten_x = bos_pos[:, :, None] - to_atten_y = bos_pos[:, None, :] - # attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1) - else: - bos_pos = torch.where(token_ids.eq(bos_token_id), 1, 0) - to_atten_x = bos_pos[:, :, None] - to_atten_y = bos_pos[:, None, :] - # attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1) - attention_mask = torch.logical_or(to_atten_x, to_atten_y).eq(1) + attention_mask = get_attention_mask(token_ids, has_bos, bos_token_id=bos_token_id) - inference_params.attention_mask = attention_mask # (bsz x num_beams, vocab_size) - scores = decoder(**{"input_ids": token_ids[:, -1:], "inference_params": inference_params}) - - if isinstance(scores, (list, tuple)): - scores = scores[0] - scores = scores[:, -1].float() - inference_params.sequence_len_offset += 1 - if repetition_penalty != 1.0: - token_scores = scores.gather(dim=1, index=token_ids) - lt_zero_mask = token_scores.lt(0).float() - ge_zero_mask = lt_zero_mask.eq(0).float() - token_scores = ( - lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores + if isinstance(decoder, torch.nn.Module): + inference_params.attention_mask = attention_mask + scores = decoder(**{"input_ids": token_ids[:, -1:], "inference_params": inference_params}) + elif isinstance(decoder, Trainer): + inference_params.set_attention_mask(attention_mask) + data = {"input_ids": token_ids[:, -1:], "inference_params": inference_params} + model_output, _, _ = decoder.execute_schedule( + (data, None), forward_only=True, return_loss=False, return_output_label=True ) - scores.scatter_(dim=1, index=token_ids, src=token_scores) - - if eos_token_id is not None: - max_len_eos_mask = max_lengths.eq(cur_len + 1) - # When the generated result exceeds the length, its eos_token_id is set to the most basic terminator. - eos_scores = scores[:, eos_token_id[0]] - scores[:, eos_token_id[0]] = torch.where(max_len_eos_mask, eos_scores + 1e32, eos_scores) - - if do_sample: - if temperature > 0 and temperature != 1: - scores = scores / temperature - - scores = top_k_top_p_filtering(scores, top_k, top_p, min_tokens_to_keep=num_beams + 1) - # add 1e-12 to avoid https://github.com/pytorch/pytorch/pull/27523 - probs = F.softmax(scores, dim=-1) + 1e-12 + scores = process_parallel_output(model_output) + else: + raise NotImplementedError(f"Unsupported decoder type: {type(decoder)}") - # batch_size' x (num_beams+1) - _tokens = torch.multinomial(probs, num_samples=num_beams + 1) + inference_params.sequence_len_offset += 1 - logits = probs.log() - # batch_size' x (num_beams+1) - _scores = logits.gather(dim=1, index=_tokens) - # batch_size' x (num_beams+1) - _scores = _scores + beam_scores[:, None] - _scores = _scores.view(batch_size, num_beams * (num_beams + 1)) - next_scores, ids = _scores.topk(2 * num_beams, dim=1, largest=True, sorted=True) - _tokens = _tokens.view(batch_size, num_beams * (num_beams + 1)) - # (batch_size, 2*num_beams) - next_tokens = _tokens.gather(dim=1, index=ids) - # (batch_size, 2*num_beams) - from_which_beam = torch.floor(ids.float() / (num_beams + 1)).long() + if gpc.is_last_rank(ParallelMode.PIPELINE): + + if isinstance(scores, (list, tuple)): + scores = scores[0] + scores = scores[:, -1].float() + if repetition_penalty != 1.0: + token_scores = scores.gather(dim=1, index=token_ids) + lt_zero_mask = token_scores.lt(0).float() + ge_zero_mask = lt_zero_mask.eq(0).float() + token_scores = ( + lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores + ) + scores.scatter_(dim=1, index=token_ids, src=token_scores) + + if eos_token_id is not None: + max_len_eos_mask = max_lengths.eq(cur_len + 1) + # When the generated result exceeds the length, its eos_token_id is set to the most basic terminator. + eos_scores = scores[:, eos_token_id[0]] + scores[:, eos_token_id[0]] = torch.where(max_len_eos_mask, eos_scores + 1e32, eos_scores) + + if do_sample: + if temperature > 0 and temperature != 1: + scores = scores / temperature + + scores = top_k_top_p_filtering(scores, top_k, top_p, min_tokens_to_keep=num_beams + 1) + # add 1e-12 to avoid https://github.com/pytorch/pytorch/pull/27523 + probs = F.softmax(scores, dim=-1) + 1e-12 + + # batch_size' x (num_beams+1) + _tokens = torch.multinomial(probs, num_samples=num_beams + 1) + + logits = probs.log() + # batch_size' x (num_beams+1) + _scores = logits.gather(dim=1, index=_tokens) + # batch_size' x (num_beams+1) + _scores = _scores + beam_scores[:, None] + _scores = _scores.view(batch_size, num_beams * (num_beams + 1)) + next_scores, ids = _scores.topk(2 * num_beams, dim=1, largest=True, sorted=True) + _tokens = _tokens.view(batch_size, num_beams * (num_beams + 1)) + # (batch_size, 2*num_beams) + next_tokens = _tokens.gather(dim=1, index=ids) + # (batch_size, 2*num_beams) + from_which_beam = torch.floor(ids.float() / (num_beams + 1)).long() + else: + # (batch_size * num_beams, vocab_size) + scores = F.log_softmax(scores, dim=-1) + # (batch_size * num_beams, vocab_size) + _scores = scores + beam_scores[:, None] + # (batch_size, num_beams*vocab_size) + _scores = _scores.view(batch_size, -1) + # (bsz, 2*num_beams) + next_scores, ids = torch.topk(_scores, 2 * num_beams, dim=1, largest=True, sorted=True) + # (batch_size, 2*num_beams) + from_which_beam = torch.floor(ids.float() / vocab_size).long() + next_tokens = ids % vocab_size # (batch_size, 2*num_beams) else: - # (batch_size * num_beams, vocab_size) - scores = F.log_softmax(scores, dim=-1) - # (batch_size * num_beams, vocab_size) - _scores = scores + beam_scores[:, None] - # (batch_size, num_beams*vocab_size) - _scores = _scores.view(batch_size, -1) - # (bsz, 2*num_beams) - next_scores, ids = torch.topk(_scores, 2 * num_beams, dim=1, largest=True, sorted=True) - # (batch_size, 2*num_beams) - from_which_beam = torch.floor(ids.float() / vocab_size).long() - next_tokens = ids % vocab_size # (batch_size, 2*num_beams) + next_tokens = tokens.new_zeros([batch_size, 2 * num_beams]) + next_scores = torch.zeros([batch_size, 2 * num_beams], dtype=torch.float32, device=next_tokens.device) + from_which_beam = torch.zeros([batch_size, 2 * num_beams], dtype=torch.int64, device=next_tokens.device) + + if gpc.is_initialized(ParallelMode.PIPELINE): + # broadcast to other rank in PP group + torch.distributed.broadcast( + next_tokens, + src=gpc.get_ranks_in_group(ParallelMode.PIPELINE)[-1], + group=gpc.get_group(ParallelMode.PIPELINE), + ) + torch.distributed.broadcast( + next_scores, + src=gpc.get_ranks_in_group(ParallelMode.PIPELINE)[-1], + group=gpc.get_group(ParallelMode.PIPELINE), + ) + torch.distributed.broadcast( + from_which_beam, + src=gpc.get_ranks_in_group(ParallelMode.PIPELINE)[-1], + group=gpc.get_group(ParallelMode.PIPELINE), + ) not_eos_mask = torch.all(next_tokens[..., None].ne(eos_token_id), dim=-1) keep_mask = not_eos_mask.cumsum(dim=1).le(num_beams) @@ -964,3 +994,128 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf") indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) logits[indices_to_remove] = filter_value return logits + + +@torch.no_grad() +def get_attention_mask(tokens, has_bos, bos_token_id=1): + if has_bos: + bos_pos = torch.where(tokens.eq(bos_token_id), 1, 0) + bos_sum = bos_pos.cumsum(dim=-1) + bos_pos = torch.where(bos_sum.eq(bos_sum[:, -1:]), 0, 1) + to_atten_x = bos_pos[:, :, None] + to_atten_y = bos_pos[:, None, :] + else: + bos_pos = torch.where(tokens.eq(bos_token_id), 1, 0) + to_atten_x = bos_pos[:, :, None] + to_atten_y = bos_pos[:, None, :] + # attention_mask = torch.logical_or(to_atten_x, to_atten_y).eq(1) + to_atten_y_new = to_atten_y.repeat(1, to_atten_x.shape[1], 1) + to_atten_x_new = to_atten_x.repeat(1, 1, to_atten_y.shape[2]) + attention_mask = torch.logical_or(to_atten_x_new, to_atten_y_new).eq(1) + + return attention_mask + + +def batch_tokenize_process_fn( + batch: Union[List[str], List[Dict], Dict], tokenizer, add_bos: bool = True, add_eos: bool = False +) -> Union[List, Dict]: + """Data post-processing function for tokenize. + + This function can be directly used in the map function of ``DatasetDict`` and supports batched=True. + + Args: + batch (Union[List[str], List[Dict], Dict]): Data used to tokenize which can be of the following + categories: + (a) A list whose content can be a string or a dictionary. If it is a dictionary, + it needs to contain the "content" field; + (b) A dictionary-like object, which should contain the "content" field. + tokenizer : Currently only sentencepiece is supported. + add_bos (bool, optional): Whether to add bos token. Defaults to True. + add_eos (bool, optional): Whether to add eos token. Defaults to False. + + Returns: + Union[List, Dict]: tokenized data. + """ + + def _tokenize(text): + tokens = [tokenizer.bos_id()] if add_bos else [] + tokens += tokenizer.encode(text) + if add_eos: + tokens.append(tokenizer.eos_id()) + return tokens + + if isinstance(batch, (List, Tuple)): + if len(batch) == 0: + return None + if isinstance(batch[0], str): + return [_tokenize(w) for w in batch] + if isinstance(batch[0], Dict): + for sample in batch: + sample["input_ids"] = _tokenize(sample["content"]) + return batch + elif isinstance(batch, str): + raise NotImplementedError("Do not support a single str as input.") + else: + try: + batch["input_ids"] = [_tokenize(w) for w in batch["content"]] + batch.pop("content") + return batch + except Exception as e: + print(f"The type of parameter ``batch`` is wrong, type:{type(batch)}, batch: {batch}.") + raise e + + +def pad_input_ids(batch: List[Dict], pad_token_id: int = 0, return_dict: bool = False) -> Union[Dict, torch.Tensor]: + """Tokenize a list of prompts with Left Padding. + + Args: + batch (List[Dict, List]): if batch[0] is a dict, then key 'input_ids' must exist, + and value must be a list of integers. + pad_token_id (int, optional): Defaults to 0. + return_dict (bool, optional): Defaults to False. + + Returns: + Union[Dict, torch.Tensor]: input_ids or dict(input_ids=input_ids) + """ + assert isinstance(batch, list), "batch must be a list" + + input_ids = [] + max_length = max([len(w["input_ids"] if isinstance(w, Dict) else w) for w in batch]) + for sample in batch: + cur_input_ids = sample["input_ids"] if isinstance(sample, Dict) else sample + assert len(cur_input_ids) > 0, "got empty list" + assert isinstance(cur_input_ids[0], int), f"only support a list of integers, but got {type(cur_input_ids[0])}" + cur_input_ids = torch.LongTensor(cur_input_ids) + # left padding for generation + input_ids.append( + torch.cat( + [ + cur_input_ids.new_full((max_length - len(cur_input_ids),), fill_value=pad_token_id), + cur_input_ids, + ] + ) + ) + input_ids = torch.stack(input_ids) + return input_ids if not return_dict else {"input_ids": input_ids} + + +def batch_tokenize( + prompts: List[str], tokenizer, return_dict: bool = False, pad_token_id: int = 1 +) -> Union[Dict, torch.Tensor]: + """Tokenize a list of prompts with Left Padding. Return the tokens. + + Args: + prompts (List[str]): a list of prompts + tokenizer : Currently only sentencepiece is supported. + return_dict (bool, optional): Defaults to False. + pad_token_id (int, optional): Defaults to 1. + + Returns: + Union[Dict, torch.Tensor]: input_ids or dict(input_ids=input_ids) + """ + + tokenizer_out = batch_tokenize_process_fn(prompts, tokenizer) + + tokens = pad_input_ids(tokenizer_out, return_dict=return_dict, pad_token_id=pad_token_id) + + return tokens diff --git a/internlm/apis/inference_utils.py b/internlm/apis/inference_utils.py new file mode 100644 index 000000000..423e7aafe --- /dev/null +++ b/internlm/apis/inference_utils.py @@ -0,0 +1,69 @@ +import torch + +from internlm.core.context import ParallelMode # noqa: E402 +from internlm.core.context import global_context as gpc # noqa: E402 +from internlm.core.parallel.comm.utils import _gather as gather + + +class InferenceParams: + """ + Intermediate cache objects for inference + """ + + def __init__( + self, + max_sequence_len, + max_batch_size, + sequence_len_offset=0, + batch_size_offset=0, + key_value_memory_dict: dict = None, + lengths_per_sample=None, + attention_mask=None, + window_size=None, + ) -> None: + + self.max_sequence_len: int = max_sequence_len + self.max_batch_size: int = max_batch_size + self.sequence_len_offset: int = sequence_len_offset + self.batch_size_offset: int = batch_size_offset + if key_value_memory_dict is None: + key_value_memory_dict = {} + self.key_value_memory_dict: dict = key_value_memory_dict + self.fused_ft_kernel: bool = False + self.lengths_per_sample = lengths_per_sample + self.attention_mask = attention_mask + self.full_attention_mask = attention_mask + self.window_size = window_size + + def reorder_state(self, indices): + if self.lengths_per_sample is not None: + self.lengths_per_sample = self.lengths_per_sample.index_select(index=indices, dim=0) + for key, value in list(self.key_value_memory_dict.items()): + value = value.index_select(index=indices, dim=0) + self.key_value_memory_dict[key] = value + + def set_batch_offset(self, offset, bsz): + """Called by `BaseScheduler._load_micro_batch`. + when micro-batch is enabled, the working attention mask is only a view of `full_attention_mask` + """ + self.batch_size_offset = offset + self.attention_mask = self.full_attention_mask[offset : offset + bsz] + + def set_attention_mask(self, mask): + """useful when generate using Engine/trainer rather than directly using model""" + self.full_attention_mask = mask + + +def process_parallel_output(model_output): + # 1. concat + if gpc.is_last_rank(ParallelMode.PIPELINE): + if not isinstance(model_output, torch.Tensor): + model_output = torch.cat(model_output, dim=0) + else: + return None + + # gather tp parallel output + if gpc.config.model.parallel_output and gpc.is_initialized(ParallelMode.TENSOR): + return gather(model_output, ParallelMode.TENSOR, -1) + else: + return model_output diff --git a/internlm/core/engine.py b/internlm/core/engine.py index b97393163..5989536dc 100644 --- a/internlm/core/engine.py +++ b/internlm/core/engine.py @@ -93,6 +93,11 @@ def criterion(self): """Returns the criterion (loss function) attached to the engine.""" return self._criterion + @criterion.setter + def criterion(self, criterion): + """Sets the criterion (loss function).""" + self._criterion = criterion + def _all_reduce_gradients(self): """Handles all-reduce operations of gradients across different parallel groups.""" for handler in self._gradient_handlers: diff --git a/internlm/core/scheduler/base_scheduler.py b/internlm/core/scheduler/base_scheduler.py index 1800ccc12..da060ade5 100644 --- a/internlm/core/scheduler/base_scheduler.py +++ b/internlm/core/scheduler/base_scheduler.py @@ -8,6 +8,7 @@ import torch +from internlm.apis import InferenceParams from internlm.core.engine import Engine @@ -44,10 +45,26 @@ def _load_micro_batch(self, data: Dict, label: torch.Tensor, offset: int, bsz_st so the data of batch is unpacked and 'bsz_stride' is equal to 'micro_bsz'. In all other cases 'bsz_stride' should be equal to 1. """ - assert isinstance(data, dict) and isinstance(label, torch.Tensor) - micro_batch_data = {k: v[offset : offset + bsz_stride] for k, v in data.items()} - micro_batch_label = label[offset : offset + bsz_stride] - + assert isinstance(data, dict) + + micro_batch_data = {} + for k, v in data.items(): + if isinstance(v, torch.Tensor): + micro_batch_data[k] = v[offset : offset + bsz_stride] + elif isinstance(v, InferenceParams): + v.set_batch_offset(offset, bsz_stride) + micro_batch_data[k] = v + elif isinstance(v, (list, tuple)): + micro_batch_data[k] = v[offset : offset + bsz_stride] + else: + raise NotImplementedError(f"value of type {type(v)} is not supported") + + if isinstance(label, torch.Tensor): + micro_batch_label = label[offset : offset + bsz_stride] + elif isinstance(label, Dict): + micro_batch_label = {k: v[offset : offset + bsz_stride] if v.dim() > 0 else v for k, v in label.items()} + else: + micro_batch_label = label return micro_batch_data, micro_batch_label @abstractmethod diff --git a/internlm/core/scheduler/no_pipeline_scheduler.py b/internlm/core/scheduler/no_pipeline_scheduler.py index 339a404e3..4040e8e1a 100644 --- a/internlm/core/scheduler/no_pipeline_scheduler.py +++ b/internlm/core/scheduler/no_pipeline_scheduler.py @@ -175,7 +175,6 @@ def forward_backward_step( If True, the model is run for the forward pass, else back propagation will be executed. return_loss (bool, optional): Loss will be returned if True. return_output_label (bool, optional): Output and label will be returned if True. - Returns: Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None. """ diff --git a/internlm/core/scheduler/pipeline_scheduler.py b/internlm/core/scheduler/pipeline_scheduler.py index df84b0661..269ddb966 100644 --- a/internlm/core/scheduler/pipeline_scheduler.py +++ b/internlm/core/scheduler/pipeline_scheduler.py @@ -79,7 +79,7 @@ def pack_return_tensors(return_tensors): raise TypeError("Output of model must be tensor or list/tuple of tensors") if isinstance(label[0], torch.Tensor): label = torch.cat(label, dim=0) - else: + elif isinstance(label[0], dict): merged_label = {k: [] for k in label[0].keys()} for d in label: for k, v in d.items(): diff --git a/internlm/utils/common.py b/internlm/utils/common.py index df4583d43..82161c7d0 100644 --- a/internlm/utils/common.py +++ b/internlm/utils/common.py @@ -53,7 +53,9 @@ def move_to_device(data): data = [move_to_device(x) for x in data] elif isinstance(data, dict): data = {k: move_to_device(v) for k, v in data.items()} - + else: + # other types like scalar, other params, return the value itself. + return data return data diff --git a/tests/test_infer/test_generate.py b/tests/test_infer/test_generate.py new file mode 100644 index 000000000..a169c96e7 --- /dev/null +++ b/tests/test_infer/test_generate.py @@ -0,0 +1,133 @@ +import os + +import pytest +import torch +from sentencepiece import SentencePieceProcessor + +from internlm.apis.inference import SequenceGenerator, batch_tokenize +from internlm.initialize import initialize_distributed_env # noqa: E402 +from internlm.train import initialize_model, initialize_parallel_communicator + + +def set_seed(seed: int = 1024): + import random + + import numpy as np + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + +def load_and_generate(path, model_type="INTERNLM2_PUBLIC", tokenizer_path=""): + model_cfg = os.path.join(path, "model_config.pt") + model_wt = os.path.join(path, "model_tp0_pp0.pt") + model_config = torch.load(model_cfg) + model_config["apply_post_layer_norm"] = False + if model_config.get("adapt_hf") is not None: + model_config.pop("adapt_hf") + evo_cfg = dict( + model_type=model_type, + model=model_config, + parallel=dict( + zero1=dict(size=1, fsdp=False), + pipeline=dict(size=1, interleaved_overlap=True), + tensor=dict(size=1, mode="mtp"), + sequence_parallel=0, + ), + ) + initialize_distributed_env(evo_cfg, master_port=23574, args_check=False) + + tokenizer = SentencePieceProcessor(tokenizer_path) # pylint: disable=E1121 + + def convert_to_str(output_ids): + output_tokens = output_ids.tolist() + all_output_str = [] + for b in range(len(output_tokens)): + for sent_idx in range(len(output_tokens[b])): + cur_output_tokens = output_tokens[b][sent_idx] + cur_sent = tokenizer.decode(cur_output_tokens) + all_output_str.append(cur_sent) + return all_output_str + + model = initialize_model() + _ = initialize_parallel_communicator(model) + # Directly get the origin model without NativeAMP wrapper. + model = model.model + + state_dict = torch.load(model_wt) + load_info = model.load_state_dict(state_dict, strict=False) + print(load_info) + + sequenece_generator = SequenceGenerator( + decoder=model, + eos_token_id=tokenizer.eos_id(), + pad_token_id=tokenizer.bos_id(), + bos_token_id=tokenizer.bos_id(), + additional_eos_token_list=None, + ) + + test_prompt_0 = "Gold is considered to be a precious metal." + test_prompt_1 = "what is love? someone think it is a feeling, someone think it is a chemical reaction." + test_prompt_2 = "kobe bryant is a basketball player." + + prompt_3 = [ + test_prompt_0, + test_prompt_1, + test_prompt_2, + ] + prompt_2 = [ + test_prompt_0, + test_prompt_1, + ] + + prompt_1 = [test_prompt_0] + + def generate(prompt): + input_ids = batch_tokenize(prompt, tokenizer, pad_token_id=tokenizer.bos_id()).cuda() + generate_kwargs = {} + set_seed() + output_ids = sequenece_generator.generate( + input_ids, + num_return_sequences=generate_kwargs.get("num_return_sequences", 1), + max_length=generate_kwargs.get("max_length", input_ids.shape[1] + 80), + num_beams=generate_kwargs.get("num_beams", 1), + do_sample=generate_kwargs.get("do_sample", False), + temperature=generate_kwargs.get("temperature", 1.0), + top_k=generate_kwargs.get("top_k", 50), + top_p=generate_kwargs.get("top_p", 1.0), + repetition_penalty=generate_kwargs.get("repetition_penalty", 1), + length_penalty=generate_kwargs.get("repetition_penalty", 1.0), + ) + + all_output_str = convert_to_str(output_ids) + return all_output_str + + output_3 = generate(prompt_3) + output_2 = generate(prompt_2) + output_1 = generate(prompt_1) + + assert output_3[0] == output_2[0] + assert output_3[1] == output_2[1] + assert ( + output_1[0] + == "Gold is considered to be a precious metal. It is a metal that is highly valued for its \ +rarity and beauty. Gold is often used in jewelry, coins, and other decorative items. It is also used in \ +the production of electronics and other high-tech products. Gold is a highly sought-after metal because \ +of its ability to resist corrosion and tarnish. It is also highly resistant to fire and is a good conductor \ +of heat and electricity.\n" + ) + print("test generate done!") + + +def test_internlm2_1_8B_generate(): + base_model_dir = os.environ.get("qa_data") + if base_model_dir is not None: + model_dir = os.path.join(base_model_dir, "internlm2_1_8B") + tokenizer_path = os.path.join(base_model_dir, "InternLM_CI_assets/v13.model") + if os.path.exists(model_dir) and os.path.exists(tokenizer_path): + load_and_generate(model_dir, tokenizer_path=tokenizer_path) + + +if __name__ == "__main__": + pytest.main(["-s", "-q", "-v", "test_generate.py"]) diff --git a/tests/test_infer/test_trainer_generate.py b/tests/test_infer/test_trainer_generate.py new file mode 100644 index 000000000..3b7fffd05 --- /dev/null +++ b/tests/test_infer/test_trainer_generate.py @@ -0,0 +1,201 @@ +import os + +import pytest +from sentencepiece import SentencePieceProcessor + +import internlm # noqa: E402 +from internlm.apis.inference import SequenceGenerator, batch_tokenize +from internlm.checkpoint import CheckpointManager # noqa: E402 +from internlm.core.context import global_context as gpc # noqa: E402 +from internlm.core.trainer import TrainState # noqa: E402 +from internlm.data import build_train_loader_with_data_type # noqa: E402 +from internlm.initialize import initialize_distributed_env # noqa: E402 +from internlm.model.losses import FlashGPTLMLoss # noqa: E402 +from internlm.train import ( # noqa: E402 + get_scheduler_hooks, + initialize_model, + initialize_optimizer, + initialize_parallel_communicator, +) + + +def setup_generator(config, tokenizer): + initialize_distributed_env(config=config) + + model = initialize_model() + isp_communicator = initialize_parallel_communicator(model) + + criterion = FlashGPTLMLoss() + + # initialize the train data loader + train_dl, _ = build_train_loader_with_data_type() + + # initialize and resume train state + train_state = TrainState(gpc.config, train_dl.batch_sampler) + + optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model, isp_communicator) + + ckpt_manager = CheckpointManager( + ckpt_config=gpc.config.ckpt, + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + train_dl=train_dl, + model_config=gpc.config.model, + feishu_address=gpc.config.monitor.alert.feishu_alert_address, + ) + ckpt_manager.try_resume_training(train_state) + + # initialize trainer + trainer, train_dl, _, _ = internlm.initialize_trainer( + model=model, + optimizer=optimizer, + criterion=criterion, + train_dataloader=train_dl, + lr_scheduler=lr_scheduler, + beta2_scheduler=beta2_scheduler, + scheduler_hooks=get_scheduler_hooks(None, optimizer, isp_communicator), + ) + + trainer.schedule.data_process_func = None + + if isinstance(tokenizer, SentencePieceProcessor): + eos_token_id = tokenizer.eos_id() + pad_token_id = tokenizer.eos_id() + bos_token_id = tokenizer.bos_id() + else: + eos_token_id = tokenizer.eos_token_id + pad_token_id = tokenizer.pad_token_id + bos_token_id = tokenizer.bos_token_id + + sequenece_generator = SequenceGenerator( + decoder=trainer, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + additional_eos_token_list=None, + ) + + return sequenece_generator + + +def do_generate(config, tokenizer_path, prompt): + tokenizer = SentencePieceProcessor(tokenizer_path) # pylint: disable=E1121 + + sequenece_generator = setup_generator(config, tokenizer) + input_ids = batch_tokenize(prompt, tokenizer, pad_token_id=tokenizer.bos_id()).cuda() + + generate_kwargs = {} + output_ids = sequenece_generator.generate( + input_ids, + num_return_sequences=generate_kwargs.get("num_return_sequences", 1), + max_length=generate_kwargs.get("max_length", 100), + num_beams=generate_kwargs.get("num_beams", 1), + do_sample=generate_kwargs.get("do_sample", True), + temperature=generate_kwargs.get("temperature", 1.0), + top_k=generate_kwargs.get("top_k", 50), + top_p=generate_kwargs.get("top_p", 1.0), + repetition_penalty=generate_kwargs.get("repetition_penalty", 1), + length_penalty=generate_kwargs.get("repetition_penalty", 1.0), + ) + output_tokens = output_ids.tolist() + all_output_str = [] + for b in range(len(output_tokens)): + for sent_idx in range(len(output_tokens[b])): + cur_output_tokens = output_tokens[b][sent_idx] + cur_sent = tokenizer.decode(cur_output_tokens) + all_output_str.append(cur_sent) + return all_output_str + + +def test_luyou_2B_generate(): + prompt = [ + "user\nHow can I keep flys away from my house\nassistant\n", + "user\nHow can I keep flys away from my house\nassistant\nThe best way is to keep your house clean, " + "and sweep away from where your meals are prepared, since flys tend to seek out food particles.\n" + "user\nAny other advice?\nassistant\n", + ] + + base_model_dir = os.environ.get("qa_data") + if base_model_dir is not None: + config = os.path.join(base_model_dir, "model_configs/Luyou_1B_merged.py") + + tokenizer_path = os.path.join(base_model_dir, "InternLM_CI_assets/v13.model") + if os.path.exists(config) and os.path.exists(tokenizer_path): + all_output_str = do_generate(config, tokenizer_path, prompt) + print("out_str:\n", all_output_str) + assert ( + all_output_str[0][len(prompt[0]) :] + == "There are several things you can do to keep flies away from your house:\n\n\ +1. Keep your home clean: Flies are attracted to food and dirty surfaces. Make sure that your home \ +is well-maintained and" + ) + assert ( + all_output_str[1][len(prompt[1]) :] + == "You can also use plastic baggies to keep any food that is dropped on your porch, \ +patio, or windowsill from attracting flies.\n[UNUSED_TOKEN_145]\nNo[UNUSED_TOKEN_145]\nYou could also \ +use scented candles or diffusers" + ) + + +@pytest.mark.skip("requires 2 gpu") +def test_internlm2_pp2_generate(): + prompt = [ + "user\nHow can I keep flys away from my house\nassistant\n", + "user\nHow can I keep flys away from my house\nassistant\nThe best way is to keep your house clean, " + "and sweep away from where your meals are prepared, since flys tend to seek out food particles.\n" + "user\nAny other advice?\nassistant\n", + ] + + base_model_dir = os.environ.get("qa_data") + if base_model_dir is not None: + config = os.path.join(base_model_dir, "model_configs/Luyou_1B_PP2.py") + tokenizer_path = os.path.join(base_model_dir, "InternLM_CI_assets/v13.model") + if os.path.exists(config) and os.path.exists(tokenizer_path): + all_output_str = do_generate(config, tokenizer_path, prompt) + print("out_str:\n", all_output_str) + assert ( + all_output_str[0][len(prompt[0]) :] + == "There are several things you can do to keep flies away \ +from your house:\n\n1. Keep your home clean: Flies are attracted to food and dirty surfaces. Make sure that your \ +home is well-maintained and" + ) + assert ( + all_output_str[1][len(prompt[1]) :] + == "You can also use plastic baggies to keep any food that is dropped on your porch, patio, or \ +windowsill from attracting flies.\n[UNUSED_TOKEN_145]\nNo[UNUSED_TOKEN_145]\nYou could also use scented candles \ +or diffusers" + ) + + +@pytest.mark.skip("reduce timecost") +def test_internlm2_7B_tp2(): + prompt = [ + "user\nHow can I keep flys away from my house\nassistant\n", + "user\nHow can I keep flys away from my house\nassistant\nThe best way is to keep your house clean, " + "and sweep away from where your meals are prepared, since flys tend to seek out food particles.\n" + "user\nAny other advice?\nassistant\n", + ] + + base_model_dir = os.environ.get("qa_data") + if base_model_dir is not None: + config = os.path.join(base_model_dir, "model_configs/7B_internlm2.py") + + tokenizer_path = os.path.join(base_model_dir, "InternLM_CI_assets/v13.model") + if os.path.exists(config) and os.path.exists(tokenizer_path): + all_output_str = do_generate(config, tokenizer_path, prompt) + print("out_str:\n", all_output_str) + assert ( + all_output_str[0][len(prompt[0]) :] + == "You can use natural repellants like lavender, vanilla or lemongrass essential oils. \ +Or you can spray essential oil in a spray bottle around doors and windows. Also, using a white vinegar and" + ) + assert ( + all_output_str[1][len(prompt[1]) :] + == "You may want to consider using fly trapped to keep or get rid of the flys if need be. \ +Also wearing indoor protective clothing may be advised as well since they can be dangerous" + ) + + +if __name__ == "__main__": + pytest.main(["-s", "-q", "-v", "test_trainer_generate.py"])