@@ -30,6 +30,15 @@ import (
30
30
"github.com/jpillora/backoff"
31
31
)
32
32
33
+ type commandModel interface {
34
+ Command (ctx context.Context , content string ) (string , error )
35
+ }
36
+
37
+ type commandClient struct {
38
+ client * cohereClient.Client
39
+ config * commandProcessorConfig
40
+ }
41
+
33
42
//go:generate paramgen -output=paramgen_command.go commandProcessorConfig
34
43
35
44
type commandProcessor struct {
@@ -40,11 +49,11 @@ type commandProcessor struct {
40
49
logger log.CtxLogger
41
50
config commandProcessorConfig
42
51
backoffCfg * backoff.Backoff
43
- client * cohereClient. Client
52
+ client commandModel
44
53
}
45
54
46
55
type commandProcessorConfig struct {
47
- // Model is one of the Cohere model ( command,embed,rerank) .
56
+ // Model is one of the name of a compatible command model version .
48
57
Model string `json:"model" default:"command"`
49
58
// APIKey is the API key for Cohere api calls.
50
59
APIKey string `json:"apiKey" validate:"required"`
@@ -91,8 +100,10 @@ func (p *commandProcessor) Configure(ctx context.Context, cfg config.Config) err
91
100
}
92
101
p .responseBodyRef = & responseBodyRef
93
102
94
- // new cohere client
95
- p .client = cohereClient .NewClient ()
103
+ p .client = & commandClient {
104
+ client : cohereClient .NewClient (),
105
+ config : & p .config ,
106
+ }
96
107
97
108
p .backoffCfg = & backoff.Backoff {
98
109
Factor : p .config .BackoffRetryFactor ,
@@ -132,22 +143,9 @@ func (p *commandProcessor) Process(ctx context.Context, records []opencdc.Record
132
143
return append (out , sdk.ErrorRecord {Error : fmt .Errorf ("failed to resolve reference %v: %w" , p .config .RequestBodyRef , err )})
133
144
}
134
145
146
+ content := fmt .Sprintf (p .config .Prompt , p .getInput (requestRef .Get ()))
135
147
for {
136
- resp , err := p .client .V2 .Chat (
137
- ctx ,
138
- & cohere.V2ChatRequest {
139
- Model : p .config .Model ,
140
- Messages : cohere.ChatMessages {
141
- {
142
- Role : "user" ,
143
- User : & cohere.UserMessage {Content : & cohere.UserMessageContent {
144
- String : fmt .Sprintf (p .config .Prompt , p .getInput (requestRef .Get ())),
145
- }},
146
- },
147
- },
148
- },
149
- cohereClient .WithToken (p .config .APIKey ),
150
- )
148
+ resp , err := p .client .Command (ctx , content )
151
149
attempt := p .backoffCfg .Attempt ()
152
150
duration := p .backoffCfg .Duration ()
153
151
@@ -182,40 +180,48 @@ func (p *commandProcessor) Process(ctx context.Context, records []opencdc.Record
182
180
183
181
p .backoffCfg .Reset ()
184
182
185
- chatResponse , err := unmarshalChatResponse ([] byte ( resp . String ()) )
183
+ err = p . setField ( & record , p . responseBodyRef , resp )
186
184
if err != nil {
187
- return append (out , sdk.ErrorRecord {Error : err })
185
+ return append (out , sdk.ErrorRecord {Error : fmt . Errorf ( "failed setting response body: %w" , err ) })
188
186
}
187
+ out = append (out , sdk .SingleRecord (record ))
189
188
190
- if len (chatResponse .Message .Content ) == 1 {
191
- err = p .setField (& record , p .responseBodyRef , chatResponse .Message .Content [0 ].Text )
192
- if err != nil {
193
- return append (out , sdk.ErrorRecord {Error : fmt .Errorf ("failed setting response body: %w" , err )})
194
- }
195
- out = append (out , sdk .SingleRecord (record ))
196
- }
197
189
break
198
190
}
199
191
}
200
192
return out
201
193
}
202
194
203
- func (p * commandProcessor ) setField (r * opencdc.Record , refRes * sdk.ReferenceResolver , data any ) error {
204
- if refRes == nil {
205
- return nil
195
+ func (cc * commandClient ) Command (ctx context.Context , content string ) (string , error ) {
196
+ resp , err := cc .client .V2 .Chat (
197
+ ctx ,
198
+ & cohere.V2ChatRequest {
199
+ Model : cc .config .Model ,
200
+ Messages : cohere.ChatMessages {
201
+ {
202
+ Role : "user" ,
203
+ User : & cohere.UserMessage {Content : & cohere.UserMessageContent {
204
+ String : content ,
205
+ }},
206
+ },
207
+ },
208
+ },
209
+ cohereClient .WithToken (cc .config .APIKey ),
210
+ )
211
+ if err != nil {
212
+ return "" , err
206
213
}
207
214
208
- ref , err := refRes . Resolve ( r )
215
+ chatResponse , err := unmarshalChatResponse ([] byte ( resp . String ()) )
209
216
if err != nil {
210
- return fmt .Errorf ("error reference resolver : %w" , err )
217
+ return "" , fmt .Errorf ("error unmarshalling chat response : %w" , err )
211
218
}
212
219
213
- err = ref .Set (data )
214
- if err != nil {
215
- return fmt .Errorf ("error reference set: %w" , err )
220
+ if len (chatResponse .Message .Content ) != 1 {
221
+ return "" , fmt .Errorf ("invalid chat content" )
216
222
}
217
223
218
- return nil
224
+ return chatResponse . Message . Content [ 0 ]. Text , nil
219
225
}
220
226
221
227
type ChatResponse struct {
@@ -249,3 +255,21 @@ func (p *commandProcessor) getInput(val any) string {
249
255
return fmt .Sprintf ("%v" , v )
250
256
}
251
257
}
258
+
259
+ func (p * commandProcessor ) setField (r * opencdc.Record , refRes * sdk.ReferenceResolver , data any ) error {
260
+ if refRes == nil {
261
+ return nil
262
+ }
263
+
264
+ ref , err := refRes .Resolve (r )
265
+ if err != nil {
266
+ return fmt .Errorf ("error reference resolver: %w" , err )
267
+ }
268
+
269
+ err = ref .Set (data )
270
+ if err != nil {
271
+ return fmt .Errorf ("error reference set: %w" , err )
272
+ }
273
+
274
+ return nil
275
+ }
0 commit comments