Skip to content

Commit

Permalink
Resolving issues
Browse files Browse the repository at this point in the history
  • Loading branch information
Suhas-Koheda committed Dec 7, 2024
1 parent 69660a4 commit 3107cb9
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 233 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
package dev.langchain4j.googleaigemini.spring;

import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.googleai.GoogleAiEmbeddingModel;
import dev.langchain4j.model.googleai.GoogleAiGeminiChatModel;
import dev.langchain4j.model.googleai.GoogleAiGeminiStreamingChatModel;
Expand All @@ -22,60 +19,60 @@ public class AutoConfig {

@Bean
@ConditionalOnProperty(name = PREFIX + ".chatModel.enabled", havingValue = "true")
ChatLanguageModel googleAiGeminiChatModel(Properties properties) {
GoogleAiGeminiChatModel googleAiGeminiChatModel(Properties properties) {
ChatModelProperties chatModelProperties = properties.getChatModel();
return GoogleAiGeminiChatModel.builder()
.apiKey(properties.getApiKey())
.modelName(chatModelProperties.getModelName())
.temperature(chatModelProperties.getTemperature())
.topP(chatModelProperties.getTopP())
.topK(chatModelProperties.getTopK())
.maxOutputTokens(chatModelProperties.getMaxOutputTokens())
.responseFormat(chatModelProperties.getResponseFormat())
.logRequestsAndResponses(chatModelProperties.getLogRequestsAndResponses())
.modelName(chatModelProperties.modelName())
.temperature(chatModelProperties.temperature())
.topP(chatModelProperties.topP())
.topK(chatModelProperties.topK())
.maxOutputTokens(chatModelProperties.maxOutputTokens())
.responseFormat(chatModelProperties.responseFormat())
.logRequestsAndResponses(chatModelProperties.logRequestsAndResponses())
.safetySettings(
Map.of(chatModelProperties.getSafetySetting().getGeminiHarmCategory(),
chatModelProperties.getSafetySetting().getGeminiHarmBlockThreshold()))
Map.of(chatModelProperties.safetySetting().geminiHarmCategory(),
chatModelProperties.safetySetting().geminiHarmBlockThreshold()))
.toolConfig(
chatModelProperties.getFunctionCallingConfig().getGeminiMode(),
chatModelProperties.getFunctionCallingConfig().getAllowedFunctionNames().toArray(new String[0]))
chatModelProperties.functionCallingConfig().getGeminiMode(),
chatModelProperties.functionCallingConfig().getAllowedFunctionNames().toArray(new String[0]))
.build();
}

@Bean
@ConditionalOnProperty(name = PREFIX + ".streamingChatModel.enabled", havingValue = "true")
StreamingChatLanguageModel googleAiGeminiStreamingChatModel(Properties properties) {
GoogleAiGeminiStreamingChatModel googleAiGeminiStreamingChatModel(Properties properties) {
ChatModelProperties chatModelProperties = properties.getStreamingChatModel();
return GoogleAiGeminiStreamingChatModel.builder()
.apiKey(properties.getApiKey())
.modelName(chatModelProperties.getModelName())
.temperature(chatModelProperties.getTemperature())
.topP(chatModelProperties.getTopP())
.topK(chatModelProperties.getTopK())
.responseFormat(chatModelProperties.getResponseFormat())
.logRequestsAndResponses(chatModelProperties.getLogRequestsAndResponses())
.modelName(chatModelProperties.modelName())
.temperature(chatModelProperties.temperature())
.topP(chatModelProperties.topP())
.topK(chatModelProperties.topK())
.responseFormat(chatModelProperties.responseFormat())
.logRequestsAndResponses(chatModelProperties.logRequestsAndResponses())
.safetySettings(
Map.of(chatModelProperties.getSafetySetting().getGeminiHarmCategory(),
chatModelProperties.getSafetySetting().getGeminiHarmBlockThreshold()))
Map.of(chatModelProperties.safetySetting().geminiHarmCategory(),
chatModelProperties.safetySetting().geminiHarmBlockThreshold()))
.toolConfig(
chatModelProperties.getFunctionCallingConfig().getGeminiMode(),
chatModelProperties.getFunctionCallingConfig().getAllowedFunctionNames().toArray(new String[0]))
chatModelProperties.functionCallingConfig().getGeminiMode(),
chatModelProperties.functionCallingConfig().getAllowedFunctionNames().toArray(new String[0]))
.build();
}

@Bean
@ConditionalOnProperty(name = PREFIX + ".embeddingModel.enabled", havingValue = "true")
EmbeddingModel googleAiGeminiEmbeddingModel(Properties properties) {
GoogleAiEmbeddingModel googleAiGeminiEmbeddingModel(Properties properties) {
EmbeddingModelProperties embeddingModelProperties = properties.getEmbeddingModel();
return GoogleAiEmbeddingModel.builder()
.apiKey(properties.getApiKey())
.modelName(embeddingModelProperties.getModelName())
.logRequestsAndResponses(embeddingModelProperties.isLogRequestsAndResponses())
.maxRetries(embeddingModelProperties.getMaxRetries())
.outputDimensionality(embeddingModelProperties.getOutputDimensionality())
.taskType(embeddingModelProperties.getTaskType())
.timeout(embeddingModelProperties.getTimeout())
.titleMetadataKey(embeddingModelProperties.getTitleMetadataKey())
.modelName(embeddingModelProperties.modelName())
.logRequestsAndResponses(embeddingModelProperties.logRequestsAndResponses())
.maxRetries(embeddingModelProperties.maxRetries())
.outputDimensionality(embeddingModelProperties.outputDimensionality())
.taskType(embeddingModelProperties.taskType())
.timeout(embeddingModelProperties.timeout())
.titleMetadataKey(embeddingModelProperties.titleMetadataKey())
.build();
}
}
Original file line number Diff line number Diff line change
@@ -1,112 +1,21 @@
package dev.langchain4j.googleaigemini.spring;

import dev.langchain4j.model.chat.request.ResponseFormat;
import org.springframework.boot.context.properties.NestedConfigurationProperty;

import java.time.Duration;

public class ChatModelProperties {

private String modelName;
private Double temperature;
private Double topP;
private Integer topK;
private Integer maxOutputTokens;
private ResponseFormat responseFormat;
private Boolean logRequestsAndResponses;
private Integer maxRetries;
private Duration timeout;

private GeminiSafetySetting safetySetting;

private GeminiFunctionCallingConfig functionCallingConfig;

public Integer getMaxRetries() {
return maxRetries;
}

public void setMaxRetries(Integer maxRetries) {
this.maxRetries = maxRetries;
}

public Duration getTimeout() {
return timeout;
}

public void setTimeout(Duration timeout) {
this.timeout = timeout;
}

public ResponseFormat getResponseFormat() {
return responseFormat;
}

public void setResponseFormat(ResponseFormat responseFormat) {
this.responseFormat = responseFormat;
}

public String getModelName() {
return modelName;
}

public void setModelName(String modelName) {
this.modelName = modelName;
}

public Double getTemperature() {
return temperature;
}

public void setTemperature(Double temperature) {
this.temperature = temperature;
}

public Double getTopP() {
return topP;
}

public void setTopP(Double topP) {
this.topP = topP;
}

public Integer getTopK() {
return topK;
}

public void setTopK(Integer topK) {
this.topK = topK;
}

public Boolean getLogRequestsAndResponses() {
return logRequestsAndResponses;
}

public void setLogRequestsAndResponses(Boolean logRequestsAndResponses) {
this.logRequestsAndResponses = logRequestsAndResponses;
}

public GeminiSafetySetting getSafetySetting() {
return safetySetting;
}

public void setSafetySetting(GeminiSafetySetting safetySetting) {
this.safetySetting = safetySetting;
}

public GeminiFunctionCallingConfig getFunctionCallingConfig() {
return functionCallingConfig;
}

public void setFunctionCallingConfig(GeminiFunctionCallingConfig functionCallingConfig) {
this.functionCallingConfig = functionCallingConfig;
}

public Integer getMaxOutputTokens() {
return maxOutputTokens;
}

public void setMaxOutputTokens(Integer maxOutputTokens) {
this.maxOutputTokens = maxOutputTokens;
}

}
package dev.langchain4j.googleaigemini.spring;

import dev.langchain4j.model.chat.request.ResponseFormat;

import java.time.Duration;


public record ChatModelProperties(
String modelName,
Double temperature,
Double topP,
Integer topK,
Integer maxOutputTokens,
ResponseFormat responseFormat,
Boolean logRequestsAndResponses,
Integer maxRetries,
Duration timeout,
GeminiSafetySetting safetySetting,
GeminiFunctionCallingConfig functionCallingConfig
) {
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,69 +5,10 @@

import java.time.Duration;

public class EmbeddingModelProperties {

private String titleMetadataKey;
private String modelName;
private Boolean logRequestsAndResponses;
private Integer maxRetries;
private Integer outputDimensionality;
private TaskType taskType;
private Duration timeout;

public String getTitleMetadataKey() {
return titleMetadataKey;
}

public void setTitleMetadataKey(String titleMetadataKey) {
this.titleMetadataKey = titleMetadataKey;
}

public String getModelName() {
return modelName;
}

public void setModelName(String modelName) {
this.modelName = modelName;
}

public Boolean isLogRequestsAndResponses() {
return logRequestsAndResponses;
}

public void setLogRequestsAndResponses(boolean logRequestsAndResponses) {
this.logRequestsAndResponses = logRequestsAndResponses;
}

public Integer getMaxRetries() {
return maxRetries;
}

public void setMaxRetries(Integer maxRetries) {
this.maxRetries = maxRetries;
}

public Integer getOutputDimensionality() {
return outputDimensionality;
}

public void setOutputDimensionality(Integer outputDimensionality) {
this.outputDimensionality = outputDimensionality;
}

public GoogleAiEmbeddingModel.TaskType getTaskType() {
return taskType;
}

public void setTaskType(GoogleAiEmbeddingModel.TaskType taskType) {
this.taskType = taskType;
}

public Duration getTimeout() {
return timeout;
}

public void setTimeout(Duration timeout) {
this.timeout = timeout;
}
}
public record EmbeddingModelProperties( String titleMetadataKey,
String modelName,
Boolean logRequestsAndResponses,
Integer maxRetries,
Integer outputDimensionality,
TaskType taskType,
Duration timeout){}
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,5 @@
import dev.langchain4j.model.googleai.GeminiHarmBlockThreshold;
import dev.langchain4j.model.googleai.GeminiHarmCategory;

public class GeminiSafetySetting {

private GeminiHarmCategory geminiHarmCategory;
private GeminiHarmBlockThreshold geminiHarmBlockThreshold;

public GeminiHarmCategory getGeminiHarmCategory() {
return geminiHarmCategory;
}

public void setGeminiHarmCategory(GeminiHarmCategory geminiHarmCategory) {
this.geminiHarmCategory = geminiHarmCategory;
}

public GeminiHarmBlockThreshold getGeminiHarmBlockThreshold() {
return geminiHarmBlockThreshold;
}

public void setGeminiHarmBlockThreshold(GeminiHarmBlockThreshold geminiHarmBlockThreshold) {
this.geminiHarmBlockThreshold = geminiHarmBlockThreshold;
}
}
public record GeminiSafetySetting(GeminiHarmCategory geminiHarmCategory,
GeminiHarmBlockThreshold geminiHarmBlockThreshold){}

0 comments on commit 3107cb9

Please sign in to comment.