Skip to content

Commit 7998bc7

Browse files
async conversation api support
Signed-off-by: Elena Kolevska <[email protected]>
1 parent 32099bb commit 7998bc7

File tree

3 files changed

+118
-3
lines changed

3 files changed

+118
-3
lines changed

dapr/aio/clients/grpc/client.py

+63
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
from google.protobuf.message import Message as GrpcMessage
3131
from google.protobuf.empty_pb2 import Empty as GrpcEmpty
32+
from google.protobuf.any_pb2 import Any as GrpcAny
3233

3334
import grpc.aio # type: ignore
3435
from grpc.aio import ( # type: ignore
@@ -75,9 +76,12 @@
7576
InvokeMethodRequest,
7677
BindingRequest,
7778
TransactionalStateOperation,
79+
ConversationInput,
7880
)
7981
from dapr.clients.grpc._response import (
8082
BindingResponse,
83+
ConversationResponse,
84+
ConversationResult,
8185
DaprResponse,
8286
GetSecretResponse,
8387
GetBulkSecretResponse,
@@ -1711,6 +1715,65 @@ async def purge_workflow(self, instance_id: str, workflow_component: str) -> Dap
17111715
except grpc.aio.AioRpcError as err:
17121716
raise DaprInternalError(err.details())
17131717

1718+
async def converse_alpha1(
1719+
self,
1720+
name: str,
1721+
inputs: List[ConversationInput],
1722+
*,
1723+
# Force remaining args to be keyword-only
1724+
context_id: Optional[str] = None,
1725+
parameters: Optional[Dict[str, GrpcAny]] = None,
1726+
metadata: Optional[Dict[str, str]] = None,
1727+
scrub_pii: Optional[bool] = None,
1728+
temperature: Optional[float] = None,
1729+
) -> ConversationResponse:
1730+
"""Invoke an LLM using the conversation API (Alpha).
1731+
1732+
Args:
1733+
name: Name of the LLM component to invoke
1734+
inputs: List of conversation inputs
1735+
context_id: Optional ID for continuing an existing chat
1736+
parameters: Optional custom parameters for the request
1737+
metadata: Optional metadata for the component
1738+
scrub_pii: Optional flag to scrub PII from outputs
1739+
temperature: Optional temperature setting (0.0 to 1.0) where
1740+
lower values give more deterministic responses and
1741+
higher values enable more creative responses
1742+
1743+
Returns:
1744+
ConversationResponse containing the conversation results
1745+
1746+
Raises:
1747+
DaprInternalError: If the Dapr runtime returns an error
1748+
"""
1749+
inputs_pb = [
1750+
api_v1.ConversationInput(message=inp.message, role=inp.role, scrubPII=inp.scrub_pii)
1751+
for inp in inputs
1752+
]
1753+
1754+
request = api_v1.ConversationRequest(
1755+
name=name,
1756+
inputs=inputs_pb,
1757+
contextID=context_id,
1758+
parameters=parameters or {},
1759+
metadata=metadata or {},
1760+
scrubPII=scrub_pii,
1761+
temperature=temperature,
1762+
)
1763+
1764+
try:
1765+
response = await self._stub.ConverseAlpha1(request)
1766+
1767+
outputs = [
1768+
ConversationResult(result=output.result, parameters=output.parameters)
1769+
for output in response.outputs
1770+
]
1771+
1772+
return ConversationResponse(context_id=response.contextID, outputs=outputs)
1773+
1774+
except Exception as ex:
1775+
raise DaprInternalError(f'Error invoking conversation API: {str(ex)}')
1776+
17141777
async def wait(self, timeout_s: float):
17151778
"""Waits for sidecar to be available within the timeout.
17161779

examples/conversation/conversation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,4 @@
3030
)
3131

3232
for output in response.outputs:
33-
print(f'Result: {output.result}')
33+
print(f'Result: {output.result}')

tests/clients/test_dapr_grpc_client_async.py

+54-2
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@
2323

2424
from dapr.aio.clients.grpc.client import DaprGrpcClientAsync
2525
from dapr.aio.clients import DaprClient
26-
from dapr.clients.exceptions import DaprGrpcError
26+
from dapr.clients.exceptions import DaprGrpcError, DaprInternalError
2727
from dapr.common.pubsub.subscription import StreamInactiveError
2828
from dapr.proto import common_v1
2929
from .fake_dapr_server import FakeDaprSidecar
3030
from dapr.conf import settings
3131
from dapr.clients.grpc._helpers import to_bytes
32-
from dapr.clients.grpc._request import TransactionalStateOperation
32+
from dapr.clients.grpc._request import TransactionalStateOperation, ConversationInput
3333
from dapr.clients.grpc._state import StateOptions, Consistency, Concurrency, StateItem
3434
from dapr.clients.grpc._crypto import EncryptOptions, DecryptOptions
3535
from dapr.clients.grpc._response import (
@@ -1113,6 +1113,58 @@ async def test_decrypt_file_data_read_chunks(self):
11131113
self.assertEqual(await resp.read(5), b'hello')
11141114
self.assertEqual(await resp.read(5), b' dapr')
11151115

1116+
async def test_converse_alpha1_basic(self):
1117+
dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.grpc_port}')
1118+
1119+
inputs = [ConversationInput(message="Hello", role="user"),
1120+
ConversationInput(message="How are you?", role="user")]
1121+
1122+
response = await dapr.converse_alpha1(name="test-llm", inputs=inputs)
1123+
1124+
# Check response structure
1125+
self.assertIsNotNone(response)
1126+
self.assertEqual(len(response.outputs), 2)
1127+
self.assertEqual(response.outputs[0].result, "Response to: Hello")
1128+
self.assertEqual(response.outputs[1].result, "Response to: How are you?")
1129+
await dapr.close()
1130+
1131+
async def test_converse_alpha1_with_options(self):
1132+
dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.grpc_port}')
1133+
1134+
inputs = [ConversationInput(message="Hello", role="user", scrub_pii=True)]
1135+
1136+
response = await dapr.converse_alpha1(name="test-llm", inputs=inputs, context_id="chat-123",
1137+
temperature=0.7, scrub_pii=True, metadata={"key": "value"})
1138+
1139+
self.assertIsNotNone(response)
1140+
self.assertEqual(len(response.outputs), 1)
1141+
self.assertEqual(response.outputs[0].result, "Response to: Hello")
1142+
await dapr.close()
1143+
1144+
async def test_converse_alpha1_error_handling(self):
1145+
dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.grpc_port}')
1146+
1147+
# Setup server to raise an exception
1148+
self._fake_dapr_server.raise_exception_on_next_call(
1149+
status_pb2.Status(code=code_pb2.INVALID_ARGUMENT, message="Invalid argument"))
1150+
1151+
inputs = [ConversationInput(message="Hello", role="user")]
1152+
1153+
with self.assertRaises(DaprInternalError) as context:
1154+
await dapr.converse_alpha1(name="test-llm", inputs=inputs)
1155+
self.assertTrue("Invalid argument" in str(context.exception))
1156+
await dapr.close()
1157+
1158+
async def test_converse_alpha1_empty_inputs(self):
1159+
dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.grpc_port}')
1160+
1161+
# Test with empty inputs list
1162+
response = await dapr.converse_alpha1(name="test-llm", inputs=[])
1163+
1164+
self.assertIsNotNone(response)
1165+
self.assertEqual(len(response.outputs), 0)
1166+
await dapr.close()
1167+
11161168

11171169
if __name__ == '__main__':
11181170
unittest.main()

0 commit comments

Comments
 (0)