Skip to content

Commit

Permalink
add check_neural_compressor_min_version for 4 bit behavior
Browse files Browse the repository at this point in the history
Signed-off-by: Xin <[email protected]>
  • Loading branch information
xin3he committed Nov 20, 2024
1 parent 5debcdf commit 5c5d38c
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
6 changes: 3 additions & 3 deletions examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
check_habana_frameworks_version,
check_optimum_habana_min_version,
get_habana_frameworks_version,
check_neural_compressor_min_version,
set_seed,
)

Expand Down Expand Up @@ -269,9 +270,8 @@ def setup_model(args, model_dtype, model_kwargs, logger):
original_model=org_model,
**model_kwargs,
)
# TODO: This will be removed in v1.19 Synapse release
# the loaded model should have the same dtype as original_model
model = model.to(model_kwargs["torch_dtype"])
if not check_neural_compressor_min_version("3.2"):
model = model.to(model_kwargs["torch_dtype"])
else:
if args.assistant_model is not None:
assistant_model = AutoModelForCausalLM.from_pretrained(
Expand Down
8 changes: 8 additions & 0 deletions optimum/habana/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,14 @@ def check_habana_frameworks_version(req_version):
)


def check_neural_compressor_min_version(req_version):
"""
Checks if the installed version of `neural_compressor` is larger than or equal to `req_version`.
"""
import neural_compressor
return version.Version(neural_compressor.__version__) >= version.Version(req_version)


def get_device_name():
"""
Returns the name of the current device: Gaudi or Gaudi2.
Expand Down

0 comments on commit 5c5d38c

Please sign in to comment.