Skip to content

Commit

Permalink
Add bf16 support in LMCorrector and update README files for Qwen2 com…
Browse files Browse the repository at this point in the history
…patibility
  • Loading branch information
Jacob-Zhou committed Jan 3, 2025
1 parent db5c08b commit 37c2e83
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 16 deletions.
12 changes: 7 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ pip install flash-attn --no-build-isolation
<!-- Add a warning about Qwen2.5 -->

> [!WARNING]
> Reported by a user, using Qwen2.5 family models without flash-attn will lead unexpected errors.
> Reported by a user, using Qwen2 or Qwen2.5 family models without flash-attn will lead unexpected errors.
>
> Please install flash-attn to avoid this issue.
>
> We are working on a fix for this issue, making it compatible with Qwen2.5 family without flash-attn.
> Please install flash-attn to avoid this issue. Or you can set `torch_dtype=torch.bfloat16` in the `LMCorrector` class to avoid this issue.
>
> Though we strongly recommend using flash-attn, which will significantly reduce the memory usage and speed up the inference process.
## Usage

Expand All @@ -84,6 +84,7 @@ from lmcsc import LMCorrector
corrector = LMCorrector(
model="Qwen/Qwen2.5-0.5B",
config_path="configs/default_config.yaml",
torch_dtype=torch.bfloat16, # the default torch_dtype is torch.float16, but it will lead unexpected errors when using Qwen2 or Qwen2.5 family models without flash-attn.
)

outputs = corrector("完善农产品上行发展机智。")
Expand All @@ -108,7 +109,8 @@ python api_server.py \
--model "Qwen/Qwen2.5-0.5B" \
--host 127.0.0.1 \
--port 8000 \
--workers 1
--workers 1 \
--bf16 # use bf16 to avoid unexpected errors when using Qwen2 or Qwen2.5 family models without flash-attn.
```

You can use `curl` to test the RESTful API server.
Expand Down
12 changes: 7 additions & 5 deletions README.zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ pip install flash-attn --no-build-isolation
<!-- Add a warning about Qwen2.5 -->

> [!WARNING]
> 一位用户报告说,使用 Qwen2.5 系列模型时,如果没有安装 flash-attn,代码会产生意料之外的行为。
> 一位用户报告说,使用 Qwen2 或 Qwen2.5 系列模型时,如果没有安装 flash-attn,代码会产生意料之外的行为。
>
> 因此如在本代码库中使用 Qwen2.5 系列模型,请务必安装 flash-attn。
>
> 同时我们正在努力解决这个问题,使未安装 flash-attn 的用户也能正常使用 Qwen2.5 系列模型
> 请安装 flash-attn 来避免这个问题。或者您可以在 `LMCorrector` 类中设置 `torch_dtype=torch.bfloat16` 来避免这个问题
>
> 虽然我们强烈建议使用 flash-attn,它将显著减少显存使用并加快推理速度
## 使用方法

Expand All @@ -84,6 +84,7 @@ from lmcsc import LMCorrector
corrector = LMCorrector(
model="Qwen/Qwen2.5-0.5B",
config_path="configs/default_config.yaml",
torch_dtype=torch.bfloat16, # 使用 bfloat16 来避免 Qwen2 或 Qwen2.5 系列模型在未安装 flash-attn 时产生意料之外的行为。
)

outputs = corrector("完善农产品上行发展机智。")
Expand All @@ -108,7 +109,8 @@ python api_server.py \
--model "Qwen/Qwen2.5-0.5B" \
--host 127.0.0.1 \
--port 8000 \
--workers 1
--workers 1 \
--bf16 # 使用 bfloat16 来避免 Qwen2 或 Qwen2.5 系列模型在未安装 flash-attn 时产生意料之外的行为。
```

您可以使用 `curl` 来测试 RESTful API 服务器。
Expand Down
2 changes: 2 additions & 0 deletions api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,13 +207,15 @@ def predict_stream(gen_params):
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--workers", type=int, default=1)
parser.add_argument("--debug", action="store_true")
parser.add_argument("--bf16", action="store_true")
args = parser.parse_args()

# Load LLM
logger.info(f"Loading model {args.model} from {args.config_path}")
corrector = LMCorrector(
model=args.model,
config_path=args.config_path,
torch_dtype=torch.bfloat16 if args.bf16 else torch.float16,
)
logger.info(f"Model {args.model} loaded successfully")
uvicorn.run(app, host=args.host, port=args.port, workers=args.workers, reload=args.debug)
13 changes: 7 additions & 6 deletions lmcsc/corrector.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,15 @@ def __init__(
self.config = yaml.safe_load(config_file)

# Set parameters, using either the provided ones or those from the configuration
self.n_beam = n_beam or self.config['n_beam']
self.n_beam_hyps_to_keep = n_beam_hyps_to_keep or self.config['n_beam_hyps_to_keep']
self.n_observed_chars = n_observed_chars or self.config['n_observed_chars']
# Set parameters with provided values or defaults from config
self.n_beam = n_beam if n_beam is not None else self.config['n_beam']
self.n_beam_hyps_to_keep = n_beam_hyps_to_keep if n_beam_hyps_to_keep is not None else self.config['n_beam_hyps_to_keep']
self.n_observed_chars = n_observed_chars if n_observed_chars is not None else self.config['n_observed_chars']
self.alpha = alpha if alpha is not None else self.config['alpha']
self.distortion_model_smoothing = distortion_model_smoothing or self.config['distortion_model_smoothing']
self.distortion_model_smoothing = distortion_model_smoothing if distortion_model_smoothing is not None else self.config['distortion_model_smoothing']
self.use_faithfulness_reward = use_faithfulness_reward if use_faithfulness_reward is not None else self.config['use_faithfulness_reward']
self.distortion_probs = customized_distortion_probs or self.config['distortion_probs']
self.max_length = max_length or self.config['max_length']
self.distortion_probs = customized_distortion_probs if customized_distortion_probs is not None else self.config['distortion_probs']
self.max_length = max_length if max_length is not None else self.config['max_length']

# Load the language model
if isinstance(model, str):
Expand Down

0 comments on commit 37c2e83

Please sign in to comment.