Skip to content

Commit

Permalink
fix(ktx): fix SrtSocket remote disconnection
Browse files Browse the repository at this point in the history
  • Loading branch information
ThibaultBee committed Sep 19, 2024
1 parent 0b15442 commit 230a83e
Showing 1 changed file with 75 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,11 @@ import io.github.thibaultbee.srtdroid.core.models.rejectreason.InternalRejectRea
import io.github.thibaultbee.srtdroid.core.models.rejectreason.PredefinedRejectReason
import io.github.thibaultbee.srtdroid.core.models.rejectreason.RejectReason
import io.github.thibaultbee.srtdroid.core.models.rejectreason.UserDefinedRejectReason
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.CompletableJob
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.Job
import kotlinx.coroutines.cancelChildren
import kotlinx.coroutines.channels.awaitClose
import kotlinx.coroutines.channels.trySendBlocking
import kotlinx.coroutines.flow.callbackFlow
Expand All @@ -38,7 +37,6 @@ import java.net.InetSocketAddress
import java.net.SocketException
import java.net.SocketTimeoutException
import java.nio.ByteBuffer
import kotlin.coroutines.CoroutineContext
import kotlin.coroutines.resumeWithException
import kotlin.math.min

Expand All @@ -52,32 +50,11 @@ private constructor(
ConfigurableSrtSocket, CoroutineScope {
constructor() : this(SrtSocket())

init {
socket.setSockFlag(SockOpt.RCVSYN, false)
socket.setSockFlag(SockOpt.SNDSYN, false)

socket.clientListener =
object : SrtSocket.ClientListener {
override fun onConnectionLost(
ns: SrtSocket,
error: ErrorType,
peerAddress: InetSocketAddress,
token: Int
) {
if (hasBeenConnected) {
socketContext.completeExceptionally(ConnectException(error.toString()))
coroutineContext.cancelChildren()
}
}
}
}

private var hasBeenConnected = false

val socketContext: CompletableJob = Job()

@OptIn(ExperimentalCoroutinesApi::class)
override val coroutineContext: CoroutineContext = Dispatchers.IO.limitedParallelism(1)
override val coroutineContext = socketContext

/**
* Flow of incoming sockets.
Expand Down Expand Up @@ -195,6 +172,36 @@ private constructor(
val localPort: Int
get() = socket.localPort

init {
socket.setSockFlag(SockOpt.RCVSYN, false)
socket.setSockFlag(SockOpt.SNDSYN, false)

socket.clientListener =
object : SrtSocket.ClientListener {
override fun onConnectionLost(
ns: SrtSocket,
error: ErrorType,
peerAddress: InetSocketAddress,
token: Int
) {
if (hasBeenConnected) {
socket.close()
complete(ConnectException(error.toString()))
}
}
}
}

private fun complete(t: Throwable? = null) {
if (!socketContext.isCompleted) {
if (t != null) {
socketContext.completeExceptionally(t)
} else {
socketContext.complete()
}
}
}

/**
* Closes the socket or group and frees all used resources.
*
Expand All @@ -203,9 +210,12 @@ private constructor(
* @throws SocketException if close failed
*/
fun close() {
coroutineContext.cancelChildren()
socket.close()
socketContext.complete()
try {
socket.close()
complete()
} catch (t: Throwable) {
complete(t)
}
}

/**
Expand All @@ -217,7 +227,7 @@ private constructor(
*
* @throws BindException if bind has failed
*/
suspend fun bind(address: InetSocketAddress) = withContext(coroutineContext) {
suspend fun bind(address: InetSocketAddress) = withContext(Dispatchers.IO) {
socket.bind(address)
hasBeenConnected = true
}
Expand Down Expand Up @@ -760,7 +770,7 @@ private constructor(
epoll.addUSock(socket, listOf(EpollOpt.ERR, epollOpt))

try {
return withContext(coroutineContext) {
return withContext(Dispatchers.IO) {
if (timeoutInMs == null) {
executeEpoll(epoll, onContinuation, block)
} else {
Expand Down Expand Up @@ -796,58 +806,58 @@ private constructor(
block: () -> T
): T {
return suspendCancellableCoroutine { continuation ->
continuation.invokeOnCancellation {
continuation.invokeOnCancellation { t ->
epoll.clearUSock()
}
onContinuation()
while (isActive) {
val epollEvents = try {
epoll.uWait(POLLING_TIMEOUT_IN_MS)
} catch (e: Exception) {
if (SrtError.lastError != ErrorType.EPOLLEMPTY) {
continuation.resumeWithException(SocketException(SrtError.lastErrorMessage))
try {
val epollEvents = epoll.uWait(POLLING_TIMEOUT_IN_MS)
if (epollEvents.isEmpty()) {
continue
}
return@suspendCancellableCoroutine
}
if (epollEvents.isEmpty()) {
continue
}
val socketEvents = epollEvents.filter { it.socket == socket }
if (socketEvents.isEmpty()) {
continue
}
epoll.addUSock(socket, null) // Unsubscribe from all events

if (socketEvents.any { it.events.contains(EpollOpt.ERR) }) {
if (sockState == SockStatus.BROKEN) {
continuation.resumeWithException(SocketException("Socket is broken. Maybe due to timeout?"))
} else {
if (SrtError.lastError != ErrorType.SUCCESS) {
continuation.resumeWithException(SocketException(SrtError.lastErrorMessage))
} else {
continuation.resumeWithException(SocketException("Epoll returned an unknown error"))
}
val socketEvents = epollEvents.filter { it.socket == socket }
if (socketEvents.isEmpty()) {
continue
}
} else {

try {
if (socketEvents.any {
it.events.contains(EpollOpt.IN) || it.events.contains(
EpollOpt.OUT
)
}) {
continuation.resumeWith(Result.success(block()))
epoll.addUSock(socket, null) // Unsubscribe to all events
} catch (_: Throwable) {
// Ignore
}

if (socketEvents.any { it.events.contains(EpollOpt.ERR) }) {
if ((SrtError.lastError != ErrorType.SUCCESS) && (SrtError.lastError != ErrorType.EPOLLEMPTY)) {
throw SocketException(SrtError.lastErrorMessage)
} else {
if ((sockState == SockStatus.BROKEN) || (sockState == SockStatus.CLOSED)) {
throw SocketException("Connection was broken")
} else {
throw SocketException("Epoll returned an unknown error (sockState = $sockState)")
}
}
} catch (e: Exception) {
continuation.resumeWithException(e)
} else if (socketEvents.any {
it.events.contains(EpollOpt.IN) || it.events.contains(
EpollOpt.OUT
)
}) {
continuation.resumeWith(Result.success(block()))
} else {
throw SocketException("Epoll returned an unknown event: $socketEvents")
}
} catch (e: Exception) {
continuation.resumeWithException(e)
}
return@suspendCancellableCoroutine
}
continuation.resumeWithException(CancellationException())
}
}

companion object {
private const val TAG = "CoroutineSocket"
private const val TAG = "CoroutineSrtSocket"

private const val POLLING_TIMEOUT_IN_MS = 1000L
}
Expand Down

0 comments on commit 230a83e

Please sign in to comment.