Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Builtin Cohere Command Processor #2165

Merged
merged 23 commits into from
Mar 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ require (
github.com/Masterminds/sprig/v3 v3.3.0
github.com/alexeyco/simpletable v1.0.0
github.com/bufbuild/buf v1.50.0
github.com/cohere-ai/cohere-go/v2 v2.12.4
github.com/conduitio/conduit-commons v0.5.1
github.com/conduitio/conduit-connector-file v0.10.0
github.com/conduitio/conduit-connector-generator v0.10.1
Expand All @@ -33,6 +34,7 @@ require (
github.com/hashicorp/go-hclog v1.6.3
github.com/hashicorp/go-plugin v1.6.3
github.com/jpillora/backoff v1.0.0
github.com/klauspost/compress v1.17.11
github.com/matryer/is v1.4.1
github.com/neilotoole/slogt v1.1.0
github.com/piotrkowalczuk/promgrpc/v4 v4.1.4
Expand Down Expand Up @@ -234,7 +236,6 @@ require (
github.com/karamaru-alpha/copyloopvar v1.2.1 // indirect
github.com/kisielk/errcheck v1.9.0 // indirect
github.com/kkHAIKE/contextcheck v1.1.6 // indirect
github.com/klauspost/compress v1.17.11 // indirect
github.com/klauspost/pgzip v1.2.6 // indirect
github.com/kulti/thelper v0.6.3 // indirect
github.com/kunwardeep/paralleltest v1.0.10 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,8 @@ github.com/ckaznocha/intrange v0.3.0 h1:VqnxtK32pxgkhJgYQEeOArVidIPg+ahLP7WBOXZd
github.com/ckaznocha/intrange v0.3.0/go.mod h1:+I/o2d2A1FBHgGELbGxzIcyd3/9l9DuwjM8FsbSS3Lo=
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
github.com/cohere-ai/cohere-go/v2 v2.12.4 h1:hWiOc7LkwJ21S3hh3Ogh9Fe5s9ZDsVu11qoaMGfYZRQ=
github.com/cohere-ai/cohere-go/v2 v2.12.4/go.mod h1:MuiJkCxlR18BDV2qQPbz2Yb/OCVphT1y6nD2zYaKeR0=
github.com/colinmarc/hdfs/v2 v2.1.1/go.mod h1:M3x+k8UKKmxtFu++uAZ0OtDU8jR3jnaZIAc6yK4Ue0c=
github.com/conduitio/conduit-commons v0.5.1 h1:o6Lh3LhBuorDfVVlWyj3tYoYbBUUPTMiJE2zefiBuSo=
github.com/conduitio/conduit-commons v0.5.1/go.mod h1:FmJI9zOxX+HDiqculivPmekaUASZdubOJnI8sOOLAXk=
Expand Down
278 changes: 278 additions & 0 deletions pkg/plugin/processor/builtin/impl/cohere/command.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,278 @@
// Copyright © 2025 Meroxa, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package cohere

import (
"context"
"fmt"
"time"

cohere "github.com/cohere-ai/cohere-go/v2"
cohereClient "github.com/cohere-ai/cohere-go/v2/client"
"github.com/conduitio/conduit-commons/config"
"github.com/conduitio/conduit-commons/opencdc"
sdk "github.com/conduitio/conduit-processor-sdk"
"github.com/conduitio/conduit/pkg/foundation/cerrors"
"github.com/conduitio/conduit/pkg/foundation/log"
"github.com/goccy/go-json"
"github.com/jpillora/backoff"
)

type commandModel interface {
Command(ctx context.Context, content string) (string, error)
}

type commandClient struct {
client *cohereClient.Client
config *commandProcessorConfig
}

//go:generate paramgen -output=paramgen_command.go commandProcessorConfig

type commandProcessor struct {
sdk.UnimplementedProcessor

requestBodyRef *sdk.ReferenceResolver
responseBodyRef *sdk.ReferenceResolver
logger log.CtxLogger
config commandProcessorConfig
backoffCfg *backoff.Backoff
client commandModel
}

type commandProcessorConfig struct {
// Model is one of the name of a compatible command model version.
Model string `json:"model" default:"command"`
// APIKey is the API key for Cohere api calls.
APIKey string `json:"apiKey" validate:"required"`
// Maximum number of retries for an individual record when backing off following an error.
BackoffRetryCount float64 `json:"backoffRetry.count" default:"0" validate:"gt=-1"`
// The multiplying factor for each increment step.
BackoffRetryFactor float64 `json:"backoffRetry.factor" default:"2" validate:"gt=0"`
// The minimum waiting time before retrying.
BackoffRetryMin time.Duration `json:"backoffRetry.min" default:"100ms"`
// The maximum waiting time before retrying.
BackoffRetryMax time.Duration `json:"backoffRetry.max" default:"5s"`
// Prompt is the preset prompt.
Prompt string `json:"prompt" validate:"required"`
// RequestBodyRef specifies the api request field.
RequestBodyRef string `json:"request.body" default:".Payload.After"`
// ResponseBodyRef specifies in which field should the response body be saved.
ResponseBodyRef string `json:"response.body" default:".Payload.After"`
}

func NewCommandProcessor(l log.CtxLogger) sdk.Processor {
return &commandProcessor{logger: l.WithComponent("cohere.command")}
}

func (p *commandProcessor) Configure(ctx context.Context, cfg config.Config) error {
// Configure is the first function to be called in a processor. It provides the processor
// with the configuration that needs to be validated and stored to be used in other methods.
// This method should not open connections or any other resources. It should solely focus
// on parsing and validating the configuration itself.

err := sdk.ParseConfig(ctx, cfg, &p.config, commandProcessorConfig{}.Parameters())
if err != nil {
return fmt.Errorf("failed to parse configuration: %w", err)
}

requestBodyRef, err := sdk.NewReferenceResolver(p.config.RequestBodyRef)
if err != nil {
return fmt.Errorf("failed parsing request.body %v: %w", p.config.RequestBodyRef, err)
}
p.requestBodyRef = &requestBodyRef

responseBodyRef, err := sdk.NewReferenceResolver(p.config.ResponseBodyRef)
if err != nil {
return fmt.Errorf("failed parsing response.body %v: %w", p.config.ResponseBodyRef, err)
}
p.responseBodyRef = &responseBodyRef

// Initialize the client only if it hasn't been injected
if p.client == nil {
p.client = &commandClient{
client: cohereClient.NewClient(),
config: &p.config,
}
}

p.backoffCfg = &backoff.Backoff{
Factor: p.config.BackoffRetryFactor,
Min: p.config.BackoffRetryMin,
Max: p.config.BackoffRetryMax,
}

return nil
}

func (p *commandProcessor) Specification() (sdk.Specification, error) {
// Specification contains the metadata for the processor, which can be used to define how
// to reference the processor, describe what the processor does and the configuration
// parameters it expects.

return sdk.Specification{
Name: "cohere.command",
Summary: "Conduit processor for Cohere's command model.",
Description: "Conduit processor for Cohere's command model.",
Version: "v0.1.0",
Author: "Meroxa, Inc.",
Parameters: p.config.Parameters(),
}, nil
}

func (p *commandProcessor) Process(ctx context.Context, records []opencdc.Record) []sdk.ProcessedRecord {
out := make([]sdk.ProcessedRecord, 0, len(records))
for _, record := range records {
var key []byte
if record.Key != nil {
key = record.Key.Bytes()
}
p.logger.Trace(ctx).Bytes("record_key", key).Msg("processing record")

requestRef, err := p.requestBodyRef.Resolve(&record)
if err != nil {
return append(out, sdk.ErrorRecord{Error: fmt.Errorf("failed to resolve reference %v: %w", p.config.RequestBodyRef, err)})
}

content := fmt.Sprintf(p.config.Prompt, p.getInput(requestRef.Get()))
for {
resp, err := p.client.Command(ctx, content)
attempt := p.backoffCfg.Attempt()
duration := p.backoffCfg.Duration()

if err != nil {
switch {
case cerrors.As(err, &cohere.GatewayTimeoutError{}),
cerrors.As(err, &cohere.InternalServerError{}),
cerrors.As(err, &cohere.ServiceUnavailableError{}):

if attempt < p.config.BackoffRetryCount {
sdk.Logger(ctx).Debug().Err(err).Float64("attempt", attempt).
Float64("backoffRetry.count", p.config.BackoffRetryCount).
Int64("backoffRetry.duration", duration.Milliseconds()).
Msg("retrying Cohere HTTP request")

select {
case <-ctx.Done():
return append(out, sdk.ErrorRecord{Error: ctx.Err()})
case <-time.After(duration):
continue
}
} else {
return append(out, sdk.ErrorRecord{Error: err})
}

default:
// BadRequestError, ClientClosedRequestError, ForbiddenError, InvalidTokenError,
// NotFoundError, NotImplementedError, TooManyRequestsError, UnauthorizedError, UnprocessableEntityError
return append(out, sdk.ErrorRecord{Error: err})
}
}

p.backoffCfg.Reset()

err = p.setField(&record, p.responseBodyRef, resp)
if err != nil {
return append(out, sdk.ErrorRecord{Error: fmt.Errorf("failed setting response body: %w", err)})
}
out = append(out, sdk.SingleRecord(record))

break
}
}
return out
}

func (cc *commandClient) Command(ctx context.Context, content string) (string, error) {
resp, err := cc.client.V2.Chat(
ctx,
&cohere.V2ChatRequest{
Model: cc.config.Model,
Messages: cohere.ChatMessages{
{
Role: "user",
User: &cohere.UserMessage{Content: &cohere.UserMessageContent{
String: content,
}},
},
},
},
cohereClient.WithToken(cc.config.APIKey),
)
if err != nil {
return "", err
}

chatResponse, err := unmarshalChatResponse([]byte(resp.String()))
if err != nil {
return "", fmt.Errorf("error unmarshalling chat response: %w", err)
}

if len(chatResponse.Message.Content) != 1 {
return "", fmt.Errorf("invalid chat content")
}

return chatResponse.Message.Content[0].Text, nil
}

type ChatResponse struct {
ID string `json:"id"`
FinishReason string `json:"finish_reason"`
Message struct {
Role string `json:"role"`
Content []struct {
Type string `json:"type"`
Text string `json:"text"`
} `json:"content"`
} `json:"message"`
}

func unmarshalChatResponse(res []byte) (*ChatResponse, error) {
response := &ChatResponse{}
err := json.Unmarshal(res, response)
if err != nil {
return nil, err
}
return response, nil
}

func (p *commandProcessor) getInput(val any) string {
switch v := val.(type) {
case opencdc.RawData:
return string(v)
case opencdc.StructuredData:
return string(v.Bytes())
default:
return fmt.Sprintf("%v", v)
}
}

func (p *commandProcessor) setField(r *opencdc.Record, refRes *sdk.ReferenceResolver, data any) error {
if refRes == nil {
return nil
}

ref, err := refRes.Resolve(r)
if err != nil {
return fmt.Errorf("error reference resolver: %w", err)
}

err = ref.Set(data)
if err != nil {
return fmt.Errorf("error reference set: %w", err)
}

return nil
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Copyright © 2025 Meroxa, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package cohere

import (
"context"
"fmt"

"github.com/conduitio/conduit-commons/config"
"github.com/conduitio/conduit-commons/opencdc"
sdk "github.com/conduitio/conduit-processor-sdk"
"github.com/conduitio/conduit/pkg/plugin/processor/builtin/internal/exampleutil"
)

//nolint:govet // a more descriptive example description
func ExampleCommandProcessor() {
p := &commandProcessor{}
p.client = &mockClient{}

exampleutil.RunExample(p, exampleutil.Example{
Summary: `Generate responses using Cohere's command model`,
Description: `
This example demonstrates how to use the Cohere command processor to generate responses for a record's ` + "`.Payload.After`" + ` field.
The processor sends the input text to the Cohere API and replaces it with the model's response.`,
Config: config.Config{
commandProcessorConfigApiKey: "apikey",
commandProcessorConfigPrompt: "hello",
},
Have: opencdc.Record{
Operation: opencdc.OperationUpdate,
Position: opencdc.Position("pos-1"),
Payload: opencdc.Change{
After: opencdc.RawData("who are you?"),
},
},
Want: sdk.SingleRecord{
Operation: opencdc.OperationUpdate,
Position: opencdc.Position("pos-1"),
Payload: opencdc.Change{
After: opencdc.RawData("cohere command response content"),
},
},
})

// Output:
// processor transformed record:
// --- before
// +++ after
// @@ -1,10 +1,10 @@
// {
// "position": "cG9zLTE=",
// "operation": "update",
// "metadata": null,
// "key": null,
// "payload": {
// "before": null,
// - "after": "who are you?"
// + "after": "cohere command response content"
// }
// }
}

type mockClient struct{}

func (m mockClient) Command(ctx context.Context, content string) (string, error) {
if content == "" {
return "", fmt.Errorf("mocked api error")
}
return "cohere command response content", nil
}
Loading
Loading