Skip to content

apply builder pattern to OllamaApi #2634

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
import org.springframework.ai.ollama.api.OllamaApi.Message.ToolCallFunction;
import org.springframework.ai.ollama.api.OllamaModel;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.ollama.api.common.OllamaApiConstants;
import org.springframework.ai.ollama.management.ModelManagementOptions;
import org.springframework.ai.ollama.management.OllamaModelManager;
import org.springframework.ai.ollama.management.PullModelStrategy;
Expand Down Expand Up @@ -207,7 +208,7 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespon

ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
.prompt(prompt)
.provider(OllamaApi.PROVIDER_NAME)
.provider(OllamaApiConstants.PROVIDER_NAME)
.requestOptions(prompt.getOptions())
.build();

Expand Down Expand Up @@ -279,7 +280,7 @@ private Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCh

final ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
.prompt(prompt)
.provider(OllamaApi.PROVIDER_NAME)
.provider(OllamaApiConstants.PROVIDER_NAME)
.requestOptions(prompt.getOptions())
.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import org.springframework.ai.ollama.api.OllamaApi.EmbeddingsResponse;
import org.springframework.ai.ollama.api.OllamaModel;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.ollama.api.common.OllamaApiConstants;
import org.springframework.ai.ollama.management.ModelManagementOptions;
import org.springframework.ai.ollama.management.OllamaModelManager;
import org.springframework.ai.ollama.management.PullModelStrategy;
Expand Down Expand Up @@ -112,7 +113,7 @@ public EmbeddingResponse call(EmbeddingRequest request) {

var observationContext = EmbeddingModelObservationContext.builder()
.embeddingRequest(request)
.provider(OllamaApi.PROVIDER_NAME)
.provider(OllamaApiConstants.PROVIDER_NAME)
.requestOptions(embeddingRequest.getOptions())
.build();

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -30,18 +30,17 @@
import com.fasterxml.jackson.annotation.JsonProperty;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.ai.ollama.api.common.OllamaApiConstants;
import org.springframework.ai.retry.RetryUtils;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.util.Assert;
import org.springframework.util.StreamUtils;
import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.RestClient;
import org.springframework.web.reactive.function.client.WebClient;
Expand All @@ -51,58 +50,74 @@
*
* @author Christian Tzolov
* @author Thomas Vitale
* @author Jonghoon Park
* @since 0.8.0
*/
// @formatter:off
public class OllamaApi {

public static final String PROVIDER_NAME = AiProvider.OLLAMA.value();
public static Builder builder() { return new Builder(); }

public static final String REQUEST_BODY_NULL_ERROR = "The request body can not be null.";

private static final Log logger = LogFactory.getLog(OllamaApi.class);

private static final String DEFAULT_BASE_URL = "http://localhost:11434";

private final ResponseErrorHandler responseErrorHandler;

private final RestClient restClient;

private final WebClient webClient;

/**
* Default constructor that uses the default localhost url.
*/
@Deprecated(since = "1.0.0.M6")
public OllamaApi() {
this(DEFAULT_BASE_URL);
this(OllamaApiConstants.DEFAULT_BASE_URL);
}

/**
* Crate a new OllamaApi instance with the given base url.
* @param baseUrl The base url of the Ollama server.
*/
@Deprecated(since = "1.0.0.M6")
public OllamaApi(String baseUrl) {
this(baseUrl, RestClient.builder(), WebClient.builder());
this(baseUrl, RestClient.builder(), WebClient.builder(), RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER);
}

/**
* Crate a new OllamaApi instance with the given base url and
* {@link RestClient.Builder}.
* @param baseUrl The base url of the Ollama server.
* @param restClientBuilder The {@link RestClient.Builder} to use.
* @param webClientBuilder The {@link WebClient.Builder} to use.
*/
@Deprecated(since = "1.0.0.M6")
public OllamaApi(String baseUrl, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder) {
this(baseUrl, restClientBuilder, webClientBuilder, RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER);
}

this.responseErrorHandler = new OllamaResponseErrorHandler();
/**
* Create a new OllamaApi instance
* @param baseUrl The base url of the Ollama server.
* @param restClientBuilder The {@link RestClient.Builder} to use.
* @param webClientBuilder The {@link WebClient.Builder} to use.
* @param responseErrorHandler Response error handler.
*/
public OllamaApi(String baseUrl, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) {

Consumer<HttpHeaders> defaultHeaders = headers -> {
headers.setContentType(MediaType.APPLICATION_JSON);
headers.setAccept(List.of(MediaType.APPLICATION_JSON));
};

this.restClient = restClientBuilder.baseUrl(baseUrl).defaultHeaders(defaultHeaders).build();
this.restClient = restClientBuilder.baseUrl(baseUrl)
.defaultHeaders(defaultHeaders)
.defaultStatusHandler(responseErrorHandler)
.build();

this.webClient = webClientBuilder.baseUrl(baseUrl).defaultHeaders(defaultHeaders).build();
this.webClient = webClientBuilder
.baseUrl(baseUrl)
.defaultHeaders(defaultHeaders)
.build();
}

/**
Expand All @@ -121,7 +136,6 @@ public ChatResponse chat(ChatRequest chatRequest) {
.uri("/api/chat")
.body(chatRequest)
.retrieve()
.onStatus(this.responseErrorHandler)
.body(ChatResponse.class);
}

Expand Down Expand Up @@ -188,7 +202,6 @@ public EmbeddingsResponse embed(EmbeddingsRequest embeddingsRequest) {
.uri("/api/embed")
.body(embeddingsRequest)
.retrieve()
.onStatus(this.responseErrorHandler)
.body(EmbeddingsResponse.class);
}

Expand All @@ -199,7 +212,6 @@ public ListModelResponse listModels() {
return this.restClient.get()
.uri("/api/tags")
.retrieve()
.onStatus(this.responseErrorHandler)
.body(ListModelResponse.class);
}

Expand All @@ -212,7 +224,6 @@ public ShowModelResponse showModel(ShowModelRequest showModelRequest) {
.uri("/api/show")
.body(showModelRequest)
.retrieve()
.onStatus(this.responseErrorHandler)
.body(ShowModelResponse.class);
}

Expand All @@ -225,7 +236,6 @@ public ResponseEntity<Void> copyModel(CopyModelRequest copyModelRequest) {
.uri("/api/copy")
.body(copyModelRequest)
.retrieve()
.onStatus(this.responseErrorHandler)
.toBodilessEntity();
}

Expand All @@ -238,7 +248,6 @@ public ResponseEntity<Void> deleteModel(DeleteModelRequest deleteModelRequest) {
.uri("/api/delete")
.body(deleteModelRequest)
.retrieve()
.onStatus(this.responseErrorHandler)
.toBodilessEntity();
}

Expand All @@ -261,26 +270,6 @@ public Flux<ProgressResponse> pullModel(PullModelRequest pullModelRequest) {
.bodyToFlux(ProgressResponse.class);
}

private static class OllamaResponseErrorHandler implements ResponseErrorHandler {

@Override
public boolean hasError(ClientHttpResponse response) throws IOException {
return response.getStatusCode().isError();
}

@Override
public void handleError(ClientHttpResponse response) throws IOException {
if (response.getStatusCode().isError()) {
int statusCode = response.getStatusCode().value();
String statusText = response.getStatusText();
String message = StreamUtils.copyToString(response.getBody(), java.nio.charset.StandardCharsets.UTF_8);
logger.warn(String.format("[%s] %s - %s", statusCode, statusText, message));
throw new RuntimeException(String.format("[%s] %s - %s", statusCode, statusText, message));
}
}

}

/**
* Chat message object.
*
Expand Down Expand Up @@ -736,5 +725,44 @@ public record ProgressResponse(
@JsonProperty("completed") Long completed
) { }

public static class Builder {

private String baseUrl = OllamaApiConstants.DEFAULT_BASE_URL;

private RestClient.Builder restClientBuilder = RestClient.builder();

private WebClient.Builder webClientBuilder = WebClient.builder();

private ResponseErrorHandler responseErrorHandler = RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER;

public Builder baseUrl(String baseUrl) {
Assert.hasText(baseUrl, "baseUrl cannot be null or empty");
this.baseUrl = baseUrl;
return this;
}

public Builder restClientBuilder(RestClient.Builder restClientBuilder) {
Assert.notNull(restClientBuilder, "restClientBuilder cannot be null");
this.restClientBuilder = restClientBuilder;
return this;
}

public Builder webClientBuilder(WebClient.Builder webClientBuilder) {
Assert.notNull(webClientBuilder, "webClientBuilder cannot be null");
this.webClientBuilder = webClientBuilder;
return this;
}

public Builder responseErrorHandler(ResponseErrorHandler responseErrorHandler) {
Assert.notNull(responseErrorHandler, "responseErrorHandler cannot be null");
this.responseErrorHandler = responseErrorHandler;
return this;
}

public OllamaApi build() {
return new OllamaApi(this.baseUrl, this.restClientBuilder, this.webClientBuilder, this.responseErrorHandler);
}

}
}
// @formatter:on
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.ollama.api.common;

import org.springframework.ai.observation.conventions.AiProvider;

/**
* Common value constants for Ollama api.
*
* @author Jonghoon Park
*/
public final class OllamaApiConstants {

public static final String DEFAULT_BASE_URL = "http://localhost:11434";

public static final String PROVIDER_NAME = AiProvider.OLLAMA.value();

private OllamaApiConstants() {

}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -86,7 +86,7 @@ public static void tearDown() {

private static OllamaApi buildOllamaApiWithModel(final String model) {
final String baseUrl = SKIP_CONTAINER_CREATION ? OLLAMA_LOCAL_URL : ollamaContainer.getEndpoint();
final OllamaApi api = new OllamaApi(baseUrl);
final OllamaApi api = OllamaApi.builder().baseUrl(baseUrl).build();
ensureModelIsPresent(api, model);
return api;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
class OllamaChatRequestTests {

OllamaChatModel chatModel = OllamaChatModel.builder()
.ollamaApi(new OllamaApi())
.ollamaApi(OllamaApi.builder().build())
.defaultOptions(OllamaOptions.builder().model("MODEL_NAME").topK(99).temperature(66.6).numGPU(1).build())
.build();

Expand All @@ -52,7 +52,7 @@ void whenToolRuntimeOptionsThenMergeWithDefaults() {
.toolContext(Map.of("key1", "value1", "key2", "valueA"))
.build();
OllamaChatModel chatModel = OllamaChatModel.builder()
.ollamaApi(new OllamaApi())
.ollamaApi(OllamaApi.builder().build())
.defaultOptions(defaultOptions)
.build();

Expand Down Expand Up @@ -144,7 +144,7 @@ public void createRequestWithPromptOptionsModelOverride() {
@Test
public void createRequestWithDefaultOptionsModelOverride() {
OllamaChatModel chatModel = OllamaChatModel.builder()
.ollamaApi(new OllamaApi())
.ollamaApi(OllamaApi.builder().build())
.defaultOptions(OllamaOptions.builder().model("DEFAULT_OPTIONS_MODEL").build())
.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
public class OllamaEmbeddingRequestTests {

OllamaEmbeddingModel embeddingModel = OllamaEmbeddingModel.builder()
.ollamaApi(new OllamaApi())
.ollamaApi(OllamaApi.builder().build())
.defaultOptions(OllamaOptions.builder().model("DEFAULT_MODEL").mainGPU(11).useMMap(true).numGPU(1).build())
.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ Next, create an `OllamaChatModel` instance and use it to send requests for text

[source,java]
----
var ollamaApi = new OllamaApi();
var ollamaApi = OllamaApi.builder().build();

var chatModel = OllamaChatModel.builder()
.ollamaApi(ollamaApi)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ Next, create an `OllamaEmbeddingModel` instance and use it to compute the embedd

[source,java]
----
var ollamaApi = new OllamaApi();
var ollamaApi = OllamaApi.builder().build();

var embeddingModel = new OllamaEmbeddingModel(this.ollamaApi,
OllamaOptions.builder()
Expand Down
Loading