Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE] Add DuckDB EmbeddingStore implementation (#2183) #24

Merged
merged 3 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions embedding-stores/langchain4j-community-duckdb/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-community</artifactId>
<version>0.37.0-SNAPSHOT</version>
<relativePath>../../pom.xml</relativePath>
</parent>

<artifactId>langchain4j-community-duckdb</artifactId>
<name>LangChain4j :: Community :: Integration :: DuckDB</name>

<properties>
<duckdb.version>1.1.3</duckdb.version>
</properties>

<dependencies>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
<version>${project.version}</version>
</dependency>

<dependency>
<groupId>org.duckdb</groupId>
<artifactId>duckdb_jdbc</artifactId>
<version>${duckdb.version}</version>
</dependency>

<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
</dependency>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
<version>${project.version}</version>
<classifier>tests</classifier>
<type>test-jar</type>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-engine</artifactId>
<scope>test</scope>
</dependency>

<!-- junit-jupiter-params should be declared explicitly
to run parameterized tests inherited from EmbeddingStore*IT-->
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-params</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.awaitility</groupId>
<artifactId>awaitility</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-embeddings-all-minilm-l6-v2-q</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,293 @@
package dev.langchain4j.community.store.embedding.duckdb;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.filter.Filter;
import org.duckdb.DuckDBConnection;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.sql.DriverManager;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
import static dev.langchain4j.internal.Utils.randomUUID;
import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
import static dev.langchain4j.internal.ValidationUtils.ensureTrue;
import static java.lang.String.format;
import static java.util.Collections.singletonList;

/**
* Implementation of {@link EmbeddingStore} using <a href="https://duckdb.org/">DuckDB</a>
* This implementation uses cosine distance and supports storing {@link Metadata}
*/
public class DuckDBEmbeddingStore implements EmbeddingStore<TextSegment> {

private static final Logger log = LoggerFactory.getLogger(DuckDBEmbeddingStore.class);

private static final String CREATE_TABLE_TEMPLATE = """
create table if not exists %s (id UUID, embedding FLOAT[], text TEXT NULL, metadata JSON NULL);
""";

private static final String SEARCH_QUERY_TEMPLATE = """
select *, (list_cosine_similarity(embedding,%s)+1.0)/2.0 as score
from %s
where score >= %s %s
order by score DESC
limit(%d)
""";

private static final String INSERT_QUERY_TEMPLATE = """
insert into %s (id, embedding, text, metadata) values %s
""";

private static final String DELETE_BY_IDS_QUERY_TEMPLATE = """
delete from %s where id in (%s)
""";

private static final String DELETE_QUERY_TEMPLATE = """
delete from %s where %s
""";

private static final String TRUNCATE_QUERY_TEMPLATE = """
truncate table %s
""";

private final String tableName;
private final DuckDBConnection duckDBConnection;
private final DuckDBMetadataFilterMapper jsonFilterMapper = new DuckDBMetadataFilterMapper();
private final ObjectMapper jsonMetadataSerializer = new ObjectMapper();


/**
* Initializes a new instance of DuckDBEmbeddingStore with the specified parameters.
*
* @param filePath File used to persist DuckDB database. If not specified, the database will be stored in-memory.
* @param tableName The database table name to use. If not specified, "embeddings" will be used
*/
public DuckDBEmbeddingStore(String filePath, String tableName) {
try {
var dbUrl = filePath != null ? "jdbc:duckdb:" + filePath : "jdbc:duckdb:";
this.tableName = getOrDefault(tableName, "embeddings");
this.duckDBConnection = (DuckDBConnection) DriverManager.getConnection(dbUrl);
initTable();
} catch (SQLException e) {
throw new DuckDBSQLException("Unable to load duckdb connection", e);
}
}


/**
* @return a new instance of DuckDBEmbeddingStore with the default configuration and database stored in-memory
*/
public static DuckDBEmbeddingStore inMemory() {
return new DuckDBEmbeddingStore(null, null);
}

public static Builder builder() {
return new Builder();
}

public static class Builder {
private String filePath;
private String tableName;

/**
* @param filePath File used to persist DuckDB database. If not specified, the database will be stored in-memory.
* @return builder
*/
public Builder filePath(String filePath) {
this.filePath = filePath;
return this;
}

/**
* @param tableName The database table name to use. If not specified, "embeddings" will be used
* @return builder
*/
public Builder tableName(String tableName) {
this.tableName = tableName;
return this;
}

/**
* @param tableName The database table name to use. If not specified, "embeddings" will be used
* @return builder
*/
public Builder inMemory(String tableName) {
return filePath(null);
}


public DuckDBEmbeddingStore build() {
return new DuckDBEmbeddingStore(filePath, tableName);
}
}

public String add(final Embedding embedding) {
String id = randomUUID();
add(id, embedding);
return id;
}

public void add(final String id, final Embedding embedding) {
addInternal(id, embedding, null);
}

public String add(final Embedding embedding, final TextSegment textSegment) {
String id = randomUUID();
addInternal(id, embedding, textSegment);
return id;
}

public List<String> addAll(final List<Embedding> embeddings) {
return addAll(embeddings, null);
}

public List<String> addAll(final List<Embedding> embeddings, final List<TextSegment> embedded) {
List<String> ids = embeddings.stream()
.map(ignored -> randomUUID())
.toList();
addAllInternal(ids, embeddings, embedded);
return ids;
}

@Override
public void removeAll(final Collection<String> ids) {
ensureNotEmpty(ids, "ids");
var idsParam = ids.stream().map(id -> "'" + id + "'").collect(Collectors.joining(","));
String sql = format(DELETE_BY_IDS_QUERY_TEMPLATE, tableName, idsParam);
try (var connection = duckDBConnection.duplicate(); var statement = connection.createStatement()) {
log.debug(sql);
statement.execute(sql);
} catch (SQLException e) {
throw new DuckDBSQLException("Unable to remove embeddings by ids", e);
}
}

@Override
public void removeAll(final Filter filter) {
ensureNotNull(filter, "filter");
var whereClause = jsonFilterMapper.map(filter);
String sql = format(DELETE_QUERY_TEMPLATE, tableName, whereClause);
try (var connection = duckDBConnection.duplicate(); var statement = connection.createStatement()) {
log.debug(sql);
statement.execute(sql);
} catch (SQLException e) {
throw new DuckDBSQLException("Unable to remove embeddings with filter", e);
}
}

@Override
public void removeAll() {
var sql = format(TRUNCATE_QUERY_TEMPLATE, tableName);
try (var connection = duckDBConnection.duplicate(); var statement = connection.createStatement()) {
statement.execute(sql);
} catch (SQLException e) {
throw new DuckDBSQLException("Unable to remove all embeddings", e);
}
}

@Override
public EmbeddingSearchResult<TextSegment> search(final EmbeddingSearchRequest request) {
try (var connection = duckDBConnection.duplicate(); var statement = connection.createStatement()) {
var matches = new ArrayList<EmbeddingMatch<TextSegment>>();
var param = embeddingToParam(request.queryEmbedding());
var filterClause = request.filter() != null ? "and " + jsonFilterMapper.map(request.filter()) : "";
var query = format(SEARCH_QUERY_TEMPLATE, param, tableName, request.minScore(), filterClause, request.maxResults());
log.debug(query);
var resultSet = statement.executeQuery(query);
while (resultSet.next()) {

var id = resultSet.getString("id");
var text = resultSet.getString("text");
var score = resultSet.getDouble("score");
var sqlArray = resultSet.getArray("embedding");
var metadataJson = resultSet.getString("metadata");

var typeReference = new TypeReference<HashMap<String, Object>>() {
};
Map<String, ?> metadataMap = metadataJson != null ? jsonMetadataSerializer.readValue(metadataJson, typeReference) : Collections.emptyMap();

var sqlList = (Object[]) sqlArray.getArray();
var vector = new float[sqlList.length];
for (int i = 0; i < sqlList.length; i++) {
vector[i] = (float) sqlList[i];
}
var ts = text != null ? TextSegment.from(text, Metadata.from(metadataMap)) : null;
matches.add(new EmbeddingMatch<>(score, id, new Embedding(vector), ts));
}
return new EmbeddingSearchResult<>(matches);
} catch (SQLException | JsonProcessingException e) {
throw new DuckDBSQLException("Error while searching embeddings", e);
}
}


private void addInternal(final String id, final Embedding embedding, final TextSegment textSegment) {
addAllInternal(singletonList(id), singletonList(embedding), embedding == null ? null : singletonList(textSegment));
}

private void addAllInternal(List<String> ids, List<Embedding> embeddings, List<TextSegment> embedded) {
if (isNullOrEmpty(ids) || isNullOrEmpty(embeddings)) {
log.info("[no embeddings to add to DuckDB]");
return;
}
ensureTrue(ids.size() == embeddings.size(), "ids size is not equal to embeddings size");
ensureTrue(embedded == null || embeddings.size() == embedded.size(), "embeddings size is not equal to embedded size");

try (var connection = duckDBConnection.duplicate(); var statement = connection.createStatement()) {
var values = new ArrayList<String>(ids.size());
for (int i = 0; i < ids.size(); i++) {
var text = "NULL";
if (embedded != null && embedded.get(i) != null) {
text = "'" + embedded.get(i).text() + "'";
}
var metadata = embedded != null && embedded.get(i) != null ? embedded.get(i).metadata().toMap() : null;
values.add(format("('%s',%s,%s,'%s')", ids.get(i), embeddingToParam(embeddings.get(i)), text, jsonMetadataSerializer.writeValueAsString(metadata)));
}

var sql = format(INSERT_QUERY_TEMPLATE, tableName, String.join(",", values));
log.debug(sql);
statement.execute(sql);
} catch (SQLException | JsonProcessingException e) {
throw new DuckDBSQLException("Unable to add embeddings in DuckDB", e);
}
}

private void initTable() {
var sql = format(CREATE_TABLE_TEMPLATE, tableName);
try (var connection = duckDBConnection.duplicate(); var statement = connection.createStatement()) {
log.debug(sql);
statement.execute(sql);
} catch (SQLException e) {
throw new DuckDBSQLException(format("Failed to init duckDB table: '%s'", sql), e);
}
}

protected String embeddingToParam(Embedding embedding) {
return embedding.vectorAsList()
.stream()
.map(Object::toString)
.collect(Collectors.joining(",", "[", "]"))
.concat("::float[]");
}
}

Loading