Skip to content

Commit

Permalink
Merge branch 'langchain4j:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
Suhas-Koheda authored Dec 2, 2024
2 parents 7554b25 + 457bca4 commit 4afe606
Show file tree
Hide file tree
Showing 16 changed files with 364 additions and 25 deletions.
35 changes: 35 additions & 0 deletions .github/pull_request_template.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
<!--
Thank you so much for your contribution!
Please fill in all the sections below.
Please open the PR as a draft initially. Once it is reviewed and approved, we will ask you to add documentation and examples.
Please note that PRs with breaking changes or without tests will be rejected.
Please note that PRs will be reviewed based on the priority of the issues they address.
We ask for your patience. We are doing our best to review your PR as quickly as possible.
Please refrain from pinging and asking when it will be reviewed. Thank you for understanding!
-->

## Issue
<!-- Please specify the ID of the issue this PR is addressing. For example: "Closes #1234" or "Fixes #1234" -->
Closes #

## Change
<!-- Please describe the changes you made. -->


## General checklist
<!-- Please double-check the following points and mark them like this: [X] -->
- [ ] There are no breaking changes
- [ ] I have added unit and/or integration tests for my change
- [ ] The tests cover both positive and negative cases
- [ ] I have manually run all the unit and integration tests in the module I have added/changed, and they are all green
<!-- Before adding documentation and example(s) (below), please wait until the PR is reviewed and approved. -->
- [ ] I have added/updated the [documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs)
- [ ] I have added an example in the [examples repo](https://github.com/langchain4j/langchain4j-examples) (only for "big" features)


## Checklist for adding new Spring Boot starter
<!-- Please double-check the following points and mark them like this: [X] -->
- [ ] I have added my new starter in the root `pom.xml`
- [ ] I have added a `org.springframework.boot.autoconfigure.AutoConfiguration.imports` file in the `langchain4j-{integration}-spring-boot-starter/src/main/resources/META-INF/spring/` directory
17 changes: 17 additions & 0 deletions .github/workflows/add_new_pr_to_project.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
name: Add new PR to Project

on:
pull_request:
types:
- opened
- reopened

jobs:
add-to-project:
name: Add PR to Project
runs-on: ubuntu-latest
steps:
- uses: actions/[email protected]
with:
project-url: https://github.com/users/langchain4j/projects/2
github-token: ${{ secrets.GH_TOKEN_ADD_NEW_PRS_TO_PROJECT }}
2 changes: 1 addition & 1 deletion langchain4j-reactor/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
<version>3.24.2</version>
<version>3.26.3</version>
<scope>test</scope>
</dependency>

Expand Down
7 changes: 7 additions & 0 deletions langchain4j-spring-boot-starter/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-aop</artifactId>
<version>${spring.boot.version}</version>
<scope>test</scope>
</dependency>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package dev.langchain4j.service.spring;

import dev.langchain4j.agent.tool.Tool;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
import dev.langchain4j.model.chat.ChatLanguageModel;
Expand All @@ -8,11 +10,20 @@
import dev.langchain4j.rag.RetrievalAugmentor;
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import dev.langchain4j.service.AiServices;
import dev.langchain4j.service.tool.DefaultToolExecutor;
import dev.langchain4j.service.tool.ToolExecutor;
import org.springframework.beans.factory.FactoryBean;

import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static dev.langchain4j.agent.tool.ToolSpecifications.toolSpecificationFrom;
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
import static org.springframework.aop.framework.AopProxyUtils.ultimateTargetClass;
import static org.springframework.aop.support.AopUtils.isAopProxy;

class AiServiceFactory implements FactoryBean<Object> {

Expand Down Expand Up @@ -94,7 +105,13 @@ public Object getObject() {
}

if (!isNullOrEmpty(tools)) {
builder = builder.tools(tools);
for (Object tool : tools) {
if (isAopProxy(tool)) {
builder = builder.tools(aopEnhancedTools(tool));
} else {
builder = builder.tools(tool);
}
}
}

return builder.build();
Expand All @@ -120,4 +137,21 @@ public boolean isSingleton() {
* (such as java.io.Closeable.close()) will not be called automatically.
* Instead, a FactoryBean should implement DisposableBean and delegate any such close call to the underlying object.
*/

private Map<ToolSpecification, ToolExecutor> aopEnhancedTools(Object enhancedTool) {
Map<ToolSpecification, ToolExecutor> toolExecutors = new HashMap<>();
Class<?> originalToolClass = ultimateTargetClass(enhancedTool);
for (Method originalToolMethod : originalToolClass.getDeclaredMethods()) {
if (originalToolMethod.isAnnotationPresent(Tool.class)) {
Arrays.stream(enhancedTool.getClass().getDeclaredMethods())
.filter(m -> m.getName().equals(originalToolMethod.getName()))
.findFirst()
.ifPresent(enhancedMethod -> {
ToolSpecification toolSpecification = toolSpecificationFrom(originalToolMethod);
toolExecutors.put(toolSpecification, new DefaultToolExecutor(enhancedTool, enhancedMethod));
});
}
}
return toolExecutors;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,25 +46,27 @@ private Set<String> getBasePackages(ConfigurableListableBeanFactory beanFactory)
}

private void addComponentScanPackages(ConfigurableListableBeanFactory beanFactory, Set<String> collectedBasePackages) {
beanFactory.getBeansWithAnnotation(ComponentScan.class).forEach((beanName, instance) -> {
Set<ComponentScan> componentScans = AnnotatedElementUtils.getMergedRepeatableAnnotations(instance.getClass(), ComponentScan.class);
for (ComponentScan componentScan : componentScans) {
Set<String> basePackages = new LinkedHashSet<>();
String[] basePackagesArray = componentScan.basePackages();
for (String pkg : basePackagesArray) {
String[] tokenized = StringUtils.tokenizeToStringArray(this.environment.resolvePlaceholders(pkg),
ConfigurableApplicationContext.CONFIG_LOCATION_DELIMITERS);
Collections.addAll(basePackages, tokenized);
}
for (Class<?> clazz : componentScan.basePackageClasses()) {
basePackages.add(ClassUtils.getPackageName(clazz));
}
if (basePackages.isEmpty()) {
basePackages.add(ClassUtils.getPackageName(instance.getClass()));
for (String beanName : beanFactory.getBeanNamesForAnnotation(ComponentScan.class)) {
Class<?> beanClass = beanFactory.getType(beanName);
if (beanClass != null) {
Set<ComponentScan> componentScans = AnnotatedElementUtils.getMergedRepeatableAnnotations(beanClass, ComponentScan.class);
for (ComponentScan componentScan : componentScans) {
Set<String> basePackages = new LinkedHashSet<>();
for (String pkg : componentScan.basePackages()) {
String[] tokenized = StringUtils.tokenizeToStringArray(this.environment.resolvePlaceholders(pkg),
ConfigurableApplicationContext.CONFIG_LOCATION_DELIMITERS);
Collections.addAll(basePackages, tokenized);
}
for (Class<?> clazz : componentScan.basePackageClasses()) {
basePackages.add(ClassUtils.getPackageName(clazz));
}
if (basePackages.isEmpty()) {
basePackages.add(ClassUtils.getPackageName(beanClass));
}
collectedBasePackages.addAll(basePackages);
}
collectedBasePackages.addAll(basePackages);
}
});
}
}

private void removeAiServicesWithInactiveProfiles(BeanDefinitionRegistry registry) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,11 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() {
Set<String> tools = new HashSet<>();
for (String beanName : beanFactory.getBeanDefinitionNames()) {
try {
Class<?> beanClass = Class.forName(beanFactory.getBeanDefinition(beanName).getBeanClassName());
String beanClassName = beanFactory.getBeanDefinition(beanName).getBeanClassName();
if (beanClassName == null) {
continue;
}
Class<?> beanClass = Class.forName(beanClassName);
for (Method beanMethod : beanClass.getDeclaredMethods()) {
if (beanMethod.isAnnotationPresent(Tool.class)) {
tools.add(beanName);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package dev.langchain4j.service.spring.mode.automatic.Issue2133;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;

/**
* @author: qing
* @Date: 2024/11/20
*/
@SpringBootApplication
public class TestAutowireAiServiceApplication {

@Autowired
TestAutowireConfiguration testAutowireConfiguration;

public static void main(String[] args) {
SpringApplication.run(TestAutowireAiServiceApplication.class, args);
}

TestAutowireConfiguration getConfiguration() {
return testAutowireConfiguration;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package dev.langchain4j.service.spring.mode.automatic.Issue2133;

import dev.langchain4j.service.spring.AiServicesAutoConfig;
import org.junit.jupiter.api.Test;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;

import static org.junit.jupiter.api.Assertions.assertNotNull;

class TestAutowireClassAiServiceIT {

ApplicationContextRunner contextRunner = new ApplicationContextRunner()
.withConfiguration(AutoConfigurations.of(AiServicesAutoConfig.class));

@Test
void should_get_configuration_class() {
contextRunner
.withUserConfiguration(TestAutowireAiServiceApplication.class)
.withBean(TestAutowireConfiguration.class)
.run(context -> {
// given
TestAutowireAiServiceApplication application = context.getBean(TestAutowireAiServiceApplication.class);

// should get the configuration class
assertNotNull(application.getConfiguration(), "TestConfiguration class should be not null");
});
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package dev.langchain4j.service.spring.mode.automatic.Issue2133;

import org.springframework.context.annotation.Configuration;

/**
* @author: qing
* @Date: 2024/11/20
*/
@Configuration
class TestAutowireConfiguration {
}
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
package dev.langchain4j.service.spring.mode.automatic.withTools;

import dev.langchain4j.service.spring.AiServicesAutoConfig;
import dev.langchain4j.service.spring.mode.automatic.withTools.aop.ToolObserverAspect;
import org.junit.jupiter.api.Test;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;

import static dev.langchain4j.service.spring.mode.ApiKeys.OPENAI_API_KEY;
import static dev.langchain4j.service.spring.mode.automatic.withTools.AopEnhancedTools.TOOL_OBSERVER_KEY;
import static dev.langchain4j.service.spring.mode.automatic.withTools.AopEnhancedTools.TOOL_OBSERVER_KEY_NAME_DESCRIPTION;
import static dev.langchain4j.service.spring.mode.automatic.withTools.AopEnhancedTools.TOOL_OBSERVER_PACKAGE_NAME;
import static dev.langchain4j.service.spring.mode.automatic.withTools.AopEnhancedTools.TOOL_OBSERVER_PACKAGE_NAME_DESCRIPTION;
import static dev.langchain4j.service.spring.mode.automatic.withTools.PackagePrivateTools.CURRENT_TIME;
import static dev.langchain4j.service.spring.mode.automatic.withTools.PublicTools.CURRENT_DATE;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;

class AiServicesAutoConfigIT {

Expand Down Expand Up @@ -61,6 +69,46 @@ void should_create_AI_service_with_tool_that_is_package_private_method_in_packag
});
}

@Test
void should_create_AI_service_with_tool_which_is_enhanced_by_spring_aop() {
contextRunner
.withPropertyValues(
"langchain4j.open-ai.chat-model.api-key=" + OPENAI_API_KEY,
"langchain4j.open-ai.chat-model.temperature=0.0",
"langchain4j.open-ai.chat-model.log-requests=true",
"langchain4j.open-ai.chat-model.log-responses=true"
)
.withUserConfiguration(AiServiceWithToolsApplication.class)
.run(context -> {

// given
AiServiceWithTools aiService = context.getBean(AiServiceWithTools.class);

// when
String answer = aiService.chat("Which package is the @ToolObserver annotation located in? " +
"And what is the key of the @ToolObserver annotation?" +
"And What is the current time?");

System.out.println("Answer: " + answer);

// then should use AopEnhancedTools.getAspectPackage()
// & AopEnhancedTools.getToolObserverKey()
// & PackagePrivateTools.getCurrentTime()
assertThat(answer).contains(TOOL_OBSERVER_PACKAGE_NAME);
assertThat(answer).contains(TOOL_OBSERVER_KEY);
assertThat(answer).contains(String.valueOf(CURRENT_TIME.getMinute()));

// and AOP aspect should be called
// & only for getToolObserverKey() which is annotated with @ToolObserver
ToolObserverAspect aspect = context.getBean(ToolObserverAspect.class);
assertTrue(aspect.aspectHasBeenCalled());

assertEquals(1, aspect.getObservedTools().size());
assertTrue(aspect.getObservedTools().contains(TOOL_OBSERVER_KEY_NAME_DESCRIPTION));
assertFalse(aspect.getObservedTools().contains(TOOL_OBSERVER_PACKAGE_NAME_DESCRIPTION));
});
}

// TODO tools which are not @Beans?
// TODO negative cases
// TODO no @AiServices in app, just models
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package dev.langchain4j.service.spring.mode.automatic.withTools;

import dev.langchain4j.agent.tool.Tool;
import dev.langchain4j.service.spring.mode.automatic.withTools.aop.ToolObserver;
import org.springframework.stereotype.Component;

@Component
public class AopEnhancedTools {

public static final String TOOL_OBSERVER_PACKAGE_NAME_DESCRIPTION =
"Find the package directory where @ToolObserver is located.";
public static final String TOOL_OBSERVER_PACKAGE_NAME = ToolObserver.class.getPackageName();

public static final String TOOL_OBSERVER_KEY_NAME_DESCRIPTION =
"Find the key name of @ToolObserver";
public static final String TOOL_OBSERVER_KEY = "AOP_ENHANCED_TOOLS_SUPPORT_@_1122";

@Tool(TOOL_OBSERVER_PACKAGE_NAME_DESCRIPTION)
public String getToolObserverPackageName() {
return TOOL_OBSERVER_PACKAGE_NAME;
}

@ToolObserver(key = TOOL_OBSERVER_KEY)
@Tool(TOOL_OBSERVER_KEY_NAME_DESCRIPTION)
public String getToolObserverKey() {
return TOOL_OBSERVER_KEY;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package dev.langchain4j.service.spring.mode.automatic.withTools.aop;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
public @interface ToolObserver {

/**
* key just for example
*
* @return the key
*/
String key();
}
Loading

0 comments on commit 4afe606

Please sign in to comment.