Skip to content

Commit

Permalink
Adding default assistants for doing genearl c-suite items (#13)
Browse files Browse the repository at this point in the history
Co-authored-by: Jamie Land <[email protected]>
  • Loading branch information
jland-redhat and Jaland authored Dec 3, 2024
1 parent cafd79a commit 8ea67e9
Show file tree
Hide file tree
Showing 10 changed files with 311 additions and 84 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@ public class AiServicesFactory {

public static final String MISTRAL7B_AI_SERVICE = "mistral7b";

public static final String MISTRAL7B_QUARKUS_AI_SERVICE = "mistral7b_quarkus";

public static final String HEALTHCARE_SERVICE = "healthcare";
public static final String GRANITE_AI_SERVICE = "granite";

/**
* Get the AI service class.
Expand All @@ -23,8 +21,8 @@ public Class<? extends BaseAiService> getAiService(String aiServiceType) {
switch (aiServiceType) {
case MISTRAL7B_AI_SERVICE:
return Mistral7bAiService.class;
case HEALTHCARE_SERVICE:
return HealthCareService.class;
case GRANITE_AI_SERVICE:
return GraniteAiService.class;
default:
throw new RuntimeException("AI service type not found: " + aiServiceType);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ public interface BaseAiService {
* @param input User Message
* @return the TokenStream
*/
TokenStream chatToken(String context, String input);
TokenStream chatToken(String context, String input, String systemMessage);

/**
* Returns a Multi of String given input.
* @param context Context information such as chat history and source information
* @param input User Message
* @return the Multi of String
*/
Multi<String> chatStream(String context, String input);
Multi<String> chatStream(String context, String input, String systemMessage);

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package com.redhat.composer.config.llm.aiservices;

import dev.langchain4j.service.SystemMessage;
import dev.langchain4j.service.TokenStream;
import dev.langchain4j.service.UserMessage;
import dev.langchain4j.service.V;
import io.smallrye.mutiny.Multi;


/**
* Mistral7BAiService.
*/
@SuppressWarnings("LineLengthCheck")
public interface GraniteAiService extends BaseAiService {

// Note: Handling history is a little more complex than just passing it as a string
// See: https://www.ibm.com/granite/docs/models/granite/#using-the-multi-round-chat-example-finance
static final String userMessage = """
<|start_of_role|>system<|end_of_role|>{{systemMessage}}<|end_of_text|>
{context}
<|start_of_role|>user<|end_of_role|>{{input}}<|end_of_text|>
<|start_of_role|>assistant<|end_of_role|>
""";

/**
* Returns TokenStream given input.
* @param context Context information such as chat history and source information
* @param input User Message
* @return the TokenStream
*/
@SystemMessage("{{systemMessage}}")
@UserMessage(userMessage)
TokenStream chatToken(@V("context") String context, @V("input") String input, @V("systemMessage") String systemMessage);


/**
* Returns a Multi of String given input.
* @param context Context information such as chat history and source information
* @param input User Message
* @return the Multi of String
*/
@SystemMessage("{{systemMessage}}")
@UserMessage(userMessage)
Multi<String> chatStream(@V("context") String context, @V("input") String input, @V("systemMessage") String systemMessage);

}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import dev.langchain4j.service.SystemMessage;
import dev.langchain4j.service.TokenStream;
import dev.langchain4j.service.UserMessage;
import dev.langchain4j.service.V;
import io.smallrye.mutiny.Multi;


Expand All @@ -12,38 +13,23 @@
@SuppressWarnings("LineLengthCheck")
public interface Mistral7bAiService extends BaseAiService {

static final String systemMessage = """
You are a helpful, respectful and honest assistant answering questions about products from the company called Red Hat.
You will be given a question you need to answer about this product.
If a question is about a specific product you will be given the product name and version, and references to provide you with additional information.
You must answer the question basing yourself as much as possible on the given references if any.
Always answer as helpfully as possible, while being safe.
Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content.
Please ensure that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct.
If you don't know the answer to a question, please don't share false information.
""";

static final String userMessage = """
<context>
{context}
</context>
<|eot_id|>
<|start_header_id|>user<|end_header_id|>
Question: {input}
<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>
""";
{{systemMessage}}
<<<
Context: {context}
User Message: {input}
>>>
""";

/**
* Returns TokenStream given input.
* @param context Context information such as chat history and source information
* @param input User Message
* @return the TokenStream
*/
@SystemMessage(systemMessage)
@SystemMessage("{{systemMessage}}")
@UserMessage(userMessage)
TokenStream chatToken(String context, String input);
TokenStream chatToken(@V("context") String context, @V("input") String input, @V("systemMessage") String systemMessage);


/**
Expand All @@ -52,8 +38,8 @@ public interface Mistral7bAiService extends BaseAiService {
* @param input User Message
* @return the Multi of String
*/
@SystemMessage(systemMessage)
@SystemMessage("{{systemMessage}}")
@UserMessage(userMessage)
Multi<String> chatStream(String context, String input);
Multi<String> chatStream(@V("context") String context, @V("input") String input, @V("systemMessage") String systemMessage);

}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ public class OpenAiStreamingModel extends StreamingBaseModel {
@ConfigProperty(name = "openai.default.temp")
private double openaiDefaultTemp;

@ConfigProperty(name = "openai.default.maxTokens")
private Integer openaiDefaultMaxTokens;

/**
* Get the Chat Model.
* @param request the LLMRequest
Expand All @@ -49,6 +52,9 @@ public StreamingChatLanguageModel getChatModel(LLMRequest request) {

// TODO: Add all the following to the request
builder.temperature(openaiDefaultTemp);

builder.maxTokens(openaiDefaultMaxTokens);


// TODO: Fill all this out
// if (modelName != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,18 @@ public class ChatBotRequest{
private String message = "";
// This is where you can pass in chat history or other context
private String context = "";
private String systemMessage = "";
private RetrieverRequest retrieverRequest = new RetrieverRequest();
private LLMRequest modelRequest = new LLMRequest();


public ChatBotRequest() {
}

public ChatBotRequest(String message, String context, RetrieverRequest retrieverRequest, LLMRequest modelRequest) {
public ChatBotRequest(String message, String context, String systemMessage, RetrieverRequest retrieverRequest, LLMRequest modelRequest) {
this.message = message;
this.context = context;
this.systemMessage = systemMessage;
this.retrieverRequest = retrieverRequest;
this.modelRequest = modelRequest;
}
Expand All @@ -39,6 +41,14 @@ public void setContext(String context) {
this.context = context;
}

public String getSystemMessage() {
return this.systemMessage;
}

public void setSystemMessage(String systemMessage) {
this.systemMessage = systemMessage;
}

public RetrieverRequest getRetrieverRequest() {
return this.retrieverRequest;
}
Expand All @@ -65,6 +75,11 @@ public ChatBotRequest context(String context) {
return this;
}

public ChatBotRequest systemMessage(String systemMessage) {
setSystemMessage(systemMessage);
return this;
}

public ChatBotRequest retrieverRequest(RetrieverRequest retrieverRequest) {
setRetrieverRequest(retrieverRequest);
return this;
Expand All @@ -82,14 +97,15 @@ public boolean equals(Object o) {

@Override
public int hashCode() {
return Objects.hash(message, context, retrieverRequest, modelRequest);
return Objects.hash(message, context, systemMessage, retrieverRequest, modelRequest);
}

@Override
public String toString() {
return "{" +
" message='" + getMessage() + "'" +
", context='" + getContext() + "'" +
", systemMessage='" + getSystemMessage() + "'" +
", retrieverRequest='" + getRetrieverRequest() + "'" +
", modelRequest='" + getModelRequest() + "'" +
"}";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import java.util.ArrayList;
import java.util.List;

import org.eclipse.microprofile.config.inject.ConfigProperty;
import org.jboss.logging.Logger;

import com.fasterxml.jackson.core.JsonProcessingException;
Expand Down Expand Up @@ -35,6 +36,9 @@ public class ChatBotService {

Logger log = Logger.getLogger(ChatBotService.class);

@ConfigProperty(name = "prompt.default.system.message")
private String defaultSystemMessage;

@Inject
StreamingModelFactory modelTemplateFactory;

Expand Down Expand Up @@ -79,6 +83,7 @@ public Multi<String> chat(AssistantChatRequest request) {
chatBotRequest.setContext(request.getContext());
chatBotRequest.setRetrieverRequest(mapperUtil.toRequest(retrieverConnection));
chatBotRequest.setModelRequest(mapperUtil.toRequest(llmConnection));
chatBotRequest.setSystemMessage(assistant.getUserPrompt());

return chat(chatBotRequest);
}
Expand Down Expand Up @@ -112,8 +117,9 @@ public Multi<String> chat(ChatBotRequest request) {

try {
List<ContentResponse> contentSources = new ArrayList<ContentResponse>();
String systemMessage = request.getSystemMessage() == null ? defaultSystemMessage : request.getSystemMessage();
Multi<String> multi = Multi.createFrom().emitter(em -> {
aiService.chatToken(request.getContext(), request.getMessage())
aiService.chatToken(request.getContext(), request.getMessage(), systemMessage)
.onNext(em::emit)
.onRetrieved(sources -> {
contentSources.add(new ContentResponse(sources));
Expand Down
3 changes: 3 additions & 0 deletions src/main/resources/application.properties
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ openai.default.apiKey=abc123
# openai.default.apiKey=<REPLACE_ME>
openai.default.modelName=Mistral-7B-Instruct-v0.3
openai.default.temp=0.8
openai.default.maxTokens=500

prompt.default.system.message=I am a friendly AI assistant associated with the Red Hat Composer UI Project. And will answer any questions I am able.

# Phoenix configuration
# TODO: This is not implemented yet
Expand Down
Loading

0 comments on commit 8ea67e9

Please sign in to comment.