@@ -193,19 +193,33 @@ def setup_caches(self, max_batch_size, max_seq_length, training: bool=False, kv_
193
193
194
194
195
195
def forward (self , idx : Tensor , input_pos : Optional [Tensor ] = None ) -> Tensor :
196
+ """Forward pass of the model.
197
+
198
+ Args:
199
+ idx (`torch.LongTensor` of shape `(batch_size, seq_length)`):
200
+ Indices of input sequence tokens in the vocabulary.
201
+ input_pos (`torch.LongTensor` of shape `(batch_size, seq_length)`, *optional*):
202
+ Indices of positions of each input sequence tokens in the position embeddings.
203
+ This argument is optional for training mode but required for
204
+ inference mode(when model.setup_caches(training=False) is used).
205
+
206
+ Returns:
207
+ Tensor: The output logits tensor.
208
+ """
196
209
assert self .freqs_cis is not None , "Caches must be initialized first"
197
210
198
211
if input_pos is None :
199
212
mask = None
200
213
freqs_cis = self .freqs_cis [:idx .shape [1 ]]
201
- elif not self .linear_causal_mask :
202
- mask = self .causal_mask [None , None , input_pos ]
203
- elif len (input_pos )> 1 and self .linear_causal_mask : # prefill for linear causal mask
204
- mask = torch .tril (torch .ones (len (input_pos ), self .max_seq_length , dtype = torch .bool , device = input_pos .device )).unsqueeze (0 ).unsqueeze (0 )
205
- else : # decode_one_token for linear causal mask
206
- self .causal_mask [0 ,0 ,0 ,input_pos ] = 1
207
- mask = self .causal_mask
208
- freqs_cis = self .freqs_cis [input_pos ]
214
+ else :
215
+ if not self .linear_causal_mask :
216
+ mask = self .causal_mask [None , None , input_pos ]
217
+ elif len (input_pos )> 1 and self .linear_causal_mask : # prefill for linear causal mask
218
+ mask = torch .tril (torch .ones (len (input_pos ), self .max_seq_length , dtype = torch .bool , device = input_pos .device )).unsqueeze (0 ).unsqueeze (0 )
219
+ else : # decode_one_token for linear causal mask
220
+ self .causal_mask [0 ,0 ,0 ,input_pos ] = 1
221
+ mask = self .causal_mask
222
+ freqs_cis = self .freqs_cis [input_pos ]
209
223
210
224
x = self .tok_embeddings (idx )
211
225
0 commit comments