diff --git a/database/qdrant_client.go b/database/qdrant_client.go index 3d2ab13..2ad691e 100644 --- a/database/qdrant_client.go +++ b/database/qdrant_client.go @@ -4,6 +4,7 @@ import ( "context" // understand this and usage in file "fmt" "os" + "sync" "time" "github.com/google/uuid" @@ -12,7 +13,12 @@ import ( "github.com/rs/zerolog/log" ) -var collectionName = os.Getenv("QDRANT_COLLECTION") +var ( + qdrantClientInstance *qdrant.Client + qdrantClientOnce sync.Once + collectionName = os.Getenv("QDRANT_COLLECTION") + qdrantHost = os.Getenv("QDRANT_HOST") +) type ScoredPoint struct { Id *qdrant.PointId `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` // Point id @@ -32,35 +38,33 @@ type GetOutputJSON struct { ModelResponse string `json:"model_response"` } -func InitializeQdrant() *qdrant.Client { +func initializeQdrant() (*qdrant.Client, error) { client, err := qdrant.NewClient(&qdrant.Config{ - Host: os.Getenv("QDRANT_HOST"), + Host: qdrantHost, Port: 6334, UseTLS: false, }) if err != nil { - panic(err) + return nil, err } - // Get a context for a minute ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - // Execute health check healthCheckResult, err := client.HealthCheck(ctx) if err != nil { - log.Fatal().Msgf("Could not get health: %v", err) + return nil, err } - log.Printf("Qdrant version: %s", healthCheckResult.GetVersion()) + log.Info().Msgf("Qdrant version: %s", healthCheckResult.GetVersion()) - // check if collection exists exists, err := client.CollectionExists(context.Background(), collectionName) if err != nil { - log.Fatal().Msgf("Could not check if collection exists: %v", err) + return nil, err } + if exists { log.Info().Msgf("Collection %s exists", collectionName) - return client + return client, nil } err = client.CreateCollection(ctx, &qdrant.CreateCollection{ @@ -74,14 +78,28 @@ func InitializeQdrant() *qdrant.Client { Type: qdrant.QuantizationType_Int8, AlwaysRam: qdrant.PtrOf(true), }), + OptimizersConfig: &qdrant.OptimizersConfigDiff{ + DefaultSegmentNumber: qdrant.PtrOf(uint64(16)), // used to minimize latency set to 2 to maximize throughput + }, }) if err != nil { - log.Fatal().Msgf("Could not create collection: %v", err) - } else { - log.Info().Msgf("Collection %s created", collectionName) + return nil, err } - return client + log.Info().Msgf("Collection %s created", collectionName) + return client, nil +} + +// GetQdrantClient returns a singleton instance of the Qdrant client +func GetQdrantClient() *qdrant.Client { + qdrantClientOnce.Do(func() { + var err error + qdrantClientInstance, err = initializeQdrant() + if err != nil { + log.Fatal().Err(err).Msg("Failed to initialize Qdrant client") + } + }) + return qdrantClientInstance } func GetQdrant(client *qdrant.Client, vectors []float32) ([]GetOutputJSON, error) { @@ -91,13 +109,16 @@ func GetQdrant(client *qdrant.Client, vectors []float32) ([]GetOutputJSON, error Query: qdrant.NewQueryDense(vectors), WithPayload: qdrant.NewWithPayloadInclude("model_response", "user_message"), ScoreThreshold: qdrant.PtrOf(float32(0.7)), // TODO: make this configurable + Params: &qdrant.SearchParams{ + Quantization: &qdrant.QuantizationSearchParams{ + Rescore: qdrant.PtrOf(true), // remove if results are inaccruate + }, + }, }) if err != nil { log.Fatal().Msgf("Could not search points: %v", err) } - client.Close() - log.Info().Msg("Searched points") var outputData []GetOutputJSON @@ -141,7 +162,5 @@ func PutQdrant(client *qdrant.Client, vectors []float32, message string, modelRe } fmt.Println("Upsert", len(upsertPoints), "points") - client.Close() - return operationInfo } diff --git a/handlers/handlers.go b/handlers/handlers.go index 3250335..bf3355a 100644 --- a/handlers/handlers.go +++ b/handlers/handlers.go @@ -30,7 +30,6 @@ type PutResponseBody struct { func HandleGetRequest(c *fiber.Ctx) error { c.Accepts("text/plain", "application/json") - c.Accepts("json", "text") log.Info().Msg("Handling GET request") // Parse the JSON body using Sonic @@ -60,7 +59,7 @@ func HandleGetRequest(c *fiber.Ctx) error { // query qdrant for response // initialize databases - qdrantClient := database.InitializeQdrant() + qdrantClient := database.GetQdrantClient() log.Info().Msg("Initialized Qdrant client") @@ -143,7 +142,7 @@ func HandlePutRequest(c *fiber.Ctx) error { // query qdrant for response // initialize databases - qdrantClient := database.InitializeQdrant() + qdrantClient := database.GetQdrantClient() log.Info().Msg("Initialized Qdrant client")