Skip to content

Commit

Permalink
[ENH] Add rust grpc server (chroma-core#1548)
Browse files Browse the repository at this point in the history
## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
	 - Adds the grpc connections for the worker
 - New functionality
	 - /

## Test plan
*How are these changes tested?*

- [x]  Tests pass locally with `cargo test`

## Documentation Changes
*Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the [docs repository](https://github.com/chroma-core/docs)?*
  • Loading branch information
HammadB authored Jan 16, 2024
1 parent 68d806c commit a669624
Show file tree
Hide file tree
Showing 17 changed files with 489 additions and 34 deletions.
2 changes: 1 addition & 1 deletion go/coordinator/internal/utils/pulsar_admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func CreateTopics(pulsarAdminURL string, tenant string, namespace string, topics
log.Info("Topic already exists", zap.String("topic", topic), zap.Any("metadata", metadata))
continue
}
err = admin.Topics().Create(*topicName, 1)
err = admin.Topics().Create(*topicName, 0)
if err != nil {
log.Error("Failed to create topic", zap.Error(err))
return err
Expand Down
2 changes: 1 addition & 1 deletion idl/chromadb/proto/chroma.proto
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ message VectorEmbeddingRecord {
message VectorQueryResult {
string id = 1;
bytes seq_id = 2;
double distance = 3;
float distance = 3;
optional Vector vector = 4;
}

Expand Down
24 changes: 5 additions & 19 deletions k8s/deployment/segment-server.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,32 +32,18 @@ spec:
spec:
containers:
- name: segment-server
image: server
image: worker
imagePullPolicy: IfNotPresent
command: ["python", "-m", "chromadb.segment.impl.distributed.server"]
command: ["cargo", "run"]
ports:
- containerPort: 50051
volumeMounts:
- name: chroma
mountPath: /index_data
env:
- name: IS_PERSISTENT
value: "TRUE"
- name: CHROMA_PRODUCER_IMPL
value: "chromadb.ingest.impl.pulsar.PulsarProducer"
- name: CHROMA_CONSUMER_IMPL
value: "chromadb.ingest.impl.pulsar.PulsarConsumer"
- name: PULSAR_BROKER_URL
value: "pulsar.chroma"
- name: PULSAR_BROKER_PORT
value: "6650"
- name: PULSAR_ADMIN_PORT
value: "8080"
- name: CHROMA_SERVER_GRPC_PORT
value: "50051"
- name: CHROMA_COLLECTION_ASSIGNMENT_POLICY_IMPL
value: "chromadb.ingest.impl.simple_policy.RendezvousHashingAssignmentPolicy"
- name: MY_POD_IP
- name: CHROMA_WORKER__PULSAR_URL
value: pulsar://pulsar.chroma:6650
- name: CHROMA_WORKER__MY_IP
valueFrom:
fieldRef:
fieldPath: status.podIP
Expand Down
2 changes: 1 addition & 1 deletion k8s/test/coordinator_service.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
apiVersion: v1
kind: Service
metadata:
name: coordinator
name: coordinator-lb
namespace: chroma
spec:
ports:
Expand Down
52 changes: 52 additions & 0 deletions k8s/test/minio.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
apiVersion: apps/v1
kind: Deployment
metadata:
name: minio-deployment
namespace: chroma
spec:
selector:
matchLabels:
app: minio
strategy:
type: Recreate
template:
metadata:
labels:
app: minio
spec:
volumes:
- name: minio
emptyDir: {}
containers:
- name: minio
image: minio/minio:latest
args:
- server
- /storage
env:
- name: MINIO_ACCESS_KEY
value: "minio"
- name: MINIO_SECRET_KEY
value: "minio123"
ports:
- containerPort: 9000
hostPort: 9000
volumeMounts:
- name: minio
mountPath: /storage

---

apiVersion: v1
kind: Service
metadata:
name: minio-lb
namespace: chroma
spec:
ports:
- name: http
port: 9000
targetPort: 9000
selector:
app: minio
type: LoadBalancer
2 changes: 1 addition & 1 deletion k8s/test/pulsar_service.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
apiVersion: v1
kind: Service
metadata:
name: pulsar
name: pulsar-lb
namespace: chroma
spec:
ports:
Expand Down
13 changes: 13 additions & 0 deletions k8s/test/segment_server_service.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
apiVersion: v1
kind: Service
metadata:
name: segment-server-lb
namespace: chroma
spec:
ports:
- name: segment-server-port
port: 50052
targetPort: 50051
selector:
app: segment-server
type: LoadBalancer
8 changes: 6 additions & 2 deletions rust/worker/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
FROM rust:1.74.1 as builder

WORKDIR /
RUN git clone https://github.com/chroma-core/hnswlib.git

WORKDIR /chroma/
COPY . .

Expand All @@ -11,5 +14,6 @@ RUN curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v25.1

RUN cargo build

# For now this runs cargo test since we have no main binary
CMD ["cargo", "test"]
WORKDIR /chroma/rust/worker

CMD ["cargo", "run"]
7 changes: 4 additions & 3 deletions rust/worker/chroma_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
# for now we nest it in the worker directory

worker:
my_ip: "10.244.0.90"
my_ip: "10.244.0.9"
my_port: 50051
num_indexing_threads: 4
pulsar_url: "pulsar://127.0.0.1:6650"
pulsar_tenant: "public"
Expand All @@ -18,10 +19,10 @@ worker:
memberlist_name: "worker-memberlist"
queue_size: 100
ingest:
queue_size: 100
queue_size: 10000
sysdb:
Grpc:
host: "localhost"
host: "coordinator.chroma"
port: 50051
segment_manager:
storage_path: "./tmp/segment_manager/"
8 changes: 7 additions & 1 deletion rust/worker/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use serde::Deserialize;

use crate::errors::ChromaError;

const DEFAULT_CONFIG_PATH: &str = "chroma_config.yaml";
const DEFAULT_CONFIG_PATH: &str = "./chroma_config.yaml";
const ENV_PREFIX: &str = "CHROMA_";

#[derive(Deserialize)]
Expand Down Expand Up @@ -97,6 +97,7 @@ impl RootConfig {
/// have its own field in this struct for its Config struct.
pub(crate) struct WorkerConfig {
pub(crate) my_ip: String,
pub(crate) my_port: u16,
pub(crate) num_indexing_threads: u32,
pub(crate) pulsar_tenant: String,
pub(crate) pulsar_namespace: String,
Expand Down Expand Up @@ -134,6 +135,7 @@ mod tests {
r#"
worker:
my_ip: "192.0.0.1"
my_port: 50051
num_indexing_threads: 4
pulsar_tenant: "public"
pulsar_namespace: "default"
Expand Down Expand Up @@ -175,6 +177,7 @@ mod tests {
r#"
worker:
my_ip: "192.0.0.1"
my_port: 50051
num_indexing_threads: 4
pulsar_tenant: "public"
pulsar_namespace: "default"
Expand Down Expand Up @@ -232,6 +235,7 @@ mod tests {
r#"
worker:
my_ip: "192.0.0.1"
my_port: 50051
pulsar_tenant: "public"
pulsar_namespace: "default"
kube_namespace: "chroma"
Expand Down Expand Up @@ -265,6 +269,7 @@ mod tests {
fn test_config_with_env_override() {
Jail::expect_with(|jail| {
let _ = jail.set_env("CHROMA_WORKER__MY_IP", "192.0.0.1");
let _ = jail.set_env("CHROMA_WORKER__MY_PORT", 50051);
let _ = jail.set_env("CHROMA_WORKER__PULSAR_TENANT", "A");
let _ = jail.set_env("CHROMA_WORKER__PULSAR_NAMESPACE", "B");
let _ = jail.set_env("CHROMA_WORKER__KUBE_NAMESPACE", "C");
Expand Down Expand Up @@ -292,6 +297,7 @@ mod tests {
);
let config = RootConfig::load();
assert_eq!(config.worker.my_ip, "192.0.0.1");
assert_eq!(config.worker.my_port, 50051);
assert_eq!(config.worker.num_indexing_threads, num_cpus::get() as u32);
assert_eq!(config.worker.pulsar_tenant, "A");
assert_eq!(config.worker.pulsar_namespace, "B");
Expand Down
1 change: 1 addition & 0 deletions rust/worker/src/ingest/ingest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ impl Configurable for Ingest {
worker_config.pulsar_namespace.clone(),
);

println!("Pulsar connection url: {}", worker_config.pulsar_url);
let pulsar = match Pulsar::builder(worker_config.pulsar_url.clone(), TokioExecutor)
.build()
.await
Expand Down
21 changes: 19 additions & 2 deletions rust/worker/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@ mod index;
mod ingest;
mod memberlist;
mod segment;
mod server;
mod sysdb;
mod system;
mod types;

use config::Configurable;
use memberlist::MemberlistProvider;

use crate::sysdb::sysdb::SysDb;

mod chroma_proto {
tonic::include_proto!("chroma");
}
Expand Down Expand Up @@ -61,8 +64,18 @@ pub async fn worker_entrypoint() {
segment_ingestor_receivers.push(recv);
}

let mut worker_server = match server::WorkerServer::try_from_config(&config.worker).await {
Ok(worker_server) => worker_server,
Err(err) => {
println!("Failed to create worker server component: {:?}", err);
return;
}
};
worker_server.set_segment_manager(segment_manager.clone());

// Boot the system
// memberlist -> ingest -> scheduler -> NUM_THREADS x segment_ingestor
// memberlist -> ingest -> scheduler -> NUM_THREADS x segment_ingestor -> segment_manager
// server <- segment_manager

for recv in segment_ingestor_receivers {
scheduler.subscribe(recv);
Expand All @@ -76,10 +89,14 @@ pub async fn worker_entrypoint() {
memberlist.subscribe(recv);
let mut memberlist_handle = system.start_component(memberlist);

let server_join_handle = tokio::spawn(async move {
crate::server::WorkerServer::run(worker_server).await;
});

// Join on all handles
let _ = tokio::join!(
ingest_handle.join(),
memberlist_handle.join(),
scheduler_handler.join()
scheduler_handler.join(),
);
}
56 changes: 54 additions & 2 deletions rust/worker/src/segment/distributed_hnsw_segment.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
use num_bigint::BigInt;
use parking_lot::{Mutex, RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
use std::collections::HashMap;
use std::sync::atomic::AtomicUsize;
use std::sync::Arc;

use crate::errors::ChromaError;
use crate::index::{HnswIndex, HnswIndexConfig, Index, IndexConfig};
use crate::types::{EmbeddingRecord, Operation, Segment};
use crate::types::{EmbeddingRecord, Operation, Segment, VectorEmbeddingRecord};

pub(crate) struct DistributedHNSWSegment {
index: Arc<RwLock<HnswIndex>>,
id: AtomicUsize,
user_id_to_id: Arc<RwLock<HashMap<String, usize>>>,
id_to_user_id: Arc<RwLock<HashMap<usize, String>>>,
index_config: IndexConfig,
hnsw_config: HnswIndexConfig,
}
Expand All @@ -33,6 +35,7 @@ impl DistributedHNSWSegment {
index: index,
id: AtomicUsize::new(0),
user_id_to_id: Arc::new(RwLock::new(HashMap::new())),
id_to_user_id: Arc::new(RwLock::new(HashMap::new())),
index_config: index_config,
hnsw_config,
});
Expand Down Expand Up @@ -63,7 +66,10 @@ impl DistributedHNSWSegment {
self.user_id_to_id
.write()
.insert(record.id.clone(), next_id);
println!("DIS SEGMENT Adding item: {}", next_id);
self.id_to_user_id
.write()
.insert(next_id, record.id.clone());
println!("Segment adding item: {}", next_id);
self.index.read().add(next_id, &vector);
}
None => {
Expand All @@ -81,4 +87,50 @@ impl DistributedHNSWSegment {
}
}
}

pub(crate) fn get_records(&self, ids: Vec<String>) -> Vec<Box<VectorEmbeddingRecord>> {
let mut records = Vec::new();
let user_id_to_id = self.user_id_to_id.read();
let index = self.index.read();
for id in ids {
let internal_id = match user_id_to_id.get(&id) {
Some(internal_id) => internal_id,
None => {
// TODO: Error
return records;
}
};
let vector = index.get(*internal_id);
match vector {
Some(vector) => {
let record = VectorEmbeddingRecord {
id: id,
seq_id: BigInt::from(0),
vector,
};
records.push(Box::new(record));
}
None => {
// TODO: error
}
}
}
return records;
}

pub(crate) fn query(&self, vector: &[f32], k: usize) -> (Vec<String>, Vec<f32>) {
let index = self.index.read();
let mut return_user_ids = Vec::new();
let (ids, distances) = index.query(vector, k);
let user_ids = self.id_to_user_id.read();
for id in ids {
match user_ids.get(&id) {
Some(user_id) => return_user_ids.push(user_id.clone()),
None => {
// TODO: error
}
};
}
return (return_user_ids, distances);
}
}
Loading

0 comments on commit a669624

Please sign in to comment.