-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprovider.go
98 lines (74 loc) · 2.26 KB
/
provider.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
package openai
import (
"context"
"fmt"
"github.com/agent-api/core"
"github.com/agent-api/openai/client"
"github.com/go-logr/logr"
)
// Provider implements the LLMProvider interface for OpenAI
type Provider struct {
host string
port int
model *core.Model
// client is the internal Ollama HTTP client
client *client.OpenAIClient
logger logr.Logger
}
type ProviderOpts struct {
BaseURL string
Port int
APIKey string
Logger *logr.Logger
}
// NewProvider creates a new Ollama provider
func NewProvider(opts *ProviderOpts) *Provider {
opts.Logger.Info("Creating new OpenAI provider")
client := client.NewClient(
opts.Logger,
// TODO - need to enable local env variable, not just through opt
//client.WithAPIKey(opts.APIKey),
)
return &Provider{
client: client,
logger: *opts.Logger,
}
}
func (p *Provider) GetCapabilities(ctx context.Context) (*core.Capabilities, error) {
p.logger.Info("Fetching capabilities")
// Placeholder for future implementation
p.logger.Info("GetCapabilities method is not implemented yet")
return nil, nil
}
func (p *Provider) UseModel(ctx context.Context, model *core.Model) error {
p.logger.Info("Setting model", "modelID", model.ID)
p.model = model
return nil
}
// Generate implements the LLMProvider interface for basic responses
func (p *Provider) Generate(ctx context.Context, opts *core.GenerateOptions) (*core.Message, error) {
p.logger.Info("Generate request received", "modelID", p.model.ID)
resp, err := p.client.Chat(ctx, &client.ChatRequest{
Model: p.model.ID,
Messages: opts.Messages,
Tools: opts.Tools,
})
if err != nil {
p.logger.V(0).Error(err, "Error calling client chat method", err)
return nil, fmt.Errorf("error calling client chat method: %w", err)
}
return &core.Message{
Role: core.AssistantMessageRole,
Content: resp.Message.Content,
ToolCalls: resp.Message.ToolCalls,
}, nil
}
// GenerateStream streams the response token by token
func (p *Provider) GenerateStream(ctx context.Context, opts *core.GenerateOptions) (<-chan *core.Message, <-chan string, <-chan error) {
p.logger.Info("Starting stream generation", "modelID", p.model.ID)
return p.client.ChatStream(ctx, &client.ChatRequest{
Model: p.model.ID,
Messages: opts.Messages,
Tools: opts.Tools,
})
}