diff --git a/pyVoIP/sock/sock.py b/pyVoIP/sock/sock.py index 8937c90..d682c78 100644 --- a/pyVoIP/sock/sock.py +++ b/pyVoIP/sock/sock.py @@ -214,7 +214,7 @@ def __init__( conn.execute( """CREATE TABLE "listening" ( "call_id" TEXT NOT NULL, - "local_tag" TEXT NOT NULL, + "local_tag" TEXT, "remote_tag" TEXT, "connection" INTEGER NOT NULL UNIQUE, PRIMARY KEY("call_id", "local_tag", "remote_tag") @@ -231,7 +231,8 @@ def __init__( def gen_bad_request( self, connection=None, message=None, error=None, received=None ) -> bytes: - body = f"{error}{received}" + body = f"{error}" + body += f"{received}" bad_request = "SIP/2.0 400 Malformed Request\r\n" bad_request += ( f"Via: SIP/2.0/{self.mode} {self.bind_ip}:{self.bind_port}\r\n" @@ -250,58 +251,29 @@ def __get_connection( local_tag, remote_tag = self.determine_tags(message) call_id = message.headers["Call-ID"] conn = self.buffer.cursor() - result = conn.execute( - """SELECT "connection" FROM "listening" WHERE - "call_id" = ? - AND "local_tag" = ? - AND "remote_tag" = ?""", - (call_id, local_tag, remote_tag), - ) + sql = 'SELECT "connection" FROM "listening" WHERE "call_id" IS ?' + sql += ' AND "local_tag" IS ? AND "remote_tag" IS ?' + result = conn.execute(sql, (call_id, local_tag, remote_tag)) rows = result.fetchall() if rows: conn.close() return self.conns[rows[0][0]] + debug("New Connection Started") # If we didn't find one lets look for something that doesn't have - # a remote tag - result = conn.execute( - """SELECT "connection" FROM "listening" WHERE - "call_id" = ? - AND "local_tag" = ? - AND "remote_tag" is NULL""", - (call_id, local_tag), - ) + # one of the tags + sql = 'SELECT "connection" FROM "listening" WHERE "call_id" = ?' + sql += ' AND ("local_tag" IS NULL OR "local_tag" = ?)' + sql += ' AND ("remote_tag" IS NULL OR "remote_tag" = ?)' + result = conn.execute(sql, (call_id, local_tag, remote_tag)) rows = result.fetchall() if rows: - conn.execute( - """UPDATE "listening" SET - "remote_tag" = ? WHERE - "call_id" = ? - AND "local_tag" = ? - AND "remote_tag" is NULL""", - (remote_tag, call_id, local_tag), - ) + if local_tag and remote_tag: + sql = 'UPDATE "listening" SET "remote_tag" = ?, ' + sql += '"local_tag" = ? WHERE "connection" = ?' + conn.execute(sql, (remote_tag, local_tag, rows[0][0])) conn.close() return self.conns[rows[0][0]] - # If we still didn't find one, maybe we got the local and remote wrong? - result = conn.execute( - """SELECT "connection" FROM "listening" WHERE - "call_id" = ? - AND "local_tag" = ? - AND "remote_tag" is NULL""", - (call_id, remote_tag), - ) - rows = result.fetchall() conn.close() - if rows: - conn.execute( - """UPDATE "listening" SET - "remote_tag" = ? WHERE - "call_id" = ? - AND "local_tag" = ? - AND "remote_tag" is NULL""", - (local_tag, call_id, remote_tag), - ) - return self.conns[rows[0][0]] return None def __register_connection(self, connection: VoIPConnection) -> None: @@ -349,10 +321,11 @@ def deregister_connection(self, connection: VoIPConnection) -> None: debug(self.get_database_dump()) try: conn = self.buffer.cursor() + sql = 'SELECT "connection" FROM "listening" WHERE "call_id" = ?' + sql += ' AND ("local_tag" IS NULL OR "local_tag" = ?)' + sql += ' AND ("remote_tag" IS NULL OR "remote_tag" = ?)' result = conn.execute( - """SELECT "connection" FROM "listening" - WHERE "call_id" = ? AND "local_tag" = ? - AND "remote_tag" = ?""", + sql, ( connection.call_id, connection.local_tag, @@ -417,60 +390,64 @@ def bind(self, addr: Tuple[str, int]) -> None: def _listen(self, backlog=0) -> None: return self.s.listen(backlog) - def run(self) -> None: - self.bind((self.bind_ip, self.bind_port)) - if self.mode != TransportMode.UDP: - self._listen() + def _tcp_tls_run(self) -> None: + self._listen() while not self.SD: - if self.mode == TransportMode.UDP: - try: - data = self.s.recv(8192) - except OSError: - continue - try: - message = SIPMessage(data) - except SIPParseError: - continue - debug("\n\nReceived UDP Message:") - debug(message.summary()) - else: - try: - conn, addr = self.s.accept() - except OSError: - continue - debug(f"Received new {self.mode} connection from {addr}.") - data = conn.recv(8192) - try: - message = SIPMessage(data) - except SIPParseError: - continue - debug("\n\nReceived SIP Message:") - debug(message.summary()) - - if not self.__connection_exists(message): - if self.mode == TransportMode.UDP: - self.__register_connection( - VoIPConnection(self, None, message) - ) - else: - self.__register_connection( - VoIPConnection(self, conn, message) - ) + try: + conn, addr = self.s.accept() + except OSError: + continue + debug(f"Received new {self.mode} connection from {addr}.") + data = conn.recv(8192) + try: + message = SIPMessage(data) + except SIPParseError: + continue + debug("\n\nReceived SIP Message:") + debug(message.summary()) + self._handle_incoming_message(conn, message) - call_id = message.headers["Call-ID"] - local_tag, remote_tag = self.determine_tags(message) - raw_message = data.decode("utf8") - conn = self.buffer.cursor() - conn.execute( - "INSERT INTO msgs (call_id, local_tag, remote_tag, msg) " - + "VALUES (?, ?, ?, ?)", - (call_id, local_tag, remote_tag, raw_message), - ) + def _udp_run(self) -> None: + while not self.SD: try: - self.buffer.commit() - except sqlite3.OperationalError: - pass - conn.close() + data = self.s.recv(8192) + except OSError: + continue + try: + message = SIPMessage(data) + except SIPParseError: + continue + debug("\n\nReceived UDP Message:") + debug(message.summary()) + self._handle_incoming_message(None, message) + + def _handle_incoming_message( + self, conn: Optional[SOCKETS], message: SIPMessage + ): + if not self.__connection_exists(message): + self.__register_connection(VoIPConnection(self, conn, message)) + + call_id = message.headers["Call-ID"] + local_tag, remote_tag = self.determine_tags(message) + raw_message = message.raw.decode("utf8") + conn = self.buffer.cursor() + conn.execute( + "INSERT INTO msgs (call_id, local_tag, remote_tag, msg) " + + "VALUES (?, ?, ?, ?)", + (call_id, local_tag, remote_tag, raw_message), + ) + try: + self.buffer.commit() + except sqlite3.OperationalError: + pass + conn.close() + + def run(self) -> None: + self.bind((self.bind_ip, self.bind_port)) + if self.mode == TransportMode.UDP: + self._udp_run() + else: + self._tcp_tls_run() def close(self) -> None: self.SD = True