From 4c645421b7911661740494e98f5e1bc0411b2123 Mon Sep 17 00:00:00 2001 From: Javier Ochoa Date: Wed, 27 Nov 2024 11:13:07 -0500 Subject: [PATCH 1/2] added userPrompt and exampleQuestions to assistants --- .../composer/model/mongo/AssistantEntity.java | 115 ++++++--- .../request/AssistantCreationRequest.java | 240 ++++++++++-------- .../services/AssistantInfoService.java | 41 +-- 3 files changed, 242 insertions(+), 154 deletions(-) diff --git a/src/main/java/com/redhat/composer/model/mongo/AssistantEntity.java b/src/main/java/com/redhat/composer/model/mongo/AssistantEntity.java index 0da484c..9ecdcba 100644 --- a/src/main/java/com/redhat/composer/model/mongo/AssistantEntity.java +++ b/src/main/java/com/redhat/composer/model/mongo/AssistantEntity.java @@ -1,4 +1,6 @@ package com.redhat.composer.model.mongo; + +import java.util.List; import java.util.Objects; import org.apache.commons.lang3.builder.EqualsBuilder; @@ -8,115 +10,162 @@ @MongoEntity(collection = "assistant") -public class AssistantEntity extends BaseEntity { +public class AssistantEntity extends BaseEntity +{ String name; String description; String displayName; + String userPrompt; - - ObjectId llmConnectionId; + List exampleQuestions; + ObjectId llmConnectionId; ObjectId retrieverConnectionId; - - - public AssistantEntity() { + public AssistantEntity() + { } - public AssistantEntity(String name, String description, String displayName, ObjectId llmConnectionId, ObjectId retrieverConnectionId) { + public AssistantEntity(String name, String description, String displayName, String userPrompt, + List exampleQuestions, ObjectId llmConnectionId, ObjectId retrieverConnectionId) + { this.name = name; this.description = description; this.displayName = displayName; + this.userPrompt = userPrompt; + this.exampleQuestions = exampleQuestions; this.llmConnectionId = llmConnectionId; this.retrieverConnectionId = retrieverConnectionId; } - public String getName() { + public String getName() + { return this.name; } - public void setName(String name) { + public void setName(String name) + { this.name = name; } - public String getDescription() { + public String getDescription() + { return this.description; } - public void setDescription(String description) { + public void setDescription(String description) + { this.description = description; } - public String getDisplayName() { + public String getDisplayName() + { return this.displayName; } - public void setDisplayName(String displayName) { + public void setDisplayName(String displayName) + { this.displayName = displayName; } - public ObjectId getLlmConnectionId() { + + public List getExampleQuestions() + { + return exampleQuestions; + } + + public void setExampleQuestions(List exampleQuestions) + { + this.exampleQuestions = exampleQuestions; + } + + public String getUserPrompt() + { + return userPrompt; + } + + public void setUserPrompt(String userPrompt) + { + this.userPrompt = userPrompt; + } + + public ObjectId getLlmConnectionId() + { return this.llmConnectionId; } - public void setLlmConnectionId(ObjectId llmConnectionId) { + public void setLlmConnectionId(ObjectId llmConnectionId) + { this.llmConnectionId = llmConnectionId; } - public ObjectId getRetrieverConnectionId() { + public ObjectId getRetrieverConnectionId() + { return this.retrieverConnectionId; } - public void setRetrieverConnectionId(ObjectId retrieverConnectionId) { + public void setRetrieverConnectionId(ObjectId retrieverConnectionId) + { this.retrieverConnectionId = retrieverConnectionId; } - public AssistantEntity name(String name) { + public AssistantEntity name(String name) + { setName(name); return this; } - public AssistantEntity description(String description) { + public AssistantEntity description(String description) + { setDescription(description); return this; } - public AssistantEntity displayName(String displayName) { + public AssistantEntity displayName(String displayName) + { setDisplayName(displayName); return this; } - public AssistantEntity llmConnectionId(ObjectId llmConnectionId) { + public AssistantEntity llmConnectionId(ObjectId llmConnectionId) + { setLlmConnectionId(llmConnectionId); return this; } - public AssistantEntity retrieverConnectionId(ObjectId retrieverConnectionId) { + public AssistantEntity retrieverConnectionId(ObjectId retrieverConnectionId) + { setRetrieverConnectionId(retrieverConnectionId); return this; } @Override - public boolean equals(Object o) { - return EqualsBuilder.reflectionEquals(this, o); + public boolean equals(Object o) + { + return EqualsBuilder.reflectionEquals(this, o); } @Override - public int hashCode() { - return Objects.hash(name, description, displayName, llmConnectionId, retrieverConnectionId); + public int hashCode() + { + return Objects.hash(name, description, displayName, userPrompt, exampleQuestions, llmConnectionId, + retrieverConnectionId); } @Override - public String toString() { + public String toString() + { return "{" + - " name='" + getName() + "'" + - ", description='" + getDescription() + "'" + - ", displayName='" + getDisplayName() + "'" + - ", llmConnectionId='" + getLlmConnectionId() + "'" + - ", retrieverConnectionId='" + getRetrieverConnectionId() + "'" + - "}"; + " name='" + getName() + "'" + + ", description='" + getDescription() + "'" + + ", displayName='" + getDisplayName() + "'" + + ", userPrompt='" + getUserPrompt() + "'" + + ", exampleQuestions='" + getExampleQuestions() + "'" + + ", llmConnectionId='" + getLlmConnectionId() + "'" + + ", retrieverConnectionId='" + getRetrieverConnectionId() + "'" + + "}"; } } diff --git a/src/main/java/com/redhat/composer/model/request/AssistantCreationRequest.java b/src/main/java/com/redhat/composer/model/request/AssistantCreationRequest.java index 63ed20e..a8e569e 100644 --- a/src/main/java/com/redhat/composer/model/request/AssistantCreationRequest.java +++ b/src/main/java/com/redhat/composer/model/request/AssistantCreationRequest.java @@ -1,5 +1,6 @@ package com.redhat.composer.model.request; +import java.util.List; import java.util.Objects; import org.apache.commons.lang3.builder.EqualsBuilder; @@ -7,110 +8,141 @@ @SuppressWarnings("all") public class AssistantCreationRequest { - String name; - String displayName; - String description; - - String llmConnectionId; - - String retrieverConnectionId; - - - public AssistantCreationRequest() { - } - - public AssistantCreationRequest(String name, String displayName, String description, String llmConnectionId, String retrieverConnectionId) { - this.name = name; - this.displayName = displayName; - this.description = description; - this.llmConnectionId = llmConnectionId; - this.retrieverConnectionId = retrieverConnectionId; - } - - public String getName() { - return this.name; - } - - public void setName(String name) { - this.name = name; - } - - public String getDisplayName() { - return this.displayName; - } - - public void setDisplayName(String displayName) { - this.displayName = displayName; - } - - public String getDescription() { - return this.description; - } - - public void setDescription(String description) { - this.description = description; - } - - public String getLlmConnectionId() { - return this.llmConnectionId; - } - - public void setLlmConnectionId(String llmConnectionId) { - this.llmConnectionId = llmConnectionId; - } - - public String getRetrieverConnectionId() { - return this.retrieverConnectionId; - } - - public void setRetrieverConnectionId(String retrieverConnectionId) { - this.retrieverConnectionId = retrieverConnectionId; - } - - public AssistantCreationRequest name(String name) { - setName(name); - return this; - } - - public AssistantCreationRequest displayName(String displayName) { - setDisplayName(displayName); - return this; - } - - public AssistantCreationRequest description(String description) { - setDescription(description); - return this; - } - - public AssistantCreationRequest llmConnectionId(String llmConnectionId) { - setLlmConnectionId(llmConnectionId); - return this; - } - - public AssistantCreationRequest retrieverConnectionId(String retrieverConnectionId) { - setRetrieverConnectionId(retrieverConnectionId); - return this; - } - - @Override + String name; + String displayName; + String description; + String userPrompt; + List exampleQuestions; + String llmConnectionId; + String retrieverConnectionId; + + public AssistantCreationRequest() { + } + + public AssistantCreationRequest(String name, String displayName, String description, String userPrompt, List exampleQuestions, String llmConnectionId, String retrieverConnectionId) { + this.name = name; + this.displayName = displayName; + this.description = description; + this.userPrompt = userPrompt; + this.exampleQuestions = exampleQuestions; + this.llmConnectionId = llmConnectionId; + this.retrieverConnectionId = retrieverConnectionId; + } + + public String getName() { + return this.name; + } + + public void setName(String name) { + this.name = name; + } + + public String getDisplayName() { + return this.displayName; + } + + public void setDisplayName(String displayName) { + this.displayName = displayName; + } + + public String getDescription() { + return this.description; + } + + public void setDescription(String description) { + this.description = description; + } + + + public String getUserPrompt() { + return userPrompt; + } + + public void setUserPrompt(String userPrompt) { + this.userPrompt = userPrompt; + } + + public List getExampleQuestions() { + return exampleQuestions; + } + + public void setExampleQuestions(List exampleQuestions) { + this.exampleQuestions = exampleQuestions; + } + + public String getLlmConnectionId() { + return this.llmConnectionId; + } + + public void setLlmConnectionId(String llmConnectionId) { + this.llmConnectionId = llmConnectionId; + } + + public String getRetrieverConnectionId() { + return this.retrieverConnectionId; + } + + public void setRetrieverConnectionId(String retrieverConnectionId) { + this.retrieverConnectionId = retrieverConnectionId; + } + + + public AssistantCreationRequest name(String name) { + setName(name); + return this; + } + + public AssistantCreationRequest displayName(String displayName) { + setDisplayName(displayName); + return this; + } + + public AssistantCreationRequest description(String description) { + setDescription(description); + return this; + } + + public AssistantCreationRequest userPrompt(String userPrompt) { + setUserPrompt(userPrompt); + return this; + } + + public AssistantCreationRequest exampleQuestions(List exampleQuestions) { + setExampleQuestions(exampleQuestions); + return this; + } + + public AssistantCreationRequest llmConnectionId(String llmConnectionId) { + setLlmConnectionId(llmConnectionId); + return this; + } + + public AssistantCreationRequest retrieverConnectionId(String retrieverConnectionId) { + setRetrieverConnectionId(retrieverConnectionId); + return this; + } + + @Override public boolean equals(Object o) { - return EqualsBuilder.reflectionEquals(this, o); - } - - @Override - public int hashCode() { - return Objects.hash(name, displayName, description, llmConnectionId, retrieverConnectionId); - } - - @Override - public String toString() { - return "{" + - " name='" + getName() + "'" + - ", displayName='" + getDisplayName() + "'" + - ", description='" + getDescription() + "'" + - ", llmConnectionId='" + getLlmConnectionId() + "'" + - ", retrieverConnectionId='" + getRetrieverConnectionId() + "'" + - "}"; - } + return EqualsBuilder.reflectionEquals(this, o); + } + + @Override + public int hashCode() { + return Objects.hash(name, displayName, description, userPrompt, exampleQuestions, llmConnectionId, retrieverConnectionId); + } + + @Override + public String toString() { + return "{" + + " name='" + getName() + "'" + + ", displayName='" + getDisplayName() + "'" + + ", description='" + getDescription() + "'" + + ", userPrompt='" + getUserPrompt() + "'" + + ", exampleQuestions='" + getExampleQuestions() + "'" + + ", llmConnectionId='" + getLlmConnectionId() + "'" + + ", retrieverConnectionId='" + getRetrieverConnectionId() + "'" + + "}"; + } } \ No newline at end of file diff --git a/src/main/java/com/redhat/composer/services/AssistantInfoService.java b/src/main/java/com/redhat/composer/services/AssistantInfoService.java index b4b08fa..d64410f 100644 --- a/src/main/java/com/redhat/composer/services/AssistantInfoService.java +++ b/src/main/java/com/redhat/composer/services/AssistantInfoService.java @@ -28,51 +28,56 @@ public class AssistantInfoService { /** * Create an Assistant. + * * @param request the AssistantCreationRequest * @return the AssistantEntity */ public AssistantEntity createAssistant(AssistantCreationRequest request) { AssistantEntity assistant = new AssistantEntity(); LlmConnectionEntity llm = (LlmConnectionEntity) LlmConnectionEntity.findByIdOptional( - new ObjectId(request.getLlmConnectionId())) - .orElseThrow(() -> new IllegalArgumentException("LLM Connection not found")); - + new ObjectId(request.getLlmConnectionId())) + .orElseThrow(() -> new IllegalArgumentException("LLM Connection not found")); + assistant.setLlmConnectionId(llm.id); - if (request.getRetrieverConnectionId() != null) { + if (request.getRetrieverConnectionId() != null) { RetrieverConnectionEntity retriever = (RetrieverConnectionEntity) RetrieverConnectionEntity - .findByIdOptional(new ObjectId(request.getRetrieverConnectionId())) - .orElseThrow(() -> new IllegalArgumentException("Retriever Connection not found")); + .findByIdOptional(new ObjectId(request.getRetrieverConnectionId())) + .orElseThrow(() -> new IllegalArgumentException("Retriever Connection not found")); assistant.setRetrieverConnectionId(retriever.id); } assistant.setName(request.getName()); assistant.setDisplayName(request.getDisplayName()); assistant.setDescription(request.getDescription()); + assistant.setUserPrompt(request.getUserPrompt()); + assistant.setExampleQuestions(request.getExampleQuestions()); assistant.persist(); return assistant; } /** * Get all Assistants. + * * @return a list of AssistantResponse */ public List getAssistant() { Stream stream = AssistantEntity.streamAll(); return stream.map(entity -> { - AssistantResponse response = new AssistantResponse(); - response.id = entity.id; - response.setName(entity.getName()); - response.setDisplayName(entity.getDisplayName()); - response.setDescription(entity.getDescription()); - response.setLlmConnection(LlmConnectionEntity.findById(entity.getLlmConnectionId())); - response.setRetrieverConnection(RetrieverConnectionEntity.findById(entity.getRetrieverConnectionId())); - return response; - } + AssistantResponse response = new AssistantResponse(); + response.id = entity.id; + response.setName(entity.getName()); + response.setDisplayName(entity.getDisplayName()); + response.setDescription(entity.getDescription()); + response.setLlmConnection(LlmConnectionEntity.findById(entity.getLlmConnectionId())); + response.setRetrieverConnection(RetrieverConnectionEntity.findById(entity.getRetrieverConnectionId())); + return response; + } ).toList(); } /** * Create a RetrieverConnectionEntity. + * * @param request the RetrieverRequest * @return the RetrieverConnectionEntity */ @@ -85,9 +90,10 @@ public RetrieverConnectionEntity createRetrieverConnectionEntity(RetrieverReques public List getRetrieverConnections() { return RetrieverConnectionEntity.listAll(); } - + /** * Create a LLMConnectionEntity. + * * @param request the LLMRequest * @return the LLMConnectionEntity */ @@ -104,10 +110,11 @@ public LlmConnectionEntity createLlmConnection(LLMRequest request) { /** * Get all LLMConnections. + * * @return a list of LlmConnectionEntity */ public List getLlmConnections() { return LlmConnectionEntity.listAll(); } - + } From 09b47d4cde6aa0a39ea922317d370d5e5f0d93e1 Mon Sep 17 00:00:00 2001 From: Javier Ochoa Date: Mon, 2 Dec 2024 12:56:49 -0500 Subject: [PATCH 2/2] Fix assistant retrieval --- .../composer/model/mongo/AssistantEntity.java | 13 +++++ .../services/AssistantInfoService.java | 2 + src/main/resources/db/changeLog.yml | 48 +++++++++---------- 3 files changed, 39 insertions(+), 24 deletions(-) diff --git a/src/main/java/com/redhat/composer/model/mongo/AssistantEntity.java b/src/main/java/com/redhat/composer/model/mongo/AssistantEntity.java index 9ecdcba..cd18485 100644 --- a/src/main/java/com/redhat/composer/model/mongo/AssistantEntity.java +++ b/src/main/java/com/redhat/composer/model/mongo/AssistantEntity.java @@ -129,6 +129,19 @@ public AssistantEntity displayName(String displayName) return this; } + public AssistantEntity userPrompt(String userPrompt) + { + setUserPrompt(userPrompt); + return this; + } + + public AssistantEntity exampleQuestions(List exampleQuestions) + { + setExampleQuestions(exampleQuestions); + return this; + } + + public AssistantEntity llmConnectionId(ObjectId llmConnectionId) { setLlmConnectionId(llmConnectionId); diff --git a/src/main/java/com/redhat/composer/services/AssistantInfoService.java b/src/main/java/com/redhat/composer/services/AssistantInfoService.java index d64410f..5b068d2 100644 --- a/src/main/java/com/redhat/composer/services/AssistantInfoService.java +++ b/src/main/java/com/redhat/composer/services/AssistantInfoService.java @@ -68,6 +68,8 @@ public List getAssistant() { response.setName(entity.getName()); response.setDisplayName(entity.getDisplayName()); response.setDescription(entity.getDescription()); + response.setUserPrompt(entity.getUserPrompt()); + response.setExampleQuestions(entity.getExampleQuestions()); response.setLlmConnection(LlmConnectionEntity.findById(entity.getLlmConnectionId())); response.setRetrieverConnection(RetrieverConnectionEntity.findById(entity.getRetrieverConnectionId())); return response; diff --git a/src/main/resources/db/changeLog.yml b/src/main/resources/db/changeLog.yml index 847489a..2475789 100644 --- a/src/main/resources/db/changeLog.yml +++ b/src/main/resources/db/changeLog.yml @@ -1,19 +1,19 @@ databaseChangeLog: - - property: - name: rc.ocp.id - value: "new ObjectId('66ec94c319c6d4b0c5a27cee')" - - property: - name: rc.rhel.id - value: "new ObjectId('66ed89833baf3b4ab30da4e2')" - - property: - name: rc.ansible.id - value: "new ObjectId('66ed8cb7bfed0c6c6d324dc6')" - - property: - name: rc.rhoai.id - value: "new ObjectId('66ed8cc3ad4df17295dc7a16')" - - property: - name: llm.default.id - value: "new ObjectId('66edae13e03073de9ef24204')" + - property: + name: rc.ocp.id + value: "new ObjectId('66ec94c319c6d4b0c5a27cee')" + - property: + name: rc.rhel.id + value: "new ObjectId('66ed89833baf3b4ab30da4e2')" + - property: + name: rc.ansible.id + value: "new ObjectId('66ed8cb7bfed0c6c6d324dc6')" + - property: + name: rc.rhoai.id + value: "new ObjectId('66ed8cc3ad4df17295dc7a16')" + - property: + name: llm.default.id + value: "new ObjectId('66edae13e03073de9ef24204')" # Create default assistants for weaviate - changeSet: id: 2 @@ -52,21 +52,21 @@ databaseChangeLog: collectionName: assistant - insertOne: collectionName: assistant - document: "{_id: new ObjectId('66edae0a18642fee8cb88587'), retrieverConnectionId: ${rc.ocp.id}, llmConnectionId: ${llm.default.id}, name: 'default_ocp', displayName: 'Default Openshift Container Platform Assistant'}" + document: "{_id: new ObjectId('66edae0a18642fee8cb88587'), retrieverConnectionId: ${rc.ocp.id}, llmConnectionId: ${llm.default.id}, name: 'default_ocp', displayName: 'Default Openshift Container Platform Assistant', userPrompt: 'you are an expert in Openshift platform', exampleQuestions: ['how can I allow external access to my pod?','what can I use to externalize secrets?']}" - insertOne: collectionName: assistant - document: "{_id: new ObjectId('66edae19c9a1bd1c8905b865'), retrieverConnectionId: ${rc.rhel.id}, llmConnectionId: ${llm.default.id}, name: 'default_rhel' , displayName: 'Default Red Hat Enterprise Linux Assistant'}" + document: "{_id: new ObjectId('66edae19c9a1bd1c8905b865'), retrieverConnectionId: ${rc.rhel.id}, llmConnectionId: ${llm.default.id}, name: 'default_rhel' , displayName: 'Default Red Hat Enterprise Linux Assistant', userPrompt:'user prompt text', exampleQuestions:['question1','question2']}" - insertOne: collectionName: assistant - document: "{_id: new ObjectId('66edae2255ce9f7058f2c472'), retrieverConnectionId: ${rc.ansible.id}, llmConnectionId: ${llm.default.id}, name: 'default_ansible' , displayName: 'Default Ansible Automation Platform Assistant'}" + document: "{_id: new ObjectId('66edae2255ce9f7058f2c472'), retrieverConnectionId: ${rc.ansible.id}, llmConnectionId: ${llm.default.id}, name: 'default_ansible' , displayName: 'Default Ansible Automation Platform Assistant', userPrompt:'user prompt text', exampleQuestions:['question1','question2']}" - insertOne: collectionName: assistant - document: "{_id: new ObjectId('66edae2738a7f2388fb02cd8'), retrieverConnectionId: ${rc.rhoai.id}, llmConnectionId: ${llm.default.id}, name: 'default_rhoai' , displayName: 'Default Red Hat Openshift AI Self Managed Assistant'}" + document: "{_id: new ObjectId('66edae2738a7f2388fb02cd8'), retrieverConnectionId: ${rc.rhoai.id}, llmConnectionId: ${llm.default.id}, name: 'default_rhoai' , displayName: 'Default Red Hat Openshift AI Self Managed Assistant', userPrompt:'user prompt text', exampleQuestions:['question1','question2']}" # Create default assistants for neo4j - - property: - name: neo4j.default.id - value: "new ObjectId('66f3fbffd7e04770c03ee123')" + - property: + name: neo4j.default.id + value: "new ObjectId('66f3fbffd7e04770c03ee123')" - changeSet: id: 3 author: quarkus @@ -87,7 +87,7 @@ databaseChangeLog: collectionName: assistant - insertOne: collectionName: assistant - document: "{_id: new ObjectId('66f3fc14de104310e21acd67'), retrieverConnectionId: ${neo4j.default.id}, llmConnectionId: ${llm.default.id}, name: 'default_neo4j', displayName: 'Default Neo4J Assistant'}" + document: "{_id: new ObjectId('66f3fc14de104310e21acd67'), retrieverConnectionId: ${neo4j.default.id}, llmConnectionId: ${llm.default.id}, name: 'default_neo4j', displayName: 'Default Neo4J Assistant', userPrompt:'user prompt text', exampleQuestions:['question1','question2']}" # Create default assistant just using the default LLM connection - changeSet: @@ -100,4 +100,4 @@ databaseChangeLog: # Create Assistants collection - insertOne: collectionName: assistant - document: "{_id: new ObjectId('672c1da8cc502c4c6ccad746'), llmConnectionId: ${llm.default.id}, name: 'default_assistant', displayName: 'Default Assistant'}" + document: "{_id: new ObjectId('672c1da8cc502c4c6ccad746'), llmConnectionId: ${llm.default.id}, name: 'default_assistant', displayName: 'Default Assistant', userPrompt:'user prompt text', exampleQuestions:['question1','question2']}"