diff --git a/documentation/admin_workflow.MD b/documentation/admin_workflow.MD new file mode 100644 index 0000000..9ce84dc --- /dev/null +++ b/documentation/admin_workflow.MD @@ -0,0 +1,24 @@ +# Admin Workflow + +Backend includes the ability to dynamically create Assistants, Rag Connections, and LLM Connections + +## Content Retriever(RAG) Connction + +### Weaviate Example + +```json +{ + "baseRetrieverRequest": { + "contentRetrieverType": "weaviate", + "textKey": "source", + "metadata": ["source","page_number", "title"], + "index": "my_custom_index", + "scheme": "http", + "host": "my-weaviate-instance.svc:8080", + "apiKey" : "ABCDE12345" + }, + "embeddingType": "noimc", + "name": "Example Rag Connection", + "description": "Example of a rag connection" +} +``` diff --git a/pom.xml b/pom.xml index f7015d1..4601695 100644 --- a/pom.xml +++ b/pom.xml @@ -16,6 +16,8 @@ true 3.3.1 0.35.0 + 0.21.0.CR1 + 1.6.2 @@ -63,6 +65,11 @@ langchain4j-weaviate ${langchain4j.version} + dev.langchain4j langchain4j-open-ai @@ -76,7 +83,7 @@ io.quarkiverse.langchain4j quarkus-langchain4j-hugging-face - 0.18.0 + ${quarkus.langchain4j.version} io.quarkus @@ -120,6 +127,18 @@ rest-assured test + + org.mapstruct + mapstruct + ${org.mapstruct.version} + provided + + + org.mapstruct + mapstruct-processor + ${org.mapstruct.version} + provided + diff --git a/src/main/java/com/redhat/composer/api/AssistantAdminAPI.java b/src/main/java/com/redhat/composer/api/AssistantAdminAPI.java index f27df6b..68ec838 100644 --- a/src/main/java/com/redhat/composer/api/AssistantAdminAPI.java +++ b/src/main/java/com/redhat/composer/api/AssistantAdminAPI.java @@ -39,7 +39,7 @@ public List getLLMs() { @POST @Path("retrieverConnection") public RetrieverConnectionEntity createRetrieverConnection(RetrieverRequest request) { - return assistantService.creaRetrieverConnectionEntity(request); + return assistantService.createRetrieverConnectionEntity(request); } @GET diff --git a/src/main/java/com/redhat/composer/api/EmbeddingAPI.java b/src/main/java/com/redhat/composer/api/EmbeddingAPI.java index 6a1fceb..6fa0da2 100644 --- a/src/main/java/com/redhat/composer/api/EmbeddingAPI.java +++ b/src/main/java/com/redhat/composer/api/EmbeddingAPI.java @@ -5,7 +5,6 @@ import io.quarkus.security.Authenticated; import jakarta.inject.Inject; import jakarta.ws.rs.Consumes; -import jakarta.ws.rs.GET; import jakarta.ws.rs.POST; import jakarta.ws.rs.Path; import jakarta.ws.rs.PathParam; diff --git a/src/main/java/com/redhat/composer/api/VectorRetriverAPI.java b/src/main/java/com/redhat/composer/api/VectorRetriverAPI.java index 0d27115..9bfe3e7 100644 --- a/src/main/java/com/redhat/composer/api/VectorRetriverAPI.java +++ b/src/main/java/com/redhat/composer/api/VectorRetriverAPI.java @@ -6,8 +6,9 @@ import com.redhat.composer.model.request.RetrieverRequest; import com.redhat.composer.model.response.SourceResponse; import com.redhat.composer.services.RetrieveService; -import com.redhat.composer.util.MapperUtil; +import com.redhat.composer.util.mappers.MapperUtil; +import dev.langchain4j.rag.content.Content; import io.quarkus.security.Authenticated; import jakarta.inject.Inject; import jakarta.ws.rs.POST; @@ -24,10 +25,13 @@ public class VectorRetriverAPI { @Inject RetrieveService retrieveService; + @Inject + MapperUtil mapperUtil; + @POST @Path("/sources") public List retrieveSources(RetrieverRequest request, @QueryParam("message") String message) { - return retrieveService.retrieveContent(request, message).stream().map(MapperUtil::toSourceResponse).toList(); + return retrieveService.retrieveContent(request, message).stream().map(VectorRetriverAPI::toSourceResponse).toList(); } @POST @@ -39,7 +43,11 @@ public List> retrieveSourceMetadata(RetrieverRequest request, .toList(); } - - + public static SourceResponse toSourceResponse(Content content) { + SourceResponse response = new SourceResponse(); + response.setContent(content.textSegment().text()); + response.setMetadata(content.textSegment().metadata().toMap()); + return response; + } } diff --git a/src/main/java/com/redhat/composer/config/llm/models/streaming/MistralStreamingModel.java b/src/main/java/com/redhat/composer/config/llm/models/streaming/MistralStreamingModel.java deleted file mode 100644 index d4e22d6..0000000 --- a/src/main/java/com/redhat/composer/config/llm/models/streaming/MistralStreamingModel.java +++ /dev/null @@ -1,55 +0,0 @@ -package com.redhat.composer.config.llm.models.streaming; - -import org.eclipse.microprofile.config.inject.ConfigProperty; - -import com.redhat.composer.model.request.LLMRequest; - -import dev.langchain4j.model.chat.StreamingChatLanguageModel; -import dev.langchain4j.model.mistralai.MistralAiStreamingChatModel; -import dev.langchain4j.model.mistralai.MistralAiStreamingChatModel.MistralAiStreamingChatModelBuilder; -import jakarta.enterprise.context.ApplicationScoped; - - -@ApplicationScoped -public class MistralStreamingModel extends StreamingBaseModel { - - - @ConfigProperty( name = "mistral.default.url") - private String mistralDefaultUrl; - - @ConfigProperty(name = "mistral.default.apiKey") - private String mistralDefaultApiKey; - - @ConfigProperty(name = "openai.default.modelName") - private String mistralDefaultModelName; - - @ConfigProperty(name = "mistral.default.temp") - private double mistralDefaultTemp; - - - @Override - public StreamingChatLanguageModel getChatModel(LLMRequest request) { - MistralAiStreamingChatModelBuilder builder = MistralAiStreamingChatModel.builder(); - builder.baseUrl(request.getUrl() == null ? mistralDefaultUrl : request.getUrl()); - builder.apiKey(request.getApiKey() == null ? mistralDefaultApiKey : request.getApiKey()); - - builder.modelName(request.getModelName() == null ? mistralDefaultModelName : request.getModelName()); - - // TODO: Add all the following to the request - builder.temperature(mistralDefaultTemp); - - // Model names can be derived from MistralAiChatModelName enum - // if (modelName != null) { - // builder.modelName(modelName); - // } - // if (maxTokens != null) { - // builder.maxTokens(maxTokens); - // } - // if (safePrompt != null) { - // builder.safePrompt(safePrompt); - // } - - return builder.build(); - } - -} diff --git a/src/main/java/com/redhat/composer/config/llm/models/streaming/StreamingModelFactory.java b/src/main/java/com/redhat/composer/config/llm/models/streaming/StreamingModelFactory.java index 1616c05..e0b0091 100644 --- a/src/main/java/com/redhat/composer/config/llm/models/streaming/StreamingModelFactory.java +++ b/src/main/java/com/redhat/composer/config/llm/models/streaming/StreamingModelFactory.java @@ -6,10 +6,6 @@ @ApplicationScoped public class StreamingModelFactory { - @Inject - MistralStreamingModel mistralModel; - public static final String MISTRAL_MODEL = "mistral"; - @Inject OpenAIStreamingModel openAIModel; public static final String OPENAI_MODEL = "openai"; @@ -23,8 +19,6 @@ public StreamingBaseModel getModel(String modelType) { } switch (modelType) { - case MISTRAL_MODEL: - return mistralModel; case OPENAI_MODEL: return openAIModel; default: diff --git a/src/main/java/com/redhat/composer/config/llm/models/synchronous/MistralModel.java b/src/main/java/com/redhat/composer/config/llm/models/synchronous/MistralModel.java deleted file mode 100644 index a0a5ada..0000000 --- a/src/main/java/com/redhat/composer/config/llm/models/synchronous/MistralModel.java +++ /dev/null @@ -1,51 +0,0 @@ -package com.redhat.composer.config.llm.models.synchronous; - -import org.eclipse.microprofile.config.inject.ConfigProperty; - -import com.redhat.composer.model.request.LLMRequest; - -import dev.langchain4j.model.chat.ChatLanguageModel; -import dev.langchain4j.model.mistralai.MistralAiChatModel; -import dev.langchain4j.model.mistralai.MistralAiChatModel.MistralAiChatModelBuilder; -import jakarta.enterprise.context.ApplicationScoped; - - -@ApplicationScoped -public class MistralModel extends SynchronousBaseModel { - - - @ConfigProperty( name = "mistral.default.url") - private String mistralDefaultUrl; - - @ConfigProperty(name = "mistral.default.apiKey") - private String mistralDefaultApiKey; - - @ConfigProperty(name = "mistral.default.modelName") - private String mistralDefaultModelName; - - @ConfigProperty(name = "mistral.default.temp") - private double mistralDefaultTemp; - - - public ChatLanguageModel getChatModel(LLMRequest request) { - MistralAiChatModelBuilder builder = MistralAiChatModel.builder(); - builder.baseUrl(request.getUrl() == null ? mistralDefaultUrl : request.getUrl()); - builder.apiKey(request.getApiKey() == null ? mistralDefaultApiKey : request.getApiKey()); - - builder.modelName(request.getModelName() == null ? mistralDefaultModelName : request.getModelName()); - - // TODO: Add all the following to the request - builder.temperature(mistralDefaultTemp); - - // TODO: Add all the following to the request - // if (maxTokens != null) { - // builder.maxTokens(maxTokens); - // } - // if (safePrompt != null) { - // builder.safePrompt(safePrompt); - // } - - return builder.build(); - } - -} diff --git a/src/main/java/com/redhat/composer/config/llm/models/synchronous/SynchronousModelFactory.java b/src/main/java/com/redhat/composer/config/llm/models/synchronous/SynchronousModelFactory.java index f6e438b..804bee3 100644 --- a/src/main/java/com/redhat/composer/config/llm/models/synchronous/SynchronousModelFactory.java +++ b/src/main/java/com/redhat/composer/config/llm/models/synchronous/SynchronousModelFactory.java @@ -6,10 +6,6 @@ @ApplicationScoped public class SynchronousModelFactory { - @Inject - MistralModel mistralModel; - public static final String MISTRAL_MODEL = "mistral"; - @Inject OpenAIModel openAIModel; public static final String OPENAI_MODEL = "openai"; @@ -23,8 +19,6 @@ public SynchronousBaseModel getModel(String modelType) { } switch (modelType) { - case MISTRAL_MODEL: - return mistralModel; case OPENAI_MODEL: return openAIModel; default: diff --git a/src/main/java/com/redhat/composer/config/retriever/contentRetriever/ContentRetrieverClientFactory.java b/src/main/java/com/redhat/composer/config/retriever/contentRetriever/ContentRetrieverClientFactory.java index f3d5283..3cdde7a 100644 --- a/src/main/java/com/redhat/composer/config/retriever/contentRetriever/ContentRetrieverClientFactory.java +++ b/src/main/java/com/redhat/composer/config/retriever/contentRetriever/ContentRetrieverClientFactory.java @@ -1,5 +1,7 @@ package com.redhat.composer.config.retriever.contentRetriever; +import com.redhat.composer.model.enums.ContentRetrieverType; + import jakarta.enterprise.context.ApplicationScoped; import jakarta.inject.Inject; @@ -8,26 +10,24 @@ public class ContentRetrieverClientFactory { @Inject WeaviateContentRetrieverClient weaviateEmbeddingStoreClient; - public static final String WEAVIATE_CONTENT_RETRIEVER = "weaviate"; - + @Inject Neo4jContentRetrieverClient neo4jContentRetrieverClient; - public static final String NEO4J_CONTENT_RETRIEVER = "neo4j"; - final static String DEFAULT_CONTENT_RETRIEVER = WEAVIATE_CONTENT_RETRIEVER; + final static ContentRetrieverType DEFAULT_CONTENT_RETRIEVER = ContentRetrieverType.WEAVIATE; - public BaseContentRetrieverClient getContentRetrieverClient(String contentRetrieverType) { + public BaseContentRetrieverClient getContentRetrieverClient(ContentRetrieverType contentRetrieverType) { if(contentRetrieverType == null) { contentRetrieverType = DEFAULT_CONTENT_RETRIEVER; } switch (contentRetrieverType) { - case WEAVIATE_CONTENT_RETRIEVER: + case ContentRetrieverType.WEAVIATE: return weaviateEmbeddingStoreClient; - case NEO4J_CONTENT_RETRIEVER: + case ContentRetrieverType.NEO4J: return neo4jContentRetrieverClient; default: - throw new RuntimeException("Embedding type not found: " + contentRetrieverType); + throw new RuntimeException("Content Retriever type not found: " + contentRetrieverType); } } diff --git a/src/main/java/com/redhat/composer/config/retriever/contentRetriever/WeaviateContentRetrieverClient.java b/src/main/java/com/redhat/composer/config/retriever/contentRetriever/WeaviateContentRetrieverClient.java index 00ce851..50c1612 100644 --- a/src/main/java/com/redhat/composer/config/retriever/contentRetriever/WeaviateContentRetrieverClient.java +++ b/src/main/java/com/redhat/composer/config/retriever/contentRetriever/WeaviateContentRetrieverClient.java @@ -3,9 +3,9 @@ import org.eclipse.microprofile.config.inject.ConfigProperty; import org.jboss.logging.Logger; -import com.redhat.composer.api.ChatBotAPI; import com.redhat.composer.config.retriever.contentRetriever.custom.WeaviateEmbeddingStoreCustom; import com.redhat.composer.model.request.RetrieverRequest; +import com.redhat.composer.model.request.retriever.WeaviateRequest; import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.rag.content.retriever.ContentRetriever; @@ -33,11 +33,15 @@ public class WeaviateContentRetrieverClient extends BaseContentRetrieverClient { private String weaviateTextKey; public ContentRetriever getContentRetriever(RetrieverRequest request) { - String scheme = request.getScheme() != null ? request.getScheme() : weaviateScheme; - String host = request.getHost() != null ? request.getHost() : weaviateHost; - String apiKey = request.getApiKey() != null ? request.getApiKey() : weaviateApiKey; - String index = request.getIndex() != null ? request.getIndex() : weaviateIndex; - String textKey = request.getTextKey() != null ? request.getTextKey() : weaviateTextKey; + WeaviateRequest weaviateRequest = (WeaviateRequest) request.getBaseRetrieverRequest(); + if(weaviateRequest == null) { + weaviateRequest = new WeaviateRequest(); + } + String scheme = weaviateRequest.getScheme() != null ? weaviateRequest.getScheme() : weaviateScheme; + String host = weaviateRequest.getHost() != null ? weaviateRequest.getHost() : weaviateHost; + String apiKey = weaviateRequest.getApiKey() != null ? weaviateRequest.getApiKey() : weaviateApiKey; + String index = weaviateRequest.getIndex() != null ? weaviateRequest.getIndex() : weaviateIndex; + String textKey = weaviateRequest.getTextKey() != null ? weaviateRequest.getTextKey() : weaviateTextKey; log.info("Attempting to connect to Weaviate at " + scheme + "://" + host + " with index " + index); @@ -47,7 +51,7 @@ public ContentRetriever getContentRetriever(RetrieverRequest request) { .host(host) .apiKey(apiKey) .metadataParentKey("") - .metadataKeys(request.getMetadataFields()) + .metadataKeys(weaviateRequest.getMetadataFields()) .objectClass(index) .avoidDups(true) .textKey(textKey) diff --git a/src/main/java/com/redhat/composer/model/enums/ContentRetrieverType.java b/src/main/java/com/redhat/composer/model/enums/ContentRetrieverType.java new file mode 100644 index 0000000..c23f889 --- /dev/null +++ b/src/main/java/com/redhat/composer/model/enums/ContentRetrieverType.java @@ -0,0 +1,27 @@ +package com.redhat.composer.model.enums; + +public enum ContentRetrieverType { + + WEAVIATE("weaviate"), + NEO4J("neo4j"); + + private final String type; + + ContentRetrieverType(String type) { + this.type = type; + } + + public String getType() { + return type; + } + + public static ContentRetrieverType fromString(String type) { + for (ContentRetrieverType retrieverType : ContentRetrieverType.values()) { + if (retrieverType.getType().equalsIgnoreCase(type)) { + return retrieverType; + } + } + throw new IllegalArgumentException("No constant with type " + type + " found"); + } + +} diff --git a/src/main/java/com/redhat/composer/model/mongo/LLMConnectionEntity.java b/src/main/java/com/redhat/composer/model/mongo/LLMConnectionEntity.java index dc12aa6..8676fdd 100644 --- a/src/main/java/com/redhat/composer/model/mongo/LLMConnectionEntity.java +++ b/src/main/java/com/redhat/composer/model/mongo/LLMConnectionEntity.java @@ -4,8 +4,6 @@ import org.apache.commons.lang3.builder.EqualsBuilder; -import com.redhat.composer.config.llm.models.streaming.StreamingModelFactory; - import io.quarkus.mongodb.panache.common.MongoEntity; @MongoEntity(collection = "llm_connection") diff --git a/src/main/java/com/redhat/composer/model/mongo/RetrieverConnectionEntity.java b/src/main/java/com/redhat/composer/model/mongo/RetrieverConnectionEntity.java index d5df46c..d06c740 100644 --- a/src/main/java/com/redhat/composer/model/mongo/RetrieverConnectionEntity.java +++ b/src/main/java/com/redhat/composer/model/mongo/RetrieverConnectionEntity.java @@ -1,12 +1,10 @@ package com.redhat.composer.model.mongo; -import java.util.List; import java.util.Objects; import org.apache.commons.lang3.builder.EqualsBuilder; -import com.redhat.composer.config.retriever.contentRetriever.ContentRetrieverClientFactory; -import com.redhat.composer.config.retriever.embeddingModel.EmbeddingModelFactory; +import com.redhat.composer.model.mongo.contentRetrieverEntites.BaseRetrieverConnectionEntity; import io.quarkus.mongodb.panache.common.MongoEntity; @@ -14,48 +12,29 @@ @MongoEntity(collection = "retriever_connection") public class RetrieverConnectionEntity extends BaseEntity { - String contentRetrieverType; + + BaseRetrieverConnectionEntity connectionEntity; String embeddingType; String name; - private String description; - - // Key of the value containing the text used for retrieval and passed into the LLM - String textKey; - - // List of metadata fields to be retrieved as part of the content - List metadataFields = List.of("source"); - - String index; - - String scheme; - - String host; - - String apiKey; + String description; public RetrieverConnectionEntity() { } - public RetrieverConnectionEntity(String contentRetrieverType, String embeddingType, String name, String description, String textKey, List metadataFields, String index, String scheme, String host, String apiKey) { - this.contentRetrieverType = contentRetrieverType; + public RetrieverConnectionEntity(BaseRetrieverConnectionEntity connectionEntity, String embeddingType, String name, String description) { + this.connectionEntity = connectionEntity; this.embeddingType = embeddingType; this.name = name; this.description = description; - this.textKey = textKey; - this.metadataFields = metadataFields; - this.index = index; - this.scheme = scheme; - this.host = host; - this.apiKey = apiKey; } - public String getContentRetrieverType() { - return this.contentRetrieverType; + public BaseRetrieverConnectionEntity getConnectionEntity() { + return this.connectionEntity; } - public void setContentRetrieverType(String contentRetrieverType) { - this.contentRetrieverType = contentRetrieverType; + public void setConnectionEntity(BaseRetrieverConnectionEntity connectionEntity) { + this.connectionEntity = connectionEntity; } public String getEmbeddingType() { @@ -82,56 +61,8 @@ public void setDescription(String description) { this.description = description; } - public String getTextKey() { - return this.textKey; - } - - public void setTextKey(String textKey) { - this.textKey = textKey; - } - - public List getMetadataFields() { - return this.metadataFields; - } - - public void setMetadataFields(List metadataFields) { - this.metadataFields = metadataFields; - } - - public String getIndex() { - return this.index; - } - - public void setIndex(String index) { - this.index = index; - } - - public String getScheme() { - return this.scheme; - } - - public void setScheme(String scheme) { - this.scheme = scheme; - } - - public String getHost() { - return this.host; - } - - public void setHost(String host) { - this.host = host; - } - - public String getApiKey() { - return this.apiKey; - } - - public void setApiKey(String apiKey) { - this.apiKey = apiKey; - } - - public RetrieverConnectionEntity contentRetrieverType(String contentRetrieverType) { - setContentRetrieverType(contentRetrieverType); + public RetrieverConnectionEntity connectionEntity(BaseRetrieverConnectionEntity connectionEntity) { + setConnectionEntity(connectionEntity); return this; } @@ -150,36 +81,6 @@ public RetrieverConnectionEntity description(String description) { return this; } - public RetrieverConnectionEntity textKey(String textKey) { - setTextKey(textKey); - return this; - } - - public RetrieverConnectionEntity metadataFields(List metadataFields) { - setMetadataFields(metadataFields); - return this; - } - - public RetrieverConnectionEntity index(String index) { - setIndex(index); - return this; - } - - public RetrieverConnectionEntity scheme(String scheme) { - setScheme(scheme); - return this; - } - - public RetrieverConnectionEntity host(String host) { - setHost(host); - return this; - } - - public RetrieverConnectionEntity apiKey(String apiKey) { - setApiKey(apiKey); - return this; - } - @Override public boolean equals(Object o) { return EqualsBuilder.reflectionEquals(this, o); @@ -187,25 +88,17 @@ public boolean equals(Object o) { @Override public int hashCode() { - return Objects.hash(contentRetrieverType, embeddingType, name, description, textKey, metadataFields, index, scheme, host, apiKey); + return Objects.hash(connectionEntity, embeddingType, name, description); } @Override public String toString() { return "{" + - " contentRetrieverType='" + getContentRetrieverType() + "'" + + " connectionEntity='" + getConnectionEntity() + "'" + ", embeddingType='" + getEmbeddingType() + "'" + ", name='" + getName() + "'" + ", description='" + getDescription() + "'" + - ", textKey='" + getTextKey() + "'" + - ", metadataFields='" + getMetadataFields() + "'" + - ", index='" + getIndex() + "'" + - ", scheme='" + getScheme() + "'" + - ", host='" + getHost() + "'" + - ", apiKey='" + getApiKey() + "'" + "}"; } - - } diff --git a/src/main/java/com/redhat/composer/model/mongo/contentRetrieverEntites/BaseRetrieverConnectionEntity.java b/src/main/java/com/redhat/composer/model/mongo/contentRetrieverEntites/BaseRetrieverConnectionEntity.java new file mode 100644 index 0000000..3a0c0ba --- /dev/null +++ b/src/main/java/com/redhat/composer/model/mongo/contentRetrieverEntites/BaseRetrieverConnectionEntity.java @@ -0,0 +1,32 @@ +package com.redhat.composer.model.mongo.contentRetrieverEntites; +import org.bson.codecs.pojo.annotations.BsonDiscriminator; + +import com.redhat.composer.model.enums.ContentRetrieverType; + +@BsonDiscriminator +public class BaseRetrieverConnectionEntity { + + ContentRetrieverType contentRetrieverType; + + + public BaseRetrieverConnectionEntity() { + } + + public BaseRetrieverConnectionEntity(ContentRetrieverType contentRetrieverType) { + this.contentRetrieverType = contentRetrieverType; + } + + public ContentRetrieverType getContentRetrieverType() { + return this.contentRetrieverType; + } + + public void setContentRetrieverType(ContentRetrieverType contentRetrieverType) { + this.contentRetrieverType = contentRetrieverType; + } + + public BaseRetrieverConnectionEntity contentRetrieverType(ContentRetrieverType contentRetrieverType) { + setContentRetrieverType(contentRetrieverType); + return this; + } + +} diff --git a/src/main/java/com/redhat/composer/model/mongo/contentRetrieverEntites/Neo4JEntity.java b/src/main/java/com/redhat/composer/model/mongo/contentRetrieverEntites/Neo4JEntity.java new file mode 100644 index 0000000..b559c50 --- /dev/null +++ b/src/main/java/com/redhat/composer/model/mongo/contentRetrieverEntites/Neo4JEntity.java @@ -0,0 +1,8 @@ +package com.redhat.composer.model.mongo.contentRetrieverEntites; + +import org.bson.codecs.pojo.annotations.BsonDiscriminator; + +@BsonDiscriminator("neo4j") +public class Neo4JEntity extends BaseRetrieverConnectionEntity { + +} diff --git a/src/main/java/com/redhat/composer/model/mongo/contentRetrieverEntites/WeaviateConnectionEntity.java b/src/main/java/com/redhat/composer/model/mongo/contentRetrieverEntites/WeaviateConnectionEntity.java new file mode 100644 index 0000000..ee31ba5 --- /dev/null +++ b/src/main/java/com/redhat/composer/model/mongo/contentRetrieverEntites/WeaviateConnectionEntity.java @@ -0,0 +1,134 @@ +package com.redhat.composer.model.mongo.contentRetrieverEntites; + +import org.apache.commons.lang3.builder.EqualsBuilder; +import org.bson.codecs.pojo.annotations.BsonDiscriminator; + +import com.redhat.composer.model.enums.ContentRetrieverType; + +import java.util.List; +import java.util.Objects; + +@BsonDiscriminator("weaviate") +public class WeaviateConnectionEntity extends BaseRetrieverConnectionEntity { + + // Key of the value containing the text used for retrieval and passed into the LLM + String textKey; + + // List of metadata fields to be retrieved as part of the content + List metadataFields = List.of("source"); + + String index; + + String scheme; + + String host; + + String apiKey; + + + public WeaviateConnectionEntity() { + contentRetrieverType = ContentRetrieverType.WEAVIATE; + } + + public String getTextKey() { + return this.textKey; + } + + public void setTextKey(String textKey) { + this.textKey = textKey; + } + + public List getMetadataFields() { + return this.metadataFields; + } + + public void setMetadataFields(List metadataFields) { + this.metadataFields = metadataFields; + } + + public String getIndex() { + return this.index; + } + + public void setIndex(String index) { + this.index = index; + } + + public String getScheme() { + return this.scheme; + } + + public void setScheme(String scheme) { + this.scheme = scheme; + } + + public String getHost() { + return this.host; + } + + public void setHost(String host) { + this.host = host; + } + + public String getApiKey() { + return this.apiKey; + } + + public void setApiKey(String apiKey) { + this.apiKey = apiKey; + } + + public WeaviateConnectionEntity textKey(String textKey) { + setTextKey(textKey); + return this; + } + + public WeaviateConnectionEntity metadataFields(List metadataFields) { + setMetadataFields(metadataFields); + return this; + } + + public WeaviateConnectionEntity index(String index) { + setIndex(index); + return this; + } + + public WeaviateConnectionEntity scheme(String scheme) { + setScheme(scheme); + return this; + } + + public WeaviateConnectionEntity host(String host) { + setHost(host); + return this; + } + + public WeaviateConnectionEntity apiKey(String apiKey) { + setApiKey(apiKey); + return this; + } + + @Override + public boolean equals(Object o) { + return EqualsBuilder.reflectionEquals(this, o); + } + + @Override + public int hashCode() { + return Objects.hash(textKey, metadataFields, index, scheme, host, apiKey); + } + + @Override + public String toString() { + return "{" + + " textKey='" + getTextKey() + "'" + + ", metadataFields='" + getMetadataFields() + "'" + + ", index='" + getIndex() + "'" + + ", scheme='" + getScheme() + "'" + + ", host='" + getHost() + "'" + + ", apiKey='" + getApiKey() + "'" + + "}"; + } + + +} diff --git a/src/main/java/com/redhat/composer/model/request/RetrieverRequest.java b/src/main/java/com/redhat/composer/model/request/RetrieverRequest.java index 4d87a70..a16b710 100644 --- a/src/main/java/com/redhat/composer/model/request/RetrieverRequest.java +++ b/src/main/java/com/redhat/composer/model/request/RetrieverRequest.java @@ -1,52 +1,35 @@ package com.redhat.composer.model.request; -import java.util.List; +import com.redhat.composer.model.request.retriever.BaseRetrieverRequest; import java.util.Objects; import org.apache.commons.lang3.builder.EqualsBuilder; -import com.redhat.composer.config.retriever.contentRetriever.ContentRetrieverClientFactory; -import com.redhat.composer.config.retriever.embeddingModel.EmbeddingModelFactory; - public class RetrieverRequest { - String contentRetrieverType; + BaseRetrieverRequest baseRetrieverRequest; String embeddingType; - // Key of the value containing the text used for retrieval and passed into the LLM - String textKey; - - // List of metadata fields to be retrieved as part of the content - List metadataFields = List.of("source"); - - String index; - - String scheme; - - String host; + String name; + String description; - String apiKey; public RetrieverRequest() { } - public RetrieverRequest(String contentRetrieverType, String embeddingType, String textKey, List metadataFields, String index, String scheme, String host, String apiKey) { - this.contentRetrieverType = contentRetrieverType; + public RetrieverRequest(BaseRetrieverRequest baseRetrieverRequest, String embeddingType, String name, String description) { + this.baseRetrieverRequest = baseRetrieverRequest; this.embeddingType = embeddingType; - this.textKey = textKey; - this.metadataFields = metadataFields; - this.index = index; - this.scheme = scheme; - this.host = host; - this.apiKey = apiKey; + this.name = name; + this.description = description; } - public String getContentRetrieverType() { - return this.contentRetrieverType; + public BaseRetrieverRequest getBaseRetrieverRequest() { + return this.baseRetrieverRequest; } - public void setContentRetrieverType(String contentRetrieverType) { - this.contentRetrieverType = contentRetrieverType; + public void setBaseRetrieverRequest(BaseRetrieverRequest baseRetrieverRequest) { + this.baseRetrieverRequest = baseRetrieverRequest; } public String getEmbeddingType() { @@ -57,56 +40,24 @@ public void setEmbeddingType(String embeddingType) { this.embeddingType = embeddingType; } - public String getTextKey() { - return this.textKey; - } - - public void setTextKey(String textKey) { - this.textKey = textKey; - } - - public List getMetadataFields() { - return this.metadataFields; + public String getName() { + return this.name; } - public void setMetadataFields(List metadataFields) { - this.metadataFields = metadataFields; + public void setName(String name) { + this.name = name; } - public String getIndex() { - return this.index; + public String getDescription() { + return this.description; } - public void setIndex(String index) { - this.index = index; + public void setDescription(String description) { + this.description = description; } - public String getScheme() { - return this.scheme; - } - - public void setScheme(String scheme) { - this.scheme = scheme; - } - - public String getHost() { - return this.host; - } - - public void setHost(String host) { - this.host = host; - } - - public String getApiKey() { - return this.apiKey; - } - - public void setApiKey(String apiKey) { - this.apiKey = apiKey; - } - - public RetrieverRequest contentRetrieverType(String contentRetrieverType) { - setContentRetrieverType(contentRetrieverType); + public RetrieverRequest baseRetrieverRequest(BaseRetrieverRequest baseRetrieverRequest) { + setBaseRetrieverRequest(baseRetrieverRequest); return this; } @@ -115,33 +66,13 @@ public RetrieverRequest embeddingType(String embeddingType) { return this; } - public RetrieverRequest textKey(String textKey) { - setTextKey(textKey); - return this; - } - - public RetrieverRequest metadataFields(List metadataFields) { - setMetadataFields(metadataFields); - return this; - } - - public RetrieverRequest index(String index) { - setIndex(index); + public RetrieverRequest name(String name) { + setName(name); return this; } - public RetrieverRequest scheme(String scheme) { - setScheme(scheme); - return this; - } - - public RetrieverRequest host(String host) { - setHost(host); - return this; - } - - public RetrieverRequest apiKey(String apiKey) { - setApiKey(apiKey); + public RetrieverRequest description(String description) { + setDescription(description); return this; } @@ -152,22 +83,18 @@ public boolean equals(Object o) { @Override public int hashCode() { - return Objects.hash(contentRetrieverType, embeddingType, textKey, metadataFields, index, scheme, host, apiKey); + return Objects.hash(baseRetrieverRequest, embeddingType, name, description); } @Override public String toString() { return "{" + - " contentRetrieverType='" + getContentRetrieverType() + "'" + + " baseRetrieverRequest='" + getBaseRetrieverRequest() + "'" + ", embeddingType='" + getEmbeddingType() + "'" + - ", textKey='" + getTextKey() + "'" + - ", metadataFields='" + getMetadataFields() + "'" + - ", index='" + getIndex() + "'" + - ", scheme='" + getScheme() + "'" + - ", host='" + getHost() + "'" + - ", apiKey='" + getApiKey() + "'" + + ", name='" + getName() + "'" + + ", description='" + getDescription() + "'" + "}"; } - + } diff --git a/src/main/java/com/redhat/composer/model/request/retriever/BaseRetrieverRequest.java b/src/main/java/com/redhat/composer/model/request/retriever/BaseRetrieverRequest.java new file mode 100644 index 0000000..894af54 --- /dev/null +++ b/src/main/java/com/redhat/composer/model/request/retriever/BaseRetrieverRequest.java @@ -0,0 +1,42 @@ +package com.redhat.composer.model.request.retriever; + +import org.apache.commons.lang3.builder.EqualsBuilder; + +import com.fasterxml.jackson.annotation.JsonSubTypes; +import com.fasterxml.jackson.annotation.JsonTypeInfo; + +@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.EXISTING_PROPERTY, property = "contentRetrieverType", visible = true) +@JsonSubTypes({ + @JsonSubTypes.Type(value = WeaviateRequest.class, name = "weaviate"), + @JsonSubTypes.Type(value = Neo4JRequest.class, name = "neo4j") +}) +public class BaseRetrieverRequest { + + String contentRetrieverType; + + + public BaseRetrieverRequest() { + } + + public BaseRetrieverRequest(String contentRetrieverType) { + this.contentRetrieverType = contentRetrieverType; + } + + public String getContentRetrieverType() { + return this.contentRetrieverType; + } + + public void setContentRetrieverType(String contentRetrieverType) { + this.contentRetrieverType = contentRetrieverType; + } + + public BaseRetrieverRequest contentRetrieverType(String contentRetrieverType) { + setContentRetrieverType(contentRetrieverType); + return this; + } + + @Override + public boolean equals(Object o) { + return EqualsBuilder.reflectionEquals(this, o); + } +} \ No newline at end of file diff --git a/src/main/java/com/redhat/composer/model/request/retriever/Neo4JRequest.java b/src/main/java/com/redhat/composer/model/request/retriever/Neo4JRequest.java new file mode 100644 index 0000000..bdd57ad --- /dev/null +++ b/src/main/java/com/redhat/composer/model/request/retriever/Neo4JRequest.java @@ -0,0 +1,6 @@ +package com.redhat.composer.model.request.retriever; + +public class Neo4JRequest extends BaseRetrieverRequest { + + +} diff --git a/src/main/java/com/redhat/composer/model/request/retriever/WeaviateRequest.java b/src/main/java/com/redhat/composer/model/request/retriever/WeaviateRequest.java new file mode 100644 index 0000000..b62caab --- /dev/null +++ b/src/main/java/com/redhat/composer/model/request/retriever/WeaviateRequest.java @@ -0,0 +1,144 @@ +package com.redhat.composer.model.request.retriever; + +import java.util.List; +import java.util.Objects; + +import org.apache.commons.lang3.builder.EqualsBuilder; +import org.bson.codecs.pojo.annotations.BsonDiscriminator; + +import com.fasterxml.jackson.annotation.JsonSubTypes; +import com.redhat.composer.model.enums.ContentRetrieverType; + +public class WeaviateRequest extends BaseRetrieverRequest { + + // Key of the value containing the text used for retrieval and passed into the LLM + String textKey; + + // List of metadata fields to be retrieved as part of the content + List metadataFields = List.of("source"); + + String index; + + String scheme; + + String host; + + String apiKey; + + + public WeaviateRequest() { + contentRetrieverType = ContentRetrieverType.WEAVIATE.getType(); + } + + public WeaviateRequest(String textKey, List metadataFields, String index, String scheme, String host, String apiKey) { + this.textKey = textKey; + this.metadataFields = metadataFields; + this.index = index; + this.scheme = scheme; + this.host = host; + this.apiKey = apiKey; + } + + public String getTextKey() { + return this.textKey; + } + + public void setTextKey(String textKey) { + this.textKey = textKey; + } + + public List getMetadataFields() { + return this.metadataFields; + } + + public void setMetadataFields(List metadataFields) { + this.metadataFields = metadataFields; + } + + public String getIndex() { + return this.index; + } + + public void setIndex(String index) { + this.index = index; + } + + public String getScheme() { + return this.scheme; + } + + public void setScheme(String scheme) { + this.scheme = scheme; + } + + public String getHost() { + return this.host; + } + + public void setHost(String host) { + this.host = host; + } + + public String getApiKey() { + return this.apiKey; + } + + public void setApiKey(String apiKey) { + this.apiKey = apiKey; + } + + public WeaviateRequest textKey(String textKey) { + setTextKey(textKey); + return this; + } + + public WeaviateRequest metadataFields(List metadataFields) { + setMetadataFields(metadataFields); + return this; + } + + public WeaviateRequest index(String index) { + setIndex(index); + return this; + } + + public WeaviateRequest scheme(String scheme) { + setScheme(scheme); + return this; + } + + public WeaviateRequest host(String host) { + setHost(host); + return this; + } + + public WeaviateRequest apiKey(String apiKey) { + setApiKey(apiKey); + return this; + } + + @Override + public boolean equals(Object o) { + return EqualsBuilder.reflectionEquals(this, o); + } + + @Override + public int hashCode() { + return Objects.hash(textKey, metadataFields, index, scheme, host, apiKey); + } + + @Override + public String toString() { + return "{" + + " textKey='" + getTextKey() + "'" + + ", metadataFields='" + getMetadataFields() + "'" + + ", index='" + getIndex() + "'" + + ", scheme='" + getScheme() + "'" + + ", host='" + getHost() + "'" + + ", apiKey='" + getApiKey() + "'" + + "}"; + } + + + +} diff --git a/src/main/java/com/redhat/composer/services/AssistantInfoService.java b/src/main/java/com/redhat/composer/services/AssistantInfoService.java index e179b30..f946b8b 100644 --- a/src/main/java/com/redhat/composer/services/AssistantInfoService.java +++ b/src/main/java/com/redhat/composer/services/AssistantInfoService.java @@ -10,8 +10,10 @@ import com.redhat.composer.model.request.LLMRequest; import com.redhat.composer.model.request.RetrieverRequest; import com.redhat.composer.model.response.AssistantResponse; +import com.redhat.composer.util.mappers.MapperUtil; import jakarta.enterprise.context.ApplicationScoped; +import jakarta.inject.Inject; // TODO: Set up MapStruct service or something // https://mapstruct.org/ @@ -19,6 +21,9 @@ @ApplicationScoped public class AssistantInfoService { + @Inject + MapperUtil mapperUtil; + public AssistantEntity createAssistant(AssistantCreationRequest request) { LLMConnectionEntity llm = (LLMConnectionEntity) LLMConnectionEntity.findByIdOptional(request.getLlmConnectionId()).orElseThrow(() -> new IllegalArgumentException("LLM Connection not found")); @@ -49,16 +54,8 @@ public List getAssistant() { ).toList(); } - public RetrieverConnectionEntity creaRetrieverConnectionEntity(RetrieverRequest request) { - RetrieverConnectionEntity entity = new RetrieverConnectionEntity(); - entity.setContentRetrieverType(request.getContentRetrieverType()); - entity.setEmbeddingType(request.getEmbeddingType()); - entity.setHost(request.getHost()); - entity.setApiKey(request.getApiKey()); - entity.setIndex(request.getIndex()); - entity.setMetadataFields(request.getMetadataFields()); - entity.setTextKey(request.getTextKey()); - entity.setScheme(request.getScheme()); + public RetrieverConnectionEntity createRetrieverConnectionEntity(RetrieverRequest request) { + RetrieverConnectionEntity entity = mapperUtil.toEntity(request); entity.persist(); return entity; } diff --git a/src/main/java/com/redhat/composer/services/ChatBotService.java b/src/main/java/com/redhat/composer/services/ChatBotService.java index 7e725aa..f054f58 100644 --- a/src/main/java/com/redhat/composer/services/ChatBotService.java +++ b/src/main/java/com/redhat/composer/services/ChatBotService.java @@ -14,7 +14,7 @@ import com.redhat.composer.model.request.AssistantChatRequest; import com.redhat.composer.model.request.ChatBotRequest; import com.redhat.composer.repositories.AssistantRepository; -import com.redhat.composer.util.MapperUtil; +import com.redhat.composer.util.mappers.MapperUtil; import dev.langchain4j.model.chat.StreamingChatLanguageModel; import dev.langchain4j.rag.content.Content; @@ -42,6 +42,9 @@ public class ChatBotService { @Inject AssistantRepository assistantRepository; + @Inject + MapperUtil mapperUtil; + public Multi chat(AssistantChatRequest request) { @@ -59,8 +62,8 @@ public Multi chat(AssistantChatRequest request) { ChatBotRequest chatBotRequest = new ChatBotRequest(); chatBotRequest.setMessage(request.getMessage()); chatBotRequest.setContext(request.getContext()); - chatBotRequest.setRetrieverRequest(MapperUtil.toRetrieverRequest(retrieverConnection)); - chatBotRequest.setModelRequest(MapperUtil.toLLMRequest(llmConnection)); + chatBotRequest.setRetrieverRequest(mapperUtil.toRequest(retrieverConnection)); + chatBotRequest.setModelRequest(mapperUtil.toRequest(llmConnection)); return chat(chatBotRequest); } diff --git a/src/main/java/com/redhat/composer/services/RetrieveService.java b/src/main/java/com/redhat/composer/services/RetrieveService.java index d8d0728..e263645 100644 --- a/src/main/java/com/redhat/composer/services/RetrieveService.java +++ b/src/main/java/com/redhat/composer/services/RetrieveService.java @@ -4,6 +4,7 @@ import com.redhat.composer.config.retriever.contentRetriever.BaseContentRetrieverClient; import com.redhat.composer.config.retriever.contentRetriever.ContentRetrieverClientFactory; +import com.redhat.composer.model.enums.ContentRetrieverType; import com.redhat.composer.model.request.RetrieverRequest; import dev.langchain4j.rag.content.Content; @@ -19,7 +20,8 @@ public class RetrieveService { ContentRetrieverClientFactory contentRetrieverClientFactory; public ContentRetriever getContentRetriever(RetrieverRequest request) { - BaseContentRetrieverClient client = contentRetrieverClientFactory.getContentRetrieverClient(request.getContentRetrieverType()); + ContentRetrieverType contentRetrieverType = ContentRetrieverType.fromString(request.getBaseRetrieverRequest().getContentRetrieverType()); + BaseContentRetrieverClient client = contentRetrieverClientFactory.getContentRetrieverClient(contentRetrieverType); //TODO: Fix this return client.getContentRetriever(request); } diff --git a/src/main/java/com/redhat/composer/util/MapperUtil.java b/src/main/java/com/redhat/composer/util/MapperUtil.java deleted file mode 100644 index 3440814..0000000 --- a/src/main/java/com/redhat/composer/util/MapperUtil.java +++ /dev/null @@ -1,44 +0,0 @@ -package com.redhat.composer.util; - -import com.redhat.composer.model.mongo.LLMConnectionEntity; -import com.redhat.composer.model.mongo.RetrieverConnectionEntity; -import com.redhat.composer.model.request.LLMRequest; -import com.redhat.composer.model.request.RetrieverRequest; -import com.redhat.composer.model.response.SourceResponse; - -import dev.langchain4j.rag.content.Content; - -public class MapperUtil { - - public static SourceResponse toSourceResponse(Content content) { - SourceResponse response = new SourceResponse(); - response.setContent(content.textSegment().text()); - response.setMetadata(content.textSegment().metadata().toMap()); - return response; - } - - public static LLMRequest toLLMRequest(LLMConnectionEntity entity) { - LLMRequest request = new LLMRequest(); - request.setUrl(entity.getUrl()); - request.setApiKey(entity.getApiKey()); - request.setModelName(entity.getModelName()); - request.setModelType(entity.getModelType()); - return request; - } - - - public static RetrieverRequest toRetrieverRequest(RetrieverConnectionEntity entity) { - RetrieverRequest request = new RetrieverRequest(); - request.setContentRetrieverType(entity.getContentRetrieverType()); - request.setEmbeddingType(entity.getEmbeddingType()); - request.setHost(entity.getHost()); - request.setApiKey(entity.getApiKey()); - request.setIndex(entity.getIndex()); - request.setTextKey(entity.getTextKey()); - request.setScheme(entity.getScheme()); - request.setMetadataFields(entity.getMetadataFields()); - return request; - } - - -} diff --git a/src/main/java/com/redhat/composer/util/mappers/MapperUtil.java b/src/main/java/com/redhat/composer/util/mappers/MapperUtil.java new file mode 100644 index 0000000..ead1774 --- /dev/null +++ b/src/main/java/com/redhat/composer/util/mappers/MapperUtil.java @@ -0,0 +1,66 @@ +package com.redhat.composer.util.mappers; + +import org.mapstruct.Mapper; +import org.mapstruct.Mapping; +import org.mapstruct.factory.Mappers; + +import com.redhat.composer.model.enums.ContentRetrieverType; +import com.redhat.composer.model.mongo.LLMConnectionEntity; +import com.redhat.composer.model.mongo.RetrieverConnectionEntity; +import com.redhat.composer.model.mongo.contentRetrieverEntites.BaseRetrieverConnectionEntity; +import com.redhat.composer.model.mongo.contentRetrieverEntites.Neo4JEntity; +import com.redhat.composer.model.mongo.contentRetrieverEntites.WeaviateConnectionEntity; +import com.redhat.composer.model.request.LLMRequest; +import com.redhat.composer.model.request.RetrieverRequest; +import com.redhat.composer.model.request.retriever.BaseRetrieverRequest; +import com.redhat.composer.model.request.retriever.Neo4JRequest; +import com.redhat.composer.model.request.retriever.WeaviateRequest; + +import jakarta.enterprise.inject.Default; + +@Default +@Mapper(config = QuarkusMapperConfig.class) +public interface MapperUtil { + + RetrieverConnectionMapper retrieverConnectionMapper = Mappers.getMapper(RetrieverConnectionMapper.class); + + @Mapping(target = "connectionEntity", source = "baseRetrieverRequest") + RetrieverConnectionEntity toEntity(RetrieverRequest request); + + @Mapping(source = "connectionEntity", target = "baseRetrieverRequest") + RetrieverRequest toRequest(RetrieverConnectionEntity entity); + + LLMConnectionEntity toEntity(LLMRequest request); + + LLMRequest toRequest(LLMConnectionEntity entity); + + default BaseRetrieverConnectionEntity mapToBaseEntity(BaseRetrieverRequest request) { + if(request == null) { + return null; + } + switch (ContentRetrieverType.fromString(request.getContentRetrieverType())) { + case ContentRetrieverType.WEAVIATE: + return retrieverConnectionMapper.toEntity((WeaviateRequest) request); + case ContentRetrieverType.NEO4J: + return retrieverConnectionMapper.toEntity((Neo4JRequest) request); + default: + return null; + } + } + + default BaseRetrieverRequest mapToBaseRequest(BaseRetrieverConnectionEntity entity){ + if(entity == null || entity.getContentRetrieverType() == null) { + return null; + } + switch (entity.getContentRetrieverType()) { + case ContentRetrieverType.WEAVIATE: + return retrieverConnectionMapper.toRequest((WeaviateConnectionEntity) entity); + case ContentRetrieverType.NEO4J: + return retrieverConnectionMapper.toRequest((Neo4JEntity) entity); + default: + return null; + } + } + + +} diff --git a/src/main/java/com/redhat/composer/util/mappers/QuarkusMapperConfig.java b/src/main/java/com/redhat/composer/util/mappers/QuarkusMapperConfig.java new file mode 100644 index 0000000..0a0d054 --- /dev/null +++ b/src/main/java/com/redhat/composer/util/mappers/QuarkusMapperConfig.java @@ -0,0 +1,9 @@ +package com.redhat.composer.util.mappers; + +import org.mapstruct.MapperConfig; + +@MapperConfig(componentModel = "cdi") +public interface QuarkusMapperConfig { + + +} diff --git a/src/main/java/com/redhat/composer/util/mappers/RetrieverConnectionMapper.java b/src/main/java/com/redhat/composer/util/mappers/RetrieverConnectionMapper.java new file mode 100644 index 0000000..680ccfc --- /dev/null +++ b/src/main/java/com/redhat/composer/util/mappers/RetrieverConnectionMapper.java @@ -0,0 +1,29 @@ +package com.redhat.composer.util.mappers; + +import org.mapstruct.Mapper; + +import com.redhat.composer.model.enums.ContentRetrieverType; +import com.redhat.composer.model.mongo.contentRetrieverEntites.Neo4JEntity; +import com.redhat.composer.model.mongo.contentRetrieverEntites.WeaviateConnectionEntity; +import com.redhat.composer.model.request.retriever.Neo4JRequest; +import com.redhat.composer.model.request.retriever.WeaviateRequest; + +@Mapper(config = QuarkusMapperConfig.class) +public interface RetrieverConnectionMapper { + + WeaviateConnectionEntity toEntity(WeaviateRequest request); + + Neo4JEntity toEntity(Neo4JRequest request); + + WeaviateRequest toRequest(WeaviateConnectionEntity entity); + + Neo4JRequest toRequest(Neo4JEntity entity); + + default String toString(ContentRetrieverType contentRetrieverType) { + return contentRetrieverType.getType(); + } + + default ContentRetrieverType toContentRetrieverType(String value) { + return ContentRetrieverType.fromString(value); + } +} diff --git a/src/main/resources/db/changeLog.yml b/src/main/resources/db/changeLog.yml index acb8d3f..5f7b146 100644 --- a/src/main/resources/db/changeLog.yml +++ b/src/main/resources/db/changeLog.yml @@ -28,16 +28,16 @@ databaseChangeLog: collectionName: retriever_connection - insertOne: collectionName: retriever_connection - document: "{_id: ${rc.ocp.id}, index: 'Openshift_container_platform_en_US_4_17', name: 'ocp_4_17_default' , description: 'Openshift Container Platform Default Connection'}" + document: "{_id: ${rc.ocp.id}, name: 'ocp_4_17_default' , description: 'Openshift Container Platform Default Connection', 'connectionEntity': { '_t': 'weaviate', 'index': 'Openshift_container_platform_en_US_4_17'}}" - insertOne: collectionName: retriever_connection - document: "{_id: ${rc.rhel.id}, index: 'Red_hat_enterprise_linux_en_US_9', name: 'rhel_9_default' , description: 'Red Hat Enterprise Linux 9 Default Connection'}" + document: "{_id: ${rc.rhel.id}, name: 'rhel_9_default' , description: 'Red Hat Enterprise Linux 9 Default Connection', 'connectionEntity': { '_t': 'weaviate', 'index': 'Red_hat_enterprise_linux_en_US_9'}}" - insertOne: collectionName: retriever_connection - document: "{_id: ${rc.ansible.id}, index: 'Red_hat_ansible_automation_platform_en_US_2_5', name: 'ansible_2_5_default' , description: 'Ansible Automation Platform 2.5 Default Connection'}" + document: "{_id: ${rc.ansible.id}, name: 'ansible_2_5_default' , description: 'Ansible Automation Platform 2.5 Default Connection''connectionEntity': { '_t': 'weaviate', 'index': 'Red_hat_ansible_automation_platform_en_US_2_5'}}" - insertOne: collectionName: retriever_connection - document: "{_id: ${rc.rhoai.id}, index: 'Red_hat_openshift_ai_self_managed_en_US_2_14', name: 'rhoia_2_14_default' , description: 'Red Hat Openshift AI Self Managed 2.14 Default Connection'}" + document: "{_id: ${rc.rhoai.id}, name: 'rhoia_2_14_default' , description: 'Red Hat Openshift AI Self Managed 2.14 Default Connection''connectionEntity': { '_t': 'weaviate', 'index': 'Red_hat_openshift_ai_self_managed_en_US_2_14'}}" # Create LLM Connection collection - createCollection: @@ -80,7 +80,7 @@ databaseChangeLog: collectionName: retriever_connection - insertOne: collectionName: retriever_connection - document: "{_id: ${neo4j.default.id}, name: 'neo4j_default' , description: 'Neo4j Default Connection', contentRetrieverType: 'neo4j'}" + document: "{_id: ${neo4j.default.id}, name: 'neo4j_default' , description: 'Neo4j Default Connection', 'connectionEntity': { '_t': 'neo4j'}}" # Create Assistants collection - createCollection: