@@ -41,7 +41,7 @@ func New(model string, opts ...ModelOption) (*LLama, error) {
41
41
C .bool (mo .F16Memory ), C .bool (mo .MLock ), C .bool (mo .Embeddings ), C .bool (mo .MMap ), C .bool (mo .LowVRAM ),
42
42
C .int (mo .NGPULayers ), C .int (mo .NBatch ), C .CString (mo .MainGPU ), C .CString (mo .TensorSplit ), C .bool (mo .NUMA ),
43
43
C .float (mo .FreqRopeBase ), C .float (mo .FreqRopeScale ),
44
- C .bool (MulMatQ ), loraAdapter , loraBase ,
44
+ C .bool (MulMatQ ), loraAdapter , loraBase , C . bool ( mo . Perplexity ),
45
45
)
46
46
47
47
if result == nil {
@@ -123,6 +123,7 @@ func (l *LLama) TokenEmbeddings(tokens []int, opts ...PredictOption) ([]float32,
123
123
C .bool (po .PromptCacheRO ),
124
124
C .CString (po .Grammar ),
125
125
C .float (po .RopeFreqBase ), C .float (po .RopeFreqScale ), C .float (po .NegativePromptScale ), C .CString (po .NegativePrompt ),
126
+ C .int (po .NDraft ),
126
127
)
127
128
ret := C .get_token_embeddings (params , l .state , myArray , C .int (len (tokens )), (* C .float )(& floats [0 ]))
128
129
if ret != 0 {
@@ -164,6 +165,7 @@ func (l *LLama) Embeddings(text string, opts ...PredictOption) ([]float32, error
164
165
C .bool (po .PromptCacheRO ),
165
166
C .CString (po .Grammar ),
166
167
C .float (po .RopeFreqBase ), C .float (po .RopeFreqScale ), C .float (po .NegativePromptScale ), C .CString (po .NegativePrompt ),
168
+ C .int (po .NDraft ),
167
169
)
168
170
169
171
ret := C .get_embeddings (params , l .state , (* C .float )(& floats [0 ]))
@@ -202,6 +204,7 @@ func (l *LLama) Eval(text string, opts ...PredictOption) error {
202
204
C .bool (po .PromptCacheRO ),
203
205
C .CString (po .Grammar ),
204
206
C .float (po .RopeFreqBase ), C .float (po .RopeFreqScale ), C .float (po .NegativePromptScale ), C .CString (po .NegativePrompt ),
207
+ C .int (po .NDraft ),
205
208
)
206
209
ret := C .eval (params , l .state , input )
207
210
if ret != 0 {
@@ -213,6 +216,64 @@ func (l *LLama) Eval(text string, opts ...PredictOption) error {
213
216
return nil
214
217
}
215
218
219
+ func (l * LLama ) SpeculativeSampling (ll * LLama , text string , opts ... PredictOption ) (string , error ) {
220
+ po := NewPredictOptions (opts ... )
221
+
222
+ if po .TokenCallback != nil {
223
+ setCallback (l .state , po .TokenCallback )
224
+ }
225
+
226
+ input := C .CString (text )
227
+ if po .Tokens == 0 {
228
+ po .Tokens = 99999999
229
+ }
230
+ out := make ([]byte , po .Tokens )
231
+
232
+ reverseCount := len (po .StopPrompts )
233
+ reversePrompt := make ([]* C.char , reverseCount )
234
+ var pass * * C.char
235
+ for i , s := range po .StopPrompts {
236
+ cs := C .CString (s )
237
+ reversePrompt [i ] = cs
238
+ pass = & reversePrompt [0 ]
239
+ }
240
+
241
+ params := C .llama_allocate_params (input , C .int (po .Seed ), C .int (po .Threads ), C .int (po .Tokens ), C .int (po .TopK ),
242
+ C .float (po .TopP ), C .float (po .Temperature ), C .float (po .Penalty ), C .int (po .Repeat ),
243
+ C .bool (po .IgnoreEOS ), C .bool (po .F16KV ),
244
+ C .int (po .Batch ), C .int (po .NKeep ), pass , C .int (reverseCount ),
245
+ C .float (po .TailFreeSamplingZ ), C .float (po .TypicalP ), C .float (po .FrequencyPenalty ), C .float (po .PresencePenalty ),
246
+ C .int (po .Mirostat ), C .float (po .MirostatETA ), C .float (po .MirostatTAU ), C .bool (po .PenalizeNL ), C .CString (po .LogitBias ),
247
+ C .CString (po .PathPromptCache ), C .bool (po .PromptCacheAll ), C .bool (po .MLock ), C .bool (po .MMap ),
248
+ C .CString (po .MainGPU ), C .CString (po .TensorSplit ),
249
+ C .bool (po .PromptCacheRO ),
250
+ C .CString (po .Grammar ),
251
+ C .float (po .RopeFreqBase ), C .float (po .RopeFreqScale ), C .float (po .NegativePromptScale ), C .CString (po .NegativePrompt ),
252
+ C .int (po .NDraft ),
253
+ )
254
+ ret := C .speculative_sampling (params , l .state , ll .state , (* C .char )(unsafe .Pointer (& out [0 ])), C .bool (po .DebugMode ))
255
+ if ret != 0 {
256
+ return "" , fmt .Errorf ("inference failed" )
257
+ }
258
+ res := C .GoString ((* C .char )(unsafe .Pointer (& out [0 ])))
259
+
260
+ res = strings .TrimPrefix (res , " " )
261
+ res = strings .TrimPrefix (res , text )
262
+ res = strings .TrimPrefix (res , "\n " )
263
+
264
+ for _ , s := range po .StopPrompts {
265
+ res = strings .TrimRight (res , s )
266
+ }
267
+
268
+ C .llama_free_params (params )
269
+
270
+ if po .TokenCallback != nil {
271
+ setCallback (l .state , nil )
272
+ }
273
+
274
+ return res , nil
275
+ }
276
+
216
277
func (l * LLama ) Predict (text string , opts ... PredictOption ) (string , error ) {
217
278
po := NewPredictOptions (opts ... )
218
279
@@ -246,6 +307,7 @@ func (l *LLama) Predict(text string, opts ...PredictOption) (string, error) {
246
307
C .bool (po .PromptCacheRO ),
247
308
C .CString (po .Grammar ),
248
309
C .float (po .RopeFreqBase ), C .float (po .RopeFreqScale ), C .float (po .NegativePromptScale ), C .CString (po .NegativePrompt ),
310
+ C .int (po .NDraft ),
249
311
)
250
312
ret := C .llama_predict (params , l .state , (* C .char )(unsafe .Pointer (& out [0 ])), C .bool (po .DebugMode ))
251
313
if ret != 0 {
@@ -294,6 +356,7 @@ func (l *LLama) TokenizeString(text string, opts ...PredictOption) (int32, []int
294
356
C .bool (po .PromptCacheRO ),
295
357
C .CString (po .Grammar ),
296
358
C .float (po .RopeFreqBase ), C .float (po .RopeFreqScale ), C .float (po .NegativePromptScale ), C .CString (po .NegativePrompt ),
359
+ C .int (po .NDraft ),
297
360
)
298
361
299
362
tokRet := C .llama_tokenize_string (params , l .state , (* C .int )(unsafe .Pointer (& out [0 ]))) //, C.int(po.Tokens), true)
0 commit comments