-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy path_54_CassandraVectorStore.java
53 lines (43 loc) · 1.99 KB
/
_54_CassandraVectorStore.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
package devoxx.demo._5_vectorsearch;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.cassandra.CassandraCassioEmbeddingStore;
import devoxx.demo.utils.AbstractDevoxxTestSupport;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test;
import static dev.langchain4j.store.embedding.filter.MetadataFilterBuilder.metadataKey;
import static devoxx.demo.devoxx.Utilities.EMBEDDING_DIMENSION;
import static devoxx.demo.devoxx.Utilities.TABLE_NAME;
@Slf4j
public class _54_CassandraVectorStore extends AbstractDevoxxTestSupport {
@Test
public void langchain4jEmbeddingStore() {
// I have to create a EmbeddingModel
EmbeddingModel embeddingModel = getEmbeddingModelGecko();
// Embed the question
Embedding questionEmbedding = embeddingModel
.embed("We struggle all our life for nothing")
.content();
// We need the store
EmbeddingStore<TextSegment> embeddingStore = new CassandraCassioEmbeddingStore(
getCassandraSession(), TABLE_NAME, EMBEDDING_DIMENSION);
// Query (1)
log.info("Querying the store");
embeddingStore
.findRelevant(questionEmbedding, 3, 0.8d)
.stream().map(r -> r.embedded().text())
.forEach(System.out::println);
// Query with a filter(2)
log.info("Querying with filter");
embeddingStore.search(EmbeddingSearchRequest.builder()
.queryEmbedding(questionEmbedding)
.filter(metadataKey("author").isEqualTo("nietzsche"))
.maxResults(3).minScore(0.8d).build())
.matches()
.stream().map(r -> r.embedded().text())
.forEach(System.out::println);
}
}