Skip to content

Commit

Permalink
update to LLMResponse for streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
ad-astra-video committed Jan 14, 2025
1 parent ddb3a7c commit eaba8ca
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 22 deletions.
2 changes: 1 addition & 1 deletion core/ai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ func (a *stubAIWorker) SegmentAnything2(ctx context.Context, req worker.GenSegme

func (a *stubAIWorker) LLM(ctx context.Context, req worker.GenLLMJSONRequestBody) (interface{}, error) {
var choices []worker.LLMChoice
choices = append(choices, worker.LLMChoice{Delta: worker.LLMMessage{Content: "choice1", Role: "assistant"}, Index: 0})
choices = append(choices, worker.LLMChoice{Delta: &worker.LLMMessage{Content: "choice1", Role: "assistant"}, Index: 0})
tokensUsed := worker.LLMTokenUsage{PromptTokens: 40, CompletionTokens: 10, TotalTokens: 50}
return &worker.LLMResponse{Choices: choices, Created: 1, Model: "llm_model", TokensUsed: tokensUsed}, nil
}
Expand Down
12 changes: 6 additions & 6 deletions server/ai_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request
}

// Check if the response is a streaming response
if streamChan, ok := resp.(<-chan worker.LlmStreamChunk); ok {
if streamChan, ok := resp.(<-chan *worker.LLMResponse); ok {
glog.Infof("Streaming response for request id=%v", requestID)

// Set headers for SSE
Expand All @@ -610,7 +610,7 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request
fmt.Fprintf(w, "data: %s\n\n", data)
flusher.Flush()

if chunk.Done {
if chunk.Choices[0].FinishReason != nil && *chunk.Choices[0].FinishReason != "" {
break
}
}
Expand Down Expand Up @@ -683,8 +683,8 @@ func (h *lphttp) AIResults() http.Handler {
case "text/event-stream":
resultType = "streaming"
glog.Infof("Received %s response from remote worker=%s taskId=%d", resultType, r.RemoteAddr, tid)
resChan := make(chan worker.LlmStreamChunk, 100)
workerResult.Results = (<-chan worker.LlmStreamChunk)(resChan)
resChan := make(chan *worker.LLMResponse, 100)
workerResult.Results = (<-chan *worker.LLMResponse)(resChan)

defer r.Body.Close()
defer close(resChan)
Expand All @@ -703,12 +703,12 @@ func (h *lphttp) AIResults() http.Handler {
line := scanner.Text()
if strings.HasPrefix(line, "data: ") {
data := strings.TrimPrefix(line, "data: ")
var chunk worker.LlmStreamChunk
var chunk worker.LLMResponse
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
clog.Errorf(ctx, "Error unmarshaling stream data: %v", err)
continue
}
resChan <- chunk
resChan <- &chunk
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions server/ai_mediaserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ func (ls *LivepeerServer) LLM() http.Handler {
took := time.Since(start)
clog.V(common.VERBOSE).Infof(ctx, "Processed LLM request model_id=%v took=%v", *req.Model, took)

if streamChan, ok := resp.(chan worker.LlmStreamChunk); ok {
if streamChan, ok := resp.(chan *worker.LLMResponse); ok {
// Handle streaming response (SSE)
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
Expand All @@ -299,7 +299,7 @@ func (ls *LivepeerServer) LLM() http.Handler {
data, _ := json.Marshal(chunk)
fmt.Fprintf(w, "data: %s\n\n", data)
w.(http.Flusher).Flush()
if chunk.Done {
if chunk.Choices[0].FinishReason != nil && *chunk.Choices[0].FinishReason != "" {
break
}
}
Expand Down
18 changes: 9 additions & 9 deletions server/ai_process.go
Original file line number Diff line number Diff line change
Expand Up @@ -1106,7 +1106,7 @@ func processLLM(ctx context.Context, params aiRequestParams, req worker.GenLLMJS
}

if req.Stream != nil && *req.Stream {
streamChan, ok := resp.(chan worker.LlmStreamChunk)
streamChan, ok := resp.(chan *worker.LLMResponse)
if !ok {
return nil, errors.New("unexpected response type for streaming request")
}
Expand Down Expand Up @@ -1166,36 +1166,36 @@ func submitLLM(ctx context.Context, params aiRequestParams, sess *AISession, req
return handleNonStreamingResponse(ctx, resp.Body, sess, req, start)
}

func handleSSEStream(ctx context.Context, body io.ReadCloser, sess *AISession, req worker.GenLLMJSONRequestBody, start time.Time) (chan worker.LlmStreamChunk, error) {
streamChan := make(chan worker.LlmStreamChunk, 100)
func handleSSEStream(ctx context.Context, body io.ReadCloser, sess *AISession, req worker.GenLLMJSONRequestBody, start time.Time) (chan *worker.LLMResponse, error) {
streamChan := make(chan *worker.LLMResponse, 100)
go func() {
defer close(streamChan)
defer body.Close()
scanner := bufio.NewScanner(body)
var totalTokens int
var totalTokens worker.LLMTokenUsage
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, "data: ") {
data := strings.TrimPrefix(line, "data: ")
if data == "[DONE]" {
streamChan <- worker.LlmStreamChunk{Done: true, TokensUsed: totalTokens}
//streamChan <- worker.LLMResponse{Done: true, TokensUsed: totalTokens}
break
}
var chunk worker.LlmStreamChunk
var chunk worker.LLMResponse
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
clog.Errorf(ctx, "Error unmarshaling SSE data: %v", err)
continue
}
totalTokens += chunk.TokensUsed
streamChan <- chunk
totalTokens = chunk.TokensUsed
streamChan <- &chunk
}
}
if err := scanner.Err(); err != nil {
clog.Errorf(ctx, "Error reading SSE stream: %v", err)
}

took := time.Since(start)
sess.LatencyScore = CalculateLLMLatencyScore(took, totalTokens)
sess.LatencyScore = CalculateLLMLatencyScore(took, totalTokens.TotalTokens)

if monitor.Enabled {
var pricePerAIUnit float64
Expand Down
6 changes: 3 additions & 3 deletions server/ai_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ func runAIJob(n *core.LivepeerNode, orchAddr string, httpc *http.Client, notify

if resp != nil {
if resultType == "text/event-stream" {
streamChan, ok := resp.(<-chan worker.LlmStreamChunk)
streamChan, ok := resp.(<-chan *worker.LLMResponse)
if ok {
sendStreamingAIResult(ctx, n, orchAddr, notify.AIJobData.Pipeline, httpc, resultType, streamChan)
return
Expand Down Expand Up @@ -530,7 +530,7 @@ func sendAIResult(ctx context.Context, n *core.LivepeerNode, orchAddr string, pi
}

func sendStreamingAIResult(ctx context.Context, n *core.LivepeerNode, orchAddr string, pipeline string, httpc *http.Client,
contentType string, streamChan <-chan worker.LlmStreamChunk,
contentType string, streamChan <-chan *worker.LLMResponse,
) {
clog.Infof(ctx, "sending streaming results back to Orchestrator")
taskId := clog.GetVal(ctx, "taskId")
Expand Down Expand Up @@ -571,7 +571,7 @@ func sendStreamingAIResult(ctx context.Context, n *core.LivepeerNode, orchAddr s
}
fmt.Fprintf(pWriter, "data: %s\n\n", data)

if chunk.Done {
if chunk.Choices[0].FinishReason != nil && *chunk.Choices[0].FinishReason != "" {
pWriter.Close()
clog.Infof(ctx, "streaming results finished")
return
Expand Down
2 changes: 1 addition & 1 deletion server/ai_worker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,7 @@ func (a *stubAIWorker) LLM(ctx context.Context, req worker.GenLLMJSONRequestBody
return nil, a.Err
} else {
var choices []worker.LLMChoice
choices = append(choices, worker.LLMChoice{Delta: worker.LLMMessage{Content: "choice1", Role: "assistant"}, Index: 0})
choices = append(choices, worker.LLMChoice{Delta: &worker.LLMMessage{Content: "choice1", Role: "assistant"}, Index: 0})
tokensUsed := worker.LLMTokenUsage{PromptTokens: 40, CompletionTokens: 10, TotalTokens: 50}
return &worker.LLMResponse{Choices: choices, Created: 1, Model: "llm_model", TokensUsed: tokensUsed}, nil
}
Expand Down

0 comments on commit eaba8ca

Please sign in to comment.