Skip to content

Commit bf70fd4

Browse files
committed
Add unit tests
1 parent 75f487f commit bf70fd4

File tree

6 files changed

+245
-38
lines changed

6 files changed

+245
-38
lines changed

examples/src/main/java/io/dapr/examples/conversation/DemoConversationAI.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ public class DemoConversationAI {
1616
* @param args Input arguments (unused).
1717
*/
1818
public static void main(String[] args) {
19-
try (DaprConversationClient client = new DaprConversationClient(null)) {
19+
try (DaprConversationClient client = new DaprConversationClient()) {
2020
DaprConversationInput daprConversationInput = new DaprConversationInput("11");
2121

2222
// Component name is the name provided in the metadata block of the conversation.yaml file.

sdk-ai/src/main/java/io/dapr/ai/client/DaprAiClient.java

+12-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,18 @@ interface DaprAiClient {
2323
Mono<DaprConversationResponse> converse(
2424
String conversationComponentName,
2525
List<DaprConversationInput> daprConversationInputs,
26-
@Nullable String contextId,
26+
String contextId,
2727
boolean scrubPii,
2828
double temperature);
29+
30+
/**
31+
* Method to call the Dapr Converse API.
32+
*
33+
* @param conversationComponentName name for the conversation component.
34+
* @param daprConversationInputs prompts that are part of the conversation.
35+
* @return @ConversationResponse.
36+
*/
37+
Mono<DaprConversationResponse> converse(
38+
String conversationComponentName,
39+
List<DaprConversationInput> daprConversationInputs);
2940
}

sdk-ai/src/main/java/io/dapr/ai/client/DaprConversationClient.java

+47-13
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
import io.dapr.utils.NetworkUtils;
1313
import io.dapr.v1.DaprGrpc;
1414
import io.dapr.v1.DaprProtos;
15+
import io.grpc.Channel;
1516
import io.grpc.ManagedChannel;
1617
import io.grpc.stub.StreamObserver;
17-
import org.jetbrains.annotations.Nullable;
1818
import reactor.core.publisher.Mono;
1919
import reactor.core.publisher.MonoSink;
2020
import reactor.util.context.ContextView;
@@ -32,11 +32,6 @@ public class DaprConversationClient implements AutoCloseable, DaprAiClient {
3232
*/
3333
private final DaprGrpc.DaprStub asyncStub;
3434

35-
/**
36-
* The GRPC managed channel to be used.
37-
*/
38-
private final ManagedChannel channel;
39-
4035
/**
4136
* The retry policy.
4237
*/
@@ -47,24 +42,50 @@ public class DaprConversationClient implements AutoCloseable, DaprAiClient {
4742
*/
4843
private final TimeoutPolicy timeoutPolicy;
4944

45+
/**
46+
* Constructor to create conversation client.
47+
*/
48+
public DaprConversationClient() {
49+
this(DaprGrpc.newStub(NetworkUtils.buildGrpcManagedChannel(new Properties())), null);
50+
}
51+
52+
/**
53+
* Constructor.
54+
*
55+
* @param properties with client configuration options.
56+
* @param resiliencyOptions retry options.
57+
*/
58+
public DaprConversationClient(
59+
Properties properties,
60+
ResiliencyOptions resiliencyOptions) {
61+
this(DaprGrpc.newStub(NetworkUtils.buildGrpcManagedChannel(properties)), resiliencyOptions);
62+
}
63+
5064
/**
5165
* ConversationClient constructor.
5266
*
5367
* @param resiliencyOptions timeout and retry policies.
5468
*/
55-
public DaprConversationClient(
56-
@Nullable ResiliencyOptions resiliencyOptions) {
57-
this.channel = NetworkUtils.buildGrpcManagedChannel(new Properties());
58-
this.asyncStub = DaprGrpc.newStub(this.channel);
69+
protected DaprConversationClient(
70+
DaprGrpc.DaprStub asyncStub,
71+
ResiliencyOptions resiliencyOptions) {
72+
this.asyncStub = asyncStub;
5973
this.retryPolicy = new RetryPolicy(resiliencyOptions == null ? null : resiliencyOptions.getMaxRetries());
6074
this.timeoutPolicy = new TimeoutPolicy(resiliencyOptions == null ? null : resiliencyOptions.getTimeout());
6175
}
6276

77+
@Override
78+
public Mono<DaprConversationResponse> converse(
79+
String conversationComponentName,
80+
List<DaprConversationInput> daprConversationInputs) {
81+
return converse(conversationComponentName, daprConversationInputs, null, false, 0.0d);
82+
}
83+
6384
@Override
6485
public Mono<DaprConversationResponse> converse(
6586
String conversationComponentName,
6687
List<DaprConversationInput> daprConversationInputs,
67-
@Nullable String contextId,
88+
String contextId,
6889
boolean scrubPii,
6990
double temperature) {
7091

@@ -87,8 +108,19 @@ public Mono<DaprConversationResponse> converse(
87108
}
88109

89110
for (DaprConversationInput input : daprConversationInputs) {
90-
conversationRequest.addInputs(DaprProtos.ConversationInput.newBuilder()
91-
.setContent(input.getContent()).build());
111+
if (input.getContent() == null || input.getContent().isEmpty()) {
112+
throw new IllegalArgumentException("Conversation input content cannot be null or empty.");
113+
}
114+
115+
DaprProtos.ConversationInput.Builder conversationInputOrBuilder = DaprProtos.ConversationInput.newBuilder()
116+
.setContent(input.getContent())
117+
.setScrubPII(input.isScrubPii());
118+
119+
if (input.getRole() != null) {
120+
conversationInputOrBuilder.setRole(input.getRole().toString());
121+
}
122+
123+
conversationRequest.addInputs(conversationInputOrBuilder.build());
92124
}
93125

94126
Mono<DaprProtos.ConversationResponse> conversationResponseMono = Mono.deferContextual(
@@ -121,6 +153,8 @@ public Mono<DaprConversationResponse> converse(
121153

122154
@Override
123155
public void close() throws Exception {
156+
ManagedChannel channel = (ManagedChannel) this.asyncStub.getChannel();
157+
124158
DaprException.wrap(() -> {
125159
if (channel != null && !channel.isShutdown()) {
126160
channel.shutdown();

sdk-ai/src/main/java/io/dapr/ai/client/DaprConversationResponse.java

-9
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,6 @@ public class DaprConversationResponse {
1212

1313
private final List<DaprConversationOutput> daprConversationOutputs;
1414

15-
/**
16-
* Constructor.
17-
*
18-
* @param daprConversationOutputs outputs from the LLM.
19-
*/
20-
public DaprConversationResponse(List<DaprConversationOutput> daprConversationOutputs) {
21-
this.daprConversationOutputs = daprConversationOutputs;
22-
}
23-
2415
/**
2516
* Constructor.
2617
*

sdk-ai/src/test/java/io/dapr/ai/AITest.java

-14
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
package io.dapr.ai.client;
2+
3+
import io.dapr.client.resiliency.ResiliencyOptions;
4+
import io.dapr.config.Properties;
5+
import io.dapr.v1.DaprGrpc;
6+
import io.dapr.v1.DaprProtos;
7+
import io.grpc.ManagedChannel;
8+
import io.grpc.stub.StreamObserver;
9+
import org.junit.Assert;
10+
import org.junit.Before;
11+
import org.junit.Test;
12+
import org.mockito.ArgumentCaptor;
13+
import org.mockito.Mockito;
14+
15+
import java.util.ArrayList;
16+
import java.util.List;
17+
18+
import static org.mockito.Mockito.*;
19+
20+
public class DaprConversationClientTest {
21+
22+
private DaprGrpc.DaprStub daprStub;
23+
24+
@Before
25+
public void initialize() {
26+
27+
ManagedChannel channel = mock(ManagedChannel.class);
28+
daprStub = mock(DaprGrpc.DaprStub.class);
29+
when(daprStub.getChannel()).thenReturn(channel);
30+
when(daprStub.withInterceptors(Mockito.any(), Mockito.any())).thenReturn(daprStub);
31+
}
32+
33+
@Test
34+
public void converseShouldThrowIllegalArgumentExceptionWhenComponentNameIsNull() throws Exception {
35+
try (DaprConversationClient daprConversationClient = new DaprConversationClient()) {
36+
List<DaprConversationInput> daprConversationInputs = new ArrayList<>();
37+
daprConversationInputs.add(new DaprConversationInput("Hello there !"));
38+
39+
IllegalArgumentException exception =
40+
Assert.assertThrows(IllegalArgumentException.class, () ->
41+
daprConversationClient.converse(null, daprConversationInputs).block());
42+
Assert.assertEquals("Conversation component name cannot be null or empty.", exception.getMessage());
43+
}
44+
}
45+
46+
@Test
47+
public void converseShouldThrowIllegalArgumentExceptionWhenConversationComponentIsEmpty() throws Exception {
48+
try (DaprConversationClient daprConversationClient = new DaprConversationClient()) {
49+
List<DaprConversationInput> daprConversationInputs = new ArrayList<>();
50+
daprConversationInputs.add(new DaprConversationInput("Hello there !"));
51+
52+
IllegalArgumentException exception =
53+
Assert.assertThrows(IllegalArgumentException.class, () ->
54+
daprConversationClient.converse("", daprConversationInputs).block());
55+
Assert.assertEquals("Conversation component name cannot be null or empty.", exception.getMessage());
56+
}
57+
}
58+
59+
@Test
60+
public void converseShouldThrowIllegalArgumentExceptionWhenConversationInputIsEmpty() throws Exception {
61+
try (DaprConversationClient daprConversationClient = new DaprConversationClient()) {
62+
List<DaprConversationInput> daprConversationInputs = new ArrayList<>();
63+
64+
IllegalArgumentException exception =
65+
Assert.assertThrows(IllegalArgumentException.class, () ->
66+
daprConversationClient.converse("openai", daprConversationInputs).block());
67+
Assert.assertEquals("Conversation inputs cannot be null or empty.", exception.getMessage());
68+
}
69+
}
70+
71+
@Test
72+
public void converseShouldThrowIllegalArgumentExceptionWhenConversationInputIsNull() throws Exception {
73+
try (DaprConversationClient daprConversationClient =
74+
new DaprConversationClient(new Properties(), null)) {
75+
76+
IllegalArgumentException exception =
77+
Assert.assertThrows(IllegalArgumentException.class, () ->
78+
daprConversationClient.converse("openai", null).block());
79+
Assert.assertEquals("Conversation inputs cannot be null or empty.", exception.getMessage());
80+
}
81+
}
82+
83+
@Test
84+
public void converseShouldThrowIllegalArgumentExceptionWhenConversationInputContentIsNull() throws Exception {
85+
try (DaprConversationClient daprConversationClient = new DaprConversationClient()) {
86+
List<DaprConversationInput> daprConversationInputs = new ArrayList<>();
87+
daprConversationInputs.add(new DaprConversationInput(null));
88+
89+
IllegalArgumentException exception =
90+
Assert.assertThrows(IllegalArgumentException.class, () ->
91+
daprConversationClient.converse("openai", daprConversationInputs).block());
92+
Assert.assertEquals("Conversation input content cannot be null or empty.", exception.getMessage());
93+
}
94+
}
95+
96+
@Test
97+
public void converseShouldThrowIllegalArgumentExceptionWhenConversationInputContentIsEmpty() throws Exception {
98+
try (DaprConversationClient daprConversationClient = new DaprConversationClient()) {
99+
List<DaprConversationInput> daprConversationInputs = new ArrayList<>();
100+
daprConversationInputs.add(new DaprConversationInput(""));
101+
102+
IllegalArgumentException exception =
103+
Assert.assertThrows(IllegalArgumentException.class, () ->
104+
daprConversationClient.converse("openai", daprConversationInputs).block());
105+
Assert.assertEquals("Conversation input content cannot be null or empty.", exception.getMessage());
106+
}
107+
}
108+
109+
@Test
110+
public void converseShouldReturnConversationResponseWhenRequiredInputsAreValid() throws Exception {
111+
DaprProtos.ConversationResponse conversationResponse = DaprProtos.ConversationResponse.newBuilder()
112+
.addOutputs(DaprProtos.ConversationResult.newBuilder().setResult("Hello How are you").build()).build();
113+
114+
doAnswer(invocation -> {
115+
StreamObserver<DaprProtos.ConversationResponse> observer = invocation.getArgument(1);
116+
observer.onNext(conversationResponse);
117+
observer.onCompleted();
118+
return null;
119+
}).when(daprStub).converseAlpha1(any(DaprProtos.ConversationRequest.class), any());
120+
121+
try (DaprConversationClient daprConversationClient =
122+
new DaprConversationClient(daprStub, new ResiliencyOptions())) {
123+
List<DaprConversationInput> daprConversationInputs = new ArrayList<>();
124+
daprConversationInputs.add(new DaprConversationInput("Hello there"));
125+
126+
DaprConversationResponse daprConversationResponse =
127+
daprConversationClient.converse("openai", daprConversationInputs).block();
128+
129+
ArgumentCaptor<DaprProtos.ConversationRequest> captor =
130+
ArgumentCaptor.forClass(DaprProtos.ConversationRequest.class);
131+
verify(daprStub, times(1)).converseAlpha1(captor.capture(), Mockito.any());
132+
133+
DaprProtos.ConversationRequest conversationRequest = captor.getValue();
134+
135+
Assert.assertEquals("openai", conversationRequest.getName());
136+
Assert.assertEquals("Hello there", conversationRequest.getInputs(0).getContent());
137+
Assert.assertEquals("Hello How are you",
138+
daprConversationResponse.getDaprConversationOutputs().get(0).getResult());
139+
}
140+
}
141+
142+
@Test
143+
public void converseShouldReturnConversationResponseWhenRequiredAndOptionalInputsAreValid() throws Exception {
144+
DaprProtos.ConversationResponse conversationResponse = DaprProtos.ConversationResponse.newBuilder()
145+
.setContextID("contextId")
146+
.addOutputs(DaprProtos.ConversationResult.newBuilder().setResult("Hello How are you").build()).build();
147+
148+
doAnswer(invocation -> {
149+
StreamObserver<DaprProtos.ConversationResponse> observer = invocation.getArgument(1);
150+
observer.onNext(conversationResponse);
151+
observer.onCompleted();
152+
return null;
153+
}).when(daprStub).converseAlpha1(any(DaprProtos.ConversationRequest.class), any());
154+
155+
try (DaprConversationClient daprConversationClient = new DaprConversationClient(daprStub, null)) {
156+
DaprConversationInput daprConversationInput = new DaprConversationInput("Hello there")
157+
.setRole(DaprConversationRole.ASSISSTANT)
158+
.setScrubPii(true);
159+
160+
List<DaprConversationInput> daprConversationInputs = new ArrayList<>();
161+
daprConversationInputs.add(daprConversationInput);
162+
163+
DaprConversationResponse daprConversationResponse =
164+
daprConversationClient.converse("openai", daprConversationInputs,
165+
"contextId", true, 1.1d).block();
166+
167+
ArgumentCaptor<DaprProtos.ConversationRequest> captor =
168+
ArgumentCaptor.forClass(DaprProtos.ConversationRequest.class);
169+
verify(daprStub, times(1)).converseAlpha1(captor.capture(), Mockito.any());
170+
171+
DaprProtos.ConversationRequest conversationRequest = captor.getValue();
172+
173+
Assert.assertEquals("openai", conversationRequest.getName());
174+
Assert.assertEquals("contextId", conversationRequest.getContextID());
175+
Assert.assertTrue(conversationRequest.getScrubPII());
176+
Assert.assertEquals(1.1d, conversationRequest.getTemperature(), 0d);
177+
Assert.assertEquals("Hello there", conversationRequest.getInputs(0).getContent());
178+
Assert.assertTrue(conversationRequest.getInputs(0).getScrubPII());
179+
Assert.assertEquals(DaprConversationRole.ASSISSTANT.toString(), conversationRequest.getInputs(0).getRole());
180+
Assert.assertEquals("contextId", daprConversationResponse.getContextId());
181+
Assert.assertEquals("Hello How are you",
182+
daprConversationResponse.getDaprConversationOutputs().get(0).getResult());
183+
}
184+
}
185+
}

0 commit comments

Comments
 (0)