Skip to content

Commit

Permalink
Fix forward for chatglm
Browse files Browse the repository at this point in the history
  • Loading branch information
slyalin committed Dec 7, 2023
1 parent bcb7cac commit d4c165b
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,11 +379,15 @@ def forward(
for input_name in self.key_value_input_names:
model_inputs = self.model.input(input_name)
shape = model_inputs.get_partial_shape()
shape[0] = batch_size
if shape[2].is_dynamic:
shape[2] = 0
if self.config.model_type == 'chatglm':
shape[0] = 0
shape[1] = batch_size
else:
shape[1] = 0
shape[0] = batch_size
if shape[2].is_dynamic:
shape[2] = 0
else:
shape[1] = 0
inputs[input_name] = Tensor(model_inputs.get_element_type(), shape.get_shape())
else:
# past_key_values are not used explicitly, instead they are handled inside the model
Expand Down

0 comments on commit d4c165b

Please sign in to comment.