From e83aea2a46c7cdcd0095c11304e4e58ede220087 Mon Sep 17 00:00:00 2001 From: Siting Ren Date: Wed, 8 Nov 2023 23:09:08 +0800 Subject: [PATCH 1/5] p1 --- vertica_python/datatypes.py | 14 +++--- vertica_python/os_utils.py | 6 +-- vertica_python/vertica/connection.py | 48 +++++++++++-------- vertica_python/vertica/deserializer.py | 64 +++++++++++++------------- vertica_python/vertica/log.py | 5 +- 5 files changed, 76 insertions(+), 61 deletions(-) diff --git a/vertica_python/datatypes.py b/vertica_python/datatypes.py index e7ddfc61..20623105 100644 --- a/vertica_python/datatypes.py +++ b/vertica_python/datatypes.py @@ -297,7 +297,7 @@ def __ne__(self, other): VerticaType.SET_LONGVARBINARY: VerticaType.LONGVARBINARY, } -def getTypeName(data_type_oid, type_modifier): +def getTypeName(data_type_oid: int, type_modifier: int) -> str: """Returns the base type name according to data_type_oid and type_modifier""" if data_type_oid in TYPENAME: return TYPENAME[data_type_oid] @@ -310,11 +310,11 @@ def getTypeName(data_type_oid, type_modifier): else: return "Unknown" -def getComplexElementType(data_type_oid): +def getComplexElementType(data_type_oid: int): """For 1D ARRAY or SET, returns the type of its elements""" return COMPLEX_ELEMENT_TYPE.get(data_type_oid) -def getIntervalRange(data_type_oid, type_modifier): +def getIntervalRange(data_type_oid: int, type_modifier: int): """Extracts an interval's range from the bits set in its type_modifier""" if data_type_oid not in (VerticaType.INTERVAL, VerticaType.INTERVALYM): @@ -361,7 +361,7 @@ def getIntervalRange(data_type_oid, type_modifier): return "Day to Second" -def getIntervalLeadingPrecision(data_type_oid, type_modifier): +def getIntervalLeadingPrecision(data_type_oid: int, type_modifier: int): """ Returns the leading precision for an interval, which is the largest number of digits that can fit in the leading field of the interval. @@ -394,7 +394,7 @@ def getIntervalLeadingPrecision(data_type_oid, type_modifier): raise ValueError("Invalid interval range: {}".format(interval_range)) -def getPrecision(data_type_oid, type_modifier): +def getPrecision(data_type_oid: int, type_modifier: int): """ Returns the precision for the given Vertica type with consideration of the type modifier. @@ -423,7 +423,7 @@ def getPrecision(data_type_oid, type_modifier): return None # None if no meaningful values can be provided -def getScale(data_type_oid, type_modifier): +def getScale(data_type_oid: int, type_modifier: int): """ Returns the scale for the given Vertica type with consideration of the type modifier. @@ -435,7 +435,7 @@ def getScale(data_type_oid, type_modifier): return None # None if no meaningful values can be provided -def getDisplaySize(data_type_oid, type_modifier): +def getDisplaySize(data_type_oid: int, type_modifier: int): """ Returns the column display size for the given Vertica type with consideration of the type modifier. diff --git a/vertica_python/os_utils.py b/vertica_python/os_utils.py index 75bfbd70..14113efa 100644 --- a/vertica_python/os_utils.py +++ b/vertica_python/os_utils.py @@ -19,7 +19,7 @@ import os -def ensure_dir_exists(filepath): +def ensure_dir_exists(filepath: str) -> None: """Ensure that a directory exists If it doesn't exist, try to create it and protect against a race condition @@ -33,7 +33,7 @@ def ensure_dir_exists(filepath): if e.errno != errno.EEXIST: raise -def check_file_readable(filename): +def check_file_readable(filename: str) -> None: """Ensure this is a readable file""" if not os.path.exists(filename): raise OSError('{} does not exist'.format(filename)) @@ -42,7 +42,7 @@ def check_file_readable(filename): elif not os.access(filename, os.R_OK): raise OSError('{} is not readable'.format(filename)) -def check_file_writable(filename): +def check_file_writable(filename: str) -> None: """Ensure this is a writable file. If the file doesn't exist, ensure its directory is writable. """ diff --git a/vertica_python/vertica/connection.py b/vertica_python/vertica/connection.py index 704090f5..0b39070f 100644 --- a/vertica_python/vertica/connection.py +++ b/vertica_python/vertica/connection.py @@ -88,7 +88,7 @@ def connect(**kwargs): return Connection(kwargs) -def parse_dsn(dsn): +def parse_dsn(dsn: str): """Parse connection string into a dictionary of keywords and values. Connection string format: vertica://:@:/?k1=v1&k2=v2&... @@ -107,6 +107,7 @@ def parse_dsn(dsn): } for key, values in parse_qs(url.query, keep_blank_values=True).items(): # Try to get the last non-blank value in the list of values for each key + value = '' for i in reversed(range(len(values))): value = values[i] if value != '': @@ -242,7 +243,7 @@ def peek_host(self): return self.address_deque[0].host -def _generate_session_label(): +def _generate_session_label() -> str: return '{type}-{version}-{id}'.format( type='vertica-python', version=vertica_python.__version__, @@ -251,7 +252,7 @@ def _generate_session_label(): class Connection(object): - def __init__(self, options=None): + def __init__(self, options=None) -> None: # type: (Optional[Dict[str, Any]]) -> None self.parameters = {} self.session_id = None @@ -277,7 +278,7 @@ def __init__(self, options=None): self.options.setdefault('log_level', DEFAULT_LOG_LEVEL) self.options.setdefault('log_path', DEFAULT_LOG_PATH) VerticaLogging.setup_logging(logger_name, self.options['log_path'], - self.options['log_level'], id(self)) + self.options['log_level'], str(id(self))) self.options.setdefault('host', DEFAULT_HOST) self.options.setdefault('port', DEFAULT_PORT) @@ -353,21 +354,24 @@ def __exit__(self, type_, value, traceback): ############################################# # dbapi methods ############################################# - def close(self): + def close(self) -> None: + """Close the connection now.""" self._logger.info('Close the connection') try: self.write(messages.Terminate()) finally: self.close_socket() - def commit(self): + def commit(self) -> None: + """Commit any pending transaction to the database.""" if self.closed(): raise errors.ConnectionError('Connection is closed') cur = self.cursor() cur.execute('COMMIT;') - def rollback(self): + def rollback(self) -> None: + """Roll back to the start of any pending transaction.""" if self.closed(): raise errors.ConnectionError('Connection is closed') @@ -376,6 +380,10 @@ def rollback(self): def cursor(self, cursor_type=None): # type: (Self, Optional[Union[Literal['list', 'dict'], Type[list[Any]], Type[dict[Any, Any]]]]) -> Cursor + """Return the Cursor Object using the connection. + + vertica-python only support one cursor per connection. + """ if self.closed(): raise errors.ConnectionError('Connection is closed') @@ -390,14 +398,14 @@ def cursor(self, cursor_type=None): # non-dbapi methods ############################################# @property - def autocommit(self): - """Read the connection's AUTOCOMMIT setting from cache""" + def autocommit(self) -> bool: + """Read the connection's AUTOCOMMIT setting from cache.""" # For a new session, autocommit is off by default return self.parameters.get('auto_commit', 'off') == 'on' @autocommit.setter - def autocommit(self, value): - """Change the connection's AUTOCOMMIT setting""" + def autocommit(self, value: bool) -> None: + """Change the connection's AUTOCOMMIT setting.""" if self.autocommit is value: return val = 'on' if value else 'off' @@ -405,9 +413,11 @@ def autocommit(self, value): cur.execute('SET SESSION AUTOCOMMIT TO {}'.format(val), use_prepared_statements=False) cur.fetchall() # check for errors and update the cache - def cancel(self): - """Cancel the current database operation. This can be called from a - different thread than the one currently executing a database operation. + def cancel(self) -> None: + """Cancel the current database operation. + + This method can be called from a different thread than the one currently + executing a database operation. """ if self.closed(): raise errors.ConnectionError('Connection is closed') @@ -419,15 +429,17 @@ def cancel(self): self._logger.info('Cancel request issued') - def opened(self): + def opened(self) -> bool: + """Returns True if the connection is opened.""" return (self.socket is not None and self.backend_pid is not None and self.transaction_status is not None) - def closed(self): + def closed(self) -> bool: + """Returns True if the connection is closed.""" return not self.opened() - def __str__(self): + def __str__(self) -> str: safe_options = {key: value for key, value in self.options.items() if key != 'password'} s1 = " None: self._logger.debug("Close connection's socket") try: if self.socket is not None: diff --git a/vertica_python/vertica/deserializer.py b/vertica_python/vertica/deserializer.py index 3450ea9b..a8941c52 100644 --- a/vertica_python/vertica/deserializer.py +++ b/vertica_python/vertica/deserializer.py @@ -64,7 +64,7 @@ def deserializer(data): TZ_RE = re.compile(r"(?ix) ^([-+]) (\d+) (?: : (\d+) )? (?: : (\d+) )? $") SECONDS_PER_DAY = 86400 -def load_bool_binary(val, ctx): +def load_bool_binary(val:bytes, ctx) -> bool: """ Parses binary representation of a BOOLEAN type. :param val: a byte - b'\x01' for True, b'\x00' for False @@ -73,7 +73,7 @@ def load_bool_binary(val, ctx): """ return val == b'\x01' -def load_int8_binary(val, ctx): +def load_int8_binary(val: bytes, ctx) -> int: """ Parses binary representation of a INTEGER type. :param val: bytes - a 64-bit integer. @@ -82,7 +82,7 @@ def load_int8_binary(val, ctx): """ return unpack("!q", val)[0] -def load_float8_binary(val, ctx): +def load_float8_binary(val: bytes, ctx) -> float: """ Parses binary representation of a FLOAT type. :param val: bytes - a float encoded in IEEE-754 format. @@ -91,7 +91,7 @@ def load_float8_binary(val, ctx): """ return unpack("!d", val)[0] -def load_numeric_binary(val, ctx): +def load_numeric_binary(val: bytes, ctx) -> Decimal: """ Parses binary representation of a NUMERIC type. :param val: bytes @@ -106,7 +106,7 @@ def load_numeric_binary(val, ctx): # The numeric value is (unscaledVal * 10^(-scale)) return Decimal(unscaledVal).scaleb(-scale, context=Context(prec=precision)) -def load_varchar_text(val, ctx): +def load_varchar_text(val: bytes, ctx) -> str: """ Parses text/binary representation of a CHAR / VARCHAR / LONG VARCHAR type. :param val: bytes @@ -115,7 +115,7 @@ def load_varchar_text(val, ctx): """ return val.decode('utf-8', ctx['unicode_error']) -def load_date_text(val, ctx): +def load_date_text(val: bytes, ctx) -> date: """ Parses text representation of a DATE type. :param val: bytes @@ -131,7 +131,7 @@ def load_date_text(val, ctx): except ValueError: raise errors.NotSupportedError('Dates after year 9999 are not supported by datetime.date. Got: {0}'.format(s)) -def load_date_binary(val, ctx): +def load_date_binary(val: bytes, ctx) -> date: """ Parses binary representation of a DATE type. :param val: bytes @@ -149,7 +149,7 @@ def load_date_binary(val, ctx): raise errors.NotSupportedError('Dates after year 9999 are not supported by datetime.date. Got: Julian day number {0}'.format(jdn)) return date.fromordinal(days) -def load_time_text(val, ctx): +def load_time_text(val: bytes, ctx) -> time: """ Parses text representation of a TIME type. :param val: bytes @@ -161,7 +161,7 @@ def load_time_text(val, ctx): return datetime.strptime(val, '%H:%M:%S').time() return datetime.strptime(val, '%H:%M:%S.%f').time() -def load_time_binary(val, ctx): +def load_time_binary(val: bytes, ctx): """ Parses binary representation of a TIME type. :param val: bytes @@ -212,7 +212,7 @@ def load_timetz_text(val, ctx): return time(int(hr), int(mi), int(sec), us, tz.tzoffset(None, tz_offset)) -def load_timetz_binary(val, ctx): +def load_timetz_binary(val: bytes, ctx) -> time: """ Parses binary representation of a TIMETZ type. :param val: bytes @@ -222,9 +222,9 @@ def load_timetz_binary(val, ctx): # 8-byte value where # - Upper 40 bits contain the number of microseconds since midnight in the UTC time zone. # - Lower 24 bits contain time zone as the UTC offset in seconds. - val = load_int8_binary(val, ctx) - tz_offset = SECONDS_PER_DAY - (val & 0xffffff) # in seconds - msecs = val >> 24 + v = load_int8_binary(val, ctx) + tz_offset = SECONDS_PER_DAY - (v & 0xffffff) # in seconds + msecs = v >> 24 # shift to given time zone msecs += tz_offset * 1000000 msecs %= SECONDS_PER_DAY * 1000000 @@ -233,7 +233,7 @@ def load_timetz_binary(val, ctx): hour, minute = divmod(msecs, 60) return time(hour, minute, second, fraction, tz.tzoffset(None, tz_offset)) -def load_timestamp_text(val, ctx): +def load_timestamp_text(val: bytes, ctx) -> datetime: """ Parses text representation of a TIMESTAMP type. :param val: bytes @@ -249,7 +249,7 @@ def load_timestamp_text(val, ctx): except ValueError: raise errors.NotSupportedError('Timestamps after year 9999 are not supported by datetime.datetime. Got: {0}'.format(s)) -def load_timestamp_binary(val, ctx): +def load_timestamp_binary(val: bytes, ctx) -> datetime: """ Parses binary representation of a TIMESTAMP type. :param val: bytes @@ -267,7 +267,7 @@ def load_timestamp_binary(val, ctx): else: raise errors.NotSupportedError('Timestamps after year 9999 are not supported by datetime.datetime.') -def load_timestamptz_text(val, ctx): +def load_timestamptz_text(val: bytes, ctx) -> datetime: """ Parses text representation of a TIMESTAMPTZ type. :param val: bytes @@ -287,7 +287,7 @@ def load_timestamptz_text(val, ctx): t = load_timetz_text(dt[1], ctx) return datetime.combine(d, t) -def load_timestamptz_binary(val, ctx): +def load_timestamptz_binary(val: bytes, ctx) -> datetime: """ Parses binary representation of a TIMESTAMPTZ type. :param val: bytes @@ -312,7 +312,7 @@ def load_timestamptz_binary(val, ctx): else: # year might be over 9999 raise errors.NotSupportedError('TimestampTzs after year 9999 are not supported by datetime.datetime.') -def load_interval_text(val, ctx): +def load_interval_text(val: bytes, ctx) -> relativedelta: """ Parses text representation of a INTERVAL day-time type. :param val: bytes @@ -361,7 +361,7 @@ def load_interval_text(val, ctx): return relativedelta(days=parts[0], hours=parts[1], minutes=parts[2], seconds=parts[3], microseconds=parts[4]) -def load_interval_binary(val, ctx): +def load_interval_binary(val: bytes, ctx) -> relativedelta: """ Parses binary representation of a INTERVAL day-time type. :param val: bytes @@ -372,7 +372,7 @@ def load_interval_binary(val, ctx): msecs = load_int8_binary(val, ctx) return relativedelta(microseconds=msecs) -def load_intervalYM_text(val, ctx): +def load_intervalYM_text(val: bytes, ctx) -> relativedelta: """ Parses text representation of a INTERVAL YEAR TO MONTH / INTERVAL YEAR / INTERVAL MONTH type. :param val: bytes @@ -398,7 +398,7 @@ def load_intervalYM_text(val, ctx): else: # Interval Month return relativedelta(months=interval) -def load_intervalYM_binary(val, ctx): +def load_intervalYM_binary(val: bytes, ctx) -> relativedelta: """ Parses binary representation of a INTERVAL YEAR TO MONTH / INTERVAL YEAR / INTERVAL MONTH type. :param val: bytes @@ -409,7 +409,7 @@ def load_intervalYM_binary(val, ctx): months = load_int8_binary(val, ctx) return relativedelta(months=months) -def load_uuid_binary(val, ctx): +def load_uuid_binary(val: bytes, ctx) -> UUID: """ Parses binary representation of a UUID type. :param val: bytes @@ -419,7 +419,7 @@ def load_uuid_binary(val, ctx): # 16-byte value in big-endian order interpreted as UUID return UUID(bytes=bytes(val)) -def load_varbinary_text(s, ctx): +def load_varbinary_text(s: bytes, ctx) -> bytes: """ Parses text representation of a BINARY / VARBINARY / LONG VARBINARY type. :param s: bytes @@ -443,21 +443,21 @@ def load_varbinary_text(s, ctx): buf.append(c) return b''.join(buf) -def load_array_text(val, ctx): +def load_array_text(val: bytes, ctx): """ Parses text/binary representation of an ARRAY type. :param val: bytes :param ctx: dict :return: list """ - val = val.decode('utf-8', ctx['unicode_error']) + v = val.decode('utf-8', ctx['unicode_error']) # Some old servers have a bug of sending ARRAY oid without child metadata if not ctx['complex_types_enabled']: - return val - json_data = json.loads(val) + return v + json_data = json.loads(v) return parse_array(json_data, ctx) -def load_set_text(val, ctx): +def load_set_text(val: bytes, ctx): """ Parses text/binary representation of a SET type. :param val: bytes @@ -485,18 +485,18 @@ def parse_array(json_data, ctx): parsed_array[idx] = parse_json_element(element, child_ctx) return parsed_array -def load_row_text(val, ctx): +def load_row_text(val: bytes, ctx): """ Parses text/binary representation of a ROW type. :param val: bytes :param ctx: dict :return: dict """ - val = val.decode('utf-8', ctx['unicode_error']) + v = val.decode('utf-8', ctx['unicode_error']) # Some old servers have a bug of sending ROW oid without child metadata if not ctx['complex_types_enabled']: - return val - json_data = json.loads(val) + return v + json_data = json.loads(v) return parse_row(json_data, ctx) def parse_row(json_data, ctx): diff --git a/vertica_python/vertica/log.py b/vertica_python/vertica/log.py index 41e29d06..843eeecc 100644 --- a/vertica_python/vertica/log.py +++ b/vertica_python/vertica/log.py @@ -37,12 +37,15 @@ from __future__ import print_function, division, absolute_import import logging +from typing import Union from ..os_utils import ensure_dir_exists class VerticaLogging(object): @classmethod - def setup_logging(cls, logger_name, logfile, log_level=logging.INFO, context=''): + def setup_logging(cls, logger_name: str, logfile: str, + log_level: Union[int, str] = logging.INFO, + context: str = '') -> None: logger = logging.getLogger(logger_name) logger.setLevel(log_level) From 8bf9f8959bcb203d2450b3f222415a76657fc65d Mon Sep 17 00:00:00 2001 From: Siting Ren Date: Tue, 14 Nov 2023 00:02:05 +0800 Subject: [PATCH 2/5] p2 --- vertica_python/vertica/connection.py | 60 ++++++++++--------- vertica_python/vertica/cursor.py | 7 ++- .../backend_messages/load_balance_response.py | 4 +- 3 files changed, 39 insertions(+), 32 deletions(-) diff --git a/vertica_python/vertica/connection.py b/vertica_python/vertica/connection.py index 0b39070f..b821000c 100644 --- a/vertica_python/vertica/connection.py +++ b/vertica_python/vertica/connection.py @@ -34,7 +34,7 @@ # THE SOFTWARE. -from __future__ import print_function, division, absolute_import +from __future__ import print_function, division, absolute_import, annotations import base64 import logging @@ -44,14 +44,14 @@ import uuid import warnings from struct import unpack -from collections import deque, namedtuple +from collections import deque import random # noinspection PyCompatibility,PyUnresolvedReferences from urllib.parse import urlparse, parse_qs from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Any, Dict, Literal, Optional, Type, Union + from typing import Any, Dict, List, Optional, Type, Union, Deque, Tuple, NamedTuple from typing_extensions import Self import vertica_python @@ -82,8 +82,7 @@ warnings.warn(f"Cannot get the login user name: {str(e)}") -def connect(**kwargs): - # type: (Any) -> Connection +def connect(**kwargs: Any) -> Connection: """Opens a new connection to a Vertica database.""" return Connection(kwargs) @@ -135,10 +134,15 @@ def parse_dsn(dsn: str): return result -_AddressEntry = namedtuple('_AddressEntry', ['host', 'resolved', 'data']) +class _AddressEntry(NamedTuple): + host: str + resolved: bool + data: Any class _AddressList(object): - def __init__(self, host, port, backup_nodes, logger): + def __init__(self, host: str, port: Union[int, str], + backup_nodes: List[Union[str, Tuple[str, Union[int, str]]]], + logger: logging.Logger) -> None: """Creates a new deque with the primary host first, followed by any backup hosts""" self._logger = logger @@ -148,7 +152,7 @@ def __init__(self, host, port, backup_nodes, logger): # - when resolved is False, data is port # - when resolved is True, data is the 5-tuple from socket.getaddrinfo # This allows for lazy resolution. Seek peek() for more. - self.address_deque = deque() + self.address_deque: Deque['_AddressEntry'] = deque() # load primary host into address_deque self._append(host, port) @@ -174,7 +178,7 @@ def __init__(self, host, port, backup_nodes, logger): raise TypeError(err_msg) self._logger.debug('Address list: {0}'.format(list(self.address_deque))) - def _append(self, host, port): + def _append(self, host: str, port: Union[int, str]) -> None: if not isinstance(host, str): err_msg = 'Host must be a string: invalid value: {0}'.format(host) self._logger.error(err_msg) @@ -199,10 +203,10 @@ def _append(self, host, port): self.address_deque.append(_AddressEntry(host=host, resolved=False, data=port)) - def push(self, host, port): + def push(self, host: str, port: int) -> None: self.address_deque.appendleft(_AddressEntry(host=host, resolved=False, data=port)) - def pop(self): + def pop(self) -> None: self.address_deque.popleft() def peek(self): @@ -235,8 +239,8 @@ def peek(self): host=host, resolved=True, data=addrinfo)) return None - def peek_host(self): - # returning the leftmost host result + def peek_host(self) -> Optional[str]: + """Return the leftmost host result.""" self._logger.debug('Peek host at address list: {0}'.format(list(self.address_deque))) if len(self.address_deque) == 0: return None @@ -252,8 +256,7 @@ def _generate_session_label() -> str: class Connection(object): - def __init__(self, options=None) -> None: - # type: (Optional[Dict[str, Any]]) -> None + def __init__(self, options: Optional[Dict[str, Any]] = None) -> None: self.parameters = {} self.session_id = None self.backend_pid = None @@ -345,7 +348,6 @@ def __init__(self, options=None) -> None: # supporting `with` statements ############################################# def __enter__(self): - # type: () -> Self return self def __exit__(self, type_, value, traceback): @@ -378,8 +380,8 @@ def rollback(self) -> None: cur = self.cursor() cur.execute('ROLLBACK;') - def cursor(self, cursor_type=None): - # type: (Self, Optional[Union[Literal['list', 'dict'], Type[list[Any]], Type[dict[Any, Any]]]]) -> Cursor + def cursor(self, + cursor_type: Union[None, str, Type[List[Any]], Type[Dict[Any, Any]]] = None) -> Cursor: """Return the Cursor Object using the connection. vertica-python only support one cursor per connection. @@ -496,7 +498,7 @@ def _socket_as_file(self): self.socket_as_file = self._socket().makefile('rb') return self.socket_as_file - def create_socket(self, family): + def create_socket(self, family) -> socket.socket: """Create a TCP socket object""" raw_socket = socket.socket(family, socket.SOCK_STREAM) raw_socket.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) @@ -506,7 +508,7 @@ def create_socket(self, family): raw_socket.settimeout(connection_timeout) return raw_socket - def balance_load(self, raw_socket): + def balance_load(self, raw_socket: socket.socket) -> socket.socket: # Send load balance request and read server response self._logger.debug('=> %s', messages.LoadBalanceRequest()) raw_socket.sendall(messages.LoadBalanceRequest().get_message()) @@ -543,7 +545,9 @@ def balance_load(self, raw_socket): return raw_socket - def enable_ssl(self, raw_socket, ssl_options): + def enable_ssl(self, + raw_socket: socket.socket, + ssl_options: Union[ssl.SSLContext, bool]) -> ssl.SSLSocket: # Send SSL request and read server response self._logger.debug('=> %s', messages.SslRequest()) raw_socket.sendall(messages.SslRequest().get_message()) @@ -574,7 +578,7 @@ def enable_ssl(self, raw_socket, ssl_options): raise errors.SSLNotSupported(err_msg) return raw_socket - def establish_socket_connection(self, address_list): + def establish_socket_connection(self, address_list: _AddressList) -> socket.socket: """Given a list of database node addresses, establish the socket connection to the database server. Return a connected socket object. """ @@ -584,7 +588,7 @@ def establish_socket_connection(self, address_list): # Failover: loop to try all addresses while addrinfo: - (family, socktype, proto, canonname, sockaddr) = addrinfo + (family, _socktype, _proto, _canonname, sockaddr) = addrinfo last_exception = None # _AddressList filters all addrs to AF_INET and AF_INET6, which both @@ -613,7 +617,7 @@ def establish_socket_connection(self, address_list): return raw_socket - def ssl(self): + def ssl(self) -> bool: return self.socket is not None and isinstance(self.socket, ssl.SSLSocket) def write(self, message, vsocket=None): @@ -653,7 +657,7 @@ def reset_connection(self): self.close() self.startup_connection() - def is_asynchronous_message(self, message): + def is_asynchronous_message(self, message) -> bool: # Check if it is an asynchronous response message # Note: ErrorResponse is a subclass of NoticeResponse return (isinstance(message, messages.ParameterStatus) or @@ -674,7 +678,7 @@ def handle_asynchronous_message(self, message): warnings.warn(notice) self._logger.warning(message.error_message()) - def read_string(self): + def read_string(self) -> bytearray: s = bytearray() while True: char = self.read_bytes(1) @@ -747,7 +751,7 @@ def read_expected_message(self, expected_types, error_handler=None): self._logger.error(msg) raise errors.MessageError(msg) - def read_bytes(self, n): + def read_bytes(self, n: int) -> bytes: if n == 1: result = self._socket_as_file().read(1) if not result: @@ -765,7 +769,7 @@ def read_bytes(self, n): to_read -= received return buf - def send_GSS_response_and_receive_challenge(self, response): + def send_GSS_response_and_receive_challenge(self, response): # Send the GSS response data to the vertica server token = base64.b64decode(response) self.write(messages.Password(token, messages.Authentication.GSS)) diff --git a/vertica_python/vertica/cursor.py b/vertica_python/vertica/cursor.py index b5e358e5..a8556c9e 100644 --- a/vertica_python/vertica/cursor.py +++ b/vertica_python/vertica/cursor.py @@ -139,8 +139,11 @@ class Cursor(object): # NOTE: this is used in executemany and is here for pandas compatibility _insert_statement = re.compile(RE_BASIC_INSERT_STAT, re.U | re.I) - def __init__(self, connection, logger, cursor_type=None, unicode_error=None): - # type: (Connection, Logger, Optional[Union[Literal['list', 'dict'], Type[list[Any]], Type[dict[Any, Any]]]], Optional[str]) -> None + def __init__(self, + connection: Connection, + logger: Logger, + cursor_type: Union[None, str, Type[List[Any]], Type[Dict[Any, Any]]] = None, + unicode_error: Optional[str] = None) -> None: self.connection = connection self._logger = logger self.cursor_type = cursor_type diff --git a/vertica_python/vertica/messages/backend_messages/load_balance_response.py b/vertica_python/vertica/messages/backend_messages/load_balance_response.py index 4297b7be..ea3e988d 100644 --- a/vertica_python/vertica/messages/backend_messages/load_balance_response.py +++ b/vertica_python/vertica/messages/backend_messages/load_balance_response.py @@ -48,10 +48,10 @@ def __init__(self, data): self.port = unpacked[0] self.host = unpacked[1].decode('utf-8') - def get_port(self): + def get_port(self) -> int: return self.port - def get_host(self): + def get_host(self) -> str: return self.host def __str__(self): From e1bb4ad7d4a4dd9b003987aa55f580b35dbb1c72 Mon Sep 17 00:00:00 2001 From: Siting Ren Date: Tue, 14 Nov 2023 14:10:30 +0800 Subject: [PATCH 3/5] p3 --- vertica_python/vertica/connection.py | 33 ++++++++++++++++++---------- vertica_python/vertica/cursor.py | 2 +- 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/vertica_python/vertica/connection.py b/vertica_python/vertica/connection.py index b821000c..e37f8ca0 100644 --- a/vertica_python/vertica/connection.py +++ b/vertica_python/vertica/connection.py @@ -37,22 +37,21 @@ from __future__ import print_function, division, absolute_import, annotations import base64 +import getpass import logging +import random import socket import ssl -import getpass import uuid import warnings -from struct import unpack from collections import deque -import random +from struct import unpack # noinspection PyCompatibility,PyUnresolvedReferences from urllib.parse import urlparse, parse_qs -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, NamedTuple if TYPE_CHECKING: - from typing import Any, Dict, List, Optional, Type, Union, Deque, Tuple, NamedTuple - from typing_extensions import Self + from typing import Any, Dict, List, Optional, Type, Union, Deque, Tuple import vertica_python from .. import errors @@ -87,8 +86,8 @@ def connect(**kwargs: Any) -> Connection: return Connection(kwargs) -def parse_dsn(dsn: str): - """Parse connection string into a dictionary of keywords and values. +def parse_dsn(dsn: str) -> Dict[str, Union[str, int, bool, float]]: + """Parse connection string (DSN) into a dictionary of keywords and values. Connection string format: vertica://:@:/?k1=v1&k2=v2&... """ @@ -97,7 +96,7 @@ def parse_dsn(dsn: str): raise ValueError("Only vertica:// scheme is supported.") # Ignore blank/invalid values - result = {k: v for k, v in ( + result: Dict[str, Union[str, int, bool, float]] = {k: v for k, v in ( ('host', url.hostname), ('port', url.port), ('user', url.username), @@ -257,7 +256,7 @@ def _generate_session_label() -> str: class Connection(object): def __init__(self, options: Optional[Dict[str, Any]] = None) -> None: - self.parameters = {} + self.parameters: Dict[str, Union[str, int]] = {} self.session_id = None self.backend_pid = None self.backend_key = None @@ -385,6 +384,17 @@ def cursor(self, """Return the Cursor Object using the connection. vertica-python only support one cursor per connection. + + Argument cursor_type determines the type of query result rows. + The following cases return each row as a list. E.g. [ [1, 'foo'], [2, 'bar'] ] + - cursor() + - cursor(cursor_type=list) + - cursor(cursor_type='list') + + The following cases return each row as a dict with column names as keys. + E.g. [ {'id': 1, 'value': 'foo'}, {'id': 2, 'value': 'bar'} ] + - cursor(cursor_type=dict) + - cursor(cursor_type='dict') """ if self.closed(): raise errors.ConnectionError('Connection is closed') @@ -499,7 +509,7 @@ def _socket_as_file(self): return self.socket_as_file def create_socket(self, family) -> socket.socket: - """Create a TCP socket object""" + """Create a TCP socket object.""" raw_socket = socket.socket(family, socket.SOCK_STREAM) raw_socket.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) connection_timeout = self.options.get('connection_timeout') @@ -618,6 +628,7 @@ def establish_socket_connection(self, address_list: _AddressList) -> socket.sock return raw_socket def ssl(self) -> bool: + """Returns True if the TCP socket is a SSL socket.""" return self.socket is not None and isinstance(self.socket, ssl.SSLSocket) def write(self, message, vsocket=None): diff --git a/vertica_python/vertica/cursor.py b/vertica_python/vertica/cursor.py index a8556c9e..36e0d769 100644 --- a/vertica_python/vertica/cursor.py +++ b/vertica_python/vertica/cursor.py @@ -34,7 +34,7 @@ # THE SOFTWARE. -from __future__ import print_function, division, absolute_import +from __future__ import print_function, division, absolute_import, annotations import datetime import glob From 2d3c3b2b5ed0dfef91b6e054d9109a721c1134ff Mon Sep 17 00:00:00 2001 From: Siting Ren Date: Fri, 24 Nov 2023 14:42:12 +0800 Subject: [PATCH 4/5] format --- vertica_python/vertica/connection.py | 16 +++++++++------- vertica_python/vertica/cursor.py | 13 ++++++------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/vertica_python/vertica/connection.py b/vertica_python/vertica/connection.py index e37f8ca0..24ab5840 100644 --- a/vertica_python/vertica/connection.py +++ b/vertica_python/vertica/connection.py @@ -133,11 +133,13 @@ def parse_dsn(dsn: str) -> Dict[str, Union[str, int, bool, float]]: return result + class _AddressEntry(NamedTuple): host: str resolved: bool data: Any + class _AddressList(object): def __init__(self, host: str, port: Union[int, str], backup_nodes: List[Union[str, Tuple[str, Union[int, str]]]], @@ -631,7 +633,7 @@ def ssl(self) -> bool: """Returns True if the TCP socket is a SSL socket.""" return self.socket is not None and isinstance(self.socket, ssl.SSLSocket) - def write(self, message, vsocket=None): + def write(self, message, vsocket=None) -> None: if not isinstance(message, FrontendMessage): raise TypeError("invalid message: ({0})".format(message)) if vsocket is None: @@ -639,7 +641,7 @@ def write(self, message, vsocket=None): self._logger.debug('=> %s', message) try: for data in message.fetch_message(): - size = 8192 # Max msg size, consistent with how the server works + size = 8192 # Max msg size, consistent with how the server works pos = 0 while pos < len(data): sent = vsocket.send(data[pos : pos + size]) @@ -664,7 +666,7 @@ def close_socket(self) -> None: finally: self.reset_values() - def reset_connection(self): + def reset_connection(self) -> None: self.close() self.startup_connection() @@ -675,13 +677,13 @@ def is_asynchronous_message(self, message) -> bool: (isinstance(message, messages.NoticeResponse) and not isinstance(message, messages.ErrorResponse))) - def handle_asynchronous_message(self, message): + def handle_asynchronous_message(self, message) -> None: if isinstance(message, messages.ParameterStatus): if message.name == 'protocol_version': message.value = int(message.value) self.parameters[message.name] = message.value elif (isinstance(message, messages.NoticeResponse) and - not isinstance(message, messages.ErrorResponse)): + not isinstance(message, messages.ErrorResponse)): if getattr(self, 'notice_handler', None) is not None: self.notice_handler(message) else: @@ -794,7 +796,7 @@ def send_GSS_response_and_receive_challenge(self, response): raise errors.MessageError(msg) return message.auth_data - def make_GSS_authentication(self): + def make_GSS_authentication(self) -> None: try: import kerberos except ImportError as e: @@ -850,7 +852,7 @@ def make_GSS_authentication(self): self._logger.error(msg) raise errors.KerberosError(msg) - def startup_connection(self): + def startup_connection(self) -> None: user = self.options['user'] database = self.options['database'] session_label = self.options['session_label'] diff --git a/vertica_python/vertica/cursor.py b/vertica_python/vertica/cursor.py index 36e0d769..808cb03f 100644 --- a/vertica_python/vertica/cursor.py +++ b/vertica_python/vertica/cursor.py @@ -60,7 +60,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import IO, Any, AnyStr, Callable, Dict, Generator, List, Literal, Optional, Sequence, Tuple, Type, TypeVar, Union + from typing import IO, Any, AnyStr, Callable, Dict, Generator, List, Optional, Sequence, Tuple, Type, TypeVar, Union from typing_extensions import Self from .connection import Connection from logging import Logger @@ -432,8 +432,7 @@ def setoutputsize(self, size, column=None): ############################################# # non-dbapi methods ############################################# - def closed(self): - # type: () -> bool + def closed(self) -> bool: return self._closed or self.connection.closed() def cancel(self): @@ -453,12 +452,12 @@ def iterate(self): def copy(self, sql, data, **kwargs): # type: (str, IO[AnyStr], Any) -> None """ - EXAMPLE: + ``` >> with open("/tmp/file.csv", "rb") as fs: >> cursor.copy("COPY table(field1,field2) FROM STDIN DELIMITER ',' ENCLOSED BY ''''", >> fs, buffer_size=65536) - + ``` """ sql = as_text(sql) @@ -890,8 +889,8 @@ def _send_copy_file_data(self): self._send_copy_data(f, self.buffer_size) self.connection.write(messages.EndOfBatchRequest()) - def _read_copy_data_response(self, is_stdin_copy=False): - """Return True if the server wants us to load more data, false if we are done""" + def _read_copy_data_response(self, is_stdin_copy: bool = False): + """Returns True if the server wants us to load more data, False if we are done.""" self._message = self.connection.read_expected_message(END_OF_BATCH_RESPONSES) # Check for rejections during this load while isinstance(self._message, messages.WriteFile): From 39f6614e3884e1a9ea209397b2b84726111ba258 Mon Sep 17 00:00:00 2001 From: Siting Ren Date: Fri, 24 Nov 2023 15:05:40 +0800 Subject: [PATCH 5/5] p4 --- vertica_python/vertica/deserializer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vertica_python/vertica/deserializer.py b/vertica_python/vertica/deserializer.py index a8941c52..2d16e2fd 100644 --- a/vertica_python/vertica/deserializer.py +++ b/vertica_python/vertica/deserializer.py @@ -161,7 +161,7 @@ def load_time_text(val: bytes, ctx) -> time: return datetime.strptime(val, '%H:%M:%S').time() return datetime.strptime(val, '%H:%M:%S.%f').time() -def load_time_binary(val: bytes, ctx): +def load_time_binary(val: bytes, ctx) -> time: """ Parses binary representation of a TIME type. :param val: bytes @@ -180,7 +180,7 @@ def load_time_binary(val: bytes, ctx): except ValueError: raise errors.NotSupportedError("Time not supported by datetime.time. Got: hour={}".format(hour)) -def load_timetz_text(val, ctx): +def load_timetz_text(val: bytes, ctx) -> time: """ Parses text representation of a TIMETZ type. :param val: bytes