-
Notifications
You must be signed in to change notification settings - Fork 304
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* feat: surrealdb integration * docs: update README * refactor: simplify integration * readme(rig-surrealdb): add link * refactor: enable using mem as db * refactor: only use mem db from local * fix(rig-surrealdb): crate info * chore: satisfy ci * chore: ci * refactor: amendments * refactor: amendments * chore: satisfy ci
- Loading branch information
1 parent
44c3971
commit 8229c08
Showing
8 changed files
with
2,260 additions
and
525 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,4 +10,5 @@ members = [ | |
"rig-core/rig-core-derive", | ||
"rig-sqlite", | ||
"rig-eternalai", "rig-fastembed", | ||
"rig-surrealdb", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
[package] | ||
name = "rig-surrealdb" | ||
version = "0.1.0" | ||
edition = "2021" | ||
|
||
[dependencies] | ||
surrealdb = { version = "2.1.4", features = ["protocol-ws", "kv-mem"] } | ||
rig-core = { path = "../rig-core", version = "0.9.0", features = ["derive"] } | ||
serde = { version = "1.0", features = ["derive"] } | ||
serde_json = "1.0" | ||
tracing = "0.1" | ||
uuid = { version = "1.13.1", features = ["v4"] } | ||
|
||
[dev-dependencies] | ||
anyhow = "1.0.86" | ||
tokio = { version = "1.38.0", features = ["macros", "rt-multi-thread"] } | ||
tracing-subscriber = { version = "0.3", features = ["env-filter"] } | ||
|
||
[[example]] | ||
name = "vector_search_surreal" | ||
required-features = ["rig-core/derive"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# Rig SurrealDB integration | ||
This crate integrates SurrealDB into Rig, allowing you to easily use RAG with this database. | ||
|
||
## Installation | ||
To install this crate, run the following command in a Rust project directory which will add `rig-surrealdb` as a dependency (requires `rig-core` added for intended usage): | ||
```bash | ||
cargo add rig-surrealdb | ||
``` | ||
|
||
There's a few different ways you can run SurrealDB: | ||
- [Install it locally and run it](https://surrealdb.com/docs/surrealdb/installation/linux) | ||
- [Through a Docker container, either locally or on Docker-compatible architecture](https://surrealdb.com/docs/surrealdb/installation/running/docker) | ||
- `docker run --rm --pull always -p 8000:8000 surrealdb/surrealdb:latest start --username root --password root` starts up a SurrealDB instance at port 8000 with the username and password as "root". | ||
- [Using SurrealDB's cloud offering](https://surrealdb.com/cloud) | ||
- Using the cloud offering you can manage your SurrealDB instance through their web UI. | ||
|
||
## How to run the example | ||
To run the example, add your OpenAI API key as an environment variable: | ||
```bash | ||
export OPENAI_API_KEY=my_key | ||
``` | ||
|
||
Finally, use the following command below to run the example: | ||
```bash | ||
cargo run --example vector_search_surreal --features rig-core/derive | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
-- define table & fields | ||
DEFINE TABLE documents SCHEMAFULL; | ||
DEFINE field document on table documents type object; | ||
DEFINE field embedding on table documents type array<float>; | ||
DEFINE field embedded_text on table documents type string; | ||
|
||
-- define index on embedding field | ||
DEFINE INDEX IF NOT EXISTS words_embedding_vector_index ON documents | ||
FIELDS embedding | ||
MTREE DIMENSION 1536 | ||
DIST COSINE; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
use rig::{embeddings::EmbeddingsBuilder, vector_store::VectorStoreIndex, Embed}; | ||
use rig_surrealdb::{Mem, SurrealVectorStore}; | ||
use serde::{Deserialize, Serialize}; | ||
use surrealdb::Surreal; | ||
|
||
// A vector search needs to be performed on the `definitions` field, so we derive the `Embed` trait for `WordDefinition` | ||
// and tag that field with `#[embed]`. | ||
// We are not going to store the definitions on our database so we skip the `Serialize` trait | ||
#[derive(Embed, Serialize, Deserialize, Clone, Debug, Eq, PartialEq, Default)] | ||
struct WordDefinition { | ||
word: String, | ||
#[serde(skip)] // we don't want to serialize this field, we use only to create embeddings | ||
#[embed] | ||
definition: String, | ||
} | ||
|
||
impl std::fmt::Display for WordDefinition { | ||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||
write!(f, "{}", self.word) | ||
} | ||
} | ||
|
||
#[tokio::main] | ||
async fn main() -> Result<(), anyhow::Error> { | ||
// Create OpenAI client | ||
let openai_client = rig::providers::openai::Client::from_env(); | ||
let model = openai_client.embedding_model(rig::providers::openai::TEXT_EMBEDDING_3_SMALL); | ||
|
||
let surreal = Surreal::new::<Mem>(()).await?; | ||
|
||
surreal.use_ns("example").use_db("example").await?; | ||
|
||
// create test documents with mocked embeddings | ||
let words = vec![ | ||
WordDefinition { | ||
word: "flurbo".to_string(), | ||
definition: "1. *flurbo* (name): A fictional digital currency that originated in the animated series Rick and Morty.".to_string() | ||
}, | ||
WordDefinition { | ||
word: "glarb-glarb".to_string(), | ||
definition: "1. *glarb-glarb* (noun): A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() | ||
}, | ||
WordDefinition { | ||
word: "linglingdong".to_string(), | ||
definition: "1. *linglingdong* (noun): A term used by inhabitants of the far side of the moon to describe humans.".to_string(), | ||
}]; | ||
|
||
let documents = EmbeddingsBuilder::new(model.clone()) | ||
.documents(words) | ||
.unwrap() | ||
.build() | ||
.await | ||
.expect("Failed to create embeddings"); | ||
|
||
// init vector store | ||
let vector_store = SurrealVectorStore::with_defaults(model, surreal); | ||
|
||
vector_store.insert_documents(documents).await?; | ||
|
||
// query vector | ||
let query = "What does \"glarb-glarb\" mean?"; | ||
|
||
let results = vector_store.top_n::<WordDefinition>(query, 2).await?; | ||
|
||
println!("#{} results for query: {}", results.len(), query); | ||
for (distance, _id, doc) in results.iter() { | ||
println!("Result distance {} for word: {}", distance, doc); | ||
|
||
// expected output | ||
// Result distance 0.693218142100547 for word: glarb-glarb | ||
// Result distance 0.2529120980283861 for word: linglingdong | ||
} | ||
|
||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,210 @@ | ||
use std::fmt::Display; | ||
|
||
use rig::{ | ||
embeddings::{Embedding, EmbeddingModel}, | ||
vector_store::{VectorStoreError, VectorStoreIndex}, | ||
Embed, OneOrMany, | ||
}; | ||
use serde::{de::DeserializeOwned, Deserialize, Serialize}; | ||
use surrealdb::{sql::Thing, Connection, Surreal}; | ||
|
||
pub use surrealdb::engine::local::Mem; | ||
pub use surrealdb::engine::remote::ws::{Ws, Wss}; | ||
|
||
pub struct SurrealVectorStore<Model: EmbeddingModel, C: Connection> { | ||
model: Model, | ||
surreal: Surreal<C>, | ||
documents_table: String, | ||
distance_function: SurrealDistanceFunction, | ||
} | ||
|
||
/// SurrealDB supported distances | ||
pub enum SurrealDistanceFunction { | ||
Knn, | ||
Hamming, | ||
Euclidean, | ||
Cosine, | ||
Jaccard, | ||
} | ||
|
||
impl Display for SurrealDistanceFunction { | ||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { | ||
match self { | ||
SurrealDistanceFunction::Cosine => write!(f, "vector::similarity::cosine"), | ||
SurrealDistanceFunction::Knn => write!(f, "vector::distance::knn"), | ||
SurrealDistanceFunction::Euclidean => write!(f, "vector::distance::euclidean"), | ||
SurrealDistanceFunction::Hamming => write!(f, "vector::distance::hamming"), | ||
SurrealDistanceFunction::Jaccard => write!(f, "vector::similarity::jaccard"), | ||
} | ||
} | ||
} | ||
|
||
#[derive(Debug, Deserialize)] | ||
struct SearchResult { | ||
id: Thing, | ||
document: String, | ||
distance: f64, | ||
} | ||
|
||
#[derive(Debug, Serialize, Deserialize)] | ||
pub struct CreateRecord { | ||
document: String, | ||
embedded_text: String, | ||
embedding: Vec<f64>, | ||
} | ||
|
||
#[derive(Debug, Deserialize)] | ||
pub struct SearchResultOnlyId { | ||
id: Thing, | ||
distance: f64, | ||
} | ||
|
||
impl SearchResult { | ||
pub fn into_result<T: DeserializeOwned>(self) -> Result<(f64, String, T), VectorStoreError> { | ||
let document: T = | ||
serde_json::from_str(&self.document).map_err(VectorStoreError::JsonError)?; | ||
|
||
Ok((self.distance, self.id.id.to_string(), document)) | ||
} | ||
} | ||
|
||
impl<Model: EmbeddingModel, C: Connection> SurrealVectorStore<Model, C> { | ||
pub fn new( | ||
model: Model, | ||
surreal: Surreal<C>, | ||
documents_table: Option<String>, | ||
distance_function: SurrealDistanceFunction, | ||
) -> Self { | ||
Self { | ||
model, | ||
surreal, | ||
documents_table: documents_table.unwrap_or(String::from("documents")), | ||
distance_function, | ||
} | ||
} | ||
|
||
pub fn inner_client(&self) -> &Surreal<C> { | ||
&self.surreal | ||
} | ||
|
||
pub fn with_defaults(model: Model, surreal: Surreal<C>) -> Self { | ||
Self::new(model, surreal, None, SurrealDistanceFunction::Cosine) | ||
} | ||
|
||
fn search_query_full(&self) -> String { | ||
self.search_query(true) | ||
} | ||
|
||
fn search_query_only_ids(&self) -> String { | ||
self.search_query(false) | ||
} | ||
|
||
fn search_query(&self, with_document: bool) -> String { | ||
let document = if with_document { ", document" } else { "" }; | ||
let embedded_text = if with_document { ", embedded_text" } else { "" }; | ||
let Self { | ||
distance_function, .. | ||
} = self; | ||
format!( | ||
" | ||
SELECT id {document} {embedded_text}, {distance_function}($vec, embedding) as distance \ | ||
from type::table($tablename) order by distance desc \ | ||
LIMIT $limit", | ||
) | ||
} | ||
|
||
pub async fn insert_documents<Doc: Serialize + Embed + Send>( | ||
&self, | ||
documents: Vec<(Doc, OneOrMany<Embedding>)>, | ||
) -> Result<(), VectorStoreError> { | ||
for (document, embeddings) in documents { | ||
let json_document: serde_json::Value = serde_json::to_value(&document).unwrap(); | ||
let json_document_as_string = serde_json::to_string(&json_document).unwrap(); | ||
|
||
for embedding in embeddings { | ||
let embedded_text = embedding.document; | ||
let embedding: Vec<f64> = embedding.vec; | ||
|
||
let record = CreateRecord { | ||
document: json_document_as_string.clone(), | ||
embedded_text, | ||
embedding, | ||
}; | ||
|
||
self.surreal | ||
.create::<Option<CreateRecord>>(self.documents_table.clone()) | ||
.content(record) | ||
.await | ||
.map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?; | ||
} | ||
} | ||
|
||
Ok(()) | ||
} | ||
} | ||
|
||
impl<Model: EmbeddingModel, C: Connection> VectorStoreIndex for SurrealVectorStore<Model, C> { | ||
/// Get the top n documents based on the distance to the given query. | ||
/// The result is a list of tuples of the form (score, id, document) | ||
async fn top_n<T: for<'a> Deserialize<'a> + Send>( | ||
&self, | ||
query: &str, | ||
n: usize, | ||
) -> Result<Vec<(f64, String, T)>, VectorStoreError> { | ||
let embedded_query: Vec<f64> = self.model.embed_text(query).await?.vec; | ||
|
||
let mut response = self | ||
.surreal | ||
.query(self.search_query_full().as_str()) | ||
.bind(("vec", embedded_query)) | ||
.bind(("tablename", self.documents_table.clone())) | ||
.bind(("limit", n)) | ||
.await | ||
.map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?; | ||
|
||
let rows: Vec<SearchResult> = response | ||
.take(0) | ||
.map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?; | ||
|
||
let rows: Vec<(f64, String, T)> = rows | ||
.into_iter() | ||
.flat_map(SearchResult::into_result) | ||
.collect(); | ||
|
||
Ok(rows) | ||
} | ||
|
||
/// Same as `top_n` but returns the document ids only. | ||
async fn top_n_ids( | ||
&self, | ||
query: &str, | ||
n: usize, | ||
) -> Result<Vec<(f64, String)>, VectorStoreError> { | ||
let embedded_query: Vec<f32> = self | ||
.model | ||
.embed_text(query) | ||
.await? | ||
.vec | ||
.iter() | ||
.map(|&x| x as f32) | ||
.collect(); | ||
|
||
let mut response = self | ||
.surreal | ||
.query(self.search_query_only_ids().as_str()) | ||
.bind(("vec", embedded_query)) | ||
.bind(("tablename", self.documents_table.clone())) | ||
.bind(("limit", n)) | ||
.await | ||
.map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?; | ||
|
||
let rows: Vec<(f64, String)> = response | ||
.take::<Vec<SearchResultOnlyId>>(0) | ||
.unwrap() | ||
.into_iter() | ||
.map(|row| (row.distance, row.id.id.to_string())) | ||
.collect(); | ||
|
||
Ok(rows) | ||
} | ||
} |