Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Increased thread-safety by using Lock context manager syntax #192

Merged
merged 9 commits into from
Jan 3, 2024
10 changes: 5 additions & 5 deletions pyVoIP/RTP.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,11 +170,10 @@ def read(self, length: int = 160) -> bytes:
# This acts functionally as a lock while the buffer is being rebuilt.
while self.rebuilding:
time.sleep(0.01)
self.bufferLock.acquire()
packet = self.buffer.read(length)
if len(packet) < length:
packet = packet + (b"\x80" * (length - len(packet)))
self.bufferLock.release()
with self.bufferLock:
packet = self.buffer.read(length)
if len(packet) < length:
packet = packet + (b"\x80" * (length - len(packet)))
return packet

def rebuild(self, reset: bool, offset: int = 0, data: bytes = b"") -> None:
Expand All @@ -192,6 +191,7 @@ def rebuild(self, reset: bool, offset: int = 0, data: bytes = b"") -> None:
self.rebuilding = False

def write(self, offset: int, data: bytes) -> None:
# TODO: Can this safely be changed to use context manager syntax?
self.bufferLock.acquire()
self.log[offset] = data
bufferloc = self.buffer.tell()
Expand Down
197 changes: 99 additions & 98 deletions pyVoIP/SIP.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from enum import Enum, IntEnum
from threading import Timer, Lock
from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING
from pyVoIP.util import acquired_lock_and_unblocked_socket
from pyVoIP.VoIP.status import PhoneStatus
import pyVoIP
import hashlib
Expand Down Expand Up @@ -40,6 +41,10 @@ class SIPParseError(Exception):
pass


class RetryRequiredError(Exception):
pass


class Counter:
def __init__(self, start: int = 1):
self.x = start
Expand Down Expand Up @@ -842,40 +847,41 @@ def __init__(
self.registerFailures = 0
self.recvLock = Lock()

def recv(self) -> None:
def recv_loop(self) -> None:
while self.NSD:
self.recvLock.acquire()
self.s.setblocking(False)
try:
raw = self.s.recv(8192)
if raw != b"\x00\x00\x00\x00":
try:
message = SIPMessage(raw)
debug(message.summary())
self.parseMessage(message)
except Exception as ex:
debug(f"Error on header parsing: {ex}")
with acquired_lock_and_unblocked_socket(self.recvLock, self.s):
self.recv()
except BlockingIOError:
self.s.setblocking(True)
self.recvLock.release()
time.sleep(0.01)
continue
except SIPParseError as e:
if "SIP Version" in str(e):
request = self.genSIPVersionNotSupported(message)
self.out.sendto(
request.encode("utf8"), (self.server, self.port)
)
else:
debug(f"SIPParseError in SIP.recv: {type(e)}, {e}")
except Exception as e:
debug(f"SIP.recv error: {type(e)}, {e}\n\n{str(raw, 'utf8')}")
if pyVoIP.DEBUG:
self.s.setblocking(True)
self.recvLock.release()
raise
self.s.setblocking(True)
self.recvLock.release()

def recv(self) -> None:
try:
raw = self.s.recv(8192)
if raw != b"\x00\x00\x00\x00":
try:
message = SIPMessage(raw)
debug(message.summary())
self.parseMessage(message)
except Exception as ex:
debug(f"Error on header parsing: {ex}")
except SIPParseError as e:
if "SIP Version" in str(e):
request = self.genSIPVersionNotSupported(message)
self.out.sendto(
request.encode("utf8"), (self.server, self.port)
)
else:
debug(f"SIPParseError in SIP.recv: {type(e)}, {e}")
except BlockingIOError:
# Re-raise BlockingIOError so recv_loop() can release locks and
# continue
raise
except Exception as e:
debug(f"SIP.recv error: {type(e)}, {e}\n\n{str(raw, 'utf8')}")
if pyVoIP.DEBUG:
raise

def parseMessage(self, message: SIPMessage) -> None:
warnings.warn(
Expand Down Expand Up @@ -955,7 +961,7 @@ def start(self) -> None:
self.s.bind((self.myIP, self.myPort))
self.out = self.s
self.register()
t = Timer(1, self.recv)
t = Timer(1, self.recv_loop)
t.name = "SIP Recieve"
t.start()

Expand Down Expand Up @@ -1596,52 +1602,49 @@ def invite(
invite = self.genInvite(
number, str(sess_id), ms, sendtype, branch, call_id
)
self.recvLock.acquire()
self.out.sendto(invite.encode("utf8"), (self.server, self.port))
debug("Invited")
response = SIPMessage(self.s.recv(8192))

while (
response.status != SIPStatus(401)
and response.status != SIPStatus(100)
and response.status != SIPStatus(180)
) or response.headers["Call-ID"] != call_id:
if not self.NSD:
break
self.parseMessage(response)
with self.recvLock:
self.out.sendto(invite.encode("utf8"), (self.server, self.port))
debug("Invited")
response = SIPMessage(self.s.recv(8192))

if response.status == SIPStatus(100) or response.status == SIPStatus(
180
):
self.recvLock.release()
return SIPMessage(invite.encode("utf8")), call_id, sess_id
debug(f"Received Response: {response.summary()}")
ack = self.genAck(response)
self.out.sendto(ack.encode("utf8"), (self.server, self.port))
debug("Acknowledged")
authhash = self.genAuthorization(response)
nonce = response.authentication["nonce"]
realm = response.authentication["realm"]
auth = (
f'Authorization: Digest username="{self.username}",realm='
+ f'"{realm}",nonce="{nonce}",uri="sip:{self.server};'
+ f'transport=UDP",response="{str(authhash, "utf8")}",'
+ "algorithm=MD5\r\n"
)

invite = self.genInvite(
number, str(sess_id), ms, sendtype, branch, call_id
)
invite = invite.replace(
"\r\nContent-Length", f"\r\n{auth}Content-Length"
)
while (
response.status != SIPStatus(401)
and response.status != SIPStatus(100)
and response.status != SIPStatus(180)
) or response.headers["Call-ID"] != call_id:
if not self.NSD:
break
self.parseMessage(response)
response = SIPMessage(self.s.recv(8192))

if response.status == SIPStatus(
100
) or response.status == SIPStatus(180):
return SIPMessage(invite.encode("utf8")), call_id, sess_id
debug(f"Received Response: {response.summary()}")
ack = self.genAck(response)
self.out.sendto(ack.encode("utf8"), (self.server, self.port))
debug("Acknowledged")
authhash = self.genAuthorization(response)
nonce = response.authentication["nonce"]
realm = response.authentication["realm"]
auth = (
f'Authorization: Digest username="{self.username}",realm='
+ f'"{realm}",nonce="{nonce}",uri="sip:{self.server};'
+ f'transport=UDP",response="{str(authhash, "utf8")}",'
+ "algorithm=MD5\r\n"
)

self.out.sendto(invite.encode("utf8"), (self.server, self.port))
invite = self.genInvite(
number, str(sess_id), ms, sendtype, branch, call_id
)
invite = invite.replace(
"\r\nContent-Length", f"\r\n{auth}Content-Length"
)

self.recvLock.release()
self.out.sendto(invite.encode("utf8"), (self.server, self.port))

return SIPMessage(invite.encode("utf8")), call_id, sess_id
return SIPMessage(invite.encode("utf8")), call_id, sess_id

def bye(self, request: SIPMessage) -> None:
message = self.genBye(request)
Expand All @@ -1650,7 +1653,8 @@ def bye(self, request: SIPMessage) -> None:

def deregister(self) -> bool:
try:
deregistered = self.__deregister()
with self.recvLock:
deregistered = self.__deregister()
if not deregistered:
debug("DEREGISTERATION FAILED")
return False
Expand All @@ -1660,12 +1664,16 @@ def deregister(self) -> bool:
return deregistered
except BaseException as e:
debug(f"DEREGISTERATION ERROR: {e}")
# TODO: a maximum tries check should be implemented otherwise a
# RecursionError will throw
if isinstance(e, RetryRequiredError):
time.sleep(5)
return self.deregister()
if type(e) is OSError:
raise
return False

def __deregister(self) -> bool:
self.recvLock.acquire()
self.phone._status = PhoneStatus.DEREGISTERING
firstRequest = self.genFirstRequest(deregister=True)
self.out.sendto(firstRequest.encode("utf8"), (self.server, self.port))
Expand All @@ -1676,7 +1684,6 @@ def __deregister(self) -> bool:
if ready[0]:
resp = self.s.recv(8192)
else:
self.recvLock.release()
raise TimeoutError("Deregistering on SIP Server timed out")

response = SIPMessage(resp)
Expand All @@ -1696,7 +1703,6 @@ def __deregister(self) -> bool:
# At this point, it's reasonable to assume that
# this is caused by invalid credentials.
debug("Unauthorized")
self.recvLock.release()
raise InvalidAccountInfoError(
"Invalid Username or "
+ "Password for SIP server "
Expand All @@ -1710,23 +1716,20 @@ def __deregister(self) -> bool:
# with new urn:uuid or reply with expire 0
self._handle_bad_request()
else:
self.recvLock.release()
raise TimeoutError("Deregistering on SIP Server timed out")

if response.status == SIPStatus(500):
self.recvLock.release()
time.sleep(5)
return self.deregister()
# We raise so the calling function can sleep and try again
raise RetryRequiredError("Response SIP status of 500")

if response.status == SIPStatus.OK:
self.recvLock.release()
return True
self.recvLock.release()
return False

def register(self) -> bool:
try:
registered = self.__register()
with self.recvLock:
registered = self.__register()
if not registered:
debug("REGISTERATION FAILED")
self.registerFailures += 1
Expand All @@ -1749,6 +1752,9 @@ def register(self) -> bool:
self.stop()
self.fatalCallback()
return False
if isinstance(e, RetryRequiredError):
time.sleep(5)
return self.register()
self.__start_register_timer(delay=0)

def __start_register_timer(self, delay: Optional[int] = None):
Expand All @@ -1764,7 +1770,6 @@ def __start_register_timer(self, delay: Optional[int] = None):
self.registerThread.start()

def __register(self) -> bool:
self.recvLock.acquire()
self.phone._status = PhoneStatus.REGISTERING
firstRequest = self.genFirstRequest()
self.out.sendto(firstRequest.encode("utf8"), (self.server, self.port))
Expand All @@ -1775,7 +1780,6 @@ def __register(self) -> bool:
if ready[0]:
resp = self.s.recv(8192)
else:
self.recvLock.release()
raise TimeoutError("Registering on SIP Server timed out")

response = SIPMessage(resp)
Expand Down Expand Up @@ -1814,7 +1818,6 @@ def __register(self) -> bool:
debug("\nRECEIVED")
debug(response.summary())
debug("=" * 50)
self.recvLock.release()
raise InvalidAccountInfoError(
"Invalid Username or "
+ "Password for SIP server "
Expand All @@ -1828,7 +1831,6 @@ def __register(self) -> bool:
# with new urn:uuid or reply with expire 0
self._handle_bad_request()
else:
self.recvLock.release()
raise TimeoutError("Registering on SIP Server timed out")

if response.status == SIPStatus(407):
Expand All @@ -1844,17 +1846,15 @@ def __register(self) -> bool:
]:
# Unauthorized
if response.status == SIPStatus(500):
self.recvLock.release()
time.sleep(5)
return self.register()
# We raise so the calling function can sleep and try again
raise RetryRequiredError("Response SIP status of 500")
else:
# TODO: determine if needed here
self.parseMessage(response)

debug(response.summary())
debug(response.raw)

self.recvLock.release()
if response.status == SIPStatus.OK:
return True
else:
Expand All @@ -1873,16 +1873,17 @@ def _handle_bad_request(self) -> None:

def subscribe(self, lastresponse: SIPMessage) -> None:
# TODO: check if needed and maybe implement fully
self.recvLock.acquire()

subRequest = self.genSubscribe(lastresponse)
self.out.sendto(subRequest.encode("utf8"), (self.server, self.port))

response = SIPMessage(self.s.recv(8192))
with self.recvLock:
subRequest = self.genSubscribe(lastresponse)
self.out.sendto(
subRequest.encode("utf8"), (self.server, self.port)
)

debug(f'Got response to subscribe: {str(response.heading, "utf8")}')
response = SIPMessage(self.s.recv(8192))

self.recvLock.release()
debug(
f'Got response to subscribe: {str(response.heading, "utf8")}'
)

def trying_timeout_check(self, response: SIPMessage) -> SIPMessage:
"""
Expand Down
Loading
Loading