Skip to content

Commit

Permalink
[Keyword search] Core node endpoint
Browse files Browse the repository at this point in the history
Description
---
Fixes dust-tt/tasks#1613

Risks
---
na (endpoint not used)

Deploy
---
core
  • Loading branch information
philipperolet committed Dec 17, 2024
1 parent 64190f2 commit 60ffbf5
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 2 deletions.
46 changes: 45 additions & 1 deletion core/bin/core_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ use dust::{
providers::provider::{provider, ProviderID},
run,
search_filter::{Filterable, SearchFilter},
search_stores::search_store::{ElasticsearchSearchStore, SearchStore},
search_stores::search_store::{
DatasourceViewFilter, ElasticsearchSearchStore, NodesSearchOptions, SearchStore,
},
sqlite_workers::client::{self, HEARTBEAT_INTERVAL_MS},
stores::{
postgres,
Expand Down Expand Up @@ -3074,6 +3076,45 @@ async fn folders_delete(
}
}

#[derive(serde::Deserialize)]
struct NodesSearchPayload {
query: String,
// filter: { datasource_id: string, view_filter: string[] }[]
filter: Vec<DatasourceViewFilter>,
options: Option<NodesSearchOptions>,
}

async fn nodes_search(
State(state): State<Arc<APIState>>,
Json(payload): Json<NodesSearchPayload>,
) -> (StatusCode, Json<APIResponse>) {
let nodes = match state
.search_store
.search_nodes(payload.query, payload.filter, payload.options)
.await
{
Ok(nodes) => nodes,
Err(e) => {
return error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_server_error",
"Failed to search nodes",
Some(e),
);
}
};

(
StatusCode::OK,
Json(APIResponse {
error: None,
response: Some(json!({
"nodes": nodes,
})),
}),
)
}

#[derive(serde::Deserialize)]
struct DatabaseQueryRunPayload {
query: String,
Expand Down Expand Up @@ -3551,6 +3592,9 @@ fn main() {
delete(folders_delete),
)

//Search
.route("/nodes/search", post(nodes_search))

// Misc
.route("/tokenize", post(tokenize))
.route("/tokenize/batch", post(tokenize_batch))
Expand Down
6 changes: 6 additions & 0 deletions core/src/data_sources/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,9 @@ impl Node {
)
}
}

impl From<serde_json::Value> for Node {
fn from(value: serde_json::Value) -> Self {
serde_json::from_value(value).expect("Failed to deserialize Node from JSON value")
}
}
89 changes: 88 additions & 1 deletion core/src/search_stores/search_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,49 @@ use async_trait::async_trait;
use elasticsearch::{
auth::Credentials,
http::transport::{SingleNodeConnectionPool, TransportBuilder},
Elasticsearch, IndexParts,
Elasticsearch, IndexParts, SearchParts,
};
use rand::Rng;
use serde_json::json;
use url::Url;

use crate::data_sources::node::Node;
use crate::{data_sources::data_source::Document, utils};
use tracing::{error, info};

#[derive(serde::Deserialize)]
pub struct NodesSearchOptions {
limit: Option<usize>,
offset: Option<usize>,
}

#[derive(serde::Deserialize)]
pub struct DatasourceViewFilter {
datasource_id: String,
view_filter: Vec<String>,
}

#[async_trait]
pub trait SearchStore {
async fn search_nodes(
&self,
query: String,
filter: Vec<DatasourceViewFilter>,
options: Option<NodesSearchOptions>,
) -> Result<Vec<Node>>;
async fn index_document(&self, document: &Document) -> Result<()>;
fn clone_box(&self) -> Box<dyn SearchStore + Sync + Send>;
}

impl Default for NodesSearchOptions {
fn default() -> Self {
NodesSearchOptions {
limit: Some(10),
offset: Some(0),
}
}
}

impl Clone for Box<dyn SearchStore + Sync + Send> {
fn clone(&self) -> Self {
self.clone_box()
Expand Down Expand Up @@ -88,6 +117,64 @@ impl SearchStore for ElasticsearchSearchStore {
}
}

async fn search_nodes(
&self,
query: String,
filter: Vec<DatasourceViewFilter>,
options: Option<NodesSearchOptions>,
) -> Result<Vec<Node>> {
// First, collect all datasource_ids and their corresponding view_filters
let mut filter_conditions = Vec::new();
for f in filter {
filter_conditions.push(json!({
"bool": {
"must": [
{ "term": { "data_source_id": f.datasource_id } },
{ "terms": { "parents": f.view_filter } }
]
}
}));
}

let options = options.unwrap_or_default();

// then, search
match self
.client
.search(SearchParts::Index(&[NODES_INDEX_NAME]))
.from(options.offset.unwrap_or(0) as i64)
.size(options.limit.unwrap_or(100) as i64)
.body(json!({
"query": {
"bool": {
"must": {
"match": {
"title.edge": query
}
},
"should": filter_conditions,
"minimum_should_match": 1
}
}
}))
.send()
.await
{
Ok(response) => {
// get nodes from elasticsearch response in hits.hits
let response_body = response.json::<serde_json::Value>().await?;
let nodes: Vec<Node> = response_body["hits"]["hits"]
.as_array()
.unwrap()
.iter()
.map(|h| Node::from(h.get("source").unwrap().clone()))
.collect();
Ok(nodes)
}
Err(e) => Err(e.into()),
}
}

fn clone_box(&self) -> Box<dyn SearchStore + Sync + Send> {
Box::new(self.clone())
}
Expand Down

0 comments on commit 60ffbf5

Please sign in to comment.