-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[FEATURE] Add DuckDB EmbeddingStore implementation (#2183) (#24)
* [FEATURE] Add DuckDB EmbeddingStore implementation (#2183) * formatting --------- Co-authored-by: Martin7-1 <[email protected]>
- Loading branch information
Showing
8 changed files
with
568 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
<?xml version="1.0" encoding="UTF-8"?> | ||
<project xmlns="http://maven.apache.org/POM/4.0.0" | ||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" | ||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> | ||
<modelVersion>4.0.0</modelVersion> | ||
<parent> | ||
<groupId>dev.langchain4j</groupId> | ||
<artifactId>langchain4j-community</artifactId> | ||
<version>0.37.0-SNAPSHOT</version> | ||
<relativePath>../../pom.xml</relativePath> | ||
</parent> | ||
|
||
<artifactId>langchain4j-community-duckdb</artifactId> | ||
<name>LangChain4j :: Community :: Integration :: DuckDB</name> | ||
|
||
<properties> | ||
<duckdb.version>1.1.3</duckdb.version> | ||
</properties> | ||
|
||
<dependencies> | ||
<dependency> | ||
<groupId>dev.langchain4j</groupId> | ||
<artifactId>langchain4j-core</artifactId> | ||
<version>${project.version}</version> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>org.duckdb</groupId> | ||
<artifactId>duckdb_jdbc</artifactId> | ||
<version>${duckdb.version}</version> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>com.fasterxml.jackson.core</groupId> | ||
<artifactId>jackson-databind</artifactId> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>dev.langchain4j</groupId> | ||
<artifactId>langchain4j-core</artifactId> | ||
<version>${project.version}</version> | ||
<classifier>tests</classifier> | ||
<type>test-jar</type> | ||
<scope>test</scope> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>org.junit.jupiter</groupId> | ||
<artifactId>junit-jupiter-engine</artifactId> | ||
<scope>test</scope> | ||
</dependency> | ||
|
||
<!-- junit-jupiter-params should be declared explicitly | ||
to run parameterized tests inherited from EmbeddingStore*IT--> | ||
<dependency> | ||
<groupId>org.junit.jupiter</groupId> | ||
<artifactId>junit-jupiter-params</artifactId> | ||
<scope>test</scope> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>org.assertj</groupId> | ||
<artifactId>assertj-core</artifactId> | ||
<scope>test</scope> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>org.awaitility</groupId> | ||
<artifactId>awaitility</artifactId> | ||
<scope>test</scope> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>dev.langchain4j</groupId> | ||
<artifactId>langchain4j-embeddings-all-minilm-l6-v2-q</artifactId> | ||
<scope>test</scope> | ||
</dependency> | ||
</dependencies> | ||
</project> |
293 changes: 293 additions & 0 deletions
293
.../src/main/java/dev/langchain4j/community/store/embedding/duckdb/DuckDBEmbeddingStore.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,293 @@ | ||
package dev.langchain4j.community.store.embedding.duckdb; | ||
|
||
import com.fasterxml.jackson.core.JsonProcessingException; | ||
import com.fasterxml.jackson.core.type.TypeReference; | ||
import com.fasterxml.jackson.databind.ObjectMapper; | ||
import dev.langchain4j.data.document.Metadata; | ||
import dev.langchain4j.data.embedding.Embedding; | ||
import dev.langchain4j.data.segment.TextSegment; | ||
import dev.langchain4j.store.embedding.EmbeddingMatch; | ||
import dev.langchain4j.store.embedding.EmbeddingSearchRequest; | ||
import dev.langchain4j.store.embedding.EmbeddingSearchResult; | ||
import dev.langchain4j.store.embedding.EmbeddingStore; | ||
import dev.langchain4j.store.embedding.filter.Filter; | ||
import org.duckdb.DuckDBConnection; | ||
import org.slf4j.Logger; | ||
import org.slf4j.LoggerFactory; | ||
|
||
import java.sql.DriverManager; | ||
import java.sql.SQLException; | ||
import java.util.ArrayList; | ||
import java.util.Collection; | ||
import java.util.Collections; | ||
import java.util.HashMap; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.stream.Collectors; | ||
|
||
import static dev.langchain4j.internal.Utils.getOrDefault; | ||
import static dev.langchain4j.internal.Utils.isNullOrEmpty; | ||
import static dev.langchain4j.internal.Utils.randomUUID; | ||
import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty; | ||
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; | ||
import static dev.langchain4j.internal.ValidationUtils.ensureTrue; | ||
import static java.lang.String.format; | ||
import static java.util.Collections.singletonList; | ||
|
||
/** | ||
* Implementation of {@link EmbeddingStore} using <a href="https://duckdb.org/">DuckDB</a> | ||
* This implementation uses cosine distance and supports storing {@link Metadata} | ||
*/ | ||
public class DuckDBEmbeddingStore implements EmbeddingStore<TextSegment> { | ||
|
||
private static final Logger log = LoggerFactory.getLogger(DuckDBEmbeddingStore.class); | ||
|
||
private static final String CREATE_TABLE_TEMPLATE = """ | ||
create table if not exists %s (id UUID, embedding FLOAT[], text TEXT NULL, metadata JSON NULL); | ||
"""; | ||
|
||
private static final String SEARCH_QUERY_TEMPLATE = """ | ||
select *, (list_cosine_similarity(embedding,%s)+1.0)/2.0 as score | ||
from %s | ||
where score >= %s %s | ||
order by score DESC | ||
limit(%d) | ||
"""; | ||
|
||
private static final String INSERT_QUERY_TEMPLATE = """ | ||
insert into %s (id, embedding, text, metadata) values %s | ||
"""; | ||
|
||
private static final String DELETE_BY_IDS_QUERY_TEMPLATE = """ | ||
delete from %s where id in (%s) | ||
"""; | ||
|
||
private static final String DELETE_QUERY_TEMPLATE = """ | ||
delete from %s where %s | ||
"""; | ||
|
||
private static final String TRUNCATE_QUERY_TEMPLATE = """ | ||
truncate table %s | ||
"""; | ||
|
||
private final String tableName; | ||
private final DuckDBConnection duckDBConnection; | ||
private final DuckDBMetadataFilterMapper jsonFilterMapper = new DuckDBMetadataFilterMapper(); | ||
private final ObjectMapper jsonMetadataSerializer = new ObjectMapper(); | ||
|
||
|
||
/** | ||
* Initializes a new instance of DuckDBEmbeddingStore with the specified parameters. | ||
* | ||
* @param filePath File used to persist DuckDB database. If not specified, the database will be stored in-memory. | ||
* @param tableName The database table name to use. If not specified, "embeddings" will be used | ||
*/ | ||
public DuckDBEmbeddingStore(String filePath, String tableName) { | ||
try { | ||
var dbUrl = filePath != null ? "jdbc:duckdb:" + filePath : "jdbc:duckdb:"; | ||
this.tableName = getOrDefault(tableName, "embeddings"); | ||
this.duckDBConnection = (DuckDBConnection) DriverManager.getConnection(dbUrl); | ||
initTable(); | ||
} catch (SQLException e) { | ||
throw new DuckDBSQLException("Unable to load duckdb connection", e); | ||
} | ||
} | ||
|
||
|
||
/** | ||
* @return a new instance of DuckDBEmbeddingStore with the default configuration and database stored in-memory | ||
*/ | ||
public static DuckDBEmbeddingStore inMemory() { | ||
return new DuckDBEmbeddingStore(null, null); | ||
} | ||
|
||
public static Builder builder() { | ||
return new Builder(); | ||
} | ||
|
||
public static class Builder { | ||
private String filePath; | ||
private String tableName; | ||
|
||
/** | ||
* @param filePath File used to persist DuckDB database. If not specified, the database will be stored in-memory. | ||
* @return builder | ||
*/ | ||
public Builder filePath(String filePath) { | ||
this.filePath = filePath; | ||
return this; | ||
} | ||
|
||
/** | ||
* @param tableName The database table name to use. If not specified, "embeddings" will be used | ||
* @return builder | ||
*/ | ||
public Builder tableName(String tableName) { | ||
this.tableName = tableName; | ||
return this; | ||
} | ||
|
||
/** | ||
* @param tableName The database table name to use. If not specified, "embeddings" will be used | ||
* @return builder | ||
*/ | ||
public Builder inMemory(String tableName) { | ||
return filePath(null); | ||
} | ||
|
||
|
||
public DuckDBEmbeddingStore build() { | ||
return new DuckDBEmbeddingStore(filePath, tableName); | ||
} | ||
} | ||
|
||
public String add(final Embedding embedding) { | ||
String id = randomUUID(); | ||
add(id, embedding); | ||
return id; | ||
} | ||
|
||
public void add(final String id, final Embedding embedding) { | ||
addInternal(id, embedding, null); | ||
} | ||
|
||
public String add(final Embedding embedding, final TextSegment textSegment) { | ||
String id = randomUUID(); | ||
addInternal(id, embedding, textSegment); | ||
return id; | ||
} | ||
|
||
public List<String> addAll(final List<Embedding> embeddings) { | ||
return addAll(embeddings, null); | ||
} | ||
|
||
public List<String> addAll(final List<Embedding> embeddings, final List<TextSegment> embedded) { | ||
List<String> ids = embeddings.stream() | ||
.map(ignored -> randomUUID()) | ||
.toList(); | ||
addAllInternal(ids, embeddings, embedded); | ||
return ids; | ||
} | ||
|
||
@Override | ||
public void removeAll(final Collection<String> ids) { | ||
ensureNotEmpty(ids, "ids"); | ||
var idsParam = ids.stream().map(id -> "'" + id + "'").collect(Collectors.joining(",")); | ||
String sql = format(DELETE_BY_IDS_QUERY_TEMPLATE, tableName, idsParam); | ||
try (var connection = duckDBConnection.duplicate(); var statement = connection.createStatement()) { | ||
log.debug(sql); | ||
statement.execute(sql); | ||
} catch (SQLException e) { | ||
throw new DuckDBSQLException("Unable to remove embeddings by ids", e); | ||
} | ||
} | ||
|
||
@Override | ||
public void removeAll(final Filter filter) { | ||
ensureNotNull(filter, "filter"); | ||
var whereClause = jsonFilterMapper.map(filter); | ||
String sql = format(DELETE_QUERY_TEMPLATE, tableName, whereClause); | ||
try (var connection = duckDBConnection.duplicate(); var statement = connection.createStatement()) { | ||
log.debug(sql); | ||
statement.execute(sql); | ||
} catch (SQLException e) { | ||
throw new DuckDBSQLException("Unable to remove embeddings with filter", e); | ||
} | ||
} | ||
|
||
@Override | ||
public void removeAll() { | ||
var sql = format(TRUNCATE_QUERY_TEMPLATE, tableName); | ||
try (var connection = duckDBConnection.duplicate(); var statement = connection.createStatement()) { | ||
statement.execute(sql); | ||
} catch (SQLException e) { | ||
throw new DuckDBSQLException("Unable to remove all embeddings", e); | ||
} | ||
} | ||
|
||
@Override | ||
public EmbeddingSearchResult<TextSegment> search(final EmbeddingSearchRequest request) { | ||
try (var connection = duckDBConnection.duplicate(); var statement = connection.createStatement()) { | ||
var matches = new ArrayList<EmbeddingMatch<TextSegment>>(); | ||
var param = embeddingToParam(request.queryEmbedding()); | ||
var filterClause = request.filter() != null ? "and " + jsonFilterMapper.map(request.filter()) : ""; | ||
var query = format(SEARCH_QUERY_TEMPLATE, param, tableName, request.minScore(), filterClause, request.maxResults()); | ||
log.debug(query); | ||
var resultSet = statement.executeQuery(query); | ||
while (resultSet.next()) { | ||
|
||
var id = resultSet.getString("id"); | ||
var text = resultSet.getString("text"); | ||
var score = resultSet.getDouble("score"); | ||
var sqlArray = resultSet.getArray("embedding"); | ||
var metadataJson = resultSet.getString("metadata"); | ||
|
||
var typeReference = new TypeReference<HashMap<String, Object>>() { | ||
}; | ||
Map<String, ?> metadataMap = metadataJson != null ? jsonMetadataSerializer.readValue(metadataJson, typeReference) : Collections.emptyMap(); | ||
|
||
var sqlList = (Object[]) sqlArray.getArray(); | ||
var vector = new float[sqlList.length]; | ||
for (int i = 0; i < sqlList.length; i++) { | ||
vector[i] = (float) sqlList[i]; | ||
} | ||
var ts = text != null ? TextSegment.from(text, Metadata.from(metadataMap)) : null; | ||
matches.add(new EmbeddingMatch<>(score, id, new Embedding(vector), ts)); | ||
} | ||
return new EmbeddingSearchResult<>(matches); | ||
} catch (SQLException | JsonProcessingException e) { | ||
throw new DuckDBSQLException("Error while searching embeddings", e); | ||
} | ||
} | ||
|
||
|
||
private void addInternal(final String id, final Embedding embedding, final TextSegment textSegment) { | ||
addAllInternal(singletonList(id), singletonList(embedding), embedding == null ? null : singletonList(textSegment)); | ||
} | ||
|
||
private void addAllInternal(List<String> ids, List<Embedding> embeddings, List<TextSegment> embedded) { | ||
if (isNullOrEmpty(ids) || isNullOrEmpty(embeddings)) { | ||
log.info("[no embeddings to add to DuckDB]"); | ||
return; | ||
} | ||
ensureTrue(ids.size() == embeddings.size(), "ids size is not equal to embeddings size"); | ||
ensureTrue(embedded == null || embeddings.size() == embedded.size(), "embeddings size is not equal to embedded size"); | ||
|
||
try (var connection = duckDBConnection.duplicate(); var statement = connection.createStatement()) { | ||
var values = new ArrayList<String>(ids.size()); | ||
for (int i = 0; i < ids.size(); i++) { | ||
var text = "NULL"; | ||
if (embedded != null && embedded.get(i) != null) { | ||
text = "'" + embedded.get(i).text() + "'"; | ||
} | ||
var metadata = embedded != null && embedded.get(i) != null ? embedded.get(i).metadata().toMap() : null; | ||
values.add(format("('%s',%s,%s,'%s')", ids.get(i), embeddingToParam(embeddings.get(i)), text, jsonMetadataSerializer.writeValueAsString(metadata))); | ||
} | ||
|
||
var sql = format(INSERT_QUERY_TEMPLATE, tableName, String.join(",", values)); | ||
log.debug(sql); | ||
statement.execute(sql); | ||
} catch (SQLException | JsonProcessingException e) { | ||
throw new DuckDBSQLException("Unable to add embeddings in DuckDB", e); | ||
} | ||
} | ||
|
||
private void initTable() { | ||
var sql = format(CREATE_TABLE_TEMPLATE, tableName); | ||
try (var connection = duckDBConnection.duplicate(); var statement = connection.createStatement()) { | ||
log.debug(sql); | ||
statement.execute(sql); | ||
} catch (SQLException e) { | ||
throw new DuckDBSQLException(format("Failed to init duckDB table: '%s'", sql), e); | ||
} | ||
} | ||
|
||
protected String embeddingToParam(Embedding embedding) { | ||
return embedding.vectorAsList() | ||
.stream() | ||
.map(Object::toString) | ||
.collect(Collectors.joining(",", "[", "]")) | ||
.concat("::float[]"); | ||
} | ||
} | ||
|
Oops, something went wrong.