Skip to content

Commit

Permalink
block while
Browse files Browse the repository at this point in the history
  • Loading branch information
cyber-pioneer committed Feb 26, 2024
1 parent 8d49caf commit bdb7a3a
Showing 1 changed file with 34 additions and 34 deletions.
68 changes: 34 additions & 34 deletions paddlenlp/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1442,40 +1442,40 @@ def _post_process_(outputs, input_ids, cur_len, origin_len, scores, unfinished_f
# make the shape of attention_mask = (-1, -1, -1, -1) in dy2static.
model_kwargs["attention_mask"] = paddle.reshape(attn_mask, paddle.shape(attn_mask))
model_kwargs["cache"] = outputs[1] if isinstance(outputs, tuple) else None
max_new_tokens = paddle.full([1], max_new_tokens + cur_len - 1, dtype="int64")

if hasattr(paddle.framework, "_no_check_dy2st_diff"):
# TODO(daisiming): _no_check_dy2st_diff is used to turn off the checking of behavior
# inconsistency between dynamic graph and static graph. _no_check_dy2st_diff should be
# removed after static graphs support inplace and stride.
with paddle.framework._no_check_dy2st_diff():
# llama infer while begin
while cur_len < max_new_tokens and paddle.any(unfinished_flag):
input_ids, scores, unfinished_flag, model_kwargs = _post_process_(
_forward_(**model_kwargs),
input_ids,
cur_len_gpu,
origin_len_gpu,
scores,
unfinished_flag,
model_kwargs,
)
paddle.increment(cur_len)
paddle.increment(cur_len_gpu)
# llama infer while end
else:
while cur_len < max_new_tokens and paddle.any(unfinished_flag):
input_ids, scores, unfinished_flag, model_kwargs = _post_process_(
_forward_(**model_kwargs),
input_ids,
cur_len_gpu,
origin_len_gpu,
scores,
unfinished_flag,
model_kwargs,
)
paddle.increment(cur_len)
paddle.increment(cur_len_gpu)
# max_new_tokens = paddle.full([1], max_new_tokens + cur_len - 1, dtype="int64")

# if hasattr(paddle.framework, "_no_check_dy2st_diff"):
# # TODO(daisiming): _no_check_dy2st_diff is used to turn off the checking of behavior
# # inconsistency between dynamic graph and static graph. _no_check_dy2st_diff should be
# # removed after static graphs support inplace and stride.
# with paddle.framework._no_check_dy2st_diff():
# # llama infer while begin
# while cur_len < max_new_tokens and paddle.any(unfinished_flag):
# input_ids, scores, unfinished_flag, model_kwargs = _post_process_(
# _forward_(**model_kwargs),
# input_ids,
# cur_len_gpu,
# origin_len_gpu,
# scores,
# unfinished_flag,
# model_kwargs,
# )
# paddle.increment(cur_len)
# paddle.increment(cur_len_gpu)
# # llama infer while end
# else:
# while cur_len < max_new_tokens and paddle.any(unfinished_flag):
# input_ids, scores, unfinished_flag, model_kwargs = _post_process_(
# _forward_(**model_kwargs),
# input_ids,
# cur_len_gpu,
# origin_len_gpu,
# scores,
# unfinished_flag,
# model_kwargs,
# )
# paddle.increment(cur_len)
# paddle.increment(cur_len_gpu)

return input_ids[:, origin_len:], scores

Expand Down

0 comments on commit bdb7a3a

Please sign in to comment.