diff --git a/megatron/core/models/common/embeddings/rotary_pos_embedding.py b/megatron/core/models/common/embeddings/rotary_pos_embedding.py index e3923d016b..2d633962eb 100644 --- a/megatron/core/models/common/embeddings/rotary_pos_embedding.py +++ b/megatron/core/models/common/embeddings/rotary_pos_embedding.py @@ -198,7 +198,7 @@ def get_rotary_seq_len( if packed_seq_params is not None: # max_seqlen are the max sequence length in the packed sequence before being divived # by the tp and cp size. - return max(packed_seq_params.max_seqlen_q, packed_seq_params.max_seqlen_kv) + return max(packed_seq_params.max_seqlen_q, packed_seq_params.max_seqlen_kv).item() elif inference_params is not None: rotary_seq_len = inference_params.max_sequence_length else: