Skip to content

Commit

Permalink
Remove generics in Scalar (#2555)
Browse files Browse the repository at this point in the history
  • Loading branch information
danielnugraha authored Nov 6, 2023
1 parent e80ca33 commit c82c9e2
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 54 deletions.
47 changes: 23 additions & 24 deletions src/kotlin/flwr/src/main/java/dev/flower/android/Serde.kt
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
package dev.flower.android

import java.nio.ByteBuffer
import com.google.protobuf.ByteString
import flwr.proto.Transport.ServerMessage
import flwr.proto.Transport.ClientMessage
import flwr.proto.Transport.Parameters as ProtoParameters
import flwr.proto.Transport.Status as ProtoStatus
import flwr.proto.Transport.Reason
import flwr.proto.Transport.ServerMessage
import java.nio.ByteBuffer
import flwr.proto.Transport.Parameters as ProtoParameters
import flwr.proto.Transport.Scalar as ProtoScalar
import flwr.proto.Transport.Status as ProtoStatus

internal fun parametersToProto(parameters: Parameters): ProtoParameters {
val tensors: MutableList<ByteString> = ArrayList()
Expand All @@ -22,11 +22,11 @@ internal fun parametersFromProto(msg: ProtoParameters): Parameters {
}

internal fun statusToProto(status: Status): ProtoStatus {
return ProtoStatus.newBuilder().setCodeValue(status.code).setMessage(status.message).build()
return ProtoStatus.newBuilder().setCodeValue(status.code.value).setMessage(status.message).build()
}

internal fun statusFromProto(msg: ProtoStatus): Status {
return Status(msg.codeValue, msg.message)
return Status(Code.fromInt(msg.codeValue), msg.message)
}

internal fun reconnectInsToProto(ins: ReconnectIns): ServerMessage.ReconnectIns {
Expand All @@ -47,7 +47,7 @@ internal fun disconnectResToProto(res: DisconnectRes): ClientMessage.DisconnectR
"WIFI_UNAVAILABLE" -> Reason.WIFI_UNAVAILABLE
else -> Reason.UNKNOWN
}
return ClientMessage.DisconnectRes.newBuilder().setReason(reason).build()
return ClientMessage.DisconnectRes.newBuilder().setReason(reason).build()
}

internal fun disconnectResFromProto(msg: ClientMessage.DisconnectRes): DisconnectRes {
Expand Down Expand Up @@ -173,24 +173,23 @@ internal fun metricsFromProto(proto: Map<String, ProtoScalar>): Metrics {
return proto.mapValues { (_, value) -> scalarFromProto(value) }
}

internal inline fun<reified T> scalarToProto(scalar: Scalar<T>): ProtoScalar {
return when (T::class) {
Scalar.BoolValue::class -> ProtoScalar.newBuilder().setBool(scalar.value as Boolean).build()
Scalar.BytesValue::class -> ProtoScalar.newBuilder().setBytes(scalar.value as ByteString).build()
Scalar.DoubleValue::class -> ProtoScalar.newBuilder().setDouble(scalar.value as Double).build()
Scalar.SInt64Value::class -> ProtoScalar.newBuilder().setSint64(scalar.value as Long).build()
Scalar.StringValue::class -> ProtoScalar.newBuilder().setString(scalar.value as String).build()
else -> throw IllegalArgumentException("Accepted Types : Bool, Data, Float, Int, Str")
internal fun scalarToProto(scalar: Scalar): ProtoScalar {
return when (scalar) {
is Scalar.BoolValue -> ProtoScalar.newBuilder().setBool(scalar.value).build()
is Scalar.BytesValue -> ProtoScalar.newBuilder().setBytes(ByteString.copyFrom(scalar.value)).build()
is Scalar.SInt64Value -> ProtoScalar.newBuilder().setSint64(scalar.value).build()
is Scalar.DoubleValue -> ProtoScalar.newBuilder().setDouble(scalar.value).build()
is Scalar.StringValue -> ProtoScalar.newBuilder().setString(scalar.value).build()
}
}

internal inline fun <reified T> scalarFromProto(scalarMsg: ProtoScalar): Scalar<T> {
return when (scalarMsg.scalarCase) {
ProtoScalar.ScalarCase.BOOL -> Scalar.BoolValue(scalarMsg.bool)
ProtoScalar.ScalarCase.BYTES -> Scalar.BytesValue(scalarMsg.bytes)
ProtoScalar.ScalarCase.DOUBLE -> Scalar.DoubleValue(scalarMsg.double)
ProtoScalar.ScalarCase.SINT64 -> Scalar.SInt64Value(scalarMsg.sint64)
ProtoScalar.ScalarCase.STRING -> Scalar.StringValue(scalarMsg.string)
else -> throw IllegalArgumentException("Accepted Types : Bool, Data, Float, Int, Str")
} as Scalar<T>
internal fun scalarFromProto(scalarMsg: ProtoScalar): Scalar {
return when (scalarMsg.scalarCase) {
ProtoScalar.ScalarCase.BOOL -> Scalar.BoolValue(scalarMsg.bool)
ProtoScalar.ScalarCase.BYTES -> Scalar.BytesValue(ByteBuffer.wrap(scalarMsg.bytes.toByteArray()))
ProtoScalar.ScalarCase.DOUBLE -> Scalar.DoubleValue(scalarMsg.double)
ProtoScalar.ScalarCase.SINT64 -> Scalar.SInt64Value(scalarMsg.sint64)
ProtoScalar.ScalarCase.STRING -> Scalar.StringValue(scalarMsg.string)
else -> throw IllegalArgumentException("Accepted Types : Bool, Data, Float, Int, Str")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ internal fun wrapClientMessageInTaskRes(clientMessage: ClientMessage): TaskRes {
return TaskRes.newBuilder()
.setTaskId("")
.setGroupId("")
.setWorkloadId("")
.setWorkloadId(0)
.setTask(Task.newBuilder().addAllAncestry(emptyList()).setLegacyClientMessage(clientMessage))
.build()
}
50 changes: 21 additions & 29 deletions src/kotlin/flwr/src/main/java/dev/flower/android/Typing.kt
Original file line number Diff line number Diff line change
@@ -1,50 +1,43 @@
package dev.flower.android

import androidx.annotation.IntDef
import com.google.protobuf.ByteString
import java.nio.ByteBuffer

/**
* Represents a map of metric values.
*/
typealias Metrics = Map<String, Scalar<Any>>
typealias Metrics = Map<String, Scalar>

/**
* Represents a map of configuration values.
*/
typealias Config = Map<String, Scalar<Any>>
typealias Config = Map<String, Scalar>

/**
* Represents a map of properties.
*/
typealias Properties = Map<String, Scalar<Any>>
typealias Properties = Map<String, Scalar>


@IntDef(
Code.OK,
Code.GET_PROPERTIES_NOT_IMPLEMENTED,
Code.GET_PARAMETERS_NOT_IMPLEMENTED,
Code.FIT_NOT_IMPLEMENTED,
Code.EVALUATE_NOT_IMPLEMENTED
)
@Retention(AnnotationRetention.SOURCE)
annotation class CodeAnnotation

/**
* The `Code` class defines client status codes used in the application.
*/
object Code {
// Client status codes.
const val OK: Int = 0
const val GET_PROPERTIES_NOT_IMPLEMENTED: Int = 1
const val GET_PARAMETERS_NOT_IMPLEMENTED: Int = 2
const val FIT_NOT_IMPLEMENTED: Int = 3
const val EVALUATE_NOT_IMPLEMENTED: Int = 4
enum class Code(val value: Int) {
OK(1),
GET_PROPERTIES_NOT_IMPLEMENTED(2),
GET_PARAMETERS_NOT_IMPLEMENTED(3),
FIT_NOT_IMPLEMENTED(4),
EVALUATE_NOT_IMPLEMENTED(5);

companion object {
fun fromInt(value: Int): Code = values().first { it.value == value }
}
}

/**
* Client status.
*/
data class Status(val code: Int, val message: String)
data class Status(val code: Code, val message: String)

/**
* The `Scalar` class represents a scalar value that can have different data types.
Expand All @@ -54,13 +47,12 @@ data class Status(val code: Int, val message: String)
* some of them arguably do not conform to other definitions of what a scalar is. Source:
* https://developers.google.com/protocol-buffers/docs/overview#scalar
*/
sealed class Scalar<T> {
abstract val value: T
data class BoolValue(override val value: Boolean): Scalar<Boolean>()
data class BytesValue(override val value: ByteString): Scalar<ByteString>()
data class SInt64Value(override val value: Long): Scalar<Long>()
data class DoubleValue(override val value: Double): Scalar<Double>()
data class StringValue(override val value: String): Scalar<String>()
sealed class Scalar {
class BoolValue(val value: Boolean) : Scalar()
class BytesValue(val value: ByteBuffer) : Scalar()
class SInt64Value(val value: Long) : Scalar()
class DoubleValue(val value: Double) : Scalar()
class StringValue(val value: String) : Scalar()
}

/**
Expand Down

0 comments on commit c82c9e2

Please sign in to comment.