From 60ffbf5089278cfeffff90e7a5834801fb97d009 Mon Sep 17 00:00:00 2001 From: filou Date: Mon, 16 Dec 2024 22:35:01 +0100 Subject: [PATCH] [Keyword search] Core node endpoint Description --- Fixes https://github.com/dust-tt/tasks/issues/1613 Risks --- na (endpoint not used) Deploy --- core --- core/bin/core_api.rs | 46 ++++++++++++- core/src/data_sources/node.rs | 6 ++ core/src/search_stores/search_store.rs | 89 +++++++++++++++++++++++++- 3 files changed, 139 insertions(+), 2 deletions(-) diff --git a/core/bin/core_api.rs b/core/bin/core_api.rs index c0efead2d2123..cca8f3419ed36 100644 --- a/core/bin/core_api.rs +++ b/core/bin/core_api.rs @@ -48,7 +48,9 @@ use dust::{ providers::provider::{provider, ProviderID}, run, search_filter::{Filterable, SearchFilter}, - search_stores::search_store::{ElasticsearchSearchStore, SearchStore}, + search_stores::search_store::{ + DatasourceViewFilter, ElasticsearchSearchStore, NodesSearchOptions, SearchStore, + }, sqlite_workers::client::{self, HEARTBEAT_INTERVAL_MS}, stores::{ postgres, @@ -3074,6 +3076,45 @@ async fn folders_delete( } } +#[derive(serde::Deserialize)] +struct NodesSearchPayload { + query: String, + // filter: { datasource_id: string, view_filter: string[] }[] + filter: Vec, + options: Option, +} + +async fn nodes_search( + State(state): State>, + Json(payload): Json, +) -> (StatusCode, Json) { + let nodes = match state + .search_store + .search_nodes(payload.query, payload.filter, payload.options) + .await + { + Ok(nodes) => nodes, + Err(e) => { + return error_response( + StatusCode::INTERNAL_SERVER_ERROR, + "internal_server_error", + "Failed to search nodes", + Some(e), + ); + } + }; + + ( + StatusCode::OK, + Json(APIResponse { + error: None, + response: Some(json!({ + "nodes": nodes, + })), + }), + ) +} + #[derive(serde::Deserialize)] struct DatabaseQueryRunPayload { query: String, @@ -3551,6 +3592,9 @@ fn main() { delete(folders_delete), ) + //Search + .route("/nodes/search", post(nodes_search)) + // Misc .route("/tokenize", post(tokenize)) .route("/tokenize/batch", post(tokenize_batch)) diff --git a/core/src/data_sources/node.rs b/core/src/data_sources/node.rs index 41e42b588b600..77bc7c20b587b 100644 --- a/core/src/data_sources/node.rs +++ b/core/src/data_sources/node.rs @@ -78,3 +78,9 @@ impl Node { ) } } + +impl From for Node { + fn from(value: serde_json::Value) -> Self { + serde_json::from_value(value).expect("Failed to deserialize Node from JSON value") + } +} diff --git a/core/src/search_stores/search_store.rs b/core/src/search_stores/search_store.rs index 554a2ad54f7ed..c6548e8ad43f5 100644 --- a/core/src/search_stores/search_store.rs +++ b/core/src/search_stores/search_store.rs @@ -3,20 +3,49 @@ use async_trait::async_trait; use elasticsearch::{ auth::Credentials, http::transport::{SingleNodeConnectionPool, TransportBuilder}, - Elasticsearch, IndexParts, + Elasticsearch, IndexParts, SearchParts, }; use rand::Rng; +use serde_json::json; use url::Url; use crate::data_sources::node::Node; use crate::{data_sources::data_source::Document, utils}; use tracing::{error, info}; + +#[derive(serde::Deserialize)] +pub struct NodesSearchOptions { + limit: Option, + offset: Option, +} + +#[derive(serde::Deserialize)] +pub struct DatasourceViewFilter { + datasource_id: String, + view_filter: Vec, +} + #[async_trait] pub trait SearchStore { + async fn search_nodes( + &self, + query: String, + filter: Vec, + options: Option, + ) -> Result>; async fn index_document(&self, document: &Document) -> Result<()>; fn clone_box(&self) -> Box; } +impl Default for NodesSearchOptions { + fn default() -> Self { + NodesSearchOptions { + limit: Some(10), + offset: Some(0), + } + } +} + impl Clone for Box { fn clone(&self) -> Self { self.clone_box() @@ -88,6 +117,64 @@ impl SearchStore for ElasticsearchSearchStore { } } + async fn search_nodes( + &self, + query: String, + filter: Vec, + options: Option, + ) -> Result> { + // First, collect all datasource_ids and their corresponding view_filters + let mut filter_conditions = Vec::new(); + for f in filter { + filter_conditions.push(json!({ + "bool": { + "must": [ + { "term": { "data_source_id": f.datasource_id } }, + { "terms": { "parents": f.view_filter } } + ] + } + })); + } + + let options = options.unwrap_or_default(); + + // then, search + match self + .client + .search(SearchParts::Index(&[NODES_INDEX_NAME])) + .from(options.offset.unwrap_or(0) as i64) + .size(options.limit.unwrap_or(100) as i64) + .body(json!({ + "query": { + "bool": { + "must": { + "match": { + "title.edge": query + } + }, + "should": filter_conditions, + "minimum_should_match": 1 + } + } + })) + .send() + .await + { + Ok(response) => { + // get nodes from elasticsearch response in hits.hits + let response_body = response.json::().await?; + let nodes: Vec = response_body["hits"]["hits"] + .as_array() + .unwrap() + .iter() + .map(|h| Node::from(h.get("source").unwrap().clone())) + .collect(); + Ok(nodes) + } + Err(e) => Err(e.into()), + } + } + fn clone_box(&self) -> Box { Box::new(self.clone()) }