diff --git a/embedding-stores/langchain4j-community-vearch/pom.xml b/embedding-stores/langchain4j-community-vearch/pom.xml new file mode 100644 index 0000000..9df8532 --- /dev/null +++ b/embedding-stores/langchain4j-community-vearch/pom.xml @@ -0,0 +1,117 @@ + + + 4.0.0 + + dev.langchain4j + langchain4j-community + 0.37.0-SNAPSHOT + ../../pom.xml + + + langchain4j-community-vearch + LangChain4j :: Community :: Integration :: Vearch + + + + dev.langchain4j + langchain4j-core + ${project.version} + + + + com.squareup.retrofit2 + retrofit + + + + com.squareup.retrofit2 + converter-jackson + + + + + com.squareup.okhttp3 + okhttp + + + + org.jetbrains.kotlin + kotlin-stdlib-jdk8 + + + + + org.slf4j + slf4j-api + + + + org.junit.jupiter + junit-jupiter-engine + test + + + + + org.junit.jupiter + junit-jupiter-params + test + + + + org.assertj + assertj-core + test + + + + dev.langchain4j + langchain4j-embeddings-all-minilm-l6-v2-q + test + + + + dev.langchain4j + langchain4j-core + ${project.version} + tests + test-jar + test + + + + org.testcontainers + testcontainers + test + + + + org.testcontainers + junit-jupiter + test + + + + org.tinylog + tinylog-impl + test + + + + org.tinylog + slf4j-tinylog + test + + + + org.awaitility + awaitility + test + + + + + diff --git a/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/CreateDatabaseResponse.java b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/CreateDatabaseResponse.java new file mode 100644 index 0000000..58352f6 --- /dev/null +++ b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/CreateDatabaseResponse.java @@ -0,0 +1,41 @@ +package dev.langchain4j.community.store.embedding.vearch; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonInclude(NON_NULL) +@JsonNaming(SnakeCaseStrategy.class) +class CreateDatabaseResponse { + + private Long id; + private String name; + + CreateDatabaseResponse() { + } + + CreateDatabaseResponse(Long id, String name) { + this.id = id; + this.name = name; + } + + public Long getId() { + return id; + } + + public void setId(Long id) { + this.id = id; + } + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } +} diff --git a/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/CreateSpaceRequest.java b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/CreateSpaceRequest.java new file mode 100644 index 0000000..df13cda --- /dev/null +++ b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/CreateSpaceRequest.java @@ -0,0 +1,84 @@ +package dev.langchain4j.community.store.embedding.vearch; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import dev.langchain4j.community.store.embedding.vearch.field.Field; + +import java.util.List; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonInclude(NON_NULL) +@JsonNaming(SnakeCaseStrategy.class) +class CreateSpaceRequest { + + private String name; + private Integer partitionNum; + private Integer replicaNum; + private List fields; + + CreateSpaceRequest() { + } + + CreateSpaceRequest(String name, Integer partitionNum, Integer replicaNum, List fields) { + this.name = name; + this.partitionNum = partitionNum; + this.replicaNum = replicaNum; + this.fields = fields; + } + + public String getName() { + return name; + } + + public Integer getPartitionNum() { + return partitionNum; + } + + public Integer getReplicaNum() { + return replicaNum; + } + + public List getFields() { + return fields; + } + + static Builder builder() { + return new Builder(); + } + + static class Builder { + + private String name; + private Integer partitionNum; + private Integer replicaNum; + private List fields; + + Builder name(String name) { + this.name = name; + return this; + } + + Builder partitionNum(Integer partitionNum) { + this.partitionNum = partitionNum; + return this; + } + + Builder replicaNum(Integer replicaNum) { + this.replicaNum = replicaNum; + return this; + } + + Builder fields(List fields) { + this.fields = fields; + return this; + } + + CreateSpaceRequest build() { + return new CreateSpaceRequest(name, partitionNum, replicaNum, fields); + } + } +} diff --git a/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/CreateSpaceResponse.java b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/CreateSpaceResponse.java new file mode 100644 index 0000000..d50d0dd --- /dev/null +++ b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/CreateSpaceResponse.java @@ -0,0 +1,41 @@ +package dev.langchain4j.community.store.embedding.vearch; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonInclude(NON_NULL) +@JsonNaming(SnakeCaseStrategy.class) +class CreateSpaceResponse { + + private Integer id; + private String name; + + CreateSpaceResponse() { + } + + CreateSpaceResponse(Integer id, String name) { + this.id = id; + this.name = name; + } + + public Integer getId() { + return id; + } + + public void setId(Integer id) { + this.id = id; + } + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } +} diff --git a/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/ListDatabaseResponse.java b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/ListDatabaseResponse.java new file mode 100644 index 0000000..be8d80e --- /dev/null +++ b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/ListDatabaseResponse.java @@ -0,0 +1,41 @@ +package dev.langchain4j.community.store.embedding.vearch; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonInclude(NON_NULL) +@JsonNaming(SnakeCaseStrategy.class) +class ListDatabaseResponse { + + private Integer id; + private String name; + + ListDatabaseResponse() { + } + + ListDatabaseResponse(Integer id, String name) { + this.id = id; + this.name = name; + } + + public Integer getId() { + return id; + } + + public void setId(Integer id) { + this.id = id; + } + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } +} diff --git a/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/ListSpaceResponse.java b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/ListSpaceResponse.java new file mode 100644 index 0000000..9c0ac21 --- /dev/null +++ b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/ListSpaceResponse.java @@ -0,0 +1,42 @@ +package dev.langchain4j.community.store.embedding.vearch; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonInclude(NON_NULL) +@JsonNaming(SnakeCaseStrategy.class) +public class ListSpaceResponse { + + private Integer id; + private String name; + + ListSpaceResponse() { + + } + + ListSpaceResponse(Integer id, String name) { + this.id = id; + this.name = name; + } + + public Integer getId() { + return id; + } + + public void setId(Integer id) { + this.id = id; + } + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } +} diff --git a/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/MetricType.java b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/MetricType.java new file mode 100644 index 0000000..0077b2a --- /dev/null +++ b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/MetricType.java @@ -0,0 +1,17 @@ +package dev.langchain4j.community.store.embedding.vearch; + +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * if metric type is not set when searching, it will use the parameter specified when building the space + * + *

LangChain4j currently only support {@link MetricType#INNER_PRODUCT}

+ */ +public enum MetricType { + + /** + * Inner Product + */ + @JsonProperty("InnerProduct") + INNER_PRODUCT +} diff --git a/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/ResponseWrapper.java b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/ResponseWrapper.java new file mode 100644 index 0000000..2511e08 --- /dev/null +++ b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/ResponseWrapper.java @@ -0,0 +1,51 @@ +package dev.langchain4j.community.store.embedding.vearch; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonInclude(NON_NULL) +@JsonNaming(SnakeCaseStrategy.class) +class ResponseWrapper { + + private Integer code; + private String msg; + private T data; + + ResponseWrapper() { + } + + ResponseWrapper(Integer code, String msg, T data) { + this.code = code; + this.msg = msg; + this.data = data; + } + + public Integer getCode() { + return code; + } + + public void setCode(Integer code) { + this.code = code; + } + + public String getMsg() { + return msg; + } + + public void setMsg(String msg) { + this.msg = msg; + } + + public T getData() { + return data; + } + + public void setData(T data) { + this.data = data; + } +} diff --git a/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/SearchRequest.java b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/SearchRequest.java new file mode 100644 index 0000000..a858b3e --- /dev/null +++ b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/SearchRequest.java @@ -0,0 +1,230 @@ +package dev.langchain4j.community.store.embedding.vearch; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import dev.langchain4j.community.store.embedding.vearch.index.search.SearchIndexParam; + +import java.util.List; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonInclude(NON_NULL) +@JsonNaming(SnakeCaseStrategy.class) +class SearchRequest { + + private String dbName; + private String spaceName; + private List vectors; + private List fields; + private Boolean vectorValue; + private Integer limit; + private SearchIndexParam indexParams; + + SearchRequest() { + } + + SearchRequest(Builder builder) { + this.dbName = builder.dbName; + this.spaceName = builder.spaceName; + this.vectors = builder.vectors; + this.fields = builder.fields; + this.vectorValue = builder.vectorValue; + this.limit = builder.limit; + this.indexParams = builder.indexParams; + } + + public String getDbName() { + return dbName; + } + + public String getSpaceName() { + return spaceName; + } + + public List getVectors() { + return vectors; + } + + public List getFields() { + return fields; + } + + public Boolean getVectorValue() { + return vectorValue; + } + + public Integer getLimit() { + return limit; + } + + public SearchIndexParam getIndexParams() { + return indexParams; + } + + static Builder builder() { + return new Builder(); + } + + static class Builder { + + private String dbName; + private String spaceName; + private List vectors; + private List fields; + private Boolean vectorValue; + private Integer limit; + private SearchIndexParam indexParams; + + Builder dbName(String dbName) { + this.dbName = dbName; + return this; + } + + Builder spaceName(String spaceName) { + this.spaceName = spaceName; + return this; + } + + Builder vectors(List vectors) { + this.vectors = vectors; + return this; + } + + Builder fields(List fields) { + this.fields = fields; + return this; + } + + Builder vectorValue(Boolean vectorValue) { + this.vectorValue = vectorValue; + return this; + } + + Builder limit(Integer limit) { + this.limit = limit; + return this; + } + + Builder indexParams(SearchIndexParam indexParams) { + this.indexParams = indexParams; + return this; + } + + SearchRequest build() { + return new SearchRequest(this); + } + } + + @JsonIgnoreProperties(ignoreUnknown = true) + @JsonInclude(NON_NULL) + @JsonNaming(SnakeCaseStrategy.class) + static class Vector { + + private String field; + private List feature; + private Double minScore; + + Vector() { + } + + Vector(String field, List feature, Double minScore) { + this.field = field; + this.feature = feature; + this.minScore = minScore; + } + + public String getField() { + return field; + } + + public List getFeature() { + return feature; + } + + public Double getMinScore() { + return minScore; + } + + static Builder builder() { + return new Builder(); + } + + static class Builder { + + private String field; + private List feature; + private Double minScore; + + Builder field(String field) { + this.field = field; + return this; + } + + Builder feature(List feature) { + this.feature = feature; + return this; + } + + Builder minScore(Double minScore) { + this.minScore = minScore; + return this; + } + + Vector build() { + return new Vector(field, feature, minScore); + } + } + } + + @JsonIgnoreProperties(ignoreUnknown = true) + @JsonInclude(NON_NULL) + @JsonNaming(SnakeCaseStrategy.class) + static class RankerParam { + + private String type; + private List params; + + RankerParam() { + } + + RankerParam(String type, List params) { + this.type = type; + this.params = params; + } + + public String getType() { + return type; + } + + public List getParams() { + return params; + } + + static Builder builder() { + return new Builder(); + } + + static class Builder { + + private String type; + private List params; + + Builder type(String type) { + this.type = type; + return this; + } + + Builder params(List params) { + this.params = params; + return this; + } + + RankerParam build() { + return new RankerParam(type, params); + } + } + } +} diff --git a/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/SearchResponse.java b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/SearchResponse.java new file mode 100644 index 0000000..37e5ccc --- /dev/null +++ b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/SearchResponse.java @@ -0,0 +1,34 @@ +package dev.langchain4j.community.store.embedding.vearch; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; + +import java.util.List; +import java.util.Map; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonInclude(NON_NULL) +@JsonNaming(SnakeCaseStrategy.class) +class SearchResponse { + + private List>> documents; + + SearchResponse() { + } + + SearchResponse(List>> documents) { + this.documents = documents; + } + + public List>> getDocuments() { + return documents; + } + + public void setDocuments(List>> documents) { + this.documents = documents; + } +} diff --git a/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/StoreParam.java b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/StoreParam.java new file mode 100644 index 0000000..881de9d --- /dev/null +++ b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/StoreParam.java @@ -0,0 +1,48 @@ +package dev.langchain4j.community.store.embedding.vearch; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonInclude(NON_NULL) +@JsonNaming(SnakeCaseStrategy.class) +public class StoreParam { + + /** + * It means you will use so much memory, the excess will be kept to disk. For MemoryOnly, this parameter is invalid. + */ + private Integer cacheSize; + + public StoreParam() { + } + + public StoreParam(Integer cacheSize) { + this.cacheSize = cacheSize; + } + + public Integer getCacheSize() { + return cacheSize; + } + + public void setCacheSize(Integer cacheSize) { + this.cacheSize = cacheSize; + } + + public static class Builder { + + private Integer cacheSize; + + public Builder cacheSize(Integer cacheSize) { + this.cacheSize = cacheSize; + return this; + } + + public StoreParam build() { + return new StoreParam(cacheSize); + } + } +} diff --git a/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/StoreType.java b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/StoreType.java new file mode 100644 index 0000000..75ec902 --- /dev/null +++ b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/StoreType.java @@ -0,0 +1,11 @@ +package dev.langchain4j.community.store.embedding.vearch; + +import com.fasterxml.jackson.annotation.JsonProperty; + +public enum StoreType { + + @JsonProperty("MemoryOnly") + MEMORY_ONLY, + @JsonProperty("RocksDB") + ROCKS_DB +} diff --git a/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/UpsertRequest.java b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/UpsertRequest.java new file mode 100644 index 0000000..2aeef70 --- /dev/null +++ b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/UpsertRequest.java @@ -0,0 +1,72 @@ +package dev.langchain4j.community.store.embedding.vearch; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import com.fasterxml.jackson.databind.annotation.JsonNaming; + +import java.util.List; +import java.util.Map; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonInclude(NON_NULL) +@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) +public class UpsertRequest { + + private String dbName; + private String spaceName; + private List> documents; + + UpsertRequest() { + } + + UpsertRequest(String dbName, String spaceName, List> documents) { + this.dbName = dbName; + this.spaceName = spaceName; + this.documents = documents; + } + + public String getDbName() { + return dbName; + } + + public String getSpaceName() { + return spaceName; + } + + public List> getDocuments() { + return documents; + } + + static Builder builder() { + return new Builder(); + } + + static class Builder { + + private String dbName; + private String spaceName; + private List> documents; + + Builder dbName(String dbName) { + this.dbName = dbName; + return this; + } + + Builder spaceName(String spaceName) { + this.spaceName = spaceName; + return this; + } + + Builder documents(List> documents) { + this.documents = documents; + return this; + } + + UpsertRequest build() { + return new UpsertRequest(dbName, spaceName, documents); + } + } +} diff --git a/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/UpsertResponse.java b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/UpsertResponse.java new file mode 100644 index 0000000..545d26a --- /dev/null +++ b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/UpsertResponse.java @@ -0,0 +1,66 @@ +package dev.langchain4j.community.store.embedding.vearch; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; + +import java.util.List; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; +import static dev.langchain4j.community.store.embedding.vearch.VearchConfig.DEFAULT_ID_FIELD_NAME; + +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonInclude(NON_NULL) +@JsonNaming(SnakeCaseStrategy.class) +class UpsertResponse { + + private Integer total; + private List documentIds; + + UpsertResponse() { + } + + UpsertResponse(Integer total, List documentIds) { + this.total = total; + this.documentIds = documentIds; + } + + public Integer getTotal() { + return total; + } + + public void setTotal(Integer total) { + this.total = total; + } + + public List getDocumentIds() { + return documentIds; + } + + public void setDocumentIds(List documentIds) { + this.documentIds = documentIds; + } + + static class DocumentInfo { + + @JsonProperty(DEFAULT_ID_FIELD_NAME) + private String id; + + DocumentInfo() { + } + + DocumentInfo(String id) { + this.id = id; + } + + public String getId() { + return id; + } + + public void setId(String id) { + this.id = id; + } + } +} diff --git a/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/VearchApi.java b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/VearchApi.java new file mode 100644 index 0000000..5c5bcfe --- /dev/null +++ b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/VearchApi.java @@ -0,0 +1,50 @@ +package dev.langchain4j.community.store.embedding.vearch; + +import retrofit2.Call; +import retrofit2.http.Body; +import retrofit2.http.DELETE; +import retrofit2.http.GET; +import retrofit2.http.Headers; +import retrofit2.http.POST; +import retrofit2.http.Path; + +import java.util.List; + +interface VearchApi { + + int OK = 0; + + /* Database Operation */ + + @GET("dbs") + @Headers({"accept: application/json"}) + Call>> listDatabase(); + + @POST("dbs/{dbName}") + @Headers({"accept: application/json", "content-type: application/json"}) + Call> createDatabase(@Path("dbName") String dbName); + + @GET("dbs/{dbName}/spaces") + @Headers({"accept: application/json"}) + Call>> listSpaceOfDatabase(@Path("dbName") String dbName); + + /* Space (like a table in relational database) Operation */ + + @POST("dbs/{dbName}/spaces") + @Headers({"accept: application/json", "content-type: application/json"}) + Call> createSpace(@Path("dbName") String dbName, + @Body CreateSpaceRequest request); + + @DELETE("dbs/{dbName}/spaces/{spaceName}") + Call deleteSpace(@Path("dbName") String dbName, @Path("spaceName") String spaceName); + + /* Document Operation */ + + @POST("document/upsert") + @Headers({"accept: application/json", "content-type: application/json"}) + Call> upsert(@Body UpsertRequest request); + + @POST("document/search") + @Headers({"accept: application/json", "content-type: application/json"}) + Call> search(@Body SearchRequest request); +} diff --git a/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/VearchClient.java b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/VearchClient.java new file mode 100644 index 0000000..5f6370a --- /dev/null +++ b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/VearchClient.java @@ -0,0 +1,223 @@ +package dev.langchain4j.community.store.embedding.vearch; + +import com.fasterxml.jackson.databind.ObjectMapper; +import dev.langchain4j.internal.Utils; +import okhttp3.OkHttpClient; +import retrofit2.Response; +import retrofit2.Retrofit; +import retrofit2.converter.jackson.JacksonConverterFactory; + +import java.io.IOException; +import java.time.Duration; +import java.util.List; + +import static com.fasterxml.jackson.databind.SerializationFeature.INDENT_OUTPUT; +import static dev.langchain4j.community.store.embedding.vearch.VearchApi.OK; + +class VearchClient { + + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper() + .enable(INDENT_OUTPUT); + + private final VearchApi vearchApi; + + public VearchClient(String baseUrl, + Duration timeout, + boolean logRequests, + boolean logResponses) { + OkHttpClient.Builder okHttpClientBuilder = new OkHttpClient.Builder() + .callTimeout(timeout) + .connectTimeout(timeout) + .readTimeout(timeout) + .writeTimeout(timeout); + + if (logRequests) { + okHttpClientBuilder.addInterceptor(new VearchRequestLoggingInterceptor()); + } + if (logResponses) { + okHttpClientBuilder.addInterceptor(new VearchResponseLoggingInterceptor()); + } + + Retrofit retrofit = new Retrofit.Builder() + .baseUrl(Utils.ensureTrailingForwardSlash(baseUrl)) + .client(okHttpClientBuilder.build()) + .addConverterFactory(JacksonConverterFactory.create(OBJECT_MAPPER)) + .build(); + + vearchApi = retrofit.create(VearchApi.class); + } + + public List listDatabase() { + try { + Response>> response = vearchApi.listDatabase().execute(); + + if (response.isSuccessful() && response.body() != null) { + ResponseWrapper> wrapper = response.body(); + if (wrapper.getCode() != OK) { + throw toException(wrapper); + } + return wrapper.getData(); + } else { + throw toException(response); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public CreateDatabaseResponse createDatabase(String databaseName) { + try { + Response> response = vearchApi.createDatabase(databaseName).execute(); + + if (response.isSuccessful() && response.body() != null) { + ResponseWrapper wrapper = response.body(); + if (wrapper.getCode() != OK) { + throw toException(wrapper); + } + return wrapper.getData(); + } else { + throw toException(response); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public List listSpaceOfDatabase(String dbName) { + try { + Response>> response = vearchApi.listSpaceOfDatabase(dbName).execute(); + + if (response.isSuccessful() && response.body() != null) { + ResponseWrapper> wrapper = response.body(); + if (wrapper.getCode() != OK) { + throw toException(wrapper); + } + return wrapper.getData(); + } else { + throw toException(response); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public CreateSpaceResponse createSpace(String dbName, CreateSpaceRequest request) { + try { + Response> response = vearchApi.createSpace(dbName, request).execute(); + + if (response.isSuccessful() && response.body() != null) { + ResponseWrapper wrapper = response.body(); + if (wrapper.getCode() != OK) { + throw toException(wrapper); + } + return wrapper.getData(); + } else { + throw toException(response); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public void upsert(UpsertRequest request) { + try { + Response> response = vearchApi.upsert(request).execute(); + + if (response.isSuccessful() && response.body() != null) { + ResponseWrapper wrapper = response.body(); + if (wrapper.getCode() != OK) { + throw toException(wrapper); + } + } else { + throw toException(response); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public SearchResponse search(SearchRequest request) { + try { + Response> response = vearchApi.search(request).execute(); + + if (response.isSuccessful() && response.body() != null) { + ResponseWrapper wrapper = response.body(); + if (wrapper.getCode() != OK) { + throw toException(wrapper); + } + + return wrapper.getData(); + } else { + throw toException(response); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public void deleteSpace(String databaseName, String spaceName) { + try { + Response response = vearchApi.deleteSpace(databaseName, spaceName).execute(); + + if (!response.isSuccessful()) { + throw toException(response); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + static Builder builder() { + return new Builder(); + } + + static final class Builder { + + private String baseUrl; + private Duration timeout; + private boolean logRequests; + private boolean logResponses; + + Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + Builder timeout(Duration timeout) { + this.timeout = timeout; + return this; + } + + Builder logRequests(boolean logRequests) { + this.logRequests = logRequests; + return this; + } + + Builder logResponses(boolean logResponses) { + this.logResponses = logResponses; + return this; + } + + VearchClient build() { + return new VearchClient(baseUrl, timeout, logRequests, logResponses); + } + } + + private RuntimeException toException(Response response) throws IOException { + int code = response.code(); + String body = response.errorBody().string(); + + String errorMessage = String.format("status code: %s; body: %s", code, body); + return new RuntimeException(errorMessage); + } + + private RuntimeException toException(ResponseWrapper responseWrapper) { + return toException(responseWrapper.getCode(), responseWrapper.getMsg()); + } + + private RuntimeException toException(int code, String msg) { + String errorMessage = String.format("code: %s; message: %s", code, msg); + + return new RuntimeException(errorMessage); + } +} diff --git a/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/VearchConfig.java b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/VearchConfig.java new file mode 100644 index 0000000..cdfe3ac --- /dev/null +++ b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/VearchConfig.java @@ -0,0 +1,181 @@ +package dev.langchain4j.community.store.embedding.vearch; + +import dev.langchain4j.community.store.embedding.vearch.field.Field; +import dev.langchain4j.community.store.embedding.vearch.field.FieldType; +import dev.langchain4j.community.store.embedding.vearch.field.StringField; +import dev.langchain4j.community.store.embedding.vearch.field.VectorField; +import dev.langchain4j.community.store.embedding.vearch.index.HNSWParam; +import dev.langchain4j.community.store.embedding.vearch.index.Index; +import dev.langchain4j.community.store.embedding.vearch.index.IndexType; +import dev.langchain4j.community.store.embedding.vearch.index.search.SearchIndexParam; + +import java.util.List; + +import static dev.langchain4j.internal.Utils.getOrDefault; +import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; + +public class VearchConfig { + + static final String DEFAULT_ID_FIELD_NAME = "_id"; + static final String DEFAULT_EMBEDDING_FIELD_NAME = "embedding"; + static final String DEFAULT_TEXT_FIELD_NAME = "text"; + static final String DEFAULT_SCORE_FILED_NAME = "_score"; + + static final List DEFAULT_FIELDS = List.of( + VectorField.builder() + .name(DEFAULT_EMBEDDING_FIELD_NAME) + .dimension(384) + .index(Index.builder() + .name("gamma") + .type(IndexType.HNSW) + .params(HNSWParam.builder() + .metricType(MetricType.INNER_PRODUCT) + .efSearch(64) + .build()) + .build()).build(), + StringField.builder() + .fieldType(FieldType.STRING) + .name(DEFAULT_TEXT_FIELD_NAME) + .build() + ); + + private String databaseName; + private String spaceName; + /** + * Index param when searching, if not set, will use {@link Index}. + * + * @see Index + */ + private SearchIndexParam searchIndexParam; + /** + * This attribute's key set should contain + * {@link VearchConfig#embeddingFieldName}, {@link VearchConfig#textFieldName} and {@link VearchConfig#metadataFieldNames} + */ + private List fields; + private String embeddingFieldName; + private String textFieldName; + /** + * This attribute should be the subset of {@link VearchConfig#fields}'s key set + */ + private List metadataFieldNames; + + public VearchConfig(Builder builder) { + this.databaseName = ensureNotNull(builder.databaseName, "databaseName"); + this.spaceName = ensureNotNull(builder.spaceName, "spaceName"); + this.searchIndexParam = builder.searchIndexParam; + this.fields = getOrDefault(builder.fields, DEFAULT_FIELDS); + this.embeddingFieldName = getOrDefault(builder.embeddingFieldName, DEFAULT_EMBEDDING_FIELD_NAME); + this.textFieldName = getOrDefault(builder.textFieldName, DEFAULT_TEXT_FIELD_NAME); + this.metadataFieldNames = builder.metadataFieldNames; + } + + public String getDatabaseName() { + return databaseName; + } + + public void setDatabaseName(String databaseName) { + this.databaseName = databaseName; + } + + public String getSpaceName() { + return spaceName; + } + + public void setSpaceName(String spaceName) { + this.spaceName = spaceName; + } + + public SearchIndexParam getSearchIndexParam() { + return searchIndexParam; + } + + public void setSearchIndexParam(SearchIndexParam searchIndexParam) { + this.searchIndexParam = searchIndexParam; + } + + public List getFields() { + return fields; + } + + public void setFields(List fields) { + this.fields = fields; + } + + public String getEmbeddingFieldName() { + return embeddingFieldName; + } + + public void setEmbeddingFieldName(String embeddingFieldName) { + this.embeddingFieldName = embeddingFieldName; + } + + public String getTextFieldName() { + return textFieldName; + } + + public void setTextFieldName(String textFieldName) { + this.textFieldName = textFieldName; + } + + public List getMetadataFieldNames() { + return metadataFieldNames; + } + + public void setMetadataFieldNames(List metadataFieldNames) { + this.metadataFieldNames = metadataFieldNames; + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private String databaseName; + private String spaceName; + private SearchIndexParam searchIndexParam; + private List fields; + private String embeddingFieldName; + private String textFieldName; + private List metadataFieldNames; + + public Builder databaseName(String databaseName) { + this.databaseName = databaseName; + return this; + } + + public Builder spaceName(String spaceName) { + this.spaceName = spaceName; + return this; + } + + public Builder searchIndexParam(SearchIndexParam searchIndexParam) { + this.searchIndexParam = searchIndexParam; + return this; + } + + public Builder fields(List fields) { + this.fields = fields; + return this; + } + + public Builder embeddingFieldName(String embeddingFieldName) { + this.embeddingFieldName = embeddingFieldName; + return this; + } + + public Builder textFieldName(String textFieldName) { + this.textFieldName = textFieldName; + return this; + } + + public Builder metadataFieldNames(List metadataFieldNames) { + this.metadataFieldNames = metadataFieldNames; + return this; + } + + public VearchConfig build() { + return new VearchConfig(this); + } + } +} diff --git a/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/VearchEmbeddingStore.java b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/VearchEmbeddingStore.java new file mode 100644 index 0000000..6611462 --- /dev/null +++ b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/VearchEmbeddingStore.java @@ -0,0 +1,294 @@ +package dev.langchain4j.community.store.embedding.vearch; + +import dev.langchain4j.data.document.Metadata; +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.store.embedding.EmbeddingMatch; +import dev.langchain4j.store.embedding.EmbeddingSearchRequest; +import dev.langchain4j.store.embedding.EmbeddingSearchResult; +import dev.langchain4j.store.embedding.EmbeddingStore; +import dev.langchain4j.store.embedding.RelevanceScore; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static dev.langchain4j.community.store.embedding.vearch.VearchConfig.DEFAULT_ID_FIELD_NAME; +import static dev.langchain4j.community.store.embedding.vearch.VearchConfig.DEFAULT_SCORE_FILED_NAME; +import static dev.langchain4j.internal.Utils.getOrDefault; +import static dev.langchain4j.internal.Utils.isNullOrBlank; +import static dev.langchain4j.internal.Utils.isNullOrEmpty; +import static dev.langchain4j.internal.Utils.randomUUID; +import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty; +import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; +import static dev.langchain4j.internal.ValidationUtils.ensureTrue; +import static dev.langchain4j.store.embedding.CosineSimilarity.fromRelevanceScore; +import static java.time.Duration.ofSeconds; +import static java.util.Collections.singletonList; +import static java.util.stream.Collectors.toList; + +/** + * Represent a Vearch index as an {@link EmbeddingStore}. + * + *

Current implementation assumes the index uses the cosine distance metric.

+ * + *

Supported Vearch version: 3.4.x and 3.5.x

+ */ +public class VearchEmbeddingStore implements EmbeddingStore { + + private final VearchConfig vearchConfig; + private final VearchClient vearchClient; + /** + * whether to normalize embedding when add to embedding store + */ + private final boolean normalizeEmbeddings; + + public VearchEmbeddingStore(String baseUrl, + Duration timeout, + VearchConfig vearchConfig, + boolean normalizeEmbeddings, + boolean logRequests, + boolean logResponses) { + // Step 0: initialize some attribute + baseUrl = ensureNotNull(baseUrl, "baseUrl"); + this.vearchConfig = ensureNotNull(vearchConfig, "vearchConfig"); + this.normalizeEmbeddings = normalizeEmbeddings; + + vearchClient = VearchClient.builder() + .baseUrl(baseUrl) + .timeout(getOrDefault(timeout, ofSeconds(60))) + .logRequests(logRequests) + .logResponses(logResponses) + .build(); + + // Step 1: check whether db exist, if not, create it + if (!isDatabaseExist(this.vearchConfig.getDatabaseName())) { + createDatabase(this.vearchConfig.getDatabaseName()); + } + + // Step 2: check whether space exist, if not, create it + if (!isSpaceExist(this.vearchConfig.getDatabaseName(), this.vearchConfig.getSpaceName())) { + createSpace(this.vearchConfig.getDatabaseName(), this.vearchConfig.getSpaceName()); + } + } + + public static Builder builder() { + return new Builder(); + } + + @Override + public String add(Embedding embedding) { + String id = randomUUID(); + add(id, embedding); + return id; + } + + @Override + public void add(String id, Embedding embedding) { + addInternal(id, embedding, null); + } + + @Override + public String add(Embedding embedding, TextSegment textSegment) { + String id = randomUUID(); + addInternal(id, embedding, textSegment); + return id; + } + + @Override + public List addAll(List embeddings) { + List ids = embeddings.stream() + .map(ignored -> randomUUID()) + .collect(toList()); + addAllInternal(ids, embeddings, null); + return ids; + } + + @Override + public List addAll(List embeddings, List embedded) { + List ids = embeddings.stream() + .map(ignored -> randomUUID()) + .collect(toList()); + addAllInternal(ids, embeddings, embedded); + return ids; + } + + @Override + public EmbeddingSearchResult search(EmbeddingSearchRequest request) { + double minSimilarity = fromRelevanceScore(request.minScore()); + List fields = new ArrayList<>(Arrays.asList(vearchConfig.getTextFieldName(), vearchConfig.getEmbeddingFieldName())); + if (!isNullOrEmpty(vearchConfig.getMetadataFieldNames())) { + fields.addAll(vearchConfig.getMetadataFieldNames()); + } + SearchRequest vearchRequest = SearchRequest.builder() + .dbName(vearchConfig.getDatabaseName()) + .spaceName(vearchConfig.getSpaceName()) + .vectors(singletonList(SearchRequest.Vector.builder() + .field(vearchConfig.getEmbeddingFieldName()) + .feature(request.queryEmbedding().vectorAsList()) + .minScore(minSimilarity) + .build())) + .fields(fields) + .vectorValue(true) + .limit(request.maxResults()) + .indexParams(vearchConfig.getSearchIndexParam()) + .build(); + + SearchResponse response = vearchClient.search(vearchRequest); + List> matches = toEmbeddingMatch(response.getDocuments().get(0)); + return new EmbeddingSearchResult<>(matches); + } + + public void deleteSpace() { + vearchClient.deleteSpace(vearchConfig.getDatabaseName(), vearchConfig.getSpaceName()); + } + + private void addInternal(String id, Embedding embedding, TextSegment embedded) { + addAllInternal(singletonList(id), singletonList(embedding), embedded == null ? null : singletonList(embedded)); + } + + private void addAllInternal(List ids, List embeddings, List embedded) { + ids = ensureNotEmpty(ids, "ids"); + embeddings = ensureNotEmpty(embeddings, "embeddings"); + ensureTrue(ids.size() == embeddings.size(), "ids size is not equal to embeddings size"); + ensureTrue(embedded == null || embeddings.size() == embedded.size(), "embeddings size is not equal to embedded size"); + + List> documents = new ArrayList<>(ids.size()); + for (int i = 0; i < ids.size(); i++) { + TextSegment textSegment = embedded == null ? null : embedded.get(i); + Map document = new HashMap<>(4); + document.put(DEFAULT_ID_FIELD_NAME, ids.get(i)); + Embedding embedding = embeddings.get(i); + if (normalizeEmbeddings) { + embedding.normalize(); + } + document.put(vearchConfig.getEmbeddingFieldName(), embedding.vector()); + + if (textSegment != null) { + String text = textSegment.text(); + Map metadata = textSegment.metadata().toMap(); + document.put(vearchConfig.getTextFieldName(), text); + if (metadata != null && !metadata.isEmpty()) { + document.putAll(metadata); + } + } + + documents.add(document); + } + + UpsertRequest request = UpsertRequest.builder() + .dbName(vearchConfig.getDatabaseName()) + .spaceName(vearchConfig.getSpaceName()) + .documents(documents) + .build(); + vearchClient.upsert(request); + } + + private boolean isDatabaseExist(String databaseName) { + List databases = vearchClient.listDatabase(); + return databases.stream().anyMatch(database -> databaseName.equals(database.getName())); + } + + private void createDatabase(String databaseName) { + vearchClient.createDatabase(databaseName); + } + + private boolean isSpaceExist(String databaseName, String spaceName) { + List spaces = vearchClient.listSpaceOfDatabase(databaseName); + return spaces.stream().anyMatch(space -> spaceName.equals(space.getName())); + } + + private void createSpace(String databaseName, String space) { + vearchClient.createSpace(databaseName, CreateSpaceRequest.builder() + .name(space) + .replicaNum(1) + .partitionNum(1) + .fields(vearchConfig.getFields()) + .build()); + } + + @SuppressWarnings("unchecked") + private List> toEmbeddingMatch(List> documents) { + if (isNullOrEmpty(documents)) { + return new ArrayList<>(); + } + + return documents.stream().map(document -> { + String id = (String) document.get(DEFAULT_ID_FIELD_NAME); + List vector = (List) document.get(vearchConfig.getEmbeddingFieldName()); + Embedding embedding = Embedding.from(vector.stream().map(Double::floatValue).collect(toList())); + + TextSegment textSegment = null; + String text = (String) document.get(vearchConfig.getTextFieldName()); + if (!isNullOrBlank(text)) { + Map metadataMap = convertMetadataMap(document); + textSegment = TextSegment.from(text, Metadata.from(metadataMap)); + } + + return new EmbeddingMatch<>(RelevanceScore.fromCosineSimilarity(((Number) document.get(DEFAULT_SCORE_FILED_NAME)).doubleValue()), id, embedding, textSegment); + }).collect(toList()); + } + + private Map convertMetadataMap(Map source) { + Map metadataMap = new HashMap<>(source); + // remove id, score, embedded text and embedding + metadataMap.remove(DEFAULT_ID_FIELD_NAME); + metadataMap.remove(DEFAULT_SCORE_FILED_NAME); + metadataMap.remove(vearchConfig.getTextFieldName()); + metadataMap.remove(vearchConfig.getEmbeddingFieldName()); + return metadataMap; + } + + public static class Builder { + + private VearchConfig vearchConfig; + private String baseUrl; + private Duration timeout; + private boolean normalizeEmbeddings; + private boolean logRequests; + private boolean logResponses; + + public Builder vearchConfig(VearchConfig vearchConfig) { + this.vearchConfig = vearchConfig; + return this; + } + + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + public Builder timeout(Duration timeout) { + this.timeout = timeout; + return this; + } + + /** + * Set whether to normalize embedding when add to embedding store + * + * @param normalizeEmbeddings whether to normalize embedding when add to embedding store + * @return builder + */ + public Builder normalizeEmbeddings(boolean normalizeEmbeddings) { + this.normalizeEmbeddings = normalizeEmbeddings; + return this; + } + + public Builder logRequests(boolean logRequests) { + this.logRequests = logRequests; + return this; + } + + public Builder logResponses(boolean logResponses) { + this.logResponses = logResponses; + return this; + } + + public VearchEmbeddingStore build() { + return new VearchEmbeddingStore(baseUrl, timeout, vearchConfig, normalizeEmbeddings, logRequests, logResponses); + } + } +} diff --git a/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/VearchRequestLoggingInterceptor.java b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/VearchRequestLoggingInterceptor.java new file mode 100644 index 0000000..aad1e83 --- /dev/null +++ b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/VearchRequestLoggingInterceptor.java @@ -0,0 +1,78 @@ +package dev.langchain4j.community.store.embedding.vearch; + +import okhttp3.Headers; +import okhttp3.Interceptor; +import okhttp3.Request; +import okhttp3.Response; +import okio.Buffer; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.Set; +import java.util.stream.StreamSupport; + +import static dev.langchain4j.internal.Utils.isNullOrBlank; +import static java.util.stream.Collectors.joining; + +public class VearchRequestLoggingInterceptor implements Interceptor { + + private static final Logger log = LoggerFactory.getLogger(VearchRequestLoggingInterceptor.class); + + private static final Set COMMON_SECRET_HEADERS = Set.of("Authorization"); + + private static String getBody(Request request) { + try { + Buffer buffer = new Buffer(); + if (request.body() == null) { + return ""; + } + request.body().writeTo(buffer); + return buffer.readUtf8(); + } catch (Exception e) { + log.warn("Exception while getting body", e); + return "Exception while getting body: " + e.getMessage(); + } + } + + private static String getHeaders(Headers headers) { + return StreamSupport.stream(headers.spliterator(), false) + .map(header -> formatHeader(header.component1(), header.component2())) + .collect(joining(", ")); + } + + private static String formatHeader(String headerKey, String headerValue) { + if (COMMON_SECRET_HEADERS.contains(headerKey.toLowerCase())) { + headerValue = maskSecretKey(headerValue); + } + return String.format("[%s: %s]", headerKey, headerValue); + } + + private static String maskSecretKey(String key) { + if (isNullOrBlank(key)) { + return key; + } + + if (key.length() >= 7) { + return key.substring(0, 5) + "..." + key.substring(key.length() - 2); + } else { + return "..."; // to short to be masked + } + } + + @Override + public Response intercept(Chain chain) throws IOException { + Request request = chain.request(); + this.log(request); + return chain.proceed(request); + } + + private void log(Request request) { + try { + log.debug("Request:\n- method: {}\n- url: {}\n- headers: {}\n- body: {}", + request.method(), request.url(), getHeaders(request.headers()), getBody(request)); + } catch (Exception e) { + log.warn("Error while logging request: {}", e.getMessage()); + } + } +} diff --git a/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/VearchResponseLoggingInterceptor.java b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/VearchResponseLoggingInterceptor.java new file mode 100644 index 0000000..c5d7db6 --- /dev/null +++ b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/VearchResponseLoggingInterceptor.java @@ -0,0 +1,35 @@ +package dev.langchain4j.community.store.embedding.vearch; + +import okhttp3.Interceptor; +import okhttp3.Request; +import okhttp3.Response; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; + +public class VearchResponseLoggingInterceptor implements Interceptor { + + private static final Logger log = LoggerFactory.getLogger(VearchResponseLoggingInterceptor.class); + + @Override + public Response intercept(Chain chain) throws IOException { + Request request = chain.request(); + Response response = chain.proceed(request); + this.log(response); + return response; + } + + private void log(Response response) { + try { + log.debug("Response:\n- status code: {}\n- headers: {}\n- body: {}", + response.code(), response.headers(), this.getBody(response)); + } catch (Exception e) { + log.warn("Error while logging response: {}", e.getMessage()); + } + } + + private String getBody(Response response) throws IOException { + return response.peekBody(Long.MAX_VALUE).string(); + } +} diff --git a/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/field/Field.java b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/field/Field.java new file mode 100644 index 0000000..ca948ac --- /dev/null +++ b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/field/Field.java @@ -0,0 +1,64 @@ +package dev.langchain4j.community.store.embedding.vearch.field; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import dev.langchain4j.community.store.embedding.vearch.index.Index; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; +import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; + +/** + * As a constraint type of all Space property only + **/ +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonInclude(NON_NULL) +@JsonNaming(SnakeCaseStrategy.class) +public abstract class Field { + + protected String name; + protected FieldType type; + protected Index index; + + protected Field() { + } + + protected Field(String name, FieldType type, Index index) { + this.name = ensureNotNull(name, "name"); + this.type = ensureNotNull(type, "type"); + this.index = index; + } + + public String getName() { + return name; + } + + public FieldType getType() { + return type; + } + + public Index getIndex() { + return index; + } + + protected abstract static class FieldParamBuilder> { + + protected String name; + protected Index index; + + public B name(String name) { + this.name = name; + return self(); + } + + public B index(Index index) { + this.index = index; + return self(); + } + + protected abstract B self(); + + public abstract C build(); + } +} diff --git a/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/field/FieldType.java b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/field/FieldType.java new file mode 100644 index 0000000..106f12a --- /dev/null +++ b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/field/FieldType.java @@ -0,0 +1,39 @@ +package dev.langchain4j.community.store.embedding.vearch.field; + +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Set; + +public enum FieldType { + + /** + * keyword is equivalent to string + */ + @JsonProperty("string") + STRING(StringField.class), + @JsonProperty("stringArray") + STRING_ARRAY(StringField.class), + @JsonProperty("integer") + INTEGER(NumericField.class), + @JsonProperty("long") + LONG(NumericField.class), + @JsonProperty("float") + FLOAT(NumericField.class), + @JsonProperty("double") + DOUBLE(NumericField.class), + @JsonProperty("vector") + VECTOR(VectorField.class); + + static final Set NUMERIC_TYPES = Set.of(INTEGER, LONG, FLOAT, DOUBLE); + static final Set STRING_TYPES = Set.of(STRING, STRING_ARRAY); + + private final Class paramClass; + + FieldType(Class paramClass) { + this.paramClass = paramClass; + } + + public Class getParamClass() { + return paramClass; + } +} diff --git a/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/field/NumericField.java b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/field/NumericField.java new file mode 100644 index 0000000..ba06327 --- /dev/null +++ b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/field/NumericField.java @@ -0,0 +1,53 @@ +package dev.langchain4j.community.store.embedding.vearch.field; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import dev.langchain4j.community.store.embedding.vearch.index.Index; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; +import static dev.langchain4j.community.store.embedding.vearch.field.FieldType.NUMERIC_TYPES; + +/** + * Support field type: INTEGER, LONG, FLOAT, DOUBLE + */ +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonInclude(NON_NULL) +@JsonNaming(SnakeCaseStrategy.class) +public class NumericField extends Field { + + public NumericField() { + } + + public NumericField(String name, FieldType fieldType, Index index) { + super(name, fieldType, index); + if (!NUMERIC_TYPES.contains(fieldType)) { + throw new IllegalArgumentException("Cannot use type " + fieldType + ", supported numeric types are: " + NUMERIC_TYPES); + } + } + + public static NumericParamBuilder builder() { + return new NumericParamBuilder(); + } + + public static class NumericParamBuilder extends FieldParamBuilder { + + private FieldType fieldType; + + public NumericParamBuilder fieldType(FieldType fieldType) { + this.fieldType = fieldType; + return this; + } + + @Override + protected NumericParamBuilder self() { + return this; + } + + @Override + public NumericField build() { + return new NumericField(name, fieldType, index); + } + } +} diff --git a/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/field/StringField.java b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/field/StringField.java new file mode 100644 index 0000000..c28ff0b --- /dev/null +++ b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/field/StringField.java @@ -0,0 +1,50 @@ +package dev.langchain4j.community.store.embedding.vearch.field; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import dev.langchain4j.community.store.embedding.vearch.index.Index; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; +import static dev.langchain4j.community.store.embedding.vearch.field.FieldType.STRING_TYPES; + +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonInclude(NON_NULL) +@JsonNaming(SnakeCaseStrategy.class) +public class StringField extends Field { + + public StringField() { + } + + public StringField(String name, FieldType fieldType, Index index) { + super(name, fieldType, index); + if (!STRING_TYPES.contains(fieldType)) { + throw new IllegalArgumentException("Cannot use type " + fieldType + ", supported string types are: " + STRING_TYPES); + } + } + + public static StringParamBuilder builder() { + return new StringParamBuilder(); + } + + public static class StringParamBuilder extends FieldParamBuilder { + + private FieldType fieldType; + + public StringParamBuilder fieldType(FieldType fieldType) { + this.fieldType = fieldType; + return this; + } + + @Override + protected StringParamBuilder self() { + return this; + } + + @Override + public StringField build() { + return new StringField(name, fieldType, index); + } + } +} diff --git a/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/field/VectorField.java b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/field/VectorField.java new file mode 100644 index 0000000..e536f8e --- /dev/null +++ b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/field/VectorField.java @@ -0,0 +1,112 @@ +package dev.langchain4j.community.store.embedding.vearch.field; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import dev.langchain4j.community.store.embedding.vearch.StoreParam; +import dev.langchain4j.community.store.embedding.vearch.StoreType; +import dev.langchain4j.community.store.embedding.vearch.index.Index; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonInclude(NON_NULL) +@JsonNaming(SnakeCaseStrategy.class) +public class VectorField extends Field { + + private Integer dimension; + /** + * "RocksDB" or "MemoryOnly". For HNSW and IVFFLAT and FLAT, it can only be run in MemoryOnly mode. + * + * @see StoreType + */ + private StoreType storeType; + private StoreParam storeParam; + private String modelId; + /** + * default not normalized. if you set "normalization", "normal" it will normalized + */ + private String format; + + public VectorField() { + } + + public VectorField(String name, Index index, Integer dimension, StoreType storeType, + StoreParam storeParam, String modelId, String format) { + super(name, FieldType.VECTOR, index); + this.dimension = dimension; + this.storeType = storeType; + this.storeParam = storeParam; + this.modelId = modelId; + this.format = format; + } + + public Integer getDimension() { + return dimension; + } + + public StoreType getStoreType() { + return storeType; + } + + public StoreParam getStoreParam() { + return storeParam; + } + + public String getModelId() { + return modelId; + } + + public String getFormat() { + return format; + } + + public static VectorParamBuilder builder() { + return new VectorParamBuilder(); + } + + public static class VectorParamBuilder extends FieldParamBuilder { + + private Integer dimension; + private StoreType storeType; + private StoreParam storeParam; + private String modelId; + private String format; + + public VectorParamBuilder dimension(Integer dimension) { + this.dimension = dimension; + return this; + } + + public VectorParamBuilder storeType(StoreType storeType) { + this.storeType = storeType; + return this; + } + + public VectorParamBuilder storeParam(StoreParam storeParam) { + this.storeParam = storeParam; + return this; + } + + public VectorParamBuilder modelId(String modelId) { + this.modelId = modelId; + return this; + } + + public VectorParamBuilder format(String format) { + this.format = format; + return this; + } + + @Override + protected VectorParamBuilder self() { + return this; + } + + @Override + public VectorField build() { + return new VectorField(name, index, dimension, storeType, storeParam, modelId, format); + } + } +} diff --git a/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/index/BINARYIVFParam.java b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/index/BINARYIVFParam.java new file mode 100644 index 0000000..b07896a --- /dev/null +++ b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/index/BINARYIVFParam.java @@ -0,0 +1,75 @@ +package dev.langchain4j.community.store.embedding.vearch.index; + +import com.fasterxml.jackson.annotation.JsonProperty; +import dev.langchain4j.community.store.embedding.vearch.MetricType; + +public class BINARYIVFParam extends IndexParam { + + /** + * coarse cluster center number + * + *

default 256

+ */ + @JsonProperty("ncentroids") + private Integer nCentroids; + private Integer trainingThreshold; + @JsonProperty("nprob") + private Integer nProb; + + public BINARYIVFParam() { + } + + public BINARYIVFParam(MetricType metricType, Integer nCentroids, Integer trainingThreshold, Integer nProb) { + super(metricType); + this.nCentroids = nCentroids; + this.trainingThreshold = trainingThreshold; + this.nProb = nProb; + } + + public Integer getNCentroids() { + return nCentroids; + } + + public Integer getTrainingThreshold() { + return trainingThreshold; + } + + public Integer getNProb() { + return nProb; + } + + public static BinaryIVFParamBuilder builder() { + return new BinaryIVFParamBuilder(); + } + + public static class BinaryIVFParamBuilder extends IndexParamBuilder { + + private Integer nCentroids; + private Integer trainingThreshold; + private Integer nProb; + + public BinaryIVFParamBuilder nCentroids(Integer nCentroids) { + this.nCentroids = nCentroids; + return this; + } + + public BinaryIVFParamBuilder trainingThreshold(Integer trainingThreshold) { + this.trainingThreshold = trainingThreshold; + return this; + } + + public BinaryIVFParamBuilder nProb(Integer nProb) { + this.nProb = nProb; + return this; + } + + @Override + protected BinaryIVFParamBuilder self() { + return this; + } + + public BINARYIVFParam build() { + return new BINARYIVFParam(metricType, nCentroids, trainingThreshold, nProb); + } + } +} diff --git a/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/index/FLATParam.java b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/index/FLATParam.java new file mode 100644 index 0000000..0849f1b --- /dev/null +++ b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/index/FLATParam.java @@ -0,0 +1,30 @@ +package dev.langchain4j.community.store.embedding.vearch.index; + +import dev.langchain4j.community.store.embedding.vearch.MetricType; + +public class FLATParam extends IndexParam { + + public FLATParam() { + } + + public FLATParam(MetricType metricType) { + super(metricType); + } + + public static FLATParamBuilder builder() { + return new FLATParamBuilder(); + } + + public static class FLATParamBuilder extends IndexParamBuilder { + + @Override + protected FLATParamBuilder self() { + return this; + } + + @Override + public FLATParam build() { + return new FLATParam(metricType); + } + } +} diff --git a/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/index/GPUParam.java b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/index/GPUParam.java new file mode 100644 index 0000000..4e25f78 --- /dev/null +++ b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/index/GPUParam.java @@ -0,0 +1,93 @@ +package dev.langchain4j.community.store.embedding.vearch.index; + +import com.fasterxml.jackson.annotation.JsonProperty; +import dev.langchain4j.community.store.embedding.vearch.MetricType; + +public class GPUParam extends IndexParam { + + /* number of buckets for indexing + * + *

default 2048

+ */ + @JsonProperty("ncentroids") + private Integer nCentroids; + /** + * the number of sub vector + * + *

default 64

+ */ + @JsonProperty("nsubvector") + private Integer nSubVector; + private Integer trainingThreshold; + @JsonProperty("nprob") + private Integer nProb; + + public GPUParam() { + } + + public GPUParam(MetricType metricType, Integer nCentroids, Integer nSubVector, Integer trainingThreshold, Integer nProb) { + super(metricType); + this.nCentroids = nCentroids; + this.nSubVector = nSubVector; + this.trainingThreshold = trainingThreshold; + this.nProb = nProb; + } + + public Integer getNCentroids() { + return nCentroids; + } + + public Integer getNSubVector() { + return nSubVector; + } + + public Integer getTrainingThreshold() { + return trainingThreshold; + } + + public Integer getNProb() { + return nProb; + } + + public static GPUParamBuilder builder() { + return new GPUParamBuilder(); + } + + public static class GPUParamBuilder extends IndexParamBuilder { + + private Integer nCentroids; + private Integer nSubVector; + private Integer trainingThreshold; + private Integer nProb; + + public GPUParamBuilder nCentroids(Integer nCentroids) { + this.nCentroids = nCentroids; + return this; + } + + public GPUParamBuilder nSubVector(Integer nSubVector) { + this.nSubVector = nSubVector; + return this; + } + + public GPUParamBuilder trainingThreshold(Integer trainingThreshold) { + this.trainingThreshold = trainingThreshold; + return this; + } + + public GPUParamBuilder nProb(Integer nProb) { + this.nProb = nProb; + return this; + } + + @Override + protected GPUParamBuilder self() { + return this; + } + + @Override + public GPUParam build() { + return new GPUParam(metricType, nCentroids, nSubVector, trainingThreshold, nProb); + } + } +} diff --git a/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/index/HNSWParam.java b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/index/HNSWParam.java new file mode 100644 index 0000000..3faaad4 --- /dev/null +++ b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/index/HNSWParam.java @@ -0,0 +1,84 @@ +package dev.langchain4j.community.store.embedding.vearch.index; + +import com.fasterxml.jackson.annotation.JsonProperty; +import dev.langchain4j.community.store.embedding.vearch.MetricType; + + +public class HNSWParam extends IndexParam { + + /** + * neighbors number of each node + * + *

default 32

+ */ + @JsonProperty("nlinks") + private Integer nLinks; + /** + * expansion factor at construction time + * + *

default 40

+ *

The higher the value, the better the construction effect, and the longer it takes

+ */ + @JsonProperty("efConstruction") + private Integer efConstruction; + @JsonProperty("efSearch") + private Integer efSearch; + + public HNSWParam() { + } + + public HNSWParam(MetricType metricType, Integer nLinks, Integer efConstruction, Integer efSearch) { + super(metricType); + this.nLinks = nLinks; + this.efConstruction = efConstruction; + this.efSearch = efSearch; + } + + public Integer getNLinks() { + return nLinks; + } + + public Integer getEfConstruction() { + return efConstruction; + } + + public Integer getEfSearch() { + return efSearch; + } + + public static HNSWParamBuilder builder() { + return new HNSWParamBuilder(); + } + + public static class HNSWParamBuilder extends IndexParamBuilder { + + private Integer nLinks; + private Integer efConstruction; + private Integer efSearch; + + public HNSWParamBuilder nLinks(Integer nLinks) { + this.nLinks = nLinks; + return this; + } + + public HNSWParamBuilder efConstruction(Integer efConstruction) { + this.efConstruction = efConstruction; + return this; + } + + public HNSWParamBuilder efSearch(Integer efSearch) { + this.efSearch = efSearch; + return this; + } + + @Override + protected HNSWParamBuilder self() { + return this; + } + + @Override + public HNSWParam build() { + return new HNSWParam(metricType, nLinks, efConstruction, efSearch); + } + } +} diff --git a/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/index/IVFFLATParam.java b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/index/IVFFLATParam.java new file mode 100644 index 0000000..6338a46 --- /dev/null +++ b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/index/IVFFLATParam.java @@ -0,0 +1,85 @@ +package dev.langchain4j.community.store.embedding.vearch.index; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import dev.langchain4j.community.store.embedding.vearch.MetricType; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonInclude(NON_NULL) +@JsonNaming(SnakeCaseStrategy.class) +public class IVFFLATParam extends IndexParam { + + /** + * number of buckets for indexing + * + *

default 2048

+ */ + @JsonProperty("ncentroids") + private Integer nCentroids; + private Integer trainingThreshold; + @JsonProperty("nprob") + private Integer nProb; + + public IVFFLATParam() { + } + + public IVFFLATParam(MetricType metricType, Integer nCentroids, Integer trainingThreshold, Integer nProb) { + super(metricType); + this.nCentroids = nCentroids; + this.trainingThreshold = trainingThreshold; + this.nProb = nProb; + } + + public Integer getNCentroids() { + return nCentroids; + } + + public Integer getTrainingThreshold() { + return trainingThreshold; + } + + public Integer getNProb() { + return nProb; + } + + public static IVFFlatParamBuilder builder() { + return new IVFFlatParamBuilder(); + } + + public static class IVFFlatParamBuilder extends IndexParamBuilder { + + private Integer nCentroids; + private Integer trainingThreshold; + private Integer nProb; + + public IVFFlatParamBuilder nCentroids(Integer nCentroids) { + this.nCentroids = nCentroids; + return this; + } + + public IVFFlatParamBuilder trainingThreshold(Integer trainingThreshold) { + this.trainingThreshold = trainingThreshold; + return this; + } + + public IVFFlatParamBuilder nProb(Integer nProb) { + this.nProb = nProb; + return this; + } + + @Override + protected IVFFlatParamBuilder self() { + return this; + } + + @Override + public IVFFLATParam build() { + return new IVFFLATParam(metricType, nCentroids, trainingThreshold, nProb); + } + } +} diff --git a/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/index/IVFPQParam.java b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/index/IVFPQParam.java new file mode 100644 index 0000000..7a6de0c --- /dev/null +++ b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/index/IVFPQParam.java @@ -0,0 +1,142 @@ +package dev.langchain4j.community.store.embedding.vearch.index; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import dev.langchain4j.community.store.embedding.vearch.MetricType; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonInclude(NON_NULL) +@JsonNaming(SnakeCaseStrategy.class) +public class IVFPQParam extends IndexParam { + + /** + * number of buckets for indexing + * + *

default 2048

+ */ + @JsonProperty("ncentroids") + private Integer nCentroids; + /** + * PQ disassembler vector size + * + *

default 64, must be a multiple of 4

+ */ + @JsonProperty("nsubvector") + private Integer nSubVector; + /** + * bucket init size + */ + private Integer bucketInitSize; + /** + * max size for each bucket + */ + private Integer bucketMaxSize; + /** + * training data size + */ + private Integer trainingThreshold; + /** + * the number of cluster centers found during retrieval + * + *

default 80

+ */ + @JsonProperty("nprob") + private Integer nProb; + + public IVFPQParam() { + } + + public IVFPQParam(MetricType metricType, Integer nCentroids, Integer nSubVector, + Integer bucketInitSize, Integer bucketMaxSize, Integer trainingThreshold, Integer nProb) { + super(metricType); + this.nCentroids = nCentroids; + this.nSubVector = nSubVector; + this.bucketInitSize = bucketInitSize; + this.bucketMaxSize = bucketMaxSize; + this.trainingThreshold = trainingThreshold; + this.nProb = nProb; + } + + public Integer getNCentroids() { + return nCentroids; + } + + public Integer getNSubVector() { + return nSubVector; + } + + public Integer getBucketInitSize() { + return bucketInitSize; + } + + public Integer getBucketMaxSize() { + return bucketMaxSize; + } + + public Integer getTrainingThreshold() { + return trainingThreshold; + } + + public Integer getNProb() { + return nProb; + } + + public static IVFPQParamBuilder builder() { + return new IVFPQParamBuilder(); + } + + public static class IVFPQParamBuilder extends IndexParamBuilder { + + private Integer nCentroids; + private Integer nSubVector; + private Integer bucketInitSize; + private Integer bucketMaxSize; + private Integer trainingThreshold; + private Integer nProb; + + public IVFPQParamBuilder nCentroids(Integer nCentroids) { + this.nCentroids = nCentroids; + return this; + } + + public IVFPQParamBuilder nSubVector(Integer nSubVector) { + this.nSubVector = nSubVector; + return this; + } + + public IVFPQParamBuilder bucketInitSize(Integer bucketInitSize) { + this.bucketInitSize = bucketInitSize; + return this; + } + + public IVFPQParamBuilder bucketMaxSize(Integer bucketMaxSize) { + this.bucketMaxSize = bucketMaxSize; + return this; + } + + public IVFPQParamBuilder trainingThreshold(Integer trainingThreshold) { + this.trainingThreshold = trainingThreshold; + return this; + } + + public IVFPQParamBuilder nProb(Integer nProb) { + this.nProb = nProb; + return this; + } + + @Override + protected IVFPQParamBuilder self() { + return this; + } + + @Override + public IVFPQParam build() { + return new IVFPQParam(metricType, nCentroids, nSubVector, bucketInitSize, bucketMaxSize, trainingThreshold, nProb); + } + } +} diff --git a/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/index/Index.java b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/index/Index.java new file mode 100644 index 0000000..3bbf11f --- /dev/null +++ b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/index/Index.java @@ -0,0 +1,88 @@ +package dev.langchain4j.community.store.embedding.vearch.index; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonInclude(NON_NULL) +@JsonNaming(SnakeCaseStrategy.class) +public class Index { + + private String name; + private IndexType type; + private IndexParam params; + + public Index() { + } + + public Index(String name, IndexType type, IndexParam params) { + setName(name); + setType(type); + setParams(params); + } + + public static Builder builder() { + return new Builder(); + } + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public IndexType getType() { + return type; + } + + public void setType(IndexType type) { + this.type = type; + } + + public IndexParam getParams() { + return params; + } + + public void setParams(IndexParam params) { + // do some constraint check + Class clazz = type.getCreateSpaceParamClass(); + if (clazz != null && !clazz.isInstance(params)) { + throw new UnsupportedOperationException( + String.format("can't assign unknown param of engine %s, please use class %s to assign engine param", + type.name(), clazz.getSimpleName())); + } + this.params = params; + } + + public static class Builder { + + private String name; + private IndexType type; + private IndexParam params; + + public Builder name(String name) { + this.name = name; + return this; + } + + public Builder type(IndexType type) { + this.type = type; + return this; + } + + public Builder params(IndexParam params) { + this.params = params; + return this; + } + + public Index build() { + return new Index(name, type, params); + } + } +} diff --git a/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/index/IndexParam.java b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/index/IndexParam.java new file mode 100644 index 0000000..7db0c5f --- /dev/null +++ b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/index/IndexParam.java @@ -0,0 +1,48 @@ +package dev.langchain4j.community.store.embedding.vearch.index; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import dev.langchain4j.community.store.embedding.vearch.MetricType; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +/** + * Index param to construct field and space. + */ +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonInclude(NON_NULL) +@JsonNaming(SnakeCaseStrategy.class) +public abstract class IndexParam { + + /** + * compute type + */ + protected MetricType metricType; + + protected IndexParam() { + } + + protected IndexParam(MetricType metricType) { + this.metricType = metricType; + } + + public MetricType getMetricType() { + return metricType; + } + + protected abstract static class IndexParamBuilder> { + + protected MetricType metricType; + + public B metricType(MetricType metricType) { + this.metricType = metricType; + return self(); + } + + protected abstract B self(); + + public abstract C build(); + } +} diff --git a/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/index/IndexType.java b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/index/IndexType.java new file mode 100644 index 0000000..d3dee60 --- /dev/null +++ b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/index/IndexType.java @@ -0,0 +1,37 @@ +package dev.langchain4j.community.store.embedding.vearch.index; + +import dev.langchain4j.community.store.embedding.vearch.index.search.BINARYIVFSearchParam; +import dev.langchain4j.community.store.embedding.vearch.index.search.FLATSearchParam; +import dev.langchain4j.community.store.embedding.vearch.index.search.GPUSearchParam; +import dev.langchain4j.community.store.embedding.vearch.index.search.HNSWSearchParam; +import dev.langchain4j.community.store.embedding.vearch.index.search.IVFFLATSearchParam; +import dev.langchain4j.community.store.embedding.vearch.index.search.IVFPQSearchParam; +import dev.langchain4j.community.store.embedding.vearch.index.search.SearchIndexParam; + +public enum IndexType { + + SCALAR(null, null), + IVFPQ(IVFPQParam.class, IVFPQSearchParam.class), + HNSW(HNSWParam.class, HNSWSearchParam.class), + GPU(GPUParam.class, GPUSearchParam.class), + IVFFLAT(IVFFLATParam.class, IVFFLATSearchParam.class), + BINARYIVF(BINARYIVFParam.class, BINARYIVFSearchParam.class), + FLAT(FLATParam.class, FLATSearchParam.class); + + private final Class createSpaceParamClass; + private final Class searchParamClass; + + IndexType(Class createSpaceParamClass, + Class searchParamClass) { + this.createSpaceParamClass = createSpaceParamClass; + this.searchParamClass = searchParamClass; + } + + public Class getCreateSpaceParamClass() { + return createSpaceParamClass; + } + + public Class getSearchParamClass() { + return searchParamClass; + } +} diff --git a/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/index/search/BINARYIVFSearchParam.java b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/index/search/BINARYIVFSearchParam.java new file mode 100644 index 0000000..686defc --- /dev/null +++ b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/index/search/BINARYIVFSearchParam.java @@ -0,0 +1,67 @@ +package dev.langchain4j.community.store.embedding.vearch.index.search; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import dev.langchain4j.community.store.embedding.vearch.MetricType; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonInclude(NON_NULL) +@JsonNaming(SnakeCaseStrategy.class) +public class BINARYIVFSearchParam extends SearchIndexParam { + + private Integer parallelOnQueries; + @JsonProperty("nprob") + private Integer nProb; + + public BINARYIVFSearchParam() { + } + + public BINARYIVFSearchParam(MetricType metricType, Integer parallelOnQueries, Integer nProb) { + super(metricType); + this.parallelOnQueries = parallelOnQueries; + this.nProb = nProb; + } + + public Integer getParallelOnQueries() { + return parallelOnQueries; + } + + public Integer getNProb() { + return nProb; + } + + public static BINARYIVFSearchParamBuilder builder() { + return new BINARYIVFSearchParamBuilder(); + } + + public static class BINARYIVFSearchParamBuilder extends SearchIndexParamBuilder { + + private Integer parallelOnQueries; + private Integer nProb; + + public BINARYIVFSearchParamBuilder parallelOnQueries(Integer parallelOnQueries) { + this.parallelOnQueries = parallelOnQueries; + return this; + } + + public BINARYIVFSearchParamBuilder nProb(Integer nProb) { + this.nProb = nProb; + return this; + } + + @Override + protected BINARYIVFSearchParamBuilder self() { + return this; + } + + @Override + public BINARYIVFSearchParam build() { + return new BINARYIVFSearchParam(metricType, parallelOnQueries, nProb); + } + } +} diff --git a/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/index/search/FLATSearchParam.java b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/index/search/FLATSearchParam.java new file mode 100644 index 0000000..4f0ccbc --- /dev/null +++ b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/index/search/FLATSearchParam.java @@ -0,0 +1,39 @@ +package dev.langchain4j.community.store.embedding.vearch.index.search; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import dev.langchain4j.community.store.embedding.vearch.MetricType; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonInclude(NON_NULL) +@JsonNaming(SnakeCaseStrategy.class) +public class FLATSearchParam extends SearchIndexParam { + + public FLATSearchParam() { + } + + public FLATSearchParam(MetricType metricType) { + super(metricType); + } + + public static FLATSearchParamBuilder builder() { + return new FLATSearchParamBuilder(); + } + + public static class FLATSearchParamBuilder extends SearchIndexParamBuilder { + + @Override + protected FLATSearchParamBuilder self() { + return this; + } + + @Override + public FLATSearchParam build() { + return new FLATSearchParam(metricType); + } + } +} diff --git a/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/index/search/GPUSearchParam.java b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/index/search/GPUSearchParam.java new file mode 100644 index 0000000..da486a4 --- /dev/null +++ b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/index/search/GPUSearchParam.java @@ -0,0 +1,67 @@ +package dev.langchain4j.community.store.embedding.vearch.index.search; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import dev.langchain4j.community.store.embedding.vearch.MetricType; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonInclude(NON_NULL) +@JsonNaming(SnakeCaseStrategy.class) +public class GPUSearchParam extends SearchIndexParam { + + private Integer recallNum; + @JsonProperty("nprob") + private Integer nProb; + + public GPUSearchParam() { + } + + public GPUSearchParam(MetricType metricType, Integer recallNum, Integer nProb) { + super(metricType); + this.recallNum = recallNum; + this.nProb = nProb; + } + + public Integer getRecallNum() { + return recallNum; + } + + public Integer getNProb() { + return nProb; + } + + public static GPUSearchParamBuilder builder() { + return new GPUSearchParamBuilder(); + } + + public static class GPUSearchParamBuilder extends SearchIndexParamBuilder { + + private Integer recallNum; + private Integer nProb; + + public GPUSearchParamBuilder recallNum(Integer recallNum) { + this.recallNum = recallNum; + return this; + } + + public GPUSearchParamBuilder nProb(Integer nProb) { + this.nProb = nProb; + return this; + } + + @Override + protected GPUSearchParamBuilder self() { + return this; + } + + @Override + public GPUSearchParam build() { + return new GPUSearchParam(metricType, recallNum, nProb); + } + } +} diff --git a/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/index/search/HNSWSearchParam.java b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/index/search/HNSWSearchParam.java new file mode 100644 index 0000000..4ecddad --- /dev/null +++ b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/index/search/HNSWSearchParam.java @@ -0,0 +1,55 @@ +package dev.langchain4j.community.store.embedding.vearch.index.search; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import dev.langchain4j.community.store.embedding.vearch.MetricType; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonInclude(NON_NULL) +@JsonNaming(SnakeCaseStrategy.class) +public class HNSWSearchParam extends SearchIndexParam { + + @JsonProperty("efSearch") + private Integer efSearch; + + public HNSWSearchParam() { + } + + public HNSWSearchParam(MetricType metricType, Integer efSearch) { + super(metricType); + this.efSearch = efSearch; + } + + public Integer getEfSearch() { + return efSearch; + } + + public static HNSWSearchParamBuilder builder() { + return new HNSWSearchParamBuilder(); + } + + public static class HNSWSearchParamBuilder extends SearchIndexParamBuilder { + + private Integer efSearch; + + public HNSWSearchParamBuilder efSearch(Integer efSearch) { + this.efSearch = efSearch; + return this; + } + + @Override + protected HNSWSearchParamBuilder self() { + return this; + } + + @Override + public HNSWSearchParam build() { + return new HNSWSearchParam(metricType, efSearch); + } + } +} diff --git a/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/index/search/IVFFLATSearchParam.java b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/index/search/IVFFLATSearchParam.java new file mode 100644 index 0000000..d9a8203 --- /dev/null +++ b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/index/search/IVFFLATSearchParam.java @@ -0,0 +1,67 @@ +package dev.langchain4j.community.store.embedding.vearch.index.search; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import dev.langchain4j.community.store.embedding.vearch.MetricType; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonInclude(NON_NULL) +@JsonNaming(SnakeCaseStrategy.class) +public class IVFFLATSearchParam extends SearchIndexParam { + + private Integer parallelOnQueries; + @JsonProperty("nprob") + private Integer nProb; + + public IVFFLATSearchParam() { + } + + public IVFFLATSearchParam(MetricType metricType, Integer parallelOnQueries, Integer nProb) { + super(metricType); + this.parallelOnQueries = parallelOnQueries; + this.nProb = nProb; + } + + public Integer getParallelOnQueries() { + return parallelOnQueries; + } + + public Integer getNProb() { + return nProb; + } + + public static IVFFLATSearchParamBuilder builder() { + return new IVFFLATSearchParamBuilder(); + } + + public static class IVFFLATSearchParamBuilder extends SearchIndexParamBuilder { + + private Integer parallelOnQueries; + private Integer nProb; + + public IVFFLATSearchParamBuilder parallelOnQueries(Integer parallelOnQueries) { + this.parallelOnQueries = parallelOnQueries; + return this; + } + + public IVFFLATSearchParamBuilder nProb(Integer nProb) { + this.nProb = nProb; + return this; + } + + @Override + protected IVFFLATSearchParamBuilder self() { + return this; + } + + @Override + public IVFFLATSearchParam build() { + return new IVFFLATSearchParam(metricType, parallelOnQueries, nProb); + } + } +} diff --git a/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/index/search/IVFPQSearchParam.java b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/index/search/IVFPQSearchParam.java new file mode 100644 index 0000000..31c49e4 --- /dev/null +++ b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/index/search/IVFPQSearchParam.java @@ -0,0 +1,79 @@ +package dev.langchain4j.community.store.embedding.vearch.index.search; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import dev.langchain4j.community.store.embedding.vearch.MetricType; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonInclude(NON_NULL) +@JsonNaming(SnakeCaseStrategy.class) +public class IVFPQSearchParam extends SearchIndexParam { + + private Integer parallelOnQueries; + private Integer recallNum; + @JsonProperty("nprob") + private Integer nProb; + + public IVFPQSearchParam() { + } + + public IVFPQSearchParam(MetricType metricType, Integer parallelOnQueries, Integer recallNum, Integer nProb) { + super(metricType); + this.parallelOnQueries = parallelOnQueries; + this.recallNum = recallNum; + this.nProb = nProb; + } + + public Integer getParallelOnQueries() { + return parallelOnQueries; + } + + public Integer getRecallNum() { + return recallNum; + } + + public Integer getNProb() { + return nProb; + } + + public static IVFPQSearchParamBuilder builder() { + return new IVFPQSearchParamBuilder(); + } + + public static class IVFPQSearchParamBuilder extends SearchIndexParamBuilder { + + private Integer parallelOnQueries; + private Integer recallNum; + private Integer nProb; + + public IVFPQSearchParamBuilder parallelOnQueries(Integer parallelOnQueries) { + this.parallelOnQueries = parallelOnQueries; + return this; + } + + public IVFPQSearchParamBuilder recallNum(Integer recallNum) { + this.recallNum = recallNum; + return this; + } + + public IVFPQSearchParamBuilder nProb(Integer nProb) { + this.nProb = nProb; + return this; + } + + @Override + protected IVFPQSearchParamBuilder self() { + return this; + } + + @Override + public IVFPQSearchParam build() { + return new IVFPQSearchParam(metricType, parallelOnQueries, recallNum, nProb); + } + } +} diff --git a/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/index/search/SearchIndexParam.java b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/index/search/SearchIndexParam.java new file mode 100644 index 0000000..fc94cc5 --- /dev/null +++ b/embedding-stores/langchain4j-community-vearch/src/main/java/dev/langchain4j/community/store/embedding/vearch/index/search/SearchIndexParam.java @@ -0,0 +1,42 @@ +package dev.langchain4j.community.store.embedding.vearch.index.search; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies.SnakeCaseStrategy; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import dev.langchain4j.community.store.embedding.vearch.MetricType; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonInclude(NON_NULL) +@JsonNaming(SnakeCaseStrategy.class) +public abstract class SearchIndexParam { + + protected MetricType metricType; + + protected SearchIndexParam() { + } + + protected SearchIndexParam(MetricType metricType) { + this.metricType = metricType; + } + + public MetricType getMetricType() { + return metricType; + } + + protected abstract static class SearchIndexParamBuilder> { + + protected MetricType metricType; + + public B metricType(MetricType metricType) { + this.metricType = metricType; + return self(); + } + + protected abstract B self(); + + public abstract C build(); + } +} diff --git a/embedding-stores/langchain4j-community-vearch/src/test/java/dev/langchain4j/community/store/embedding/vearch/TestUtils.java b/embedding-stores/langchain4j-community-vearch/src/test/java/dev/langchain4j/community/store/embedding/vearch/TestUtils.java new file mode 100644 index 0000000..ba7e418 --- /dev/null +++ b/embedding-stores/langchain4j-community-vearch/src/test/java/dev/langchain4j/community/store/embedding/vearch/TestUtils.java @@ -0,0 +1,26 @@ +package dev.langchain4j.community.store.embedding.vearch; + +import org.junit.jupiter.api.TestInfo; + +import java.lang.reflect.Method; +import java.util.Optional; + +class TestUtils { + + private TestUtils() { + + } + + static boolean isMethodFromClass(TestInfo testInfo, Class clazz) { + try { + Optional method = testInfo.getTestMethod(); + if (method.isPresent()) { + String methodName = method.get().getName(); + return clazz.getDeclaredMethod(methodName) != null; + } + return false; + } catch (NoSuchMethodException e) { + return false; + } + } +} diff --git a/embedding-stores/langchain4j-community-vearch/src/test/java/dev/langchain4j/community/store/embedding/vearch/VearchContainer.java b/embedding-stores/langchain4j-community-vearch/src/test/java/dev/langchain4j/community/store/embedding/vearch/VearchContainer.java new file mode 100644 index 0000000..0a89d46 --- /dev/null +++ b/embedding-stores/langchain4j-community-vearch/src/test/java/dev/langchain4j/community/store/embedding/vearch/VearchContainer.java @@ -0,0 +1,14 @@ +package dev.langchain4j.community.store.embedding.vearch; + +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.utility.MountableFile; + +public class VearchContainer extends GenericContainer { + + public VearchContainer() { + super("vearch/vearch:latest"); + withExposedPorts(9001, 8817); + withCommand("all"); + withCopyFileToContainer(MountableFile.forClasspathResource("config.toml"), "/vearch/config.toml"); + } +} diff --git a/embedding-stores/langchain4j-community-vearch/src/test/java/dev/langchain4j/community/store/embedding/vearch/VearchEmbeddingStoreIT.java b/embedding-stores/langchain4j-community-vearch/src/test/java/dev/langchain4j/community/store/embedding/vearch/VearchEmbeddingStoreIT.java new file mode 100644 index 0000000..a924178 --- /dev/null +++ b/embedding-stores/langchain4j-community-vearch/src/test/java/dev/langchain4j/community/store/embedding/vearch/VearchEmbeddingStoreIT.java @@ -0,0 +1,198 @@ +package dev.langchain4j.community.store.embedding.vearch; + +import dev.langchain4j.community.store.embedding.vearch.field.Field; +import dev.langchain4j.community.store.embedding.vearch.field.FieldType; +import dev.langchain4j.community.store.embedding.vearch.field.NumericField; +import dev.langchain4j.community.store.embedding.vearch.field.StringField; +import dev.langchain4j.community.store.embedding.vearch.field.VectorField; +import dev.langchain4j.community.store.embedding.vearch.index.HNSWParam; +import dev.langchain4j.community.store.embedding.vearch.index.Index; +import dev.langchain4j.community.store.embedding.vearch.index.IndexType; +import dev.langchain4j.community.store.embedding.vearch.index.search.HNSWSearchParam; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel; +import dev.langchain4j.store.embedding.EmbeddingStore; +import dev.langchain4j.store.embedding.EmbeddingStoreIT; +import dev.langchain4j.store.embedding.EmbeddingStoreWithoutMetadataIT; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInfo; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.ThreadLocalRandom; + +import static dev.langchain4j.community.store.embedding.vearch.TestUtils.isMethodFromClass; +import static org.assertj.core.api.Assertions.assertThat; + +class VearchEmbeddingStoreIT extends EmbeddingStoreIT { + + static VearchContainer vearch = new VearchContainer(); + + VearchEmbeddingStore embeddingStore; + + EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); + + /** + * in order to clear embedding store + */ + static VearchClient vearchClient; + + static String databaseName; + + static String spaceName; + + static String baseUrl; + + @BeforeAll + static void start() { + vearch.start(); + + databaseName = "embedding_db"; + spaceName = "embedding_space_" + ThreadLocalRandom.current().nextInt(0, Integer.MAX_VALUE); + baseUrl = "http://" + vearch.getHost() + ":" + vearch.getMappedPort(9001); + vearchClient = VearchClient.builder() + .baseUrl(baseUrl) + .timeout(Duration.ofSeconds(60)) + .build(); + } + + @AfterAll + static void stop() { + vearch.stop(); + } + + @BeforeEach + void beforeEach(TestInfo testInfo) { + if (isMethodFromClass(testInfo, EmbeddingStoreIT.class)) { + buildEmbeddingStoreWithMetadata(); + } else if (isMethodFromClass(testInfo, EmbeddingStoreWithoutMetadataIT.class) || isMethodFromClass(testInfo, VearchEmbeddingStoreIT.class)) { + buildEmbeddingStoreWithoutMetadata(); + } + } + + private void buildEmbeddingStoreWithMetadata() { + buildEmbeddingStore(true); + } + + private void buildEmbeddingStoreWithoutMetadata() { + buildEmbeddingStore(false); + } + + private void buildEmbeddingStore(boolean withMetadata) { + String embeddingFieldName = "text_embedding"; + String textFieldName = "text"; + Map metadata = createMetadata().toMap(); + + // init fields + List fields = new ArrayList<>(4); + List metadataFieldNames = new ArrayList<>(); + fields.add(VectorField.builder() + .name(embeddingFieldName) + .dimension(embeddingModel.dimension()) + .index(Index.builder() + .name("gamma") + .type(IndexType.HNSW) + .params(HNSWParam.builder() + .metricType(MetricType.INNER_PRODUCT) + .efConstruction(100) + .nLinks(32) + .efSearch(64) + .build()) + .build()) + .build() + ); + fields.add(StringField.builder().name(textFieldName).fieldType(FieldType.STRING).build()); + if (withMetadata) { + // metadata + for (Map.Entry entry : metadata.entrySet()) { + String key = entry.getKey(); + Object value = entry.getValue(); + if (value instanceof String || value instanceof UUID) { + fields.add(StringField.builder().name(key).fieldType(FieldType.STRING).build()); + } else if (value instanceof Integer) { + fields.add(NumericField.builder().name(key).fieldType(FieldType.INTEGER).build()); + } else if (value instanceof Long) { + fields.add(NumericField.builder().name(key).fieldType(FieldType.LONG).build()); + } else if (value instanceof Float) { + fields.add(NumericField.builder().name(key).fieldType(FieldType.FLOAT).build()); + } else if (value instanceof Double) { + fields.add(NumericField.builder().name(key).fieldType(FieldType.DOUBLE).build()); + } + } + } + + // init vearch config + spaceName = "embedding_space_" + ThreadLocalRandom.current().nextInt(0, Integer.MAX_VALUE); + VearchConfig vearchConfig = VearchConfig.builder() + .databaseName(databaseName) + .spaceName(spaceName) + .textFieldName(textFieldName) + .embeddingFieldName(embeddingFieldName) + .fields(fields) + .metadataFieldNames(metadataFieldNames) + .searchIndexParam(HNSWSearchParam.builder() + .metricType(MetricType.INNER_PRODUCT) + .efSearch(64) + .build()) + .build(); + if (withMetadata) { + vearchConfig.setMetadataFieldNames(new ArrayList<>(metadata.keySet())); + } + + // init embedding store + embeddingStore = VearchEmbeddingStore.builder() + .vearchConfig(vearchConfig) + .baseUrl(baseUrl) + .logRequests(true) + .logResponses(true) + .build(); + } + + @Override + protected EmbeddingStore embeddingStore() { + return embeddingStore; + } + + @Override + protected EmbeddingModel embeddingModel() { + return embeddingModel; + } + + @Override + protected void clearStore() { + vearchClient.deleteSpace(databaseName, spaceName); + + buildEmbeddingStoreWithMetadata(); + } + + @Override + protected void ensureStoreIsEmpty() { + // This method should be skipped because the @BeforeEach method of the parent class is called before the @BeforeEach method of the child class + // This test manually create Space at @BeforeEach, so it's guaranteed that the EmbeddingStore is empty + } + + @Override + protected boolean testFloatExactly() { + return false; + } + + @Override + protected boolean testDoubleExactly() { + return false; + } + + @Test + void should_delete_space() { + embeddingStore.deleteSpace(); + List actual = vearchClient.listSpaceOfDatabase(databaseName); + assertThat(actual.stream().map(ListSpaceResponse::getName)).doesNotContain(spaceName); + } + +} diff --git a/embedding-stores/langchain4j-community-vearch/src/test/resources/config.toml b/embedding-stores/langchain4j-community-vearch/src/test/resources/config.toml new file mode 100644 index 0000000..5f83ac0 --- /dev/null +++ b/embedding-stores/langchain4j-community-vearch/src/test/resources/config.toml @@ -0,0 +1,86 @@ +[global] + # the name will validate join cluster by same name + name = "cbdb" + # specify which resources to use to create space + resource_name = "default" + # you data save to disk path ,If you are in a production environment, You'd better set absolute paths + data = ["datas/","datas1/"] + # log path , If you are in a production environment, You'd better set absolute paths + log = "logs/" + # default log type for any model + level = "debug" + # master <-> ps <-> router will use this key to send or receive data + signkey = "secret" + # skip auth for master and router + skip_auth = true + # tell Vearch whether it should manage it's own instance of etcd or not + self_manage_etcd = false + # automatically remove the failed node and recover when new nodes join + auto_recover_ps = false + # support access etcd basic auth,depend on self_manage_etcd = true + support_etcd_auth = false + # ensure leader-follow raft data synchronization is consistent + raft_consistent = false + +# self_manage_etcd = true,means manage etcd by yourself,need provide additional configuration +[etcd] + #etcd server ip or domain + address = ["127.0.0.1"] + # advertise_client_urls AND listen_client_urls + etcd_client_port = 2379 + # provider username and password,if you turn on auth + user_name = "root" + password = "" + +# if you are master you'd better set all config for router and ps and router and ps use default config it so cool +[[masters]] + #name machine name for cluster + name = "m1" + #ip or domain + address = "127.0.0.1" + # api port for http server + api_port = 8817 + # port for etcd server + etcd_port = 2378 + # listen_peer_urls List of comma separated URLs to listen on for peer traffic. + # advertise_peer_urls List of this member's peer URLs to advertise to the rest of the cluster. The URLs needed to be a comma-separated list. + etcd_peer_port = 2390 + # List of this member's client URLs to advertise to the public. + # The URLs needed to be a comma-separated list. + # advertise_client_urls AND listen_client_urls + etcd_client_port = 2370 + # init cluster state + cluster_state = "new" + pprof_port = 6062 + # monitor + monitor_port = 8818 + +[router] + # port for server + port = 9001 + # rpc_port = 9002 + pprof_port = 6061 + plugin_path = "plugin" + +[ps] + # port for server + rpc_port = 8081 + ps_heartbeat_timeout = 5 #seconds + #raft config begin + raft_heartbeat_port = 8898 + raft_replicate_port = 8899 + heartbeat-interval = 200 #ms + raft_retain_logs = 20000000 + raft_replica_concurrency = 1 + raft_snap_concurrency = 1 + raft_truncate_count = 500000 + #when behind leader this value,will stop the server for search + raft_diff_count = 10000 + # engine config + engine_dwpt_num = 8 + pprof_port = 6060 + # if set true , this ps only use in db meta config + private = false + # seconds + flush_time_interval = 600 + flush_count_threshold = 200000 diff --git a/langchain4j-community-bom/pom.xml b/langchain4j-community-bom/pom.xml index 3475d05..506a8fb 100644 --- a/langchain4j-community-bom/pom.xml +++ b/langchain4j-community-bom/pom.xml @@ -34,6 +34,12 @@ ${project.version} + + dev.langchain4j + langchain4j-community-vearch + ${project.version} + + dev.langchain4j diff --git a/pom.xml b/pom.xml index 0827d1e..07480ce 100644 --- a/pom.xml +++ b/pom.xml @@ -32,6 +32,7 @@ embedding-stores/langchain4j-community-clickhouse + embedding-stores/langchain4j-community-vearch web-search-engines/langchain4j-community-web-search-engine-searxng