Skip to content

Commit

Permalink
Fix server-cli communication and reduce overhead
Browse files Browse the repository at this point in the history
  • Loading branch information
toasterofbread committed Feb 20, 2024
1 parent 2a5919a commit 2bf5ccd
Show file tree
Hide file tree
Showing 8 changed files with 138 additions and 97 deletions.
2 changes: 1 addition & 1 deletion src/commonMain/kotlin/cinterop/zmq/ZmqSocket.kt
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class ZmqSocket(mem_scope: MemScope, type: Int, val is_binder: Boolean) {
}

fun sendStringMultipart(parts: List<String>) =
sendMultipart(parts.map { it.cstr })
sendMultipart(SpMsSocketApi.encode(parts).map { it.cstr })

fun sendMultipart(parts: List<CValues<ByteVar>>) = memScoped {
if (parts.isEmpty()) {
Expand Down
12 changes: 3 additions & 9 deletions src/commonMain/kotlin/spms/Command.kt
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,6 @@ abstract class Command(
SpMs.log(message)
}

protected fun getVersionInfoText(): String =
localisation.versionInfoText(SPMS_API_VERSION)

override fun commandHelpEpilog(context: Context): String = context.loc.cli.bug_report_notice
override fun commandHelp(context: Context): String = help?.invoke(context.loc).orEmpty()

Expand All @@ -54,12 +51,9 @@ abstract class Command(
localization = localisation
}

if (output_version || !silent) {
println(getVersionInfoText())

if (output_version) {
exitProcess(0)
}
if (output_version) {
SpMs.printVersionInfo(localisation)
exitProcess(0)
}
}

Expand Down
14 changes: 8 additions & 6 deletions src/commonMain/kotlin/spms/client/cli/CommandLineClientMode.kt
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ abstract class CommandLineClientMode(
return context.socket
}

fun connectSocket() {
fun connectSocket(type: SpMsClientType = SpMsClientType.COMMAND_LINE_ACTION, handshake_actions: List<String>? = null): List<String> {
check(!socket_connected)

try {
Expand All @@ -41,9 +41,10 @@ abstract class CommandLineClientMode(

val handshake: SpMsClientHandshake = SpMsClientHandshake(
name = context.client_name,
type = SpMsClientType.COMMAND_LINE,
type = type,
machine_id = SpMs.getMachineId(),
language = currentContext.loc.language.name
language = currentContext.loc.language.name,
actions = handshake_actions
)
context.socket.sendStringMultipart(listOf(Json.encodeToString(handshake)))

Expand All @@ -53,15 +54,16 @@ abstract class CommandLineClientMode(
throw SpMsCommandLineClientError(currentContext.loc.cli.errServerDidNotRespond(SERVER_REPLY_TIMEOUT_MS))
}

log(currentContext.loc.cli.handshake_reply_received)
log(currentContext.loc.cli.handshake_reply_received + " " + reply.toString())
socket_connected = true

return reply
}
catch (e: Throwable) {
log(currentContext.loc.cli.releasing_socket)
context.release()
throw e
}

socket_connected = true
}

fun releaseSocket() {
Expand Down
86 changes: 49 additions & 37 deletions src/commonMain/kotlin/spms/client/cli/modes/Run.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import com.github.ajalt.clikt.parameters.options.option
import com.github.ajalt.clikt.parameters.types.float
import com.github.ajalt.clikt.parameters.types.int
import kotlinx.serialization.encodeToString
import kotlinx.serialization.serializer
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.JsonPrimitive
import libzmq.ZMQ_NOBLOCK
Expand All @@ -24,6 +25,7 @@ import spms.socketapi.shared.SpMsSocketApi
import spms.client.cli.CommandLineClientMode
import spms.client.cli.SERVER_REPLY_TIMEOUT_MS
import spms.localisation.loc
import spms.socketapi.shared.SpMsServerHandshake
import spms.socketapi.shared.SpMsActionReply
import spms.socketapi.shared.SPMS_EXPECT_REPLY_CHAR
import spms.socketapi.player.PlayerAction
Expand Down Expand Up @@ -115,9 +117,13 @@ class ActionCommandLineClientMode(
}
}

connectSocket()
val reply: SpMsActionReply? = executeActionOnSocket(
action,
parameter_values,
SERVER_REPLY_TIMEOUT_MS,
currentContext, silent = silent
)

val reply: SpMsActionReply? = action.executeOnSocket(socket, parameter_values, SERVER_REPLY_TIMEOUT_MS, currentContext, silent = silent)
if (reply == null) {
throw CliktError(currentContext.loc.server_actions.replyNotReceived(SERVER_REPLY_TIMEOUT_MS).toRed())
}
Expand All @@ -141,45 +147,51 @@ class ActionCommandLineClientMode(

releaseSocket()
}
}

@OptIn(ExperimentalForeignApi::class)
private fun Action.executeOnSocket(
socket: ZmqSocket,
parameter_values: List<JsonPrimitive>,
reply_timeout_ms: Long?,
context: Context,
silent: Boolean = false
): SpMsActionReply? {
socket.recvMultipart(reply_timeout_ms) ?: return null

if (!silent) {
println(context.loc.server_actions.sendingActionToServer(identifier))
}
@OptIn(ExperimentalForeignApi::class)
private fun executeActionOnSocket(
action: Action,
parameter_values: List<JsonPrimitive>,
reply_timeout_ms: Long?,
context: Context,
silent: Boolean = false
): SpMsActionReply? {
if (!silent) {
println(context.loc.server_actions.sendingActionToServer(action.identifier))
}

socket.sendStringMultipart(
listOf(SPMS_EXPECT_REPLY_CHAR + identifier, Json.encodeToString(parameter_values))
)
val reply = connectSocket(
handshake_actions = listOf(SPMS_EXPECT_REPLY_CHAR + action.identifier, Json.encodeToString(parameter_values))
)

if (!silent) {
println(context.loc.server_actions.actionSentAndWaitingForReply(identifier))
}
if (!silent) {
println(context.loc.server_actions.actionSentAndWaitingForReply(action.identifier))
}

val timeout_end: Long? = reply_timeout_ms?.let { getTimeMillis() + it }
do {
val reply: List<String>? = socket.recvStringMultipart(timeout_end?.let { (it - getTimeMillis()).coerceAtLeast(ZMQ_NOBLOCK.toLong()) })
if (!reply.isNullOrEmpty()) {
if (!silent) {
println(context.loc.server_actions.receivedReplyFromServer(identifier))
}
val timeout_end: Long? = reply_timeout_ms?.let { getTimeMillis() + it }
do {
if (!reply.isNullOrEmpty()) {
if (!silent) {
println(context.loc.server_actions.receivedReplyFromServer(action.identifier))
}

val decoded: String = SpMsSocketApi.decode(reply).first()
return Json.decodeFromString<SpMsActionReply>(decoded)
}
else if (!silent) {
println(context.loc.server_actions.receivedEmptyReplyFromServer(identifier))
}
} while (timeout_end == null || getTimeMillis() < timeout_end)
val decoded: String = SpMsSocketApi.decode(reply).first()
try {
val parsed_reply: SpMsServerHandshake = Json.decodeFromString(decoded)
check(!parsed_reply.action_replies.isNullOrEmpty()) {
"Got empty reply from server"
}
return parsed_reply.action_replies.first()
}
catch (e: Throwable) {
throw RuntimeException("JSON decoding server reply failed $decoded", e)
}
}
else if (!silent) {
println(context.loc.server_actions.receivedEmptyReplyFromServer(action.identifier))
}
} while (timeout_end == null || getTimeMillis() < timeout_end)

return null
return null
}
}
98 changes: 63 additions & 35 deletions src/commonMain/kotlin/spms/server/SpMs.kt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import spms.socketapi.parseSocketMessage
import spms.socketapi.player.PlayerAction
import spms.socketapi.server.ServerAction
import spms.socketapi.shared.*
import spms.localisation.SpMsLocalisation
import kotlin.experimental.ExperimentalNativeApi
import kotlin.system.exitProcess
import kotlin.system.getTimeMillis
Expand Down Expand Up @@ -142,31 +143,7 @@ class SpMs(mem_scope: MemScope, val headless: Boolean = false, enable_gui: Boole
executing_client_id = client.id

try {
val reply: List<SpMsActionReply> =
parseSocketMessage(
client_reply.parts,
{
RuntimeException("Parse exception while processing reply from $client", it).printStackTrace()
}
) { action_name, action_params ->
val server_action: ServerAction? = ServerAction.getByName(action_name)
if (server_action != null) {
return@parseSocketMessage server_action.execute(this, client.id, action_params)
}

if (!headless && player is MpvClientImpl) {
val player_action: PlayerAction? = PlayerAction.getByName(action_name)
if (player_action != null) {
return@parseSocketMessage player_action.execute(player, action_params)
}
}

throw NotImplementedError("Unknown action '$action_name'")
}

if (reply.isNotEmpty()) {
sendMultipart(client.createMessage(listOf(Json.encodeToString(reply))))
}
processClientMessage(client_reply, client)
}
catch (e: Throwable) {
RuntimeException("Exception while processing reply from $client", e).printStackTrace()
Expand All @@ -176,6 +153,39 @@ class SpMs(mem_scope: MemScope, val headless: Boolean = false, enable_gui: Boole
}
}

private fun processClientActions(actions: List<String>, client: SpMsClient): List<SpMsActionReply> {
return parseSocketMessage(
actions,
{
RuntimeException("Parse exception while processing message from $client", it).printStackTrace()
}
) { action_name, action_params ->
val server_action: ServerAction? = ServerAction.getByName(action_name)
if (server_action != null) {
println("Performing server action $action_name with $action_params from $client")
return@parseSocketMessage server_action.execute(this, client.id, action_params)
}

if (!headless && player is MpvClientImpl) {
val player_action: PlayerAction? = PlayerAction.getByName(action_name)
if (player_action != null) {
println("Performing player action $action_name with $action_params from $client")
return@parseSocketMessage player_action.execute(player, action_params)
}
}

throw NotImplementedError("Unknown action '$action_name' from $client")
}
}

private fun processClientMessage(message: Message, client: SpMsClient) {
val reply: List<SpMsActionReply> = processClientActions(message.parts, client)

if (reply.isNotEmpty()) {
sendMultipart(client.createMessage(listOf(Json.encodeToString(reply))))
}
}

fun onClientReadyToPlay(client_id: SpMsClientID, item_index: Int, item_id: String, item_duration_ms: Long) {
if (!playback_waiting_for_clients) {
return
Expand Down Expand Up @@ -234,7 +244,7 @@ class SpMs(mem_scope: MemScope, val headless: Boolean = false, enable_gui: Boole
event.init(
event_id = player_event_inc++,
client_id = if (clientless) null else executing_client_id,
client_amount = clients.size
client_amount = clients.count { it.type.receivesEvents() }
)
player_events.add(event)
}
Expand All @@ -254,25 +264,24 @@ class SpMs(mem_scope: MemScope, val headless: Boolean = false, enable_gui: Boole
return numbered_name
}

private fun onClientMessage(handshake_message: Message) {
val id: Int = handshake_message.client_id.contentHashCode()
private fun onClientMessage(message: Message) {
val id: Int = message.client_id.contentHashCode()

// Return if client is already added
if (clients.any { it.id == id }) {
return
}

val client_handshake: SpMsClientHandshake
try {
client_handshake = handshake_message.parts.firstOrNull()?.let { Json.decodeFromString(it) } ?: return
client_handshake = message.parts.firstOrNull()?.let { Json.decodeFromString(it) } ?: return
}
catch (e: SerializationException) {
RuntimeException("Exception while parsing the following handshake message: ${handshake_message.parts}", e).printStackTrace()
RuntimeException("Exception while parsing the following handshake message: ${message.parts}", e).printStackTrace()
return
}

val client: SpMsClient = SpMsClient(
handshake_message.client_id,
message.client_id,
SpMsClientInfo(
name = getNewClientName(client_handshake.name),
type = client_handshake.type,
Expand All @@ -283,15 +292,22 @@ class SpMs(mem_scope: MemScope, val headless: Boolean = false, enable_gui: Boole
player_event_inc
)

clients.add(client)
val action_replies: List<SpMsActionReply>?
if (client_handshake.actions != null) {
action_replies = processClientActions(client_handshake.actions, client)
}
else {
action_replies = null
}

val server_handshake: SpMsServerHandshake =
SpMsServerHandshake(
name = SpMs.application_name,
device_name = getDeviceName(),
spms_api_version = SPMS_API_VERSION,
server_state = player.getCurrentStateJson(),
machine_id = SpMs.getMachineId()
machine_id = SpMs.getMachineId(),
action_replies = action_replies
)

sendMultipart(
Expand All @@ -302,7 +318,10 @@ class SpMs(mem_scope: MemScope, val headless: Boolean = false, enable_gui: Boole
)
println("Sent connection reply to $client: $server_handshake")

onClientConnected(client)
if (client.type.receivesEvents()) {
clients.add(client)
onClientConnected(client)
}
}

private fun onClientConnected(client: SpMsClient) {
Expand Down Expand Up @@ -350,6 +369,15 @@ class SpMs(mem_scope: MemScope, val headless: Boolean = false, enable_gui: Boole
}
}

private var version_printed: Boolean = false
fun printVersionInfo(localisation: SpMsLocalisation) {
if (version_printed) {
return
}
println(localisation.versionInfoText(SPMS_API_VERSION))
version_printed = true
}

fun getMachineId(): String {
val id_path: Path =
when (Platform.osFamily) {
Expand Down
3 changes: 1 addition & 2 deletions src/commonMain/kotlin/spms/socketapi/ParseMessage.kt
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ fun parseSocketMessage(
val result: JsonElement?
try {
result = executeAction(action_name, action_params)
// result = Action.executeByName(this@SpMs, client.id, action_name, action_params)
}
catch (e: Throwable) {
val message: String = "Executing action $action_name(${action_params.map { it.toString() }}) failed"
Expand Down Expand Up @@ -65,6 +64,6 @@ fun parseSocketMessage(
)
}
}

return reply
}
Loading

0 comments on commit 2bf5ccd

Please sign in to comment.