diff --git a/src/kotlin/flwr/src/main/java/dev/flower/android/Serde.kt b/src/kotlin/flwr/src/main/java/dev/flower/android/Serde.kt index dc764295b18f..79038ddfe6ab 100644 --- a/src/kotlin/flwr/src/main/java/dev/flower/android/Serde.kt +++ b/src/kotlin/flwr/src/main/java/dev/flower/android/Serde.kt @@ -1,13 +1,13 @@ package dev.flower.android -import java.nio.ByteBuffer import com.google.protobuf.ByteString -import flwr.proto.Transport.ServerMessage import flwr.proto.Transport.ClientMessage -import flwr.proto.Transport.Parameters as ProtoParameters -import flwr.proto.Transport.Status as ProtoStatus import flwr.proto.Transport.Reason +import flwr.proto.Transport.ServerMessage +import java.nio.ByteBuffer +import flwr.proto.Transport.Parameters as ProtoParameters import flwr.proto.Transport.Scalar as ProtoScalar +import flwr.proto.Transport.Status as ProtoStatus internal fun parametersToProto(parameters: Parameters): ProtoParameters { val tensors: MutableList = ArrayList() @@ -22,11 +22,11 @@ internal fun parametersFromProto(msg: ProtoParameters): Parameters { } internal fun statusToProto(status: Status): ProtoStatus { - return ProtoStatus.newBuilder().setCodeValue(status.code).setMessage(status.message).build() + return ProtoStatus.newBuilder().setCodeValue(status.code.value).setMessage(status.message).build() } internal fun statusFromProto(msg: ProtoStatus): Status { - return Status(msg.codeValue, msg.message) + return Status(Code.fromInt(msg.codeValue), msg.message) } internal fun reconnectInsToProto(ins: ReconnectIns): ServerMessage.ReconnectIns { @@ -47,7 +47,7 @@ internal fun disconnectResToProto(res: DisconnectRes): ClientMessage.DisconnectR "WIFI_UNAVAILABLE" -> Reason.WIFI_UNAVAILABLE else -> Reason.UNKNOWN } - return ClientMessage.DisconnectRes.newBuilder().setReason(reason).build() + return ClientMessage.DisconnectRes.newBuilder().setReason(reason).build() } internal fun disconnectResFromProto(msg: ClientMessage.DisconnectRes): DisconnectRes { @@ -173,24 +173,23 @@ internal fun metricsFromProto(proto: Map): Metrics { return proto.mapValues { (_, value) -> scalarFromProto(value) } } -internal inline fun scalarToProto(scalar: Scalar): ProtoScalar { - return when (T::class) { - Scalar.BoolValue::class -> ProtoScalar.newBuilder().setBool(scalar.value as Boolean).build() - Scalar.BytesValue::class -> ProtoScalar.newBuilder().setBytes(scalar.value as ByteString).build() - Scalar.DoubleValue::class -> ProtoScalar.newBuilder().setDouble(scalar.value as Double).build() - Scalar.SInt64Value::class -> ProtoScalar.newBuilder().setSint64(scalar.value as Long).build() - Scalar.StringValue::class -> ProtoScalar.newBuilder().setString(scalar.value as String).build() - else -> throw IllegalArgumentException("Accepted Types : Bool, Data, Float, Int, Str") +internal fun scalarToProto(scalar: Scalar): ProtoScalar { + return when (scalar) { + is Scalar.BoolValue -> ProtoScalar.newBuilder().setBool(scalar.value).build() + is Scalar.BytesValue -> ProtoScalar.newBuilder().setBytes(ByteString.copyFrom(scalar.value)).build() + is Scalar.SInt64Value -> ProtoScalar.newBuilder().setSint64(scalar.value).build() + is Scalar.DoubleValue -> ProtoScalar.newBuilder().setDouble(scalar.value).build() + is Scalar.StringValue -> ProtoScalar.newBuilder().setString(scalar.value).build() } } -internal inline fun scalarFromProto(scalarMsg: ProtoScalar): Scalar { - return when (scalarMsg.scalarCase) { - ProtoScalar.ScalarCase.BOOL -> Scalar.BoolValue(scalarMsg.bool) - ProtoScalar.ScalarCase.BYTES -> Scalar.BytesValue(scalarMsg.bytes) - ProtoScalar.ScalarCase.DOUBLE -> Scalar.DoubleValue(scalarMsg.double) - ProtoScalar.ScalarCase.SINT64 -> Scalar.SInt64Value(scalarMsg.sint64) - ProtoScalar.ScalarCase.STRING -> Scalar.StringValue(scalarMsg.string) - else -> throw IllegalArgumentException("Accepted Types : Bool, Data, Float, Int, Str") - } as Scalar +internal fun scalarFromProto(scalarMsg: ProtoScalar): Scalar { + return when (scalarMsg.scalarCase) { + ProtoScalar.ScalarCase.BOOL -> Scalar.BoolValue(scalarMsg.bool) + ProtoScalar.ScalarCase.BYTES -> Scalar.BytesValue(ByteBuffer.wrap(scalarMsg.bytes.toByteArray())) + ProtoScalar.ScalarCase.DOUBLE -> Scalar.DoubleValue(scalarMsg.double) + ProtoScalar.ScalarCase.SINT64 -> Scalar.SInt64Value(scalarMsg.sint64) + ProtoScalar.ScalarCase.STRING -> Scalar.StringValue(scalarMsg.string) + else -> throw IllegalArgumentException("Accepted Types : Bool, Data, Float, Int, Str") + } } diff --git a/src/kotlin/flwr/src/main/java/dev/flower/android/TaskHandler.kt b/src/kotlin/flwr/src/main/java/dev/flower/android/TaskHandler.kt index 5c28c22d5378..6e2874d1e792 100644 --- a/src/kotlin/flwr/src/main/java/dev/flower/android/TaskHandler.kt +++ b/src/kotlin/flwr/src/main/java/dev/flower/android/TaskHandler.kt @@ -38,7 +38,7 @@ internal fun wrapClientMessageInTaskRes(clientMessage: ClientMessage): TaskRes { return TaskRes.newBuilder() .setTaskId("") .setGroupId("") - .setWorkloadId("") + .setWorkloadId(0) .setTask(Task.newBuilder().addAllAncestry(emptyList()).setLegacyClientMessage(clientMessage)) .build() } diff --git a/src/kotlin/flwr/src/main/java/dev/flower/android/Typing.kt b/src/kotlin/flwr/src/main/java/dev/flower/android/Typing.kt index 6aa912979573..a88af0e28974 100644 --- a/src/kotlin/flwr/src/main/java/dev/flower/android/Typing.kt +++ b/src/kotlin/flwr/src/main/java/dev/flower/android/Typing.kt @@ -1,50 +1,43 @@ package dev.flower.android -import androidx.annotation.IntDef -import com.google.protobuf.ByteString import java.nio.ByteBuffer /** * Represents a map of metric values. */ -typealias Metrics = Map> +typealias Metrics = Map /** * Represents a map of configuration values. */ -typealias Config = Map> +typealias Config = Map /** * Represents a map of properties. */ -typealias Properties = Map> +typealias Properties = Map + -@IntDef( - Code.OK, - Code.GET_PROPERTIES_NOT_IMPLEMENTED, - Code.GET_PARAMETERS_NOT_IMPLEMENTED, - Code.FIT_NOT_IMPLEMENTED, - Code.EVALUATE_NOT_IMPLEMENTED -) -@Retention(AnnotationRetention.SOURCE) -annotation class CodeAnnotation /** * The `Code` class defines client status codes used in the application. */ -object Code { - // Client status codes. - const val OK: Int = 0 - const val GET_PROPERTIES_NOT_IMPLEMENTED: Int = 1 - const val GET_PARAMETERS_NOT_IMPLEMENTED: Int = 2 - const val FIT_NOT_IMPLEMENTED: Int = 3 - const val EVALUATE_NOT_IMPLEMENTED: Int = 4 +enum class Code(val value: Int) { + OK(1), + GET_PROPERTIES_NOT_IMPLEMENTED(2), + GET_PARAMETERS_NOT_IMPLEMENTED(3), + FIT_NOT_IMPLEMENTED(4), + EVALUATE_NOT_IMPLEMENTED(5); + + companion object { + fun fromInt(value: Int): Code = values().first { it.value == value } + } } /** * Client status. */ -data class Status(val code: Int, val message: String) +data class Status(val code: Code, val message: String) /** * The `Scalar` class represents a scalar value that can have different data types. @@ -54,13 +47,12 @@ data class Status(val code: Int, val message: String) * some of them arguably do not conform to other definitions of what a scalar is. Source: * https://developers.google.com/protocol-buffers/docs/overview#scalar */ -sealed class Scalar { - abstract val value: T - data class BoolValue(override val value: Boolean): Scalar() - data class BytesValue(override val value: ByteString): Scalar() - data class SInt64Value(override val value: Long): Scalar() - data class DoubleValue(override val value: Double): Scalar() - data class StringValue(override val value: String): Scalar() +sealed class Scalar { + class BoolValue(val value: Boolean) : Scalar() + class BytesValue(val value: ByteBuffer) : Scalar() + class SInt64Value(val value: Long) : Scalar() + class DoubleValue(val value: Double) : Scalar() + class StringValue(val value: String) : Scalar() } /**