Skip to content

Commit c53d16a

Browse files
update chat/completions endpoint to allow passing a stream boolean, and handling a streaming response
1 parent cbd6022 commit c53d16a

File tree

6 files changed

+204
-26
lines changed

6 files changed

+204
-26
lines changed

internal/pkg/assistants/assistant_chat_completions.go

+23-9
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ const (
1313
URL_ASSISTANT_CHAT_COMPLETIONS = "/assistant/chat/%s/chat/completions"
1414
)
1515

16-
func GetAssistantChatCompletions(asstName string, msg string) (*models.ChatCompletionModel, error) {
16+
func GetAssistantChatCompletions(asstName string, msg string, stream bool) (*models.ChatCompletionModel, error) {
1717
outgoingMsg := models.ChatCompletionMessage{
1818
Role: "user",
1919
Content: msg,
@@ -30,21 +30,35 @@ func GetAssistantChatCompletions(asstName string, msg string) (*models.ChatCompl
3030

3131
body := models.ChatCompletionRequest{
3232
Messages: chat.Messages,
33+
Stream: stream,
3334
}
3435

3536
assistantDataUrl, err := GetAssistantDataBaseUrl()
3637
if err != nil {
3738
return nil, err
3839
}
3940

40-
resp, err := network.PostAndDecode[models.ChatCompletionRequest, models.ChatCompletionModel](
41-
assistantDataUrl,
42-
fmt.Sprintf(URL_ASSISTANT_CHAT_COMPLETIONS, asstName),
43-
true,
44-
body,
45-
)
46-
if err != nil {
47-
return nil, err
41+
var resp *models.ChatCompletionModel
42+
if !stream {
43+
resp, err = network.PostAndDecode[models.ChatCompletionRequest, models.ChatCompletionModel](
44+
assistantDataUrl,
45+
fmt.Sprintf(URL_ASSISTANT_CHAT_COMPLETIONS, asstName),
46+
true,
47+
body,
48+
)
49+
if err != nil {
50+
return nil, err
51+
}
52+
} else {
53+
resp, err = network.PostAndStreamChatResponse[models.ChatCompletionRequest](
54+
assistantDataUrl,
55+
fmt.Sprintf(URL_ASSISTANT_CHAT_COMPLETIONS, asstName),
56+
true,
57+
body,
58+
)
59+
if err != nil {
60+
return nil, err
61+
}
4862
}
4963

5064
// If the request was successful, update the chat history

internal/pkg/assistants/assistant_file_upload.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ func UploadAssistantFile(name string, filePath string) (*AssistantFileModel, err
1616
return nil, err
1717
}
1818

19-
resp, err := network.PostAndDecodeMultipartFormData[AssistantFileModel](
19+
resp, err := network.PostMultipartFormDataAndDecode[AssistantFileModel](
2020
assistantDataUrl,
2121
fmt.Sprintf(URL_ASSISTANT_FILE_UPLOAD, name),
2222
true,

internal/pkg/cli/command/assistant/chat.go

+43-13
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
type AssistantChatCmdOptions struct {
2424
name string
2525
message string
26+
stream bool
2627
json bool
2728
}
2829

@@ -43,19 +44,20 @@ func NewAssistantChatCmd() *cobra.Command {
4344
return
4445
}
4546

46-
// If no message is provided drop them into chat
47+
// If no message is provided drop them into interactive chat
4748
if options.message == "" {
4849
startChat(options.name)
4950
} else {
5051
// If message is provided, send it to the assistant
51-
sendMessage(options.name, options.message)
52+
sendMessage(options.name, options.message, options.stream)
5253
}
5354
},
5455
}
5556

5657
cmd.Flags().StringVarP(&options.name, "name", "n", "", "name of the assistant to chat with")
57-
cmd.Flags().BoolVar(&options.json, "json", false, "output as JSON")
5858
cmd.Flags().StringVarP(&options.message, "message", "m", "", "your message to the assistant")
59+
cmd.Flags().BoolVarP(&options.stream, "stream", "s", false, "stream chat message responses")
60+
cmd.Flags().BoolVar(&options.json, "json", false, "output as JSON")
5961
cmd.MarkFlagRequired("content")
6062

6163
cmd.AddCommand(NewAssistantChatClearCmd())
@@ -94,7 +96,8 @@ func startChat(asstName string) {
9496
checkForChatCommands(text)
9597

9698
if text != "" {
97-
_, err := sendMessage(asstName, text)
99+
// Stream here since we're in interactive chat mode
100+
_, err := sendMessage(asstName, text, true)
98101
if err != nil {
99102
pcio.Printf("Error sending message: %s\n", err)
100103
continue
@@ -103,27 +106,54 @@ func startChat(asstName string) {
103106
}
104107
}
105108

106-
func sendMessage(asstName string, message string) (*models.ChatCompletionModel, error) {
109+
func sendMessage(asstName string, message string, stream bool) (*models.ChatCompletionModel, error) {
107110
response := &models.ChatCompletionModel{}
108111

109-
err := style.Spinner("", func() error {
110-
chatResponse, err := assistants.GetAssistantChatCompletions(asstName, message)
112+
var chatGetter func() error
113+
if stream {
114+
chatGetter = streamChatResponse(response, asstName, message, stream)
115+
} else {
116+
chatGetter = getChatResponse(response, asstName, message, stream)
117+
}
118+
119+
err := chatGetter()
120+
if err != nil {
121+
return nil, err
122+
}
123+
124+
return response, nil
125+
}
126+
127+
func getChatResponse(resp *models.ChatCompletionModel, asstName string, message string, stream bool) func() error {
128+
return func() error {
129+
chatResponse, err := assistants.GetAssistantChatCompletions(asstName, message, stream)
111130
if err != nil {
112131
exit.Error(err)
113132
}
114133

115-
response = chatResponse
134+
resp = chatResponse
116135

117136
for _, choice := range chatResponse.Choices {
118137
presenters.PrintAssistantChatResponse(choice.Message.Content)
119138
}
120139
return nil
121-
})
122-
if err != nil {
123-
return nil, err
124140
}
141+
}
125142

126-
return response, nil
143+
func streamChatResponse(resp *models.ChatCompletionModel, asstName string, message string, stream bool) func() error {
144+
return func() error {
145+
chatResponse, err := assistants.GetAssistantChatCompletions(asstName, message, stream)
146+
if err != nil {
147+
exit.Error(err)
148+
}
149+
150+
resp = chatResponse
151+
152+
for _, choice := range chatResponse.Choices {
153+
presenters.PrintAssistantChatResponse(choice.Message.Content)
154+
}
155+
return nil
156+
}
127157
}
128158

129159
func displayChatHistory(asstName string, maxNoMsgs int) {
@@ -137,7 +167,7 @@ func displayChatHistory(asstName string, maxNoMsgs int) {
137167
presenters.PrintChatHistory(chat, maxNoMsgs)
138168
}
139169

140-
// This function checks the input for accepted chat commands
170+
// Checks the input for accepted chat commands
141171
func checkForChatCommands(text string) {
142172
switch text {
143173
case "exit()":

internal/pkg/utils/models/models.go

+27-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package models
33
import "time"
44

55
type ChatCompletionRequest struct {
6+
Stream bool `json:"stream"`
67
Messages []ChatCompletionMessage `json:"messages"`
78
}
89

@@ -12,6 +13,12 @@ type ChatCompletionModel struct {
1213
Model string `json:"model"`
1314
}
1415

16+
type ChoiceModel struct {
17+
FinishReason ChatFinishReason `json:"finish_reason"`
18+
Index int32 `json:"index"`
19+
Message ChatCompletionMessage `json:"message"`
20+
}
21+
1522
type ChatCompletionMessage struct {
1623
Role string `json:"role"`
1724
Content string `json:"content"`
@@ -33,8 +40,26 @@ const (
3340
FunctionCall ChatFinishReason = "function_call"
3441
)
3542

36-
type ChoiceModel struct {
43+
type StreamChatCompletionModel struct {
44+
Id string `json:"id"`
45+
Choices []ChoiceChunkModel `json:"choices"`
46+
Model string `json:"model"`
47+
}
48+
49+
type StreamChunk struct {
50+
Data StreamChatCompletionModel `json:"data"`
51+
}
52+
53+
type ChoiceChunkModel struct {
3754
FinishReason ChatFinishReason `json:"finish_reason"`
3855
Index int32 `json:"index"`
39-
Message ChatCompletionMessage `json:"message"`
56+
Delta ChatCompletionMessage `json:"delta"`
57+
}
58+
59+
type ContextRefModel struct {
60+
Id string `json:"id"`
61+
Source string `json:"source"`
62+
Text string `json:"text"`
63+
Score float64 `json:"score"`
64+
Path []string `json:"path"`
4065
}

internal/pkg/utils/network/request.go

+39
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,45 @@ func decodeResponse[T any](resp *http.Response, target *T) error {
8585
return nil
8686
}
8787

88+
func RequestWithBody[B any](baseUrl string, path string, method string, useApiKey bool, body B) (*http.Response, error) {
89+
url := baseUrl + path
90+
91+
var bodyJson []byte
92+
bodyJson, err := json.Marshal(body)
93+
if err != nil {
94+
log.Error().
95+
Err(err).
96+
Str("method", method).
97+
Str("url", url).
98+
Msg("Error marshalling JSON")
99+
return nil, pcio.Errorf("error marshalling JSON: %v", err)
100+
}
101+
102+
req, err := buildRequest(method, url, bodyJson)
103+
log.Info().
104+
Str("method", method).
105+
Str("url", url).
106+
Msg("Fetching data from dashboard")
107+
if err != nil {
108+
log.Error().
109+
Err(err).
110+
Str("url", url).
111+
Str("method", method).
112+
Msg("Error building request")
113+
return nil, pcio.Errorf("error building request: %v", err)
114+
}
115+
116+
resp, err := performRequest(req, useApiKey)
117+
if err != nil {
118+
log.Error().
119+
Err(err).
120+
Str("method", method).
121+
Str("url", url)
122+
return nil, pcio.Errorf("error performing request to %s: %v", url, err)
123+
}
124+
return resp, nil
125+
}
126+
88127
func RequestWithBodyAndDecode[B any, R any](baseUrl string, path string, method string, useApiKey bool, body B) (*R, error) {
89128
url := baseUrl + path
90129

internal/pkg/utils/network/request_post.go

+71-1
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,92 @@
11
package network
22

33
import (
4+
"bufio"
45
"bytes"
6+
"encoding/json"
7+
"fmt"
58
"io"
69
"mime/multipart"
710
"net/http"
811
"os"
912
"path/filepath"
13+
"strings"
1014

1115
"github.com/pinecone-io/cli/internal/pkg/utils/log"
16+
"github.com/pinecone-io/cli/internal/pkg/utils/models"
1217
"github.com/pinecone-io/cli/internal/pkg/utils/pcio"
1318
)
1419

1520
func PostAndDecode[B any, R any](baseUrl string, path string, useApiKey bool, body B) (*R, error) {
1621
return RequestWithBodyAndDecode[B, R](baseUrl, path, http.MethodPost, useApiKey, body)
1722
}
1823

19-
func PostAndDecodeMultipartFormData[R any](baseUrl string, path string, useApiKey bool, bodyPath string) (*R, error) {
24+
func PostAndStreamChatResponse[B any](baseUrl string, path string, useApiKey bool, body B) (*models.ChatCompletionModel, error) {
25+
resp, err := RequestWithBody[B](baseUrl, path, http.MethodPost, useApiKey, body)
26+
if err != nil {
27+
return nil, err
28+
}
29+
defer resp.Body.Close()
30+
31+
var completeResponse string
32+
var id string
33+
var model string
34+
35+
scanner := bufio.NewScanner(resp.Body)
36+
for scanner.Scan() {
37+
line := scanner.Text()
38+
if strings.HasPrefix(line, "data:") {
39+
dataStr := strings.TrimPrefix(line, "data:")
40+
dataStr = strings.TrimSpace(dataStr)
41+
42+
var chunkResp *models.StreamChatCompletionModel
43+
if err := json.Unmarshal([]byte(dataStr), &chunkResp); err != nil {
44+
return nil, pcio.Errorf("error unmarshaling chunk: %v", err)
45+
}
46+
47+
for _, choice := range chunkResp.Choices {
48+
fmt.Print(choice.Delta.Content)
49+
os.Stdout.Sync()
50+
completeResponse += choice.Delta.Content
51+
}
52+
id = chunkResp.Id
53+
model = chunkResp.Model
54+
}
55+
}
56+
57+
completionResp := &models.ChatCompletionModel{
58+
Id: id,
59+
Model: model,
60+
Choices: []models.ChoiceModel{
61+
{
62+
FinishReason: "stop",
63+
Index: 0,
64+
Message: models.ChatCompletionMessage{
65+
Content: completeResponse,
66+
Role: "assistant",
67+
},
68+
},
69+
},
70+
}
71+
72+
if err != nil {
73+
log.Error().
74+
Err(err).
75+
Str("method", http.MethodPost).
76+
Str("url", baseUrl+path).
77+
Str("status", resp.Status).
78+
Msg("Error decoding response")
79+
return nil, pcio.Errorf("error decoding JSON: %v", err)
80+
}
81+
82+
log.Info().
83+
Str("method", http.MethodPost).
84+
Str("url", baseUrl+path).
85+
Msg("Request completed successfully")
86+
return completionResp, nil
87+
}
88+
89+
func PostMultipartFormDataAndDecode[R any](baseUrl string, path string, useApiKey bool, bodyPath string) (*R, error) {
2090
url := baseUrl + path
2191

2292
var requestBody bytes.Buffer

0 commit comments

Comments
 (0)