From 3f74cc5b0540bf67adc5700408cfa60538d5ae09 Mon Sep 17 00:00:00 2001 From: Luc Date: Tue, 3 Dec 2024 18:54:43 +0100 Subject: [PATCH] Streamline api, Introduce item re-indexing, and introduce intelligence module --- engine/.env.example | 4 + engine/Cargo.lock | 38 ++++++++++ engine/Cargo.toml | 1 + engine/src/intelligence/mod.rs | 40 ++++++++++ engine/src/main.rs | 1 + engine/src/models/item/mod.rs | 6 ++ engine/src/models/item/search.rs | 28 +++++-- engine/src/models/media.rs | 21 ++++- engine/src/routes/instance.rs | 4 +- engine/src/routes/me.rs | 4 +- engine/src/routes/media/mod.rs | 101 +++++++++++++++++++++++++ engine/src/routes/mod.rs | 26 +++---- engine/src/routes/root.rs | 67 ---------------- engine/src/routes/search/mod.rs | 27 +++++-- engine/src/routes/sessions/mod.rs | 4 +- engine/src/routes/users/mod.rs | 4 +- engine/src/search/mod.rs | 82 ++++++++++++++++++-- engine/src/state.rs | 14 +++- web/src/routes/settings/index.lazy.tsx | 24 +++++- 19 files changed, 380 insertions(+), 116 deletions(-) create mode 100644 engine/src/intelligence/mod.rs create mode 100644 engine/src/routes/media/mod.rs delete mode 100644 engine/src/routes/root.rs diff --git a/engine/.env.example b/engine/.env.example index a6a0f78..29ea329 100644 --- a/engine/.env.example +++ b/engine/.env.example @@ -10,3 +10,7 @@ DATABASE_URL=postgres://postgres:postgres@localhost:5432/property # Meilisearch MEILISEARCH_URL=http://localhost:7700 MEILISEARCH_MASTER_KEY=master + +# Ollama +OLLAMA_URL=http://localhost +OLLAMA_PORT=11434 diff --git a/engine/Cargo.lock b/engine/Cargo.lock index 86fc92a..aa68377 100644 --- a/engine/Cargo.lock +++ b/engine/Cargo.lock @@ -249,6 +249,28 @@ dependencies = [ "wasm-bindgen-futures", ] +[[package]] +name = "async-stream" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476" +dependencies = [ + "async-stream-impl", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.89", +] + [[package]] name = "async-task" version = "4.7.1" @@ -1839,6 +1861,21 @@ dependencies = [ "memchr", ] +[[package]] +name = "ollama-rs" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46483ac9e1f9e93da045b5875837ca3c9cf014fd6ab89b4d9736580ddefc4759" +dependencies = [ + "async-stream", + "async-trait", + "log", + "reqwest", + "serde", + "serde_json", + "url", +] + [[package]] name = "once_cell" version = "1.19.0" @@ -3837,6 +3874,7 @@ dependencies = [ "hex", "hmac", "meilisearch-sdk", + "ollama-rs", "openid", "poem", "poem-openapi", diff --git a/engine/Cargo.toml b/engine/Cargo.toml index 139cec4..47589ed 100644 --- a/engine/Cargo.toml +++ b/engine/Cargo.toml @@ -16,6 +16,7 @@ dotenvy = "0.15.7" hex = "0.4.3" hmac = "0.12.1" meilisearch-sdk = "0.27.1" +ollama-rs = "0.2.1" openid = "0.14.0" poem = { version = "3.0.4", git = "https://github.com/poem-web/poem", branch = "master" } poem-openapi = { version = "5", git = "https://github.com/poem-web/poem", branch = "master", features = [ diff --git a/engine/src/intelligence/mod.rs b/engine/src/intelligence/mod.rs new file mode 100644 index 0000000..ac83f6b --- /dev/null +++ b/engine/src/intelligence/mod.rs @@ -0,0 +1,40 @@ +use std::env; + +use ollama_rs::Ollama; +use tracing::info; + +pub struct Intelligence { + pub ollama: Ollama, +} + +impl Intelligence { + pub async fn new(url: String, port: u16) -> Result { + let ollama = Ollama::new(url, port); + + let models: Vec = ollama + .list_local_models() + .await? + .iter() + .map(|m| m.name.clone()) + .collect(); + + info!("Ollama models detected: {:?}", models); + + Ok(Self { + ollama, + }) + } + + pub async fn guess() -> Result { + let url = env::var("OLLAMA_URL") + .map_err(|_| anyhow::anyhow!("OLLAMA_URL is not set"))?; + let port = env::var("OLLAMA_PORT") + .map_err(|_| anyhow::anyhow!("OLLAMA_PORT is not set"))? + .parse::() + .map_err(|_| anyhow::anyhow!("OLLAMA_PORT is not a valid u16"))?; + + Self::new(url, port) + .await + .map_err(anyhow::Error::from) + } +} diff --git a/engine/src/main.rs b/engine/src/main.rs index 4a8e843..2c1d3e4 100644 --- a/engine/src/main.rs +++ b/engine/src/main.rs @@ -8,6 +8,7 @@ mod models; mod routes; mod state; mod search; +mod intelligence; #[async_std::main] async fn main() { diff --git a/engine/src/models/item/mod.rs b/engine/src/models/item/mod.rs index 009f817..fa6e469 100644 --- a/engine/src/models/item/mod.rs +++ b/engine/src/models/item/mod.rs @@ -60,6 +60,12 @@ impl Item { .await } + pub async fn get_all(db: &Database) -> Result, sqlx::Error> { + query_as!(Item, "SELECT * FROM items") + .fetch_all(&db.pool) + .await + } + pub async fn get_by_owner_id( database: &Database, owner_id: i32 diff --git a/engine/src/models/item/search.rs b/engine/src/models/item/search.rs index a9a6d2a..11de5b3 100644 --- a/engine/src/models/item/search.rs +++ b/engine/src/models/item/search.rs @@ -13,18 +13,31 @@ pub struct SearchableItem { pub owner_id: Option, pub location_id: Option, // TODO: add more location info - pub fields: Vec, + pub fields: Option>, pub created_at: Option>, pub updated_at: Option>, + + #[serde(rename = "_vectors")] + pub vectors: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Object)] +pub struct SearchableItemVectors { + pub ollama: SearchableItemVectorsOllama, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Object)] +pub struct SearchableItemVectorsOllama { + pub regenerate: bool, } impl Item { pub async fn into_search(&self, db: &Database) -> Result { - let fields = ItemField::get_by_item_id(db, &self.item_id) - .await? - .iter() - .map(|field| field.into()) - .collect(); + let fields = Some(ItemField::get_by_item_id(db, &self.item_id) + .await? + .iter() + .map(|field| field.into()) + .collect()); Ok(SearchableItem { item_id: self.item_id.clone(), @@ -35,6 +48,9 @@ impl Item { fields, created_at: self.created_at, updated_at: self.updated_at, + vectors: Some(SearchableItemVectors { + ollama: SearchableItemVectorsOllama { regenerate: true }, + }), }) } } diff --git a/engine/src/models/media.rs b/engine/src/models/media.rs index c20d6c0..4388a44 100644 --- a/engine/src/models/media.rs +++ b/engine/src/models/media.rs @@ -1,11 +1,11 @@ use chrono::{DateTime, Utc}; use poem_openapi::Object; use serde::{Deserialize, Serialize}; -use sqlx::{query_as, FromRow}; +use sqlx::{query, query_as, FromRow}; use crate::database::Database; -#[derive(FromRow, Object, Debug, Clone, Serialize, Deserialize)] +#[derive(FromRow, Object, Debug, Clone, Serialize, Deserialize, Default)] pub struct Media { pub media_id: i32, pub description: Option, @@ -33,9 +33,22 @@ impl Media { .await } - pub async fn get_by_id(db: &Database, media_id: i32) -> Result { + pub async fn get_by_id(db: &Database, media_id: i32) -> Result, sqlx::Error> { query_as!(Media, "SELECT * FROM media WHERE media_id = $1", media_id) - .fetch_one(&db.pool) + .fetch_optional(&db.pool) .await } + + pub async fn get_all(db: &Database) -> Result, sqlx::Error> { + query_as!(Media, "SELECT * FROM media") + .fetch_all(&db.pool) + .await + } + + pub async fn delete(self, db: &Database) -> Result<(), sqlx::Error> { + query!("DELETE FROM media WHERE media_id = $1", self.media_id) + .execute(&db.pool) + .await + .map(|_| ()) + } } diff --git a/engine/src/routes/instance.rs b/engine/src/routes/instance.rs index d6248eb..1b8c951 100644 --- a/engine/src/routes/instance.rs +++ b/engine/src/routes/instance.rs @@ -6,7 +6,7 @@ use serde::{Deserialize, Serialize}; use crate::{auth::middleware::AuthToken, state::AppState}; -pub struct ApiInstance; +pub struct InstanceApi; #[derive(Serialize, Deserialize, Enum)] pub enum IdCasingPreference { @@ -32,7 +32,7 @@ impl Default for InstanceSettings { } #[OpenApi] -impl ApiInstance { +impl InstanceApi { #[oai(path = "/instance/settings", method = "get")] pub async fn settings( &self, diff --git a/engine/src/routes/me.rs b/engine/src/routes/me.rs index 718cc58..d208d4f 100644 --- a/engine/src/routes/me.rs +++ b/engine/src/routes/me.rs @@ -9,10 +9,10 @@ use crate::{ state::AppState, }; -pub struct ApiMe; +pub struct MeApi; #[OpenApi] -impl ApiMe { +impl MeApi { #[oai(path = "/me", method = "get")] pub async fn me(&self, state: Data<&Arc>, token: AuthToken) -> Json { match token { diff --git a/engine/src/routes/media/mod.rs b/engine/src/routes/media/mod.rs new file mode 100644 index 0000000..81939eb --- /dev/null +++ b/engine/src/routes/media/mod.rs @@ -0,0 +1,101 @@ +use std::sync::Arc; + +use poem::{ + web::{Data, Path}, + Result, +}; +use poem_openapi::{payload::Json, Object, OpenApi}; +use reqwest::StatusCode; +use serde::{Deserialize, Serialize}; + +use crate::{ + auth::middleware::AuthToken, + models::media::Media, + state::AppState, +}; + +pub struct MediaApi; + +#[derive(Deserialize, Debug, Serialize, Object)] +pub struct MediaIdResponse { + media_id: String, +} + +#[derive(Deserialize, Debug, Serialize, Object)] +pub struct CreateMediaRequest { + name: Option, + kind: Option, +} + +#[OpenApi] +impl MediaApi { + #[oai(path = "/media/:media_id", method = "get")] + async fn get_media( + &self, + state: Data<&Arc>, + auth: AuthToken, + media_id: Path, + ) -> Result> { + Media::get_by_id(&state.database, media_id.0) + .await + .or(Err(poem::Error::from_status( + StatusCode::INTERNAL_SERVER_ERROR, + )))? + .ok_or(poem::Error::from_status(StatusCode::NOT_FOUND)) + .map(|x| Json(x)) + } + + #[oai(path = "/media/:media_id", method = "delete")] + async fn delete_media( + &self, + auth: AuthToken, + state: Data<&Arc>, + media_id: Path, + ) -> Result<()> { + Media::get_by_id(&state.database, media_id.0) + .await + .unwrap() + .unwrap() + .delete(&state.database) + .await + .unwrap(); + + Ok(()) + } + + #[oai(path = "/media", method = "get")] + async fn get_all_media( + &self, + auth: AuthToken, + state: Data<&Arc>, + ) -> Result>> { + match auth.ok() { + Some(user) => Ok(Json( + Media::get_all(&state.database) + .await + .unwrap(), + )), + None => Err(StatusCode::UNAUTHORIZED.into()), + } + } + + // #[oai(path = "/media", method = "post")] + // async fn create_media( + // &self, + // auth: AuthToken, + // state: Data<&Arc>, + // request: Query, + // ) -> Json { + // Json( + // Media { + // ..Default::default() + // } + // .insert(&state.database) + // .await + // .unwrap() + // .index_search(&state.search, &state.database) + // .await + // .unwrap(), + // ) + // } +} diff --git a/engine/src/routes/mod.rs b/engine/src/routes/mod.rs index ac51d71..3f404b9 100644 --- a/engine/src/routes/mod.rs +++ b/engine/src/routes/mod.rs @@ -1,16 +1,16 @@ use std::sync::Arc; -use instance::ApiInstance; +use instance::InstanceApi; use items::ItemsApi; -use me::ApiMe; +use me::MeApi; +use media::MediaApi; use poem::{ get, handler, listener::TcpListener, middleware::Cors, web::Html, EndpointExt, Route, Server, }; use poem_openapi::{OpenApi, OpenApiService}; -use root::RootApi; -use search::{tasks::ApiSearchTask, ApiSearch}; -use sessions::ApiSessions; -use users::ApiUserById; +use search::{tasks::ApiSearchTask, SearchApi}; +use sessions::SessionsApi; +use users::UserApi; use crate::state::AppState; @@ -18,21 +18,21 @@ pub mod instance; pub mod items; pub mod me; pub mod oauth; -pub mod root; pub mod search; pub mod sessions; pub mod users; +pub mod media; fn get_api() -> impl OpenApi { ( - RootApi, - ApiMe, - ApiSessions, - ApiUserById, - ApiInstance, + MeApi, + SessionsApi, + UserApi, + InstanceApi, ItemsApi, - ApiSearch, + SearchApi, ApiSearchTask, + MediaApi ) } diff --git a/engine/src/routes/root.rs b/engine/src/routes/root.rs deleted file mode 100644 index e4335a6..0000000 --- a/engine/src/routes/root.rs +++ /dev/null @@ -1,67 +0,0 @@ -use std::sync::Arc; - -use poem::{ - web::{Data, Path}, -}; -use poem_openapi::{ - param::Query, - payload::{Json, PlainText}, - OpenApi, -}; - -use crate::{ - models::{item::Item, media::Media, products::Product}, - state::AppState, -}; - -pub struct RootApi; - -#[OpenApi] -impl RootApi { - /// Testing one two three - #[oai(path = "/hello", method = "get")] - async fn index(&self, name: Query>) -> PlainText { - match name.0 { - Some(name) => PlainText(format!("Hello, {}!", name)), - None => PlainText("Hello, World!".to_string()), - } - } - - #[oai(path = "/media/:media_id", method = "get")] - async fn get_media( - &self, - state: Data<&Arc>, - media_id: Path, - ) -> poem_openapi::payload::Json { - let media = Media::get_by_id(&state.database, media_id.0).await.unwrap(); - - poem_openapi::payload::Json(media) - } - - #[oai(path = "/product/:product_id", method = "get")] - async fn get_product( - &self, - state: Data<&Arc>, - product_id: Path, - ) -> poem_openapi::payload::Json { - let product = Product::get_by_id(&state.database, product_id.0) - .await - .unwrap(); - - poem_openapi::payload::Json(product.unwrap()) - } - - // #[oai(path = "/item/:item_id", method = "get")] - // async fn get_item( - // &self, - // state: Data<&Arc>, - // item_id: Path, - // ) -> Result> { - // let item = Item::get_by_id(&state.database, item_id.0).await.unwrap(); - - // match item { - // Some(item) => Ok(Json(item)), - // None => Err(StatusCode::NOT_FOUND.into()), - // } - // } -} diff --git a/engine/src/routes/search/mod.rs b/engine/src/routes/search/mod.rs index cebe2fa..cdfa954 100644 --- a/engine/src/routes/search/mod.rs +++ b/engine/src/routes/search/mod.rs @@ -1,21 +1,18 @@ use std::sync::Arc; -use meilisearch_sdk::search::SearchQuery; -use poem::web::{Data, Path, Query}; -use poem::{Error, Result}; +use poem::web::{Data, Query}; +use poem::Result; use poem_openapi::Object; use poem_openapi::{payload::Json, OpenApi}; -use reqwest::StatusCode; use serde::{Deserialize, Serialize}; use tracing::info; use crate::models::item::search::SearchableItem; -use crate::models::item::Item; -use crate::{models::search::SearchTask, state::AppState}; +use crate::state::AppState; pub mod tasks; -pub struct ApiSearch; +pub struct SearchApi; #[derive(Debug, Serialize, Deserialize, Object)] pub struct SearchQueryParams { @@ -30,7 +27,7 @@ pub struct SearchQueryParams { // } #[OpenApi] -impl ApiSearch { +impl SearchApi { #[oai(path = "/search", method = "get")] pub async fn search( &self, @@ -52,4 +49,18 @@ impl ApiSearch { Json(results) } + + #[oai(path = "/search/index", method = "post")] + pub async fn index_all_items(&self, state: Data<&Arc>) -> Result<()> { + info!("Indexing all items"); + state + .search + .as_ref() + .unwrap() + .index_all_items(&state.database) + .await + .unwrap(); + + Ok(()) + } } diff --git a/engine/src/routes/sessions/mod.rs b/engine/src/routes/sessions/mod.rs index e55d2b4..f41d15d 100644 --- a/engine/src/routes/sessions/mod.rs +++ b/engine/src/routes/sessions/mod.rs @@ -9,10 +9,10 @@ use crate::{auth::middleware::AuthToken, models::sessions::Session, state::AppSt pub mod delete; -pub struct ApiSessions; +pub struct SessionsApi; #[OpenApi] -impl ApiSessions { +impl SessionsApi { #[oai(path = "/sessions", method = "get")] async fn get_sessions( &self, diff --git a/engine/src/routes/users/mod.rs b/engine/src/routes/users/mod.rs index 26b96de..3b8ea88 100644 --- a/engine/src/routes/users/mod.rs +++ b/engine/src/routes/users/mod.rs @@ -8,10 +8,10 @@ use crate::{ state::AppState, }; -pub struct ApiUserById; +pub struct UserApi; #[OpenApi] -impl ApiUserById { +impl UserApi { #[oai(path = "/user/:id", method = "get")] pub async fn user(&self, state: Data<&Arc>, id: Path) -> Json { let user = UserEntry::find_by_user_id(id.0, &state.database) diff --git a/engine/src/search/mod.rs b/engine/src/search/mod.rs index 60b6271..d1f1497 100644 --- a/engine/src/search/mod.rs +++ b/engine/src/search/mod.rs @@ -1,12 +1,16 @@ use std::env; -use bigdecimal::ToPrimitive; -use meilisearch_sdk::{client::Client, tasks::Task}; -use tracing::{info, warn}; +use meilisearch_sdk::{client::Client, settings::Settings}; +use serde_json::json; +use tracing::info; use crate::{ database::Database, - models::{item::search::SearchableItem, search::SearchTask}, + intelligence::Intelligence, + models::{ + item::{search::SearchableItem, Item}, + search::SearchTask, + }, }; pub struct Search { @@ -17,22 +21,64 @@ impl Search { pub async fn new( url: String, master_key: Option, + intelligence: &Option, ) -> Result { - let client = Client::new(url, master_key)?; + let client = Client::new(url.clone(), master_key.clone())?; let health = client.health().await?; info!("Meilisearch is healthy: {:?}", health); + let skip_embeddings = env::var("MEILISEARCH_SKIP_EMBEDDINGS") + .unwrap_or("false".to_string()) + .to_lowercase() + == "true"; + + if !skip_embeddings { + if let Some(intelligence) = intelligence { + let x = intelligence + .ollama + .pull_model("all-minilm".to_string(), false) + .await + .unwrap(); + + info!("Model pulled: {:?}", x); + + // TODO: check if embeddings are already enabled on meilisearch + // you have to manually enable vectorStore (experimental feature) + let response = reqwest::Client::new() + .patch(format!("{}/indexes/items/settings", url)) + .json(&json!({ + "embedders": { + "ollama": { + "source": "ollama", + "url": format!("{}/api/embeddings", intelligence.ollama.url()), + "model": "all-minilm", + "documentTemplate": "{{doc.name}}", + "dimensions": 384 + } + } + })) + .header( + "Authorization", + format!("Bearer {}", &master_key.clone().unwrap_or("".to_string())), + ) + .send() + .await?; + + info!("Embeddings enabled: {:?}", response); + } + } + Ok(Self { client }) } - pub async fn guess() -> Result { + pub async fn guess(intelligence: &Option) -> Result { let url = env::var("MEILISEARCH_URL") .map_err(|_| anyhow::anyhow!("MEILISEARCH_URL is not set"))?; let master_key = env::var("MEILISEARCH_MASTER_KEY").ok(); - Self::new(url, master_key) + Self::new(url, master_key, intelligence) .await .map_err(anyhow::Error::from) } @@ -52,6 +98,28 @@ impl Search { Ok(()) } + pub async fn index_all_items(&self, db: &Database) -> Result<(), ()> { + let items = Item::get_all(db).await.unwrap(); + + // batch by 10 (artificial TODO: implement sql paging) + let batches = items.chunks(10); + + for batch in batches { + let x = self + .client + .index("items") + .add_documents(batch, Some("item_id")) + .await + .unwrap(); + + SearchTask::new(db, x.task_uid, x.status.into()) + .await + .unwrap(); + } + + Ok(()) + } + pub async fn refresh_task( &self, db: &Database, diff --git a/engine/src/state.rs b/engine/src/state.rs index e8c25ee..45dc0fc 100644 --- a/engine/src/state.rs +++ b/engine/src/state.rs @@ -4,12 +4,13 @@ use openid::DiscoveredClient; use reqwest::Url; use tracing::warn; -use crate::{auth::oauth::OpenIDClient, database::Database, search::Search}; +use crate::{auth::oauth::OpenIDClient, database::Database, intelligence::Intelligence, search::Search}; pub struct AppState { pub database: Database, // #[cfg(feature = "oauth")] pub openid: OpenIDClient, + pub intelligence: Option, pub search: Option, } @@ -39,7 +40,15 @@ impl AppState { .unwrap() }; - let search = match Search::guess().await { + let intelligence = match Intelligence::guess().await { + Ok(intelligence) => Some(intelligence), + Err(e) => { + warn!("Failed to initialize intelligence: {}", e); + None + } + }; + + let search = match Search::guess(&intelligence).await { Ok(search) => Some(search), Err(e) => { warn!("Failed to initialize search: {}", e); @@ -50,6 +59,7 @@ impl AppState { Self { database, openid, + intelligence, search, } } diff --git a/web/src/routes/settings/index.lazy.tsx b/web/src/routes/settings/index.lazy.tsx index be2ea58..d6701e7 100644 --- a/web/src/routes/settings/index.lazy.tsx +++ b/web/src/routes/settings/index.lazy.tsx @@ -1,16 +1,38 @@ +import { useMutation } from '@tanstack/react-query'; import { createLazyFileRoute } from '@tanstack/react-router'; +import { useAuth } from '@/api/auth'; +import { BASE_URL } from '@/api/core'; import { useInstanceSettings } from '@/api/instance_settings'; import { SearchTaskTable } from '@/components/search_tasks/SearchTaskTable'; +import { Button } from '@/components/ui/Button'; import { SCPage } from '@/layouts/SimpleCenterPage'; export const Route = createLazyFileRoute('/settings/')({ component: () => { const { data: instanceSettings } = useInstanceSettings(); + const { token } = useAuth(); + const { mutate: indexAllItems } = useMutation({ + mutationFn: async () => { + const response = await fetch(BASE_URL + '/api/search/index', { + method: 'POST', + headers: { + Authorization: `Bearer ${token}`, + }, + }); + + return response.json(); + }, + }); return ( -
Hello /settings!
+
+ Hello /settings! + +

Instance Settings