Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add embeddings endpoint #9

Open
kyriediculous opened this issue Dec 19, 2024 · 0 comments
Open

Add embeddings endpoint #9

kyriediculous opened this issue Dec 19, 2024 · 0 comments
Assignees

Comments

@kyriediculous
Copy link
Contributor

Example implementation (minus endpoint)

// embeddings/model.go
package embeddings

import (
    "context"
    "sync"

    "github.com/sugarme/gotch"
    "github.com/sugarme/tokenizer"
)

type EmbeddingModel struct {
    model     *gotch.Module
    tokenizer *tokenizer.Tokenizer
    device    gotch.Device
    batchSize int
    maxLength int
    mu        sync.RWMutex
}

type Config struct {
    ModelPath  string
    BatchSize  int
    MaxLength  int
    DeviceType gotch.DeviceType
}

func NewEmbeddingModel(cfg Config) (*EmbeddingModel, error) {
    // Load model in inference mode
    model, err := gotch.ModuleLoad(cfg.ModelPath)
    if err != nil {
        return nil, err
    }
    model.SetEvalMode()

    // Initialize tokenizer
    tokenizer, err := tokenizer.NewFromFile(cfg.ModelPath)
    if err != nil {
        return nil, err
    }

    // Setup device (CPU/CUDA)
    device := gotch.NewDevice(cfg.DeviceType)

    return &EmbeddingModel{
        model:     model,
        tokenizer: tokenizer,
        device:    device,
        batchSize: cfg.BatchSize,
        maxLength: cfg.MaxLength,
    }, nil
}

// Batch processing for better performance
func (m *EmbeddingModel) GetEmbeddingsBatch(ctx context.Context, inputs []string) ([][]float32, error) {
    m.mu.RLock()
    defer m.mu.RUnlock()

    // Tokenize with padding and truncation
    encoded, err := m.tokenizer.EncodeBatch(inputs,
        tokenizer.WithMaxLength(m.maxLength),
        tokenizer.WithTruncation(true),
        tokenizer.WithPadding(true),
    )
    if err != nil {
        return nil, err
    }

    // Convert to tensors
    inputIds := gotch.TensorFromSlice(encoded.InputIds).To(m.device)
    attentionMask := gotch.TensorFromSlice(encoded.AttentionMask).To(m.device)
    
    defer inputIds.MustDrop()
    defer attentionMask.MustDrop()

    // Forward pass with no_grad for inference
    var embeddings gotch.Tensor
    gotch.NoGrad(func() {
        embeddings = m.model.Forward(gotch.NewTensorList([]gotch.Tensor{
            inputIds,
            attentionMask,
        }))
    })
    defer embeddings.MustDrop()

    // Mean pooling
    maskedEmbeddings := embeddings.MulT(attentionMask.Unsqueeze(-1))
    sumEmbeddings := maskedEmbeddings.Sum(1)
    sumMask := attentionMask.Sum(1).Unsqueeze(-1)
    pooledEmbeddings := sumEmbeddings.DivT(sumMask)

    // Convert to [][]float32
    return pooledEmbeddings.Float32Values2D(), nil
}

// Handler for concurrent requests
type EmbeddingHandler struct {
    model *EmbeddingModel
    pool  chan struct{}
}

func NewEmbeddingHandler(model *EmbeddingModel, maxConcurrent int) *EmbeddingHandler {
    return &EmbeddingHandler{
        model: model,
        pool:  make(chan struct{}, maxConcurrent),
    }
}

func (h *EmbeddingHandler) GetEmbeddings(ctx context.Context, input []string) ([][]float32, error) {
    select {
    case h.pool <- struct{}{}:
        defer func() { <-h.pool }()
        return h.model.GetEmbeddingsBatch(ctx, input)
    case <-ctx.Done():
        return nil, ctx.Err()
    }
}

// API types
type EmbeddingRequest struct {
    Input []string `json:"input"`
    Model string   `json:"model,omitempty"`
}

type EmbeddingResponse struct {
    Object string             `json:"object"`
    Data   []EmbeddingVector `json:"data"`
    Model  string            `json:"model"`
}

type EmbeddingVector struct {
    Object    string    `json:"object"`
    Embedding []float32 `json:"embedding"`
    Index     int       `json:"index"`
}

// HTTP handler
func (h *EmbeddingHandler) HandleEmbeddings(w http.ResponseWriter, r *http.Request) {
    var req EmbeddingRequest
    if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
        http.Error(w, err.Error(), http.StatusBadRequest)
        return
    }

    embeddings, err := h.GetEmbeddings(r.Context(), req.Input)
    if err != nil {
        http.Error(w, err.Error(), http.StatusInternalServerError)
        return
    }

    response := EmbeddingResponse{
        Object: "list",
        Data:   make([]EmbeddingVector, len(embeddings)),
        Model:  h.model.model.Name(),
    }

    for i, emb := range embeddings {
        response.Data[i] = EmbeddingVector{
            Object:    "embedding",
            Embedding: emb,
            Index:     i,
        }
    }

    json.NewEncoder(w).Encode(response)
}

Let me clarify the workflow of how embeddings and LLMs work together in a RAG system:

  1. First, you create embeddings for your documents:
// embeddings/service.go
type Document struct {
    ID      string `json:"id"`
    Content string `json:"content"`
    Source  string `json:"source"`
}

type EmbeddedDocument struct {
    Document
    Embedding []float32 `json:"embedding"`
}

// Store embeddings in vector DB (like Qdrant, Milvus, etc.)
func (h *EmbeddingHandler) StoreDocument(doc Document) (*EmbeddedDocument, error) {
    embedding, err := h.GetEmbeddings(context.Background(), []string{doc.Content})
    if err != nil {
        return nil, err
    }

    embeddedDoc := &EmbeddedDocument{
        Document:  doc,
        Embedding: embedding[0],
    }

    // Store in vector DB
    return embeddedDoc, nil
}
  1. When a query comes in, you:
    a. Generate embedding for the query
    b. Find similar documents
    c. Format them into context for the LLM
// rag/service.go
type RAGService struct {
    embeddings *embeddings.Handler
    vectorDB   *vectorstore.Client
    llmClient  *llm.Client
}

func (s *RAGService) Query(ctx context.Context, query string) (*llm.Response, error) {
    // 1. Get query embedding
    queryEmbedding, err := s.embeddings.GetEmbeddings(ctx, []string{query})
    if err != nil {
        return nil, err
    }

    // 2. Find similar documents
    similar, err := s.vectorDB.Search(ctx, queryEmbedding[0], 3) // top 3 matches
    if err != nil {
        return nil, err
    }

    // 3. Format context
    context := formatContext(similar)
    
    // 4. Create messages for LLM with context as system message
    messages := []llm.Message{
        {
            Role:    "system",
            Content: context,
        },
        {
            Role:    "user",
            Content: query,
        },
    }

    // 5. Send to LLM
    return s.llmClient.Generate(ctx, messages)
}

func formatContext(docs []EmbeddedDocument) string {
    var context strings.Builder
    context.WriteString("Use the following information to answer the question:\n\n")
    
    for _, doc := range docs {
        context.WriteString(fmt.Sprintf("Source [%s]: %s\n\n", doc.Source, doc.Content))
    }
    
    return context.String()
}

Example usage:

func main() {
    // Initialize services
    embeddingHandler := embeddings.NewEmbeddingHandler(...)
    vectorDB := vectorstore.NewClient(...)
    llmClient := llm.NewClient("http://localhost:8005/llm")

    ragService := rag.NewRAGService(embeddingHandler, vectorDB, llmClient)

    // API endpoint
    http.HandleFunc("/rag", func(w http.ResponseWriter, r *http.Request) {
        var req struct {
            Query string `json:"query"`
        }
        json.NewDecoder(r.Body).Decode(&req)

        response, err := ragService.Query(r.Context(), req.Query)
        if err != nil {
            http.Error(w, err.Error(), http.StatusInternalServerError)
            return
        }

        json.NewEncoder(w).Encode(response)
    })
}

So the flow is:

  1. Documents → Embedding Service → Vector DB
  2. Query → Embedding Service → Vector Search → Similar Documents
  3. Similar Documents → Format as Context → LLM Pipeline
  4. LLM Pipeline uses context in system message to generate informed response

You'd use it like:

# First, store documents
curl -X POST "http://localhost:8080/documents" \
-d '{
    "content": "Dubai was founded in 1833...",
    "source": "history.txt"
}'

# Then query
curl -X POST "http://localhost:8080/rag" \
-d '{
    "query": "When was Dubai founded?"
}'

The LLM would receive messages like:

{
    "messages": [
        {
            "role": "system",
            "content": "Use the following information to answer the question:\n\nSource [history.txt]: Dubai was founded in 1833..."
        },
        {
            "role": "user",
            "content": "When was Dubai founded?"
        }
    ]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants