Skip to content

Commit

Permalink
Updating Reciever Connection data structure
Browse files Browse the repository at this point in the history
  • Loading branch information
Jaland committed Oct 30, 2024
1 parent 275b403 commit 23f5af0
Show file tree
Hide file tree
Showing 29 changed files with 639 additions and 430 deletions.
24 changes: 24 additions & 0 deletions documentation/admin_workflow.MD
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Admin Workflow

Backend includes the ability to dynamically create Assistants, Rag Connections, and LLM Connections

## Content Retriever(RAG) Connction

### Weaviate Example

```json
{
"baseRetrieverRequest": {
"contentRetrieverType": "weaviate",
"textKey": "source",
"metadata": ["source","page_number", "title"],
"index": "my_custom_index",
"scheme": "http",
"host": "my-weaviate-instance.svc:8080",
"apiKey" : "ABCDE12345"
},
"embeddingType": "noimc",
"name": "Example Rag Connection",
"description": "Example of a rag connection"
}
```
21 changes: 20 additions & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
<skipITs>true</skipITs>
<surefire-plugin.version>3.3.1</surefire-plugin.version>
<langchain4j.version>0.35.0</langchain4j.version>
<quarkus.langchain4j.version>0.21.0.CR1</quarkus.langchain4j.version>
<org.mapstruct.version>1.6.2</org.mapstruct.version>
</properties>

<dependencyManagement>
Expand Down Expand Up @@ -63,6 +65,11 @@
<artifactId>langchain4j-weaviate</artifactId>
<version>${langchain4j.version}</version>
</dependency>
<!-- <dependency>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-openai</artifactId>
<version>${quarkus.langchain4j.version}</version>
</dependency> -->
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-open-ai</artifactId>
Expand All @@ -76,7 +83,7 @@
<dependency>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-hugging-face</artifactId>
<version>0.18.0</version>
<version>${quarkus.langchain4j.version}</version>
</dependency>
<dependency>
<groupId>io.quarkus</groupId>
Expand Down Expand Up @@ -120,6 +127,18 @@
<artifactId>rest-assured</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mapstruct</groupId>
<artifactId>mapstruct</artifactId>
<version>${org.mapstruct.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.mapstruct</groupId>
<artifactId>mapstruct-processor</artifactId>
<version>${org.mapstruct.version}</version>
<scope>provided</scope>
</dependency>
</dependencies>

<build>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public List<LLMConnectionEntity> getLLMs() {
@POST
@Path("retrieverConnection")
public RetrieverConnectionEntity createRetrieverConnection(RetrieverRequest request) {
return assistantService.creaRetrieverConnectionEntity(request);
return assistantService.createRetrieverConnectionEntity(request);
}

@GET
Expand Down
1 change: 0 additions & 1 deletion src/main/java/com/redhat/composer/api/EmbeddingAPI.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import io.quarkus.security.Authenticated;
import jakarta.inject.Inject;
import jakarta.ws.rs.Consumes;
import jakarta.ws.rs.GET;
import jakarta.ws.rs.POST;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.PathParam;
Expand Down
16 changes: 12 additions & 4 deletions src/main/java/com/redhat/composer/api/VectorRetriverAPI.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
import com.redhat.composer.model.request.RetrieverRequest;
import com.redhat.composer.model.response.SourceResponse;
import com.redhat.composer.services.RetrieveService;
import com.redhat.composer.util.MapperUtil;
import com.redhat.composer.util.mappers.MapperUtil;

import dev.langchain4j.rag.content.Content;
import io.quarkus.security.Authenticated;
import jakarta.inject.Inject;
import jakarta.ws.rs.POST;
Expand All @@ -24,10 +25,13 @@ public class VectorRetriverAPI {
@Inject
RetrieveService retrieveService;

@Inject
MapperUtil mapperUtil;

@POST
@Path("/sources")
public List<SourceResponse> retrieveSources(RetrieverRequest request, @QueryParam("message") String message) {
return retrieveService.retrieveContent(request, message).stream().map(MapperUtil::toSourceResponse).toList();
return retrieveService.retrieveContent(request, message).stream().map(VectorRetriverAPI::toSourceResponse).toList();
}

@POST
Expand All @@ -39,7 +43,11 @@ public List<Map<String,Object>> retrieveSourceMetadata(RetrieverRequest request,
.toList();
}



public static SourceResponse toSourceResponse(Content content) {
SourceResponse response = new SourceResponse();
response.setContent(content.textSegment().text());
response.setMetadata(content.textSegment().metadata().toMap());
return response;
}

}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@
@ApplicationScoped
public class StreamingModelFactory {

@Inject
MistralStreamingModel mistralModel;
public static final String MISTRAL_MODEL = "mistral";

@Inject
OpenAIStreamingModel openAIModel;
public static final String OPENAI_MODEL = "openai";
Expand All @@ -23,8 +19,6 @@ public StreamingBaseModel getModel(String modelType) {
}

switch (modelType) {
case MISTRAL_MODEL:
return mistralModel;
case OPENAI_MODEL:
return openAIModel;
default:
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@
@ApplicationScoped
public class SynchronousModelFactory {

@Inject
MistralModel mistralModel;
public static final String MISTRAL_MODEL = "mistral";

@Inject
OpenAIModel openAIModel;
public static final String OPENAI_MODEL = "openai";
Expand All @@ -23,8 +19,6 @@ public SynchronousBaseModel getModel(String modelType) {
}

switch (modelType) {
case MISTRAL_MODEL:
return mistralModel;
case OPENAI_MODEL:
return openAIModel;
default:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package com.redhat.composer.config.retriever.contentRetriever;

import com.redhat.composer.model.enums.ContentRetrieverType;

import jakarta.enterprise.context.ApplicationScoped;
import jakarta.inject.Inject;

Expand All @@ -8,26 +10,24 @@ public class ContentRetrieverClientFactory {

@Inject
WeaviateContentRetrieverClient weaviateEmbeddingStoreClient;
public static final String WEAVIATE_CONTENT_RETRIEVER = "weaviate";


@Inject
Neo4jContentRetrieverClient neo4jContentRetrieverClient;
public static final String NEO4J_CONTENT_RETRIEVER = "neo4j";

final static String DEFAULT_CONTENT_RETRIEVER = WEAVIATE_CONTENT_RETRIEVER;
final static ContentRetrieverType DEFAULT_CONTENT_RETRIEVER = ContentRetrieverType.WEAVIATE;

public BaseContentRetrieverClient getContentRetrieverClient(String contentRetrieverType) {
public BaseContentRetrieverClient getContentRetrieverClient(ContentRetrieverType contentRetrieverType) {
if(contentRetrieverType == null) {
contentRetrieverType = DEFAULT_CONTENT_RETRIEVER;
}

switch (contentRetrieverType) {
case WEAVIATE_CONTENT_RETRIEVER:
case ContentRetrieverType.WEAVIATE:
return weaviateEmbeddingStoreClient;
case NEO4J_CONTENT_RETRIEVER:
case ContentRetrieverType.NEO4J:
return neo4jContentRetrieverClient;
default:
throw new RuntimeException("Embedding type not found: " + contentRetrieverType);
throw new RuntimeException("Content Retriever type not found: " + contentRetrieverType);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import org.eclipse.microprofile.config.inject.ConfigProperty;
import org.jboss.logging.Logger;

import com.redhat.composer.api.ChatBotAPI;
import com.redhat.composer.config.retriever.contentRetriever.custom.WeaviateEmbeddingStoreCustom;
import com.redhat.composer.model.request.RetrieverRequest;
import com.redhat.composer.model.request.retriever.WeaviateRequest;

import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.rag.content.retriever.ContentRetriever;
Expand Down Expand Up @@ -33,11 +33,15 @@ public class WeaviateContentRetrieverClient extends BaseContentRetrieverClient {
private String weaviateTextKey;

public ContentRetriever getContentRetriever(RetrieverRequest request) {
String scheme = request.getScheme() != null ? request.getScheme() : weaviateScheme;
String host = request.getHost() != null ? request.getHost() : weaviateHost;
String apiKey = request.getApiKey() != null ? request.getApiKey() : weaviateApiKey;
String index = request.getIndex() != null ? request.getIndex() : weaviateIndex;
String textKey = request.getTextKey() != null ? request.getTextKey() : weaviateTextKey;
WeaviateRequest weaviateRequest = (WeaviateRequest) request.getBaseRetrieverRequest();
if(weaviateRequest == null) {
weaviateRequest = new WeaviateRequest();
}
String scheme = weaviateRequest.getScheme() != null ? weaviateRequest.getScheme() : weaviateScheme;
String host = weaviateRequest.getHost() != null ? weaviateRequest.getHost() : weaviateHost;
String apiKey = weaviateRequest.getApiKey() != null ? weaviateRequest.getApiKey() : weaviateApiKey;
String index = weaviateRequest.getIndex() != null ? weaviateRequest.getIndex() : weaviateIndex;
String textKey = weaviateRequest.getTextKey() != null ? weaviateRequest.getTextKey() : weaviateTextKey;


log.info("Attempting to connect to Weaviate at " + scheme + "://" + host + " with index " + index);
Expand All @@ -47,7 +51,7 @@ public ContentRetriever getContentRetriever(RetrieverRequest request) {
.host(host)
.apiKey(apiKey)
.metadataParentKey("")
.metadataKeys(request.getMetadataFields())
.metadataKeys(weaviateRequest.getMetadataFields())
.objectClass(index)
.avoidDups(true)
.textKey(textKey)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package com.redhat.composer.model.enums;

public enum ContentRetrieverType {

WEAVIATE("weaviate"),
NEO4J("neo4j");

private final String type;

ContentRetrieverType(String type) {
this.type = type;
}

public String getType() {
return type;
}

public static ContentRetrieverType fromString(String type) {
for (ContentRetrieverType retrieverType : ContentRetrieverType.values()) {
if (retrieverType.getType().equalsIgnoreCase(type)) {
return retrieverType;
}
}
throw new IllegalArgumentException("No constant with type " + type + " found");
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

import org.apache.commons.lang3.builder.EqualsBuilder;

import com.redhat.composer.config.llm.models.streaming.StreamingModelFactory;

import io.quarkus.mongodb.panache.common.MongoEntity;

@MongoEntity(collection = "llm_connection")
Expand Down
Loading

0 comments on commit 23f5af0

Please sign in to comment.