Skip to content

Commit

Permalink
feat(idp-extraction-connector): added compatibility with llms lacking…
Browse files Browse the repository at this point in the history
… system message support (#3578)

* feat(idp-extraction-connector): added compatibility with llms lacking system message support

* feat(idp-extraction-connector): moved system_prompt_variable_template to llm model

* feat(idp-extraction-connector): added vendor to llm model
  • Loading branch information
sahilbhatoacamunda authored and sbuettner committed Dec 5, 2024
1 parent 399408d commit 0fad67d
Show file tree
Hide file tree
Showing 2 changed files with 171 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,86 +8,59 @@

import io.camunda.connector.idp.extraction.model.ConverseData;
import io.camunda.connector.idp.extraction.model.ExtractionRequest;
import java.util.stream.Collectors;
import io.camunda.connector.idp.extraction.model.LlmModel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock;
import software.amazon.awssdk.services.bedrockruntime.model.ConversationRole;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse;
import software.amazon.awssdk.services.bedrockruntime.model.Message;
import software.amazon.awssdk.services.bedrockruntime.model.SystemContentBlock;

public class BedrockCaller {

private static final Logger LOGGER = LoggerFactory.getLogger(BedrockCaller.class);

private static final String EXTRACTED_TEXT_PLACEHOLDER_FOR_PROMPT = "{{extractedText}}";

private static final String TAXONOMY_PLACEHOLDER_FOR_PROMPT = "{{taxonomy}}";

private static final String SYSTEM_PROMPT_TEMPLATE =
"""
You will receive extracted text from a PDF document. This text will be between the <DOCUMENT_TEXT> tags.
Your task is to extract certain variables from the text. The description how to extract the variables is
between the <EXTRACTION> tags. Every variable is represented by a <VAR> tag. Every variable has a name,
which is represented by the <NAME> tag, as well as instructions which data to extract, which is represented
by the <PROMPT> tag.
Respond in JSON format, without any preamble. Example response:
{
"name": "John Smith",
"age": 32
}
Here is the document text as well as your instructions on which variables to extract:
<DOCUMENT_TEXT>%s</DOCUMENT_TEXT>
<EXTRACTION>%s</EXTRACTION>
"""
.formatted(EXTRACTED_TEXT_PLACEHOLDER_FOR_PROMPT, TAXONOMY_PLACEHOLDER_FOR_PROMPT);

private static final String SYSTEM_PROMPT_VARIABLE_TEMPLATE =
"""
<VAR>
<NAME>%s</NAME>
<PROMPT>%s</PROMPT>
</VAR>
""";

public String call(
ExtractionRequest extractionRequest,
String extractedText,
BedrockRuntimeClient bedrockRuntimeClient) {
LOGGER.debug("Calling AWS Bedrock model with extraction request: {}", extractionRequest);

String taxonomyItems =
extractionRequest.input().taxonomyItems().stream()
.map(item -> String.format(SYSTEM_PROMPT_VARIABLE_TEMPLATE, item.name(), item.prompt()))
.collect(Collectors.joining());

String prompt =
SYSTEM_PROMPT_TEMPLATE
.replace(EXTRACTED_TEXT_PLACEHOLDER_FOR_PROMPT, extractedText)
.replace(TAXONOMY_PLACEHOLDER_FOR_PROMPT, taxonomyItems);

Message message =
Message.builder()
.content(ContentBlock.fromText(prompt))
.role(ConversationRole.USER)
.build();

ConverseData converseData = extractionRequest.input().converseData();
LlmModel llmModel = LlmModel.fromId(converseData.modelId());

ConverseResponse response =
bedrockRuntimeClient.converse(
request ->
request
.modelId(converseData.modelId())
.messages(message)
.inferenceConfig(
config ->
config
.maxTokens(converseData.maxTokens())
.temperature(converseData.temperature())
.topP(converseData.topP())));
request -> {
String userMessage =
llmModel.getMessage(extractedText, extractionRequest.input().taxonomyItems());

if (llmModel.isSystemPromptAllowed()) {
SystemContentBlock prompt =
SystemContentBlock.builder().text(llmModel.getSystemPrompt()).build();
request.system(prompt);
} else {
userMessage = String.format("%s%n%s", llmModel.getSystemPrompt(), userMessage);
}

Message message =
Message.builder()
.content(ContentBlock.fromText(userMessage))
.role(ConversationRole.USER)
.build();

request
.modelId(converseData.modelId())
.messages(message)
.inferenceConfig(
config ->
config
.maxTokens(converseData.maxTokens())
.temperature(converseData.temperature())
.topP(converseData.topP()));
});

return response.output().message().content().getFirst().text();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
/*
* Copyright Camunda Services GmbH and/or licensed to Camunda Services GmbH
* under one or more contributor license agreements. Licensed under a proprietary license.
* See the License.txt file for more information. You may not use this file
* except in compliance with the proprietary license.
*/
package io.camunda.connector.idp.extraction.model;

import java.util.List;
import java.util.stream.Collectors;

public enum LlmModel {
CLAUDE("anthropic", getCommonSystemPrompt(), getCommonMessageTemplate()),
LLAMA(
"meta",
"""
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
%s
You are a helpful assistant with tool calling capabilities.
"""
.formatted(getCommonSystemInstruction()),
"""
<|eot_id|><|start_header_id|>user<|end_header_id|>
Given the following functions, please respond with a JSON for a function call with its proper arguments
that best answers the given prompts.
Respond in JSON format, without any preamble. Example response:
{
"name": "John Smith",
"age": 32
}
{
"type": "function",
"function": {
"name": "extract text variables",
"description": "extract every variable based on the given prompt in the question",
"parameters": {
"type": "object"
}
}
}
Question: Given the following functions, please respond with a JSON for a function call with its proper arguments
that best answers the given prompt.
%s
"""
.formatted(getCommonMessageTemplate())),
TITAN("amazon", getCommonSystemPrompt(), getCommonMessageTemplate());

private final String vendor;
private final String systemPrompt;
private final String messageTemplate;

private static final String EXTRACTED_TEXT_PLACEHOLDER_FOR_MESSAGE = "{{extractedText}}";
private static final String TAXONOMY_PLACEHOLDER_FOR_MESSAGE = "{{taxonomy}}";
private static final String SYSTEM_PROMPT_VARIABLE_TEMPLATE =
"""
<VAR>
<NAME>%s</NAME>
<PROMPT>%s</PROMPT>
</VAR>
""";

LlmModel(String vendor, String systemPrompt, String messageTemplate) {
this.vendor = vendor;
this.systemPrompt = systemPrompt;
this.messageTemplate = messageTemplate;
}

public String getSystemPrompt() {
return systemPrompt;
}

public String getVendor() {
return vendor;
}

public String getMessage(String extractedText, List<TaxonomyItem> taxonomyItems) {
String taxonomies =
taxonomyItems.stream()
.map(item -> String.format(SYSTEM_PROMPT_VARIABLE_TEMPLATE, item.name(), item.prompt()))
.collect(Collectors.joining());

return messageTemplate
.replace(EXTRACTED_TEXT_PLACEHOLDER_FOR_MESSAGE, extractedText)
.replace(TAXONOMY_PLACEHOLDER_FOR_MESSAGE, taxonomies);
}

public boolean isSystemPromptAllowed() {
return this != TITAN;
}

public static LlmModel fromId(String id) {
String modelId = id.toLowerCase();
if (modelId.contains(CLAUDE.getVendor())) {
return CLAUDE;
} else if (modelId.contains(LLAMA.getVendor())) {
return LLAMA;
} else if (modelId.contains(TITAN.getVendor())) {
return TITAN;
} else {
return CLAUDE;
}
}

private static String getCommonSystemInstruction() {
return """
You will receive extracted text from a PDF document. This text will be between the <DOCUMENT_TEXT> tags.
Your task is to extract certain variables from the text. The description how to extract the variables is
between the <EXTRACTION> tags. Every variable is represented by a <VAR> tag. Every variable has a name,
which is represented by the <NAME> tag, as well as instructions which data to extract, which is represented
by the <PROMPT> tag.
""";
}

private static String getCommonSystemPrompt() {
return """
%s
Respond in JSON format, without any preamble. Example response:
{
"name": "John Smith",
"age": 32
}
"""
.formatted(getCommonSystemInstruction());
}

private static String getCommonMessageTemplate() {
return """
Here is the document text as well as your instructions on which variables to extract:
<DOCUMENT_TEXT>%s</DOCUMENT_TEXT>
<EXTRACTION>%s</EXTRACTION>
"""
.formatted(EXTRACTED_TEXT_PLACEHOLDER_FOR_MESSAGE, TAXONOMY_PLACEHOLDER_FOR_MESSAGE);
}
}

0 comments on commit 0fad67d

Please sign in to comment.