forked from go-skynet/go-llama.cpp
-
Notifications
You must be signed in to change notification settings - Fork 0
/
options.go
148 lines (125 loc) · 3.03 KB
/
options.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
package llama
import "runtime"
type ModelOptions struct {
ContextSize int
Parts int
Seed int
F16Memory bool
MLock bool
}
type PredictOptions struct {
Seed, Threads, Tokens, TopK, Repeat int
TopP, Temperature, Penalty float64
F16KV bool
IgnoreEOS bool
}
type PredictOption func(p *PredictOptions)
type ModelOption func(p *ModelOptions)
var DefaultModelOptions ModelOptions = ModelOptions{
ContextSize: 512,
Seed: 0,
F16Memory: false,
MLock: false,
}
var DefaultOptions PredictOptions = PredictOptions{
Seed: -1,
Threads: runtime.NumCPU(),
Tokens: 128,
TopK: 10000,
TopP: 0.90,
Temperature: 0.96,
Penalty: 1,
Repeat: 64,
}
// SetContext sets the context size.
func SetContext(c int) ModelOption {
return func(p *ModelOptions) {
p.ContextSize = c
}
}
func SetModelSeed(c int) ModelOption {
return func(p *ModelOptions) {
p.Seed = c
}
}
func SetParts(c int) ModelOption {
return func(p *ModelOptions) {
p.Parts = c
}
}
var EnableF16Memory ModelOption = func(p *ModelOptions) {
p.F16Memory = true
}
var EnableF16KV PredictOption = func(p *PredictOptions) {
p.F16KV = true
}
var EnableMLock ModelOption = func(p *ModelOptions) {
p.MLock = true
}
// Create a new PredictOptions object with the given options.
func NewModelOptions(opts ...ModelOption) ModelOptions {
p := DefaultModelOptions
for _, opt := range opts {
opt(&p)
}
return p
}
var IgnoreEOS PredictOption = func(p *PredictOptions) {
p.IgnoreEOS = true
}
// SetSeed sets the random seed for sampling text generation.
func SetSeed(seed int) PredictOption {
return func(p *PredictOptions) {
p.Seed = seed
}
}
// SetThreads sets the number of threads to use for text generation.
func SetThreads(threads int) PredictOption {
return func(p *PredictOptions) {
p.Threads = threads
}
}
// SetTokens sets the number of tokens to generate.
func SetTokens(tokens int) PredictOption {
return func(p *PredictOptions) {
p.Tokens = tokens
}
}
// SetTopK sets the value for top-K sampling.
func SetTopK(topk int) PredictOption {
return func(p *PredictOptions) {
p.TopK = topk
}
}
// SetTopP sets the value for nucleus sampling.
func SetTopP(topp float64) PredictOption {
return func(p *PredictOptions) {
p.TopP = topp
}
}
// SetTemperature sets the temperature value for text generation.
func SetTemperature(temp float64) PredictOption {
return func(p *PredictOptions) {
p.Temperature = temp
}
}
// SetPenalty sets the repetition penalty for text generation.
func SetPenalty(penalty float64) PredictOption {
return func(p *PredictOptions) {
p.Penalty = penalty
}
}
// SetRepeat sets the number of times to repeat text generation.
func SetRepeat(repeat int) PredictOption {
return func(p *PredictOptions) {
p.Repeat = repeat
}
}
// Create a new PredictOptions object with the given options.
func NewPredictOptions(opts ...PredictOption) PredictOptions {
p := DefaultOptions
for _, opt := range opts {
opt(&p)
}
return p
}