From da54df5ddff779bc288d73c9b67d4a3a68756a08 Mon Sep 17 00:00:00 2001 From: alvinlee518 Date: Tue, 31 Dec 2024 17:39:01 +0800 Subject: [PATCH] add spring boot starter of xinference (#38) * add spring boot starter of xinference * add spring boot starter of xinference * switch testcontainers * Update spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java Co-authored-by: Martin7-1 * Update spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java Co-authored-by: Martin7-1 * Update spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java Co-authored-by: Martin7-1 * Update spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java Co-authored-by: Martin7-1 * Update spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java Co-authored-by: Martin7-1 * Update spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java Co-authored-by: Martin7-1 * Update spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java Co-authored-by: Martin7-1 * Update spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java Co-authored-by: Martin7-1 * Update spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java Co-authored-by: Martin7-1 * add spring boot starter of xinference --------- Co-authored-by: lixw <> Co-authored-by: Martin7-1 --- langchain4j-community-bom/pom.xml | 9 +- .../pom.xml | 99 ++++++++ .../spring/ChatModelProperties.java | 186 +++++++++++++++ .../spring/EmbeddingModelProperties.java | 105 +++++++++ .../spring/ImageModelProperties.java | 142 +++++++++++ .../spring/LanguageModelProperties.java | 178 ++++++++++++++ .../xinference/spring/ProxyProperties.java | 42 ++++ .../spring/ScoringModelProperties.java | 123 ++++++++++ .../spring/StreamingChatModelProperties.java | 178 ++++++++++++++ .../StreamingLanguageModelProperties.java | 169 ++++++++++++++ .../spring/XinferenceAutoConfiguration.java | 192 +++++++++++++++ ...ot.autoconfigure.AutoConfiguration.imports | 1 + .../xinference/spring/AutoConfigIT.java | 220 ++++++++++++++++++ .../spring/XinferenceContainer.java | 95 ++++++++ .../xinference/spring/XinferenceUtils.java | 69 ++++++ spring-boot-starters/pom.xml | 1 + 16 files changed, 1808 insertions(+), 1 deletion(-) create mode 100644 spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/pom.xml create mode 100644 spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ChatModelProperties.java create mode 100644 spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/EmbeddingModelProperties.java create mode 100644 spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ImageModelProperties.java create mode 100644 spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/LanguageModelProperties.java create mode 100644 spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ProxyProperties.java create mode 100644 spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ScoringModelProperties.java create mode 100644 spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/StreamingChatModelProperties.java create mode 100644 spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/StreamingLanguageModelProperties.java create mode 100644 spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/XinferenceAutoConfiguration.java create mode 100644 spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports create mode 100644 spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/AutoConfigIT.java create mode 100644 spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/XinferenceContainer.java create mode 100644 spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/XinferenceUtils.java diff --git a/langchain4j-community-bom/pom.xml b/langchain4j-community-bom/pom.xml index f29aa2a..ba0b0ad 100644 --- a/langchain4j-community-bom/pom.xml +++ b/langchain4j-community-bom/pom.xml @@ -11,7 +11,8 @@ pom LangChain4j :: Community :: BOM - Bill of Materials POM for getting full, complete set of compatible versions of LangChain4j Community modules + Bill of Materials POM for getting full, complete set of compatible versions of LangChain4j Community + modules @@ -82,6 +83,12 @@ ${project.version} + + dev.langchain4j + langchain4j-community-xinference-spring-boot-starter + ${project.version} + + diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/pom.xml b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/pom.xml new file mode 100644 index 0000000..21e5b09 --- /dev/null +++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/pom.xml @@ -0,0 +1,99 @@ + + + 4.0.0 + + dev.langchain4j + langchain4j-community-spring-boot-starters + 1.0.0-alpha1 + ../pom.xml + + + langchain4j-community-xinference-spring-boot-starter + LangChain4j :: Community :: Spring Boot starter :: Xinference + + + + Apache-2.0 + https://www.apache.org/licenses/LICENSE-2.0.txt + repo + A business-friendly OSS license + + + + + + dev.langchain4j + langchain4j-community-xinference + ${project.version} + + + + org.springframework.boot + spring-boot-starter + + + ch.qos.logback + logback-classic + + + + + + ch.qos.logback + logback-classic + + + + org.springframework.boot + spring-boot-autoconfigure-processor + true + + + + + org.springframework.boot + spring-boot-configuration-processor + true + + + + org.springframework.boot + spring-boot-starter-test + test + + + + org.testcontainers + testcontainers + test + + + + org.testcontainers + junit-jupiter + test + + + + + + + org.honton.chas + license-maven-plugin + + + + + Eclipse Public License + http://www.eclipse.org/legal/epl-v10.html + + + GNU Lesser General Public License + http://www.gnu.org/licenses/old-licenses/lgpl-2.1.html + + + + + + + diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ChatModelProperties.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ChatModelProperties.java new file mode 100644 index 0000000..d86d41d --- /dev/null +++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ChatModelProperties.java @@ -0,0 +1,186 @@ +package dev.langchain4j.community.xinference.spring; + +import java.time.Duration; +import java.util.List; +import java.util.Map; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.context.properties.NestedConfigurationProperty; + +@ConfigurationProperties(prefix = ChatModelProperties.PREFIX) +public class ChatModelProperties { + static final String PREFIX = "langchain4j.community.xinference.chat-model"; + private String baseUrl; + private String apiKey; + private String modelName; + private Double temperature; + private Double topP; + private List stop; + private Integer maxTokens; + private Double presencePenalty; + private Double frequencyPenalty; + private Integer seed; + private String user; + private Object toolChoice; + private Boolean parallelToolCalls; + private Integer maxRetries; + private Duration timeout; + + @NestedConfigurationProperty + private ProxyProperties proxy; + + private Boolean logRequests; + private Boolean logResponses; + private Map customHeaders; + + public String getBaseUrl() { + return baseUrl; + } + + public void setBaseUrl(final String baseUrl) { + this.baseUrl = baseUrl; + } + + public String getApiKey() { + return apiKey; + } + + public void setApiKey(final String apiKey) { + this.apiKey = apiKey; + } + + public String getModelName() { + return modelName; + } + + public void setModelName(final String modelName) { + this.modelName = modelName; + } + + public Double getTemperature() { + return temperature; + } + + public void setTemperature(final Double temperature) { + this.temperature = temperature; + } + + public Double getTopP() { + return topP; + } + + public void setTopP(final Double topP) { + this.topP = topP; + } + + public List getStop() { + return stop; + } + + public void setStop(final List stop) { + this.stop = stop; + } + + public Integer getMaxTokens() { + return maxTokens; + } + + public void setMaxTokens(final Integer maxTokens) { + this.maxTokens = maxTokens; + } + + public Double getPresencePenalty() { + return presencePenalty; + } + + public void setPresencePenalty(final Double presencePenalty) { + this.presencePenalty = presencePenalty; + } + + public Double getFrequencyPenalty() { + return frequencyPenalty; + } + + public void setFrequencyPenalty(final Double frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + } + + public Integer getSeed() { + return seed; + } + + public void setSeed(final Integer seed) { + this.seed = seed; + } + + public String getUser() { + return user; + } + + public void setUser(final String user) { + this.user = user; + } + + public Object getToolChoice() { + return toolChoice; + } + + public void setToolChoice(final Object toolChoice) { + this.toolChoice = toolChoice; + } + + public Boolean getParallelToolCalls() { + return parallelToolCalls; + } + + public void setParallelToolCalls(final Boolean parallelToolCalls) { + this.parallelToolCalls = parallelToolCalls; + } + + public Integer getMaxRetries() { + return maxRetries; + } + + public void setMaxRetries(final Integer maxRetries) { + this.maxRetries = maxRetries; + } + + public Duration getTimeout() { + return timeout; + } + + public void setTimeout(final Duration timeout) { + this.timeout = timeout; + } + + public ProxyProperties getProxy() { + return proxy; + } + + public void setProxy(final ProxyProperties proxy) { + this.proxy = proxy; + } + + public Boolean getLogRequests() { + return logRequests; + } + + public void setLogRequests(final Boolean logRequests) { + this.logRequests = logRequests; + } + + public Boolean getLogResponses() { + return logResponses; + } + + public void setLogResponses(final Boolean logResponses) { + this.logResponses = logResponses; + } + + public Map getCustomHeaders() { + return customHeaders; + } + + public void setCustomHeaders(final Map customHeaders) { + this.customHeaders = customHeaders; + } +} diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/EmbeddingModelProperties.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/EmbeddingModelProperties.java new file mode 100644 index 0000000..d628324 --- /dev/null +++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/EmbeddingModelProperties.java @@ -0,0 +1,105 @@ +package dev.langchain4j.community.xinference.spring; + +import java.time.Duration; +import java.util.Map; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.context.properties.NestedConfigurationProperty; + +@ConfigurationProperties(prefix = EmbeddingModelProperties.PREFIX) +public class EmbeddingModelProperties { + static final String PREFIX = "langchain4j.community.xinference.embedding-model"; + + private String baseUrl; + private String apiKey; + private String modelName; + private String user; + private Integer maxRetries; + private Duration timeout; + + @NestedConfigurationProperty + private ProxyProperties proxy; + + private Boolean logRequests; + private Boolean logResponses; + private Map customHeaders; + + public String getBaseUrl() { + return baseUrl; + } + + public void setBaseUrl(final String baseUrl) { + this.baseUrl = baseUrl; + } + + public String getApiKey() { + return apiKey; + } + + public void setApiKey(final String apiKey) { + this.apiKey = apiKey; + } + + public String getModelName() { + return modelName; + } + + public void setModelName(final String modelName) { + this.modelName = modelName; + } + + public String getUser() { + return user; + } + + public void setUser(final String user) { + this.user = user; + } + + public Integer getMaxRetries() { + return maxRetries; + } + + public void setMaxRetries(final Integer maxRetries) { + this.maxRetries = maxRetries; + } + + public Duration getTimeout() { + return timeout; + } + + public void setTimeout(final Duration timeout) { + this.timeout = timeout; + } + + public ProxyProperties getProxy() { + return proxy; + } + + public void setProxy(final ProxyProperties proxy) { + this.proxy = proxy; + } + + public Boolean getLogRequests() { + return logRequests; + } + + public void setLogRequests(final Boolean logRequests) { + this.logRequests = logRequests; + } + + public Boolean getLogResponses() { + return logResponses; + } + + public void setLogResponses(final Boolean logResponses) { + this.logResponses = logResponses; + } + + public Map getCustomHeaders() { + return customHeaders; + } + + public void setCustomHeaders(final Map customHeaders) { + this.customHeaders = customHeaders; + } +} diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ImageModelProperties.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ImageModelProperties.java new file mode 100644 index 0000000..e59ecfb --- /dev/null +++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ImageModelProperties.java @@ -0,0 +1,142 @@ +package dev.langchain4j.community.xinference.spring; + +import dev.langchain4j.community.model.xinference.client.image.ResponseFormat; +import java.time.Duration; +import java.util.Map; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.context.properties.NestedConfigurationProperty; + +@ConfigurationProperties(prefix = ImageModelProperties.PREFIX) +public class ImageModelProperties { + static final String PREFIX = "langchain4j.community.xinference.image-model"; + + private String baseUrl; + private String apiKey; + private String modelName; + private String negativePrompt; + private ResponseFormat responseFormat; + private String size; + private String kwargs; + private String user; + private Integer maxRetries; + private Duration timeout; + + @NestedConfigurationProperty + private ProxyProperties proxy; + + private Boolean logRequests; + private Boolean logResponses; + private Map customHeaders; + + public String getBaseUrl() { + return baseUrl; + } + + public void setBaseUrl(final String baseUrl) { + this.baseUrl = baseUrl; + } + + public String getApiKey() { + return apiKey; + } + + public void setApiKey(final String apiKey) { + this.apiKey = apiKey; + } + + public String getModelName() { + return modelName; + } + + public void setModelName(final String modelName) { + this.modelName = modelName; + } + + public String getNegativePrompt() { + return negativePrompt; + } + + public void setNegativePrompt(final String negativePrompt) { + this.negativePrompt = negativePrompt; + } + + public ResponseFormat getResponseFormat() { + return responseFormat; + } + + public void setResponseFormat(final ResponseFormat responseFormat) { + this.responseFormat = responseFormat; + } + + public String getSize() { + return size; + } + + public void setSize(final String size) { + this.size = size; + } + + public String getKwargs() { + return kwargs; + } + + public void setKwargs(final String kwargs) { + this.kwargs = kwargs; + } + + public String getUser() { + return user; + } + + public void setUser(final String user) { + this.user = user; + } + + public Integer getMaxRetries() { + return maxRetries; + } + + public void setMaxRetries(final Integer maxRetries) { + this.maxRetries = maxRetries; + } + + public Duration getTimeout() { + return timeout; + } + + public void setTimeout(final Duration timeout) { + this.timeout = timeout; + } + + public ProxyProperties getProxy() { + return proxy; + } + + public void setProxy(final ProxyProperties proxy) { + this.proxy = proxy; + } + + public Boolean getLogRequests() { + return logRequests; + } + + public void setLogRequests(final Boolean logRequests) { + this.logRequests = logRequests; + } + + public Boolean getLogResponses() { + return logResponses; + } + + public void setLogResponses(final Boolean logResponses) { + this.logResponses = logResponses; + } + + public Map getCustomHeaders() { + return customHeaders; + } + + public void setCustomHeaders(final Map customHeaders) { + this.customHeaders = customHeaders; + } +} diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/LanguageModelProperties.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/LanguageModelProperties.java new file mode 100644 index 0000000..43bfdd0 --- /dev/null +++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/LanguageModelProperties.java @@ -0,0 +1,178 @@ +package dev.langchain4j.community.xinference.spring; + +import java.time.Duration; +import java.util.List; +import java.util.Map; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.context.properties.NestedConfigurationProperty; + +@ConfigurationProperties(prefix = LanguageModelProperties.PREFIX) +public class LanguageModelProperties { + static final String PREFIX = "langchain4j.community.xinference.language-model"; + + private String baseUrl; + private String apiKey; + private String modelName; + private Integer maxTokens; + private Double temperature; + private Double topP; + private Integer logprobs; + private Boolean echo; + private List stop; + private Double presencePenalty; + private Double frequencyPenalty; + private String user; + private Integer maxRetries; + private Duration timeout; + + @NestedConfigurationProperty + private ProxyProperties proxy; + + private Boolean logRequests; + private Boolean logResponses; + private Map customHeaders; + + public String getBaseUrl() { + return baseUrl; + } + + public void setBaseUrl(final String baseUrl) { + this.baseUrl = baseUrl; + } + + public String getApiKey() { + return apiKey; + } + + public void setApiKey(final String apiKey) { + this.apiKey = apiKey; + } + + public String getModelName() { + return modelName; + } + + public void setModelName(final String modelName) { + this.modelName = modelName; + } + + public Integer getMaxTokens() { + return maxTokens; + } + + public void setMaxTokens(final Integer maxTokens) { + this.maxTokens = maxTokens; + } + + public Double getTemperature() { + return temperature; + } + + public void setTemperature(final Double temperature) { + this.temperature = temperature; + } + + public Double getTopP() { + return topP; + } + + public void setTopP(final Double topP) { + this.topP = topP; + } + + public Integer getLogprobs() { + return logprobs; + } + + public void setLogprobs(final Integer logprobs) { + this.logprobs = logprobs; + } + + public Boolean getEcho() { + return echo; + } + + public void setEcho(final Boolean echo) { + this.echo = echo; + } + + public List getStop() { + return stop; + } + + public void setStop(final List stop) { + this.stop = stop; + } + + public Double getPresencePenalty() { + return presencePenalty; + } + + public void setPresencePenalty(final Double presencePenalty) { + this.presencePenalty = presencePenalty; + } + + public Double getFrequencyPenalty() { + return frequencyPenalty; + } + + public void setFrequencyPenalty(final Double frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + } + + public String getUser() { + return user; + } + + public void setUser(final String user) { + this.user = user; + } + + public Integer getMaxRetries() { + return maxRetries; + } + + public void setMaxRetries(final Integer maxRetries) { + this.maxRetries = maxRetries; + } + + public Duration getTimeout() { + return timeout; + } + + public void setTimeout(final Duration timeout) { + this.timeout = timeout; + } + + public ProxyProperties getProxy() { + return proxy; + } + + public void setProxy(final ProxyProperties proxy) { + this.proxy = proxy; + } + + public Boolean getLogRequests() { + return logRequests; + } + + public void setLogRequests(final Boolean logRequests) { + this.logRequests = logRequests; + } + + public Boolean getLogResponses() { + return logResponses; + } + + public void setLogResponses(final Boolean logResponses) { + this.logResponses = logResponses; + } + + public Map getCustomHeaders() { + return customHeaders; + } + + public void setCustomHeaders(final Map customHeaders) { + this.customHeaders = customHeaders; + } +} diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ProxyProperties.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ProxyProperties.java new file mode 100644 index 0000000..f110310 --- /dev/null +++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ProxyProperties.java @@ -0,0 +1,42 @@ +package dev.langchain4j.community.xinference.spring; + +import java.net.InetSocketAddress; +import java.net.Proxy; +import java.util.Objects; + +public class ProxyProperties { + private Proxy.Type type; + private String host; + private Integer port; + + public Proxy.Type getType() { + return type; + } + + public void setType(final Proxy.Type type) { + this.type = type; + } + + public String getHost() { + return host; + } + + public void setHost(final String host) { + this.host = host; + } + + public Integer getPort() { + return port; + } + + public void setPort(final Integer port) { + this.port = port; + } + + public static Proxy convert(ProxyProperties properties) { + if (Objects.isNull(properties)) { + return null; + } + return new Proxy(properties.getType(), new InetSocketAddress(properties.getHost(), properties.getPort())); + } +} diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ScoringModelProperties.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ScoringModelProperties.java new file mode 100644 index 0000000..a53a280 --- /dev/null +++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ScoringModelProperties.java @@ -0,0 +1,123 @@ +package dev.langchain4j.community.xinference.spring; + +import java.time.Duration; +import java.util.Map; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.context.properties.NestedConfigurationProperty; + +@ConfigurationProperties(prefix = ScoringModelProperties.PREFIX) +public class ScoringModelProperties { + static final String PREFIX = "langchain4j.community.xinference.scoring-model"; + + private String baseUrl; + private String apiKey; + private String modelName; + private Integer topN; + private Boolean returnDocuments; + private Boolean returnLen; + private Integer maxRetries; + private Duration timeout; + + @NestedConfigurationProperty + private ProxyProperties proxy; + + private Boolean logRequests; + private Boolean logResponses; + private Map customHeaders; + + public String getBaseUrl() { + return baseUrl; + } + + public void setBaseUrl(final String baseUrl) { + this.baseUrl = baseUrl; + } + + public String getApiKey() { + return apiKey; + } + + public void setApiKey(final String apiKey) { + this.apiKey = apiKey; + } + + public String getModelName() { + return modelName; + } + + public void setModelName(final String modelName) { + this.modelName = modelName; + } + + public Integer getTopN() { + return topN; + } + + public void setTopN(final Integer topN) { + this.topN = topN; + } + + public Boolean getReturnDocuments() { + return returnDocuments; + } + + public void setReturnDocuments(final Boolean returnDocuments) { + this.returnDocuments = returnDocuments; + } + + public Boolean getReturnLen() { + return returnLen; + } + + public void setReturnLen(final Boolean returnLen) { + this.returnLen = returnLen; + } + + public Integer getMaxRetries() { + return maxRetries; + } + + public void setMaxRetries(final Integer maxRetries) { + this.maxRetries = maxRetries; + } + + public Duration getTimeout() { + return timeout; + } + + public void setTimeout(final Duration timeout) { + this.timeout = timeout; + } + + public ProxyProperties getProxy() { + return proxy; + } + + public void setProxy(final ProxyProperties proxy) { + this.proxy = proxy; + } + + public Boolean getLogRequests() { + return logRequests; + } + + public void setLogRequests(final Boolean logRequests) { + this.logRequests = logRequests; + } + + public Boolean getLogResponses() { + return logResponses; + } + + public void setLogResponses(final Boolean logResponses) { + this.logResponses = logResponses; + } + + public Map getCustomHeaders() { + return customHeaders; + } + + public void setCustomHeaders(final Map customHeaders) { + this.customHeaders = customHeaders; + } +} diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/StreamingChatModelProperties.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/StreamingChatModelProperties.java new file mode 100644 index 0000000..6f1121f --- /dev/null +++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/StreamingChatModelProperties.java @@ -0,0 +1,178 @@ +package dev.langchain4j.community.xinference.spring; + +import java.time.Duration; +import java.util.List; +import java.util.Map; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.context.properties.NestedConfigurationProperty; + +@ConfigurationProperties(prefix = StreamingChatModelProperties.PREFIX) +public class StreamingChatModelProperties { + static final String PREFIX = "langchain4j.community.xinference.streaming-chat-model"; + + private String baseUrl; + private String apiKey; + private String modelName; + private Double temperature; + private Double topP; + private List stop; + private Integer maxTokens; + private Double presencePenalty; + private Double frequencyPenalty; + private Integer seed; + private String user; + private Object toolChoice; + private Boolean parallelToolCalls; + private Duration timeout; + + @NestedConfigurationProperty + private ProxyProperties proxy; + + private Boolean logRequests; + private Boolean logResponses; + private Map customHeaders; + + public String getBaseUrl() { + return baseUrl; + } + + public void setBaseUrl(final String baseUrl) { + this.baseUrl = baseUrl; + } + + public String getApiKey() { + return apiKey; + } + + public void setApiKey(final String apiKey) { + this.apiKey = apiKey; + } + + public String getModelName() { + return modelName; + } + + public void setModelName(final String modelName) { + this.modelName = modelName; + } + + public Double getTemperature() { + return temperature; + } + + public void setTemperature(final Double temperature) { + this.temperature = temperature; + } + + public Double getTopP() { + return topP; + } + + public void setTopP(final Double topP) { + this.topP = topP; + } + + public List getStop() { + return stop; + } + + public void setStop(final List stop) { + this.stop = stop; + } + + public Integer getMaxTokens() { + return maxTokens; + } + + public void setMaxTokens(final Integer maxTokens) { + this.maxTokens = maxTokens; + } + + public Double getPresencePenalty() { + return presencePenalty; + } + + public void setPresencePenalty(final Double presencePenalty) { + this.presencePenalty = presencePenalty; + } + + public Double getFrequencyPenalty() { + return frequencyPenalty; + } + + public void setFrequencyPenalty(final Double frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + } + + public Integer getSeed() { + return seed; + } + + public void setSeed(final Integer seed) { + this.seed = seed; + } + + public String getUser() { + return user; + } + + public void setUser(final String user) { + this.user = user; + } + + public Object getToolChoice() { + return toolChoice; + } + + public void setToolChoice(final Object toolChoice) { + this.toolChoice = toolChoice; + } + + public Boolean getParallelToolCalls() { + return parallelToolCalls; + } + + public void setParallelToolCalls(final Boolean parallelToolCalls) { + this.parallelToolCalls = parallelToolCalls; + } + + public Duration getTimeout() { + return timeout; + } + + public void setTimeout(final Duration timeout) { + this.timeout = timeout; + } + + public ProxyProperties getProxy() { + return proxy; + } + + public void setProxy(final ProxyProperties proxy) { + this.proxy = proxy; + } + + public Boolean getLogRequests() { + return logRequests; + } + + public void setLogRequests(final Boolean logRequests) { + this.logRequests = logRequests; + } + + public Boolean getLogResponses() { + return logResponses; + } + + public void setLogResponses(final Boolean logResponses) { + this.logResponses = logResponses; + } + + public Map getCustomHeaders() { + return customHeaders; + } + + public void setCustomHeaders(final Map customHeaders) { + this.customHeaders = customHeaders; + } +} diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/StreamingLanguageModelProperties.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/StreamingLanguageModelProperties.java new file mode 100644 index 0000000..2937006 --- /dev/null +++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/StreamingLanguageModelProperties.java @@ -0,0 +1,169 @@ +package dev.langchain4j.community.xinference.spring; + +import java.time.Duration; +import java.util.List; +import java.util.Map; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.context.properties.NestedConfigurationProperty; + +@ConfigurationProperties(prefix = StreamingLanguageModelProperties.PREFIX) +public class StreamingLanguageModelProperties { + static final String PREFIX = "langchain4j.community.xinference.streaming-language-model"; + + private String baseUrl; + private String apiKey; + private String modelName; + private Integer maxTokens; + private Double temperature; + private Double topP; + private Integer logprobs; + private Boolean echo; + private List stop; + private Double presencePenalty; + private Double frequencyPenalty; + private String user; + private Duration timeout; + + @NestedConfigurationProperty + private ProxyProperties proxy; + + private Boolean logRequests; + private Boolean logResponses; + private Map customHeaders; + + public String getBaseUrl() { + return baseUrl; + } + + public void setBaseUrl(final String baseUrl) { + this.baseUrl = baseUrl; + } + + public String getApiKey() { + return apiKey; + } + + public void setApiKey(final String apiKey) { + this.apiKey = apiKey; + } + + public String getModelName() { + return modelName; + } + + public void setModelName(final String modelName) { + this.modelName = modelName; + } + + public Integer getMaxTokens() { + return maxTokens; + } + + public void setMaxTokens(final Integer maxTokens) { + this.maxTokens = maxTokens; + } + + public Double getTemperature() { + return temperature; + } + + public void setTemperature(final Double temperature) { + this.temperature = temperature; + } + + public Double getTopP() { + return topP; + } + + public void setTopP(final Double topP) { + this.topP = topP; + } + + public Integer getLogprobs() { + return logprobs; + } + + public void setLogprobs(final Integer logprobs) { + this.logprobs = logprobs; + } + + public Boolean getEcho() { + return echo; + } + + public void setEcho(final Boolean echo) { + this.echo = echo; + } + + public List getStop() { + return stop; + } + + public void setStop(final List stop) { + this.stop = stop; + } + + public Double getPresencePenalty() { + return presencePenalty; + } + + public void setPresencePenalty(final Double presencePenalty) { + this.presencePenalty = presencePenalty; + } + + public Double getFrequencyPenalty() { + return frequencyPenalty; + } + + public void setFrequencyPenalty(final Double frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + } + + public String getUser() { + return user; + } + + public void setUser(final String user) { + this.user = user; + } + + public Duration getTimeout() { + return timeout; + } + + public void setTimeout(final Duration timeout) { + this.timeout = timeout; + } + + public ProxyProperties getProxy() { + return proxy; + } + + public void setProxy(final ProxyProperties proxy) { + this.proxy = proxy; + } + + public Boolean getLogRequests() { + return logRequests; + } + + public void setLogRequests(final Boolean logRequests) { + this.logRequests = logRequests; + } + + public Boolean getLogResponses() { + return logResponses; + } + + public void setLogResponses(final Boolean logResponses) { + this.logResponses = logResponses; + } + + public Map getCustomHeaders() { + return customHeaders; + } + + public void setCustomHeaders(final Map customHeaders) { + this.customHeaders = customHeaders; + } +} diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/XinferenceAutoConfiguration.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/XinferenceAutoConfiguration.java new file mode 100644 index 0000000..3f59135 --- /dev/null +++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/XinferenceAutoConfiguration.java @@ -0,0 +1,192 @@ +package dev.langchain4j.community.xinference.spring; + +import dev.langchain4j.community.model.xinference.XinferenceChatModel; +import dev.langchain4j.community.model.xinference.XinferenceEmbeddingModel; +import dev.langchain4j.community.model.xinference.XinferenceImageModel; +import dev.langchain4j.community.model.xinference.XinferenceLanguageModel; +import dev.langchain4j.community.model.xinference.XinferenceScoringModel; +import dev.langchain4j.community.model.xinference.XinferenceStreamingChatModel; +import dev.langchain4j.community.model.xinference.XinferenceStreamingLanguageModel; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.context.annotation.Bean; + +@AutoConfiguration +@EnableConfigurationProperties({ + ChatModelProperties.class, + StreamingChatModelProperties.class, + LanguageModelProperties.class, + StreamingLanguageModelProperties.class, + EmbeddingModelProperties.class, + ImageModelProperties.class, + ScoringModelProperties.class +}) +public class XinferenceAutoConfiguration { + @Bean + @ConditionalOnMissingBean + @ConditionalOnProperty(ChatModelProperties.PREFIX + ".base-url") + public XinferenceChatModel xinferenceChatModel(ChatModelProperties chatModelProperties) { + return XinferenceChatModel.builder() + .baseUrl(chatModelProperties.getBaseUrl()) + .apiKey(chatModelProperties.getApiKey()) + .modelName(chatModelProperties.getModelName()) + .temperature(chatModelProperties.getTemperature()) + .topP(chatModelProperties.getTopP()) + .stop(chatModelProperties.getStop()) + .maxTokens(chatModelProperties.getMaxTokens()) + .presencePenalty(chatModelProperties.getPresencePenalty()) + .frequencyPenalty(chatModelProperties.getFrequencyPenalty()) + .seed(chatModelProperties.getSeed()) + .user(chatModelProperties.getUser()) + .toolChoice(chatModelProperties.getToolChoice()) + .parallelToolCalls(chatModelProperties.getParallelToolCalls()) + .maxRetries(chatModelProperties.getMaxRetries()) + .timeout(chatModelProperties.getTimeout()) + .proxy(ProxyProperties.convert(chatModelProperties.getProxy())) + .logRequests(chatModelProperties.getLogRequests()) + .logResponses(chatModelProperties.getLogResponses()) + .customHeaders(chatModelProperties.getCustomHeaders()) + .build(); + } + + @Bean + @ConditionalOnMissingBean + @ConditionalOnProperty(StreamingChatModelProperties.PREFIX + ".base-url") + public XinferenceStreamingChatModel xinferenceStreamingChatModel( + StreamingChatModelProperties streamingChatModelProperties) { + return XinferenceStreamingChatModel.builder() + .baseUrl(streamingChatModelProperties.getBaseUrl()) + .apiKey(streamingChatModelProperties.getApiKey()) + .modelName(streamingChatModelProperties.getModelName()) + .temperature(streamingChatModelProperties.getTemperature()) + .topP(streamingChatModelProperties.getTopP()) + .stop(streamingChatModelProperties.getStop()) + .maxTokens(streamingChatModelProperties.getMaxTokens()) + .presencePenalty(streamingChatModelProperties.getPresencePenalty()) + .frequencyPenalty(streamingChatModelProperties.getFrequencyPenalty()) + .seed(streamingChatModelProperties.getSeed()) + .user(streamingChatModelProperties.getUser()) + .toolChoice(streamingChatModelProperties.getToolChoice()) + .parallelToolCalls(streamingChatModelProperties.getParallelToolCalls()) + .timeout(streamingChatModelProperties.getTimeout()) + .proxy(ProxyProperties.convert(streamingChatModelProperties.getProxy())) + .logRequests(streamingChatModelProperties.getLogRequests()) + .logResponses(streamingChatModelProperties.getLogResponses()) + .customHeaders(streamingChatModelProperties.getCustomHeaders()) + .build(); + } + + @Bean + @ConditionalOnMissingBean + @ConditionalOnProperty(LanguageModelProperties.PREFIX + ".base-url") + public XinferenceLanguageModel xinferenceLanguageModel(LanguageModelProperties languageModelProperties) { + return XinferenceLanguageModel.builder() + .baseUrl(languageModelProperties.getBaseUrl()) + .apiKey(languageModelProperties.getApiKey()) + .modelName(languageModelProperties.getModelName()) + .maxTokens(languageModelProperties.getMaxTokens()) + .temperature(languageModelProperties.getTemperature()) + .topP(languageModelProperties.getTopP()) + .logprobs(languageModelProperties.getLogprobs()) + .echo(languageModelProperties.getEcho()) + .stop(languageModelProperties.getStop()) + .presencePenalty(languageModelProperties.getPresencePenalty()) + .frequencyPenalty(languageModelProperties.getFrequencyPenalty()) + .user(languageModelProperties.getUser()) + .maxRetries(languageModelProperties.getMaxRetries()) + .timeout(languageModelProperties.getTimeout()) + .proxy(ProxyProperties.convert(languageModelProperties.getProxy())) + .logRequests(languageModelProperties.getLogRequests()) + .logResponses(languageModelProperties.getLogResponses()) + .customHeaders(languageModelProperties.getCustomHeaders()) + .build(); + } + + @Bean + @ConditionalOnMissingBean + @ConditionalOnProperty(StreamingLanguageModelProperties.PREFIX + ".base-url") + public XinferenceStreamingLanguageModel xinferenceStreamingLanguageModel( + StreamingLanguageModelProperties streamingLanguageModelProperties) { + return XinferenceStreamingLanguageModel.builder() + .baseUrl(streamingLanguageModelProperties.getBaseUrl()) + .apiKey(streamingLanguageModelProperties.getApiKey()) + .modelName(streamingLanguageModelProperties.getModelName()) + .maxTokens(streamingLanguageModelProperties.getMaxTokens()) + .temperature(streamingLanguageModelProperties.getTemperature()) + .topP(streamingLanguageModelProperties.getTopP()) + .logprobs(streamingLanguageModelProperties.getLogprobs()) + .echo(streamingLanguageModelProperties.getEcho()) + .stop(streamingLanguageModelProperties.getStop()) + .presencePenalty(streamingLanguageModelProperties.getPresencePenalty()) + .frequencyPenalty(streamingLanguageModelProperties.getFrequencyPenalty()) + .user(streamingLanguageModelProperties.getUser()) + .timeout(streamingLanguageModelProperties.getTimeout()) + .proxy(ProxyProperties.convert(streamingLanguageModelProperties.getProxy())) + .logRequests(streamingLanguageModelProperties.getLogRequests()) + .logResponses(streamingLanguageModelProperties.getLogResponses()) + .customHeaders(streamingLanguageModelProperties.getCustomHeaders()) + .build(); + } + + @Bean + @ConditionalOnMissingBean + @ConditionalOnProperty(EmbeddingModelProperties.PREFIX + ".base-url") + public XinferenceEmbeddingModel xinferenceEmbeddingModel(EmbeddingModelProperties embeddingModelProperties) { + return XinferenceEmbeddingModel.builder() + .baseUrl(embeddingModelProperties.getBaseUrl()) + .apiKey(embeddingModelProperties.getApiKey()) + .modelName(embeddingModelProperties.getModelName()) + .user(embeddingModelProperties.getUser()) + .maxRetries(embeddingModelProperties.getMaxRetries()) + .timeout(embeddingModelProperties.getTimeout()) + .proxy(ProxyProperties.convert(embeddingModelProperties.getProxy())) + .logRequests(embeddingModelProperties.getLogRequests()) + .logResponses(embeddingModelProperties.getLogResponses()) + .customHeaders(embeddingModelProperties.getCustomHeaders()) + .build(); + } + + @Bean + @ConditionalOnMissingBean + @ConditionalOnProperty(ImageModelProperties.PREFIX + ".base-url") + public XinferenceImageModel xinferenceImageModel(ImageModelProperties imageModelProperties) { + return XinferenceImageModel.builder() + .baseUrl(imageModelProperties.getBaseUrl()) + .apiKey(imageModelProperties.getApiKey()) + .modelName(imageModelProperties.getModelName()) + .negativePrompt(imageModelProperties.getNegativePrompt()) + .responseFormat(imageModelProperties.getResponseFormat()) + .size(imageModelProperties.getSize()) + .kwargs(imageModelProperties.getKwargs()) + .user(imageModelProperties.getUser()) + .maxRetries(imageModelProperties.getMaxRetries()) + .timeout(imageModelProperties.getTimeout()) + .proxy(ProxyProperties.convert(imageModelProperties.getProxy())) + .logRequests(imageModelProperties.getLogRequests()) + .logResponses(imageModelProperties.getLogResponses()) + .customHeaders(imageModelProperties.getCustomHeaders()) + .build(); + } + + @Bean + @ConditionalOnMissingBean + @ConditionalOnProperty(ScoringModelProperties.PREFIX + ".base-url") + public XinferenceScoringModel xinferenceScoringModel(ScoringModelProperties scoringModelProperties) { + return XinferenceScoringModel.builder() + .baseUrl(scoringModelProperties.getBaseUrl()) + .apiKey(scoringModelProperties.getApiKey()) + .modelName(scoringModelProperties.getModelName()) + .topN(scoringModelProperties.getTopN()) + .returnDocuments(scoringModelProperties.getReturnDocuments()) + .returnLen(scoringModelProperties.getReturnLen()) + .maxRetries(scoringModelProperties.getMaxRetries()) + .timeout(scoringModelProperties.getTimeout()) + .proxy(ProxyProperties.convert(scoringModelProperties.getProxy())) + .logRequests(scoringModelProperties.getLogRequests()) + .logResponses(scoringModelProperties.getLogResponses()) + .customHeaders(scoringModelProperties.getCustomHeaders()) + .build(); + } +} diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports new file mode 100644 index 0000000..4de934a --- /dev/null +++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports @@ -0,0 +1 @@ +dev.langchain4j.community.xinference.spring.XinferenceAutoConfiguration diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/AutoConfigIT.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/AutoConfigIT.java new file mode 100644 index 0000000..5198026 --- /dev/null +++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/AutoConfigIT.java @@ -0,0 +1,220 @@ +package dev.langchain4j.community.xinference.spring; + +import static dev.langchain4j.community.xinference.spring.XinferenceUtils.CHAT_MODEL_NAME; +import static dev.langchain4j.community.xinference.spring.XinferenceUtils.EMBEDDING_MODEL_NAME; +import static dev.langchain4j.community.xinference.spring.XinferenceUtils.GENERATE_MODEL_NAME; +import static dev.langchain4j.community.xinference.spring.XinferenceUtils.IMAGE_MODEL_NAME; +import static dev.langchain4j.community.xinference.spring.XinferenceUtils.RERANK_MODEL_NAME; +import static dev.langchain4j.community.xinference.spring.XinferenceUtils.XINFERENCE_IMAGE; +import static dev.langchain4j.community.xinference.spring.XinferenceUtils.launchCmd; +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.assertj.core.api.Assertions.assertThat; + +import dev.langchain4j.community.model.xinference.XinferenceChatModel; +import dev.langchain4j.community.model.xinference.XinferenceEmbeddingModel; +import dev.langchain4j.community.model.xinference.XinferenceImageModel; +import dev.langchain4j.community.model.xinference.XinferenceLanguageModel; +import dev.langchain4j.community.model.xinference.XinferenceScoringModel; +import dev.langchain4j.community.model.xinference.XinferenceStreamingChatModel; +import dev.langchain4j.community.model.xinference.XinferenceStreamingLanguageModel; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.StreamingResponseHandler; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.model.image.ImageModel; +import dev.langchain4j.model.language.LanguageModel; +import dev.langchain4j.model.language.StreamingLanguageModel; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.scoring.ScoringModel; +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + +/** + * + */ +@Testcontainers +class AutoConfigIT { + ApplicationContextRunner contextRunner = + new ApplicationContextRunner().withConfiguration(AutoConfigurations.of(XinferenceAutoConfiguration.class)); + + @Container + XinferenceContainer chatModelContainer = new XinferenceContainer(XINFERENCE_IMAGE); + + @Test + void should_provide_chat_model() throws IOException, InterruptedException { + chatModelContainer.execInContainer("bash", "-c", launchCmd(CHAT_MODEL_NAME)); + contextRunner + .withPropertyValues( + ChatModelProperties.PREFIX + ".base-url=" + chatModelContainer.getEndpoint(), + ChatModelProperties.PREFIX + ".model-name=" + CHAT_MODEL_NAME, + ChatModelProperties.PREFIX + ".logRequests=true", + ChatModelProperties.PREFIX + ".logResponses=true") + .run(context -> { + ChatLanguageModel chatLanguageModel = context.getBean(ChatLanguageModel.class); + assertThat(chatLanguageModel).isInstanceOf(XinferenceChatModel.class); + assertThat(chatLanguageModel.generate("What is the capital of Germany?")) + .contains("Berlin"); + assertThat(context.getBean(XinferenceChatModel.class)).isSameAs(chatLanguageModel); + }); + } + + @Test + void should_provide_streaming_chat_model() throws IOException, InterruptedException { + chatModelContainer.execInContainer("bash", "-c", launchCmd(CHAT_MODEL_NAME)); + contextRunner + .withPropertyValues( + StreamingChatModelProperties.PREFIX + ".base-url=" + chatModelContainer.getEndpoint(), + StreamingChatModelProperties.PREFIX + ".model-name=" + CHAT_MODEL_NAME, + StreamingChatModelProperties.PREFIX + ".logRequests=true", + StreamingChatModelProperties.PREFIX + ".logResponses=true") + .run(context -> { + StreamingChatLanguageModel streamingChatLanguageModel = + context.getBean(StreamingChatLanguageModel.class); + assertThat(streamingChatLanguageModel).isInstanceOf(XinferenceStreamingChatModel.class); + CompletableFuture> future = new CompletableFuture<>(); + streamingChatLanguageModel.generate( + "What is the capital of Germany?", new StreamingResponseHandler() { + @Override + public void onNext(String token) {} + + @Override + public void onComplete(Response response) { + future.complete(response); + } + + @Override + public void onError(Throwable error) {} + }); + Response response = future.get(60, SECONDS); + assertThat(response.content().text()).contains("Berlin"); + assertThat(context.getBean(XinferenceStreamingChatModel.class)) + .isSameAs(streamingChatLanguageModel); + }); + } + + @Test + void should_provide_language_model() throws IOException, InterruptedException { + chatModelContainer.execInContainer("bash", "-c", launchCmd(GENERATE_MODEL_NAME)); + contextRunner + .withPropertyValues( + LanguageModelProperties.PREFIX + ".base-url=" + chatModelContainer.getEndpoint(), + LanguageModelProperties.PREFIX + ".model-name=" + GENERATE_MODEL_NAME, + LanguageModelProperties.PREFIX + ".logRequests=true", + LanguageModelProperties.PREFIX + ".logResponses=true") + .run(context -> { + LanguageModel languageModel = context.getBean(LanguageModel.class); + assertThat(languageModel).isInstanceOf(XinferenceLanguageModel.class); + assertThat(languageModel + .generate("What is the capital of Germany?") + .content()) + .contains("Berlin"); + assertThat(context.getBean(XinferenceLanguageModel.class)).isSameAs(languageModel); + }); + } + + @Test + void should_provide_streaming_language_model() throws IOException, InterruptedException { + chatModelContainer.execInContainer("bash", "-c", launchCmd(GENERATE_MODEL_NAME)); + contextRunner + .withPropertyValues( + StreamingLanguageModelProperties.PREFIX + ".base-url=" + chatModelContainer.getEndpoint(), + StreamingLanguageModelProperties.PREFIX + ".model-name=" + GENERATE_MODEL_NAME, + StreamingLanguageModelProperties.PREFIX + ".logRequests=true", + StreamingLanguageModelProperties.PREFIX + ".logResponses=true") + .run(context -> { + StreamingLanguageModel streamingLanguageModel = context.getBean(StreamingLanguageModel.class); + assertThat(streamingLanguageModel).isInstanceOf(XinferenceStreamingLanguageModel.class); + CompletableFuture> future = new CompletableFuture<>(); + streamingLanguageModel.generate( + "What is the capital of Germany?", new StreamingResponseHandler() { + @Override + public void onNext(String token) {} + + @Override + public void onComplete(Response response) { + future.complete(response); + } + + @Override + public void onError(Throwable error) {} + }); + Response response = future.get(60, SECONDS); + assertThat(response.content()).contains("Berlin"); + + assertThat(context.getBean(XinferenceStreamingLanguageModel.class)) + .isSameAs(streamingLanguageModel); + }); + } + + @Test + void should_provide_embedding_model() throws IOException, InterruptedException { + chatModelContainer.execInContainer("bash", "-c", launchCmd(EMBEDDING_MODEL_NAME)); + contextRunner + .withPropertyValues( + EmbeddingModelProperties.PREFIX + ".base-url=" + chatModelContainer.getEndpoint(), + EmbeddingModelProperties.PREFIX + ".modelName=" + EMBEDDING_MODEL_NAME, + EmbeddingModelProperties.PREFIX + ".logRequests=true", + EmbeddingModelProperties.PREFIX + ".logResponses=true") + .run(context -> { + EmbeddingModel embeddingModel = context.getBean(EmbeddingModel.class); + assertThat(embeddingModel).isInstanceOf(XinferenceEmbeddingModel.class); + assertThat(embeddingModel.embed("hello world").content().dimension()) + .isEqualTo(768); + assertThat(context.getBean(XinferenceEmbeddingModel.class)).isSameAs(embeddingModel); + }); + } + + @Test + void should_provide_sc_model() throws IOException, InterruptedException { + chatModelContainer.execInContainer("bash", "-c", launchCmd(RERANK_MODEL_NAME)); + contextRunner + .withPropertyValues( + ScoringModelProperties.PREFIX + ".base-url=" + chatModelContainer.getEndpoint(), + ScoringModelProperties.PREFIX + ".modelName=" + RERANK_MODEL_NAME, + ScoringModelProperties.PREFIX + ".logRequests=true", + ScoringModelProperties.PREFIX + ".logResponses=true") + .run(context -> { + ScoringModel scoringModel = context.getBean(ScoringModel.class); + assertThat(scoringModel).isInstanceOf(XinferenceScoringModel.class); + TextSegment catSegment = TextSegment.from("The Maine Coon is a large domesticated cat breed."); + TextSegment dogSegment = TextSegment.from( + "The sweet-faced, lovable Labrador Retriever is one of America's most popular dog breeds, year after year."); + List segments = Arrays.asList(catSegment, dogSegment); + String query = "tell me about dogs"; + Response> response = scoringModel.scoreAll(segments, query); + List scores = response.content(); + assertThat(scores).hasSize(2); + assertThat(scores.get(0)).isGreaterThan(scores.get(1)); + assertThat(context.getBean(XinferenceScoringModel.class)).isSameAs(scoringModel); + }); + } + + @Test + @Disabled("Not supported to run in a Docker environment without GPU .") + void should_provide_image_model() throws IOException, InterruptedException { + chatModelContainer.execInContainer("bash", "-c", launchCmd(IMAGE_MODEL_NAME)); + contextRunner + .withPropertyValues( + ImageModelProperties.PREFIX + ".base-url=" + chatModelContainer.getEndpoint(), + ImageModelProperties.PREFIX + ".modelName=" + IMAGE_MODEL_NAME, + ImageModelProperties.PREFIX + ".logRequests=true", + ImageModelProperties.PREFIX + ".logResponses=true") + .run(context -> { + ImageModel imageModel = context.getBean(ImageModel.class); + assertThat(imageModel).isInstanceOf(XinferenceImageModel.class); + assertThat(imageModel.generate("banana").content().base64Data()) + .isNotNull(); + assertThat(context.getBean(XinferenceImageModel.class)).isSameAs(imageModel); + }); + } +} diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/XinferenceContainer.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/XinferenceContainer.java new file mode 100644 index 0000000..7ed4fe5 --- /dev/null +++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/XinferenceContainer.java @@ -0,0 +1,95 @@ +package dev.langchain4j.community.xinference.spring; + +import static dev.langchain4j.community.xinference.spring.XinferenceUtils.launchCmd; + +import com.github.dockerjava.api.DockerClient; +import com.github.dockerjava.api.command.InspectContainerResponse; +import com.github.dockerjava.api.model.DeviceRequest; +import com.github.dockerjava.api.model.Image; +import com.github.dockerjava.api.model.Info; +import com.github.dockerjava.api.model.RuntimeInfo; +import java.io.IOException; +import java.time.Duration; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testcontainers.DockerClientFactory; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; +import org.testcontainers.utility.DockerImageName; + +class XinferenceContainer extends GenericContainer { + private static final Logger log = LoggerFactory.getLogger(XinferenceContainer.class); + private static final DockerImageName DOCKER_IMAGE_NAME = DockerImageName.parse("xprobe/xinference"); + private static final Integer EXPOSED_PORT = 9997; + private String modelName; + + public XinferenceContainer(String image) { + this(DockerImageName.parse(image)); + } + + public XinferenceContainer(DockerImageName image) { + super(image); + image.assertCompatibleWith(DOCKER_IMAGE_NAME); + Info info = this.dockerClient.infoCmd().exec(); + Map runtimes = info.getRuntimes(); + if (runtimes != null && runtimes.containsKey("nvidia")) { + this.withCreateContainerCmdModifier((cmd) -> { + Objects.requireNonNull(cmd.getHostConfig()) + .withDeviceRequests(Collections.singletonList((new DeviceRequest()) + .withCapabilities(Collections.singletonList(Collections.singletonList("gpu"))) + .withCount(-1))); + }); + } + this.withExposedPorts(EXPOSED_PORT); + // https://github.com/xorbitsai/inference/issues/2573 + this.withCommand("bash", "-c", "xinference-local -H 0.0.0.0"); + this.waitingFor(Wait.forListeningPort().withStartupTimeout(Duration.ofMinutes(10))); + } + + @Override + protected void containerIsStarted(InspectContainerResponse containerInfo) { + if (this.modelName != null) { + try { + log.info("Start pulling the '{}' model ... would take several minutes ...", this.modelName); + ExecResult r = execInContainer("bash", "-c", launchCmd(this.modelName)); + if (r.getExitCode() != 0) { + throw new RuntimeException(r.getStderr()); + } + log.info("Model pulling competed! {}", r); + } catch (IOException | InterruptedException e) { + throw new RuntimeException("Error pulling model", e); + } + } + } + + public XinferenceContainer withModel(String modelName) { + this.modelName = modelName; + return this; + } + + public void commitToImage(String imageName) { + DockerImageName dockerImageName = DockerImageName.parse(this.getDockerImageName()); + if (!dockerImageName.equals(DockerImageName.parse(imageName))) { + DockerClient dockerClient = DockerClientFactory.instance().client(); + List images = + dockerClient.listImagesCmd().withReferenceFilter(imageName).exec(); + if (images.isEmpty()) { + DockerImageName imageModel = DockerImageName.parse(imageName); + dockerClient + .commitCmd(this.getContainerId()) + .withRepository(imageModel.getUnversionedPart()) + .withLabels(Collections.singletonMap("org.testcontainers.sessionId", "")) + .withTag(imageModel.getVersionPart()) + .exec(); + } + } + } + + public String getEndpoint() { + return "http://" + this.getHost() + ":" + this.getMappedPort(EXPOSED_PORT); + } +} diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/XinferenceUtils.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/XinferenceUtils.java new file mode 100644 index 0000000..edb94be --- /dev/null +++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/XinferenceUtils.java @@ -0,0 +1,69 @@ +package dev.langchain4j.community.xinference.spring; + +import com.github.dockerjava.api.DockerClient; +import com.github.dockerjava.api.model.Image; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.testcontainers.DockerClientFactory; +import org.testcontainers.utility.DockerImageName; + +class XinferenceUtils { + public static final String XINFERENCE_BASE_URL = System.getenv("XINFERENCE_BASE_URL"); + public static final String XINFERENCE_API_KEY = System.getenv("XINFERENCE_BASE_URL"); + // CPU + public static final String XINFERENCE_IMAGE = "xprobe/xinference:latest-cpu"; + // GPU + // public static final String XINFERENCE_IMAGE = "xprobe/xinference:latest"; + + public static final String CHAT_MODEL_NAME = "qwen2.5-instruct"; + public static final String GENERATE_MODEL_NAME = "qwen2.5"; + public static final String VISION_MODEL_NAME = "qwen2-vl-instruct"; + public static final String IMAGE_MODEL_NAME = "sd3-medium"; + public static final String EMBEDDING_MODEL_NAME = "text2vec-base-chinese"; + public static final String RERANK_MODEL_NAME = "bge-reranker-base"; + + private static final Map MODEL_LAUNCH_MAP = new HashMap<>() { + { + put( + CHAT_MODEL_NAME, + String.format( + "xinference launch --model-engine Transformers --model-name %s --size-in-billions 0_5 --model-format pytorch --quantization none", + CHAT_MODEL_NAME)); + put( + GENERATE_MODEL_NAME, + String.format( + "xinference launch --model-engine Transformers --model-name %s --size-in-billions 0_5 --model-format pytorch --quantization none", + GENERATE_MODEL_NAME)); + put( + VISION_MODEL_NAME, + String.format( + "xinference launch --model-engine Transformers --model-name %s --size-in-billions 2 --model-format pytorch --quantization none", + VISION_MODEL_NAME)); + put( + RERANK_MODEL_NAME, + String.format("xinference launch --model-name %s --model-type rerank", RERANK_MODEL_NAME)); + put( + IMAGE_MODEL_NAME, + String.format("xinference launch --model-name %s --model-type image", IMAGE_MODEL_NAME)); + put( + EMBEDDING_MODEL_NAME, + String.format("xinference launch --model-name %s --model-type embedding", EMBEDDING_MODEL_NAME)); + } + }; + + public static DockerImageName resolve(String baseImage, String localImageName) { + DockerImageName dockerImageName = DockerImageName.parse(baseImage); + DockerClient dockerClient = DockerClientFactory.instance().client(); + List images = + dockerClient.listImagesCmd().withReferenceFilter(localImageName).exec(); + if (images.isEmpty()) { + return dockerImageName; + } + return DockerImageName.parse(localImageName).asCompatibleSubstituteFor(baseImage); + } + + public static String launchCmd(String modelName) { + return MODEL_LAUNCH_MAP.get(modelName); + } +} diff --git a/spring-boot-starters/pom.xml b/spring-boot-starters/pom.xml index cc72675..899fb76 100644 --- a/spring-boot-starters/pom.xml +++ b/spring-boot-starters/pom.xml @@ -26,6 +26,7 @@ langchain4j-community-dashscope-spring-boot-starter langchain4j-community-qianfan-spring-boot-starter + langchain4j-community-xinference-spring-boot-starter