diff --git a/pyVoIP/RTP.py b/pyVoIP/RTP.py index ac329d1..412a68c 100644 --- a/pyVoIP/RTP.py +++ b/pyVoIP/RTP.py @@ -170,11 +170,10 @@ def read(self, length: int = 160) -> bytes: # This acts functionally as a lock while the buffer is being rebuilt. while self.rebuilding: time.sleep(0.01) - self.bufferLock.acquire() - packet = self.buffer.read(length) - if len(packet) < length: - packet = packet + (b"\x80" * (length - len(packet))) - self.bufferLock.release() + with self.bufferLock: + packet = self.buffer.read(length) + if len(packet) < length: + packet = packet + (b"\x80" * (length - len(packet))) return packet def rebuild(self, reset: bool, offset: int = 0, data: bytes = b"") -> None: @@ -192,6 +191,7 @@ def rebuild(self, reset: bool, offset: int = 0, data: bytes = b"") -> None: self.rebuilding = False def write(self, offset: int, data: bytes) -> None: + # TODO: Can this safely be changed to use context manager syntax? self.bufferLock.acquire() self.log[offset] = data bufferloc = self.buffer.tell() diff --git a/pyVoIP/SIP.py b/pyVoIP/SIP.py index a737b0d..0337e3e 100644 --- a/pyVoIP/SIP.py +++ b/pyVoIP/SIP.py @@ -1,6 +1,7 @@ from enum import Enum, IntEnum from threading import Timer, Lock from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING +from pyVoIP.util import acquired_lock_and_unblocked_socket from pyVoIP.VoIP.status import PhoneStatus import pyVoIP import hashlib @@ -40,6 +41,10 @@ class SIPParseError(Exception): pass +class RetryRequiredError(Exception): + pass + + class Counter: def __init__(self, start: int = 1): self.x = start @@ -842,40 +847,41 @@ def __init__( self.registerFailures = 0 self.recvLock = Lock() - def recv(self) -> None: + def recv_loop(self) -> None: while self.NSD: - self.recvLock.acquire() - self.s.setblocking(False) try: - raw = self.s.recv(8192) - if raw != b"\x00\x00\x00\x00": - try: - message = SIPMessage(raw) - debug(message.summary()) - self.parseMessage(message) - except Exception as ex: - debug(f"Error on header parsing: {ex}") + with acquired_lock_and_unblocked_socket(self.recvLock, self.s): + self.recv() except BlockingIOError: - self.s.setblocking(True) - self.recvLock.release() time.sleep(0.01) continue - except SIPParseError as e: - if "SIP Version" in str(e): - request = self.genSIPVersionNotSupported(message) - self.out.sendto( - request.encode("utf8"), (self.server, self.port) - ) - else: - debug(f"SIPParseError in SIP.recv: {type(e)}, {e}") - except Exception as e: - debug(f"SIP.recv error: {type(e)}, {e}\n\n{str(raw, 'utf8')}") - if pyVoIP.DEBUG: - self.s.setblocking(True) - self.recvLock.release() - raise - self.s.setblocking(True) - self.recvLock.release() + + def recv(self) -> None: + try: + raw = self.s.recv(8192) + if raw != b"\x00\x00\x00\x00": + try: + message = SIPMessage(raw) + debug(message.summary()) + self.parseMessage(message) + except Exception as ex: + debug(f"Error on header parsing: {ex}") + except SIPParseError as e: + if "SIP Version" in str(e): + request = self.genSIPVersionNotSupported(message) + self.out.sendto( + request.encode("utf8"), (self.server, self.port) + ) + else: + debug(f"SIPParseError in SIP.recv: {type(e)}, {e}") + except BlockingIOError: + # Re-raise BlockingIOError so recv_loop() can release locks and + # continue + raise + except Exception as e: + debug(f"SIP.recv error: {type(e)}, {e}\n\n{str(raw, 'utf8')}") + if pyVoIP.DEBUG: + raise def parseMessage(self, message: SIPMessage) -> None: warnings.warn( @@ -955,7 +961,7 @@ def start(self) -> None: self.s.bind((self.myIP, self.myPort)) self.out = self.s self.register() - t = Timer(1, self.recv) + t = Timer(1, self.recv_loop) t.name = "SIP Recieve" t.start() @@ -1596,52 +1602,49 @@ def invite( invite = self.genInvite( number, str(sess_id), ms, sendtype, branch, call_id ) - self.recvLock.acquire() - self.out.sendto(invite.encode("utf8"), (self.server, self.port)) - debug("Invited") - response = SIPMessage(self.s.recv(8192)) - - while ( - response.status != SIPStatus(401) - and response.status != SIPStatus(100) - and response.status != SIPStatus(180) - ) or response.headers["Call-ID"] != call_id: - if not self.NSD: - break - self.parseMessage(response) + with self.recvLock: + self.out.sendto(invite.encode("utf8"), (self.server, self.port)) + debug("Invited") response = SIPMessage(self.s.recv(8192)) - if response.status == SIPStatus(100) or response.status == SIPStatus( - 180 - ): - self.recvLock.release() - return SIPMessage(invite.encode("utf8")), call_id, sess_id - debug(f"Received Response: {response.summary()}") - ack = self.genAck(response) - self.out.sendto(ack.encode("utf8"), (self.server, self.port)) - debug("Acknowledged") - authhash = self.genAuthorization(response) - nonce = response.authentication["nonce"] - realm = response.authentication["realm"] - auth = ( - f'Authorization: Digest username="{self.username}",realm=' - + f'"{realm}",nonce="{nonce}",uri="sip:{self.server};' - + f'transport=UDP",response="{str(authhash, "utf8")}",' - + "algorithm=MD5\r\n" - ) - - invite = self.genInvite( - number, str(sess_id), ms, sendtype, branch, call_id - ) - invite = invite.replace( - "\r\nContent-Length", f"\r\n{auth}Content-Length" - ) + while ( + response.status != SIPStatus(401) + and response.status != SIPStatus(100) + and response.status != SIPStatus(180) + ) or response.headers["Call-ID"] != call_id: + if not self.NSD: + break + self.parseMessage(response) + response = SIPMessage(self.s.recv(8192)) + + if response.status == SIPStatus( + 100 + ) or response.status == SIPStatus(180): + return SIPMessage(invite.encode("utf8")), call_id, sess_id + debug(f"Received Response: {response.summary()}") + ack = self.genAck(response) + self.out.sendto(ack.encode("utf8"), (self.server, self.port)) + debug("Acknowledged") + authhash = self.genAuthorization(response) + nonce = response.authentication["nonce"] + realm = response.authentication["realm"] + auth = ( + f'Authorization: Digest username="{self.username}",realm=' + + f'"{realm}",nonce="{nonce}",uri="sip:{self.server};' + + f'transport=UDP",response="{str(authhash, "utf8")}",' + + "algorithm=MD5\r\n" + ) - self.out.sendto(invite.encode("utf8"), (self.server, self.port)) + invite = self.genInvite( + number, str(sess_id), ms, sendtype, branch, call_id + ) + invite = invite.replace( + "\r\nContent-Length", f"\r\n{auth}Content-Length" + ) - self.recvLock.release() + self.out.sendto(invite.encode("utf8"), (self.server, self.port)) - return SIPMessage(invite.encode("utf8")), call_id, sess_id + return SIPMessage(invite.encode("utf8")), call_id, sess_id def bye(self, request: SIPMessage) -> None: message = self.genBye(request) @@ -1650,7 +1653,8 @@ def bye(self, request: SIPMessage) -> None: def deregister(self) -> bool: try: - deregistered = self.__deregister() + with self.recvLock: + deregistered = self.__deregister() if not deregistered: debug("DEREGISTERATION FAILED") return False @@ -1660,12 +1664,16 @@ def deregister(self) -> bool: return deregistered except BaseException as e: debug(f"DEREGISTERATION ERROR: {e}") + # TODO: a maximum tries check should be implemented otherwise a + # RecursionError will throw + if isinstance(e, RetryRequiredError): + time.sleep(5) + return self.deregister() if type(e) is OSError: raise return False def __deregister(self) -> bool: - self.recvLock.acquire() self.phone._status = PhoneStatus.DEREGISTERING firstRequest = self.genFirstRequest(deregister=True) self.out.sendto(firstRequest.encode("utf8"), (self.server, self.port)) @@ -1676,7 +1684,6 @@ def __deregister(self) -> bool: if ready[0]: resp = self.s.recv(8192) else: - self.recvLock.release() raise TimeoutError("Deregistering on SIP Server timed out") response = SIPMessage(resp) @@ -1696,7 +1703,6 @@ def __deregister(self) -> bool: # At this point, it's reasonable to assume that # this is caused by invalid credentials. debug("Unauthorized") - self.recvLock.release() raise InvalidAccountInfoError( "Invalid Username or " + "Password for SIP server " @@ -1710,23 +1716,20 @@ def __deregister(self) -> bool: # with new urn:uuid or reply with expire 0 self._handle_bad_request() else: - self.recvLock.release() raise TimeoutError("Deregistering on SIP Server timed out") if response.status == SIPStatus(500): - self.recvLock.release() - time.sleep(5) - return self.deregister() + # We raise so the calling function can sleep and try again + raise RetryRequiredError("Response SIP status of 500") if response.status == SIPStatus.OK: - self.recvLock.release() return True - self.recvLock.release() return False def register(self) -> bool: try: - registered = self.__register() + with self.recvLock: + registered = self.__register() if not registered: debug("REGISTERATION FAILED") self.registerFailures += 1 @@ -1749,6 +1752,9 @@ def register(self) -> bool: self.stop() self.fatalCallback() return False + if isinstance(e, RetryRequiredError): + time.sleep(5) + return self.register() self.__start_register_timer(delay=0) def __start_register_timer(self, delay: Optional[int] = None): @@ -1764,7 +1770,6 @@ def __start_register_timer(self, delay: Optional[int] = None): self.registerThread.start() def __register(self) -> bool: - self.recvLock.acquire() self.phone._status = PhoneStatus.REGISTERING firstRequest = self.genFirstRequest() self.out.sendto(firstRequest.encode("utf8"), (self.server, self.port)) @@ -1775,7 +1780,6 @@ def __register(self) -> bool: if ready[0]: resp = self.s.recv(8192) else: - self.recvLock.release() raise TimeoutError("Registering on SIP Server timed out") response = SIPMessage(resp) @@ -1814,7 +1818,6 @@ def __register(self) -> bool: debug("\nRECEIVED") debug(response.summary()) debug("=" * 50) - self.recvLock.release() raise InvalidAccountInfoError( "Invalid Username or " + "Password for SIP server " @@ -1828,7 +1831,6 @@ def __register(self) -> bool: # with new urn:uuid or reply with expire 0 self._handle_bad_request() else: - self.recvLock.release() raise TimeoutError("Registering on SIP Server timed out") if response.status == SIPStatus(407): @@ -1844,9 +1846,8 @@ def __register(self) -> bool: ]: # Unauthorized if response.status == SIPStatus(500): - self.recvLock.release() - time.sleep(5) - return self.register() + # We raise so the calling function can sleep and try again + raise RetryRequiredError("Response SIP status of 500") else: # TODO: determine if needed here self.parseMessage(response) @@ -1854,7 +1855,6 @@ def __register(self) -> bool: debug(response.summary()) debug(response.raw) - self.recvLock.release() if response.status == SIPStatus.OK: return True else: @@ -1873,16 +1873,17 @@ def _handle_bad_request(self) -> None: def subscribe(self, lastresponse: SIPMessage) -> None: # TODO: check if needed and maybe implement fully - self.recvLock.acquire() - - subRequest = self.genSubscribe(lastresponse) - self.out.sendto(subRequest.encode("utf8"), (self.server, self.port)) - - response = SIPMessage(self.s.recv(8192)) + with self.recvLock: + subRequest = self.genSubscribe(lastresponse) + self.out.sendto( + subRequest.encode("utf8"), (self.server, self.port) + ) - debug(f'Got response to subscribe: {str(response.heading, "utf8")}') + response = SIPMessage(self.s.recv(8192)) - self.recvLock.release() + debug( + f'Got response to subscribe: {str(response.heading, "utf8")}' + ) def trying_timeout_check(self, response: SIPMessage) -> SIPMessage: """ diff --git a/pyVoIP/VoIP/VoIP.py b/pyVoIP/VoIP/VoIP.py index 6174e98..f6132ad 100644 --- a/pyVoIP/VoIP/VoIP.py +++ b/pyVoIP/VoIP/VoIP.py @@ -235,12 +235,11 @@ def __del__(self): self.phone.release_ports(call=self) def dtmf_callback(self, code: str) -> None: - self.dtmfLock.acquire() - bufferloc = self.dtmf.tell() - self.dtmf.seek(0, 2) - self.dtmf.write(code) - self.dtmf.seek(bufferloc, 0) - self.dtmfLock.release() + with self.dtmfLock: + bufferloc = self.dtmf.tell() + self.dtmf.seek(0, 2) + self.dtmf.write(code) + self.dtmf.seek(bufferloc, 0) def getDTMF(self, length=1) -> str: warnings.warn( @@ -252,10 +251,9 @@ def getDTMF(self, length=1) -> str: return self.get_dtmf(length) def get_dtmf(self, length=1) -> str: - self.dtmfLock.acquire() - packet = self.dtmf.read(length) - self.dtmfLock.release() - return packet + with self.dtmfLock: + packet = self.dtmf.read(length) + return packet def genMs(self) -> Dict[int, Dict[int, RTP.PayloadType]]: warnings.warn( @@ -737,9 +735,8 @@ def request_port(self, blocking=True) -> int: return selection def release_ports(self, call: Optional[VoIPCall] = None) -> None: - self.portsLock.acquire() - self._cleanup_dead_calls() - try: + with self.portsLock: + self._cleanup_dead_calls() if isinstance(call, VoIPCall): ports = list(call.assignedPorts.keys()) else: @@ -753,8 +750,6 @@ def release_ports(self, call: Optional[VoIPCall] = None) -> None: for port in ports: self.assignedPorts.remove(port) - finally: - self.portsLock.release() def _cleanup_dead_calls(self) -> None: to_delete = [] diff --git a/pyVoIP/util.py b/pyVoIP/util.py new file mode 100644 index 0000000..fa7c6cb --- /dev/null +++ b/pyVoIP/util.py @@ -0,0 +1,20 @@ +from contextlib import contextmanager +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from socket import socket + from threading import Lock + + +@contextmanager +def acquired_lock_and_unblocked_socket(lock: "Lock", socket: "socket"): + """Alongside an acquired Lock, a corresponding socket will become + non-blocking, and then blocking once the Lock is released. + + Lock will release and socket will become blocking even during exceptions""" + try: + with lock: + socket.setblocking(False) + yield + finally: + socket.setblocking(True) diff --git a/tests/test_util.py b/tests/test_util.py new file mode 100644 index 0000000..74c016c --- /dev/null +++ b/tests/test_util.py @@ -0,0 +1,31 @@ +from pyVoIP.util import acquired_lock_and_unblocked_socket +from threading import Lock +from socket import socket +import pytest + + +def test_acquired_lock_and_unblocked_socket(): + l = Lock() + s = socket() + assert l.locked() is False + assert s.getblocking() is True + with acquired_lock_and_unblocked_socket(l, s): + assert l.locked() is True + assert s.getblocking() is False + assert l.locked() is False + assert s.getblocking() is True + + +def test_acquired_lock_and_unblocked_socket__with_exception(): + l = Lock() + s = socket() + assert l.locked() is False + assert s.getblocking() is True + with pytest.raises(Exception): + with acquired_lock_and_unblocked_socket(l, s): + assert l.locked() is True + assert s.getblocking() is False + raise Exception("Uh oh") + assert False, "Should never execute" + assert l.locked() is False + assert s.getblocking() is True