Skip to content

Commit

Permalink
Merge branch 'main' into aiservice_created_event
Browse files Browse the repository at this point in the history
  • Loading branch information
catofdestruction authored Nov 26, 2024
2 parents 4e5cc40 + f934f16 commit 0b1d869
Show file tree
Hide file tree
Showing 17 changed files with 335 additions and 26 deletions.
2 changes: 1 addition & 1 deletion .github/ISSUE_TEMPLATE/bug_report.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Please provide a relevant code snippets to reproduce this bug.
A clear and concise description of what you expected to happen.

**Please complete the following information:**
- LangChain4j version: e.g. 0.36.0
- LangChain4j version: e.g. 0.36.1
- Java version: e.g. 17
- Spring Boot version: e.g. 3.3.1

Expand Down
17 changes: 17 additions & 0 deletions .github/workflows/add_new_pr_to_project.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
name: Add new PR to Project

on:
pull_request:
types:
- opened
- reopened

jobs:
add-to-project:
name: Add PR to Project
runs-on: ubuntu-latest
steps:
- uses: actions/[email protected]
with:
project-url: https://github.com/users/langchain4j/projects/2
github-token: ${{ secrets.GH_TOKEN_ADD_NEW_PRS_TO_PROJECT }}
8 changes: 8 additions & 0 deletions langchain4j-elasticsearch-spring-boot-starter/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,14 @@
<artifactId>slf4j-tinylog</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>commons-io</groupId>
<artifactId>commons-io</artifactId>
<version>2.17.0</version>
<scope>test</scope>
</dependency>

</dependencies>

<build>
Expand Down
2 changes: 1 addition & 1 deletion langchain4j-reactor/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
<version>3.24.2</version>
<version>3.26.3</version>
<scope>test</scope>
</dependency>

Expand Down
7 changes: 7 additions & 0 deletions langchain4j-spring-boot-starter/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-aop</artifactId>
<version>${spring.boot.version}</version>
<scope>test</scope>
</dependency>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package dev.langchain4j.service.spring;

import dev.langchain4j.agent.tool.Tool;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
import dev.langchain4j.model.chat.ChatLanguageModel;
Expand All @@ -8,11 +10,20 @@
import dev.langchain4j.rag.RetrievalAugmentor;
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import dev.langchain4j.service.AiServices;
import dev.langchain4j.service.tool.DefaultToolExecutor;
import dev.langchain4j.service.tool.ToolExecutor;
import org.springframework.beans.factory.FactoryBean;

import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static dev.langchain4j.agent.tool.ToolSpecifications.toolSpecificationFrom;
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
import static org.springframework.aop.framework.AopProxyUtils.ultimateTargetClass;
import static org.springframework.aop.support.AopUtils.isAopProxy;

class AiServiceFactory implements FactoryBean<Object> {

Expand Down Expand Up @@ -94,7 +105,13 @@ public Object getObject() {
}

if (!isNullOrEmpty(tools)) {
builder = builder.tools(tools);
for (Object tool : tools) {
if (isAopProxy(tool)) {
builder = builder.tools(aopEnhancedTools(tool));
} else {
builder = builder.tools(tool);
}
}
}

return builder.build();
Expand All @@ -120,4 +137,21 @@ public boolean isSingleton() {
* (such as java.io.Closeable.close()) will not be called automatically.
* Instead, a FactoryBean should implement DisposableBean and delegate any such close call to the underlying object.
*/

private Map<ToolSpecification, ToolExecutor> aopEnhancedTools(Object enhancedTool) {
Map<ToolSpecification, ToolExecutor> toolExecutors = new HashMap<>();
Class<?> originalToolClass = ultimateTargetClass(enhancedTool);
for (Method originalToolMethod : originalToolClass.getDeclaredMethods()) {
if (originalToolMethod.isAnnotationPresent(Tool.class)) {
Arrays.stream(enhancedTool.getClass().getDeclaredMethods())
.filter(m -> m.getName().equals(originalToolMethod.getName()))
.findFirst()
.ifPresent(enhancedMethod -> {
ToolSpecification toolSpecification = toolSpecificationFrom(originalToolMethod);
toolExecutors.put(toolSpecification, new DefaultToolExecutor(enhancedTool, enhancedMethod));
});
}
}
return toolExecutors;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,25 +46,27 @@ private Set<String> getBasePackages(ConfigurableListableBeanFactory beanFactory)
}

private void addComponentScanPackages(ConfigurableListableBeanFactory beanFactory, Set<String> collectedBasePackages) {
beanFactory.getBeansWithAnnotation(ComponentScan.class).forEach((beanName, instance) -> {
Set<ComponentScan> componentScans = AnnotatedElementUtils.getMergedRepeatableAnnotations(instance.getClass(), ComponentScan.class);
for (ComponentScan componentScan : componentScans) {
Set<String> basePackages = new LinkedHashSet<>();
String[] basePackagesArray = componentScan.basePackages();
for (String pkg : basePackagesArray) {
String[] tokenized = StringUtils.tokenizeToStringArray(this.environment.resolvePlaceholders(pkg),
ConfigurableApplicationContext.CONFIG_LOCATION_DELIMITERS);
Collections.addAll(basePackages, tokenized);
}
for (Class<?> clazz : componentScan.basePackageClasses()) {
basePackages.add(ClassUtils.getPackageName(clazz));
}
if (basePackages.isEmpty()) {
basePackages.add(ClassUtils.getPackageName(instance.getClass()));
for (String beanName : beanFactory.getBeanNamesForAnnotation(ComponentScan.class)) {
Class<?> beanClass = beanFactory.getType(beanName);
if (beanClass != null) {
Set<ComponentScan> componentScans = AnnotatedElementUtils.getMergedRepeatableAnnotations(beanClass, ComponentScan.class);
for (ComponentScan componentScan : componentScans) {
Set<String> basePackages = new LinkedHashSet<>();
for (String pkg : componentScan.basePackages()) {
String[] tokenized = StringUtils.tokenizeToStringArray(this.environment.resolvePlaceholders(pkg),
ConfigurableApplicationContext.CONFIG_LOCATION_DELIMITERS);
Collections.addAll(basePackages, tokenized);
}
for (Class<?> clazz : componentScan.basePackageClasses()) {
basePackages.add(ClassUtils.getPackageName(clazz));
}
if (basePackages.isEmpty()) {
basePackages.add(ClassUtils.getPackageName(beanClass));
}
collectedBasePackages.addAll(basePackages);
}
collectedBasePackages.addAll(basePackages);
}
});
}
}

private void removeAiServicesWithInactiveProfiles(BeanDefinitionRegistry registry) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,11 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() {
Map<String, ToolSpecification> toolSpecifications = new HashMap<>();
for (String beanName : beanFactory.getBeanDefinitionNames()) {
try {
Class<?> beanClass = Class.forName(beanFactory.getBeanDefinition(beanName).getBeanClassName());
String beanClassName = beanFactory.getBeanDefinition(beanName).getBeanClassName();
if (beanClassName == null) {
continue;
}
Class<?> beanClass = Class.forName(beanClassName);
for (Method beanMethod : beanClass.getDeclaredMethods()) {
if (beanMethod.isAnnotationPresent(Tool.class)) {
if (tools.add(beanName)) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package dev.langchain4j.service.spring.mode.automatic.Issue2133;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;

/**
* @author: qing
* @Date: 2024/11/20
*/
@SpringBootApplication
public class TestAutowireAiServiceApplication {

@Autowired
TestAutowireConfiguration testAutowireConfiguration;

public static void main(String[] args) {
SpringApplication.run(TestAutowireAiServiceApplication.class, args);
}

TestAutowireConfiguration getConfiguration() {
return testAutowireConfiguration;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package dev.langchain4j.service.spring.mode.automatic.Issue2133;

import dev.langchain4j.service.spring.AiServicesAutoConfig;
import org.junit.jupiter.api.Test;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;

import static org.junit.jupiter.api.Assertions.assertNotNull;

class TestAutowireClassAiServiceIT {

ApplicationContextRunner contextRunner = new ApplicationContextRunner()
.withConfiguration(AutoConfigurations.of(AiServicesAutoConfig.class));

@Test
void should_get_configuration_class() {
contextRunner
.withUserConfiguration(TestAutowireAiServiceApplication.class)
.withBean(TestAutowireConfiguration.class)
.run(context -> {
// given
TestAutowireAiServiceApplication application = context.getBean(TestAutowireAiServiceApplication.class);

// should get the configuration class
assertNotNull(application.getConfiguration(), "TestConfiguration class should be not null");
});
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package dev.langchain4j.service.spring.mode.automatic.Issue2133;

import org.springframework.context.annotation.Configuration;

/**
* @author: qing
* @Date: 2024/11/20
*/
@Configuration
class TestAutowireConfiguration {
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,24 @@
import dev.langchain4j.service.spring.AiServicesAutoConfig;
import dev.langchain4j.service.spring.event.AiServiceRegisteredEvent;
import dev.langchain4j.service.spring.mode.automatic.withTools.listener.AiServiceRegisteredEventListener;
import dev.langchain4j.service.spring.mode.automatic.withTools.aop.ToolObserverAspect;
import org.junit.jupiter.api.Test;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;

import java.util.List;

import static dev.langchain4j.service.spring.mode.ApiKeys.OPENAI_API_KEY;
import static dev.langchain4j.service.spring.mode.automatic.withTools.AopEnhancedTools.TOOL_OBSERVER_KEY;
import static dev.langchain4j.service.spring.mode.automatic.withTools.AopEnhancedTools.TOOL_OBSERVER_KEY_NAME_DESCRIPTION;
import static dev.langchain4j.service.spring.mode.automatic.withTools.AopEnhancedTools.TOOL_OBSERVER_PACKAGE_NAME;
import static dev.langchain4j.service.spring.mode.automatic.withTools.AopEnhancedTools.TOOL_OBSERVER_PACKAGE_NAME_DESCRIPTION;
import static dev.langchain4j.service.spring.mode.automatic.withTools.PackagePrivateTools.CURRENT_TIME;
import static dev.langchain4j.service.spring.mode.automatic.withTools.PublicTools.CURRENT_DATE;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;

class AiServicesAutoConfigIT {
Expand Down Expand Up @@ -90,6 +96,45 @@ void should_receive_ai_service_registered_event() {
List<String> tools = event.getToolSpecifications().stream().map(ToolSpecification::name).toList();
assertTrue(tools.contains("getCurrentDate"));
assertTrue(tools.contains("getCurrentTime"));
});
}

void should_create_AI_service_with_tool_which_is_enhanced_by_spring_aop() {
contextRunner
.withPropertyValues(
"langchain4j.open-ai.chat-model.api-key=" + OPENAI_API_KEY,
"langchain4j.open-ai.chat-model.temperature=0.0",
"langchain4j.open-ai.chat-model.log-requests=true",
"langchain4j.open-ai.chat-model.log-responses=true"
)
.withUserConfiguration(AiServiceWithToolsApplication.class)
.run(context -> {

// given
AiServiceWithTools aiService = context.getBean(AiServiceWithTools.class);

// when
String answer = aiService.chat("Which package is the @ToolObserver annotation located in? " +
"And what is the key of the @ToolObserver annotation?" +
"And What is the current time?");

System.out.println("Answer: " + answer);

// then should use AopEnhancedTools.getAspectPackage()
// & AopEnhancedTools.getToolObserverKey()
// & PackagePrivateTools.getCurrentTime()
assertThat(answer).contains(TOOL_OBSERVER_PACKAGE_NAME);
assertThat(answer).contains(TOOL_OBSERVER_KEY);
assertThat(answer).contains(String.valueOf(CURRENT_TIME.getMinute()));

// and AOP aspect should be called
// & only for getToolObserverKey() which is annotated with @ToolObserver
ToolObserverAspect aspect = context.getBean(ToolObserverAspect.class);
assertTrue(aspect.aspectHasBeenCalled());

assertEquals(1, aspect.getObservedTools().size());
assertTrue(aspect.getObservedTools().contains(TOOL_OBSERVER_KEY_NAME_DESCRIPTION));
assertFalse(aspect.getObservedTools().contains(TOOL_OBSERVER_PACKAGE_NAME_DESCRIPTION));
});
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package dev.langchain4j.service.spring.mode.automatic.withTools;

import dev.langchain4j.agent.tool.Tool;
import dev.langchain4j.service.spring.mode.automatic.withTools.aop.ToolObserver;
import org.springframework.stereotype.Component;

@Component
public class AopEnhancedTools {

public static final String TOOL_OBSERVER_PACKAGE_NAME_DESCRIPTION =
"Find the package directory where @ToolObserver is located.";
public static final String TOOL_OBSERVER_PACKAGE_NAME = ToolObserver.class.getPackageName();

public static final String TOOL_OBSERVER_KEY_NAME_DESCRIPTION =
"Find the key name of @ToolObserver";
public static final String TOOL_OBSERVER_KEY = "AOP_ENHANCED_TOOLS_SUPPORT_@_1122";

@Tool(TOOL_OBSERVER_PACKAGE_NAME_DESCRIPTION)
public String getToolObserverPackageName() {
return TOOL_OBSERVER_PACKAGE_NAME;
}

@ToolObserver(key = TOOL_OBSERVER_KEY)
@Tool(TOOL_OBSERVER_KEY_NAME_DESCRIPTION)
public String getToolObserverKey() {
return TOOL_OBSERVER_KEY;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package dev.langchain4j.service.spring.mode.automatic.withTools.aop;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
public @interface ToolObserver {

/**
* key just for example
*
* @return the key
*/
String key();
}
Loading

0 comments on commit 0b1d869

Please sign in to comment.