Skip to content

Commit

Permalink
added elasticsearch vector db (#7)
Browse files Browse the repository at this point in the history
* added elasticsearch vector db

* fixed checkstyle errors
  • Loading branch information
aokugel authored Nov 22, 2024
1 parent 491735c commit a684b2f
Show file tree
Hide file tree
Showing 10 changed files with 420 additions and 5 deletions.
5 changes: 5 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@
<artifactId>langchain4j-weaviate</artifactId>
<version>${langchain4j.version}</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-elasticsearch</artifactId>
<version>${langchain4j.version}</version>
</dependency>
<!-- <dependency>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-openai</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ public class ContentRetrieverClientFactory {

@Inject
Neo4jContentRetrieverClient neo4jContentRetrieverClient;


@Inject
ElasticsearchContentRetrieverClient elasticsearchContentRetrieverClient;

static final ContentRetrieverType DEFAULT_CONTENT_RETRIEVER = ContentRetrieverType.WEAVIATE;

/**
Expand All @@ -35,6 +37,8 @@ public BaseContentRetrieverClient getContentRetrieverClient(ContentRetrieverType
return weaviateEmbeddingStoreClient;
case ContentRetrieverType.NEO4J:
return neo4jContentRetrieverClient;
case ContentRetrieverType.ELASTICSEARCH:
return elasticsearchContentRetrieverClient;
default:
throw new RuntimeException("Content Retriever type not found: " + contentRetrieverType);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package com.redhat.composer.config.retriever.contentretriever;

import org.eclipse.microprofile.config.inject.ConfigProperty;
import org.jboss.logging.Logger;

//import com.redhat.composer.config.retriever.contentretriever.custom.WeaviateEmbeddingStoreCustom;
import com.redhat.composer.model.request.RetrieverRequest;
//import com.redhat.composer.model.request.retriever.WeaviateRequest;
import com.redhat.composer.model.request.retriever.ElasticsearchRequest;

import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.elasticsearch.ElasticsearchEmbeddingStore;
import org.apache.http.impl.client.BasicCredentialsProvider;
import org.apache.http.message.BasicHeader;
import org.apache.http.auth.UsernamePasswordCredentials;
import org.apache.http.client.CredentialsProvider;
import org.apache.http.auth.AuthScope;
import org.apache.http.Header;
import org.apache.http.HttpHost;
import org.elasticsearch.client.RestClient;
import jakarta.inject.Singleton;

/**
* Elasticsearch Content Retriever Client.
*/
@Singleton
public class ElasticsearchContentRetrieverClient extends BaseContentRetrieverClient {

Logger log = Logger.getLogger(ElasticsearchContentRetrieverClient.class);

@ConfigProperty(name = "elasticsearch.default.host")
private String elasticHost;

@ConfigProperty(name = "elasticsearch.default.user")
private String elasticUser;

@ConfigProperty(name = "elasticsearch.default.password")
private String elasticPassword;

@ConfigProperty(name = "elasticsearch.default.index")
private String elasticIndex;

/**
* Get the Content Retriever.
* @param request the RetrieverRequest
* @return the Content Retriever
*/
public ContentRetriever getContentRetriever(RetrieverRequest request) {
ElasticsearchRequest elasticsearchRequest = (ElasticsearchRequest) request.getBaseRetrieverRequest();
if (elasticsearchRequest == null) {
elasticsearchRequest = new ElasticsearchRequest();
}
String host = elasticsearchRequest.getHost() != null ? elasticsearchRequest.getHost() : elasticHost;
String user = elasticsearchRequest.getUser() != null ? elasticsearchRequest.getUser() : elasticUser;
String pass = elasticsearchRequest.getPassword() != null ? elasticsearchRequest.getPassword() : elasticPassword;
String index = elasticsearchRequest.getIndex() != null ? elasticsearchRequest.getIndex() : elasticIndex;

// TODO: Make this configurable for different authentication types
final CredentialsProvider credentialsProvider = new BasicCredentialsProvider();
credentialsProvider.setCredentials(AuthScope.ANY, new UsernamePasswordCredentials(user, pass));

RestClient restClient = RestClient
.builder(HttpHost.create(host))
.setHttpClientConfigCallback(httpClientBuilder -> {
httpClientBuilder.setDefaultCredentialsProvider(credentialsProvider);
return httpClientBuilder;
})
.build();

log.debug("Attempting to connect to Elasticsearch at " + host + " with index " + index);

EmbeddingStore<TextSegment> store = ElasticsearchEmbeddingStore.builder()
.indexName(index)
.restClient(restClient)
.build();

// Retrieve the embedding model
EmbeddingModel embeddingModel = getEmbeddingModel(request.getEmbeddingType());

// Create the content retriever
ContentRetriever contentRetriever = EmbeddingStoreContentRetriever.builder()
.embeddingStore(store)
.embeddingModel(embeddingModel)
.build();

return contentRetriever;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
public enum ContentRetrieverType {

WEAVIATE("weaviate"),
NEO4J("neo4j");
NEO4J("neo4j"),
ELASTICSEARCH("elasticsearch");

private final String type;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
package com.redhat.composer.model.mongo.contentretrieverentites;

import org.apache.commons.lang3.builder.EqualsBuilder;
import org.bson.codecs.pojo.annotations.BsonDiscriminator;

import com.redhat.composer.model.enums.ContentRetrieverType;

import java.util.List;
import java.util.Objects;

/**
* Elasticsearch Connection Entity.
*/
@SuppressWarnings("all")
@BsonDiscriminator("Elasticsearch")
public class ElasticsearchConnectionEntity extends BaseRetrieverConnectionEntity {

// Key of the value containing the text used for retrieval and passed into the LLM
String textKey;

// List of metadata fields to be retrieved as part of the content
List<String> metadataFields = List.of("source");

String index;

String host;

String user;

String password;

public ElasticsearchConnectionEntity() {
contentRetrieverType = ContentRetrieverType.ELASTICSEARCH;
}

public String getTextKey() {
return this.textKey;
}

public void setTextKey(String textKey) {
this.textKey = textKey;
}

public List<String> getMetadataFields() {
return this.metadataFields;
}

public void setMetadataFields(List<String> metadataFields) {
this.metadataFields = metadataFields;
}

public String getIndex() {
return this.index;
}

public void setIndex(String index) {
this.index = index;
}

public String getHost() {
return this.host;
}

public void setHost(String host) {
this.host = host;
}

public String getUser() {
return this.user;
}

public String getPassword() {
return this.password;
}

public void setUser(String user) {
this.user = user;
}

public void setPassword(String password) {
this.password = password;
}

public ElasticsearchConnectionEntity textKey(String textKey) {
setTextKey(textKey);
return this;
}

public ElasticsearchConnectionEntity metadataFields(List<String> metadataFields) {
setMetadataFields(metadataFields);
return this;
}

public ElasticsearchConnectionEntity index(String index) {
setIndex(index);
return this;
}

public ElasticsearchConnectionEntity host(String host) {
setHost(host);
return this;
}

public ElasticsearchConnectionEntity user(String user) {
setUser(user);
return this;
}

public ElasticsearchConnectionEntity password(String password) {
setPassword(password);
return this;
}

@Override
public boolean equals(Object o) {
return EqualsBuilder.reflectionEquals(this, o);
}

@Override
public int hashCode() {
return Objects.hash(textKey, metadataFields, index, host, user, password);
}

@Override
public String toString() {
return "{" +
" textKey='" + getTextKey() + "'" +
", metadataFields='" + getMetadataFields() + "'" +
", index='" + getIndex() + "'" +
", host='" + getHost() + "'" +
", apiKey='" + getUser() + "'" +
", apiKey='" + getPassword() + "'" +
"}";
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.EXISTING_PROPERTY, property = "contentRetrieverType", visible = true)
@JsonSubTypes({
@JsonSubTypes.Type(value = WeaviateRequest.class, name = "weaviate"),
@JsonSubTypes.Type(value = Neo4JRequest.class, name = "neo4j")
@JsonSubTypes.Type(value = Neo4JRequest.class, name = "neo4j"),
@JsonSubTypes.Type(value = ElasticsearchRequest.class, name = "elasticsearch")
})
public class BaseRetrieverRequest {

Expand Down
Loading

0 comments on commit a684b2f

Please sign in to comment.