From 974f93955e47ec074e8ac04dff9c354879c3e13c Mon Sep 17 00:00:00 2001 From: Jan Kaniecki Date: Wed, 11 Dec 2024 13:30:26 +0200 Subject: [PATCH 1/5] Assign intermediate states to pre-located tensor to avoid memcopying --- vllm/model_executor/models/mllama.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 57c6bbc7c494d..5512449ae2628 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -509,21 +509,27 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ) -> Union[Tuple, BaseModelOutput]: - encoder_states = () + encoder_states = torch.empty((len(self.output_hidden_states), + hidden_states.size(0), + hidden_states.size(1), + hidden_states.size(2)), + dtype=hidden_states.dtype, + device=hidden_states.device) + hidden_states_idx = 0 for i, encoder_layer in enumerate(self.layers): if i in self.output_hidden_states: - encoder_states = encoder_states + (hidden_states, ) + encoder_states[hidden_states_idx] = hidden_states + hidden_states_idx += 1 hidden_states = encoder_layer( hidden_states, attention_mask, ) if len(self.layers) - 1 in self.output_hidden_states: - encoder_states = encoder_states + (hidden_states, ) - - return hidden_states, encoder_states + encoder_states[hidden_states_idx] = hidden_states + return hidden_states, encoder_states.permute(1, 2, 3, 0) class MllamaVisionModel(nn.Module): @@ -658,8 +664,6 @@ def forward(self, pixel_values: torch.Tensor, attention_mask=attention_mask, ) hidden_state, intermediate_hidden_states = output[0], output[1] - intermediate_hidden_states = torch.stack(intermediate_hidden_states, - dim=-1) # apply global encoder hidden_state = self.layernorm_post(hidden_state) From 007bc06941cd1ce03e55986f01ecdd7f6a6dddb5 Mon Sep 17 00:00:00 2001 From: Jan Kaniecki Date: Wed, 11 Dec 2024 12:43:23 +0100 Subject: [PATCH 2/5] Yapf formatting --- vllm/model_executor/models/mllama.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 5512449ae2628..7c6371674e155 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -509,12 +509,11 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ) -> Union[Tuple, BaseModelOutput]: - encoder_states = torch.empty((len(self.output_hidden_states), - hidden_states.size(0), - hidden_states.size(1), - hidden_states.size(2)), - dtype=hidden_states.dtype, - device=hidden_states.device) + encoder_states = torch.empty( + (len(self.output_hidden_states), hidden_states.size(0), + hidden_states.size(1), hidden_states.size(2)), + dtype=hidden_states.dtype, + device=hidden_states.device) hidden_states_idx = 0 for i, encoder_layer in enumerate(self.layers): @@ -531,6 +530,7 @@ def forward( return hidden_states, encoder_states.permute(1, 2, 3, 0) + class MllamaVisionModel(nn.Module): def __init__( From f4f9322a772d171129ce0429cecfbb97f146144a Mon Sep 17 00:00:00 2001 From: Jan Kaniecki Date: Tue, 17 Dec 2024 14:03:27 +0200 Subject: [PATCH 3/5] Use tensor indexing to avoid possible recompilations --- vllm/model_executor/models/mllama.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 7c6371674e155..3d1b1acb39d2c 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -514,19 +514,19 @@ def forward( hidden_states.size(1), hidden_states.size(2)), dtype=hidden_states.dtype, device=hidden_states.device) - hidden_states_idx = 0 + hidden_states_idx = torch.tensor([0], device=hidden_states.device) for i, encoder_layer in enumerate(self.layers): if i in self.output_hidden_states: - encoder_states[hidden_states_idx] = hidden_states - hidden_states_idx += 1 + encoder_states.index_copy_(0, hidden_states_idx, hidden_states) + hidden_states_idx.add_(1) hidden_states = encoder_layer( hidden_states, attention_mask, ) if len(self.layers) - 1 in self.output_hidden_states: - encoder_states[hidden_states_idx] = hidden_states + encoder_states.index_copy_(0, hidden_states_idx, hidden_states) return hidden_states, encoder_states.permute(1, 2, 3, 0) From f6dff07291c1967540f7dd84ea523007c2e19097 Mon Sep 17 00:00:00 2001 From: Jan Kaniecki Date: Tue, 17 Dec 2024 15:22:11 +0100 Subject: [PATCH 4/5] Update mllama.py --- vllm/model_executor/models/mllama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 3d1b1acb39d2c..bf9f733edd845 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -518,7 +518,7 @@ def forward( for i, encoder_layer in enumerate(self.layers): if i in self.output_hidden_states: - encoder_states.index_copy_(0, hidden_states_idx, hidden_states) + encoder_states.index_copy_(0, hidden_states_idx, hidden_states.unsqueeze(0)) hidden_states_idx.add_(1) hidden_states = encoder_layer( hidden_states, @@ -526,7 +526,7 @@ def forward( ) if len(self.layers) - 1 in self.output_hidden_states: - encoder_states.index_copy_(0, hidden_states_idx, hidden_states) + encoder_states.index_copy_(0, hidden_states_idx, hidden_states.unsqueeze(0)) return hidden_states, encoder_states.permute(1, 2, 3, 0) From b99a3fe5454325263eabaf9c0f2a2862dc7ec7c5 Mon Sep 17 00:00:00 2001 From: Jan Kaniecki Date: Tue, 17 Dec 2024 15:50:42 +0100 Subject: [PATCH 5/5] Yapf formatting --- vllm/model_executor/models/mllama.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index bf9f733edd845..12f0a147e7809 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -518,7 +518,8 @@ def forward( for i, encoder_layer in enumerate(self.layers): if i in self.output_hidden_states: - encoder_states.index_copy_(0, hidden_states_idx, hidden_states.unsqueeze(0)) + encoder_states.index_copy_(0, hidden_states_idx, + hidden_states.unsqueeze(0)) hidden_states_idx.add_(1) hidden_states = encoder_layer( hidden_states, @@ -526,7 +527,8 @@ def forward( ) if len(self.layers) - 1 in self.output_hidden_states: - encoder_states.index_copy_(0, hidden_states_idx, hidden_states.unsqueeze(0)) + encoder_states.index_copy_(0, hidden_states_idx, + hidden_states.unsqueeze(0)) return hidden_states, encoder_states.permute(1, 2, 3, 0)