Skip to content

Commit

Permalink
Merge pull request #490 from devoxx/issue-442
Browse files Browse the repository at this point in the history
Feat #442 Support OpenAI o1 and o3 models
  • Loading branch information
stephanj authored Feb 6, 2025
2 parents 7d7f7e2 + 1198222 commit c82d450
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 32 deletions.
5 changes: 3 additions & 2 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@ tasks.named("buildPlugin") {
}

dependencies {
val lg4j_version = "0.36.2"
val lg4j_version = "1.0.0-beta1"

// Add the dependencies for the core module
implementation(project(":core"))

implementation("dev.langchain4j:langchain4j:$lg4j_version")
implementation("dev.langchain4j:langchain4j-ollama:$lg4j_version")
implementation("dev.langchain4j:langchain4j-local-ai:$lg4j_version")
implementation("dev.langchain4j:langchain4j-open-ai:$lg4j_version")
implementation("dev.langchain4j:langchain4j-open-ai:1.0.0-alpha2-SNAPSHOT")
implementation("dev.langchain4j:langchain4j-anthropic:$lg4j_version")
implementation("dev.langchain4j:langchain4j-bedrock:$lg4j_version")
implementation("dev.langchain4j:langchain4j-mistral-ai:$lg4j_version")
Expand All @@ -66,6 +66,7 @@ dependencies {
implementation("dev.langchain4j:langchain4j-azure-open-ai:$lg4j_version")
implementation("dev.langchain4j:langchain4j-chroma:$lg4j_version")

implementation("com.squareup.retrofit2:converter-gson:2.11.0")
implementation("org.xerial:sqlite-jdbc:3.48.0.0")

implementation("com.github.docker-java:docker-java:3.4.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import com.devoxx.genie.model.enumarations.ModelProvider;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.chat.request.ChatRequestParameters;
import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.model.openai.OpenAiStreamingChatModel;
import org.jetbrains.annotations.NotNull;
Expand All @@ -19,34 +20,42 @@ public class OpenAIChatModelFactory implements ChatModelFactory {

@Override
public ChatLanguageModel createChatModel(@NotNull ChatModel chatModel) {
boolean isO1 = chatModel.getModelName().startsWith("o1-");

final var builder = OpenAiChatModel.builder()
return OpenAiChatModel.builder()
.apiKey(getApiKey(MODEL_PROVIDER))
.modelName(chatModel.getModelName())
.defaultRequestParameters(createChatContextParameters(chatModel))
.maxRetries(chatModel.getMaxRetries())
.temperature(isO1 ? 1.0 : chatModel.getTemperature())
.timeout(Duration.ofSeconds(chatModel.getTimeout()))
.topP(isO1 ? 1.0 : chatModel.getTopP());

return builder.build();
.build();
}

@Override
public StreamingChatLanguageModel createStreamingChatModel(@NotNull ChatModel chatModel) {
boolean isO1 = chatModel.getModelName().startsWith("o1-");
final var builder = OpenAiStreamingChatModel.builder()
return OpenAiStreamingChatModel.builder()
.apiKey(getApiKey(MODEL_PROVIDER))
.defaultRequestParameters(createChatContextParameters(chatModel))
.modelName(chatModel.getModelName())
.temperature(isO1 ? 1.0 : chatModel.getTemperature())
.topP(isO1 ? 1.0 : chatModel.getTopP())
.timeout(Duration.ofSeconds(chatModel.getTimeout()));

return builder.build();
.timeout(Duration.ofSeconds(chatModel.getTimeout()))
.build();
}

@Override
public List<LanguageModel> getModels() {
return getModels(MODEL_PROVIDER);
}

private ChatRequestParameters createChatContextParameters(@NotNull ChatModel chatModel) {
boolean isO1 = chatModel.getModelName().toLowerCase().startsWith("o1");
boolean isO3 = chatModel.getModelName().toLowerCase().startsWith("o3");

if (isO1 || isO3) {
// o1 and o3 models do not support temperature and topP
return ChatRequestParameters.builder().build();
} else {
return ChatRequestParameters.builder()
.temperature(chatModel.getTemperature())
.topP(chatModel.getTopP())
.build();
}
}
}
31 changes: 15 additions & 16 deletions src/main/java/com/devoxx/genie/service/LLMModelRegistryService.java
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@ private void addAnthropicModels() {

private void addOpenAiModels() {

// TODO Add o3 and o3-mini when available in around Feb 2025

// Not yet available via API
// String o3Model = OpenAIChatModelName.O3.toString();
// models.put(ModelProvider.OpenAI.getName() + ":" + o3Model,
// LanguageModel.builder()
Expand All @@ -110,19 +109,19 @@ private void addOpenAiModels() {
// .outputMaxTokens(100_000)
// .apiKeyUsed(true)
// .build());
//
// String o3MiniModel = OpenAIChatModelName.O3_MINI.toString();
// models.put(ModelProvider.OpenAI.getName() + ":" + o3Model,
// LanguageModel.builder()
// .provider(ModelProvider.OpenAI)
// .modelName(o3Model)
// .displayName("o3-mini")
// .inputCost(5)
// .outputCost(15)
// .inputMaxTokens(200_000)
// .outputMaxTokens(100_000)
// .apiKeyUsed(true)
// .build());

String o3MiniModel = OpenAIChatModelName.O3_MINI.toString();
models.put(ModelProvider.OpenAI.getName() + ":" + o3MiniModel,
LanguageModel.builder()
.provider(ModelProvider.OpenAI)
.modelName(o3MiniModel)
.displayName("o3-mini")
.inputCost(5)
.outputCost(15)
.inputMaxTokens(200_000)
.outputMaxTokens(100_000)
.apiKeyUsed(true)
.build());

String o1Model = OpenAIChatModelName.O1.toString();
models.put(ModelProvider.OpenAI.getName() + ":" + o1Model,
Expand All @@ -142,7 +141,7 @@ private void addOpenAiModels() {
LanguageModel.builder()
.provider(ModelProvider.OpenAI)
.modelName(o1Mini)
.displayName("o1 mini")
.displayName("o1-mini")
.inputCost(5)
.outputCost(15)
.inputMaxTokens(128_000)
Expand Down

0 comments on commit c82d450

Please sign in to comment.