Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save node in datastore #2687

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 116 additions & 29 deletions src/kotlin/flwr/src/main/java/dev/flower/android/Grpc.kt
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
package dev.flower.android

import android.content.Context
import android.util.Log
import androidx.datastore.core.CorruptionException
import androidx.datastore.core.DataStore
import androidx.datastore.core.Serializer
import androidx.datastore.dataStore
import androidx.datastore.preferences.protobuf.InvalidProtocolBufferException
import flwr.proto.FleetGrpc
import flwr.proto.FleetOuterClass.CreateNodeRequest
import flwr.proto.FleetOuterClass.CreateNodeResponse
Expand All @@ -8,21 +15,47 @@ import flwr.proto.FleetOuterClass.DeleteNodeResponse
import flwr.proto.FleetOuterClass.PullTaskInsRequest
import flwr.proto.FleetOuterClass.PullTaskInsResponse
import flwr.proto.FleetOuterClass.PushTaskResRequest
import io.grpc.ManagedChannel
import io.grpc.ManagedChannelBuilder
import io.grpc.stub.StreamObserver
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
import java.util.concurrent.CountDownLatch
import flwr.proto.FlowerServiceGrpc
import flwr.proto.NodeOuterClass.Node
import flwr.proto.TaskOuterClass.TaskIns
import flwr.proto.TaskOuterClass.TaskRes
import flwr.proto.Transport.ServerMessage
import io.grpc.ManagedChannel
import io.grpc.ManagedChannelBuilder
import io.grpc.stub.StreamObserver
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.cancel
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.withContext
import java.io.InputStream
import java.io.OutputStream
import java.util.concurrent.CountDownLatch

object NodeSerializer : Serializer<Node> {
override val defaultValue: Node = Node.getDefaultInstance()

override suspend fun readFrom(input: InputStream): Node {
try {
return Node.parseFrom(input)
} catch (exception: InvalidProtocolBufferException) {
throw CorruptionException("Cannot read proto.", exception)
}
}

override suspend fun writeTo(
t: Node,
output: OutputStream
) = t.writeTo(output)
}

val Context.nodeDataStore: DataStore<Node> by dataStore(
fileName = "fleet.pb",
serializer = NodeSerializer
)

internal class FlowerGrpc
@Throws constructor(
Expand All @@ -35,6 +68,7 @@ internal class FlowerGrpc

private val requestObserver = asyncStub.join(object : StreamObserver<ServerMessage> {
override fun onNext(msg: ServerMessage) {
Log.i("Flower Grpc", "Receive message: $msg.")
try {
sendResponse(msg)
} catch (e: Exception) {
Expand All @@ -54,6 +88,7 @@ internal class FlowerGrpc

fun sendResponse(msg: ServerMessage) {
val response = handleLegacyMessage(client, msg)
Log.i("Flower Grpc", "Send message: $response.")
requestObserver.onNext(response.first)
}
}
Expand Down Expand Up @@ -85,12 +120,29 @@ internal suspend fun createChannel(address: String, useTLS: Boolean = false): Ma
}
}

suspend fun startClient(host: String, port: Int, useTls: Boolean, client: Client) {
FlowerGrpc(createChannel(host, port, useTls), client)
}

internal suspend fun createChannel(host: String, port: Int, useTls: Boolean): ManagedChannel {
val channelBuilder =
ManagedChannelBuilder.forAddress(host, port).maxInboundMessageSize(HUNDRED_MEBIBYTE)
if (!useTls) {
channelBuilder.usePlaintext()
}
return withContext(Dispatchers.IO) {
channelBuilder.build()
}
}


const val HUNDRED_MEBIBYTE = 100 * 1024 * 1024

internal class FlwrRere
@Throws constructor(
channel: ManagedChannel,
private val client: Client,
private val context: Context
) {

private val KEYNODE = "node"
Expand All @@ -106,28 +158,50 @@ internal class FlwrRere
private fun createNode() {
val createNodeRequest = CreateNodeRequest.newBuilder().build()

asyncStub.createNode(createNodeRequest, object : StreamObserver<CreateNodeResponse> {
override fun onNext(value: CreateNodeResponse?) {
nodeStore[KEYNODE] = value?.node
}
try {
asyncStub.createNode(createNodeRequest, object : StreamObserver<CreateNodeResponse> {
override fun onNext(value: CreateNodeResponse?) {
value?.let { response ->
runBlocking {
context.nodeDataStore.updateData { node ->
node.toBuilder()
.setNodeId(response.node.nodeId)
.setAnonymous(response.node.anonymous)
.build()
}
}

override fun onError(t: Throwable?) {
t?.printStackTrace()
finishLatch.countDown()
}
}
}

override fun onCompleted() {
finishLatch.countDown()
}
})
override fun onError(t: Throwable?) {
t?.printStackTrace()
finishLatch.countDown()
}

override fun onCompleted() {
finishLatch.countDown()
}
})
}catch(e: Exception) {
e.printStackTrace()
Log.i("Flower Grpc", "Create node not implemented.")
}
}

private fun deleteNode() {
nodeStore[KEYNODE]?.let { node ->
val deleteNodeRequest = DeleteNodeRequest.newBuilder().setNode(node).build()
private suspend fun deleteNode() {
context.nodeDataStore.data.collect { data ->
val deleteNodeRequest = DeleteNodeRequest.newBuilder().setNode(data).build()
asyncStub.deleteNode(deleteNodeRequest, object : StreamObserver<DeleteNodeResponse> {
override fun onNext(value: DeleteNodeResponse?) {
nodeStore[KEYNODE] = null
runBlocking {
context.nodeDataStore.updateData { node ->
node.toBuilder()
.setNodeId(0)
.setAnonymous(false)
.build()
}
}
}

override fun onError(t: Throwable?) {
Expand All @@ -142,18 +216,26 @@ internal class FlwrRere
}
}

private suspend fun request(requestChannel: Channel<PullTaskInsRequest>, node: Node) {
val request = PullTaskInsRequest.newBuilder().setNode(node).build()
private suspend fun request(requestChannel: Channel<PullTaskInsRequest>, node: Node?) {
val request = if (node != null) {
PullTaskInsRequest.newBuilder().setNode(node).build()
} else {
PullTaskInsRequest.newBuilder().build()
}

Log.i("Flower Grpc", "Sending request $request")
requestChannel.send(request)
}

private suspend fun receive(requestChannel: Channel<PullTaskInsRequest>, node: Node) = flow {
private suspend fun receive(requestChannel: Channel<PullTaskInsRequest>, node: Node?, withTimeout: Boolean = false) = flow {
coroutineScope {
var numberOfTries = 0
val responses = Channel<TaskIns?>(1)
for (request in requestChannel)
asyncStub.pullTaskIns(request, object : StreamObserver<PullTaskInsResponse> {
override fun onNext(value: PullTaskInsResponse?) {
val taskIns = value?.let { getTaskIns(it) }
Log.i("Flower Grpc", "Receive $taskIns")
if (taskIns != null && validateTaskIns(taskIns, true)) {
state[KEYTASKINS] = taskIns
responses.trySend(taskIns).isSuccess
Expand All @@ -171,9 +253,14 @@ internal class FlwrRere

for (response in responses) {
if (response == null) {
if (numberOfTries >= 10) {
cancel("Timeout")
}
delay(3000)
numberOfTries++
request(requestChannel, node)
} else {
numberOfTries = 0
emit(response)
}
}
Expand All @@ -183,10 +270,9 @@ internal class FlwrRere
suspend fun startGrpcRere() {
createNode()

val node = nodeStore[KEYNODE]
val node: Node? = nodeStore[KEYNODE]
if (node == null) {
println("Node not available")
return
}

val requestChannel = Channel<PullTaskInsRequest>(1)
Expand Down Expand Up @@ -227,10 +313,11 @@ internal class FlwrRere
}
}

suspend fun createFlowerRere(
suspend fun startFlowerRere(
serverAddress: String,
useTLS: Boolean,
client: Client,
context: Context
) {
FlwrRere(createChannel(serverAddress, useTLS), client)
FlwrRere(createChannel(serverAddress, useTLS), client, context).startGrpcRere()
}