Skip to content

Commit

Permalink
Implement append_to_messages
Browse files Browse the repository at this point in the history
  • Loading branch information
marcus-daily committed Jan 3, 2025
1 parent 68b6787 commit 05acdbc
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 16 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ The following RTVI transports are available in this repository:
Add the following dependency to your `build.gradle` file:

```
implementation "ai.pipecat:daily-transport:0.3.1"
implementation "ai.pipecat:daily-transport:0.3.2"
```

Instantiate from your code:
Expand Down Expand Up @@ -46,7 +46,7 @@ using Kotlin Coroutines (`client.start().await()`).
Add the following dependency to your `build.gradle` file:

```
implementation "ai.pipecat:gemini-live-websocket-transport:0.3.1"
implementation "ai.pipecat:gemini-live-websocket-transport:0.3.2"
```

Instantiate from your code:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,13 @@ private data class InlineData(
}
}

private data class AppendedMessage(
val role: String,
val content: String
)

internal class GeminiClient private constructor(
private val onSendUserMessage: (String) -> Unit,
private val onSendUserMessage: (AppendedMessage) -> Unit,
private val onClose: () -> Unit,
private val setMicMuted: (Boolean) -> Unit,
private val isMicMuted: () -> Boolean,
Expand All @@ -118,13 +123,14 @@ internal class GeminiClient private constructor(

private sealed interface ClientThreadEvent {
class SendAudioData(val buf: ByteArray) : ClientThreadEvent
class SendUserMessage(val text: String) : ClientThreadEvent
class SendUserMessage(val msg: AppendedMessage) : ClientThreadEvent
data object Stop : ClientThreadEvent
data object WebsocketClosed : ClientThreadEvent
class WebsocketFailed(val t: Throwable) : ClientThreadEvent
class WebsocketMessage(val msg: IncomingMessage, val originalText: String) :
ClientThreadEvent
class SetMicMute(val muted: Boolean): ClientThreadEvent

class SetMicMute(val muted: Boolean) : ClientThreadEvent
}

companion object {
Expand Down Expand Up @@ -225,16 +231,16 @@ internal class GeminiClient private constructor(
}
})

fun doSendUserMessage(text: String) {
fun doSendUserMessage(msg: AppendedMessage) {
ws.send(
JSON.encodeToString(
ClientContentRequest.serializer(),
ClientContentRequest(
clientContent = ClientContentRequest.ClientContent(
turns = listOf(
ClientContentRequest.ClientContent.Turn(
role = "user",
parts = listOf(TurnPart(text = text))
role = msg.role,
parts = listOf(TurnPart(text = msg.content))
)
),
turnComplete = true
Expand Down Expand Up @@ -266,7 +272,7 @@ internal class GeminiClient private constructor(
}

is ClientThreadEvent.SendUserMessage -> {
doSendUserMessage(event.text)
doSendUserMessage(event.msg)
}

ClientThreadEvent.Stop -> {
Expand All @@ -291,7 +297,12 @@ internal class GeminiClient private constructor(
Log.i(TAG, "Setup complete")

if (config.initialMessage != null) {
doSendUserMessage(config.initialMessage)
doSendUserMessage(
AppendedMessage(
content = config.initialMessage,
role = "user"
)
)
}

getListener()?.onConnected()
Expand Down Expand Up @@ -370,7 +381,8 @@ internal class GeminiClient private constructor(
setMicMuted(muted)
}

fun sendUserMessage(text: String) = onSendUserMessage(text)
fun sendUserMessage(role: String, content: String) =
onSendUserMessage(AppendedMessage(role = role, content = content))

fun close() = onClose()
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import ai.pipecat.client.result.resolvedPromiseOk
import ai.pipecat.client.result.withPromise
import ai.pipecat.client.transport.AuthBundle
import ai.pipecat.client.transport.MsgClientToServer
import ai.pipecat.client.transport.MsgServerToClient
import ai.pipecat.client.transport.Transport
import ai.pipecat.client.transport.TransportContext
import ai.pipecat.client.transport.TransportFactory
Expand All @@ -24,8 +25,13 @@ import android.annotation.SuppressLint
import android.content.Context
import android.media.AudioManager
import android.util.Log
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.decodeFromJsonElement
import kotlinx.serialization.json.encodeToJsonElement


private val JSON = Json { ignoreUnknownKeys = true }

class GeminiLiveWebsocketTransport(
private val transportContext: TransportContext,
androidContext: Context
Expand All @@ -50,12 +56,15 @@ class GeminiLiveWebsocketTransport(
ServiceConfig(
SERVICE_LLM, listOf(
Option(OPTION_API_KEY, apiKey),
Option(OPTION_INITIAL_USER_MESSAGE, initialUserMessage?.let { Value.Str(it) } ?: Value.Null),
Option(
OPTION_INITIAL_USER_MESSAGE,
initialUserMessage?.let { Value.Str(it) } ?: Value.Null),
Option(
OPTION_MODEL_CONFIG, Value.Object(
"model" to Value.Str(model),
"generation_config" to generationConfig,
"system_instruction" to (systemInstruction?.let { Value.Str(it) } ?: Value.Null),
"system_instruction" to (systemInstruction?.let { Value.Str(it) }
?: Value.Null),
"tools" to tools,
)
)
Expand Down Expand Up @@ -114,7 +123,8 @@ class GeminiLiveWebsocketTransport(

val apiKey = (options?.getValueFor(OPTION_API_KEY) as? Value.Str)?.value
val modelConfig = options?.getValueFor(OPTION_MODEL_CONFIG)
val initialUserMessage = (options?.getValueFor(OPTION_INITIAL_USER_MESSAGE) as? Value.Str)?.value
val initialUserMessage =
(options?.getValueFor(OPTION_INITIAL_USER_MESSAGE) as? Value.Str)?.value

if (apiKey == null) {
return@chain resolvedPromiseErr(
Expand Down Expand Up @@ -143,6 +153,8 @@ class GeminiLiveWebsocketTransport(
thread.runOnThread {
setState(TransportState.Connected)
transportContext.callbacks.onConnected()
setState(TransportState.Ready)
transportContext.callbacks.onBotReady("local", emptyList())
promise.resolveOk(Unit)
}
}
Expand All @@ -167,7 +179,61 @@ class GeminiLiveWebsocketTransport(
resolvedPromiseOk(thread, Unit)
}

override fun sendMessage(message: MsgClientToServer) = operationNotSupported<Unit>()
override fun sendMessage(message: MsgClientToServer): Future<Unit, RTVIError> {

when (message.type) {
"action" -> {
try {
val data =
JSON.decodeFromJsonElement<MsgClientToServer.Data.Action>(message.data!!)

when (data.action) {

"append_to_messages" -> {
val messages: List<Value.Object> =
(data.arguments.getValueFor("messages") as Value.Array).value.map { it as Value.Object }

for (appendedMessage in messages) {

val role = appendedMessage.value["role"] as Value.Str
val content = appendedMessage.value["content"] as Value.Str

Log.i(TAG, "Sending message as ${role.value}: '${content.value}'")

client?.sendUserMessage(role = role.value, content = content.value)
}

transportContext.onMessage(
MsgServerToClient(
id = message.id,
label = message.label,
type = MsgServerToClient.Type.ActionResponse,
data = JSON.encodeToJsonElement(
MsgServerToClient.Data.ActionResponse(
Value.Null
)
)
)
)

return resolvedPromiseOk(thread, Unit)
}

else -> {
return operationNotSupported()
}
}

} catch (e: Exception) {
return resolvedPromiseErr(thread, RTVIError.ExceptionThrown(e))
}
}

else -> {
return operationNotSupported()
}
}
}

override fun state(): TransportState {
return state
Expand Down Expand Up @@ -196,7 +262,7 @@ class GeminiLiveWebsocketTransport(

override fun updateCam(camId: MediaDeviceId) = operationNotSupported<Unit>()

override fun selectedMic(): MediaDeviceInfo? {
override fun selectedMic(): MediaDeviceInfo {
val audioManager = appContext.getSystemService(Context.AUDIO_SERVICE) as AudioManager

return when (audioManager.isSpeakerphoneOn) {
Expand Down

0 comments on commit 05acdbc

Please sign in to comment.