Skip to content

Commit

Permalink
feat: surrealdb integration (#280)
Browse files Browse the repository at this point in the history
* feat: surrealdb integration

* docs: update README

* refactor: simplify integration

* readme(rig-surrealdb): add link

* refactor: enable using mem as db

* refactor: only use mem db from local

* fix(rig-surrealdb): crate info

* chore: satisfy ci

* chore: ci

* refactor: amendments

* refactor: amendments

* chore: satisfy ci
  • Loading branch information
joshua-mo-143 authored Feb 24, 2025
1 parent 44c3971 commit 8229c08
Show file tree
Hide file tree
Showing 8 changed files with 2,260 additions and 525 deletions.
2,440 changes: 1,915 additions & 525 deletions Cargo.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ members = [
"rig-core/rig-core-derive",
"rig-sqlite",
"rig-eternalai", "rig-fastembed",
"rig-surrealdb",
]
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ Vector stores are available as separate companion-crates:
- Neo4j vector store: [`rig-neo4j`](https://github.com/0xPlaygrounds/rig/tree/main/rig-neo4j)
- Qdrant vector store: [`rig-qdrant`](https://github.com/0xPlaygrounds/rig/tree/main/rig-qdrant)
- SQLite vector store: [`rig-sqlite`](https://github.com/0xPlaygrounds/rig/tree/main/rig-sqlite)
- SurrealDB vector store: [`rig-surrealdb`](https://github.com/0xPlaygrounds/rig/tree/main/rig-surrealdb)

The following providers are available as separate companion-crates:
- Fastembed: [`rig-fastembed`](https://github.com/0xPlaygrounds/rig/tree/main/rig-fastembed)
Expand Down
21 changes: 21 additions & 0 deletions rig-surrealdb/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
[package]
name = "rig-surrealdb"
version = "0.1.0"
edition = "2021"

[dependencies]
surrealdb = { version = "2.1.4", features = ["protocol-ws", "kv-mem"] }
rig-core = { path = "../rig-core", version = "0.9.0", features = ["derive"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
tracing = "0.1"
uuid = { version = "1.13.1", features = ["v4"] }

[dev-dependencies]
anyhow = "1.0.86"
tokio = { version = "1.38.0", features = ["macros", "rt-multi-thread"] }
tracing-subscriber = { version = "0.3", features = ["env-filter"] }

[[example]]
name = "vector_search_surreal"
required-features = ["rig-core/derive"]
26 changes: 26 additions & 0 deletions rig-surrealdb/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Rig SurrealDB integration
This crate integrates SurrealDB into Rig, allowing you to easily use RAG with this database.

## Installation
To install this crate, run the following command in a Rust project directory which will add `rig-surrealdb` as a dependency (requires `rig-core` added for intended usage):
```bash
cargo add rig-surrealdb
```

There's a few different ways you can run SurrealDB:
- [Install it locally and run it](https://surrealdb.com/docs/surrealdb/installation/linux)
- [Through a Docker container, either locally or on Docker-compatible architecture](https://surrealdb.com/docs/surrealdb/installation/running/docker)
- `docker run --rm --pull always -p 8000:8000 surrealdb/surrealdb:latest start --username root --password root` starts up a SurrealDB instance at port 8000 with the username and password as "root".
- [Using SurrealDB's cloud offering](https://surrealdb.com/cloud)
- Using the cloud offering you can manage your SurrealDB instance through their web UI.

## How to run the example
To run the example, add your OpenAI API key as an environment variable:
```bash
export OPENAI_API_KEY=my_key
```

Finally, use the following command below to run the example:
```bash
cargo run --example vector_search_surreal --features rig-core/derive
```
11 changes: 11 additions & 0 deletions rig-surrealdb/examples/migrations.surql
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
-- define table & fields
DEFINE TABLE documents SCHEMAFULL;
DEFINE field document on table documents type object;
DEFINE field embedding on table documents type array<float>;
DEFINE field embedded_text on table documents type string;

-- define index on embedding field
DEFINE INDEX IF NOT EXISTS words_embedding_vector_index ON documents
FIELDS embedding
MTREE DIMENSION 1536
DIST COSINE;
75 changes: 75 additions & 0 deletions rig-surrealdb/examples/vector_search_surreal.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
use rig::{embeddings::EmbeddingsBuilder, vector_store::VectorStoreIndex, Embed};
use rig_surrealdb::{Mem, SurrealVectorStore};
use serde::{Deserialize, Serialize};
use surrealdb::Surreal;

// A vector search needs to be performed on the `definitions` field, so we derive the `Embed` trait for `WordDefinition`
// and tag that field with `#[embed]`.
// We are not going to store the definitions on our database so we skip the `Serialize` trait
#[derive(Embed, Serialize, Deserialize, Clone, Debug, Eq, PartialEq, Default)]
struct WordDefinition {
word: String,
#[serde(skip)] // we don't want to serialize this field, we use only to create embeddings
#[embed]
definition: String,
}

impl std::fmt::Display for WordDefinition {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.word)
}
}

#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
// Create OpenAI client
let openai_client = rig::providers::openai::Client::from_env();
let model = openai_client.embedding_model(rig::providers::openai::TEXT_EMBEDDING_3_SMALL);

let surreal = Surreal::new::<Mem>(()).await?;

surreal.use_ns("example").use_db("example").await?;

// create test documents with mocked embeddings
let words = vec![
WordDefinition {
word: "flurbo".to_string(),
definition: "1. *flurbo* (name): A fictional digital currency that originated in the animated series Rick and Morty.".to_string()
},
WordDefinition {
word: "glarb-glarb".to_string(),
definition: "1. *glarb-glarb* (noun): A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string()
},
WordDefinition {
word: "linglingdong".to_string(),
definition: "1. *linglingdong* (noun): A term used by inhabitants of the far side of the moon to describe humans.".to_string(),
}];

let documents = EmbeddingsBuilder::new(model.clone())
.documents(words)
.unwrap()
.build()
.await
.expect("Failed to create embeddings");

// init vector store
let vector_store = SurrealVectorStore::with_defaults(model, surreal);

vector_store.insert_documents(documents).await?;

// query vector
let query = "What does \"glarb-glarb\" mean?";

let results = vector_store.top_n::<WordDefinition>(query, 2).await?;

println!("#{} results for query: {}", results.len(), query);
for (distance, _id, doc) in results.iter() {
println!("Result distance {} for word: {}", distance, doc);

// expected output
// Result distance 0.693218142100547 for word: glarb-glarb
// Result distance 0.2529120980283861 for word: linglingdong
}

Ok(())
}
210 changes: 210 additions & 0 deletions rig-surrealdb/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
use std::fmt::Display;

use rig::{
embeddings::{Embedding, EmbeddingModel},
vector_store::{VectorStoreError, VectorStoreIndex},
Embed, OneOrMany,
};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use surrealdb::{sql::Thing, Connection, Surreal};

pub use surrealdb::engine::local::Mem;
pub use surrealdb::engine::remote::ws::{Ws, Wss};

pub struct SurrealVectorStore<Model: EmbeddingModel, C: Connection> {
model: Model,
surreal: Surreal<C>,
documents_table: String,
distance_function: SurrealDistanceFunction,
}

/// SurrealDB supported distances
pub enum SurrealDistanceFunction {
Knn,
Hamming,
Euclidean,
Cosine,
Jaccard,
}

impl Display for SurrealDistanceFunction {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
SurrealDistanceFunction::Cosine => write!(f, "vector::similarity::cosine"),
SurrealDistanceFunction::Knn => write!(f, "vector::distance::knn"),
SurrealDistanceFunction::Euclidean => write!(f, "vector::distance::euclidean"),
SurrealDistanceFunction::Hamming => write!(f, "vector::distance::hamming"),
SurrealDistanceFunction::Jaccard => write!(f, "vector::similarity::jaccard"),
}
}
}

#[derive(Debug, Deserialize)]
struct SearchResult {
id: Thing,
document: String,
distance: f64,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct CreateRecord {
document: String,
embedded_text: String,
embedding: Vec<f64>,
}

#[derive(Debug, Deserialize)]
pub struct SearchResultOnlyId {
id: Thing,
distance: f64,
}

impl SearchResult {
pub fn into_result<T: DeserializeOwned>(self) -> Result<(f64, String, T), VectorStoreError> {
let document: T =
serde_json::from_str(&self.document).map_err(VectorStoreError::JsonError)?;

Ok((self.distance, self.id.id.to_string(), document))
}
}

impl<Model: EmbeddingModel, C: Connection> SurrealVectorStore<Model, C> {
pub fn new(
model: Model,
surreal: Surreal<C>,
documents_table: Option<String>,
distance_function: SurrealDistanceFunction,
) -> Self {
Self {
model,
surreal,
documents_table: documents_table.unwrap_or(String::from("documents")),
distance_function,
}
}

pub fn inner_client(&self) -> &Surreal<C> {
&self.surreal
}

pub fn with_defaults(model: Model, surreal: Surreal<C>) -> Self {
Self::new(model, surreal, None, SurrealDistanceFunction::Cosine)
}

fn search_query_full(&self) -> String {
self.search_query(true)
}

fn search_query_only_ids(&self) -> String {
self.search_query(false)
}

fn search_query(&self, with_document: bool) -> String {
let document = if with_document { ", document" } else { "" };
let embedded_text = if with_document { ", embedded_text" } else { "" };
let Self {
distance_function, ..
} = self;
format!(
"
SELECT id {document} {embedded_text}, {distance_function}($vec, embedding) as distance \
from type::table($tablename) order by distance desc \
LIMIT $limit",
)
}

pub async fn insert_documents<Doc: Serialize + Embed + Send>(
&self,
documents: Vec<(Doc, OneOrMany<Embedding>)>,
) -> Result<(), VectorStoreError> {
for (document, embeddings) in documents {
let json_document: serde_json::Value = serde_json::to_value(&document).unwrap();
let json_document_as_string = serde_json::to_string(&json_document).unwrap();

for embedding in embeddings {
let embedded_text = embedding.document;
let embedding: Vec<f64> = embedding.vec;

let record = CreateRecord {
document: json_document_as_string.clone(),
embedded_text,
embedding,
};

self.surreal
.create::<Option<CreateRecord>>(self.documents_table.clone())
.content(record)
.await
.map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
}
}

Ok(())
}
}

impl<Model: EmbeddingModel, C: Connection> VectorStoreIndex for SurrealVectorStore<Model, C> {
/// Get the top n documents based on the distance to the given query.
/// The result is a list of tuples of the form (score, id, document)
async fn top_n<T: for<'a> Deserialize<'a> + Send>(
&self,
query: &str,
n: usize,
) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
let embedded_query: Vec<f64> = self.model.embed_text(query).await?.vec;

let mut response = self
.surreal
.query(self.search_query_full().as_str())
.bind(("vec", embedded_query))
.bind(("tablename", self.documents_table.clone()))
.bind(("limit", n))
.await
.map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;

let rows: Vec<SearchResult> = response
.take(0)
.map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;

let rows: Vec<(f64, String, T)> = rows
.into_iter()
.flat_map(SearchResult::into_result)
.collect();

Ok(rows)
}

/// Same as `top_n` but returns the document ids only.
async fn top_n_ids(
&self,
query: &str,
n: usize,
) -> Result<Vec<(f64, String)>, VectorStoreError> {
let embedded_query: Vec<f32> = self
.model
.embed_text(query)
.await?
.vec
.iter()
.map(|&x| x as f32)
.collect();

let mut response = self
.surreal
.query(self.search_query_only_ids().as_str())
.bind(("vec", embedded_query))
.bind(("tablename", self.documents_table.clone()))
.bind(("limit", n))
.await
.map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;

let rows: Vec<(f64, String)> = response
.take::<Vec<SearchResultOnlyId>>(0)
.unwrap()
.into_iter()
.map(|row| (row.distance, row.id.id.to_string()))
.collect();

Ok(rows)
}
}

0 comments on commit 8229c08

Please sign in to comment.