Skip to content

Commit

Permalink
feat(inference): support generation using trainer (#230)
Browse files Browse the repository at this point in the history
  • Loading branch information
KimmiShi authored Jun 11, 2024
1 parent e96450a commit eaef99e
Show file tree
Hide file tree
Showing 10 changed files with 818 additions and 231 deletions.
6 changes: 6 additions & 0 deletions internlm/apis/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .inference_utils import InferenceParams, process_parallel_output

__all__ = [
"InferenceParams",
"process_parallel_output",
]
603 changes: 379 additions & 224 deletions internlm/apis/inference.py

Large diffs are not rendered by default.

69 changes: 69 additions & 0 deletions internlm/apis/inference_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import torch

from internlm.core.context import ParallelMode # noqa: E402
from internlm.core.context import global_context as gpc # noqa: E402
from internlm.core.parallel.comm.utils import _gather as gather


class InferenceParams:
"""
Intermediate cache objects for inference
"""

def __init__(
self,
max_sequence_len,
max_batch_size,
sequence_len_offset=0,
batch_size_offset=0,
key_value_memory_dict: dict = None,
lengths_per_sample=None,
attention_mask=None,
window_size=None,
) -> None:

self.max_sequence_len: int = max_sequence_len
self.max_batch_size: int = max_batch_size
self.sequence_len_offset: int = sequence_len_offset
self.batch_size_offset: int = batch_size_offset
if key_value_memory_dict is None:
key_value_memory_dict = {}
self.key_value_memory_dict: dict = key_value_memory_dict
self.fused_ft_kernel: bool = False
self.lengths_per_sample = lengths_per_sample
self.attention_mask = attention_mask
self.full_attention_mask = attention_mask
self.window_size = window_size

def reorder_state(self, indices):
if self.lengths_per_sample is not None:
self.lengths_per_sample = self.lengths_per_sample.index_select(index=indices, dim=0)
for key, value in list(self.key_value_memory_dict.items()):
value = value.index_select(index=indices, dim=0)
self.key_value_memory_dict[key] = value

def set_batch_offset(self, offset, bsz):
"""Called by `BaseScheduler._load_micro_batch`.
when micro-batch is enabled, the working attention mask is only a view of `full_attention_mask`
"""
self.batch_size_offset = offset
self.attention_mask = self.full_attention_mask[offset : offset + bsz]

def set_attention_mask(self, mask):
"""useful when generate using Engine/trainer rather than directly using model"""
self.full_attention_mask = mask


def process_parallel_output(model_output):
# 1. concat
if gpc.is_last_rank(ParallelMode.PIPELINE):
if not isinstance(model_output, torch.Tensor):
model_output = torch.cat(model_output, dim=0)
else:
return None

# gather tp parallel output
if gpc.config.model.parallel_output and gpc.is_initialized(ParallelMode.TENSOR):
return gather(model_output, ParallelMode.TENSOR, -1)
else:
return model_output
5 changes: 5 additions & 0 deletions internlm/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ def criterion(self):
"""Returns the criterion (loss function) attached to the engine."""
return self._criterion

@criterion.setter
def criterion(self, criterion):
"""Sets the criterion (loss function)."""
self._criterion = criterion

def _all_reduce_gradients(self):
"""Handles all-reduce operations of gradients across different parallel groups."""
for handler in self._gradient_handlers:
Expand Down
25 changes: 21 additions & 4 deletions internlm/core/scheduler/base_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import torch

from internlm.apis import InferenceParams
from internlm.core.engine import Engine


Expand Down Expand Up @@ -44,10 +45,26 @@ def _load_micro_batch(self, data: Dict, label: torch.Tensor, offset: int, bsz_st
so the data of batch is unpacked and 'bsz_stride' is equal to 'micro_bsz'.
In all other cases 'bsz_stride' should be equal to 1.
"""
assert isinstance(data, dict) and isinstance(label, torch.Tensor)
micro_batch_data = {k: v[offset : offset + bsz_stride] for k, v in data.items()}
micro_batch_label = label[offset : offset + bsz_stride]

assert isinstance(data, dict)

micro_batch_data = {}
for k, v in data.items():
if isinstance(v, torch.Tensor):
micro_batch_data[k] = v[offset : offset + bsz_stride]
elif isinstance(v, InferenceParams):
v.set_batch_offset(offset, bsz_stride)
micro_batch_data[k] = v
elif isinstance(v, (list, tuple)):
micro_batch_data[k] = v[offset : offset + bsz_stride]
else:
raise NotImplementedError(f"value of type {type(v)} is not supported")

if isinstance(label, torch.Tensor):
micro_batch_label = label[offset : offset + bsz_stride]
elif isinstance(label, Dict):
micro_batch_label = {k: v[offset : offset + bsz_stride] if v.dim() > 0 else v for k, v in label.items()}
else:
micro_batch_label = label
return micro_batch_data, micro_batch_label

@abstractmethod
Expand Down
1 change: 0 additions & 1 deletion internlm/core/scheduler/no_pipeline_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,6 @@ def forward_backward_step(
If True, the model is run for the forward pass, else back propagation will be executed.
return_loss (bool, optional): Loss will be returned if True.
return_output_label (bool, optional): Output and label will be returned if True.
Returns:
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
"""
Expand Down
2 changes: 1 addition & 1 deletion internlm/core/scheduler/pipeline_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def pack_return_tensors(return_tensors):
raise TypeError("Output of model must be tensor or list/tuple of tensors")
if isinstance(label[0], torch.Tensor):
label = torch.cat(label, dim=0)
else:
elif isinstance(label[0], dict):
merged_label = {k: [] for k in label[0].keys()}
for d in label:
for k, v in d.items():
Expand Down
4 changes: 3 additions & 1 deletion internlm/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ def move_to_device(data):
data = [move_to_device(x) for x in data]
elif isinstance(data, dict):
data = {k: move_to_device(v) for k, v in data.items()}

else:
# other types like scalar, other params, return the value itself.
return data
return data


Expand Down
133 changes: 133 additions & 0 deletions tests/test_infer/test_generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import os

import pytest
import torch
from sentencepiece import SentencePieceProcessor

from internlm.apis.inference import SequenceGenerator, batch_tokenize
from internlm.initialize import initialize_distributed_env # noqa: E402
from internlm.train import initialize_model, initialize_parallel_communicator


def set_seed(seed: int = 1024):
import random

import numpy as np

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)


def load_and_generate(path, model_type="INTERNLM2_PUBLIC", tokenizer_path=""):
model_cfg = os.path.join(path, "model_config.pt")
model_wt = os.path.join(path, "model_tp0_pp0.pt")
model_config = torch.load(model_cfg)
model_config["apply_post_layer_norm"] = False
if model_config.get("adapt_hf") is not None:
model_config.pop("adapt_hf")
evo_cfg = dict(
model_type=model_type,
model=model_config,
parallel=dict(
zero1=dict(size=1, fsdp=False),
pipeline=dict(size=1, interleaved_overlap=True),
tensor=dict(size=1, mode="mtp"),
sequence_parallel=0,
),
)
initialize_distributed_env(evo_cfg, master_port=23574, args_check=False)

tokenizer = SentencePieceProcessor(tokenizer_path) # pylint: disable=E1121

def convert_to_str(output_ids):
output_tokens = output_ids.tolist()
all_output_str = []
for b in range(len(output_tokens)):
for sent_idx in range(len(output_tokens[b])):
cur_output_tokens = output_tokens[b][sent_idx]
cur_sent = tokenizer.decode(cur_output_tokens)
all_output_str.append(cur_sent)
return all_output_str

model = initialize_model()
_ = initialize_parallel_communicator(model)
# Directly get the origin model without NativeAMP wrapper.
model = model.model

state_dict = torch.load(model_wt)
load_info = model.load_state_dict(state_dict, strict=False)
print(load_info)

sequenece_generator = SequenceGenerator(
decoder=model,
eos_token_id=tokenizer.eos_id(),
pad_token_id=tokenizer.bos_id(),
bos_token_id=tokenizer.bos_id(),
additional_eos_token_list=None,
)

test_prompt_0 = "Gold is considered to be a precious metal."
test_prompt_1 = "what is love? someone think it is a feeling, someone think it is a chemical reaction."
test_prompt_2 = "kobe bryant is a basketball player."

prompt_3 = [
test_prompt_0,
test_prompt_1,
test_prompt_2,
]
prompt_2 = [
test_prompt_0,
test_prompt_1,
]

prompt_1 = [test_prompt_0]

def generate(prompt):
input_ids = batch_tokenize(prompt, tokenizer, pad_token_id=tokenizer.bos_id()).cuda()
generate_kwargs = {}
set_seed()
output_ids = sequenece_generator.generate(
input_ids,
num_return_sequences=generate_kwargs.get("num_return_sequences", 1),
max_length=generate_kwargs.get("max_length", input_ids.shape[1] + 80),
num_beams=generate_kwargs.get("num_beams", 1),
do_sample=generate_kwargs.get("do_sample", False),
temperature=generate_kwargs.get("temperature", 1.0),
top_k=generate_kwargs.get("top_k", 50),
top_p=generate_kwargs.get("top_p", 1.0),
repetition_penalty=generate_kwargs.get("repetition_penalty", 1),
length_penalty=generate_kwargs.get("repetition_penalty", 1.0),
)

all_output_str = convert_to_str(output_ids)
return all_output_str

output_3 = generate(prompt_3)
output_2 = generate(prompt_2)
output_1 = generate(prompt_1)

assert output_3[0] == output_2[0]
assert output_3[1] == output_2[1]
assert (
output_1[0]
== "Gold is considered to be a precious metal. It is a metal that is highly valued for its \
rarity and beauty. Gold is often used in jewelry, coins, and other decorative items. It is also used in \
the production of electronics and other high-tech products. Gold is a highly sought-after metal because \
of its ability to resist corrosion and tarnish. It is also highly resistant to fire and is a good conductor \
of heat and electricity.\n"
)
print("test generate done!")


def test_internlm2_1_8B_generate():
base_model_dir = os.environ.get("qa_data")
if base_model_dir is not None:
model_dir = os.path.join(base_model_dir, "internlm2_1_8B")
tokenizer_path = os.path.join(base_model_dir, "InternLM_CI_assets/v13.model")
if os.path.exists(model_dir) and os.path.exists(tokenizer_path):
load_and_generate(model_dir, tokenizer_path=tokenizer_path)


if __name__ == "__main__":
pytest.main(["-s", "-q", "-v", "test_generate.py"])
Loading

0 comments on commit eaef99e

Please sign in to comment.