From eaba8ca52bb82e125a8baa059d3f552a02f391af Mon Sep 17 00:00:00 2001 From: Brad P Date: Mon, 13 Jan 2025 11:29:01 -0600 Subject: [PATCH] update to LLMResponse for streaming --- core/ai_test.go | 2 +- server/ai_http.go | 12 ++++++------ server/ai_mediaserver.go | 4 ++-- server/ai_process.go | 18 +++++++++--------- server/ai_worker.go | 6 +++--- server/ai_worker_test.go | 2 +- 6 files changed, 22 insertions(+), 22 deletions(-) diff --git a/core/ai_test.go b/core/ai_test.go index e9453879b..481b230de 100644 --- a/core/ai_test.go +++ b/core/ai_test.go @@ -653,7 +653,7 @@ func (a *stubAIWorker) SegmentAnything2(ctx context.Context, req worker.GenSegme func (a *stubAIWorker) LLM(ctx context.Context, req worker.GenLLMJSONRequestBody) (interface{}, error) { var choices []worker.LLMChoice - choices = append(choices, worker.LLMChoice{Delta: worker.LLMMessage{Content: "choice1", Role: "assistant"}, Index: 0}) + choices = append(choices, worker.LLMChoice{Delta: &worker.LLMMessage{Content: "choice1", Role: "assistant"}, Index: 0}) tokensUsed := worker.LLMTokenUsage{PromptTokens: 40, CompletionTokens: 10, TotalTokens: 50} return &worker.LLMResponse{Choices: choices, Created: 1, Model: "llm_model", TokensUsed: tokensUsed}, nil } diff --git a/server/ai_http.go b/server/ai_http.go index 7537a274b..71110d1f2 100644 --- a/server/ai_http.go +++ b/server/ai_http.go @@ -586,7 +586,7 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request } // Check if the response is a streaming response - if streamChan, ok := resp.(<-chan worker.LlmStreamChunk); ok { + if streamChan, ok := resp.(<-chan *worker.LLMResponse); ok { glog.Infof("Streaming response for request id=%v", requestID) // Set headers for SSE @@ -610,7 +610,7 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request fmt.Fprintf(w, "data: %s\n\n", data) flusher.Flush() - if chunk.Done { + if chunk.Choices[0].FinishReason != nil && *chunk.Choices[0].FinishReason != "" { break } } @@ -683,8 +683,8 @@ func (h *lphttp) AIResults() http.Handler { case "text/event-stream": resultType = "streaming" glog.Infof("Received %s response from remote worker=%s taskId=%d", resultType, r.RemoteAddr, tid) - resChan := make(chan worker.LlmStreamChunk, 100) - workerResult.Results = (<-chan worker.LlmStreamChunk)(resChan) + resChan := make(chan *worker.LLMResponse, 100) + workerResult.Results = (<-chan *worker.LLMResponse)(resChan) defer r.Body.Close() defer close(resChan) @@ -703,12 +703,12 @@ func (h *lphttp) AIResults() http.Handler { line := scanner.Text() if strings.HasPrefix(line, "data: ") { data := strings.TrimPrefix(line, "data: ") - var chunk worker.LlmStreamChunk + var chunk worker.LLMResponse if err := json.Unmarshal([]byte(data), &chunk); err != nil { clog.Errorf(ctx, "Error unmarshaling stream data: %v", err) continue } - resChan <- chunk + resChan <- &chunk } } } diff --git a/server/ai_mediaserver.go b/server/ai_mediaserver.go index a810d7945..22721afd0 100644 --- a/server/ai_mediaserver.go +++ b/server/ai_mediaserver.go @@ -289,7 +289,7 @@ func (ls *LivepeerServer) LLM() http.Handler { took := time.Since(start) clog.V(common.VERBOSE).Infof(ctx, "Processed LLM request model_id=%v took=%v", *req.Model, took) - if streamChan, ok := resp.(chan worker.LlmStreamChunk); ok { + if streamChan, ok := resp.(chan *worker.LLMResponse); ok { // Handle streaming response (SSE) w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") @@ -299,7 +299,7 @@ func (ls *LivepeerServer) LLM() http.Handler { data, _ := json.Marshal(chunk) fmt.Fprintf(w, "data: %s\n\n", data) w.(http.Flusher).Flush() - if chunk.Done { + if chunk.Choices[0].FinishReason != nil && *chunk.Choices[0].FinishReason != "" { break } } diff --git a/server/ai_process.go b/server/ai_process.go index 7dc83735a..10f75629b 100644 --- a/server/ai_process.go +++ b/server/ai_process.go @@ -1106,7 +1106,7 @@ func processLLM(ctx context.Context, params aiRequestParams, req worker.GenLLMJS } if req.Stream != nil && *req.Stream { - streamChan, ok := resp.(chan worker.LlmStreamChunk) + streamChan, ok := resp.(chan *worker.LLMResponse) if !ok { return nil, errors.New("unexpected response type for streaming request") } @@ -1166,28 +1166,28 @@ func submitLLM(ctx context.Context, params aiRequestParams, sess *AISession, req return handleNonStreamingResponse(ctx, resp.Body, sess, req, start) } -func handleSSEStream(ctx context.Context, body io.ReadCloser, sess *AISession, req worker.GenLLMJSONRequestBody, start time.Time) (chan worker.LlmStreamChunk, error) { - streamChan := make(chan worker.LlmStreamChunk, 100) +func handleSSEStream(ctx context.Context, body io.ReadCloser, sess *AISession, req worker.GenLLMJSONRequestBody, start time.Time) (chan *worker.LLMResponse, error) { + streamChan := make(chan *worker.LLMResponse, 100) go func() { defer close(streamChan) defer body.Close() scanner := bufio.NewScanner(body) - var totalTokens int + var totalTokens worker.LLMTokenUsage for scanner.Scan() { line := scanner.Text() if strings.HasPrefix(line, "data: ") { data := strings.TrimPrefix(line, "data: ") if data == "[DONE]" { - streamChan <- worker.LlmStreamChunk{Done: true, TokensUsed: totalTokens} + //streamChan <- worker.LLMResponse{Done: true, TokensUsed: totalTokens} break } - var chunk worker.LlmStreamChunk + var chunk worker.LLMResponse if err := json.Unmarshal([]byte(data), &chunk); err != nil { clog.Errorf(ctx, "Error unmarshaling SSE data: %v", err) continue } - totalTokens += chunk.TokensUsed - streamChan <- chunk + totalTokens = chunk.TokensUsed + streamChan <- &chunk } } if err := scanner.Err(); err != nil { @@ -1195,7 +1195,7 @@ func handleSSEStream(ctx context.Context, body io.ReadCloser, sess *AISession, r } took := time.Since(start) - sess.LatencyScore = CalculateLLMLatencyScore(took, totalTokens) + sess.LatencyScore = CalculateLLMLatencyScore(took, totalTokens.TotalTokens) if monitor.Enabled { var pricePerAIUnit float64 diff --git a/server/ai_worker.go b/server/ai_worker.go index d626d42f5..02ac6e2d1 100644 --- a/server/ai_worker.go +++ b/server/ai_worker.go @@ -354,7 +354,7 @@ func runAIJob(n *core.LivepeerNode, orchAddr string, httpc *http.Client, notify if resp != nil { if resultType == "text/event-stream" { - streamChan, ok := resp.(<-chan worker.LlmStreamChunk) + streamChan, ok := resp.(<-chan *worker.LLMResponse) if ok { sendStreamingAIResult(ctx, n, orchAddr, notify.AIJobData.Pipeline, httpc, resultType, streamChan) return @@ -530,7 +530,7 @@ func sendAIResult(ctx context.Context, n *core.LivepeerNode, orchAddr string, pi } func sendStreamingAIResult(ctx context.Context, n *core.LivepeerNode, orchAddr string, pipeline string, httpc *http.Client, - contentType string, streamChan <-chan worker.LlmStreamChunk, + contentType string, streamChan <-chan *worker.LLMResponse, ) { clog.Infof(ctx, "sending streaming results back to Orchestrator") taskId := clog.GetVal(ctx, "taskId") @@ -571,7 +571,7 @@ func sendStreamingAIResult(ctx context.Context, n *core.LivepeerNode, orchAddr s } fmt.Fprintf(pWriter, "data: %s\n\n", data) - if chunk.Done { + if chunk.Choices[0].FinishReason != nil && *chunk.Choices[0].FinishReason != "" { pWriter.Close() clog.Infof(ctx, "streaming results finished") return diff --git a/server/ai_worker_test.go b/server/ai_worker_test.go index e628f3cd8..fc728acfd 100644 --- a/server/ai_worker_test.go +++ b/server/ai_worker_test.go @@ -606,7 +606,7 @@ func (a *stubAIWorker) LLM(ctx context.Context, req worker.GenLLMJSONRequestBody return nil, a.Err } else { var choices []worker.LLMChoice - choices = append(choices, worker.LLMChoice{Delta: worker.LLMMessage{Content: "choice1", Role: "assistant"}, Index: 0}) + choices = append(choices, worker.LLMChoice{Delta: &worker.LLMMessage{Content: "choice1", Role: "assistant"}, Index: 0}) tokensUsed := worker.LLMTokenUsage{PromptTokens: 40, CompletionTokens: 10, TotalTokens: 50} return &worker.LLMResponse{Choices: choices, Created: 1, Model: "llm_model", TokensUsed: tokensUsed}, nil }