diff --git a/Dockerfile b/Dockerfile index 6eb3c1f241a..f4430a339e2 100644 --- a/Dockerfile +++ b/Dockerfile @@ -61,6 +61,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins make \ curl \ git \ + python3.11-dev \ && rm -rf /var/lib/apt/lists/* # Install server @@ -96,5 +97,5 @@ FROM base COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh RUN chmod +x /tgi-entrypoint.sh -ENTRYPOINT ["/tgi-entrypoint.sh"] +#ENTRYPOINT ["/tgi-entrypoint.sh"] # CMD ["--json-output"] diff --git a/server/Makefile b/server/Makefile index 9338b299090..e906458a189 100644 --- a/server/Makefile +++ b/server/Makefile @@ -4,10 +4,6 @@ include Makefile-vllm include Makefile-awq include Makefile-eetq include Makefile-selective-scan -include Makefile-lorax-punica -include Makefile-fbgemm -include Makefile-exllamav2 -include Makefile-flashinfer unit-tests: pytest -s -vv -m "not private" tests @@ -21,25 +17,20 @@ gen-server: find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \; touch text_generation_server/pb/__init__.py -install-server: gen-server +install: gen-server pip install pip --upgrade - pip install -r requirements_cuda.txt - pip install -e ".[accelerate, quantize, peft, outlines]" - - -install: install-cuda - echo "Installed server" - -install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention install-fbgemm - pip install -e ".[bnb]" - pip install nvidia-nccl-cu12==2.22.3 - -install-rocm: install-server install-flash-attention-v2-rocm install-vllm-rocm + pip install -r requirements.txt + pip install -e "." run-dev: SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded +install-poetry: + curl -sSL https://install.python-poetry.org | python3 - + +update-lock: + rm poetry.lock + poetry lock --no-update + export-requirements: - poetry export -o requirements_cuda.txt --without-hashes - poetry export -o requirements_rocm.txt --without-hashes - poetry export -o requirements_intel.txt --without-hashes + poetry export -o requirements.txt --without-hashes diff --git a/server/dill-0.3.7-patch.sh b/server/dill-0.3.7-patch.sh new file mode 100644 index 00000000000..ad8c8be589d --- /dev/null +++ b/server/dill-0.3.7-patch.sh @@ -0,0 +1,91 @@ +#!/bin/bash +git clone -b dill-0.3.7 https://github.com/uqfoundation/dill.git +pushd dill +cat < dill-0.3.7.patch +diff --git a/dill/_dill.py b/dill/_dill.py +index d0cf543..f6eb662 100644 +--- a/dill/_dill.py ++++ b/dill/_dill.py +@@ -69,7 +69,15 @@ TypeType = type # 'new-style' classes #XXX: unregistered + XRangeType = range + from types import MappingProxyType as DictProxyType, new_class + from pickle import DEFAULT_PROTOCOL, HIGHEST_PROTOCOL, PickleError, PicklingError, UnpicklingError +-import __main__ as _main_module ++class _LazyMainModule(object): ++ _module = None ++ @property ++ def module(self): ++ if self._module is None: ++ import __main__ as _m_module ++ self._module = _m_module ++ return self._module ++_main_module = _LazyMainModule() + import marshal + import gc + # import zlib +@@ -353,7 +361,7 @@ class Pickler(StockPickler): + _fmode = kwds.pop('fmode', None) + _recurse = kwds.pop('recurse', None) + StockPickler.__init__(self, file, *args, **kwds) +- self._main = _main_module ++ self._main = _main_module.module + self._diff_cache = {} + self._byref = settings['byref'] if _byref is None else _byref + self._strictio = False #_strictio +@@ -435,12 +443,12 @@ class Unpickler(StockUnpickler): + settings = Pickler.settings + _ignore = kwds.pop('ignore', None) + StockUnpickler.__init__(self, *args, **kwds) +- self._main = _main_module ++ self._main = _main_module.module + self._ignore = settings['ignore'] if _ignore is None else _ignore + + def load(self): #NOTE: if settings change, need to update attributes + obj = StockUnpickler.load(self) +- if type(obj).__module__ == getattr(_main_module, '__name__', '__main__'): ++ if type(obj).__module__ == getattr(self._main, '__name__', '__main__'): + if not self._ignore: + # point obj class to main + try: obj.__class__ = getattr(self._main, type(obj).__name__) +@@ -1194,11 +1202,11 @@ def save_module_dict(pickler, obj): + logger.trace(pickler, "D1: %s", _repr_dict(obj)) # obj + pickler.write(bytes('c__builtin__\n__main__\n', 'UTF-8')) + logger.trace(pickler, "# D1") +- elif (not is_dill(pickler, child=False)) and (obj == _main_module.__dict__): ++ elif (not is_dill(pickler, child=False)) and (obj == _main_module.module.__dict__): + logger.trace(pickler, "D3: %s", _repr_dict(obj)) # obj + pickler.write(bytes('c__main__\n__dict__\n', 'UTF-8')) #XXX: works in general? + logger.trace(pickler, "# D3") +- elif '__name__' in obj and obj != _main_module.__dict__ \\ ++ elif '__name__' in obj and obj != _main_module.module.__dict__ \\ + and type(obj['__name__']) is str \\ + and obj is getattr(_import_module(obj['__name__'],True), '__dict__', None): + logger.trace(pickler, "D4: %s", _repr_dict(obj)) # obj +diff --git a/dill/session.py b/dill/session.py +index 74234ab..1be8d89 100644 +--- a/dill/session.py ++++ b/dill/session.py +@@ -233,7 +233,7 @@ def dump_module( + protocol = settings['protocol'] + main = module + if main is None: +- main = _main_module ++ main = _main_module.module + elif isinstance(main, str): + main = _import_module(main) + if not isinstance(main, ModuleType): +@@ -501,7 +501,7 @@ def load_module( + pass + assert loaded is main + _restore_modules(unpickler, main) +- if main is _main_module or main is module: ++ if main is _main_module.module or main is module: + return None + else: + return main + +EOF +git apply dill-0.3.7.patch +python -m pip install . +popd +rm -fr dill diff --git a/server/dill-0.3.8-patch.sh b/server/dill-0.3.8-patch.sh new file mode 100644 index 00000000000..da263960f6b --- /dev/null +++ b/server/dill-0.3.8-patch.sh @@ -0,0 +1,91 @@ +#!/bin/bash +git clone -b 0.3.8 https://github.com/uqfoundation/dill.git +pushd dill +cat < dill-0.3.8.patch +diff --git a/dill/_dill.py b/dill/_dill.py +index d42432f..1d251e6 100644 +--- a/dill/_dill.py ++++ b/dill/_dill.py +@@ -69,7 +69,15 @@ TypeType = type # 'new-style' classes #XXX: unregistered + XRangeType = range + from types import MappingProxyType as DictProxyType, new_class + from pickle import DEFAULT_PROTOCOL, HIGHEST_PROTOCOL, PickleError, PicklingError, UnpicklingError +-import __main__ as _main_module ++class _LazyMainModule(object): ++ _module = None ++ @property ++ def module(self): ++ if self._module is None: ++ import __main__ as _m_module ++ self._module = _m_module ++ return self._module ++_main_module = _LazyMainModule() + import marshal + import gc + # import zlib +@@ -355,7 +363,7 @@ class Pickler(StockPickler): + _fmode = kwds.pop('fmode', None) + _recurse = kwds.pop('recurse', None) + StockPickler.__init__(self, file, *args, **kwds) +- self._main = _main_module ++ self._main = _main_module.module + self._diff_cache = {} + self._byref = settings['byref'] if _byref is None else _byref + self._strictio = False #_strictio +@@ -437,12 +445,12 @@ class Unpickler(StockUnpickler): + settings = Pickler.settings + _ignore = kwds.pop('ignore', None) + StockUnpickler.__init__(self, *args, **kwds) +- self._main = _main_module ++ self._main = _main_module.module + self._ignore = settings['ignore'] if _ignore is None else _ignore + + def load(self): #NOTE: if settings change, need to update attributes + obj = StockUnpickler.load(self) +- if type(obj).__module__ == getattr(_main_module, '__name__', '__main__'): ++ if type(obj).__module__ == getattr(self._main, '__name__', '__main__'): + if not self._ignore: + # point obj class to main + try: obj.__class__ = getattr(self._main, type(obj).__name__) +@@ -1199,11 +1207,11 @@ def save_module_dict(pickler, obj): + logger.trace(pickler, "D1: %s", _repr_dict(obj)) # obj + pickler.write(bytes('c__builtin__\n__main__\n', 'UTF-8')) + logger.trace(pickler, "# D1") +- elif (not is_dill(pickler, child=False)) and (obj == _main_module.__dict__): ++ elif (not is_dill(pickler, child=False)) and (obj == _main_module.module.__dict__): + logger.trace(pickler, "D3: %s", _repr_dict(obj)) # obj + pickler.write(bytes('c__main__\n__dict__\n', 'UTF-8')) #XXX: works in general? + logger.trace(pickler, "# D3") +- elif '__name__' in obj and obj != _main_module.__dict__ \\ ++ elif '__name__' in obj and obj != _main_module.module.__dict__ \\ + and type(obj['__name__']) is str \\ + and obj is getattr(_import_module(obj['__name__'],True), '__dict__', None): + logger.trace(pickler, "D4: %s", _repr_dict(obj)) # obj +diff --git a/dill/session.py b/dill/session.py +index e91068a..a921b43 100644 +--- a/dill/session.py ++++ b/dill/session.py +@@ -233,7 +233,7 @@ def dump_module( + protocol = settings['protocol'] + main = module + if main is None: +- main = _main_module ++ main = _main_module.module + elif isinstance(main, str): + main = _import_module(main) + if not isinstance(main, ModuleType): +@@ -501,7 +501,7 @@ def load_module( + pass + assert loaded is main + _restore_modules(unpickler, main) +- if main is _main_module or main is module: ++ if main is _main_module.module or main is module: + return None + else: + return main + +EOF +git apply dill-0.3.8.patch +python -m pip install . +popd +rm -fr dill \ No newline at end of file diff --git a/server/pyproject.toml b/server/pyproject.toml index 6bdd238591c..c1809550f12 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "text-generation-server" -version = "2.0.5-dev0" +version = "2.0.4" description = "Text Generation Inference Python gRPC Server" authors = ["Olivier Dehaene "] @@ -9,76 +9,34 @@ text-generation-server = 'text_generation_server.cli:app' [tool.poetry.dependencies] python = ">=3.9,<3.13" -protobuf = "^4.25.3" +protobuf = "^3.20.3" grpcio = "^1.51.1" -grpcio-status = "^1.51.1" -grpcio-reflection = "^1.51.1" +grpcio-status = "*" +grpcio-reflection = "*" grpc-interceptor = "^0.15.0" -typer = "^0.6.1" -accelerate = { version = "^0.29.1", optional = true } -bitsandbytes = { version = "^0.43.0", optional = true } -safetensors = "^0.4" +typer = "^0.7.0" loguru = "^0.6.0" -opentelemetry-api = "^1.25.0" -opentelemetry-exporter-otlp = "^1.25.0" -opentelemetry-instrumentation-grpc = "^0.46b0" +opentelemetry-api = "^1.15.0" +opentelemetry-exporter-otlp = "^1.15.0" +opentelemetry-instrumentation-grpc = "^0.36b0" hf-transfer = "^0.1.2" sentencepiece = "^0.1.97" -tokenizers = "^0.19.1" -huggingface-hub = "^0.23" -transformers = "^4.43" -einops = "^0.6.1" -texttable = { version = "^1.6.7", optional = true } -datasets = { version = "^2.14.0", optional = true } -peft = { version = "^0.10", optional = true } -torch = { version = "^2.4.0", optional = true } -scipy = "^1.11.1" -pillow = "^10.0.0" -outlines= { version = "^0.0.34", optional = true } +peft = "^0.10" +optimum-habana = "1.13.2" +transformers = "4.43.4" +numpy = "1.26.4" +accelerate = "0.33.0" +outlines= { version = "^0.0.36", optional = true } prometheus-client = "^0.20.0" py-cpuinfo = "^9.0.0" -# Remove later, temporary workaround for outlines. -numpy = "^1.26" - -marlin-kernels = [ - { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true }, - { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true }, - { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true }, - { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true }, -] -moe-kernels = [ - { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true }, - { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true }, - { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true }, - { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true }, -] -rich = "^13.7.1" - -[tool.poetry.extras] -torch = ["torch"] -accelerate = ["accelerate"] -bnb = ["bitsandbytes"] -marlin = ["marlin-kernels"] -moe = ["moe-kernels"] -peft = ["peft"] -quantize = ["texttable", "datasets", "accelerate"] -outlines = ["outlines"] [tool.poetry.group.dev.dependencies] -grpcio-tools = "^1.51.1" +grpcio-tools = "*" pytest = "^7.3.0" - -[[tool.poetry.source]] -name = "pytorch-gpu-src" -url = "https://download.pytorch.org/whl/cu121" -priority = "explicit" - [tool.pytest.ini_options] markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"] [build-system] -requires = [ - "poetry-core>=1.0.0", -] +requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" diff --git a/server/requirements.txt b/server/requirements.txt new file mode 100644 index 00000000000..0a091a2fe85 --- /dev/null +++ b/server/requirements.txt @@ -0,0 +1,88 @@ +accelerate==0.33.0 ; python_version >= "3.9" and python_version < "3.13" +aiohappyeyeballs==2.4.0 ; python_version >= "3.9" and python_version < "3.13" +aiohttp==3.10.5 ; python_version >= "3.9" and python_version < "3.13" +aiosignal==1.3.1 ; python_version >= "3.9" and python_version < "3.13" +async-timeout==4.0.3 ; python_version >= "3.9" and python_version < "3.11" +attrs==24.2.0 ; python_version >= "3.9" and python_version < "3.13" +backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13" +certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13" +charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13" +click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" +colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") +coloredlogs==15.0.1 ; python_version >= "3.9" and python_version < "3.13" +datasets==2.21.0 ; python_version >= "3.9" and python_version < "3.13" +deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" +diffusers==0.29.2 ; python_version >= "3.9" and python_version < "3.13" +dill==0.3.8 ; python_version >= "3.9" and python_version < "3.13" +filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13" +frozenlist==1.4.1 ; python_version >= "3.9" and python_version < "3.13" +fsspec==2024.6.1 ; python_version >= "3.9" and python_version < "3.13" +fsspec[http]==2024.6.1 ; python_version >= "3.9" and python_version < "3.13" +googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13" +grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13" +grpcio-reflection==1.48.2 ; python_version >= "3.9" and python_version < "3.13" +grpcio-status==1.48.2 ; python_version >= "3.9" and python_version < "3.13" +grpcio==1.66.0 ; python_version >= "3.9" and python_version < "3.13" +hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13" +huggingface-hub==0.24.6 ; python_version >= "3.9" and python_version < "3.13" +humanfriendly==10.0 ; python_version >= "3.9" and python_version < "3.13" +idna==3.8 ; python_version >= "3.9" and python_version < "3.13" +importlib-metadata==8.4.0 ; python_version >= "3.9" and python_version < "3.13" +jinja2==3.1.4 ; python_version >= "3.9" and python_version < "3.13" +joblib==1.4.2 ; python_version >= "3.9" and python_version < "3.13" +loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" +markupsafe==2.1.5 ; python_version >= "3.9" and python_version < "3.13" +mpmath==1.3.0 ; python_version >= "3.9" and python_version < "3.13" +multidict==6.0.5 ; python_version >= "3.9" and python_version < "3.13" +multiprocess==0.70.16 ; python_version >= "3.9" and python_version < "3.13" +networkx==3.2.1 ; python_version >= "3.9" and python_version < "3.13" +numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13" +optimum-habana==1.13.2 ; python_version >= "3.9" and python_version < "3.13" +optimum==1.21.4 ; python_version >= "3.9" and python_version < "3.13" +packaging==24.1 ; python_version >= "3.9" and python_version < "3.13" +pandas==2.2.2 ; python_version >= "3.9" and python_version < "3.13" +peft==0.10.0 ; python_version >= "3.9" and python_version < "3.13" +pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13" +prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13" +protobuf==3.20.3 ; python_version >= "3.9" and python_version < "3.13" +psutil==6.0.0 ; python_version >= "3.9" and python_version < "3.13" +py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13" +pyarrow==17.0.0 ; python_version >= "3.9" and python_version < "3.13" +pyreadline3==3.4.1 ; sys_platform == "win32" and python_version >= "3.9" and python_version < "3.13" +python-dateutil==2.9.0.post0 ; python_version >= "3.9" and python_version < "3.13" +pytz==2024.1 ; python_version >= "3.9" and python_version < "3.13" +pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13" +regex==2024.7.24 ; python_version >= "3.9" and python_version < "3.13" +requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13" +safetensors==0.4.4 ; python_version >= "3.9" and python_version < "3.13" +scikit-learn==1.5.1 ; python_version >= "3.9" and python_version < "3.13" +scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" +sentence-transformers[train]==3.0.1 ; python_version >= "3.9" and python_version < "3.13" +sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" +setuptools==73.0.1 ; python_version >= "3.9" and python_version < "3.13" +six==1.16.0 ; python_version >= "3.9" and python_version < "3.13" +sympy==1.12.1 ; python_version >= "3.9" and python_version < "3.13" +threadpoolctl==3.5.0 ; python_version >= "3.9" and python_version < "3.13" +tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13" +tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13" +transformers==4.43.4 ; python_version >= "3.9" and python_version < "3.13" +transformers[sentencepiece]==4.43.4 ; python_version >= "3.9" and python_version < "3.13" +triton==3.0.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version < "3.13" and python_version >= "3.9" +typer==0.7.0 ; python_version >= "3.9" and python_version < "3.13" +typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13" +tzdata==2024.1 ; python_version >= "3.9" and python_version < "3.13" +urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13" +win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32" +wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13" +xxhash==3.5.0 ; python_version >= "3.9" and python_version < "3.13" +yarl==1.9.4 ; python_version >= "3.9" and python_version < "3.13" +zipp==3.20.0 ; python_version >= "3.9" and python_version < "3.13" diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 10aa3a3b2b4..baf94986b5a 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -47,9 +47,9 @@ def serve( max_input_tokens: Optional[int] = None, ): if sharded: - assert ( - os.getenv("RANK", None) is not None - ), "RANK must be set when sharded is True" + # assert ( + # os.getenv("RANK", None) is not None + # ), "RANK must be set when sharded is True" assert ( os.getenv("WORLD_SIZE", None) is not None ), "WORLD_SIZE must be set when sharded is True" @@ -96,7 +96,7 @@ def serve( # Downgrade enum into str for easier management later on quantize = None if quantize is None else quantize.value - dtype = None if dtype is None else dtype.value + dtype = "bfloat16" if dtype is None else dtype.value if dtype is not None and quantize not in { None, "bitsandbytes", @@ -106,18 +106,76 @@ def serve( raise RuntimeError( "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model." ) - server.serve( - model_id, - lora_adapters, - revision, - sharded, - quantize, - speculate, - dtype, - trust_remote_code, - uds_path, - max_input_tokens, - ) + + logger.info("CLI SHARDED = {} DTYPE = {}".format(sharded, dtype)) + + if sharded: + tgi_file = Path(__file__).resolve().parent / "tgi_service.py" + num_shard = int(os.getenv("WORLD_SIZE", "1")) + logger.info("CLI SHARDED = {}".format(num_shard)) + import subprocess + + cmd = f"deepspeed --num_nodes 1 --num_gpus {num_shard} --no_local_rank {tgi_file}" + cmd += f" --model_id {model_id} --revision {revision} --sharded {sharded}" + cmd += f" --dtype {dtype} --trust_remote_code {trust_remote_code} --uds_path {uds_path}" + cmd += f" --quantize {quantize} --max_input_tokens {max_input_tokens}" + if speculate is not None: + cmd += f"--speculate {speculate}" + logger.info("CLI server start deepspeed ={} ".format(cmd)) + sys.stdout.flush() + sys.stderr.flush() + with subprocess.Popen(cmd, shell=True, executable="/bin/bash") as proc: + do_terminate = False + current_handler = signal.getsignal(signal.SIGTERM) + def terminate_handler(sig, frame): + nonlocal do_terminate + do_terminate = True + if callable(current_handler): + current_handler(sig, frame) + + signal.signal(signal.SIGTERM, terminate_handler) + + finished = False + while not finished: + try: + if do_terminate: + parent = psutil.Process(proc.pid) + all_procs = parent.children(recursive=True) + [parent] + for p in all_procs: + try: + p.terminate() + except psutil.NoSuchProcess: + pass + _, alive = psutil.wait_procs(all_procs, timeout=30) + for p in alive: + p.kill() + + do_terminate = False + + proc.wait(timeout=3) + except subprocess.TimeoutExpired: + pass + else: + finished = True + + sys.stdout.flush() + sys.stderr.flush() + if proc.returncode != 0: + logger.error(f"{cmd} exited with status = {proc.returncode}") + return proc.returncode + else: + server.serve( + model_id, + lora_adapters, + revision, + sharded, + quantize, + speculate, + dtype, + trust_remote_code, + uds_path, + max_input_tokens, + ) @app.command() diff --git a/server/text_generation_server/habana_quantization_env.py b/server/text_generation_server/habana_quantization_env.py new file mode 100644 index 00000000000..3c06fd098b1 --- /dev/null +++ b/server/text_generation_server/habana_quantization_env.py @@ -0,0 +1,27 @@ +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. + +import os + +quant_config = os.getenv("QUANT_CONFIG", "") +is_quantization_enabled = quant_config != "" + +if is_quantization_enabled: + os.environ.setdefault("ENABLE_EXPERIMENTAL_FLAGS", "true") + os.environ.setdefault("USE_DEFAULT_QUANT_PARAM", "true") + os.environ.setdefault("UPDATE_GRAPH_OUTPUT_MME", "false") + os.environ.setdefault("ENABLE_CALC_DYNAMIC_RANGE", "false") + os.environ.setdefault( + "UPDATE_MME_OUTPUT_PRECISION_FILTER", "v_proj,matmul_av") + os.environ.setdefault("EXPERIMENTAL_WEIGHT_SHARING", "FALSE") + + +def prepare_model_for_quantization(model): + if is_quantization_enabled: + if os.getenv("USE_INC", "1") != "0": + from neural_compressor.torch.quantization import FP8Config, convert + config = FP8Config.from_json_file(quant_config) + model = convert(model, config) + else: + import habana_quantization_toolkit + habana_quantization_toolkit.prep_model(model) + return model diff --git a/server/text_generation_server/interceptor.py b/server/text_generation_server/interceptor.py index 57df172575a..05339282b1b 100644 --- a/server/text_generation_server/interceptor.py +++ b/server/text_generation_server/interceptor.py @@ -1,3 +1,5 @@ +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. + import torch import grpc @@ -6,6 +8,8 @@ from grpc_interceptor.server import AsyncServerInterceptor from loguru import logger from typing import Callable, Any +import traceback +import os class ExceptionInterceptor(AsyncServerInterceptor): @@ -20,6 +24,7 @@ async def intercept( response = method(request_or_iterator, context) return await response except Exception as err: + trace = " " + traceback.format_exc() if os.environ.get('DUMP_STACK') else '' method_name = method_name.split("/")[-1] logger.exception(f"Method {method_name} encountered an error.") @@ -30,8 +35,10 @@ async def intercept( if torch.cuda.is_available(): torch.cuda.empty_cache() + from .utils.debug import dbg_trace + dbg_trace('EXCEPTION', traceback.format_exc()) await context.abort_with_status( rpc_status.to_status( - status_pb2.Status(code=code_pb2.INTERNAL, message=str(err)) + status_pb2.Status(code=code_pb2.INTERNAL, message=str(err) + trace) ) ) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index fc530b38459..d3c7bd8f68e 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -1,319 +1,40 @@ -# ruff: noqa: F821 -# the above line disables the `undefined-name` rule for the model type variables - import torch -import enum import os from loguru import logger from transformers.configuration_utils import PretrainedConfig from transformers.models.auto import modeling_auto from huggingface_hub import hf_hub_download, HfApi -from typing import Optional, List, Dict +from typing import Optional from pathlib import Path +from typing import Optional, List, Dict +# Needed to properly setup habana_frameworks +import text_generation_server.habana_quantization_env as hq_env from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.models.model import Model -from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast -from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM -from text_generation_server.models.custom_modeling.mpt_modeling import ( - MPTForCausalLM, -) -from text_generation_server.models.bloom import BloomCausalLMBatch -from text_generation_server.models.custom_modeling.bloom_modeling import ( - BloomForCausalLM, -) -from text_generation_server.models.seq2seq_lm import Seq2SeqLM -from text_generation_server.models.galactica import GalacticaCausalLMBatch -from text_generation_server.models.custom_modeling.neox_modeling import ( - GPTNeoxForCausalLM, -) -from text_generation_server.models.custom_modeling.phi_modeling import ( - PhiConfig, - PhiForCausalLM, -) -from text_generation_server.models.custom_modeling.t5_modeling import ( - T5ForConditionalGeneration, +from text_generation_server.models.causal_lm import CausalLM +#from text_generation_server.models.bloom import BLOOM +from text_generation_server.models.starcoder import StarCoder +from text_generation_server.models.vlm_causal_lm import VlmCausalLM +from text_generation_server.models.custom_modeling.llava_next import ( + LlavaNextForConditionalGeneration, ) - from text_generation_server.utils.adapter import ( AdapterParameters, build_layer_weight_lookup, load_and_merge_adapters, AdapterInfo, ) -from text_generation_server.adapters.lora import LoraWeights - -from text_generation_server.utils.import_utils import SYSTEM -from text_generation_server.utils.log import log_master -# The flag below controls whether to allow TF32 on matmul. This flag defaults to False -# in PyTorch 1.12 and later. -torch.backends.cuda.matmul.allow_tf32 = True +from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi -# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True. -torch.backends.cudnn.allow_tf32 = True # Disable gradients torch.set_grad_enabled(False) -__all__ = [ - "Model", - "CausalLM", - "Seq2SeqLM", - "get_model_with_lora_adapters", -] - -FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models." - -FLASH_ATTENTION = True - -try: - from text_generation_server.models.flash_causal_lm import FlashCausalLM - from text_generation_server.models.vlm_causal_lm import VlmCausalLM - from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import ( - FlashDeepseekV2ForCausalLM, - DeepseekV2Config, - ) - from text_generation_server.models.custom_modeling.flash_llama_modeling import ( - FlashLlamaForCausalLM, - ) - from text_generation_server.models.custom_modeling.flash_cohere_modeling import ( - FlashCohereForCausalLM, - ) - from text_generation_server.models.custom_modeling.flash_gemma_modeling import ( - FlashGemmaForCausalLM, - ) - from text_generation_server.models.custom_modeling.flash_gemma2_modeling import ( - FlashGemma2ForCausalLM, - ) - from text_generation_server.models.custom_modeling.flash_dbrx_modeling import ( - FlashDbrxForCausalLM, - DbrxConfig, - ) - from text_generation_server.models.custom_modeling.flash_rw_modeling import ( - RWConfig, - FlashRWForCausalLM, - ) - from text_generation_server.models.custom_modeling.flash_neox_modeling import ( - FlashGPTNeoXForCausalLM, - ) - from text_generation_server.models.pali_gemma import ( - PaliGemmaBatch, - ) - from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import ( - PaliGemmaForConditionalGeneration, - ) - from text_generation_server.models.custom_modeling.flash_phi_modeling import ( - FlashPhiForCausalLM, - ) - from text_generation_server.models.idefics import IDEFICSSharded - from text_generation_server.models.custom_modeling.llava_next import ( - LlavaNextForConditionalGeneration, - ) - - from text_generation_server.models.custom_modeling.flash_santacoder_modeling import ( - FlashSantacoderForCausalLM, - ) - from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import ( - FlashStarcoder2ForCausalLM, - ) - from text_generation_server.models.custom_modeling.flash_qwen2_modeling import ( - Qwen2ForCausalLM, - ) - from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( - FlashMistralForCausalLM, - ) - from text_generation_server.models.custom_modeling.flash_mixtral_modeling import ( - FlashMixtralForCausalLM, - ) - from text_generation_server.models.custom_modeling.flash_gpt2_modeling import ( - FlashGPT2ForCausalLM, - ) - from text_generation_server.models.custom_modeling.flash_gptj_modeling import ( - FlashGPTJForCausalLM, - ) - from text_generation_server.models.custom_modeling.idefics2 import ( - Idefics2ForConditionalGeneration, - ) - from text_generation_server.layers.attention import SUPPORTS_WINDOWING -except ImportError as e: - log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}") - SUPPORTS_WINDOWING = False - FLASH_ATTENTION = False - -if FLASH_ATTENTION: - __all__.append(FlashCausalLM) - __all__.append(IDEFICSSharded) - -MAMBA_AVAILABLE = True -try: - from text_generation_server.models.mamba import Mamba -except ImportError as e: - log_master(logger.warning, f"Could not import Mamba: {e}") - MAMBA_AVAILABLE = False - -if MAMBA_AVAILABLE: - __all__.append(Mamba) - - -class ModelType(enum.Enum): - DEEPSEEK_V2 = { - "type": "deepseek_v2", - "name": "Deepseek V2", - "url": "https://huggingface.co/deepseek-ai/DeepSeek-V2", - } - IDEFICS2 = { - "type": "idefics2", - "name": "Idefics 2", - "url": "https://huggingface.co/HuggingFaceM4/idefics2-8b", - "multimodal": True, - } - LLAVA_NEXT = { - "type": "llava_next", - "name": "Llava Next (1.6)", - "url": "https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf", - "multimodal": True, - } - LLAMA = { - "type": "llama", - "name": "Llama", - "url": "https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f", - } - PHI3 = { - "type": "phi3", - "name": "Phi 3", - "url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct", - } - GEMMA = { - "type": "gemma", - "name": "Gemma", - "url": "https://huggingface.co/google/gemma-7b", - } - PALIGEMMA = { - "type": "paligemma", - "name": "PaliGemma", - "url": "https://huggingface.co/google/paligemma-3b-pt-224", - } - GEMMA2 = { - "type": "gemma2", - "name": "Gemma2", - "url": "https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315", - } - COHERE = { - "type": "cohere", - "name": "Cohere", - "url": "https://huggingface.co/CohereForAI/c4ai-command-r-plus", - } - DBRX = { - "type": "dbrx", - "name": "Dbrx", - "url": "https://huggingface.co/databricks/dbrx-instruct", - } - MAMBA = { - "type": "ssm", - "name": "Mamba", - "url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj", - } - MISTRAL = { - "type": "mistral", - "name": "Mistral", - "url": "https://huggingface.co/mistralai/Mistral-Nemo-Instruct-2407", - } - MIXTRAL = { - "type": "mixtral", - "name": "Mixtral", - "url": "https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1", - } - GPT_BIGCODE = { - "type": "gpt_bigcode", - "name": "Gpt Bigcode", - "url": "https://huggingface.co/bigcode/gpt_bigcode-santacoder", - } - PHI = { - "type": "phi", - "name": "Phi", - "url": "https://huggingface.co/microsoft/phi-1_5", - } - BAICHUAN = { - "type": "baichuan", - "name": "Baichuan", - "url": "https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat", - } - FALCON = { - "type": "falcon", - "name": "Falcon", - "url": "https://huggingface.co/tiiuae/falcon-7b-instruct", - } - STARCODER2 = { - "type": "starcoder2", - "name": "StarCoder 2", - "url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1", - } - QWEN2 = { - "type": "qwen2", - "name": "Qwen 2", - "url": "https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f", - } - OPT = { - "type": "opt", - "name": "Opt", - "url": "https://huggingface.co/facebook/opt-6.7b", - } - T5 = { - "type": "t5", - "name": "T5", - "url": "https://huggingface.co/google/flan-t5-xxl", - } - GALACTICA = { - "type": "galactica", - "name": "Galactica", - "url": "https://huggingface.co/facebook/galactica-120b", - } - SANTACODER = { - "type": "santacoder", - "name": "SantaCoder", - "url": "https://huggingface.co/bigcode/santacoder", - } - BLOOM = { - "type": "bloom", - "name": "Bloom", - "url": "https://huggingface.co/bigscience/bloom-560m", - } - MPT = { - "type": "mpt", - "name": "Mpt", - "url": "https://huggingface.co/mosaicml/mpt-7b-instruct", - } - GPT2 = { - "type": "gpt2", - "name": "Gpt2", - "url": "https://huggingface.co/openai-community/gpt2", - } - GPT_NEOX = { - "type": "gpt_neox", - "name": "Gpt Neox", - "url": "https://huggingface.co/EleutherAI/gpt-neox-20b", - } - GPTJ = { - "type": "gptj", - "name": "Gptj", - "url": "https://huggingface.co/EleutherAI/gpt-j-6b", - } - IDEFICS = { - "type": "idefics", - "name": "Idefics", - "url": "https://huggingface.co/HuggingFaceM4/idefics-9b", - "multimodal": True, - } - - -__GLOBALS = locals() -for data in ModelType: - __GLOBALS[data.name] = data.value["type"] - def get_model( model_id: str, @@ -322,55 +43,22 @@ def get_model( sharded: bool, quantize: Optional[str], speculate: Optional[int], - dtype: Optional[str], + dtype: Optional[torch.dtype], trust_remote_code: bool, max_input_tokens: int, ) -> Model: - global FLASH_ATTENTION - - config_dict, _ = PretrainedConfig.get_config_dict( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - model_type = config_dict.get("model_type", None) - - quantization_config = config_dict.get("quantization_config", None) - if quantization_config is not None and quantize is None: - method = quantization_config.get("quant_method", None) - if method in {"gptq", "awq", "exl2"}: - log_master(logger.info, f"Auto selecting quantization method {method}") - quantize = method - elif method == "fbgemm_fp8": - log_master(logger.info, "Auto selecting quantization method fp8") - quantize = "fp8" - else: - log_master(logger.warning, f"Unknown quantization method {method}") - - if dtype is None: - if quantize in ["awq", "exl2", "gptq", "marlin"]: - # These quantizers only work with float16 params. - dtype = torch.float16 - elif quantize == "fp8": - from text_generation_server.layers.fp8 import FBGEMM_DYN_AVAILABLE - - if FBGEMM_DYN_AVAILABLE: - # fbgemm kernels are fp8xfp8->bf16 - dtype = torch.bfloat16 - else: - # Keep it as default for now and let - # every model resolve their own default dtype. - dtype = None - elif dtype == "float16": - dtype = torch.float16 - elif dtype == "bfloat16": - dtype = torch.bfloat16 - else: - raise RuntimeError(f"Unknown dtype {dtype}") + adapt_transformers_to_gaudi() if speculate is not None: set_speculate(speculate) else: set_speculate(0) + config_dict, _ = PretrainedConfig.get_config_dict( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + model_type = config_dict.get("model_type", None) + speculator = None if "medusa_num_heads" in config_dict: medusa_model_id = model_id @@ -477,681 +165,35 @@ def get_model( speculate = get_speculate() if speculate > 0: - log_master( - logger.info, f"Using speculation {method} with {speculate} input ids." - ) - - if model_type is None: - # TODO: fix how we determine model type for Mamba - if "ssm_cfg" in config_dict: - # *only happens in Mamba case - model_type = "ssm" - else: - raise RuntimeError( - f"Could not determine model type for {model_id} revision {revision}" - ) + logger.info(f"Using speculation {method} with {speculate} input ids.") - if quantize == "exl2" and sharded: - raise RuntimeError( - "Sharding is currently not supported with `exl2` quantization" - ) + model_type = config_dict["model_type"] - sliding_window = ( - config_dict.get("sliding_window") - if config_dict.get("sliding_window") is not None - else -1 - ) + if model_type == "gpt_bigcode": + return StarCoder(model_id, revision, dtype) - use_sliding_window = sliding_window is not None and sliding_window != -1 - needs_sliding_window = ( - max_input_tokens is not None and max_input_tokens > sliding_window - ) - if use_sliding_window and needs_sliding_window and not SUPPORTS_WINDOWING: - raise ValueError( - f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})." - ) - - if model_type == DEEPSEEK_V2: - if FLASH_ATTENTION: - head_size = max( - config_dict.get("qk_nope_dim", 128) - + config_dict.get("qk_rope_dim", 64), - config_dict.get("v_head_dim", 128), - ) - return FlashCausalLM( - model_id=model_id, - model_class=FlashDeepseekV2ForCausalLM, - revision=revision, - quantize=quantize, - speculator=speculator, - default_dtype=torch.bfloat16, - dtype=dtype, - trust_remote_code=trust_remote_code, - lora_adapter_ids=lora_adapter_ids, - config_class=DeepseekV2Config, - head_size=head_size, - ) - elif sharded: - raise NotImplementedError( - FLASH_ATT_ERROR_MESSAGE.format("Sharded Deepseek V2") - ) - else: - return CausalLM.fallback( - model_id, - revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - elif model_type == MAMBA: - return Mamba( + if model_type == "bloom": + return BLOOM( model_id, revision, - quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) - if model_id.startswith("facebook/galactica"): - return CausalLM( + if model_type == "llava_next": + return VlmCausalLM( + model_class=LlavaNextForConditionalGeneration, model_id=model_id, - # Yes galactica is just an OPT model. - model_class=OPTForCausalLM, revision=revision, - quantize=quantize, + quantize=None, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, - batch_class=GalacticaCausalLMBatch, ) - if ( - model_type == GPT_BIGCODE - or model_type == GPT2 - and model_id.startswith("bigcode/") - ): - if FLASH_ATTENTION: - return FlashCausalLM( - model_id=model_id, - model_class=FlashSantacoderForCausalLM, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - lora_adapter_ids=lora_adapter_ids, - aliases={"transformer.wte.weight": ["lm_head.weight"]}, - num_kv_heads=1, - ) - elif sharded: - raise NotImplementedError( - FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder") - ) - else: - return CausalLM.fallback( - model_id=model_id, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - - if model_type == BLOOM: - return CausalLM( - model_id=model_id, - model_class=BloomForCausalLM, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - batch_class=BloomCausalLMBatch, - ) - elif model_type == MPT: - return CausalLM( - model_id=model_id, - model_class=MPTForCausalLM, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - batch_class=CausalLMBatchKeysLast, - ) - elif model_type == GPT2: - if FLASH_ATTENTION: - try: - return FlashCausalLM( - model_id=model_id, - model_class=FlashGPT2ForCausalLM, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - lora_adapter_ids=lora_adapter_ids, - ) - except RuntimeError as e: - # Lots of legacy models with various weight names. - log_master(logger.warning, f"Couldn't load flash gpt2 variant: {e}") - return CausalLM.fallback( - model_id, - revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - elif sharded: - raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2")) - else: - return CausalLM.fallback( - model_id, - revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - elif model_type == GPTJ: - if FLASH_ATTENTION: - try: - return FlashCausalLM( - model_id=model_id, - model_class=FlashGPTJForCausalLM, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - lora_adapter_ids=lora_adapter_ids, - ) - except RuntimeError as e: - # Lots of legacy models with various weight names. - log_master(logger.warning, f"Couldn't load flash gptj variant: {e}") - return CausalLM.fallback( - model_id, - revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - elif sharded: - raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-J")) - else: - return CausalLM.fallback( - model_id, - revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - elif model_type == GPT_NEOX: - if FLASH_ATTENTION: - from text_generation_server.models.custom_modeling.flash_neox_modeling import ( - GPTNeoXConfig, - ) - - return FlashCausalLM( - model_id=model_id, - model_class=FlashGPTNeoXForCausalLM, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - lora_adapter_ids=lora_adapter_ids, - config_class=GPTNeoXConfig, - ) - elif sharded: - return CausalLM( - model_id=model_id, - model_class=GPTNeoxForCausalLM, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - else: - return CausalLM.fallback( - model_id, - revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - - elif model_type == PHI: - if FLASH_ATTENTION: - return FlashCausalLM( - model_id=model_id, - model_class=FlashPhiForCausalLM, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - lora_adapter_ids=lora_adapter_ids, - ) - else: - return CausalLM.fallback( - model_id, - revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - - elif model_type == "phi-msft": - if FLASH_ATTENTION: - raise NotImplementedError( - "Legacy phi-msft is not supported with Flash Attention" - ) - else: - return CausalLM( - model_id=model_id, - model_class=PhiForCausalLM, - config_class=PhiConfig, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - - elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3: - print(f">>> model_type: {model_type}") - if FLASH_ATTENTION: - return FlashCausalLM( - model_id=model_id, - model_class=FlashLlamaForCausalLM, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - lora_adapter_ids=lora_adapter_ids, - ) - elif sharded: - raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama")) - else: - return CausalLM.fallback( - model_id, - revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - if model_type == GEMMA: - if FLASH_ATTENTION: - return FlashCausalLM( - model_id=model_id, - model_class=FlashGemmaForCausalLM, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - # Works better for these models - default_dtype=torch.bfloat16, - trust_remote_code=trust_remote_code, - lora_adapter_ids=lora_adapter_ids, - ) - elif sharded: - raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma")) - else: - return CausalLM.fallback( - model_id, - revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - elif model_type == GEMMA2: - if FLASH_ATTENTION: - return FlashCausalLM( - model_id=model_id, - model_class=FlashGemma2ForCausalLM, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - # Works better for these models - default_dtype=torch.bfloat16, - trust_remote_code=trust_remote_code, - lora_adapter_ids=lora_adapter_ids, - ) - elif sharded: - raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2")) - else: - return CausalLM.fallback( - model_id, - revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - - if model_type == COHERE: - if FLASH_ATTENTION: - return FlashCausalLM( - model_id=model_id, - model_class=FlashCohereForCausalLM, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - lora_adapter_ids=lora_adapter_ids, - ) - elif sharded: - raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere")) - else: - return CausalLM.fallback( - model_id, - revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - - if model_type == DBRX: - if FLASH_ATTENTION: - return FlashCausalLM( - model_id=model_id, - model_class=FlashDbrxForCausalLM, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - # Dbrx works better in bfloat16. - default_dtype=torch.bfloat16, - trust_remote_code=trust_remote_code, - lora_adapter_ids=lora_adapter_ids, - config_class=DbrxConfig, - ) - elif sharded: - raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX")) - else: - return CausalLM.fallback( - model_id, - revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - - if model_type in ["RefinedWeb", "RefinedWebModel", FALCON]: - if sharded: - if FLASH_ATTENTION: - if config_dict.get("alibi", False): - raise NotImplementedError("sharded is not supported for this model") - return FlashCausalLM( - model_id=model_id, - model_class=FlashRWForCausalLM, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - aliases={ - "lm_head.weight": ["transformer.word_embeddings.weight"], - "transformer.word_embeddings.weight": ["lm_head.weight"], - }, - trust_remote_code=trust_remote_code, - lora_adapter_ids=lora_adapter_ids, - config_class=RWConfig, - ) - raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Falcon")) - else: - if FLASH_ATTENTION and not config_dict.get("alibi", False): - return FlashCausalLM( - model_id=model_id, - model_class=FlashRWForCausalLM, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - aliases={ - "lm_head.weight": ["transformer.word_embeddings.weight"], - "transformer.word_embeddings.weight": ["lm_head.weight"], - }, - trust_remote_code=trust_remote_code, - lora_adapter_ids=lora_adapter_ids, - config_class=RWConfig, - ) - else: - return CausalLM.fallback( - model_id, - revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - - if model_type == MISTRAL: - if FLASH_ATTENTION: - return FlashCausalLM( - model_id=model_id, - model_class=FlashMistralForCausalLM, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - lora_adapter_ids=lora_adapter_ids, - ) - elif sharded: - raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral")) - else: - return CausalLM.fallback( - model_id, - revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - - if model_type == MIXTRAL: - if FLASH_ATTENTION: - return FlashCausalLM( - model_id=model_id, - model_class=FlashMixtralForCausalLM, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - lora_adapter_ids=lora_adapter_ids, - ) - elif sharded: - raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral")) - else: - return CausalLM.fallback( - model_id, - revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - - if model_type == STARCODER2: - if FLASH_ATTENTION: - return FlashCausalLM( - model_id=model_id, - model_class=FlashStarcoder2ForCausalLM, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - lora_adapter_ids=lora_adapter_ids, - ) - elif sharded: - raise NotImplementedError( - FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2") - ) - else: - return CausalLM.fallback( - model_id, - revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - - if model_type == QWEN2: - if FLASH_ATTENTION: - return FlashCausalLM( - model_id=model_id, - model_class=Qwen2ForCausalLM, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - lora_adapter_ids=lora_adapter_ids, - ) - elif sharded: - raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2")) - else: - return CausalLM.fallback( - model_id, - revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - - if model_type == OPT: - return CausalLM( - model_id=model_id, - model_class=OPTForCausalLM, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - - if model_type == T5: - return Seq2SeqLM( - model_id=model_id, - model_class=T5ForConditionalGeneration, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - aliases={ - "shared.weight": [ - "encoder.embed_tokens.weight", - "decoder.embed_tokens.weight", - ] - }, - ) - if model_type == IDEFICS: - if FLASH_ATTENTION: - return IDEFICSSharded( - model_id, - revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - else: - raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) - if model_type == IDEFICS2: - if FLASH_ATTENTION: - return VlmCausalLM( - model_id=model_id, - model_class=Idefics2ForConditionalGeneration, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - lora_adapter_ids=lora_adapter_ids, - # XXX: Extremely important to cap resolution in order to limit - # VRAM usage. - processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}}, - ) - else: - raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) - if model_type == PALIGEMMA: - if FLASH_ATTENTION: - return VlmCausalLM( - model_id=model_id, - model_class=PaliGemmaForConditionalGeneration, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - # Works better for these models - default_dtype=torch.bfloat16, - trust_remote_code=trust_remote_code, - lora_adapter_ids=lora_adapter_ids, - batch_class=PaliGemmaBatch, - ) - else: - raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) - - if model_type == LLAVA_NEXT: - if FLASH_ATTENTION: - return VlmCausalLM( - model_class=LlavaNextForConditionalGeneration, - model_id=model_id, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - else: - raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext")) - - if sharded: - raise NotImplementedError("sharded is not supported for AutoModel") - if quantize == "gptq": - raise NotImplementedError( - "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" - ) - if quantize == "awq": - raise NotImplementedError("awq quantization is not supported for AutoModel") - elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"): - raise NotImplementedError("4bit quantization is not supported for AutoModel") - elif quantize == "eetq": - raise NotImplementedError("Eetq quantization is not supported for AutoModel") - elif quantize == "exl2": - raise NotImplementedError("exl2 quantization is not supported for AutoModel") if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: - return CausalLM.fallback( - model_id, - revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES: - return Seq2SeqLM.fallback( + return CausalLM( model_id, revision, quantize=quantize, @@ -1160,27 +202,6 @@ def get_model( trust_remote_code=trust_remote_code, ) - auto_map = config_dict.get("auto_map", None) - if trust_remote_code and auto_map is not None: - if "AutoModelForCausalLM" in auto_map.keys(): - return CausalLM.fallback( - model_id, - revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - if "AutoModelForSeq2SeqLM" in auto_map.keys(): - return Seq2SeqLM.fallback( - model_id, - revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - raise ValueError(f"Unsupported model type {model_type}") @@ -1193,7 +214,7 @@ def get_model_with_lora_adapters( sharded: bool, quantize: Optional[str], speculate: Optional[int], - dtype: Optional[str], + dtype: Optional[torch.dtype], trust_remote_code: bool, max_input_tokens: int, adapter_to_index: Dict[str, int], diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 732b4c5394c..6fe64374835 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -1,11 +1,10 @@ +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. + import torch -import torch.distributed from typing import Optional, Type -from transformers import ( - PreTrainedTokenizerBase, -) +from transformers import PreTrainedTokenizerBase from text_generation_server.models import CausalLM from text_generation_server.models.causal_lm import CausalLMBatch @@ -21,26 +20,33 @@ def from_pb( dtype: torch.dtype, device: torch.device, ) -> "CausalLMBatch": - batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device) + batch = super().from_pb( + pb=pb, + tokenizer=tokenizer, + dtype=dtype, + device=device, + ) batch.keys_head_dim_last = False return batch -class BLOOMSharded(CausalLM): - @property - def batch_type(self) -> Type[CausalLMBatch]: - return BloomCausalLMBatch - - def forward( - self, input_ids, attention_mask, position_ids, past_key_values: Optional = None +class BLOOM(CausalLM): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + speculator: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, ): - outputs, speculative_logits = self.model.forward( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - use_cache=True, + super(BLOOM, self).__init__( + model_id=model_id, + revision=revision, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, ) - logits = outputs.logits - return logits, speculative_logits, outputs.past_key_values + @property + def batch_type(self) -> Type[CausalLMBatch]: + return BloomCausalLMBatch diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 28534d0f73b..9f88c4fc6ee 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -1,27 +1,40 @@ -import torch -import time -import torch.distributed +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. +import bisect from dataclasses import dataclass +from functools import wraps +import itertools +import math +import os +import tempfile +import time +import copy +from typing import Dict, List, Optional, Tuple, Type + +import torch +import torch._dynamo +from loguru import logger from opentelemetry import trace + +import text_generation_server.habana_quantization_env as hq_env +import habana_frameworks.torch as htorch +from optimum.habana.utils import HabanaProfile +from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES +from text_generation_server.utils.chunks import concat_text_chunks +from optimum.habana.checkpoint_utils import ( + get_repo_root, + model_on_meta, + write_checkpoints_json, +) from transformers import ( - AutoConfig, AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, + AutoConfig, ) -from typing import Optional, Tuple, List, Type, Dict -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) -from text_generation_server.models import Model -from text_generation_server.utils.chunks import concat_text_chunks -from text_generation_server.utils.import_utils import SYSTEM -from text_generation_server.utils.quantization import get_loader from text_generation_server.utils.tokens import batch_top_tokens +from text_generation_server.models import Model from text_generation_server.models.types import ( Batch, Tokens, @@ -29,55 +42,441 @@ GeneratedText, ) from text_generation_server.pb import generate_pb2 -from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling +from text_generation_server.utils import ( + HeterogeneousNextTokenChooser, + StoppingCriteria, + make_tokenizer_optional, + is_tokenizer_transparent, + pad_next_token_chooser_parameters, +) +from optimum.habana.utils import get_hpu_memory_stats +from text_generation_server.utils.debug import dbg_trace +from text_generation_server.utils.speculate import get_speculate tracer = trace.get_tracer(__name__) +# MAX_TOTAL_TOKENS = int(os.environ.get('MAX_TOTAL_TOKENS', 2048)) +# BATCH_BUCKET_SIZE = int(os.environ.get('BATCH_BUCKET_SIZE', 8)) +# PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF', 128)) +# PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get('PREFILL_BATCH_BUCKET_SIZE', 4)) +# CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048] +# LAZY_MODE = int(os.environ.get('PT_HPU_LAZY_MODE', 1)) +MAX_TOTAL_TOKENS = int(os.environ.get('MAX_TOTAL_TOKENS', 8192)) +MAX_BATCH_TOTAL_TOKENS = int(os.environ.get('MAX_BATCH_TOTAL_TOKENS', 65536)) +PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF', 256)) +CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048] +LAZY_MODE = int(os.environ.get('PT_HPU_LAZY_MODE', 1)) + + +PREFILL_WARMUP_BATCH_SIZE_LIST = [] +PREFILL_WARMUP_SEQLEN_LIST = [] +DECODE_WARMUP_BATCH_SIZE_LIST = [] + + +def torch_compile_for_eager(func): + if LAZY_MODE == 1: + return func + return torch.compile(func, backend="hpu_backend", options={"keep_input_mutations": True}) + + +def round_up(warmup_list:list, num) : + i = 0 + for i in warmup_list: + if num <= i : + break + return i + + +def to_tensor_indices(indices, device): + return torch.tensor(indices, dtype=torch.long, device=device) + + +def calculate_chunks(offset): + result = [] + while offset != 0: + sign = 1 if offset > 0 else -1 + best_chunk = min((abs(offset - sign * c), sign * c) for c in CHUNK_SIZES)[1] + result.append(best_chunk) + offset = offset - best_chunk + return result + + +def biggest_single_chunk(offset): + if offset != 0: + idx = bisect.bisect(CHUNK_SIZES, abs(offset)) + return int(math.copysign(CHUNK_SIZES[idx - 1], offset)) + else: + return 0 + + +@torch_compile_for_eager +def grouped_pad(tensor_groups, dims, values): + grouped_result = [] + for tensors, dim, value in zip(tensor_groups, dims, values): + padding = MAX_TOTAL_TOKENS - tensors[0].size(dim) if dim is not None else 0 + if padding > 0: + assert dim in [-1, -2], f'Only dims -1 and -2 are supported! {dim}' + pad_shape = (0, 0, 0, padding) if dim == -2 else (0, padding) + result = [torch.nn.functional.pad(t, pad_shape, value=value) for t in tensors] + else: + result = [t for t in tensors] + grouped_result.append(result) + htorch.core.mark_step() + return grouped_result + + +@torch_compile_for_eager +def roll(tensor, chunk, dim, merge_graphs): + if dim is None: + return tensor + tensor = torch.roll(tensor, chunk, dim) + if not merge_graphs: + htorch.core.mark_step() + return tensor + + +def grouped_roll(tensor_groups, chunk, dims, merge_graphs): + tensor_groups = [[roll(t, chunk, dim, merge_graphs) for t in tensors] for tensors, dim in zip(tensor_groups, dims)] + if merge_graphs: + htorch.core.mark_step() + return tensor_groups + + +@torch_compile_for_eager +def grouped_shift(tensor_groups, dims, offset, merge_graphs): + chunks = calculate_chunks(offset) + for c in chunks: + tensor_groups = grouped_roll(tensor_groups, c, dims, merge_graphs) + return tensor_groups + + +def move(dst_tensors, dst_indices, src_tensors): + bs_dim = 0 + num_indices = dst_indices.size(0) + for i, (dst_t, src_t) in enumerate(zip(dst_tensors, src_tensors)): + if src_t.size(bs_dim) != num_indices: + src_t = torch.narrow(src_t, bs_dim, 0, num_indices) + dst_t.index_copy_(bs_dim, dst_indices, src_t) + htorch.core.mark_step() + + +def grouped_move(dst_tensor_groups, dst_indices, src_tensor_groups): + for dst_tensors, src_tensors in zip(dst_tensor_groups, src_tensor_groups): + move(dst_tensors, dst_indices, src_tensors) + + +@torch_compile_for_eager +def extend_tensor(tensor, padding, dim): + result = torch.cat([tensor, padding], dim=dim) + htorch.core.mark_step() + return result + + +@torch_compile_for_eager +def extend_batch(tensors, target_bs, dim): + diff = target_bs - tensors[0].size(dim) + # TODO: add support for shrinking bs + if diff <= 0: + return tensors + shape = list(tensors[0].shape) + shape[dim] = diff + padding = torch.empty(shape, device=tensors[0].device, dtype=tensors[0].dtype) + tensors = [extend_tensor(t, padding, dim) for t in tensors] + return tensors + + +def grouped_extend_batch(tensor_groups, target_bs, bs_dims): + tensor_groups = [extend_batch(tensors, target_bs, dim) for tensors, dim in zip(tensor_groups, bs_dims)] + return tensor_groups + + +@torch_compile_for_eager +def merge(tensor_group): + tensor_group = [torch.stack(tensor_group)] + htorch.core.mark_step() + return tensor_group + + +@torch_compile_for_eager +def split(tensor_group, clone_data): + tensor_group = [t.squeeze(0) for t in torch.split(tensor_group[0], 1)] + if clone_data: + tensor_group = [t.clone() for t in tensor_group] + htorch.core.mark_step() + return tensor_group + + +def remove_kv_cache_from_output(module): + orig_fwd = module.forward + + @wraps(orig_fwd) + def forward(*args, **kwargs): + if kwargs["past_key_values"] is not None: + kwargs["return_dict"] = False + output = orig_fwd(*args, **kwargs) + first_value, second_value, *_ = output + if first_value.nelement() < 2: + return second_value + else: + return first_value + else: + kwargs["return_dict"] = True + return orig_fwd(*args, **kwargs) + + module.forward = forward + return module + + +@dataclass +class CausalLMRequest: + idx: int + data: generate_pb2.Request + input_length: int + prefix_offset: int + read_offset: int + stopping_criteria: StoppingCriteria + + all_input_ids: torch.Tensor + + @classmethod + def from_pb(cls, idx: int, data: generate_pb2.Request, tokenizer: PreTrainedTokenizerBase): + return cls( + idx=idx, + data=data, + input_length=None, + prefix_offset=None, + read_offset=None, + stopping_criteria=StoppingCriteria.from_pb(data.stopping_parameters, tokenizer), + all_input_ids=None,) + + def update_idx(self, new_idx): + prev = self.idx + self.idx = new_idx + return (new_idx, prev) + @dataclass class CausalLMBatch(Batch): batch_id: int - requests: List[generate_pb2.Request] - requests_idx_mapping: Dict[int, int] + requests: List[CausalLMRequest] # Decoder values input_ids: torch.Tensor attention_mask: torch.Tensor position_ids: torch.Tensor past_key_values: Optional[List[Tuple]] - - # All tokens - all_input_ids: List[torch.Tensor] + merged_kv_cache: bool # Lengths of all generations present in the batch - input_lengths: List[int] - prefix_offsets: List[int] - read_offsets: List[int] + input_length: int # Generation helpers - next_token_choosers: List[NextTokenChooser] - stopping_criterias: List[StoppingCriteria] + next_token_chooser: HeterogeneousNextTokenChooser top_n_tokens: List[int] top_n_tokens_tensor: torch.Tensor - # Metadata used for padding - max_input_length: int - padding_right_offset: int - - # Maximum number of tokens this batch will grow to - max_tokens: int + input_length: int # Past metadata + logits = None + past = None + keys_head_dim_last: bool = True def to_pb(self) -> generate_pb2.CachedBatch: return generate_pb2.CachedBatch( id=self.batch_id, - request_ids=[r.id for r in self.requests], + request_ids=[r.data.id for r in self.requests], size=len(self), max_tokens=self.max_tokens, ) + def detach_kv_cache(self): + past_keys = [past[0] for past in self.past_key_values] + past_values = [past[1] for past in self.past_key_values] + del self.past_key_values + return past_keys, past_values + + def attach_kv_cache(self, past_keys, past_values): + # TODO: Add support for models that don't store kv_cache in a list + self.past_key_values = list(zip(past_keys, past_values)) + + def merge_kv_cache_if_needed(self, target_bs, offset): + pad_needed = self.seq_length < MAX_TOTAL_TOKENS + shift_needed = offset != 0 + expand_needed = target_bs > self.batch_size + # Very simple heuristic to determine whether we should merge tensors + # this needs tuning for other models/scenarios + small_bs = len(self.past_key_values) > self.batch_size + if not self.merged_kv_cache and small_bs and (pad_needed or shift_needed or expand_needed): + past_keys, past_values = self.detach_kv_cache() + past_keys = merge(past_keys) + past_values = merge(past_values) + self.attach_kv_cache(past_keys, past_values) + self.merged_kv_cache = True + + def split_kv_cache_if_needed(self, clone_data): + if self.merged_kv_cache: + past_keys, past_values = self.detach_kv_cache() + past_keys = split(past_keys, clone_data) + past_values = split(past_values, clone_data) + self.attach_kv_cache(past_keys, past_values) + self.merged_kv_cache = False + + def get_tensor_groups(self): + past_keys, past_values = self.detach_kv_cache() + seq_dim = -1 + key_dim = -2 if self.keys_head_dim_last else -1 + value_dim = -2 + tensors = [[self.input_ids], [self.attention_mask], [self.position_ids], past_keys, past_values] + # We don't need to align position_ids + seq_dims = [seq_dim, seq_dim, None, key_dim, value_dim] + bs_dims = [0, 0, 0] + ([1, 1] if self.merged_kv_cache else [0, 0]) + return tensors, seq_dims, bs_dims + + def set_tensor_groups(self, tensors): + self.input_ids = tensors.pop(0)[0] + self.attention_mask = tensors.pop(0)[0] + self.position_ids = tensors.pop(0)[0] + past_keys = tensors.pop(0) + past_values = tensors.pop(0) + self.attach_kv_cache(past_keys, past_values) + + def realign(self, target_bs, offset, pad_token_id): + tensors, seq_dims, _ = self.get_tensor_groups() + tensors = grouped_pad(tensors, seq_dims, [pad_token_id, 0, 0, 0, 0]) + tensors = grouped_shift(tensors, seq_dims, offset, self.merged_kv_cache) + self.set_tensor_groups(tensors) + + def expand_bs(self, target_bs): + tensors, _, bs_dims = self.get_tensor_groups() + tensors = grouped_extend_batch(tensors, target_bs, bs_dims) + self.set_tensor_groups(tensors) + + def used_indices(self): + return [req.idx for req in self.requests] + + def update_indices(self, new_indices): + for req, new_idx in zip(self.requests, new_indices): + req.idx = new_idx + return self.used_indices() + + def free_indices_generator(self): + used = set(req.idx for req in self.requests) + return (i for i in range(self.batch_size) if i not in used) + + def move_data(self, src_batches): + dst_tensors, _, dst_dims = self.get_tensor_groups() + free_indices_gen = self.free_indices_generator() + for src_b in src_batches: + dst_indices = to_tensor_indices(src_b.update_indices(free_indices_gen), self.input_ids.device) + src_tensors, _, src_dims = src_b.get_tensor_groups() + grouped_move(dst_tensors, dst_indices, src_tensors) + self.set_tensor_groups(dst_tensors) + + @classmethod + def recombine(cls, batches: List["CausalLMBatch"], pad_token_id: int, is_warmup: bool =False) -> "CausalLMBatch": + if not all(b.past_key_values is not None for b in batches): + raise ValueError("KV cache not allocated! Cannot recombine before prefill!") + + total_requests = sum(len(b) for b in batches) + new_bs = total_requests + if is_warmup is False : + new_bs = round_up(DECODE_WARMUP_BATCH_SIZE_LIST, total_requests) + + batch_id = batches[0].batch_id + device = batches[0].input_ids.device + + input_lengths = [b.input_length for b in batches] + max_input_length = max(input_lengths) + offsets = [max_input_length - b.input_length for b in batches] + + cur_padding = [b.right_padding for b in batches] + # For prefill there is a space allocated only for first token + # Need to add padding to the max total tokens before first decode + + moves_needed = [total_requests - len(b) if b.batch_size == new_bs else total_requests for b in batches] + dst_batch_idx = min(enumerate(moves_needed), key=lambda idx_val: idx_val[1])[0] + reshape = (batches[dst_batch_idx].batch_size < new_bs) + + # TODO: Add support for changing max seq len, i.e. due to output length bucketing + # FIXME: max_seq_len for non optimized code + if len(batches) > 1: + scenario = 'CONCAT' + elif reshape: + scenario = 'RESHAPE' + elif cur_padding[dst_batch_idx] <= 0: + scenario = 'SHIFT' + offsets = [biggest_single_chunk(b.max_input_length - max_input_length) for b in batches] + max_input_length = max_input_length + offsets[dst_batch_idx] + else: + # Nothing to do + return batches[0] + + dbg_trace( + scenario, f'bs:{[b.batch_size for b in batches]}->{new_bs}' + f' reqs:{[len(b) for b in batches]}' + f' offsets:{offsets}' + f' input_lengths:{input_lengths}' + f' cur_padding:{cur_padding}' + f' dst_batch:{dst_batch_idx}') + + grouped_requests = [[req for req in batch.requests] for batch in batches] + flat_requests = list(itertools.chain(*grouped_requests)) + + for i in range(len(batches)): + target_bs = new_bs if i == dst_batch_idx else batches[i].batch_size + batches[i].merge_kv_cache_if_needed(target_bs, offsets[i]) + batches[i].realign(target_bs, offsets[i], pad_token_id) + batches[i].split_kv_cache_if_needed(i == dst_batch_idx) + batches[dst_batch_idx].expand_bs(new_bs) + batches[dst_batch_idx].move_data([batches[i] for i in range(len(batches)) if i != dst_batch_idx]) + + top_n_tokens = [r.data.top_n_tokens for r in flat_requests] + top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64) + + parameters = [r.data.parameters for r in flat_requests] + # append the dummy parameters for dummy requests + batch_size = batches[dst_batch_idx].batch_size + parameters = pad_next_token_chooser_parameters(parameters, batch_size) + + # update past grammar states + fsm_grammar_states = [0] * batch_size + for batch in batches: + for i, req in enumerate(batch.requests): + fsm_grammar_states[req.idx] = batch.next_token_chooser.fsm_grammar_states[i] + + next_token_chooser = HeterogeneousNextTokenChooser.from_pb( + parameters, + batches[dst_batch_idx].next_token_chooser.dtype, + batches[dst_batch_idx].next_token_chooser.device, + batches[dst_batch_idx].next_token_chooser.tokenizer, + fsm_grammar_states, + quantization_enabled=hq_env.is_quantization_enabled, + ) + + input_ids = batches[dst_batch_idx].input_ids + attention_mask = batches[dst_batch_idx].attention_mask + position_ids = batches[dst_batch_idx].position_ids + past_key_values = batches[dst_batch_idx].past_key_values + input_length = max_input_length + + htorch.core.mark_step() + + return cls( + batch_id=batch_id, + requests=flat_requests, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + merged_kv_cache=False, + next_token_chooser=next_token_chooser, + top_n_tokens=top_n_tokens, + top_n_tokens_tensor=top_n_tokens_tensor, + input_length=input_length, + ) + @classmethod def from_pb( cls, @@ -85,415 +484,145 @@ def from_pb( tokenizer: PreTrainedTokenizerBase, dtype: torch.dtype, device: torch.device, + is_warmup: bool = False, ) -> "CausalLMBatch": + dbg_trace('FROM_PB', f'num_reqs:{len(pb.requests)}') + requests = [CausalLMRequest.from_pb(idx, req, tokenizer) for idx, req in enumerate(pb.requests)] inputs = [] - next_token_choosers = [] - stopping_criterias = [] top_n_tokens = [] - prefix_offsets = [] - read_offsets = [] - requests_idx_mapping = {} # Parse batch max_truncation = 0 - padding_right_offset = 0 - max_decode_tokens = 0 for i, r in enumerate(pb.requests): - requests_idx_mapping[r.id] = i inputs.append(concat_text_chunks(r.input_chunks.chunks)) - - next_token_choosers.append( - NextTokenChooser.from_pb(r.parameters, device, tokenizer) - ) - stopping_criteria = StoppingCriteria.from_pb( - r.stopping_parameters, tokenizer - ) - stopping_criterias.append(stopping_criteria) top_n_tokens.append(r.top_n_tokens) max_truncation = max(max_truncation, r.truncate) - max_decode_tokens += stopping_criteria.max_new_tokens - padding_right_offset = max( - padding_right_offset, stopping_criteria.max_new_tokens - ) + + max_new_tokens = max(r.stopping_criteria.max_new_tokens for r in requests) + max_input_length = max_truncation + # TODO: by tokenizing all inputs at once we loose information on actual input lengths + # this means that we cannot shift inputs to the left after a long input sequence + # was filtered out + new_bs = round_up(PREFILL_WARMUP_BATCH_SIZE_LIST, len(requests)) + missing_inputs = new_bs - len(inputs) + dummy_inputs = ["?"] * missing_inputs + parameters = [r.parameters for r in pb.requests] + # append the dummy parameters for dummy request + parameters = pad_next_token_chooser_parameters(parameters, new_bs) + + next_token_chooser = HeterogeneousNextTokenChooser.from_pb( + pb=parameters, + dtype=dtype, + device=device, + tokenizer=tokenizer, + quantization_enabled=hq_env.is_quantization_enabled, + ) tokenized_inputs = tokenizer( - inputs, + inputs+dummy_inputs, return_tensors="pt", - padding=True, + padding="longest", return_token_type_ids=False, truncation=True, max_length=max_truncation, ).to(device) - for _ in pb.requests: - input_len = tokenized_inputs["input_ids"].shape[1] - prefix_offsets.append(input_len - 5) - read_offsets.append(input_len) - input_lengths = tokenized_inputs["attention_mask"].sum(1) - max_input_length = input_lengths.max() + input_len = tokenized_inputs["input_ids"].shape[1] + + # Round up sequence length + bucket_size = max_input_length + left_padding = max_input_length - input_len + if is_warmup is False: + if input_len < max_input_length and PAD_SEQUENCE_TO_MULTIPLE_OF != 0: + assert PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length, "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length" + rounded_seq_len = round_up(PREFILL_WARMUP_SEQLEN_LIST, input_len + 1) + if rounded_seq_len <= max_input_length: + bucket_size = rounded_seq_len - 1 + else: + bucket_size = max_input_length - 1 + left_padding = bucket_size - input_len input_ids = tokenized_inputs["input_ids"] - # Allocate maximum attention_mask - attention_mask = input_ids.new_zeros( - (pb.size, max_input_length + padding_right_offset) + attention_mask = tokenized_inputs["attention_mask"] + + # Allocate space for first token + input_ids = torch.nn.functional.pad( + input_ids, (left_padding, 1), value=tokenizer.pad_token_id ) - # Copy tokenizer attention_mask into fully allocated attention_mask - attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"] + attention_mask = torch.nn.functional.pad( + attention_mask, (left_padding, 1), value=0 + ) + all_input_ids = torch.nn.functional.pad( + input_ids, (0, max_new_tokens), value=tokenizer.pad_token_id + ).T.split(1, dim=1)[0:len(pb.requests)] + input_len = bucket_size + for r in requests: + r.input_length = input_len + r.prefix_offset = input_len - 5 + r.read_offset = input_len + r.all_input_ids = all_input_ids[r.idx] + + input_ids = input_ids.to(device) + attention_mask = attention_mask.to(device) + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + + htorch.core.mark_step() - position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 - position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) - all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1) top_n_tokens_tensor = torch.tensor( top_n_tokens, device=device, dtype=torch.int64 ) - - max_tokens = len(inputs) * (max_input_length + max_decode_tokens) - + htorch.core.mark_step() return cls( batch_id=pb.id, - requests=pb.requests, - requests_idx_mapping=requests_idx_mapping, + requests=requests, input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=None, - all_input_ids=list(all_input_ids), - input_lengths=input_lengths.tolist(), - prefix_offsets=prefix_offsets, - read_offsets=read_offsets, - next_token_choosers=next_token_choosers, - stopping_criterias=stopping_criterias, + merged_kv_cache=False, + next_token_chooser=next_token_chooser, top_n_tokens=top_n_tokens, top_n_tokens_tensor=top_n_tokens_tensor, - max_input_length=max_input_length.item(), - padding_right_offset=padding_right_offset, - max_tokens=max_tokens, + input_length=input_len, ) @tracer.start_as_current_span("filter") def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]: - if len(request_ids) == 0: - raise ValueError("Batch must have at least one request") - if len(request_ids) == len(self): - return self - - keep_indices = [] - - # New values after filtering - requests_idx_mapping = {} - requests = [] - input_lengths = [] - prefix_offsets = [] - read_offsets = [] - all_input_ids = [] - max_input_length = 0 - - next_token_choosers = [] - stopping_criterias = [] - top_n_tokens = [] - - total_remaining_decode_tokens = 0 - new_padding_right_offset = 0 - - for i, request_id in enumerate(request_ids): - idx = self.requests_idx_mapping[request_id] - requests_idx_mapping[request_id] = i - keep_indices.append(idx) - - requests.append(self.requests[idx]) - prefix_offsets.append(self.prefix_offsets[idx]) - read_offsets.append(self.read_offsets[idx]) - all_input_ids.append(self.all_input_ids[idx]) - - request_input_length = self.input_lengths[idx] - input_lengths.append(request_input_length) - max_input_length = max(max_input_length, request_input_length) - - next_token_choosers.append(self.next_token_choosers[idx]) - stopping_criteria = self.stopping_criterias[idx] - stopping_criterias.append(stopping_criteria) - top_n_tokens.append(self.top_n_tokens[idx]) - remaining_decode_tokens = ( - stopping_criteria.max_new_tokens - stopping_criteria.current_tokens - ) - total_remaining_decode_tokens += remaining_decode_tokens - new_padding_right_offset = max( - new_padding_right_offset, remaining_decode_tokens - ) - - # Apply indices to input_ids, attention mask, past key values and other items that need to be cached - input_ids = self.input_ids[keep_indices] - position_ids = self.position_ids[keep_indices] - self.attention_mask = self.attention_mask[ - keep_indices, - -(self.padding_right_offset + max_input_length) : ( - self.attention_mask.shape[1] - self.padding_right_offset - ) - + new_padding_right_offset, - ] - - # Ensure that past_key_values tensors can be updated in-place - if type(self.past_key_values[0]) is tuple: - self.past_key_values = [list(layer) for layer in self.past_key_values] - - # Update tensors in-place to allow incremental garbage collection - past_kv_length = max_input_length - 1 - for layer in self.past_key_values: - past_keys, past_values = layer - if len(past_keys.shape) == 3: - # Force past to be of dim [self_size, num_heads, ...] for easy indexing - past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:]) - past_values = past_values.view(len(self), -1, *past_values.shape[-2:]) - if self.keys_head_dim_last: - layer[0] = past_keys[keep_indices, :, -past_kv_length:, :] - else: - layer[0] = past_keys[keep_indices, :, :, -past_kv_length:] - del past_keys - layer[1] = past_values[keep_indices, :, -past_kv_length:, :] - del past_values - - top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices] - max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens - - self.requests = requests - self.requests_idx_mapping = requests_idx_mapping - self.input_ids = input_ids - self.position_ids = position_ids - self.all_input_ids = all_input_ids - self.input_lengths = input_lengths - self.prefix_offsets = prefix_offsets - self.read_offsets = read_offsets - self.next_token_choosers = next_token_choosers - self.stopping_criterias = stopping_criterias - self.top_n_tokens = top_n_tokens - self.top_n_tokens_tensor = top_n_tokens_tensor - self.max_input_length = max_input_length - self.padding_right_offset = new_padding_right_offset - self.max_tokens = max_tokens - + dbg_trace('FILTER', f'num_reqs:{len(self.requests)} -> {len(request_ids)}') + request_ids = set(request_ids) + self.requests = [req for req in self.requests if req.data.id in request_ids] return self @classmethod @tracer.start_as_current_span("concatenate") - def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch": - # Used for padding - total_batch_size = 0 - max_input_length = 0 - padding_right_offset = 0 - for batch in batches: - total_batch_size += len(batch) - max_input_length = max(max_input_length, batch.max_input_length) - padding_right_offset = max(padding_right_offset, batch.padding_right_offset) - - # Batch attributes - requests = [] - requests_idx_mapping = {} - input_lengths = [] - prefix_offsets = [] - read_offsets = [] - all_input_ids = [] - next_token_choosers = [] - stopping_criterias = [] - top_n_tokens = [] - max_tokens = 0 - - # Batch tensors - input_ids = None - attention_mask = None - position_ids = None - past_key_values = [] - top_n_tokens_tensor = None - - # Used for slicing correctly inside the tensors - # Equivalent to a cumsum on batch sizes - start_index = 0 - for i, batch in enumerate(batches): - requests.extend(batch.requests) - input_lengths.extend(batch.input_lengths) - prefix_offsets.extend(batch.prefix_offsets) - read_offsets.extend(batch.read_offsets) - all_input_ids.extend(batch.all_input_ids) - next_token_choosers.extend(batch.next_token_choosers) - stopping_criterias.extend(batch.stopping_criterias) - top_n_tokens.extend(batch.top_n_tokens) - - if i == 0: - requests_idx_mapping = batch.requests_idx_mapping - else: - # We need to offset the mapping for each batch by the cumulative batch size - for k, v in batch.requests_idx_mapping.items(): - requests_idx_mapping[k] = v + start_index - - # Slicing end index for this batch - end_index = start_index + len(batch) - - # We only concatenate batches that did at least one step - if batch.past_key_values is None: - raise ValueError("only concatenate prefilled batches") - - # Create empty tensor - # input_ids is always of shape [batch_size, 1] - # We do not need to pad it - if input_ids is None: - input_ids = batch.input_ids.new_empty((total_batch_size, 1)) - # Copy to correct indices - input_ids[start_index:end_index] = batch.input_ids - - # Create padded tensor - if attention_mask is None: - attention_mask = batch.attention_mask.new_zeros( - (total_batch_size, max_input_length + padding_right_offset), - ) - - if top_n_tokens_tensor is None: - top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros( - total_batch_size, - ) - top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor - - # We need to slice the attention mask to remove padding from previous steps - # and to remove unused allocated space - left_offset = max_input_length - batch.max_input_length - batch_left_offset = ( - batch.attention_mask.shape[1] - - batch.max_input_length - - batch.padding_right_offset - ) - attention_mask[ - start_index:end_index, - left_offset:-padding_right_offset, - ] = batch.attention_mask[ - :, - batch_left_offset : -batch.padding_right_offset, - ] - - # Create empty tensor - # position_ids is always of shape [batch_size, 1] - if position_ids is None: - position_ids = batch.position_ids.new_empty((total_batch_size, 1)) - position_ids[start_index:end_index] = batch.position_ids - - # Shenanigans to get dimensions because BLOOM outputs a past with a different shape - # BLOOM Keys: [batch_size * num_heads, head_dim, seq_length] - # BLOOM Values: [batch_size * num_heads, seq_length, head_dim] - # And ensure that we can update tensors in-place - if isinstance(batch.past_key_values[0], tuple): - batch.past_key_values = [ - [t.view(len(batch), -1, *t.shape[-2:]) for t in layer] - for layer in batch.past_key_values - ] - elif len(batch.past_key_values[0][0].shape) == 3: - for layer in batch.past_key_values: - for k, t in enumerate(layer): - layer[k] = t.view(len(batch), -1, *t.shape[-2:]) - - # Add eventual padding tokens that were added while concatenating - max_tokens += batch.max_tokens + ( - max_input_length - batch.max_input_length - ) * len(batch) - - start_index = end_index - - first_past_kvs = batches[0].past_key_values - _, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape - - padded_past_values_shape = ( - total_batch_size, - num_heads, - max_input_length - 1, - head_dim, - ) - - if batches[0].keys_head_dim_last: - padded_past_keys_shape = padded_past_values_shape - else: - # seq_length is last for BLOOM - padded_past_keys_shape = ( - total_batch_size, - num_heads, - head_dim, - max_input_length - 1, - ) - - # Iterate over attention layers - # Concatenate past key values layer by layer to allow incremental garbage collection - for j in range(len(first_past_kvs)): - padded_past_keys = first_past_kvs[j][0].new_zeros(padded_past_keys_shape) - start_index = 0 - for batch in batches: - past_keys = batch.past_key_values[j][0] - # Clear reference to the original tensor - batch.past_key_values[j][0] = None - - # Slicing end index for this batch - end_index = start_index + len(batch) - # We slice the keys to remove the padding from previous batches - past_seq_len = batch.max_input_length - 1 - if batch.keys_head_dim_last: - padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = ( - past_keys[:, :, -past_seq_len:, :] - ) - else: - # BLOOM case - padded_past_keys[start_index:end_index, :, :, -past_seq_len:] = ( - past_keys[:, :, :, -past_seq_len:] - ) - del past_keys - - start_index = end_index - - padded_past_values = first_past_kvs[j][1].new_zeros( - padded_past_values_shape - ) - start_index = 0 - for batch in batches: - past_values = batch.past_key_values[j][1] - # Clear reference to the original tensor - batch.past_key_values[j][1] = None - - # Slicing end index for this batch - end_index = start_index + len(batch) - # We slice the past values to remove the padding from previous batches - past_seq_len = batch.max_input_length - 1 - padded_past_values[start_index:end_index, :, -past_seq_len:, :] = ( - past_values[:, :, -past_seq_len:, :] - ) - del past_values + def concatenate(cls, batches: List["CausalLMBatch"], pad_token_id: int = 0, is_warmup: bool = False) -> "CausalLMBatch": + return cls.recombine(batches, pad_token_id, is_warmup) - # Update values - start_index = end_index + def __len__(self): + return len(self.requests) - past_key_values.append([padded_past_keys, padded_past_values]) + @property + def max_input_length(self): + return max(req.input_length for req in self.requests) - return cls( - batch_id=batches[0].batch_id, - requests=requests, - requests_idx_mapping=requests_idx_mapping, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - all_input_ids=all_input_ids, - input_lengths=input_lengths, - prefix_offsets=prefix_offsets, - read_offsets=read_offsets, - next_token_choosers=next_token_choosers, - stopping_criterias=stopping_criterias, - top_n_tokens=top_n_tokens, - top_n_tokens_tensor=top_n_tokens_tensor, - max_input_length=max_input_length, - padding_right_offset=padding_right_offset, - keys_head_dim_last=batches[0].keys_head_dim_last, - max_tokens=max_tokens, - ) + @property + def batch_size(self): + return self.attention_mask.size(0) - def __len__(self): - return len(self.requests) + @property + def seq_length(self): + return self.attention_mask.size(1) + @property + def right_padding(self): + return self.seq_length - self.input_length -@dataclass -class CausalLMBatchKeysLast(CausalLMBatch): - keys_head_dim_last: bool = False + # Maximum number of tokens this batch will grow to + @property + def max_tokens(self): + max_total_tokens = self.attention_mask.size(1) + return len(self.requests) * max_total_tokens class CausalLM(Model): @@ -510,256 +639,462 @@ def __init__( tokenizer_class=AutoTokenizer, config_class=AutoConfig, batch_class=CausalLMBatch, + ): + + if speculator: + raise RuntimeError("Speculator decoding is not enabled for AutoModel") + + self.prev_bs = 0 self.quantize = quantize - self.batch_class = batch_class - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = default_dtype if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = default_dtype if dtype is None else dtype - else: - device = torch.device("cpu") - # Float16 doesn't exist on target. - dtype = torch.bfloat16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype - tokenizer = tokenizer_class.from_pretrained( + # Create tokenizer + tokenizer = AutoTokenizer.from_pretrained( model_id, revision=revision, padding_side="left", truncation_side="left", trust_remote_code=trust_remote_code, ) + make_tokenizer_optional(tokenizer) - config = config_class.from_pretrained( - model_id, - revision=revision, - trust_remote_code=trust_remote_code, - ) - config.quantize = quantize - config.speculator = speculator - if tokenizer.pad_token_id is None: - if config.pad_token_id is not None: - tokenizer.pad_token_id = config.pad_token_id - elif config.eos_token_id is not None: - tokenizer.pad_token_id = config.eos_token_id - elif tokenizer.eos_token_id is not None: - tokenizer.pad_token_id = tokenizer.eos_token_id - - torch.distributed.barrier(group=self.process_group) - weights_loader = get_loader( - quantize=quantize, model_id=model_id, revision=revision - ) - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights( - filenames, - device=device, - dtype=dtype, - process_group=self.process_group, - weights_loader=weights_loader, - ) + # Create model + world_size = int(os.getenv("WORLD_SIZE", "1")) + rank = int(os.getenv("RANK", "0")) + dtype = torch.bfloat16 if dtype is None else dtype + device = torch.device("hpu") - prefix = "" - model = model_class(prefix, config, weights) - - torch.distributed.barrier(group=self.process_group) - super().__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) + if hq_env.is_quantization_enabled: + htorch.core.hpu_set_env() - @classmethod - def fallback( - cls, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - if speculator: - raise RuntimeError("Speculator decoding is not enabled for AutoModel") + if world_size > 1: + model = self.get_deepspeed_model( + model_id, dtype, revision + ) + model = self.prepare_model_for_quantization(model) + else: + get_repo_root(model_id) - if torch.cuda.is_available(): - device = torch.device("cuda") - dtype = torch.float16 if dtype is None else dtype + # Check support for rope scaling + model_kwargs = {} + config = AutoConfig.from_pretrained( + model_id + ) + if hasattr(config, "rope_scaling"): + model_kwargs["rope_scaling"] = self.get_rope_scaling() + + model = AutoModelForCausalLM.from_pretrained( + model_id, + revision=revision, + torch_dtype=dtype, + trust_remote_code=trust_remote_code, + **model_kwargs + ) + model = self.prepare_model_for_quantization(model) + model = model.eval().to(device) + + self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1 + self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true" + model = remove_kv_cache_from_output(model) + if self.enable_hpu_graph: + from habana_frameworks.torch.hpu import wrap_in_hpu_graph + model = wrap_in_hpu_graph(model, disable_tensor_cache=True) else: - if quantize: - raise ValueError("quantization is not available on CPU") + if LAZY_MODE == 0: + # It is said that "keep_input_mutations" is safe for inference to be done + dbg_trace( + "TORCH COMPILE", f'Torch compiling of model') + model.model = torch.compile(model.model, backend="hpu_backend", options={"keep_input_mutations": True}) - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype + model = self.setup_quantization(model) - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - model = AutoModelForCausalLM.from_pretrained( - model_id, - revision=revision, - torch_dtype=dtype, - device_map=( - "auto" - if torch.cuda.is_available() and torch.cuda.device_count() > 1 - else None - ), - load_in_8bit=quantize == "bitsandbytes", - trust_remote_code=trust_remote_code, - ) - if ( - torch.cuda.is_available() - and torch.cuda.device_count() == 1 - and quantize != "bitsandbytes" - ): - model = model.cuda() + if model.config.model_type not in MODELS_OPTIMIZED_WITH_STATIC_SHAPES: + raise ValueError(f"Model type {model.config.model_type} is not supported!") if tokenizer.pad_token_id is None: if model.config.pad_token_id is not None: tokenizer.pad_token_id = model.config.pad_token_id elif model.config.eos_token_id is not None: - tokenizer.pad_token_id = model.config.eos_token_id + if isinstance(model.config.eos_token_id, int): + tokenizer.pad_token_id = model.config.eos_token_id + elif isinstance(model.config.eos_token_id, list): + tokenizer.pad_token_id = model.config.eos_token_id[0] + else: + raise ValueError( + f"{type(model.config.eos_token_id)} type of eos_token_id in the model's config is not supported for tokenizer.pad_token_id" + ) elif tokenizer.eos_token_id is not None: tokenizer.pad_token_id = tokenizer.eos_token_id else: tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - self = cls.__new__( - cls, - ) - self.batch_class = CausalLMBatch - super().__init__( - self, + self.kwargs = { + "use_cache": True, + "return_dict": True, + } + + if model.config.model_type in ["llama", "mistral", "starcoder2", "qwen2"]: + + if model.config.model_type in ["llama", "mistral", "qwen2"]: + self.kwargs["attn_softmax_bf16"] = True + self.kwargs["trim_logits"] = True + + if os.getenv("USE_FLASH_ATTENTION", "false").lower() == "true": + self.kwargs["use_flash_attention"] = True + if os.getenv("FLASH_ATTENTION_RECOMPUTE", "false").lower() == "true": + self.kwargs["flash_attention_recompute"] = True + + self.speculate = get_speculate() + + super(CausalLM, self).__init__( model_id=model_id, model=model, tokenizer=tokenizer, requires_padding=True, dtype=dtype, device=device, + rank=rank, ) - self.quantize = quantize - return self + + # Create profiler + ranks_to_profile = [int(val) for val in os.getenv("PROF_RANKS", "0").split(',')] + record_shapes = os.getenv("PROF_RECORD_SHAPES", "false").lower() == "true" + output_dir = os.getenv("PROF_PATH", "/tmp/hpu_profile") + self.profiling_warmup_steps = int(os.getenv("PROF_WARMUPSTEP", "0")) if rank in ranks_to_profile else 0 + self.profiling_steps = int(os.getenv("PROF_STEP", "0")) if rank in ranks_to_profile else 0 + self.profiling_wait_steps = int(os.getenv("PROF_WAITSTEP", "0")) + if self.profiling_steps > 0: + self.hb_profiler = HabanaProfile( + wait=self.profiling_wait_steps, + warmup=self.profiling_warmup_steps, + active=self.profiling_steps, + output_dir=output_dir, + record_shapes=record_shapes + ) + self.hb_profiler.start() + else: + self.hb_profiler = None + self.step = 0 + + def get_deepspeed_model( + self, + model_id: str, + dtype: torch.dtype, + revision: Optional[str] = None + ) -> torch.nn.Module: + import deepspeed + from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu + + world_size, rank, local_rank = initialize_distributed_hpu() + model_kwargs = { + "revision": revision + } + + # Initialize process(es) for DeepSpeed + deepspeed.init_distributed(dist_backend="hccl") + logger.info( + "DeepSpeed is enabled. world_size {} rank {} local_rank {}".format(world_size, rank, local_rank) + ) + config = AutoConfig.from_pretrained(model_id, **model_kwargs) + load_to_meta = model_on_meta(config) + + # Check support for rope scaling + if hasattr(config, "rope_scaling"): + config.rope_scaling = self.get_rope_scaling() + model_kwargs["rope_scaling"] = self.get_rope_scaling() + + if load_to_meta: + # Construct model with fake meta tensors, later will be replaced on devices during ds-inference ckpt load + with deepspeed.OnDevice(dtype=dtype, device="meta"): + model = AutoModelForCausalLM.from_config(config, torch_dtype=dtype) + else: + get_repo_root(model_id, local_rank=os.getenv("LOCAL_RANK")) + # TODO: revisit placement on CPU when auto-injection is possible + with deepspeed.OnDevice(dtype=dtype, device="cpu"): + model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, **model_kwargs) + model = model.eval() + + # Initialize the model + ds_inference_kwargs = {"dtype": dtype} + ds_inference_kwargs["tensor_parallel"] = {"tp_size": world_size} + ds_inference_kwargs["enable_cuda_graph"] = False + + if load_to_meta: + # model loaded to meta is managed differently + checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w") + write_checkpoints_json(model_id, local_rank, checkpoints_json) + ds_inference_kwargs["checkpoint"] = checkpoints_json.name + model = deepspeed.init_inference(model, **ds_inference_kwargs) + + return model.module + + def get_rope_scaling(self) -> Optional[Dict]: + rope_scaling = os.getenv("ROPE_SCALING", None) + if rope_scaling is None: + return None + + rope_factor = float(os.getenv("ROPE_FACTOR", 1.0)) + return { + 'type': rope_scaling, 'factor': float(rope_factor) + } + + def setup_quantization(self, model): + if hq_env.is_quantization_enabled: + htorch.core.quantization._mark_params_as_const(model) + htorch.core.quantization._check_params_as_const(model) + htorch.core.hpu_initialize(model) + return model + + def prepare_model_for_quantization(self, model): + if hq_env.is_quantization_enabled: + if model.config.model_type == "llama": + self.patch_scoped_linear_all_reduce(model) + model = hq_env.prepare_model_for_quantization(model) + return model + + def patch_scoped_linear_all_reduce(self, model): + from deepspeed.module_inject.layers import LinearAllreduce + from optimum.habana.transformers.models.modeling_all_models import ScopedLinearAllReduce + for name, module in model.named_children(): + if type(module) is LinearAllreduce: + SL = ScopedLinearAllReduce(mod=module) + setattr(model, name, SL) + self.patch_scoped_linear_all_reduce(module) @property def batch_type(self) -> Type[CausalLMBatch]: - return self.batch_class + return CausalLMBatch + + def decode(self, generated_ids: List[int]) -> str: + return self.tokenizer.decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) + + def decode_token( + self, + all_input_ids: List[int], + prefix_offset: int = 0, + read_offset: int = 0, + ) -> Tuple[str, int, int]: + if is_tokenizer_transparent(self.tokenizer): + new_text = self.tokenizer.decode(all_input_ids[read_offset:], skip_special_tokens=False) + return new_text, read_offset, len(all_input_ids) + else: + return super().decode_token(all_input_ids, prefix_offset, read_offset) def forward( - self, input_ids, attention_mask, position_ids, past_key_values: Optional = None - ) -> Tuple[ - torch.Tensor, Optional[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]] - ]: + self, + input_ids, + attention_mask, + position_ids, + token_idx, + past_key_values: Optional[List[Tuple]] = None, + bypass_hpu_graph: Optional[bool] = None, + ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: # Model Forward kwargs = { "input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values, - "use_cache": True, - "return_dict": True, + "token_idx": token_idx, } + + # Optimum Habana got "lazy_mode" key-val only supported for llama type of models + if self.model.config.model_type == "llama" : + kwargs["lazy_mode"] = LAZY_MODE == 1 + if self.has_position_ids: kwargs["position_ids"] = position_ids - outputs = self.model.forward(**kwargs) - if isinstance(outputs, tuple): - outputs, speculative_logits = outputs + if bypass_hpu_graph != None: + kwargs["bypass_hpu_graphs"] = bypass_hpu_graph + + kwargs.update(self.kwargs) + if past_key_values is not None: + return self.model.forward(**kwargs) else: - speculative_logits = None - return outputs.logits, speculative_logits, outputs.past_key_values + outputs = self.model.forward(**kwargs) + return outputs.logits, outputs.past_key_values @tracer.start_as_current_span("generate_token") def generate_token( - self, batch: CausalLMBatch + self, batches: List[CausalLMBatch], is_warmup: bool = False ) -> Tuple[List[Generation], Optional[CausalLMBatch], Tuple[int, int]]: start = time.time_ns() - # slice the attention mask to the correct shape - attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] - - logits, speculative_logits, past = self.forward( - batch.input_ids, - attention_mask, - batch.position_ids, - batch.past_key_values, - ) - # Results generations: List[Generation] = [] - stopped = True - - # Speculation is not active for causal - accepted_ids = torch.ones_like(batch.input_ids)[:, 0] - batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( - batch.top_n_tokens, - batch.top_n_tokens_tensor, - torch.log_softmax(logits[:, -1], -1), - accepted_ids, - ) + prev_batches = [] + requests_to_generate = [] + # In order to pipeline any actions on CPU we perform the operation in 3 main stages: + # Stage 1. Collect next token ids of any previously started generations + for batch_id, batch in enumerate(batches): + if batch.logits is not None: + logits = batch.logits + past = batch.past + prefill = batch.past_key_values is None + if prefill: + # no right padding for prefill + token_idx_scalar = batch.attention_mask.shape[-1] - 1 + token_idx = torch.tensor(token_idx_scalar).to(self.device) + else: + token_idx_scalar = batch.attention_mask.shape[-1] - batch.right_padding + token_idx = torch.tensor(token_idx_scalar).to(self.device) + + # Select next token + input_length = batch.input_length + if logits.shape[-2] > 1: + next_token_ids, next_token_logprobs, logprobs, _, _ = batch.next_token_chooser( + batch.input_ids, logits[:, input_length - 1: input_length, :].squeeze(-2), self.speculate + ) + else: + next_token_ids, next_token_logprobs, logprobs, _, _ = batch.next_token_chooser( + batch.input_ids, logits.squeeze(-2), self.speculate + ) + # Speculation is not active for causal + accepted_ids = torch.ones_like(batch.input_ids)[:, 0] + batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( + batch.top_n_tokens, + batch.top_n_tokens_tensor, + logprobs, + accepted_ids, + ) - start_decode = time.time_ns() + prev_batches.append({ + 'next_token_ids': next_token_ids, + 'next_token_logprobs': next_token_logprobs, + }) - # Zipped iterator - iterator = zip( - batch.requests, - batch.input_lengths, - batch.prefix_offsets, - batch.read_offsets, - logits, - batch.next_token_choosers, - batch.stopping_criterias, - batch.all_input_ids, - batch.top_n_tokens, - batch_top_token_ids, - batch_top_token_logprobs, - ) + for req_idx, req in enumerate(batch.requests): + requests_to_generate.append({ + 'req': req, + 'prev_req_idx': req.idx, + 'batch_id': batch_id, + 'seed': batch.next_token_chooser.seeds[req_idx], + 'do_sample': batch.next_token_chooser.do_sample[req_idx], + 'top_n_tokens': batch.top_n_tokens[req_idx], + 'top_token_ids': batch_top_token_ids[req_idx], + 'top_token_logprobs': batch_top_token_logprobs[req_idx], + 'grammar_state': batch.next_token_chooser.fsm_grammar_states[req_idx], + + }) + + htorch.core.mark_step() + + # Add new token into input_ids + batch.input_ids.index_copy_(1, token_idx, next_token_ids.unsqueeze(1)) + + # Update attention_mask as we added a new token to input_ids + batch.attention_mask.index_fill_(1, token_idx, 1) + + # Adjust lengths + batch.input_length += 1 + + # Update position_ids + if prefill: + batch.position_ids = torch.index_select(batch.position_ids, 1, token_idx - 1) + 1 + else: + batch.position_ids += 1 + # Update past key values + if prefill: + batch.past_key_values = past + + htorch.core.mark_step() - # For each member of the batch - for i, ( - request, - input_length, - prefix_offset, - read_offset, - logits, - next_token_chooser, - stopping_criteria, - all_input_ids, - top_n_tokens, - top_token_ids, - top_token_logprobs, - ) in enumerate(iterator): - # Select next token - next_token_id, logprobs = next_token_chooser( - all_input_ids.view(1, -1), logits[-1:, :] + # Stage 2. Prepare new batch for speculative scheduling + if len(batches) > 1: + batch = self.batch_type.concatenate(batches, self.tokenizer.pad_token_id, is_warmup) + else: + batch = batches[0] + + prefill = batch.past_key_values is None + + # Check if we need to do any bookkeeping first + if not prefill: + batch = batch.__class__.recombine([batch], self.tokenizer.pad_token_id, is_warmup) + + scenario = 'PREFILL' if prefill else 'GENERATE' + if self.enable_hpu_graph and self.limit_hpu_graph and round_up(DECODE_WARMUP_BATCH_SIZE_LIST, batch.batch_size) != self.prev_bs: + self.model.clear_cache() + self.prev_bs = round_up(DECODE_WARMUP_BATCH_SIZE_LIST, batch.batch_size) + dbg_trace( + scenario, f'bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}') + assert batch.right_padding > 0, 'No more room for next token!' + + # Execute batch + if prefill: + # no right padding for prefill + token_idx = torch.tensor(batch.attention_mask.shape[-1] - 1).to(self.device) + batch.logits, batch.past = self.forward( + batch.input_ids, + batch.attention_mask, + batch.position_ids, + token_idx, + batch.past_key_values, + bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None, + ) + elif all([req.stopping_criteria.max_new_tokens == 1 for req in batch.requests]): + # Don't schedule next forward if max_new_tokens for all requests equals 1 + # - we've already generated the first and only needed token in the prefill phase + pass + else: + token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device) + input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1) + batch.logits = self.forward( + input_ids, + batch.attention_mask, + batch.position_ids, + token_idx, + batch.past_key_values, + bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None, ) + htorch.core.mark_step() + + start_decode = time.time_ns() + + # Stage 3. Finish and return previous generations + stopped = len(requests_to_generate) > 0 + for prev_batch in prev_batches: + prev_batch['next_token_logprobs'] = prev_batch['next_token_logprobs'].tolist() + prev_batch['next_token_ids_cpu'] = prev_batch['next_token_ids'].cpu() + htorch.core.mark_step() + + for req_data in requests_to_generate: + req = req_data['req'] + i = req_data['prev_req_idx'] + prev_batch_id = req_data['batch_id'] + assert len(prev_batches) > prev_batch_id + next_token_ids_cpu = prev_batches[prev_batch_id]['next_token_ids_cpu'] + next_token_logprobs = prev_batches[prev_batch_id]['next_token_logprobs'] + + request = req.data + input_length = req.input_length + prefix_offset = req.prefix_offset + read_offset = req.read_offset + do_sample = req_data['do_sample'] + seed = req_data['seed'] + stopping_criteria = req.stopping_criteria + all_input_ids = req.all_input_ids + next_token_id = next_token_ids_cpu[i] + next_token_logprob = next_token_logprobs[i] + top_n_tokens = req_data['top_n_tokens'] + top_token_ids = req_data['top_token_ids'] + top_token_logprobs = req_data['top_token_logprobs'] + grammar_state = req_data['grammar_state'] + # Append next token to all tokens - all_input_ids = torch.cat([all_input_ids, next_token_id]) + all_input_ids[input_length] = next_token_id new_input_length = input_length + 1 # Generated token - next_token_logprob = logprobs[-1, next_token_id] - next_token_id_squeezed = next_token_id.squeeze() - next_token_text, prefix_offset, read_offset = self.decode_token( - all_input_ids[:, 0], prefix_offset, read_offset - ) + if is_tokenizer_transparent(self.tokenizer) and len(stopping_criteria.stop_sequence_criterias) == 0: + next_token_text = '' + else: + next_token_text, prefix_offset, read_offset = self.decode_token( + all_input_ids[0:new_input_length, 0], prefix_offset, read_offset + ) # Evaluate stopping criteria stop, reason = stopping_criteria( - next_token_id_squeezed, + next_token_id, next_token_text, ) @@ -771,23 +1106,17 @@ def generate_token( if i % self.world_size == self.rank: if stop: # Decode generated tokens - output_text, _, _ = self.decode_token( - all_input_ids[:, 0], - prefix_offset=len(all_input_ids) - - stopping_criteria.current_tokens - - 1, - read_offset=len(all_input_ids) - - stopping_criteria.current_tokens, - skip_special_tokens=True, - ) - # Get seed - if isinstance(next_token_chooser.choice, Sampling): - seed = next_token_chooser.choice.seed + if is_tokenizer_transparent(self.tokenizer): + output_text = None else: - seed = None - + output_text = self.decode( + all_input_ids[new_input_length - stopping_criteria.current_tokens: new_input_length, 0] + ) generated_text = GeneratedText( - output_text, stopping_criteria.current_tokens, reason, seed + output_text, + stopping_criteria.current_tokens, + reason, + seed if do_sample else None, ) else: generated_text = None @@ -795,12 +1124,8 @@ def generate_token( # Prefill if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: # Remove generated token to only have prefill and add nan for first prompt token - prefill_logprobs = [float("nan")] + torch.log_softmax( - logits, -1 - ).gather(1, all_input_ids[1:]).squeeze(1)[ - -new_input_length:-1 - ].tolist() - prefill_token_ids = all_input_ids[-new_input_length:-1] + prefill_logprobs = [float("nan")] + next_token_logprobs + prefill_token_ids = all_input_ids[0: new_input_length - 1] prefill_texts = self.tokenizer.batch_decode( prefill_token_ids, clean_up_tokenization_spaces=False, @@ -844,10 +1169,10 @@ def generate_token( request.id, prefill_tokens, Tokens( - [next_token_id_squeezed], + [next_token_id], [next_token_logprob], [next_token_text], - [next_token_id_squeezed.item() in self.all_special_ids], + [next_token_id in self.all_special_ids], ), generated_text, top_tokens, @@ -855,37 +1180,150 @@ def generate_token( generations.append(generation) - # Update values - batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar( - next_token_id_squeezed.item() + batch.next_token_chooser = ( + batch.next_token_chooser.advance_grammar_single_with_past_state( + req.idx, next_token_id, grammar_state + ) ) - batch.input_ids[i, 0] = next_token_id - batch.all_input_ids[i] = all_input_ids - batch.input_lengths[i] = new_input_length - batch.prefix_offsets[i] = prefix_offset - batch.read_offsets[i] = read_offset - batch.max_input_length = max(batch.max_input_length, new_input_length) - # We finished all generations in the batch; there is no next batch - if stopped: - forward_ns = start_decode - start - decode_ns = time.time_ns() - start_decode - return generations, None, (forward_ns, decode_ns) + req.all_input_ids = all_input_ids + req.input_length = new_input_length + req.prefix_offset = prefix_offset + req.read_offset = read_offset - # Slice unused values from prefill - batch.input_ids = batch.input_ids[:, :1] + htorch.core.mark_step() + self.step = self.step + 1 + if self.hb_profiler is not None: + if self.step > self.profiling_wait_steps + self.profiling_warmup_steps + self.profiling_steps: + self.hb_profiler.stop() + else: + self.hb_profiler.step() - # Update attention_mask as we added a new token to input_ids - batch.attention_mask[:, -batch.padding_right_offset] = 1 - # Decrease right offset - batch.padding_right_offset -= 1 + forward_ns = start_decode - start + decode_ns = time.time_ns() - start_decode + return generations, batch if not stopped else None, (forward_ns, decode_ns) - # Update position_ids - batch.position_ids = batch.position_ids[:, -1:] + 1 + def generate_warmup_batch(self, request, seq_len, batch_size, is_warmup): + batch = copy.deepcopy(request.batch) + for req in batch.requests: + req.truncate = seq_len - # Update past key values - batch.past_key_values = past + for i in range(len(batch.requests) - batch_size): + batch.requests.pop() - forward_ns = start_decode - start - decode_ns = time.time_ns() - start_decode - return generations, batch, (forward_ns, decode_ns) + return CausalLMBatch.from_pb(batch, self.tokenizer, self.dtype, self.device, is_warmup) + + + def warmup(self, request) -> None: + is_warmup = True + batch = CausalLMBatch.from_pb(request.batch, self.tokenizer, self.dtype, self.device, is_warmup = is_warmup) + try: + # max prefill batch size warmup + _, prefill_batch, _ = self.generate_token([batch], is_warmup) + except: + raise RuntimeError( + f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. " + f"You need to decrease `--max-batch-prefill-tokens`" + ) + + global MAX_TOTAL_TOKENS, MAX_BATCH_TOTAL_TOKENS, PREFILL_WARMUP_BATCH_SIZE_LIST, PREFILL_WARMUP_SEQLEN_LIST, DECODE_WARMUP_BATCH_SIZE_LIST + max_input_length = batch.input_ids.shape[1] + max_prefill_batch_size = batch.input_ids.shape[0] + PREFILL_WARMUP_BATCH_SIZE_LIST = [] + batch_size = 1 + while batch_size <= max_prefill_batch_size: + PREFILL_WARMUP_BATCH_SIZE_LIST.append(batch_size) + batch_size = batch_size * 2 + if PREFILL_WARMUP_BATCH_SIZE_LIST[-1] < max_prefill_batch_size : + PREFILL_WARMUP_BATCH_SIZE_LIST.append(max_prefill_batch_size) + + seq_len = PAD_SEQUENCE_TO_MULTIPLE_OF + PREFILL_WARMUP_SEQLEN_LIST = [] + i = 0 + while seq_len <= max_input_length: + PREFILL_WARMUP_SEQLEN_LIST.append(seq_len) + seq_len += PAD_SEQUENCE_TO_MULTIPLE_OF*(2**i) + i += 1 + if PREFILL_WARMUP_SEQLEN_LIST[-1] < max_input_length: + PREFILL_WARMUP_SEQLEN_LIST.append(max_input_length) + + #Prefill and decode warmup + DECODE_WARMUP_BATCH_SIZE_LIST = [] + prefill_batch = None + decode_batch = None + try: + for batch_size in PREFILL_WARMUP_BATCH_SIZE_LIST : + for seq_len in PREFILL_WARMUP_SEQLEN_LIST : + batch = self.generate_warmup_batch(request, seq_len, batch_size, is_warmup) + _, prefill_batch, _ = self.generate_token([batch], is_warmup) + _, decode_batch, _ = self.generate_token([prefill_batch], is_warmup) + + DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size) + + except: + raise RuntimeError( + f"Not enough memory to handle following prefill and decode warmup." + f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}" + f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}" + f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}" + f"You need to decrease `--max-batch-prefill-tokens`" + ) + + mem_stats = get_hpu_memory_stats(self.device) + logger.info( + f"\nFollowing prefill and decode warmup successfully.\n" + f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}\n" + f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}\n" + f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}\n" + f"Memory stats: {mem_stats} " + ) + + max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS) + batch_size = max_prefill_batch_size * 2 + # Decode warmup with bigger batch_size + try: + if DECODE_WARMUP_BATCH_SIZE_LIST[-1] < max_decode_batch_size and batch_size <= max_decode_batch_size: + batches = [] + for i in range(int(batch_size/max_prefill_batch_size)) : + batch = self.generate_warmup_batch(request, PREFILL_WARMUP_SEQLEN_LIST[0], DECODE_WARMUP_BATCH_SIZE_LIST[-1], is_warmup) + _, prefill_batch, _ = self.generate_token([batch], is_warmup) + batches.append(prefill_batch) + while batch_size <= max_decode_batch_size: + _, decode_batch, _ = self.generate_token(batches, is_warmup) + DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size) + batch_size = batch_size * 2 + batches.clear() + + for i in range(int(batch_size/max_prefill_batch_size)) : + batch = self.generate_warmup_batch(request, PREFILL_WARMUP_SEQLEN_LIST[0], DECODE_WARMUP_BATCH_SIZE_LIST[-1], is_warmup) + _, prefill_batch, _ = self.generate_token([batch], is_warmup) + batches.append(prefill_batch) + + batches.clear() + if DECODE_WARMUP_BATCH_SIZE_LIST[-1] < max_decode_batch_size: + max_decode_batch_size = math.floor( max_decode_batch_size / 2) * 2 + batch_size = max_decode_batch_size + for i in range(int(max_decode_batch_size / 2)) : + batch = self.generate_warmup_batch(request, PREFILL_WARMUP_SEQLEN_LIST[0], 2, is_warmup) + _, prefill_batch, _ = self.generate_token([batch], is_warmup) + batches.append(prefill_batch) + _, decode_batch, _ = self.generate_token(batches, is_warmup) + DECODE_WARMUP_BATCH_SIZE_LIST.append(max_decode_batch_size) + max_batch_total_tokens = max_decode_batch_size * MAX_TOTAL_TOKENS + MAX_BATCH_TOTAL_TOKENS = max_batch_total_tokens + except : + raise RuntimeError( + f"Not enough memory to handle batch_size({batch_size}) decode warmup." + f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}" + f"max_decode_batch_size is {max_decode_batch_size}" + f"You need to decrease env `MAX_BATCH_TOTAL_TOKENS` or '--max_batch_total_tokens'" + ) + + mem_stats = get_hpu_memory_stats(self.device) + logger.info( + f"\nFollowing decode warmup successfully.\n" + f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}\n" + f"Memory stats: {mem_stats}" + ) + + return MAX_BATCH_TOTAL_TOKENS \ No newline at end of file diff --git a/server/text_generation_server/models/custom_modeling/llava_next.py b/server/text_generation_server/models/custom_modeling/llava_next.py index 32e9d3348b3..1ef55019066 100644 --- a/server/text_generation_server/models/custom_modeling/llava_next.py +++ b/server/text_generation_server/models/custom_modeling/llava_next.py @@ -14,25 +14,18 @@ # limitations under the License. """ PyTorch Llava-NeXT model.""" -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn from transformers.activations import ACT2FN -from transformers.image_processing_utils import select_best_resolution - -from text_generation_server.layers.attention import Seqlen -from text_generation_server.models.custom_modeling.vlm import ( - load_text_model, - load_vision_model, +from transformers.models.llava_next.modeling_llava_next import ( + unpad_image, ) -from text_generation_server.layers import ( - TensorParallelColumnLinear, - TensorParallelRowLinear, -) - +from optimum.habana.transformers.models import GaudiLlavaNextForConditionalGeneration +from transformers.image_processing_utils import select_best_resolution def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): """ @@ -40,7 +33,7 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): Args: image_size (`tuple`): - The size of the input image in the format (height, width). + The size of the input image in the format (width, height). grid_pinpoints (`List`): A list containing possible resolutions. Each item in the list should be a tuple or list of the form `(height, width)`. @@ -48,7 +41,7 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): The size of each image patch. Returns: - tuple: The shape of the image patch grid in the format (height, width). + tuple: The shape of the image patch grid in the format (width, height). """ if not isinstance(grid_pinpoints, list): raise ValueError("grid_pinpoints should be a list of tuples or lists") @@ -57,100 +50,13 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): return height // patch_size, width // patch_size -def unpad_image(tensor, original_size): - """ - Unpads a PyTorch tensor of a padded and resized image. - - Args: - tensor (`torch.Tensor`): - The image tensor, assumed to be of shape (num_channels, height, width). - original_size (`tuple`): - The original size of the image (height, width). - - Returns: - `torch.Tensor`: The unpadded image tensor. - """ - original_height, original_width = original_size - current_height, current_width = tensor.shape[1:] - - original_aspect_ratio = original_width / original_height - current_aspect_ratio = current_width / current_height - - if original_aspect_ratio > current_aspect_ratio: - scale_factor = current_width / original_width - new_height = int(original_height * scale_factor) - padding = (current_height - new_height) // 2 - unpadded_tensor = tensor[:, padding : current_height - padding, :] - else: - scale_factor = current_height / original_height - new_width = int(original_width * scale_factor) - padding = (current_width - new_width) // 2 - unpadded_tensor = tensor[:, :, padding : current_width - padding] - - return unpadded_tensor - - -# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaNext -class LlavaNextMultiModalProjector(nn.Module): - def __init__(self, prefix, config, weights): - super().__init__() - - self.linear_1 = TensorParallelColumnLinear.load( - prefix=f"{prefix}.linear_1", config=config, weights=weights, bias=True - ) - self.act = ACT2FN[config.projector_hidden_act] - self.linear_2 = TensorParallelRowLinear.load( - prefix=f"{prefix}.linear_2", config=config, weights=weights, bias=True - ) - - def forward(self, image_features): - hidden_states = self.linear_1(image_features) - hidden_states = self.act(hidden_states) - hidden_states = self.linear_2(hidden_states) - return hidden_states - - -class LlavaNextForConditionalGeneration(nn.Module): - def __init__(self, prefix, config, weights): - super().__init__() - config.vision_config.quantize = config.quantize - vision_config = config.vision_config - # Instead of selecting in hidden_states[-2]. - # Instead compute only the n -2 + 1 layers and don't pool - if config.vision_feature_layer < 0: - vision_config.num_hidden_layers += config.vision_feature_layer + 1 - else: - vision_config.num_hidden_layers = config.vision_feature_layer + 1 - self.vision_tower = load_vision_model( - prefix="vision_tower" if not prefix else f"{prefix}.vision_tower", - config=config.vision_config, - weights=weights, - ) - - self.multi_modal_projector = LlavaNextMultiModalProjector( - prefix="multi_modal_projector", config=config, weights=weights - ) - - self.image_newline = weights.get_tensor("image_newline") - - self.vocab_size = config.text_config.vocab_size - self.config = config - config.text_config.quantize = config.quantize - config.text_config.speculator = config.speculator - self.text_model = load_text_model( - prefix="language_model" if not prefix else f"{prefix}.language_model", - config=config.text_config, - weights=weights, - ) - self.pad_token_id = ( - config.pad_token_id if config.pad_token_id is not None else -1 - ) - +class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration): + def _merge_input_ids_with_image_features( self, - input_ids: torch.Tensor, inputs_embeds: torch.Tensor, image_features: torch.Tensor, + input_ids: torch.Tensor, ): """In place merges in vision_embeddings with inputs_embeds.""" mask = input_ids == self.config.image_token_index @@ -165,125 +71,226 @@ def _merge_input_ids_with_image_features( def forward( self, - input_ids: torch.Tensor, - position_ids: torch.Tensor, - cu_seqlen_prefill: Optional[torch.Tensor], - kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, - slots: torch.Tensor, - seqlen: Seqlen, - max_s: int, - prefill_cache_indices: Optional[torch.Tensor], - lm_head_indices: Optional[torch.Tensor] = None, + input_ids: torch.LongTensor = None, pixel_values: torch.FloatTensor = None, - # Unused for this model - pixel_attention_mask=None, image_sizes: Optional[torch.LongTensor] = None, - adapter_data: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[int] = None, + vision_feature_select_strategy: Optional[str] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + token_idx: Optional[torch.Tensor] = None, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, ): - inputs_embeds = self.text_model.embed_tokens(input_ids) - if pixel_values is not None and len(pixel_values) > 0: - # num_special_image_tokens = (input_ids == self.config.image_token_index).sum() - # assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid" - # 1. Extract the input embeddings - - # 2. Merge text and images - num_images, num_patches, channels, height, width = pixel_values.shape - pixel_values = pixel_values.view( - num_images * num_patches, channels, height, width - ) - image_features = self.vision_tower(pixel_values) - # selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer] - # Already done within the clip model - selected_image_feature = image_features.last_hidden_state - - if self.config.vision_feature_select_strategy == "default": - selected_image_feature = selected_image_feature[:, 1:] - elif self.config.vision_feature_select_strategy == "full": - selected_image_feature = selected_image_feature - else: - raise RuntimeError( - f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid." - ) - - image_features = self.multi_modal_projector(selected_image_feature) - - # split up image_features for each of the individual images - # hence we get a list of image_features, each of shape (5, num_patches, hidden_size) - # if we assume each image has 5 image features (base image + 4 patches) - split_sizes = [num_patches] * num_images - image_features = torch.split(image_features, split_sizes, dim=0) - - # NOTE we only support multimodal_patch_merge_type == "spatial_unpad" - height = width = ( - self.config.vision_config.image_size - // self.config.vision_config.patch_size + if token_idx is not None: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + token_idx=token_idx, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, ) - new_image_features = [] - for image_idx, image_feature in enumerate(image_features): - if image_feature.shape[0] > 1: - base_image_feature = image_feature[0] - image_feature = image_feature[1:] - - if height * width != base_image_feature.shape[0]: - raise ValueError( - "The number of patches is not consistent with the image size." - ) - - # Dimensions are intentionally swapped to be bug-compatible with - # upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59 - num_patch_width, num_patch_height = get_anyres_image_grid_shape( - image_sizes[image_idx], - self.config.image_grid_pinpoints, - self.config.vision_config.image_size, + logits = outputs[0] + + if not return_dict: + output = (logits,) + outputs[1:] + return output + + return outputs + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + pixel_values=None, + image_sizes=None, + attention_mask=None, + **kwargs, + ): + """ + Inherits from LlavaForConditionalGeneration: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava_next/modeling_llava_next.py#L635 + The only differences are: + - add new args token_idx + - add the process of merging images into inputs_embeds + """ + token_idx = kwargs.get("token_idx", None) + if token_idx is None: + return super().prepare_inputs_for_generation( + input_ids=input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + pixel_values=pixel_values, + image_sizes=image_sizes, + attention_mask=attention_mask, + **kwargs, + ) + else: + use_flash_attention = kwargs.get("use_flash_attention", False) + flash_attention_recompute = kwargs.get("flash_attention_recompute", False) + + position_ids = kwargs.get("position_ids", None) + labels = kwargs.get("labels", None) + if past_key_values is None and pixel_values is not None and input_ids.shape[1] != 1: + vision_feature_select_strategy = kwargs.get("vision_feature_select_strategy", None) + vision_feature_layer = kwargs.get("vision_feature_layer", None) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy ) - image_feature = image_feature.view( - num_patch_height, num_patch_width, height, width, -1 + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer ) - image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() - image_feature = image_feature.flatten(1, 2).flatten(2, 3) - image_feature = unpad_image(image_feature, image_sizes[image_idx]) - image_feature = torch.cat( - ( - image_feature, - self.image_newline[:, None, None].expand( - *image_feature.shape[:-1], 1 - ), - ), - dim=-1, + + # 1. Extract the input embeddings + inputs_embeds = self.get_input_embeddings()(input_ids) + # 2. Merge text and images + batch_size, num_patches, num_channels, height, width = pixel_values.shape + reshaped_pixel_values = pixel_values.view(batch_size * num_patches, num_channels, height, width) + image_features = self.vision_tower( + reshaped_pixel_values, + output_hidden_states=True, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, ) - image_feature = image_feature.flatten(1, 2).transpose(0, 1) - image_feature = torch.cat( - (base_image_feature, image_feature), dim=0 + + selected_image_feature = image_features.hidden_states[vision_feature_layer] + + if vision_feature_select_strategy == "default": + selected_image_feature = selected_image_feature[:, 1:] + elif vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature + + image_features = self.multi_modal_projector(selected_image_feature) + + # split up image_features for each of the individual images + # hence we get a list of image_features, each of shape (5, num_patches, hidden_size) + # if we assume each image has 5 image features (base image + 4 patches) + split_sizes = [image.shape[0] for image in pixel_values] + image_features = torch.split(image_features, split_sizes, dim=0) + + # NOTE we only support multimodal_patch_merge_type == "spatial_unpad" + height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size + + new_image_features = [] + for image_idx, image_feature in enumerate(image_features): + if image_feature.shape[0] > 1: + base_image_feature = image_feature[0] + image_feature = image_feature[1:] + + if height * width != base_image_feature.shape[0]: + raise ValueError("The number of patches is not consistent with the image size.") + + num_patch_height, num_patch_width = get_anyres_image_grid_shape( + image_sizes[image_idx].tolist(), + self.config.image_grid_pinpoints, + self.config.vision_config.image_size, + ) + + image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image(image_feature, image_sizes[image_idx]) + image_feature = torch.cat( + ( + image_feature, + self.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1), + ), + dim=-1, + ) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + image_feature = torch.cat((base_image_feature, image_feature), dim=0) + else: + image_feature = image_feature[0] + image_feature = torch.cat((image_feature, self.image_newline[None]), dim=0) + new_image_features.append(image_feature) + image_features = torch.stack(new_image_features, dim=0) + inputs_embeds = self._merge_input_ids_with_image_features(inputs_embeds, image_features, input_ids) + self.image_offset = image_features.shape[1] - 1 # image_token has occupied 1 token position. + # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of + # generation with cache + elif past_key_values is not None: + seq_len = input_ids.shape[1] + pad_len = seq_len - token_idx.item() + input_ids = torch.index_select(input_ids, 1, token_idx - 1) + # Retrieve the first layer to inspect the logits and mask out the hidden states + # that are set to 0 + first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] + + # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 + batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) + + # Get the target length + past_length = first_layer_past_key_value.shape[-1] + extended_attention_mask = torch.ones( + (attention_mask.shape[0], past_length), + dtype=attention_mask.dtype, + device=attention_mask.device, ) + # Filter out only the tokens that can be un-attended, this can happen + # if one uses Llava + Fused modules where the cache on the + # first iteration is already big enough, or if one passes custom cache + valid_indices = non_attended_tokens < extended_attention_mask.size(-1) + new_batch_index = batch_index[valid_indices] + new_non_attended_tokens = non_attended_tokens[valid_indices] + + # Zero-out the places where we don't need to attend + extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 + + attention_mask = extended_attention_mask + attention_mask[:, -pad_len:] = 0 + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + if token_idx is not None: + position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 + else: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} else: - image_feature = image_feature[0] - image_feature = torch.cat( - (image_feature, self.image_newline[None]), dim=0 - ) - new_image_features.append(image_feature) - image_features = torch.stack(new_image_features, dim=0) - - inputs_embeds = self._merge_input_ids_with_image_features( - input_ids, inputs_embeds, image_features - ) + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + "token_idx": token_idx, + "labels": labels, + "use_flash_attention": use_flash_attention, + "flash_attention_recompute": flash_attention_recompute, + } + ) - hidden_states = self.text_model.model( - inputs_embeds=inputs_embeds, - position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, - kv_cache=kv_cache, - block_tables=block_tables, - slots=slots, - seqlen=seqlen, - max_s=max_s, - true_max_s=max_s, - prefill_cache_indices=None, - adapter_data=adapter_data, - ) - if lm_head_indices is not None: - hidden_states = hidden_states[lm_head_indices] - logits, speculative_logits = self.text_model.lm_head(hidden_states) - return logits, speculative_logits + return model_inputs diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 6c518c2caa5..92c3cf0de57 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -1,36 +1,9 @@ import torch import os -from loguru import logger from typing import Dict, Optional -from text_generation_server.utils.log import log_master - -PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING").lower() in {"1", "true"} -log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}") -ATTENTION = os.getenv("ATTENTION") -_expected = {"paged", "flashdecoding", "flashinfer"} -assert ( - ATTENTION in _expected -), f"Attention is not valid {ATTENTION}, expected {_expected}" -log_master(logger.info, f"Using Attention = {ATTENTION}") - -if PREFIX_CACHING and ATTENTION not in {"flashinfer", "flashdecoding"}: - raise RuntimeError("Prefix caching is only supported with flashinfer") - MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None -TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.95")) -assert TGI_WIGGLE_ROOM > 0 -assert TGI_WIGGLE_ROOM < 1 - # This is overridden by the cli -BLOCK_SIZE: int -if ATTENTION == "flashdecoding": - BLOCK_SIZE = 256 -elif ATTENTION == "flashinfer": - BLOCK_SIZE = 1 -else: - BLOCK_SIZE = 16 - cuda_graphs = os.getenv("CUDA_GRAPHS") if cuda_graphs is not None: try: @@ -41,13 +14,18 @@ ) else: cuda_graphs = None -# sorting the cuda graphs in descending order helps reduce the -# memory impact and results in less memory usage -if cuda_graphs is not None: - cuda_graphs.sort(reverse=True) CUDA_GRAPHS = cuda_graphs +# This is overridden at model loading. +global MODEL_ID +MODEL_ID = None + + +def set_model_id(model_id: str): + global MODEL_ID + MODEL_ID = model_id + # NOTE: eventually we should move this into the router and pass back the # index in all cases. ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None diff --git a/server/text_generation_server/models/starcoder.py b/server/text_generation_server/models/starcoder.py new file mode 100644 index 00000000000..98e7939a651 --- /dev/null +++ b/server/text_generation_server/models/starcoder.py @@ -0,0 +1,47 @@ +from loguru import logger +import torch +from dataclasses import dataclass +import os +from typing import List, Optional, Type + +from text_generation_server.models import CausalLM +from text_generation_server.models.causal_lm import CausalLMBatch + + +@dataclass +class StarCoderCausalLMBatch(CausalLMBatch): + past_key_values: Optional[List[torch.Tensor]] + + def detach_kv_cache(self): + past_keys = [] + past_values = [] + last_dim = int(self.past_key_values[0].size(dim=-1)/2) + for key_value in self.past_key_values: + past_keys.append(key_value.split((last_dim, last_dim), dim=-1)[0]) + past_values.append(key_value.split((last_dim, last_dim), dim=-1)[1]) + del self.past_key_values + + return past_keys, past_values + + def attach_kv_cache(self, past_keys, past_values): + self.past_key_values = [ + torch.cat((key, value), dim=-1) for key, value in zip(past_keys, past_values)] + + +class StarCoder(CausalLM): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + ): + + super(StarCoder, self).__init__( + model_id=model_id, + revision=revision, + dtype=dtype, + ) + + @property + def batch_type(self) -> Type[CausalLMBatch]: + return StarCoderCausalLMBatch \ No newline at end of file diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 7f7d2e4d9f4..a07dafd584e 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -1,22 +1,70 @@ +import re import torch +import os +import time +import math from PIL import Image from io import BytesIO - +import base64 +import numpy from opentelemetry import trace +from loguru import logger from typing import Iterable, Optional, Tuple, List, Type, Dict - +import itertools +import tempfile +import copy +from text_generation_server.models import Model from transformers import PreTrainedTokenizerBase from transformers.image_processing_utils import select_best_resolution +from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.pb import generate_pb2 -from text_generation_server.models.flash_causal_lm import ( - FlashCausalLMBatch, - FlashCausalLM, - block_tables_to_ragged, +from text_generation_server.models.causal_lm import ( + CausalLMBatch, + CausalLMRequest, + remove_kv_cache_from_output, + biggest_single_chunk, +) + +from transformers.models.llava_next.modeling_llava_next import ( + get_anyres_image_grid_shape, ) -from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION -from text_generation_server.utils.log import log_master + from transformers import AutoProcessor -from text_generation_server.layers.attention import Seqlen +import text_generation_server.habana_quantization_env as hq_env +from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi +from text_generation_server.utils import ( + HeterogeneousNextTokenChooser, + StoppingCriteria, + make_tokenizer_optional, + is_tokenizer_transparent, + pad_next_token_chooser_parameters, +) +import habana_frameworks.torch as htorch +from optimum.habana.utils import HabanaProfile +from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES +from optimum.habana.utils import get_hpu_memory_stats +from optimum.habana.checkpoint_utils import get_ds_injection_policy + +from transformers import ( + AutoTokenizer, + AutoModel, + PreTrainedTokenizerBase, + AutoConfig, +) +from optimum.habana.checkpoint_utils import ( + get_repo_root, + model_on_meta, + write_checkpoints_json, +) + +from text_generation_server.utils.speculate import get_speculate +from text_generation_server.models.types import ( + Batch, + Tokens, + Generation, + GeneratedText, +) +from text_generation_server.utils.debug import dbg_trace tracer = trace.get_tracer(__name__) @@ -24,28 +72,39 @@ IDEFICS2_IMAGE_TOKEN = "" -def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): - """ - Calculate the shape of the image patch grid after the preprocessing for images of any resolution. +IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)") +BASE_IMAGE_TOKENS = int(os.environ.get('BASE_IMAGE_TOKENS', 2048)) +MAX_TOTAL_TOKENS = int(os.environ.get('MAX_TOTAL_TOKENS', 8192)) +MAX_BATCH_TOTAL_TOKENS = int(os.environ.get('MAX_BATCH_TOTAL_TOKENS', 131072)) +PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF', 256)) +CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048] +LAZY_MODE = int(os.environ.get('PT_HPU_LAZY_MODE', 1)) + +PREFILL_WARMUP_BATCH_SIZE_LIST = [] +PREFILL_WARMUP_SEQLEN_LIST = [] +DECODE_WARMUP_BATCH_SIZE_LIST = [] +def round_up(warmup_list:list, num) : + i = 0 + for i in warmup_list: + if num <= i : + break + return i - Args: - image_size (`tuple`): - The size of the input image in the format (height, width). - grid_pinpoints (`List`): - A list containing possible resolutions. Each item in the list should be a tuple or list - of the form `(height, width)`. - patch_size (`int`): - The size of each image patch. +def split(string) -> List[Dict[str, str]]: + parts = [] + cursor = 0 + for pattern in IMAGES.finditer(string): + start = pattern.start() + if start != cursor: + parts.append({"type": "text", "content": string[cursor:start]}) - Returns: - tuple: The shape of the image patch grid in the format (width, height). - """ - if not isinstance(grid_pinpoints, list): - raise ValueError("grid_pinpoints should be a list of tuples or lists") + parts.append({"type": "image", "content": pattern.group(1)}) + cursor = pattern.end() - height, width = select_best_resolution(image_size, grid_pinpoints) - return height // patch_size, width // patch_size + if cursor != len(string): + parts.append({"type": "text", "content": string[cursor:]}) + return parts def image_text_replacement(processor, image_input, config, image_id: int) -> str: if config.model_type == "idefics2": @@ -59,8 +118,7 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str num_features = get_number_of_features(height, width, config) from loguru import logger - log_master( - logger.info, + logger.info( f"Found {num_features} features in image of resolution {height}x{width}", ) return "" * num_features @@ -125,6 +183,7 @@ def get_number_of_features(height: int, width: int, config) -> int: image_grid_pinpoints, image_size, ) + unpadded_features, newline_features = get_unpadded_features( height, width, npatches, num_patch_height, num_patch_width ) @@ -133,31 +192,106 @@ def get_number_of_features(height: int, width: int, config) -> int: return unpadded_features + newline_features + base_features -class VlmCausalLMBatch(FlashCausalLMBatch): +class VlmCausalLMBatch(CausalLMBatch): pixel_values: Optional[List[torch.Tensor]] pixel_attention_mask: Optional[List[torch.Tensor]] image_sizes: Optional[List[Tuple[int, int]]] @classmethod - @tracer.start_as_current_span("concatenate") - def concatenate(cls, batches): - batch = super(VlmCausalLMBatch, cls).concatenate(batches) - batch.pixel_values = None - batch.pixel_attention_mask = None - batch.image_sizes = None - return batch + def from_tokenized( + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + batch_tokenized_inputs, + dtype: torch.dtype, + device: torch.device, + is_warmup: bool = False, + ) -> "VlmCausalLMBatch": + + dbg_trace('FROM_PB', f'num_reqs:{len(pb.requests)}') + requests = [CausalLMRequest.from_pb(idx, req, tokenizer) for idx, req in enumerate(pb.requests)] + + max_input_length = max(r.data.truncate for r in requests) + max_new_tokens = max(r.stopping_criteria.max_new_tokens for r in requests) + # TODO: Add support for sparse batches + top_n_tokens = [r.top_n_tokens for r in pb.requests] + top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64) + + # TODO: by tokenizing all inputs at once we loose information on actual input lengths + # this means that we cannot shift inputs to the left after a long input sequence + # was filtered out + new_bs = round_up(PREFILL_WARMUP_BATCH_SIZE_LIST, len(requests)) + parameters = [r.parameters for r in pb.requests] + # append the dummy parameters for dummy request + parameters = pad_next_token_chooser_parameters(parameters, new_bs) + + next_token_chooser = HeterogeneousNextTokenChooser.from_pb( + pb=parameters, + dtype=dtype, + device=device, + tokenizer=tokenizer, + quantization_enabled=hq_env.is_quantization_enabled, + ) + tokenized_inputs = batch_tokenized_inputs + input_len = tokenized_inputs["input_ids"].shape[1] + + bucket_size = max_input_length + left_padding = max_input_length - input_len + if is_warmup is False: + if input_len < max_input_length : + rounded_seq_len = round_up(PREFILL_WARMUP_SEQLEN_LIST, input_len + 1) + if rounded_seq_len <= max_input_length: + bucket_size = rounded_seq_len - 1 + else: + bucket_size = max_input_length - 1 + left_padding = bucket_size - input_len + + input_ids = tokenized_inputs["input_ids"] + attention_mask = tokenized_inputs["attention_mask"] + # Allocate space for first token + if left_padding > 0: + input_ids = torch.nn.functional.pad( + input_ids, (left_padding, 1), value=tokenizer.pad_token_id + ) + attention_mask = torch.nn.functional.pad( + attention_mask, (left_padding, 1), value=0 + ) + all_input_ids = torch.nn.functional.pad( + input_ids, (0, max_new_tokens), value=tokenizer.pad_token_id + ).T.split(1, dim=1) + + # New input length after left padding + input_len = bucket_size + for r in requests: + r.input_length = input_len + r.prefix_offset = input_len - 5 + r.read_offset = input_len + r.all_input_ids = all_input_ids[r.idx] + input_ids = input_ids.to(device) + attention_mask = attention_mask.to(device) + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + + htorch.core.mark_step() + + return cls( + batch_id=pb.id, + requests=requests, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=None, + merged_kv_cache=False, + next_token_chooser=next_token_chooser, + top_n_tokens=top_n_tokens, + top_n_tokens_tensor=top_n_tokens_tensor, + input_length=input_len, + ) - @tracer.start_as_current_span("filter") - def filter(self, request_ids: List[int]): - batch = super().filter(request_ids) - batch.pixel_values = None - batch.pixel_attention_mask = None - batch.image_sizes = None - return batch @classmethod def batch_tokenized_inputs( - cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config + cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config, is_warmup ): # Process images first. We need all of them so that the processor # can make the image splits the same size. And we need the final @@ -177,10 +311,9 @@ def batch_tokenized_inputs( else: raise RuntimeError(f"Invalid chunk type {chunk_type}") + image_inputs = None if images: image_inputs = processor.image_processor(images, return_tensors="pt") - else: - image_inputs = None batch_inputs = [] max_truncation = 0 @@ -196,18 +329,57 @@ def batch_tokenized_inputs( processor, image_inputs, config, image_id ) image_id += 1 - full_text = image_text_replacement_fixup(config, full_text) batch_inputs.append(full_text) max_truncation = max(max_truncation, r.truncate) + missing_inputs = 0 + dummy_images = None + if is_warmup is False: + new_bs = round_up(PREFILL_WARMUP_BATCH_SIZE_LIST, len(requests)) + missing_inputs = new_bs - len(requests) + if missing_inputs > 0: + dummy_inputs = [] + if len(batch_inputs) > 0: + dummy_inputs = [batch_inputs[0]] * missing_inputs + + batch_inputs += dummy_inputs + batch_tokenized_inputs = tokenizer( batch_inputs, truncation=True, max_length=max_truncation, add_special_tokens=not config.model_type == "paligemma", - )["input_ids"] + return_tensors="pt", + padding="longest", + return_token_type_ids=False, + ) + + if missing_inputs > 0 and image_inputs is not None: + dummy_shape = list(image_inputs['pixel_values'].shape) + dummy_shape[0] = missing_inputs + dummy_images = torch.rand(dummy_shape) + new_image_inputs = { + "pixel_values": torch.cat( + (image_inputs['pixel_values'], dummy_images), dim=0 + ), + } + if "pixel_attention_mask" in image_inputs: + dummy_shape = list(image_inputs['pixel_attention_mask'].shape) + dummy_shape[0] = missing_inputs + dummy_attention = torch.zeros(dummy_shape) + new_image_inputs["pixel_attention_mask"] = torch.cat( + (image_inputs["pixel_attention_mask"], dummy_attention), dim=0 + ) + if "image_sizes" in image_inputs: + dummy_shape = list(image_inputs['image_sizes'].shape) + dummy_shape[0] = missing_inputs + dummy_sizes = torch.randint(dummy_shape) + new_image_inputs["image_sizes"] = torch.cat( + (image_inputs["image_sizes"], dummy_sizes), dim=0 + ) + image_inputs = new_image_inputs return batch_tokenized_inputs, image_inputs @@ -220,9 +392,10 @@ def from_pb_processor( config, dtype: torch.dtype, device: torch.device, + is_warmup: bool = False, ) -> "VlmCausalLMBatch": batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs( - pb.requests, tokenizer, processor, config + pb.requests, tokenizer, processor, config, is_warmup ) batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) if image_inputs is not None: @@ -243,21 +416,131 @@ def from_pb_processor( batch.image_sizes = None return batch + @classmethod + @tracer.start_as_current_span("concatenate") + def concatenate(cls, batches: List["CausalLMBatch"], pad_token_id: int = 0, is_warmup:bool = False) -> "CausalLMBatch": + return cls.recombine(batches, pad_token_id, is_warmup) + + -class VlmCausalLM(FlashCausalLM): + @classmethod + def recombine(cls, batches: List["VlmCausalLMBatch"], pad_token_id: int, is_warmup: bool =False) -> "VlmCausalLMBatch": + if not all(b.past_key_values is not None for b in batches): + raise ValueError("KV cache not allocated! Cannot recombine before prefill!") + + total_requests = sum(len(b) for b in batches) + new_bs = total_requests + if is_warmup is False : + new_bs = round_up(DECODE_WARMUP_BATCH_SIZE_LIST, total_requests) + batch_id = batches[0].batch_id + device = batches[0].input_ids.device + + input_lengths = [b.input_length for b in batches] + max_input_length = max(input_lengths) + offsets = [max_input_length - b.input_length for b in batches] + + cur_padding = [b.right_padding for b in batches] + # For prefill there is a space allocated only for first token + # Need to add padding to the max total tokens before first decode + + moves_needed = [total_requests - len(b) if b.batch_size == new_bs else total_requests for b in batches] + dst_batch_idx = min(enumerate(moves_needed), key=lambda idx_val: idx_val[1])[0] + reshape = (batches[dst_batch_idx].batch_size < new_bs) + + # TODO: Add support for changing max seq len, i.e. due to output length bucketing + # FIXME: max_seq_len for non optimized code + if len(batches) > 1: + scenario = 'CONCAT' + elif reshape: + scenario = 'RESHAPE' + elif cur_padding[dst_batch_idx] <= 0: + scenario = 'SHIFT' + offsets = [biggest_single_chunk(b.max_input_length - max_input_length) for b in batches] + max_input_length = max_input_length + offsets[dst_batch_idx] + else: + # Nothing to do + return batches[0] + + dbg_trace( + scenario, f'bs:{[b.batch_size for b in batches]}->{new_bs}' + f' reqs:{[len(b) for b in batches]}' + f' offsets:{offsets}' + f' input_lengths:{input_lengths}' + f' cur_padding:{cur_padding}' + f' dst_batch:{dst_batch_idx}') + + grouped_requests = [[req for req in batch.requests] for batch in batches] + flat_requests = list(itertools.chain(*grouped_requests)) + + for i in range(len(batches)): + target_bs = new_bs if i == dst_batch_idx else batches[i].batch_size + batches[i].merge_kv_cache_if_needed(target_bs, offsets[i]) + batches[i].realign(target_bs, offsets[i], pad_token_id) + batches[i].split_kv_cache_if_needed(i == dst_batch_idx) + batches[dst_batch_idx].expand_bs(new_bs) + batches[dst_batch_idx].move_data([batches[i] for i in range(len(batches)) if i != dst_batch_idx]) + + top_n_tokens = [r.data.top_n_tokens for r in flat_requests] + top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64) + + parameters = [r.data.parameters for r in flat_requests] + # append the dummy parameters for dummy requests + batch_size = batches[dst_batch_idx].batch_size + parameters = pad_next_token_chooser_parameters(parameters, batch_size) + + # update past grammar states + fsm_grammar_states = [0] * batch_size + for batch in batches: + for i, req in enumerate(batch.requests): + fsm_grammar_states[req.idx] = batch.next_token_chooser.fsm_grammar_states[i] + + next_token_chooser = HeterogeneousNextTokenChooser.from_pb( + parameters, + batches[dst_batch_idx].next_token_chooser.dtype, + batches[dst_batch_idx].next_token_chooser.device, + batches[dst_batch_idx].next_token_chooser.tokenizer, + fsm_grammar_states, + quantization_enabled=hq_env.is_quantization_enabled, + ) + + input_ids = batches[dst_batch_idx].input_ids + attention_mask = batches[dst_batch_idx].attention_mask + position_ids = batches[dst_batch_idx].position_ids + past_key_values = batches[dst_batch_idx].past_key_values + input_length = max_input_length + + htorch.core.mark_step() + + return cls( + batch_id=batch_id, + requests=flat_requests, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + merged_kv_cache=False, + next_token_chooser=next_token_chooser, + top_n_tokens=top_n_tokens, + top_n_tokens_tensor=top_n_tokens_tensor, + input_length=input_length, + ) + +class VlmCausalLM(Model): def __init__( self, + model_class, model_id: str, *, processor_class=AutoProcessor, processor_kwargs=None, batch_class=VlmCausalLMBatch, revision, + quantize: Optional[str] = None, + dtype, trust_remote_code: bool, **kwargs, ): - if PREFIX_CACHING: - raise NotImplementedError("Vlm do not work with prefix caching yet") + adapt_transformers_to_gaudi() if processor_kwargs is None: processor_kwargs = {} self.processor = processor_class.from_pretrained( @@ -267,12 +550,134 @@ def __init__( **processor_kwargs, ) self.batch_class = batch_class - super().__init__( - model_id=model_id, + self.prev_bs = 0 + self.quantize = quantize + + # Create tokenizer + tokenizer = AutoTokenizer.from_pretrained( + model_id, revision=revision, + padding_side="left", + truncation_side="left", trust_remote_code=trust_remote_code, - **kwargs, ) + make_tokenizer_optional(tokenizer) + + # Create model + world_size = int(os.getenv("WORLD_SIZE", "1")) + rank = int(os.getenv("RANK", "0")) + dtype = torch.bfloat16 if dtype is None else dtype + device = torch.device("hpu") + + if hq_env.is_quantization_enabled: + htorch.core.hpu_set_env() + + if world_size > 1: + model = self.get_deepspeed_model( + model_class, model_id, dtype, revision + ) + model = self.prepare_model_for_quantization(model) + else: + get_repo_root(model_id) + + # Check support for rope scaling + model_kwargs = {} + config = AutoConfig.from_pretrained( + model_id + ) + if hasattr(config, "rope_scaling"): + model_kwargs["rope_scaling"] = self.get_rope_scaling() + + model = model_class.from_pretrained( + model_id, + revision=revision, + torch_dtype=dtype, + trust_remote_code=trust_remote_code, + **model_kwargs + ) + model = self.prepare_model_for_quantization(model) + model = model.eval().to(device) + + self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1 + self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true" + model = remove_kv_cache_from_output(model) + if self.enable_hpu_graph: + from habana_frameworks.torch.hpu import wrap_in_hpu_graph + model = wrap_in_hpu_graph(model, disable_tensor_cache=True) + else: + if LAZY_MODE == 0: + # It is said that "keep_input_mutations" is safe for inference to be done + dbg_trace( + "TORCH COMPILE", f'Torch compiling of model') + model.model = torch.compile(model.model, backend="hpu_backend", options={"keep_input_mutations": True}) + + model = self.setup_quantization(model) + + if model.config.model_type not in MODELS_OPTIMIZED_WITH_STATIC_SHAPES: + raise ValueError(f"Model type {model.config.model_type} is not supported!") + + if tokenizer.pad_token_id is None: + if model.config.pad_token_id is not None: + tokenizer.pad_token_id = model.config.pad_token_id + elif model.config.eos_token_id is not None: + if isinstance(model.config.eos_token_id, int): + tokenizer.pad_token_id = model.config.eos_token_id + elif isinstance(model.config.eos_token_id, list): + tokenizer.pad_token_id = model.config.eos_token_id[0] + else: + raise ValueError( + f"{type(model.config.eos_token_id)} type of eos_token_id in the model's config is not supported for tokenizer.pad_token_id" + ) + elif tokenizer.eos_token_id is not None: + tokenizer.pad_token_id = tokenizer.eos_token_id + else: + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + + self.kwargs = { + "use_cache": True, + "return_dict": True, + } + + if model.config.model_type in ["llava_next"]: + self.kwargs["attn_softmax_bf16"] = True + self.kwargs["trim_logits"] = True + + if os.getenv("USE_FLASH_ATTENTION", "false").lower() == "true": + self.kwargs["use_flash_attention"] = True + if os.getenv("FLASH_ATTENTION_RECOMPUTE", "false").lower() == "true": + self.kwargs["flash_attention_recompute"] = True + + self.speculate = get_speculate() + super(VlmCausalLM, self).__init__( + model_id=model_id, + model=model, + tokenizer=tokenizer, + requires_padding=True, + dtype=dtype, + device=device, + rank=rank, + ) + + # Create profiler + ranks_to_profile = [int(val) for val in os.getenv("PROF_RANKS", "0").split(',')] + record_shapes = os.getenv("PROF_RECORD_SHAPES", "false").lower() == "true" + output_dir = os.getenv("PROF_PATH", "/tmp/hpu_profile") + self.profiling_warmup_steps = int(os.getenv("PROF_WARMUPSTEP", "0")) if rank in ranks_to_profile else 0 + self.profiling_steps = int(os.getenv("PROF_STEP", "0")) if rank in ranks_to_profile else 0 + self.profiling_wait_steps = int(os.getenv("PROF_WAITSTEP", "0")) + if self.profiling_steps > 0: + self.hb_profiler = HabanaProfile( + wait=self.profiling_wait_steps, + warmup=self.profiling_warmup_steps, + active=self.profiling_steps, + output_dir=output_dir, + record_shapes=record_shapes + ) + self.hb_profiler.start() + else: + self.hb_profiler = None + self.step = 0 + @property def batch_type(self) -> Type[VlmCausalLMBatch]: @@ -281,158 +686,569 @@ def batch_type(self) -> Type[VlmCausalLMBatch]: def max_past(self) -> Optional[int]: return getattr(self.model.text_model, "max_past", None) + def get_deepspeed_model( + self, + model_class, + model_id: str, + dtype: torch.dtype, + revision: Optional[str] = None + ) -> torch.nn.Module: + import deepspeed + from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu + + world_size, rank, local_rank = initialize_distributed_hpu() + model_kwargs = { + "revision": revision + } + + # Initialize process(es) for DeepSpeed + deepspeed.init_distributed(dist_backend="hccl") + logger.info( + "DeepSpeed is enabled. world_size {} rank {} local_rank {}".format(world_size, rank, local_rank) + ) + config = AutoConfig.from_pretrained(model_id, **model_kwargs) + load_to_meta = model_on_meta(config) + + # Check support for rope scaling + if hasattr(config, "rope_scaling"): + config.rope_scaling = self.get_rope_scaling() + model_kwargs["rope_scaling"] = self.get_rope_scaling() + + if load_to_meta: + # Construct model with fake meta tensors, later will be replaced on devices during ds-inference ckpt load + with deepspeed.OnDevice(dtype=dtype, device="meta"): + model = model_class.from_config(config, torch_dtype=dtype) + else: + get_repo_root(model_id, local_rank=os.getenv("LOCAL_RANK")) + # TODO: revisit placement on CPU when auto-injection is possible + with deepspeed.OnDevice(dtype=dtype, device="cpu"): + model = model_class.from_pretrained(model_id, torch_dtype=dtype, **model_kwargs) + model = model.eval() + + # Initialize the model + ds_inference_kwargs = {"dtype": dtype} + ds_inference_kwargs["tensor_parallel"] = {"tp_size": world_size} + ds_inference_kwargs["enable_cuda_graph"] = False + ds_inference_kwargs["injection_policy"] = get_ds_injection_policy(model.language_model.config) + + if load_to_meta: + # model loaded to meta is managed differently + checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w") + write_checkpoints_json(model_id, local_rank, checkpoints_json) + ds_inference_kwargs["checkpoint"] = checkpoints_json.name + model = deepspeed.init_inference(model, **ds_inference_kwargs) + + return model.module + + def get_rope_scaling(self) -> Optional[Dict]: + rope_scaling = os.getenv("ROPE_SCALING", None) + if rope_scaling is None: + return None + + rope_factor = float(os.getenv("ROPE_FACTOR", 1.0)) + return { + 'type': rope_scaling, 'factor': float(rope_factor) + } + + def setup_quantization(self, model): + if hq_env.is_quantization_enabled: + htorch.core.quantization._mark_params_as_const(model) + htorch.core.quantization._check_params_as_const(model) + htorch.core.hpu_initialize(model) + return model + + def prepare_model_for_quantization(self, model): + if hq_env.is_quantization_enabled: + if model.config.model_type == "llama": + self.patch_scoped_linear_all_reduce(model) + model = hq_env.prepare_model_for_quantization(model) + return model + + def patch_scoped_linear_all_reduce(self, model): + from deepspeed.module_inject.layers import LinearAllreduce + from optimum.habana.transformers.models.modeling_all_models import ScopedLinearAllReduce + for name, module in model.named_children(): + if type(module) is LinearAllreduce: + SL = ScopedLinearAllReduce(mod=module) + setattr(model, name, SL) + self.patch_scoped_linear_all_reduce(module) + + def decode(self, generated_ids: List[int]) -> str: + return self.tokenizer.decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) + + def decode_token( + self, + all_input_ids: List[int], + prefix_offset: int = 0, + read_offset: int = 0, + ) -> Tuple[str, int, int]: + if is_tokenizer_transparent(self.tokenizer): + new_text = self.tokenizer.decode(all_input_ids[read_offset:], skip_special_tokens=False) + return new_text, read_offset, len(all_input_ids) + else: + return super().decode_token(all_input_ids, prefix_offset, read_offset) + def forward( self, - batch: VlmCausalLMBatch, - adapter_data: Optional[Dict[str, torch.Tensor]] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + input_ids, + attention_mask, + position_ids, + token_idx, + past_key_values: Optional[List[Tuple]] = None, + pixel_values: Optional[List[torch.Tensor]] = None, + image_sizes: Optional[List[Tuple[int, int]]] = None, + bypass_hpu_graph: Optional[bool] = None, + ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: # Model Forward - if batch.speculative_ids is not None: - input_ids = batch.input_ids - position_ids = batch.position_ids - cu_seqlen_prefill = batch.cu_seqlen_prefill - kv_cache = self.kv_cache - block_tables = batch.block_tables_tensor - slots = batch.slots[batch.slot_indices] - input_lengths = batch.input_lengths_tensor - max_s = batch.max_seqlen - lm_head_indices = batch.prefill_head_indices - - speculative_ids = batch.speculative_ids - - B, speculative_length = speculative_ids.shape - new_length = speculative_length + 1 - new_input_ids = torch.cat( - [input_ids.unsqueeze(-1), speculative_ids], dim=1 - ).reshape(-1) - arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0) - arange_int = arange.to(dtype=torch.int32) - new_position_ids = ( - position_ids.unsqueeze(-1).expand(B, new_length) + arange - ).view(-1) - slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) - input_lengths = ( - input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int - ).view(-1) - prefix_lens_tensor = ( - batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length) - ).reshape(-1) - - # Add Copy the block tables for all members - block_tables = ( - block_tables.unsqueeze(1) - .expand(B, new_length, -1) - .reshape(B * new_length, -1) - .contiguous() - ) - max_s = max_s + speculative_length + kwargs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "past_key_values": past_key_values, + "token_idx": token_idx, + "pixel_values": pixel_values, + "image_sizes": image_sizes, + } + + hpu_kwargs = {} + # Optimum Habana got "lazy_mode" key-val only supported for llama type of models + if self.model.config.model_type == "llama" : + hpu_kwargs["lazy_mode"] = LAZY_MODE == 1 + + if self.has_position_ids: + kwargs["position_ids"] = position_ids - input_ids = new_input_ids - position_ids = new_position_ids + if bypass_hpu_graph != None: + hpu_kwargs["bypass_hpu_graphs"] = bypass_hpu_graph + + kwargs.update(self.kwargs) + model_inputs = self.model.prepare_inputs_for_generation(**kwargs) + if past_key_values is not None: + return self.model.forward(**model_inputs, **hpu_kwargs) else: - input_ids = batch.input_ids - position_ids = batch.position_ids - cu_seqlen_prefill = batch.cu_seqlen_prefill - kv_cache = self.kv_cache - block_tables = batch.block_tables_tensor - slots = batch.slots[batch.slot_indices] - input_lengths = batch.input_lengths_tensor - prefix_lens_tensor = batch.prefix_lens_tensor - max_s = batch.max_seqlen - lm_head_indices = batch.prefill_head_indices - - if cu_seqlen_prefill is None and self.max_past() is not None: - # In decode, not prefill, we're actually overwriting the KV-cache - # in a circular buffer mode. - # This makes sure the max_s for the decode pass is correct. - max_s = min(self.max_past(), max_s) - - bs = input_ids.shape[0] - # Try to find an associated cuda graph - bs = input_ids.shape[0] - sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs]) - if sorted_padded_bs: - # Get associated cuda graph - cuda_graph = self.cuda_graphs[sorted_padded_bs[0]] + outputs = self.model.forward(**model_inputs, **hpu_kwargs) + return outputs.logits, outputs.past_key_values + + @tracer.start_as_current_span("generate_token") + def generate_token( + self, batches: List[VlmCausalLMBatch], is_warmup: bool = False + ) -> Tuple[List[Generation], Optional[CausalLMBatch], Tuple[int, int]]: + start = time.time_ns() + # Results + generations: List[Generation] = [] + prev_batches = [] + requests_to_generate = [] + # In order to pipeline any actions on CPU we perform the operation in 3 main stages: + # Stage 1. Collect next token ids of any previously started generations + for batch_id, batch in enumerate(batches): + if batch.logits is not None: + logits = batch.logits + past = batch.past + prefill = batch.past_key_values is None + if prefill: + # no right padding for prefill + token_idx_scalar = batch.attention_mask.shape[-1] - 1 + token_idx = torch.tensor(token_idx_scalar).to(self.device) + else: + token_idx_scalar = batch.attention_mask.shape[-1] - batch.right_padding + token_idx = torch.tensor(token_idx_scalar).to(self.device) + + # Select next token + input_length = batch.input_length + if logits.shape[-2] > 1: + next_token_ids, next_token_logprobs, logprobs, _, _ = batch.next_token_chooser( + batch.input_ids, logits[:, input_length - 1: input_length, :].squeeze(-2), self.speculate + ) + else: + next_token_ids, next_token_logprobs, logprobs, _, _ = batch.next_token_chooser( + batch.input_ids, logits.squeeze(-2), self.speculate + ) + # Speculation is not active for causal + accepted_ids = torch.ones_like(batch.input_ids)[:, 0] + batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( + batch.top_n_tokens, + batch.top_n_tokens_tensor, + logprobs, + accepted_ids, + ) + + prev_batches.append({ + 'next_token_ids': next_token_ids, + 'next_token_logprobs': next_token_logprobs, + }) + + for req_idx, req in enumerate(batch.requests): + requests_to_generate.append({ + 'req': req, + 'prev_req_idx': req.idx, + 'batch_id': batch_id, + 'seed': batch.next_token_chooser.seeds[req_idx], + 'do_sample': batch.next_token_chooser.do_sample[req_idx], + 'top_n_tokens': batch.top_n_tokens[req_idx], + 'top_token_ids': batch_top_token_ids[req_idx], + 'top_token_logprobs': batch_top_token_logprobs[req_idx], + 'grammar_state': batch.next_token_chooser.fsm_grammar_states[req.idx], + }) + + htorch.core.mark_step() + + # Add new token into input_ids + batch.input_ids.index_copy_(1, token_idx, next_token_ids.unsqueeze(1)) + + # Update attention_mask as we added a new token to input_ids + batch.attention_mask.index_fill_(1, token_idx, 1) + + # Adjust lengths + batch.input_length += 1 + + # Update position_ids + if prefill: + batch.position_ids = torch.index_select(batch.position_ids, 1, token_idx - 1) + 1 + else: + batch.position_ids += 1 + # Update past key values + if prefill: + batch.past_key_values = past + + htorch.core.mark_step() + + # Stage 2. Prepare new batch for speculative scheduling + if len(batches) > 1: + batch = self.batch_type.concatenate(batches, self.tokenizer.pad_token_id, is_warmup) + else: + batch = batches[0] + + prefill = batch.past_key_values is None + + # Check if we need to do any bookkeeping first + if not prefill: + batch = batch.__class__.recombine([batch], self.tokenizer.pad_token_id, is_warmup) + + scenario = 'PREFILL' if prefill else 'GENERATE' + if self.enable_hpu_graph and self.limit_hpu_graph and round_up(DECODE_WARMUP_BATCH_SIZE_LIST, batch.batch_size) != self.prev_bs: + self.model.clear_cache() + self.prev_bs = round_up(DECODE_WARMUP_BATCH_SIZE_LIST, batch.batch_size) + dbg_trace( + scenario, f'bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}') + #assert batch.right_padding > 0, 'No more room for next token!' + + # Execute batch + if prefill: + # no right padding for prefill + token_idx = torch.tensor(batch.attention_mask.shape[-1] - 1).to(self.device) + batch.logits, batch.past = self.forward( + batch.input_ids, + batch.attention_mask, + batch.position_ids, + token_idx, + batch.past_key_values, + batch.pixel_values, + batch.image_sizes, + bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None, + ) + elif all([req.stopping_criteria.max_new_tokens == 1 for req in batch.requests]): + # Don't schedule next forward if max_new_tokens for all requests equals 1 + # - we've already generated the first and only needed token in the prefill phase + pass else: - cuda_graph = None - if cu_seqlen_prefill is not None or cuda_graph is None: - input_lengths = input_lengths + prefix_lens_tensor - if PREFIX_CACHING: - block_tables = block_tables_to_ragged( - block_tables=block_tables, - input_lengths=batch.input_lengths, - prefix_lens=batch.prefix_lens, + token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device) + batch.logits = self.forward( + batch.input_ids, + batch.attention_mask, + batch.position_ids, + token_idx, + batch.past_key_values, + bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None, + ) + + htorch.core.mark_step() + + start_decode = time.time_ns() + + # Stage 3. Finish and return previous generations + stopped = len(requests_to_generate) > 0 + for prev_batch in prev_batches: + prev_batch['next_token_logprobs'] = prev_batch['next_token_logprobs'].tolist() + prev_batch['next_token_ids_cpu'] = prev_batch['next_token_ids'].cpu() + htorch.core.mark_step() + + for req_data in requests_to_generate: + req = req_data['req'] + i = req_data['prev_req_idx'] + prev_batch_id = req_data['batch_id'] + assert len(prev_batches) > prev_batch_id + next_token_ids_cpu = prev_batches[prev_batch_id]['next_token_ids_cpu'] + next_token_logprobs = prev_batches[prev_batch_id]['next_token_logprobs'] + + request = req.data + input_length = req.input_length + prefix_offset = req.prefix_offset + read_offset = req.read_offset + do_sample = req_data['do_sample'] + seed = req_data['seed'] + stopping_criteria = req.stopping_criteria + all_input_ids = req.all_input_ids + next_token_id = next_token_ids_cpu[i] + next_token_logprob = next_token_logprobs[i] + top_n_tokens = req_data['top_n_tokens'] + top_token_ids = req_data['top_token_ids'] + top_token_logprobs = req_data['top_token_logprobs'] + grammar_state = req_data['grammar_state'] + + # Append next token to all tokens + all_input_ids[input_length] = next_token_id + new_input_length = input_length + 1 + + # Generated token + if is_tokenizer_transparent(self.tokenizer) and len(stopping_criteria.stop_sequence_criterias) == 0: + next_token_text = '' + else: + next_token_text, prefix_offset, read_offset = self.decode_token( + all_input_ids[0:new_input_length, 0], prefix_offset, read_offset ) - with self._forward_context( - block_tables=block_tables, - cu_seqlen_prefill=cu_seqlen_prefill, - input_lengths_tensor=input_lengths, - prefix_lens_tensor=prefix_lens_tensor, - ): - max_k = (input_lengths + prefix_lens_tensor).max().item() - seqlen = Seqlen( - input_lengths=input_lengths, - prefix_lengths=prefix_lens_tensor, - cu_seqlen_q=cu_seqlen_prefill, - max_q=max_s, - max_k=max_k, + + # Evaluate stopping criteria + stop, reason = stopping_criteria( + next_token_id, + next_token_text, + ) + + if not stop: + stopped = False + + # Shard generations + # All generations will be appended in the rust sharded client + if i % self.world_size == self.rank: + if stop: + # Decode generated tokens + if is_tokenizer_transparent(self.tokenizer): + output_text = None + else: + output_text = self.decode( + all_input_ids[new_input_length - stopping_criteria.current_tokens: new_input_length, 0] + ) + generated_text = GeneratedText( + output_text, + stopping_criteria.current_tokens, + reason, + seed if do_sample else None, + ) + else: + generated_text = None + + # Prefill + if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: + # Remove generated token to only have prefill and add nan for first prompt token + prefill_logprobs = [float("nan")] + next_token_logprobs + prefill_token_ids = all_input_ids[0: new_input_length - 1] + prefill_texts = self.tokenizer.batch_decode( + prefill_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + prefill_tokens = Tokens( + prefill_token_ids, + prefill_logprobs, + prefill_texts, + is_special=[], + ) + else: + prefill_tokens = None + + if top_n_tokens > 0: + all_top_tokens = [] + for top_token_ids, top_token_logprobs in zip( + top_token_ids, top_token_logprobs + ): + toptoken_texts = self.tokenizer.batch_decode( + top_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + special_toptokens = [ + token_id in self.all_special_ids + for token_id in top_token_ids + ] + top_tokens = Tokens( + top_token_ids, + top_token_logprobs, + toptoken_texts, + special_toptokens, + ) + all_top_tokens.append(top_tokens) + top_tokens = all_top_tokens + else: + top_tokens = None + + generation = Generation( + request.id, + prefill_tokens, + Tokens( + [next_token_id], + [next_token_logprob], + [next_token_text], + [next_token_id in self.all_special_ids], + ), + generated_text, + top_tokens, ) - logits, speculative_logits = self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, - kv_cache=kv_cache, - block_tables=block_tables, - slots=slots, - seqlen=seqlen, - max_s=max_s, - prefill_cache_indices=batch.prefill_cache_indices, - lm_head_indices=lm_head_indices, - pixel_values=batch.pixel_values, - pixel_attention_mask=batch.pixel_attention_mask, - image_sizes=batch.image_sizes, + + generations.append(generation) + + batch.next_token_chooser = ( + batch.next_token_chooser.advance_grammar_single_with_past_state( + req.idx, next_token_id, grammar_state ) - if batch.prefill_cache_indices is not None: - batch.prefill_cache_indices = None - if batch.pixel_values is not None: - batch.pixel_values = None - if batch.pixel_attention_mask is not None: - batch.pixel_attention_mask = None - if batch.image_sizes is not None: - batch.image_sizes = None - return logits, speculative_logits - - # Copy inputs to the static inputs of the cuda graph - # Static inputs are potentially padded - cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids - cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids - if ATTENTION == "flashinfer": - block_tables = block_tables_to_ragged( - block_tables=block_tables, - input_lengths=batch.input_lengths, - prefix_lens=batch.prefix_lens, ) - cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables - else: - cuda_graph["block_tables"][ - : block_tables.shape[0], : block_tables.shape[1] - ] = block_tables - cuda_graph["slots"].fill_(-1) - cuda_graph["slots"][: slots.shape[0]] = slots - cuda_graph["input_lengths"].zero_() - cuda_graph["input_lengths"][: input_lengths.shape[0]] = ( - input_lengths + prefix_lens_tensor - ) - # Replay the graph - cuda_graph["graph"].replay() + req.all_input_ids = all_input_ids + req.input_length = new_input_length + req.prefix_offset = prefix_offset + req.read_offset = read_offset + + htorch.core.mark_step() + self.step = self.step + 1 + if self.hb_profiler is not None: + if self.step > self.profiling_wait_steps + self.profiling_warmup_steps + self.profiling_steps: + self.hb_profiler.stop() + else: + self.hb_profiler.step() + + forward_ns = start_decode - start + decode_ns = time.time_ns() - start_decode + return generations, batch if not stopped else None, (forward_ns, decode_ns) - # Slice output to the correct shape - speculative_logits = ( - cuda_graph["speculative_logits"][:bs] - if cuda_graph["speculative_logits"] is not None - else None + def batch_from_pb(self, batch, is_warmup): + return VlmCausalLMBatch.from_pb_processor( + batch, + self.tokenizer, + self.processor, + self.model.config, + self.dtype, + self.device, + is_warmup ) - logits = cuda_graph["logits"][:bs] - return logits, speculative_logits + + def generate_warmup_batch(self, request, seq_len, batch_size, is_warmup): + batch = copy.deepcopy(request.batch) + for req in batch.requests: + req.truncate = seq_len + + for i in range(len(batch.requests) - batch_size): + batch.requests.pop() + + return self.batch_from_pb(batch, is_warmup) + + def warmup(self, request) -> None: + is_warmup = True + batch = self.batch_from_pb(request.batch, is_warmup) + + try: + # max prefill batch size warmup + _, prefill_batch, _ = self.generate_token([batch], is_warmup) + except: + raise RuntimeError( + f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. " + f"You need to decrease `--max-batch-prefill-tokens`" + ) + + global BASE_IMAGE_TOKENS, MAX_TOTAL_TOKENS, MAX_BATCH_TOTAL_TOKENS, PREFILL_WARMUP_BATCH_SIZE_LIST, PREFILL_WARMUP_SEQLEN_LIST, DECODE_WARMUP_BATCH_SIZE_LIST + max_input_length = batch.input_ids.shape[1] + max_prefill_batch_size = batch.input_ids.shape[0] + PREFILL_WARMUP_BATCH_SIZE_LIST = [] + batch_size = 1 + while batch_size <= max_prefill_batch_size: + PREFILL_WARMUP_BATCH_SIZE_LIST.append(batch_size) + batch_size = batch_size * 2 + if PREFILL_WARMUP_BATCH_SIZE_LIST[-1] < max_prefill_batch_size : + PREFILL_WARMUP_BATCH_SIZE_LIST.append(max_prefill_batch_size) + + seq_len = BASE_IMAGE_TOKENS + PREFILL_WARMUP_SEQLEN_LIST = [] + i = 0 + while seq_len <= max_input_length: + PREFILL_WARMUP_SEQLEN_LIST.append(seq_len) + seq_len += PAD_SEQUENCE_TO_MULTIPLE_OF*(2**i) + i += 1 + if PREFILL_WARMUP_SEQLEN_LIST[-1] < max_input_length: + PREFILL_WARMUP_SEQLEN_LIST.append(max_input_length) + + #Prefill and decode warmup + DECODE_WARMUP_BATCH_SIZE_LIST = [] + prefill_batch = None + decode_batch = None + try: + for batch_size in PREFILL_WARMUP_BATCH_SIZE_LIST : + for seq_len in PREFILL_WARMUP_SEQLEN_LIST : + batch = self.generate_warmup_batch(request, seq_len, batch_size, is_warmup) + _, prefill_batch, _ = self.generate_token([batch], is_warmup) + _, decode_batch, _ = self.generate_token([prefill_batch], is_warmup) + + DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size) + + except: + raise RuntimeError( + f"Not enough memory to handle following prefill and decode warmup." + f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}" + f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}" + f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}" + f"You need to decrease `--max-batch-prefill-tokens`" + ) + + mem_stats = get_hpu_memory_stats(self.device) + logger.info( + f"\nFollowing prefill and decode warmup successfully.\n" + f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}\n" + f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}\n" + f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}\n" + f"Memory stats: {mem_stats} " + ) + + max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS) + batch_size = max_prefill_batch_size * 2 + # Decode warmup with bigger batch_size + try: + if DECODE_WARMUP_BATCH_SIZE_LIST[-1] < max_decode_batch_size and batch_size <= max_decode_batch_size: + batches = [] + for i in range(int(batch_size/max_prefill_batch_size)) : + batch = self.generate_warmup_batch(request, PREFILL_WARMUP_SEQLEN_LIST[0], DECODE_WARMUP_BATCH_SIZE_LIST[-1], is_warmup) + _, prefill_batch, _ = self.generate_token([batch], is_warmup) + batches.append(prefill_batch) + while batch_size <= max_decode_batch_size: + _, decode_batch, _ = self.generate_token(batches, is_warmup) + DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size) + batch_size = batch_size * 2 + batches.clear() + + for i in range(int(batch_size/max_prefill_batch_size)) : + batch = self.generate_warmup_batch(request, PREFILL_WARMUP_SEQLEN_LIST[0], DECODE_WARMUP_BATCH_SIZE_LIST[-1], is_warmup) + _, prefill_batch, _ = self.generate_token([batch], is_warmup) + batches.append(prefill_batch) + + batches.clear() + if DECODE_WARMUP_BATCH_SIZE_LIST[-1] < max_decode_batch_size: + max_decode_batch_size = math.floor( max_decode_batch_size / 2) * 2 + batch_size = max_decode_batch_size + for i in range(int(max_decode_batch_size / 2)) : + batch = self.generate_warmup_batch(request, PREFILL_WARMUP_SEQLEN_LIST[0], 2, is_warmup) + _, prefill_batch, _ = self.generate_token([batch], is_warmup) + batches.append(prefill_batch) + _, decode_batch, _ = self.generate_token(batches, is_warmup) + DECODE_WARMUP_BATCH_SIZE_LIST.append(max_decode_batch_size) + max_batch_total_tokens = max_decode_batch_size * MAX_TOTAL_TOKENS + MAX_BATCH_TOTAL_TOKENS = max_batch_total_tokens + except : + raise RuntimeError( + f"Not enough memory to handle batch_size({batch_size}) decode warmup." + f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}" + f"max_decode_batch_size is {max_decode_batch_size}" + f"You need to decrease env `MAX_BATCH_TOTAL_TOKENS` or '--max_batch_total_tokens'" + ) + + mem_stats = get_hpu_memory_stats(self.device) + logger.info( + f"\nFollowing decode warmup successfully.\n" + f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}\n" + f"Memory stats: {mem_stats}" + ) + + return MAX_BATCH_TOTAL_TOKENS \ No newline at end of file diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 22871ec5fb5..6f00f38a8b9 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -1,5 +1,8 @@ +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. + import asyncio import os +import sys import torch import time import signal @@ -14,23 +17,24 @@ from text_generation_server.cache import Cache from text_generation_server.interceptor import ExceptionInterceptor from text_generation_server.models import Model, get_model_with_lora_adapters +from text_generation_server.pb import generate_pb2_grpc, generate_pb2 +from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor +from text_generation_server.models.globals import set_model_id +from text_generation_server.models.globals import set_adapter_to_index from text_generation_server.utils.adapter import AdapterInfo try: - from text_generation_server.models.pali_gemma import PaliGemmaBatch + #from text_generation_server.models.pali_gemma import PaliGemmaBatch from text_generation_server.models.vlm_causal_lm import ( VlmCausalLMBatch, ) - from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch + #from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch - VLM_BATCH_TYPES = {PaliGemmaBatch, VlmCausalLMBatch, IdeficsCausalLMBatch} + VLM_BATCH_TYPES = {VlmCausalLMBatch} except (ImportError, NotImplementedError): # These imports can fail on CPU/Non flash. VLM_BATCH_TYPES = set() - -from text_generation_server.pb import generate_pb2_grpc, generate_pb2 -from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor -from text_generation_server.models.globals import set_adapter_to_index +from text_generation_server.utils.version import is_driver_compatible, MIN_TGI_GAUDI_SYNAPSE_VERSION class SignalHandler: @@ -58,16 +62,19 @@ def __init__( self.quantize = model.quantize self.server_urls = server_urls # For some reason, inference_mode does not work well with GLOO which we use on CPU - if model.device.type == "cuda": - # Force inference mode for the lifetime of TextGenerationService - self._inference_mode_raii_guard = torch._C._InferenceMode(True) + # TODO: The inferecemode set messes up the autograd op dispatch. And results in aten::matmul + # op not optimized issue. Will investigate further. + # if model.device.type == "hpu": + # Force inference mode for the lifetime of TextGenerationService + # self._inference_mode_raii_guard = torch._C._InferenceMode(True) + async def Info(self, request, context): return self.model.info async def Health(self, request, context): - if self.model.device.type == "cuda": - torch.zeros((2, 2)).cuda() + if self.model.device.type == "hpu": + torch.zeros((2, 2)).to("hpu") return generate_pb2.HealthResponse() async def ServiceDiscovery(self, request, context): @@ -90,41 +97,17 @@ async def FilterBatch(self, request, context): return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) async def Warmup(self, request, context): - if self.quantize in {"exl2", "gptq"}: - try: - # When using GPTQ, Exllama kernels need some global kernels - # For which we have the finale shapes only after the model has loaded - # This will allocate those buffers. - from text_generation_server.layers.gptq import ( - create_exllama_buffers, - set_device, - ) - - set_device(self.model.device) - create_exllama_buffers(request.max_prefill_tokens) - except ImportError: - pass - if ( - self.model.batch_type in VLM_BATCH_TYPES - ): # Hack, i would rather use kwargs in the `from_pb` call - batch = self.model.batch_type.from_pb_processor( - request.batch, - self.model.tokenizer, - self.model.processor, - self.model.model.config, - self.model.dtype, - self.model.device, - ) - else: - batch = self.model.batch_type.from_pb( - request.batch, self.model.tokenizer, self.model.dtype, self.model.device - ) - max_supported_total_tokens = self.model.warmup(batch) + max_supported_total_tokens = self.model.warmup(request) + return generate_pb2.WarmupResponse(max_supported_total_tokens=max_supported_total_tokens) + # else: + # batch = self.model.batch_type.from_pb( + # request.batch, self.model.tokenizer, self.model.dtype, self.model.device + # ) + + # max_supported_total_tokens = self.model.warmup(batch) + # return generate_pb2.WarmupResponse(max_supported_total_tokens=max_supported_total_tokens) - return generate_pb2.WarmupResponse( - max_supported_total_tokens=max_supported_total_tokens - ) async def Prefill(self, request, context): start = time.time_ns() @@ -144,7 +127,7 @@ async def Prefill(self, request, context): request.batch, self.model.tokenizer, self.model.dtype, self.model.device ) - generations, next_batch, timings = self.model.generate_token(batch) + generations, next_batch, timings = self.model.generate_token([batch]) self.cache.set(next_batch) return generate_pb2.PrefillResponse( @@ -170,21 +153,13 @@ async def Decode(self, request, context): if len(batches) == 0: raise ValueError("All batches are empty") - if len(batches) > 1: - start_concat = time.time_ns() - batch = self.model.batch_type.concatenate(batches) - concat_ns = time.time_ns() - start_concat - else: - batch = batches[0] - concat_ns = None - - generations, next_batch, timings = self.model.generate_token(batch) + generations, next_batch, timings = self.model.generate_token(batches) self.cache.set(next_batch) return generate_pb2.DecodeResponse( generations=[generation.to_pb() for generation in generations], batch=next_batch.to_pb() if next_batch else None, - concat_ns=concat_ns, + concat_ns=None, forward_ns=timings[0], decode_ns=timings[1], total_ns=time.time_ns() - start, @@ -213,18 +188,31 @@ async def serve_inner( dtype: Optional[str] = None, trust_remote_code: bool = False, ): + if not is_driver_compatible(): + logger.warning(f"Current Synapse version is lower than the minimum version supported: {MIN_TGI_GAUDI_SYNAPSE_VERSION}, this could result in failures") + unix_socket_template = "unix://{}-{}" adapter_to_index = {} + logger.info("Server:server_inner: sharded ={}".format(sharded)) + if sharded: + rank = int(os.environ["RANK"]) + logger.info("Server:server_inner: rank ={}".format(rank)) server_urls = [ - unix_socket_template.format(uds_path, rank) - for rank in range(int(os.environ["WORLD_SIZE"])) + unix_socket_template.format(uds_path, rank) for rank in range(int(os.environ["WORLD_SIZE"])) ] local_url = server_urls[int(os.environ["RANK"])] else: local_url = unix_socket_template.format(uds_path, 0) server_urls = [local_url] + logger.info("Server:server_inner: data type = {}, local_url = {}".format(dtype, local_url)) + if dtype == "bfloat16" or None: + data_type = torch.bfloat16 + else: + data_type = torch.float + if revision == "None": + revision = None try: model = get_model_with_lora_adapters( model_id, @@ -233,7 +221,7 @@ async def serve_inner( sharded, quantize, speculate, - dtype, + data_type, trust_remote_code, max_input_tokens, adapter_to_index, @@ -271,6 +259,7 @@ async def serve_inner( while signal_handler.KEEP_PROCESSING: await asyncio.sleep(0.5) + set_model_id(model_id) asyncio.run( serve_inner( model_id, diff --git a/server/text_generation_server/tgi_service.py b/server/text_generation_server/tgi_service.py new file mode 100644 index 00000000000..f0f131268bb --- /dev/null +++ b/server/text_generation_server/tgi_service.py @@ -0,0 +1,45 @@ +import os +from pathlib import Path +from loguru import logger +import sys +from text_generation_server import server +import argparse +from typing import List +from text_generation_server.utils.adapter import parse_lora_adapters + + +def main(args): + logger.info("TGIService: starting tgi service .... ") + logger.info( + "TGIService: --model_id {}, --revision {}, --sharded {}, --speculate {}, --dtype {}, --trust_remote_code {}, --uds_path {} ".format( + args.model_id, args.revision, args.sharded, args.speculate, args.dtype, args.trust_remote_code, args.uds_path + ) + ) + lora_adapters = parse_lora_adapters(os.getenv("LORA_ADAPTERS")) + server.serve( + model_id=args.model_id, + lora_adapters=lora_adapters, + revision=args.revision, + sharded=args.sharded, + quantize=args.quantize, + speculate=args.speculate, + dtype=args.dtype, + trust_remote_code=args.trust_remote_code, + uds_path=args.uds_path, + max_input_tokens=args.max_input_tokens + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model_id", type=str) + parser.add_argument("--revision", type=str) + parser.add_argument("--sharded", type=bool) + parser.add_argument("--speculate", type=int, default=None) + parser.add_argument("--dtype", type=str) + parser.add_argument("--trust_remote_code", type=bool) + parser.add_argument("--uds_path", type=Path) + parser.add_argument("--quantize", type=str) + parser.add_argument("--max_input_tokens", type=int) + args = parser.parse_args() + main(args) diff --git a/server/text_generation_server/utils/__init__.py b/server/text_generation_server/utils/__init__.py index 08ba808d13f..565a7c3ca64 100644 --- a/server/text_generation_server/utils/__init__.py +++ b/server/text_generation_server/utils/__init__.py @@ -1,3 +1,6 @@ +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. + +import text_generation_server.habana_quantization_env from text_generation_server.utils.convert import convert_file, convert_files from text_generation_server.utils.dist import initialize_torch_distributed from text_generation_server.utils.weights import Weights @@ -18,6 +21,9 @@ FinishReason, Sampling, Greedy, + make_tokenizer_optional, + is_tokenizer_transparent, + pad_next_token_chooser_parameters, ) __all__ = [ diff --git a/server/text_generation_server/utils/debug.py b/server/text_generation_server/utils/debug.py new file mode 100644 index 00000000000..ef8d437b73a --- /dev/null +++ b/server/text_generation_server/utils/debug.py @@ -0,0 +1,31 @@ +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. + +import os +import glob +import time + +from optimum.habana.utils import to_gb_rounded +import habana_frameworks.torch as htorch + +START_TS = None +DBG_TRACE_FILENAME = os.environ.get('DBG_TRACE_FILENAME') +if 'GRAPH_VISUALIZATION' in os.environ: + for f in glob.glob('.graph_dumps/*'): + os.remove(f) + + +def count_hpu_graphs(): + return len(glob.glob('.graph_dumps/*PreGraph*')) + + +def dbg_trace(tag, txt): + global START_TS + if DBG_TRACE_FILENAME is not None and int(os.getenv("RANK", 0)) == 0: + if START_TS is None: + START_TS = time.perf_counter() + time_offset = time.perf_counter() - START_TS + mem_stats = htorch.hpu.memory.memory_stats() + mem_used = to_gb_rounded(mem_stats['InUse']) + max_mem_used = to_gb_rounded(mem_stats['MaxInUse']) + print(f'ts:{time_offset:.3f}s g:{count_hpu_graphs()} mu:{mem_used:.1f}GB ' + f'mmu:{max_mem_used:.1f}GB | {tag} | {txt}', flush=True, file=open(DBG_TRACE_FILENAME, 'a')) diff --git a/server/text_generation_server/utils/dist.py b/server/text_generation_server/utils/dist.py index 82aeba6ce9f..d370a3d5cea 100644 --- a/server/text_generation_server/utils/dist.py +++ b/server/text_generation_server/utils/dist.py @@ -3,7 +3,6 @@ from datetime import timedelta from loguru import logger -from text_generation_server.utils.import_utils import SYSTEM # Tensor Parallelism settings RANK = int(os.getenv("RANK", "0")) @@ -45,6 +44,12 @@ def rank(self): def initialize_torch_distributed(): + import habana_frameworks.torch.core as htcore + + rank = int(os.getenv("RANK", "0")) + world_size = int(os.getenv("WORLD_SIZE", "1")) + + options = None if torch.cuda.is_available(): from torch.distributed import ProcessGroupNCCL @@ -56,9 +61,21 @@ def initialize_torch_distributed(): backend = "nccl" options = ProcessGroupNCCL.Options() options.is_high_priority_stream = True - options._timeout = timedelta(seconds=120) + options._timeout = timedelta(seconds=60) + elif torch.hpu.is_available(): + backend = "hccl" + n_hpus = torch.hpu.device_count() + if world_size > n_hpus: + raise ValueError(f"WORLD_SIZE ({world_size}) is higher than the number of available HPUs ({n_hpus}).") else: - backend = "gloo" + try: + import oneccl_bindings_for_pytorch + + backend = "ccl" + if os.getenv("CCL_WORKER_COUNT", None) is None: + os.environ["CCL_WORKER_COUNT"] = str(1) + except ImportError: + backend = "gloo" options = None if WORLD_SIZE == 1: @@ -69,24 +86,13 @@ def initialize_torch_distributed(): if not torch.distributed.is_initialized(): # Call the init process. - if SYSTEM == "ipex": - import intel_extension_for_pytorch as ipex - - ipex.distributed.init_process_group( - backend="ccl", - world_size=WORLD_SIZE, - rank=RANK, - timeout=timedelta(seconds=120), - pg_options=options, - ) - else: - torch.distributed.init_process_group( - backend=backend, - world_size=WORLD_SIZE, - rank=RANK, - timeout=timedelta(seconds=120), - pg_options=options, - ) + torch.distributed.init_process_group( + backend=backend, + world_size=WORLD_SIZE, + rank=RANK, + timeout=timedelta(seconds=60), + pg_options=options, + ) else: logger.warning("torch.distributed is already initialized.") diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index 9abd886f250..104fc2f098f 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -1,5 +1,6 @@ import math import torch +import habana_frameworks.torch.core as htcore from loguru import logger from typing import Dict, Union @@ -43,37 +44,31 @@ def __init__( if typical_p is not None and typical_p < 1.0: self.warpers.append(TypicalLogitsWarper(mass=typical_p)) - self.cuda_graph = None + self.hpu_graph = None self.static_scores = None self.static_warped_scores = None self.static_next_logprob = None def __call__(self, scores): - if torch.cuda.is_available(): - if self.cuda_graph is None: - self.static_scores = scores - self.cuda_graph = torch.cuda.CUDAGraph() + if self.hpu_graph is None: + self.static_scores = scores.clone().contiguous() + self.static_warped_scores = scores.clone().contiguous() + self.static_next_logprob = scores.clone().contiguous() + self.hpu_graph = htcore.hpu.HPUGraph() - with torch.cuda.graph(self.cuda_graph, pool=mempool): - local_scores = self.static_scores - for warper in self.warpers: - local_scores = warper(None, local_scores) + with htcore.hpu.graph(self.hpu_graph): + local_scores = self.static_scores + for warper in self.warpers: + local_scores = warper(None, local_scores) - self.static_warped_scores = local_scores - # Compute logprobs - self.static_next_logprob = torch.log_softmax( - self.static_warped_scores, -1 - ) + self.static_warped_scores.copy_(local_scores) + # Compute logprobs + self.static_next_logprob.copy_(torch.log_softmax(self.static_warped_scores, -1)) - self.static_scores.copy_(scores) - self.cuda_graph.replay() + self.static_scores.copy_(scores) + self.hpu_graph.replay() - return self.static_warped_scores, self.static_next_logprob - - # CPU branch - for warper in self.warpers: - scores = warper(None, scores) - return scores, torch.log_softmax(scores, -1) + return self.static_warped_scores, self.static_next_logprob @lru_cache(10) @@ -83,9 +78,7 @@ def static_warper( top_p: Optional[float], typical_p: Optional[float], ) -> StaticWarper: - return StaticWarper( - temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p - ) + return StaticWarper(temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p) class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor): @@ -102,17 +95,13 @@ class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor): def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device): self.penalty = penalty - self.penalty_tensor = torch.tensor( - penalty, dtype=dtype, device=device - ).unsqueeze(1) + self.penalty_tensor = torch.tensor(penalty, dtype=dtype, device=device).unsqueeze(1) def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: score = torch.gather(scores, 1, input_ids) # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability - score = torch.where( - score < 0, score * self.penalty_tensor, score / self.penalty_tensor - ) + score = torch.where(score < 0, score * self.penalty_tensor, score / self.penalty_tensor) scores.scatter_(1, input_ids, score) return scores @@ -170,9 +159,11 @@ def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tenso vocab_size = scores.size(1) # Calculate the frequency for each token so far - token_freq = torch.zeros(batch_size, vocab_size, device=input_ids.device) + token_freq = torch.zeros( + batch_size, vocab_size, dtype=scores.dtype, device=scores.device + ) token_freq.scatter_add_( - 1, input_ids, torch.ones_like(input_ids, dtype=torch.float) + 1, input_ids, torch.ones_like(input_ids, dtype=scores.dtype, device=scores.device) ) token_freq /= input_size @@ -199,13 +190,9 @@ class HeterogeneousTemperatureLogitsWarper: The value used to module the logits distribution. """ - def __init__( - self, temperature: List[float], dtype: torch.dtype, device: torch.device - ): + def __init__(self, temperature: List[float], dtype: torch.dtype, device: torch.device): self.temperature = temperature - self.temperature_tensor = torch.tensor( - temperature, dtype=dtype, device=device - ).unsqueeze(1) + self.temperature_tensor = torch.tensor(temperature, dtype=dtype, device=device).unsqueeze(1) def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: scores.div_(self.temperature_tensor) @@ -244,9 +231,7 @@ def __init__( min_tokens_to_keep: int = 1, ): self.top_p = top_p - self.top_p_opposite = 1 - torch.tensor( - top_p, dtype=dtype, device=device - ).unsqueeze(1) + self.top_p_opposite = 1 - torch.tensor(top_p, dtype=dtype, device=device).unsqueeze(1) self.filter_value = filter_value self.min_tokens_to_keep = min_tokens_to_keep @@ -263,9 +248,7 @@ def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tenso sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0 # scatter sorted tensors to original indexing - indices_to_remove = sorted_indices_to_remove.scatter( - 1, sorted_indices, sorted_indices_to_remove - ) + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value) return warped_scores @@ -313,9 +296,7 @@ def __init__( disabled = [x == 0 for x in top_k] if any(disabled): - self.top_k_disabled_mask = torch.tensor( - disabled, dtype=torch.bool, device=device - ).view(-1, 1) + self.top_k_disabled_mask = torch.tensor(disabled, dtype=torch.bool, device=device).view(-1, 1) else: self.top_k_disabled_mask = None @@ -351,9 +332,7 @@ def filter(self, indices): self.max_top_k = max(self.top_k) if self.top_k_disabled_mask is not None: - self.top_k_disabled_mask = ( - self.top_k_disabled_mask[indices] if any(disabled) else None - ) + self.top_k_disabled_mask = self.top_k_disabled_mask[indices] if any(disabled) else None return self return None @@ -419,15 +398,11 @@ def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tenso if self.disabled_mask is not None: last_ind.masked_fill_(self.disabled_mask, scores.shape[-1] - 1) - sorted_indices_to_remove = sorted_scores > sorted_scores.gather( - 1, last_ind.view(-1, 1) - ) + sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1)) if self.min_tokens_to_keep > 1: # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 - indices_to_remove = sorted_indices_to_remove.scatter( - 1, sorted_indices, sorted_indices_to_remove - ) + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value) @@ -441,9 +416,7 @@ def filter(self, indices): self.mass_tensor = self.mass_tensor[indices] if self.disabled_mask is not None: - self.disabled_mask = ( - self.disabled_mask[indices] if any(disabled) else None - ) + self.disabled_mask = self.disabled_mask[indices] if any(disabled) else None return self return None @@ -521,13 +494,7 @@ def _advance(next_token_id, fsm_grammar_state, fsm): def _cached_compile_fsm(grammar_type, schema, tokenizer): start_time = time.time() if grammar_type == GrammarType.GRAMMAR_TYPE_JSON: - try: - schema = build_regex_from_schema(schema) - # TODO: this is only here short term to avoid crashing the python server, mid term we want this in the rust/router layer - except Exception as e: - logger.error(f"Error compiling FSM, grammar won't be enforced \n{e}") - # allows everything - schema = "(.*?)" + schema = build_regex_from_schema(schema) elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX: pass # schema is already a regex just here for clarity fsm = RegexFSM(schema, tokenizer) @@ -586,7 +553,7 @@ def __call__( mask = torch.full_like(logits, -math.inf) for i in range(logits.shape[0]): fsm = self.fsms[i] - if fsm_grammar_states[i] == -1 or fsm is None: + if fsm is None: continue allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i]) mask[i, allowed_tokens] = 0 diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 9ab49665a75..1136fa963de 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -247,10 +247,12 @@ def __init__( tokenizer: PreTrainedTokenizerBase, grammars: List[str], grammar_types: List[int], - fsm_grammar_states=List[int], + fsm_grammar_states:List[int], + quantization_enabled: bool, ): warpers = [] + # TODO: enable watermark with FP8 quantization self.watermark_processor = ( HeterogeneousProcessorWrapper( { @@ -259,7 +261,7 @@ def __init__( if do_watermark } ) - if any(watermark) + if any(watermark) and not quantization_enabled else None ) @@ -431,6 +433,18 @@ def advance_grammar_single(self, grammar_state_index: int, next_id: int): ) return self + def advance_grammar_single_with_past_state( + self, grammar_state_index: int, next_id: torch.Tensor, past_state: int + ): + if self.grammar_processor is not None: + next_id = next_id.item() + self.fsm_grammar_states[grammar_state_index] = ( + self.grammar_processor.advance_at_index( + next_id, past_state, grammar_state_index, + ) + ) + return self + def filter(self, indices): if self.watermark_processor is not None: self.watermark_processor = self.watermark_processor.filter(indices) @@ -481,6 +495,7 @@ def from_pb( device: torch.device, tokenizer: PreTrainedTokenizerBase, fsm_grammar_states: Optional[List[int]] = None, + quantization_enabled: bool = False, ) -> "HeterogeneousNextTokenChooser": return HeterogeneousNextTokenChooser( watermark=[pb_.watermark for pb_ in pb], @@ -500,12 +515,37 @@ def from_pb( fsm_grammar_states=( fsm_grammar_states if fsm_grammar_states else [0] * len(pb) ), + quantization_enabled=quantization_enabled, ) +def pad_next_token_chooser_parameters( + parameters: List[generate_pb2.NextTokenChooserParameters], + expected_size: int, +) -> List[generate_pb2.NextTokenChooserParameters]: + # disable all logits processors to minimize padding overhead + empty_parameters = generate_pb2.NextTokenChooserParameters( + temperature=1.0, + top_k=0, + top_p=1.0, + typical_p=1.0, + do_sample=False, + seed=0, + repetition_penalty=1.0, + frequency_penalty=0.0, + watermark=False, + grammar="", + grammar_type=0, + ) + parameters.extend( + [empty_parameters] * (expected_size - len(parameters)) + ) + return parameters + + class Sampling: def __init__(self, seed: int, device: str = "cpu"): - self.generator = torch.Generator(device) + self.generator = torch.Generator("cpu") self.generator.manual_seed(seed) self.seed = seed @@ -541,7 +581,7 @@ def __init__(self, do_sample: List[bool], seeds: List[int], device: torch.device self.greedy = Greedy() def __call__(self, logits): - out = torch.empty(logits.shape[0], dtype=torch.int64, device=logits.device) + out = torch.zeros(logits.shape[0], dtype=torch.int64, device=logits.device) if self.greedy_indices: # Computing for all indices is faster than slicing torch.argmax(logits, -1, out=out) @@ -643,3 +683,50 @@ def batch_top_tokens( batch_top_token_logprobs.append(row_top_token_logprobs) return batch_top_token_ids, batch_top_token_logprobs + + +def make_tokenizer_optional(tokenizer): + class _(type(tokenizer)): + def __call__( + self, + text, + return_tensors, + padding, + return_token_type_ids, + truncation, + max_length + ): + assert return_tensors == "pt", "inccorrect input arguments when calling TransparentTokenizer" + assert padding == "max_length" or padding == "longest", "inccorrect input arguments when calling TransparentTokenizer" + assert return_token_type_ids == False, "inccorrect input arguments when calling TransparentTokenizer" + assert truncation == True, "inccorrect input arguments when calling TransparentTokenizer" + + def str_token_to_int(i): + if i == '?': + return tokenizer.pad_token_id + else: + return int(i) + all_tokens = [[str_token_to_int(i.strip()) for i in inner_text.split(',')] + for inner_text in text] + if padding == "longest": + max_length = max(len(tokens) for tokens in all_tokens) + return {"input_ids": torch.tensor([[tokenizer.pad_token_id] * (max_length - len(tokens)) + tokens for tokens in all_tokens]), + "attention_mask": torch.tensor([[0] * (max_length - len(tokens)) + [1] * len(tokens) for tokens in all_tokens])} + + def decode( + self, + token_ids, + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: bool = None, + **kwargs, + ) -> str: + return ','.join(str(i) for i in to_py_obj(token_ids)) + + import os + if os.getenv("SKIP_TOKENIZER_IN_TGI", "false").lower() == "true": + tokenizer.__class__ = _ + tokenizer.is_transparent = True + + +def is_tokenizer_transparent(tokenizer): + return hasattr(tokenizer, "is_transparent") and tokenizer.is_transparent is True diff --git a/server/text_generation_server/utils/version.py b/server/text_generation_server/utils/version.py new file mode 100644 index 00000000000..a72a9ea7b48 --- /dev/null +++ b/server/text_generation_server/utils/version.py @@ -0,0 +1,12 @@ +from optimum.habana.utils import get_driver_version +from packaging.version import Version + +MIN_TGI_GAUDI_SYNAPSE_VERSION=Version("1.16.0") + + +def is_driver_compatible(): + driver_version = get_driver_version() + if driver_version is not None: + if driver_version < MIN_TGI_GAUDI_SYNAPSE_VERSION: + return False + return True \ No newline at end of file diff --git a/server/text_generation_server/utils/watermark.py b/server/text_generation_server/utils/watermark.py index 5d8f531234f..5092b076c33 100644 --- a/server/text_generation_server/utils/watermark.py +++ b/server/text_generation_server/utils/watermark.py @@ -34,7 +34,7 @@ def __init__( # watermarking parameters self.gamma = gamma self.delta = delta - self.rng = torch.Generator(device=device) + self.rng = torch.Generator(device="cpu") self.hash_key = hash_key def _seed_rng(self, input_ids: Union[List[int], torch.LongTensor]):