diff --git a/makefile b/makefile index e150de6..ef8b6ac 100644 --- a/makefile +++ b/makefile @@ -1,17 +1,17 @@ -IMAGENAME = zenoss/build-tools -VERSION = 0.0.14 +IMAGENAME = zenoss/zenpackbuild +VERSION = ubuntu2204-7 TAG = $(IMAGENAME):$(VERSION) UID := $(shell id -u) GID := $(shell id -g) -DOCKER_COMMAND = docker run --rm -v $(PWD):/mnt -w /mnt -u $(UID):$(GID) $(TAG) +DOCKER_COMMAND = docker run --rm -v $(PWD):/mnt -w /mnt $(TAG) .DEFAULT_GOAL := build .PHONY: bdist bdist: - @$(DOCKER_COMMAND) bash -c "python setup.py bdist_wheel" + $(DOCKER_COMMAND) bash -c "python setup.py bdist_wheel" .PHONY: sdist sdist: @@ -25,11 +25,14 @@ clean: rm -rf *.pyc dist build pynetsnmp.egg-info .PHONY: test -HOST ?= 127.0.0.1 test: - docker run --rm -v $(PWD):/mnt -w /mnt --user 0 $(TAG) \ - bash -c "python setup.py bdist_wheel \ - && pip install dist/pynetsnmp*py2-none-any.whl ipaddr Twisted==20.3.0 \ - && cd test \ - && python test_runner.py --host $(HOST) \ - && chown -R $(UID):$(GID) /mnt" ; + @$(DOCKER_COMMAND) bash -c "pip --no-python-version-warning install -q .; cd tests; python -m unittest discover" + +# HOST ?= 127.0.0.1 +# test: +# docker run --rm -v $(PWD):/mnt -w /mnt $(TAG) \ +# bash -c "python setup.py bdist_wheel \ +# && pip install dist/pynetsnmp*py2-none-any.whl ipaddr Twisted==20.3.0 \ +# && cd test \ +# && python test_runner.py --host $(HOST) \ +# && chown -R $(UID):$(GID) /mnt" ; diff --git a/pkg b/pkg deleted file mode 100755 index 6471d85..0000000 --- a/pkg +++ /dev/null @@ -1,37 +0,0 @@ -#! /bin/sh -# -# Script used by the Zenoss Dev team (the authors of pynetsnmp) to -# release and deploy updated versions. First, go edit version.py -# and update the VERSION value. -# -# Run this script. This will generate a tarball in the parent -# directory (..). Move this script to your zenoss inst/externallibs -# directory and remove the old tarball and add the new one: -# -# $ svn remove pynetsnmp-OLDVERSION.tar.gz -# $ svn add pynetsnmp-NEWVERSION.tar.gz -# -# Then you are set to re-release Zenoss. -# - -quit() { - echo $@ - exit 1 -} -PACKAGE=pynetsnmp -VERSION=`python -c 'import version; print version.VERSION'` -VPACKAGE=$PACKAGE-$VERSION -SVN=http://dev.zenoss.org/svnint -SVNTRUNK=$SVN/trunk/core/$PACKAGE -SVNTAG=$SVN/tags/core/$VPACKAGE -svn cp -m"making release $VERSION" $SVNTRUNK $SVNTAG || quit cannot create tag -svn export $SVNTAG /tmp/$VPACKAGE || quit cannot create export tree -OLD=`pwd` -( - cd /tmp - tar -czvf $OLD/../$VPACKAGE.tar.gz $VPACKAGE - rm -rf $VPACKAGE -) || quit cannot create tarball -echo "Remember to move ../$VPACKAGE.tar.gz the Zenoss " -echo "inst/externallibs directory and check it in." -exit 0 diff --git a/pynetsnmp/CONSTANTS.py b/pynetsnmp/CONSTANTS.py index 61d699f..2828148 100644 --- a/pynetsnmp/CONSTANTS.py +++ b/pynetsnmp/CONSTANTS.py @@ -1,5 +1,5 @@ +NULL = 0 USM_LENGTH_OID_TRANSFORM = 10 -NULL = None MAX_CALLBACK_IDS = 2 MAX_CALLBACK_SUBIDS = 16 SNMP_CALLBACK_LIBRARY = 0 @@ -306,7 +306,8 @@ NETSNMP_CALLBACK_OP_SEND_FAILED = 3 NETSNMP_CALLBACK_OP_CONNECT = 4 NETSNMP_CALLBACK_OP_DISCONNECT = 5 -snmp_init_statistics = () +NETSNMP_CALLBACK_OP_RESEND = 6 +NETSNMP_CALLBACK_OP_SEC_ERROR = 7 STAT_SNMPUNKNOWNSECURITYMODELS = 0 STAT_SNMPINVALIDMSGS = 1 STAT_SNMPUNKNOWNPDUHANDLERS = 2 @@ -377,7 +378,6 @@ MAX_STATS = NETSNMP_STAT_MAX_STATS COMMUNITY_MAX_LEN = 256 SPRINT_MAX_LEN = 2560 -NULL = 0 TRUE = 1 FALSE = 0 READ = 1 diff --git a/pynetsnmp/SnmpSession.py b/pynetsnmp/SnmpSession.py index 1e5b118..f9d969d 100644 --- a/pynetsnmp/SnmpSession.py +++ b/pynetsnmp/SnmpSession.py @@ -1,7 +1,7 @@ -from __future__ import absolute_import - """Backwards compatible API for SnmpSession""" +from __future__ import absolute_import + from . import netsnmp diff --git a/pynetsnmp/conversions.py b/pynetsnmp/conversions.py index beac056..0398a01 100644 --- a/pynetsnmp/conversions.py +++ b/pynetsnmp/conversions.py @@ -1,11 +1,36 @@ from __future__ import absolute_import +from ipaddr import IPAddress + def asOidStr(oid): """converts an oid int sequence to an oid string""" - return "." + ".".join([str(x) for x in oid]) + return "." + ".".join(str(x) for x in oid) def asOid(oidStr): """converts an OID string into a tuple of integers""" - return tuple([int(x) for x in oidStr.strip(".").split(".")]) + return tuple(int(x) for x in oidStr.strip(".").split(".")) + + +def asAgent(ip, port): + """take a google ipaddr object and port number and produce a net-snmp + agent specification (see the snmpcmd manpage)""" + ip, interface = ip.split("%") if "%" in ip else (ip, None) + address = IPAddress(ip) + + if address.version == 4: + return "udp:{}:{}".format(address.compressed, port) + + if address.version == 6: + if address.is_link_local: + if interface is None: + raise RuntimeError( + "Cannot create agent specification from link local " + "IPv6 address without an interface" + ) + else: + return "udp6:[{}%{}]:{}".format( + address.compressed, interface, port + ) + return "udp6:[{}]:{}".format(address.compressed, port) diff --git a/pynetsnmp/errors.py b/pynetsnmp/errors.py new file mode 100644 index 0000000..6cc73e2 --- /dev/null +++ b/pynetsnmp/errors.py @@ -0,0 +1,61 @@ +from __future__ import absolute_import + +from . import oids + + +class SnmpTimeoutError(Exception): + pass + + +class ArgumentParseError(Exception): + pass + + +class TransportError(Exception): + pass + + +class SnmpNameError(Exception): + def __init__(self, oid): + Exception.__init__(self, "Bad Name", oid) + + +class SnmpError(Exception): + def __init__(self, message, *args, **kwargs): + self.message = message + + def __str__(self): + return self.message + + def __repr__(self): + return self.message + + +class SnmpUsmError(SnmpError): + pass + + +class SnmpUsmStatsError(SnmpUsmError): + def __init__(self, mesg, oid): + super(SnmpUsmStatsError, self).__init__(mesg) + self.oid = oid + + +_stats_oid_error_map = { + oids.WrongDigest: SnmpUsmStatsError( + "unexpected authentication digest", oids.WrongDigest + ), + oids.UnknownUserName: SnmpUsmStatsError( + "unknown user", oids.UnknownUserName + ), + oids.UnknownSecurityLevel: SnmpUsmStatsError( + "unknown or unavailable security level", oids.UnknownSecurityLevel + ), + oids.DecryptionError: SnmpUsmStatsError( + "privacy decryption error", oids.DecryptionError + ), +} + + +def get_stats_error(oid): + return _stats_oid_error_map.get(oid) diff --git a/pynetsnmp/netsnmp.py b/pynetsnmp/netsnmp.py index afee683..f893ae5 100644 --- a/pynetsnmp/netsnmp.py +++ b/pynetsnmp/netsnmp.py @@ -12,7 +12,6 @@ Structure, Union, byref, - c_byte, c_char, c_char_p, c_double, @@ -60,6 +59,7 @@ MAX_OID_LEN, NETSNMP_CALLBACK_OP_RECEIVED_MESSAGE, NETSNMP_CALLBACK_OP_TIMED_OUT, + NETSNMP_CALLBACK_OP_SEC_ERROR, NETSNMP_DS_LIB_APPTYPE, NETSNMP_DS_LIBRARY_ID, NETSNMP_LOGHANDLER_CALLBACK, @@ -87,6 +87,7 @@ USM_AUTH_KU_LEN, USM_PRIV_KU_LEN, ) +from .errors import ArgumentParseError, SnmpTimeoutError def _getLogger(name): @@ -97,12 +98,12 @@ def _getLogger(name): find_library_orig = find_library def find_library(name): - for name in [ + for filename in [ "/usr/lib/lib%s.so" % name, "/usr/local/lib/lib%s.so" % name, ]: - if os.path.exists(name): - return name + if os.path.exists(filename): + return filename return find_library_orig(name) @@ -122,8 +123,13 @@ def find_library(name): return find_library_orig(name) -c_int_p = c_void_p -authenticator = CFUNCTYPE(c_char_p, c_int_p, c_char_p, c_int) +oid = c_long +size_t = c_size_t +u_char = c_ubyte +u_char_p = POINTER(c_ubyte) +u_int = c_uint +u_long = c_ulong +u_short = c_ushort try: # needed by newer netsnmp's @@ -131,18 +137,17 @@ def find_library(name): except Exception: import warnings - warnings.warn("Unable to load crypto library") + warnings.warn("Unable to load crypto library", stacklevel=1) lib = CDLL(find_library("netsnmp"), RTLD_GLOBAL) lib.netsnmp_get_version.restype = c_char_p -oid = c_long -u_long = c_ulong -u_short = c_ushort -u_char_p = c_char_p -u_int = c_uint -size_t = c_size_t -u_char = c_byte +version = lib.netsnmp_get_version() +float_version = float(".".join(version.split(".")[:2])) +_netsnmp_str_version = tuple(str(v) for v in version.split(".")) + +if float_version < 5.099: + raise ImportError("netsnmp version 5.1 or greater is required") class netsnmp_session(Structure): @@ -173,7 +178,13 @@ class netsnmp_trap_stats(Structure): ] -# include/net-snmp/types.h -> int (*netsnmp_callback) (int, netsnmp_session *, int, netsnmp_pdu *, void *); +authenticator = CFUNCTYPE( + u_char_p, u_char_p, POINTER(c_size_t), u_char_p, c_size_t +) + + +# include/net-snmp/types.h +# int (*netsnmp_callback) (int, netsnmp_session *, int, netsnmp_pdu *, void *); # the first argument is the return type in CFUNCTYPE notation. netsnmp_callback = CFUNCTYPE( c_int, @@ -187,9 +198,6 @@ class netsnmp_trap_stats(Structure): # int (*proc)(int, char * const *, int) arg_parse_proc = CFUNCTYPE(c_int, POINTER(c_char_p), c_int) -version = lib.netsnmp_get_version() -float_version = float(".".join(version.split(".")[:2])) -_netsnmp_str_version = tuple(str(v) for v in version.split(".")) localname = [] paramName = [] transportConfig = [] @@ -203,8 +211,6 @@ class netsnmp_trap_stats(Structure): identifier = [] fGetTaddr = [] -if float_version < 5.099: - raise ImportError("netsnmp version 5.1 or greater is required") if float_version > 5.199: localname = [("localname", c_char_p)] if float_version > 5.299: @@ -224,11 +230,13 @@ class netsnmp_container_s(Structure): transportConfig = [ ("transport_configuration", POINTER(netsnmp_container_s)) ] -if _netsnmp_str_version >= ('5','8'): - # Version >= 5.8 broke binary compatibility, adding the trap_stats member to the netsnmp_session struct - trapStats = [('trap_stats', POINTER(netsnmp_trap_stats))] - # Version >= 5.8 broke binary compatibility, adding the msgMaxSize member to the snmp_pdu struct - msgMaxSize = [('msgMaxSize', c_long)] +if _netsnmp_str_version >= ("5", "8"): + # Version >= 5.8 broke binary compatibility, adding the trap_stats + # member to the netsnmp_session struct + trapStats = [("trap_stats", POINTER(netsnmp_trap_stats))] + # Version >= 5.8 broke binary compatibility, adding the msgMaxSize + # member to the snmp_pdu struct + msgMaxSize = [("msgMaxSize", c_long)] baseTransport = [("base_transport", POINTER(netsnmp_transport))] fOpen = [("f_open", c_void_p)] fConfig = [("f_config", c_void_p)] @@ -236,7 +244,8 @@ class netsnmp_container_s(Structure): fSetupSession = [("f_setup_session", c_void_p)] identifier = [("identifier", POINTER(u_char_p))] fGetTaddr = [("f_get_taddr", c_void_p)] - # Version >= 5.8 broke binary compatibility, doubling the size of these constants used for struct sizes + # Version >= 5.8 broke binary compatibility, doubling the size of these + # constants used for struct sizes USM_AUTH_KU_LEN = 64 USM_PRIV_KU_LEN = 64 @@ -262,7 +271,7 @@ class netsnmp_container_s(Structure): ("subsession", POINTER(netsnmp_session)), ("next", POINTER(netsnmp_session)), ("peername", c_char_p), - ("remote_port", u_short), + ("remote_port", u_short), # deprecated ] + localname + [ @@ -296,14 +305,15 @@ class netsnmp_container_s(Structure): ("securityAuthLocalKeyLen", c_size_t), ("securityPrivProto", POINTER(oid)), ("securityPrivProtoLen", c_size_t), - ("securityPrivKey", c_char * USM_PRIV_KU_LEN), + ("securityPrivKey", u_char * USM_PRIV_KU_LEN), ("securityPrivKeyLen", c_size_t), ("securityPrivLocalKey", c_char_p), ("securityPrivLocalKeyLen", c_size_t), ("securityModel", c_int), ("securityLevel", c_int), ] - + paramName + trapStats + + paramName + + trapStats + [ ("securityInfo", c_void_p), ] @@ -323,6 +333,7 @@ class counter64(Structure): ("low", c_ulong), ] + # include/net-snmp/types.h class netsnmp_vardata(Union): _fields_ = [ @@ -339,6 +350,7 @@ class netsnmp_vardata(Union): class netsnmp_variable_list(Structure): pass + # include/net-snmp/types.h netsnmp_variable_list._fields_ = [ ("next_variable", POINTER(netsnmp_variable_list)), @@ -354,45 +366,49 @@ class netsnmp_variable_list(Structure): ("index", c_int), ] # include/net-snmp/types.h -netsnmp_pdu._fields_ = [ - ("version", c_long), - ("command", c_int), - ("reqid", c_long), - ("msgid", c_long), - ("transid", c_long), - ("sessid", c_long), - ("errstat", c_long), - ("errindex", c_long), - ("time", c_ulong), - ("flags", c_ulong), - ("securityModel", c_int), - ("securityLevel", c_int), - ("msgParseModel", c_int), - ] + msgMaxSize + [ - ("transport_data", c_void_p), - ("transport_data_length", c_int), - ("tDomain", POINTER(oid)), - ("tDomainLen", c_size_t), - ("variables", POINTER(netsnmp_variable_list)), - ("community", c_char_p), - ("community_len", c_size_t), - ("enterprise", POINTER(oid)), - ("enterprise_length", c_size_t), - ("trap_type", c_long), - ("specific_type", c_long), - ("agent_addr", c_ubyte * 4), - ("contextEngineID", c_char_p), - ("contextEngineIDLen", c_size_t), - ("contextName", c_char_p), - ("contextNameLen", c_size_t), - ("securityEngineID", c_char_p), - ("securityEngineIDLen", c_size_t), - ("securityName", c_char_p), - ("securityNameLen", c_size_t), - ("priority", c_int), - ("range_subid", c_int), - ("securityStateRef", c_void_p), -] +netsnmp_pdu._fields_ = ( + [ + ("version", c_long), + ("command", c_int), + ("reqid", c_long), + ("msgid", c_long), + ("transid", c_long), + ("sessid", c_long), + ("errstat", c_long), + ("errindex", c_long), + ("time", c_ulong), + ("flags", c_ulong), + ("securityModel", c_int), + ("securityLevel", c_int), + ("msgParseModel", c_int), + ] + + msgMaxSize + + [ + ("transport_data", c_void_p), + ("transport_data_length", c_int), + ("tDomain", POINTER(oid)), + ("tDomainLen", c_size_t), + ("variables", POINTER(netsnmp_variable_list)), + ("community", c_char_p), + ("community_len", c_size_t), + ("enterprise", POINTER(oid)), + ("enterprise_length", c_size_t), + ("trap_type", c_long), + ("specific_type", c_long), + ("agent_addr", c_ubyte * 4), + ("contextEngineID", c_char_p), + ("contextEngineIDLen", c_size_t), + ("contextName", c_char_p), + ("contextNameLen", c_size_t), + ("securityEngineID", c_char_p), + ("securityEngineIDLen", c_size_t), + ("securityName", c_char_p), + ("securityNameLen", c_size_t), + ("priority", c_int), + ("range_subid", c_int), + ("securityStateRef", c_void_p), + ] +) netsnmp_pdu_p = POINTER(netsnmp_pdu) @@ -404,7 +420,9 @@ class netsnmp_log_message(Structure): netsnmp_log_message_p = POINTER(netsnmp_log_message) -# callback.h typedef int (SNMPCallback) (int majorID, int minorID, void *serverarg, void *clientarg); +# callback.h +# typedef int (SNMPCallback) ( +# int majorID, int minorID, void *serverarg, void *clientarg); log_callback = CFUNCTYPE(c_int, c_int, netsnmp_log_message_p, c_void_p) # include/net-snmp/library/snmp_logging.h @@ -423,20 +441,23 @@ class netsnmp_log_message(Structure): LOG_DEBUG: logging.DEBUG, } + # snmplib/snmp_logging.c -> free(logh); -# include/net-snmp/output_api.h -> int snmp_log( int priority, const char *format, ...) +# include/net-snmp/output_api.h +# int snmp_log(int priority, const char *format, ...); # in net-snmp -> snmp_log(LOG_ERR|WARNING|INFO|DEBUG, msg) def netsnmp_logger(a, b, msg): msg = cast(msg, netsnmp_log_message_p) priority = PRIORITY_MAP.get(msg.contents.priority, logging.DEBUG) - _getLogger("netsnmp").log(priority, str(msg.contents.msg).strip()) + _getLogger("libnetsnmp").log(priority, str(msg.contents.msg).strip()) return 0 netsnmp_logger = log_callback(netsnmp_logger) -# include/net-snmp/library/callback.h -> -# int snmp_register_callback(int major, int minor, SNMPCallback * new_callback, void *arg); +# include/net-snmp/library/callback.h +# int snmp_register_callback( +# int major, int minor, SNMPCallback * new_callback, void *arg); lib.snmp_register_callback( SNMP_CALLBACK_LIBRARY, SNMP_CALLBACK_LOGGING, netsnmp_logger, 0 ) @@ -445,39 +466,55 @@ def netsnmp_logger(a, b, msg): lib.snmp_open.restype = POINTER(netsnmp_session) # include/net-snmp/library/snmp_transport.h -netsnmp_transport._fields_ = [ - ("domain", POINTER(oid)), - ("domain_length", c_int), - ("local", u_char_p), - ("local_length", c_int), - ("remote", u_char_p), - ("remote_length", c_int), - ("sock", c_int), - ("flags", u_int), - ("data", c_void_p), - ("data_length", c_int), - ("msgMaxSize", c_size_t), - ] + baseTransport + [ - ("f_recv", c_void_p), - ("f_send", c_void_p), - ("f_close", c_void_p), - ] + fOpen + [ - ("f_accept", c_void_p), - ("f_fmtaddr", c_void_p), -] + fCopy + fCopy + fSetupSession + identifier + fGetTaddr - -# include/net-snmp/library/snmp_transport.h -> -# netsnmp_transport *netsnmp_tdomain_transport( const char *str, int local, const char *default_domain); +netsnmp_transport._fields_ = ( + [ + ("domain", POINTER(oid)), + ("domain_length", c_int), + ("local", u_char_p), + ("local_length", c_int), + ("remote", u_char_p), + ("remote_length", c_int), + ("sock", c_int), + ("flags", u_int), + ("data", c_void_p), + ("data_length", c_int), + ("msgMaxSize", c_size_t), + ] + + baseTransport + + [ + ("f_recv", c_void_p), + ("f_send", c_void_p), + ("f_close", c_void_p), + ] + + fOpen + + [ + ("f_accept", c_void_p), + ("f_fmtaddr", c_void_p), + ] + + fCopy + + fCopy + + fSetupSession + + identifier + + fGetTaddr +) + +# include/net-snmp/library/snmp_transport.h +# netsnmp_transport *netsnmp_tdomain_transport( +# const char *str, int local, const char *default_domain); lib.netsnmp_tdomain_transport.restype = POINTER(netsnmp_transport) -# include/net-snmp/library/snmp_api.h -> netsnmp_session *snmp_add( -# netsnmp_session *, struct netsnmp_transport_s *, -# int (*fpre_parse) (netsnmp_session *, struct netsnmp_transport_s *, void *, int), -# int (*fpost_parse) (netsnmp_session *, netsnmp_pdu *, int) -# ); +# include/net-snmp/library/snmp_api.h +# netsnmp_session *snmp_add( +# netsnmp_session *, +# struct netsnmp_transport_s *, +# int (*fpre_parse) ( +# netsnmp_session *, struct netsnmp_transport_s *, void *, int), +# int (*fpost_parse) (netsnmp_session *, netsnmp_pdu *, int) +# ); lib.snmp_add.restype = POINTER(netsnmp_session) -# include/net-snmp/session_api.h -> int snmp_add_var(netsnmp_pdu *, const oid *, size_t, char, const char *); +# include/net-snmp/session_api.h +# int snmp_add_var(netsnmp_pdu *, const oid *, size_t, char, const char *); lib.snmp_add_var.argtypes = [ netsnmp_pdu_p, POINTER(oid), @@ -488,7 +525,8 @@ def netsnmp_logger(a, b, msg): lib.get_uptime.restype = c_long -# include/net-snmp/session_api.h -> int snmp_send(netsnmp_session *, netsnmp_pdu *); +# include/net-snmp/session_api.h +# int snmp_send(netsnmp_session *, netsnmp_pdu *); lib.snmp_send.argtypes = (POINTER(netsnmp_session), netsnmp_pdu_p) lib.snmp_send.restype = c_int @@ -551,15 +589,11 @@ def decodeString(pdu): return "" -_valueToConstant = dict( - [ - (chr(getattr(CONSTANTS, k)), k) - for k in CONSTANTS.__dict__.keys() - if isinstance(getattr(CONSTANTS, k), int) - and getattr(CONSTANTS, k) >= 0 - and getattr(CONSTANTS, k) < 256 - ] -) +_valueToConstant = { + chr(_v): _k + for _k, _v in CONSTANTS.__dict__.items() + if isinstance(_v, int) and (0 <= _v < 256) +} decoder = { @@ -605,16 +639,12 @@ def getResult(pdu, log): return result -class SnmpError(Exception): +class NetSnmpError(Exception): def __init__(self, why): lib.snmp_perror(why) Exception.__init__(self, why) -class SnmpTimeoutError(Exception): - pass - - sessionMap = {} @@ -626,6 +656,13 @@ def _callback(operation, sp, reqid, pdu, magic): sess.callback(pdu.contents) elif operation == NETSNMP_CALLBACK_OP_TIMED_OUT: sess.timeout(reqid) + elif operation == NETSNMP_CALLBACK_OP_SEC_ERROR: + _getLogger("callback").error( + "peer has rejected security credentials " + "peername=%s security-name=%s", + sp.contents.peername, + sp.contents.securityName, + ) else: _getLogger("callback").error("Unknown operation: %d", operation) except Exception as ex: @@ -636,14 +673,6 @@ def _callback(operation, sp, reqid, pdu, magic): _callback = netsnmp_callback(_callback) -class ArgumentParseError(Exception): - pass - - -class TransportError(Exception): - pass - - def _doNothingProc(argc, argv, arg): return 0 @@ -652,9 +681,7 @@ def _doNothingProc(argc, argv, arg): def parse_args(args, session): - args = [ - sys.argv[0], - ] + args + args = [sys.argv[0]] + args argc = len(args) argv = (c_char_p * argc)() for i in range(argc): @@ -662,7 +689,9 @@ def parse_args(args, session): argv[i] = create_string_buffer(args[i]).raw # WARNING: Usage of snmp_parse_args call causes memory leak. if lib.snmp_parse_args(argc, argv, session, "", _doNothingProc) < 0: - raise ArgumentParseError("Unable to parse arguments", " ".join(argv)) + raise ArgumentParseError( + "Unable to parse arguments arguments='{}'".format(" ".join(argv)) + ) # keep a reference to the args for as long as sess is alive return argv @@ -674,57 +703,88 @@ def initialize_session(sess, cmdLineArgs, kw): args = None kw = kw.copy() if cmdLineArgs: - cmdLine = [x for x in cmdLineArgs] - if isinstance(cmdLine[0], tuple): - result = [] - for opt, val in cmdLine: - result.append(opt) - result.append(val) - cmdLine = result - if kw.get("peername"): - cmdLine.append(kw["peername"]) - del kw["peername"] - args = parse_args(cmdLine, byref(sess)) + args = _init_from_args(sess, cmdLineArgs, kw) else: lib.snmp_sess_init(byref(sess)) for attr, value in kw.items(): pv = getattr(sess, attr, _NoAttribute) if pv is _NoAttribute: continue # Don't set invalid properties - if attr == "timeout": - # -1 means the property hasn't been set - if pv == -1: - # Converts seconds to microseconds - setattr(sess, attr, value * 1000000) - elif attr == "version": - # -1 means the property hasn't been set - if pv == -1: - setattr(sess, attr, value) - elif attr == "community": - # None means the property hasn't been set - if pv is None: - setattr(sess, attr, value) - setattr(sess, "community_len", len(value)) - elif attr == "community_len": - # Setting community_len on its own is a segfault waiting to happen - pass - else: - setattr(sess, attr, value) + _update_session(attr, value, pv, sess) return args -class Session(object): +def _init_from_args(sess, cmdLineArgs, kw): + cmdLine = list(cmdLineArgs) + if isinstance(cmdLine[0], tuple): + result = [] + for opt, val in cmdLine: + result.append(opt) + result.append(val) + cmdLine = result + if kw.get("peername"): + cmdLine.append(kw["peername"]) + del kw["peername"] + return parse_args(cmdLine, byref(sess)) + + +def _update_session(attr, value, pv, sess): + if attr == "timeout": + # -1 means 'timeout' hasn't been set + if pv == -1: + # Converts seconds to microseconds + setattr(sess, attr, value * 1000000) + elif attr == "version": + # -1 means 'version' hasn't been set + if pv == -1: + setattr(sess, attr, value) + elif attr == "community": + # None means 'community' hasn't been set + if pv is None: + setattr(sess, attr, value) + # Set 'community_len' at the same time because it's + # related to the value for the 'community' property. + sess.community_len = len(value) + elif attr == "community_len": + # Do nothing to avoid setting a 'community_len' value when no + # value has been set for 'community', otherwise, a segmentation + # fault can occur. + pass + else: + setattr(sess, attr, value) + +class Session(object): cb = None - def __init__(self, cmdLineArgs=(), **kw): + def __init__(self, cmdLineArgs=(), freeEtimelist=True, **kw): self.cmdLineArgs = cmdLineArgs + self.freeEtimelist = freeEtimelist self.kw = kw self.sess = None self.args = None self._data = None # ref to _CallbackData object self._log = _getLogger("session") + def _snmp_send(self, session, pdu): + """ + Allows execution of free_etimelist() after each snmp_send() call. + + Executes lib.free_etimelist() after each lib.snmp_send() call if the + `freeEtimelist` attribute is set, or re-calls lib.snmp_send() + otherwise. This frees all the memory used by entries in the + etimelist inside the net-snmp library, allowing the processing of + devices with duplicated engineID. + + Note: This feature is not supported by RFC. + """ + + try: + return lib.snmp_send(session, pdu) + finally: + if self.freeEtimelist: + lib.free_etimelist() + def open(self): sess = netsnmp_session() self.args = initialize_session(sess, self.cmdLineArgs, self.kw) @@ -736,7 +796,7 @@ def open(self): ref = byref(sess) self.sess = lib.snmp_open(ref) if not self.sess: - raise SnmpError("snmp_open") + raise NetSnmpError("snmp_open") def awaitTraps( self, peername, fileno=-1, pre_parse_callback=None, debug=False @@ -745,7 +805,6 @@ def awaitTraps( lib.netsnmp_ds_set_string( NETSNMP_DS_LIBRARY_ID, NETSNMP_DS_LIB_APPTYPE, "pynetsnmp" ) - lib.init_usm() if debug: lib.debug_register_tokens("snmp_parse") # or "ALL" for everything lib.snmp_set_do_debugging(1) @@ -764,15 +823,15 @@ def awaitTraps( lib.setup_engineID(None, None) transport = lib.netsnmp_tdomain_transport(peername, 1, "udp") if not transport: - raise SnmpError( + raise NetSnmpError( "Unable to create transport {peername}".format( peername=peername ) ) if fileno >= 0: os.dup2(fileno, transport.contents.sock) - sess = netsnmp_session() + sess = netsnmp_session() self.sess = pointer(sess) lib.snmp_sess_init(self.sess) sess.peername = SNMP_DEFAULT_PEERNAME @@ -791,34 +850,41 @@ def awaitTraps( sess.isAuthoritative = SNMP_SESS_UNKNOWNAUTH rc = lib.snmp_add(self.sess, transport, pre_parse_callback, None) if not rc: - raise SnmpError("snmp_add") + raise NetSnmpError("snmp_add") def create_users(self, users): self._log.debug("create_users: Creating %s users.", len(users)) for user in users: - if user.version == 3: - try: - line = "" - if user.engine_id: - line = "-e {} ".format(user.engine_id) - line += " ".join( - [ - user.username, - user.authentication_type, # MD5 or SHA - user.authentication_passphrase, - user.privacy_protocol, # DES or AES - user.privacy_passphrase, - ] - ) - lib.usm_parse_create_usmUser("createUser", line) - self._log.debug("create_users: created user: %s", user) - except StandardError as e: - self._log.debug( - "create_users: could not create user: %s: (%s: %s)", - user, - e.__class__.__name__, - e, + if str(user.version) != str(SNMP_VERSION_3): + self._log.info("create_users: user is not v3 %s", user) + continue + try: + line = "" + if user.engine: + line = "-e '{}'".format(user.engine) + if user.name: + line += " '{}'".format( + _escape_char("'", user.name), ) + if user.auth: + line += " '{}' '{}'".format( + _escape_char("'", user.auth.protocol.name), + _escape_char("'", user.auth.passphrase), + ) + if user.priv: + line += " '{}' '{}'".format( + _escape_char("'", user.priv.protocol.name), + _escape_char("'", user.priv.passphrase), + ) + lib.usm_parse_create_usmUser("createUser", line.strip()) + self._log.debug("create_users: created user: %s", user) + except Exception as e: + self._log.debug( + "create_users: could not create user: %s: (%s: %s)", + user, + e.__class__.__name__, + e, + ) def sendTrap(self, trapoid, varbinds=None): if "-v1" in self.cmdLineArgs: @@ -857,7 +923,7 @@ def sendTrap(self, trapoid, varbinds=None): n = strToOid(n) lib.snmp_add_var(pdu, n, len(n), t, v) - lib.snmp_send(self.sess, pdu) + self._snmp_send(self.sess, pdu) def close(self): if self.sess is not None: @@ -914,14 +980,14 @@ def _handle_send_status(self, req, send_status, send_type): lib.snmp_free_pdu(req) if snmperr.value == SNMPERR_TIMEOUT: raise SnmpTimeoutError() - raise SnmpError(msg_fmt % msg_args) + raise NetSnmpError(msg_fmt % msg_args) def get(self, oids): req = self._create_request(SNMP_MSG_GET) for oid in oids: oid = mkoid(oid) lib.snmp_add_null_var(req, oid, len(oid)) - send_status = lib.snmp_send(self.sess, req) + send_status = self._snmp_send(self.sess, req) self._handle_send_status(req, send_status, "get") return req.contents.reqid @@ -933,7 +999,7 @@ def getbulk(self, nonrepeaters, maxrepetitions, oids): for oid in oids: oid = mkoid(oid) lib.snmp_add_null_var(req, oid, len(oid)) - send_status = lib.snmp_send(self.sess, req) + send_status = self._snmp_send(self.sess, req) self._handle_send_status(req, send_status, "get") return req.contents.reqid @@ -941,12 +1007,16 @@ def walk(self, root): req = self._create_request(SNMP_MSG_GETNEXT) oid = mkoid(root) lib.snmp_add_null_var(req, oid, len(oid)) - send_status = lib.snmp_send(self.sess, req) + send_status = self._snmp_send(self.sess, req) self._log.debug("walk: send_status=%s", send_status) self._handle_send_status(req, send_status, "walk") return req.contents.reqid +def _escape_char(char, text): + return text.replace(char, r"\{}".format(char)) + + MAXFD = 1024 FD_SETSIZE = MAXFD fdset = c_int32 * (MAXFD / 32) @@ -969,6 +1039,7 @@ def fdset2list(rd, n): result.append(i * 32 + j) return result + class netsnmp_large_fd_set(Structure): # This structure must be initialized by calling netsnmp_large_fd_set_init() # and must be cleaned up via netsnmp_large_fd_set_cleanup(). If this last @@ -977,7 +1048,7 @@ class netsnmp_large_fd_set(Structure): _fields_ = [ ("lfs_setsize", c_uint), ("lfs_setptr", POINTER(fdset)), - ("lfs_set", fdset) + ("lfs_set", fdset), ] @@ -995,6 +1066,7 @@ def snmp_select_info(): t = timeout.tv_sec + timeout.tv_usec / 1e6 return fdset2list(rd, maxfd.value), t + def snmp_select_info2(): rd = netsnmp_large_fd_set() lib.netsnmp_large_fd_set_init(byref(rd), FD_SETSIZE) @@ -1004,7 +1076,9 @@ def snmp_select_info2(): timeout.tv_usec = 0 block = c_int(0) maxfd = c_int(MAXFD) - lib.snmp_select_info2(byref(maxfd), byref(rd), byref(timeout), byref(block)) + lib.snmp_select_info2( + byref(maxfd), byref(rd), byref(timeout), byref(block) + ) t = None if not block: t = timeout.tv_sec + timeout.tv_usec / 1e6 @@ -1017,11 +1091,13 @@ def snmp_select_info2(): lib.netsnmp_large_fd_set_cleanup(byref(rd)) return result, t + def snmp_read(fd): rd = fdset() rd[fd / 32] |= 1 << (fd % 32) lib.snmp_read(byref(rd)) + def snmp_read2(fd): rd = netsnmp_large_fd_set() lib.netsnmp_large_fd_set_init(byref(rd), FD_SETSIZE) @@ -1029,6 +1105,7 @@ def snmp_read2(fd): lib.snmp_read2(byref(rd)) lib.netsnmp_large_fd_set_cleanup(byref(rd)) + done = False diff --git a/pynetsnmp/oids.py b/pynetsnmp/oids.py new file mode 100644 index 0000000..873991b --- /dev/null +++ b/pynetsnmp/oids.py @@ -0,0 +1,94 @@ +from __future__ import absolute_import + +from .conversions import asOidStr + + +class OID(object): + __slots__ = ("oid",) + + def __init__(self, oid): + super(OID, self).__setattr__("oid", oid) + + def __setattr__(self, key, value): + if key in OID.__slots__: + raise AttributeError( + "can't set attribute '{}' on 'OID' object".format(key) + ) + super(OID, self).__setattr__(key, value) + + def __eq__(this, that): + if isinstance(that, (tuple, list)): + return this.oid == that + if isinstance(that, OID): + return this.oid == that.oid + return NotImplemented + + def __ne__(this, that): + if isinstance(that, (tuple, list)): + return this.oid != that + if isinstance(that, OID): + return this.oid != that.oid + return NotImplemented + + def __hash__(self): + return hash(self.oid) + + def __repr__(self): + return "<{0.__module__}.{0.__class__.__name__} {1}>".format( + self, asOidStr(self.oid) + ) + + def __str__(self): + return asOidStr(self.oid) + + +_base_status_oid = (1, 3, 6, 1, 6, 3, 15, 1, 1) + + +class UnknownSecurityLevel(OID): + __slots__ = () + + +UnknownSecurityLevel = UnknownSecurityLevel(_base_status_oid + (1, 0)) + + +class NotInTimeWindow(OID): + __slots__ = () + + +NotInTimeWindow = NotInTimeWindow(_base_status_oid + (2, 0)) + + +class UnknownUserName(OID): + __slots__ = () + + +UnknownUserName = UnknownUserName(_base_status_oid + (3, 0)) + + +class UnknownEngineId(OID): + __slots__ = () + + +UnknownEngineId = UnknownEngineId(_base_status_oid + (4, 0)) + + +class WrongDigest(OID): + __slots__ = () + + +WrongDigest = WrongDigest(_base_status_oid + (5, 0)) + + +class DecryptionError(OID): + __slots__ = () + + +DecryptionError = DecryptionError(_base_status_oid + (6, 0)) + + +class SysDescr(OID): + __slots__ = () + + +SysDescr = SysDescr((1, 3, 6, 1, 2, 1, 1, 1, 0)) diff --git a/pynetsnmp/twistedsnmp.py b/pynetsnmp/twistedsnmp.py index e5fae50..52142de 100644 --- a/pynetsnmp/twistedsnmp.py +++ b/pynetsnmp/twistedsnmp.py @@ -3,13 +3,12 @@ import logging import struct -from ipaddr import IPAddress from twisted.internet import defer, reactor from twisted.internet.selectreactor import SelectReactor from twisted.internet.error import TimeoutError from twisted.python import failure -from . import netsnmp +from . import netsnmp, oids from .CONSTANTS import ( SNMP_ERR_AUTHORIZATIONERROR, SNMP_ERR_BADVALUE, @@ -31,14 +30,21 @@ SNMP_ERR_WRONGTYPE, SNMP_ERR_WRONGVALUE, ) -from .conversions import asOidStr, asOid +from .conversions import asAgent, asOidStr, asOid +from .errors import SnmpError, SnmpUsmError, get_stats_error from .tableretriever import TableRetriever +log = netsnmp._getLogger("agentproxy") + class Timer(object): callLater = None +DEFAULT_PORT = 161 +DEFAULT_TIMEOUT = 2 +DEFAULT_RETRIES = 6 + timer = Timer() fdMap = {} @@ -103,9 +109,12 @@ def updateReactor(): log.debug("reactor settings: %r, %r", fds, t) for fd in fds: if isSelect and fd > netsnmp.MAXFD: - log.error("fd > %d detected!!" + - " This will not work properly with the SelectReactor and is being ignored." + - " Timeouts will occur unless you switch to EPollReactor instead!") + log.error( + "fd > %d detected!! " + "This will not work properly with the SelectReactor and " + "is being ignored. Timeouts will occur unless you switch " + "to EPollReactor instead!" + ) continue if fd not in fdMap: @@ -125,75 +134,6 @@ def updateReactor(): timer.callLater = reactor.callLater(t, checkTimeouts) -class SnmpNameError(Exception): - def __init__(self, oid): - Exception.__init__(self, "Bad Name", oid) - - -def _get_agent_spec(ipobj, interface, port): - """take a google ipaddr object and port number and produce a net-snmp - agent specification (see the snmpcmd manpage)""" - if ipobj.version == 4: - agent = "udp:%s:%s" % (ipobj.compressed, port) - elif ipobj.version == 6: - if ipobj.is_link_local: - if interface is None: - raise RuntimeError( - "Cannot create agent specification from link local " - "IPv6 address without an interface" - ) - else: - agent = "udp6:[%s%%%s]:%s" % ( - ipobj.compressed, - interface, - port, - ) - else: - agent = "udp6:[%s]:%s" % (ipobj.compressed, port) - else: - raise RuntimeError( - "Cannot create agent specification for IP address version: %s" - % ipobj.version - ) - return agent - - -class SnmpError(Exception): - def __init__(self, message, *args, **kwargs): - self.message = message - - def __str__(self): - return self.message - - def __repr__(self): - return self.message - - -class Snmpv3Error(SnmpError): - pass - - -USM_STATS_OIDS = { - # usmStatsWrongDigests - ".1.3.6.1.6.3.15.1.1.5.0": ( - "check zSnmpAuthType and zSnmpAuthPassword, " - "packet did not include the expected digest value" - ), - # usmStatsUnknownUserNames - ".1.3.6.1.6.3.15.1.1.3.0": ( - "check zSnmpSecurityName, packet referenced an unknown user" - ), - # usmStatsUnsupportedSecLevels - ".1.3.6.1.6.3.15.1.1.1.0": ( - "packet requested an unknown or unavailable security level" - ), - # usmStatsDecryptionErrors - ".1.3.6.1.6.3.15.1.1.6.0": ( - "check zSnmpPrivType, packet could not be decrypted" - ), -} - - class AgentProxy(object): """The public methods on AgentProxy (get, walk, getbulk) expect input OIDs to be strings, and the result they produce is a dictionary. The @@ -206,130 +146,121 @@ class AgentProxy(object): the SNMP query. The list is ordered correctly by the OID (i.e. it is not ordered by the OID string).""" + @classmethod + def create( + cls, + address, + security=None, + timeout=DEFAULT_TIMEOUT, + retries=DEFAULT_RETRIES, + ): + try: + ip, port = address + except ValueError: + port = DEFAULT_PORT + try: + ip = address.pop(0) + except AttributeError: + ip = address + return cls( + ip, port=port, security=security, timeout=timeout, tries=retries + ) + def __init__( self, ip, port=161, community="public", snmpVersion="1", - protocol=None, - allowCache=False, + protocol=None, # no longer used + allowCache=False, # no longer used timeout=1.5, tries=3, cmdLineArgs=(), + security=None, ): + if security is not None: + self._security = security + self.snmpVersion = security.version + else: + self._security = None + self.snmpVersion = snmpVersion self.ip = ip self.port = port self.community = community - self.snmpVersion = snmpVersion self.timeout = timeout self.tries = tries self.cmdLineArgs = cmdLineArgs - self.defers = {} + self.defers = _DeferredMap() self.session = None - self._log = netsnmp._getLogger("agentproxy") - def _signSafePop(self, d, intkey): - """ - Attempt to pop the item at intkey from dictionary d. - Upon failure, try to convert intkey from a signed to an unsigned - integer and try to pop again. + def open(self): + if self.session is not None: + self.session.close() + self.session = None + updateReactor() - This addresses potential integer rollover issues caused by the fact - that netsnmp_pdu.reqid is a c_long and the netsnmp_callback function - pointer definition specifies it as a c_int. See ZEN-4481. - """ - try: - return d.pop(intkey) - except KeyError as ex: - if intkey < 0: - self._log.debug("Negative ID for _signSafePop: %s", intkey) - # convert to unsigned, try that key - uintkey = struct.unpack("I", struct.pack("i", intkey))[0] - try: - return d.pop(uintkey) - except KeyError: - # Nothing by the unsigned key either, - # throw the original KeyError for consistency - raise ex - raise + if self._security: + agent = asAgent(self.ip, self.port) + cmdlineargs = self._security.getArguments() + ( + ("-t", str(self.timeout), "-r", str(self.tries), agent) + ) + self.session = netsnmp.Session(cmdLineArgs=cmdlineargs) + else: + self.session = netsnmp.Session( + version=netsnmp.SNMP_VERSION_MAP.get( + self.snmpVersion, netsnmp.SNMP_VERSION_2c + ), + timeout=int(self.timeout), + retries=int(self.tries), + peername="%s:%d" % (self.ip, self.port), + community=self.community, + community_len=len(self.community), + cmdLineArgs=self._getCmdLineArgs(), + ) + + self.session.callback = self.callback + self.session.timeout = self._handle_timeout + self.session.open() + updateReactor() + + def close(self): + if self.session is not None: + self.session.close() + self.session = None + updateReactor() def callback(self, pdu): """netsnmp session callback""" - result = [] - response = netsnmp.getResult(pdu, self._log) + response = netsnmp.getResult(pdu, log) try: - d, oids_requested = self._signSafePop(self.defers, pdu.reqid) + d, oids_requested = self.defers.pop(pdu.reqid) except KeyError: - # We seem to end up here if we use bad credentials with authPriv. - # The only reasonable thing to do is call all of the deferreds with - # Snmpv3Errors. - for usmStatsOid, count in response: - usmStatsOidStr = asOidStr(usmStatsOid) - - if usmStatsOidStr == ".1.3.6.1.6.3.15.1.1.2.0": - # Some devices use usmStatsNotInTimeWindows as a normal - # part of the SNMPv3 handshake (JIRA-1565). - # net-snmp automatically retries the request with the - # previous request_id and the values for - # msgAuthoritativeEngineBoots and - # msgAuthoritativeEngineTime from this error packet. - self._log.debug( - "Received a usmStatsNotInTimeWindows error. Some " - "devices use usmStatsNotInTimeWindows as a normal " - "part of the SNMPv3 handshake." - ) - return - - if usmStatsOidStr == ".1.3.6.1.2.1.1.1.0": - # Some devices (Cisco Nexus/MDS) use sysDescr as a normal - # part of the SNMPv3 handshake (JIRA-7943) - self._log.debug( - "Received sysDescr during handshake. Some devices use " - "sysDescr as a normal part of the SNMPv3 handshake." - ) - return - - default_msg = "packet dropped (OID: {0})".format( - usmStatsOidStr - ) - message = USM_STATS_OIDS.get(usmStatsOidStr, default_msg) - break - else: - message = "packet dropped" - - for d in ( - d for d, rOids in self.defers.itervalues() if not d.called - ): - reactor.callLater( - 0, d.errback, failure.Failure(Snmpv3Error(message)) - ) - + self._handle_missing_request(response) return - for oid, value in response: - if isinstance(value, tuple): - value = asOidStr(value) - result.append((oid, value)) + result = tuple( + (oid, asOidStr(value) if isinstance(value, tuple) else value) + for oid, value in response + ) + if len(result) == 1 and result[0][0] not in oids_requested: - usmStatsOidStr = asOidStr(result[0][0]) - if usmStatsOidStr in USM_STATS_OIDS: - msg = USM_STATS_OIDS.get(usmStatsOidStr) - reactor.callLater( - 0, d.errback, failure.Failure(Snmpv3Error(msg)) - ) + statsOid = result[0][0] + error = get_stats_error(statsOid) + if error: + reactor.callLater(0, d.errback, failure.Failure(error)) return - elif usmStatsOidStr == ".1.3.6.1.6.3.15.1.1.2.0": + if statsOid == oids.NotInTimeWindow: # we may get a subsequent snmp result with the correct value # if not the timeout will be called at some point self.defers[pdu.reqid] = (d, oids_requested) return if pdu.errstat != SNMP_ERR_NOERROR: pduError = PDU_ERRORS.get( - pdu.errstat, "Unknown error (%d)" % pdu.errstat + pdu.errstat, "unknown error (%d)" % pdu.errstat ) - message = "Packet for %s has error: %s" % (self.ip, pduError) + message = "packet for %s has error: %s" % (self.ip, pduError) if pdu.errstat in ( SNMP_ERR_NOACCESS, SNMP_ERR_RESOURCEUNAVAILABLE, @@ -341,13 +272,53 @@ def callback(self, pdu): return else: result = [] - self._log.warning(message + ". OIDS: %s", oids_requested) + log.warning(message + ". OIDS: %s", oids_requested) reactor.callLater(0, d.callback, result) - def timeout_(self, reqid): - d = self._signSafePop(self.defers, reqid)[0] - reactor.callLater(0, d.errback, failure.Failure(TimeoutError())) + def _handle_missing_request(self, response): + usmStatsOid, _ = next(iter(response), (None, None)) + + if usmStatsOid == oids.NotInTimeWindow: + # Some devices use usmStatsNotInTimeWindows as a normal part of + # the SNMPv3 handshake (JIRA-1565). net-snmp automatically + # retries the request with the previous request_id and the + # values for msgAuthoritativeEngineBoots and + # msgAuthoritativeEngineTime from this error packet. + log.debug( + "Received a usmStatsNotInTimeWindows error. Some " + "devices use usmStatsNotInTimeWindows as a normal " + "part of the SNMPv3 handshake." + ) + return + + if usmStatsOid == oids.SysDescr: + # Some devices (Cisco Nexus/MDS) use sysDescr as a normal + # part of the SNMPv3 handshake (JIRA-7943) + log.debug( + "Received sysDescr during handshake. Some devices use " + "sysDescr as a normal part of the SNMPv3 handshake." + ) + return + + if usmStatsOid is not None: + error = get_stats_error(usmStatsOid) + if not error: + error = SnmpUsmError( + "packet dropped (OID: {0})".format(asOidStr(usmStatsOid)) + ) + else: + error = SnmpUsmError("packet dropped") + + for d in (d for d, _ in self.defers.itervalues() if not d.called): + reactor.callLater(0, d.errback, failure.Failure(error)) + + def _handle_timeout(self, reqid): + try: + d = self.defers.pop(reqid)[0] + reactor.callLater(0, d.errback, failure.Failure(TimeoutError())) + except KeyError: + log.warning("handled timeout for unknown request") def _getCmdLineArgs(self): if not self.cmdLineArgs: @@ -357,18 +328,7 @@ def _getCmdLineArgs(self): if version == "2": version += "c" - if "%" in self.ip: - address, interface = self.ip.split("%") - else: - address = self.ip - interface = None - - self._log.debug( - "AgentProxy._getCmdLineArgs: using google ipaddr on %s", address - ) - - ipobj = IPAddress(address) - agent = _get_agent_spec(ipobj, interface, self.port) + agent = asAgent(self.ip, self.port) cmdLineArgs = list(self.cmdLineArgs) + [ "-v", @@ -383,34 +343,6 @@ def _getCmdLineArgs(self): ] return cmdLineArgs - def open(self): - if self.session is not None: - self.session.close() - self.session = None - - self.session = netsnmp.Session( - version=netsnmp.SNMP_VERSION_MAP.get( - self.snmpVersion, netsnmp.SNMP_VERSION_2c - ), - timeout=int(self.timeout), - retries=int(self.tries), - peername="%s:%d" % (self.ip, self.port), - community=self.community, - community_len=len(self.community), - cmdLineArgs=self._getCmdLineArgs(), - ) - - self.session.callback = self.callback - self.session.timeout = self.timeout_ - self.session.open() - updateReactor() - - def close(self): - if self.session is not None: - self.session.close() - self.session = None - updateReactor() - def _get(self, oids, timeout=None, retryCount=None): d = defer.Deferred() try: @@ -468,11 +400,8 @@ def getbulk(self, nonrepeaters, maxrepititions, oidStrs): return deferred def _convertToDict(self, result): - def strKey(item): - return asOidStr(item[0]), item[1] - - if isinstance(result, list): - return dict(map(strKey, result)) + if isinstance(result, (list, tuple)): + return {asOidStr(key): value for key, value in result} return result @@ -484,3 +413,25 @@ def port(self): snmpprotocol = _FakeProtocol() + + +class _DeferredMap(dict): + """ + Wrap the dict type to add extra behavior. + """ + + def pop(self, key): + """ + Attempt to pop the item at key from the dictionary. + """ + # Check for negative key to address potential integer rollover issues + # caused by the fact that netsnmp_pdu.reqid is a c_long and the + # netsnmp_callback function pointer definition specifies it as a + # c_int. See ZEN-4481. + if key not in self and key < 0: + log.debug("try negative ID for deferred map: %s", key) + # convert to unsigned, try that key + uintkey = struct.unpack("I", struct.pack("i", key))[0] + if uintkey in self: + key = uintkey + return super(_DeferredMap, self).pop(key) diff --git a/pynetsnmp/usm/__init__.py b/pynetsnmp/usm/__init__.py new file mode 100644 index 0000000..d089315 --- /dev/null +++ b/pynetsnmp/usm/__init__.py @@ -0,0 +1,43 @@ +from __future__ import absolute_import + +from .auth import Authentication +from .community import Community +from .priv import Privacy +from .user import User +from .protocols import ( + AUTH_MD5, + AUTH_NOAUTH, + auth_protocols, + AUTH_SHA, + AUTH_SHA_224, + AUTH_SHA_256, + AUTH_SHA_384, + AUTH_SHA_512, + PRIV_AES, + PRIV_AES_192, + PRIV_AES_256, + PRIV_DES, + PRIV_NOPRIV, + priv_protocols, +) + +__all__ = ( + "Authentication", + "AUTH_MD5", + "AUTH_NOAUTH", + "auth_protocols", + "AUTH_SHA", + "AUTH_SHA_224", + "AUTH_SHA_256", + "AUTH_SHA_384", + "AUTH_SHA_512", + "Community", + "Privacy", + "PRIV_AES", + "PRIV_AES_192", + "PRIV_AES_256", + "PRIV_DES", + "PRIV_NOPRIV", + "priv_protocols", + "User", +) diff --git a/pynetsnmp/usm/auth.py b/pynetsnmp/usm/auth.py new file mode 100644 index 0000000..38e3efc --- /dev/null +++ b/pynetsnmp/usm/auth.py @@ -0,0 +1,50 @@ +from __future__ import absolute_import + +from .protocols import AUTH_NOAUTH, auth_protocols + + +class Authentication(object): + """ + Provides the authentication data for User objects. + """ + + __slots__ = ("protocol", "passphrase") + + @classmethod + def new_noauth(cls): + return cls(None, None) + + def __init__(self, protocol, passphrase): + if ( + not protocol + or protocol is AUTH_NOAUTH + or protocol == "AUTH_NOAUTH" + ): + self.protocol = AUTH_NOAUTH + self.passphrase = None + else: + self.protocol = auth_protocols[protocol] + if not passphrase: + raise ValueError( + "Authentication protocol requires a passphrase" + ) + self.passphrase = passphrase + + def __eq__(self, other): + if not isinstance(other, Authentication): + return NotImplemented + return ( + self.protocol == other.protocol + and self.passphrase == other.passphrase + ) + + def __nonzero__(self): + return self.protocol is not AUTH_NOAUTH + + def __repr__(self): + return ( + "<{0.__module__}.{0.__class__.__name__} protocol={0.protocol}>" + ).format(self) + + def __str__(self): + return "{0.__class__.__name__}(protocol={0.protocol})".format(self) diff --git a/pynetsnmp/usm/common.py b/pynetsnmp/usm/common.py new file mode 100644 index 0000000..24b0b6a --- /dev/null +++ b/pynetsnmp/usm/common.py @@ -0,0 +1,19 @@ +from __future__ import absolute_import + +from ..CONSTANTS import ( + SNMP_VERSION_1 as _V1, + SNMP_VERSION_2c as _V2C, + SNMP_VERSION_3 as _V3, +) + +version_map = { + "1": "1", + "2c": "2c", + "3": "3", + _V1: "1", + _V2C: "2c", + _V3: "3", + "v1": "1", + "v2c": "2c", + "v3": "3", +} diff --git a/pynetsnmp/usm/community.py b/pynetsnmp/usm/community.py new file mode 100644 index 0000000..44c3a73 --- /dev/null +++ b/pynetsnmp/usm/community.py @@ -0,0 +1,23 @@ +from __future__ import absolute_import + +from ..CONSTANTS import SNMP_VERSION_2c as _V2C +from .common import version_map + + +class Community(object): + """ + Provides the community based security model for SNMP v1/V2c. + """ + + def __init__(self, name, version=_V2C): + mapped = version_map.get(version) + if mapped is None or mapped == "3": + raise ValueError( + "SNMP version '{}' not supported for Community".format(version) + ) + self.name = name + self.version = mapped + + def getArguments(self): + community = ("-c", str(self.name)) if self.name else () + return ("-v", self.version) + community diff --git a/pynetsnmp/usm/priv.py b/pynetsnmp/usm/priv.py new file mode 100644 index 0000000..22aa210 --- /dev/null +++ b/pynetsnmp/usm/priv.py @@ -0,0 +1,48 @@ +from __future__ import absolute_import + +from .protocols import PRIV_NOPRIV, priv_protocols + + +class Privacy(object): + """ + Provides the privacy data for User objects. + """ + + __slots__ = ("protocol", "passphrase") + + @classmethod + def new_nopriv(cls): + return cls(None, None) + + def __init__(self, protocol, passphrase): + if ( + not protocol + or protocol is PRIV_NOPRIV + or protocol == "PRIV_NOPRIV" + ): + self.protocol = PRIV_NOPRIV + self.passphrase = None + else: + self.protocol = priv_protocols[protocol] + if not passphrase: + raise ValueError("Privacy protocol requires a passphrase") + self.passphrase = passphrase + + def __eq__(self, other): + if not isinstance(other, Privacy): + return NotImplemented + return ( + self.protocol == other.protocol + and self.passphrase == other.passphrase + ) + + def __nonzero__(self): + return self.protocol is not PRIV_NOPRIV + + def __repr__(self): + return ( + "<{0.__module__}.{0.__class__.__name__} protocol={0.protocol}>" + ).format(self) + + def __str__(self): + return "{0.__class__.__name__}(protocol={0.protocol})".format(self) diff --git a/pynetsnmp/usm/protocols.py b/pynetsnmp/usm/protocols.py new file mode 100644 index 0000000..fd69de2 --- /dev/null +++ b/pynetsnmp/usm/protocols.py @@ -0,0 +1,94 @@ +from __future__ import absolute_import + + +class _Protocol(object): + """ """ + + __slots__ = ("name", "oid") + + def __init__(self, name, oid): + self.name = name + self.oid = oid + + def __eq__(self, other): + if not isinstance(other, type(self)): + return NotImplemented + return self.name == other.name and self.oid == other.oid + + def __str__(self): + return self.name + + def __repr__(self): + return "<{0.__module__}.{0.__name__} {1} {2}>".format( + self.__class__, self.name, ".".join(str(v) for v in self.oid) + ) + + +class _Protocols(object): + __slots__ = ("__protocols", "__kind") + + def __init__(self, protocols, kind): + self.__protocols = protocols + self.__kind = kind + + def __len__(self): + return len(self.__protocols) + + def __iter__(self): + return iter(self.__protocols) + + def __contains__(self, proto): + if not proto: + proto = self.__noargs + if proto not in self.__protocols: + return any(str(p) == proto for p in self.__protocols) + return True + + def __getitem__(self, name): + name = str(name) + proto = next((p for p in self.__protocols if str(p) == name), None) + if proto is None: + raise KeyError( + "unknown {} protocol '{}'".format(self.__kind, name) + ) + return proto + + def __repr__(self): + return "<{0.__module__}.{0.__name__} {1}>".format( + self.__class__, ", ".join(str(p) for p in self.__protocols) + ) + + +AUTH_NOAUTH = _Protocol("NOAUTH", (1, 3, 6, 1, 6, 3, 10, 1, 1, 1)) +AUTH_MD5 = _Protocol("MD5", (1, 3, 6, 1, 6, 3, 10, 1, 1, 2)) +AUTH_SHA = _Protocol("SHA", (1, 3, 6, 1, 6, 3, 10, 1, 1, 3)) +AUTH_SHA_224 = _Protocol("SHA-224", (1, 3, 6, 1, 6, 3, 10, 1, 1, 4)) +AUTH_SHA_256 = _Protocol("SHA-256", (1, 3, 6, 1, 6, 3, 10, 1, 1, 5)) +AUTH_SHA_384 = _Protocol("SHA-384", (1, 3, 6, 1, 6, 3, 10, 1, 1, 6)) +AUTH_SHA_512 = _Protocol("SHA-512", (1, 3, 6, 1, 6, 3, 10, 1, 1, 7)) + +auth_protocols = _Protocols( + ( + AUTH_NOAUTH, + AUTH_MD5, + AUTH_SHA, + AUTH_SHA_224, + AUTH_SHA_256, + AUTH_SHA_384, + AUTH_SHA_512, + ), + "authentication", +) + +PRIV_NOPRIV = _Protocol("NOPRIV", (1, 3, 6, 1, 6, 3, 10, 1, 2, 1)) +PRIV_DES = _Protocol("DES", (1, 3, 6, 1, 6, 3, 10, 1, 2, 2)) +PRIV_AES = _Protocol("AES", (1, 3, 6, 1, 6, 3, 10, 1, 2, 4)) +PRIV_AES_192 = _Protocol("AES-192", (1, 3, 6, 1, 4, 1, 14832, 1, 3)) +PRIV_AES_256 = _Protocol("AES-256", (1, 3, 6, 1, 4, 1, 14832, 1, 4)) + +priv_protocols = _Protocols( + (PRIV_NOPRIV, PRIV_DES, PRIV_AES, PRIV_AES_192, PRIV_AES_256), "privacy" +) + +del _Protocol +del _Protocols diff --git a/pynetsnmp/usm/user.py b/pynetsnmp/usm/user.py new file mode 100644 index 0000000..eca9a2b --- /dev/null +++ b/pynetsnmp/usm/user.py @@ -0,0 +1,89 @@ +from __future__ import absolute_import + +from ..CONSTANTS import SNMP_VERSION_3 as _V3 + +from .auth import Authentication +from .common import version_map +from .priv import Privacy + + +_sec_level = { + (True, True): "authPriv", + (True, False): "authNoPriv", + (False, False): "noAuthNoPriv", +} + + +class User(object): + """ + Provides User-based Security Model configuration for SNMP v3. + """ + + def __init__(self, name, auth=None, priv=None, engine=None, context=None): + self.name = name + if auth is None: + auth = Authentication.new_noauth() + if not isinstance(auth, Authentication): + raise ValueError("invalid authentication object") + self.auth = auth + if priv is None: + priv = Privacy.new_nopriv() + if not isinstance(priv, Privacy): + raise ValueError("invalid privacy object") + self.priv = priv + self.engine = engine + self.context = context + self.version = version_map.get(_V3) + + def getArguments(self): + auth_args = ( + ("-a", self.auth.protocol.name, "-A", self.auth.passphrase) + if self.auth + else () + ) + if auth_args: + # The privacy arguments are only given if the authentication + # arguments are also provided. + priv_args = ( + ("-x", self.priv.protocol.name, "-X", self.priv.passphrase) + if self.priv + else () + ) + else: + priv_args = () + seclevel_arg = ("-l", _sec_level[(bool(self.auth), bool(self.priv))]) + + return ( + ("-v", self.version) + + (("-u", self.name) if self.name else ()) + + seclevel_arg + + auth_args + + priv_args + + (("-e", self.engine) if self.engine else ()) + + (("-n", self.context) if self.context else ()) + ) + + def __eq__(self, other): + return ( + self.name == other.name + and self.auth == other.auth + and self.priv == other.priv + and self.engine == other.engine + and self.context == other.context + ) + + def __str__(self): + info = ", ".join( + "{0}={1}".format(k, v) + for k, v in ( + ("name", self.name), + ("auth", self.auth), + ("priv", self.priv), + ("engine", self.engine), + ("context", self.context), + ) + if v + ) + return "{0.__class__.__name__}(version={0.version}{1}{2})".format( + self, ", " if info else "", info + ) diff --git a/setup.py b/setup.py index fdd4f30..6a1c130 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ setup( name="pynetsnmp", - version="0.42.0", + version="0.43.0", packages=find_packages(), install_requires=["setuptools"], include_package_data=True, diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 0000000..a17d68f --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1,92 @@ +import unittest + +from pynetsnmp import usm + + +class TestAuthentication(unittest.TestCase): + passwd = "security123" # noqa: S105 + + def test_noauth_classmethod(t): + auth = usm.Authentication.new_noauth() + t.assertEqual(usm.AUTH_NOAUTH, auth.protocol) + t.assertIsNone(auth.passphrase) + + def test_none_init(t): + auth = usm.Authentication(None, None) + t.assertEqual(usm.AUTH_NOAUTH, auth.protocol) + t.assertIsNone(auth.passphrase) + + auth = usm.Authentication(None, t.passwd) + t.assertEqual(usm.AUTH_NOAUTH, auth.protocol) + t.assertIsNone(auth.passphrase) + + def test_noauth_init(t): + auth = usm.Authentication(usm.AUTH_NOAUTH, None) + t.assertEqual(usm.AUTH_NOAUTH, auth.protocol) + t.assertIsNone(auth.passphrase) + + auth = usm.Authentication(usm.AUTH_NOAUTH, t.passwd) + t.assertEqual(usm.AUTH_NOAUTH, auth.protocol) + t.assertIsNone(auth.passphrase) + + def test_noauth_str_init(t): + auth = usm.Authentication("AUTH_NOAUTH", None) + t.assertEqual(usm.AUTH_NOAUTH, auth.protocol) + t.assertIsNone(auth.passphrase) + + auth = usm.Authentication("AUTH_NOAUTH", t.passwd) + t.assertEqual(usm.AUTH_NOAUTH, auth.protocol) + t.assertIsNone(auth.passphrase) + + def test_noauth_is_false(t): + auth = usm.Authentication.new_noauth() + t.assertFalse(auth) + + def test_md5(t): + auth = usm.Authentication(usm.AUTH_MD5, t.passwd) + t.assertTrue(auth) + t.assertEqual(auth.protocol, usm.AUTH_MD5) + t.assertEqual(auth.passphrase, t.passwd) + + def test_sha(t): + auth = usm.Authentication(usm.AUTH_SHA, t.passwd) + t.assertTrue(auth) + t.assertEqual(auth.protocol, usm.AUTH_SHA) + t.assertEqual(auth.passphrase, t.passwd) + + def test_sha_224(t): + auth = usm.Authentication(usm.AUTH_SHA_224, t.passwd) + t.assertTrue(auth) + t.assertEqual(auth.protocol, usm.AUTH_SHA_224) + t.assertEqual(auth.passphrase, t.passwd) + + def test_sha_256(t): + auth = usm.Authentication(usm.AUTH_SHA_256, t.passwd) + t.assertTrue(auth) + t.assertEqual(auth.protocol, usm.AUTH_SHA_256) + t.assertEqual(auth.passphrase, t.passwd) + + def test_sha_384(t): + auth = usm.Authentication(usm.AUTH_SHA_384, t.passwd) + t.assertTrue(auth) + t.assertEqual(auth.protocol, usm.AUTH_SHA_384) + t.assertEqual(auth.passphrase, t.passwd) + + def test_sha_512(t): + auth = usm.Authentication(usm.AUTH_SHA_512, t.passwd) + t.assertTrue(auth) + t.assertEqual(auth.protocol, usm.AUTH_SHA_512) + t.assertEqual(auth.passphrase, t.passwd) + + def test_equal(t): + auth1 = usm.Authentication(usm.AUTH_MD5, t.passwd) + auth2 = usm.Authentication(usm.AUTH_MD5, t.passwd) + t.assertEqual(auth1, auth2) + + def test_not_equal(t): + auth1 = usm.Authentication(usm.AUTH_MD5, t.passwd) + auth2 = usm.Authentication(usm.AUTH_SHA, t.passwd) + t.assertNotEqual(auth1, auth2) + + auth3 = usm.Authentication(usm.AUTH_SHA, t.passwd + "456") + t.assertNotEqual(auth2, auth3) diff --git a/tests/test_community.py b/tests/test_community.py new file mode 100644 index 0000000..7432548 --- /dev/null +++ b/tests/test_community.py @@ -0,0 +1,62 @@ +import unittest + +from pynetsnmp import CONSTANTS, usm + + +class TestCommunity(unittest.TestCase): + name = "public" + + def test_default(t): + c = usm.Community(t.name) + t.assertEqual(c.name, t.name) + t.assertEqual(c.version, "2c") + expected = ("-v", "2c", "-c", t.name) + t.assertSequenceEqual(c.getArguments(), expected) + + def test_v1_constant(t): + c = usm.Community(t.name, CONSTANTS.SNMP_VERSION_1) + t.assertEqual(c.name, t.name) + t.assertEqual(c.version, "1") + expected = ("-v", "1", "-c", t.name) + t.assertSequenceEqual(c.getArguments(), expected) + + def test_v1_v1(t): + c = usm.Community(t.name, "v1") + t.assertEqual(c.name, t.name) + t.assertEqual(c.version, "1") + expected = ("-v", "1", "-c", t.name) + t.assertSequenceEqual(c.getArguments(), expected) + + def test_v2c_constant(t): + c = usm.Community(t.name, CONSTANTS.SNMP_VERSION_2c) + t.assertEqual(c.name, t.name) + t.assertEqual(c.version, "2c") + expected = ("-v", "2c", "-c", t.name) + t.assertSequenceEqual(c.getArguments(), expected) + + def test_v2c_v2c(t): + c = usm.Community(t.name, "v2c") + t.assertEqual(c.name, t.name) + t.assertEqual(c.version, "2c") + expected = ("-v", "2c", "-c", t.name) + t.assertSequenceEqual(c.getArguments(), expected) + + def test_v3_constant(t): + with t.assertRaises(ValueError): + usm.Community(t.name, CONSTANTS.SNMP_VERSION_3) + + def test_v3_v3(t): + with t.assertRaises(ValueError): + usm.Community(t.name, "v3") + + def test_none_version(t): + with t.assertRaises(ValueError): + usm.Community(t.name, None) + + def test_not_a_version_str(t): + with t.assertRaises(ValueError): + usm.Community(t.name, "oi") + + def test_not_a_version_number(t): + with t.assertRaises(ValueError): + usm.Community(t.name, 3947) diff --git a/tests/test_priv.py b/tests/test_priv.py new file mode 100644 index 0000000..9dc3842 --- /dev/null +++ b/tests/test_priv.py @@ -0,0 +1,80 @@ +import unittest + +from pynetsnmp import usm + + +class TestPrivacy(unittest.TestCase): + passwd = "security123" # noqa: S105 + + def test_nopriv_classmethod(t): + priv = usm.Privacy.new_nopriv() + t.assertEqual(usm.PRIV_NOPRIV, priv.protocol) + t.assertIsNone(priv.passphrase) + + def test_none_init(t): + priv = usm.Privacy(None, None) + t.assertEqual(usm.PRIV_NOPRIV, priv.protocol) + t.assertIsNone(priv.passphrase) + + priv = usm.Privacy(None, t.passwd) + t.assertEqual(usm.PRIV_NOPRIV, priv.protocol) + t.assertIsNone(priv.passphrase) + + def test_nopriv_init(t): + priv = usm.Privacy(usm.PRIV_NOPRIV, None) + t.assertEqual(usm.PRIV_NOPRIV, priv.protocol) + t.assertIsNone(priv.passphrase) + + priv = usm.Privacy(usm.PRIV_NOPRIV, t.passwd) + t.assertEqual(usm.PRIV_NOPRIV, priv.protocol) + t.assertIsNone(priv.passphrase) + + def test_nopriv_str_init(t): + priv = usm.Privacy("PRIV_NOPRIV", None) + t.assertEqual(usm.PRIV_NOPRIV, priv.protocol) + t.assertIsNone(priv.passphrase) + + priv = usm.Privacy("PRIV_NOPRIV", t.passwd) + t.assertEqual(usm.PRIV_NOPRIV, priv.protocol) + t.assertIsNone(priv.passphrase) + + def test_nopriv_is_false(t): + priv = usm.Privacy.new_nopriv() + t.assertFalse(priv) + + def test_des(t): + priv = usm.Privacy(usm.PRIV_DES, t.passwd) + t.assertTrue(priv) + t.assertEqual(priv.protocol, usm.PRIV_DES) + t.assertEqual(priv.passphrase, t.passwd) + + def test_aes(t): + priv = usm.Privacy(usm.PRIV_AES, t.passwd) + t.assertTrue(priv) + t.assertEqual(priv.protocol, usm.PRIV_AES) + t.assertEqual(priv.passphrase, t.passwd) + + def test_aes_192(t): + priv = usm.Privacy(usm.PRIV_AES_192, t.passwd) + t.assertTrue(priv) + t.assertEqual(priv.protocol, usm.PRIV_AES_192) + t.assertEqual(priv.passphrase, t.passwd) + + def test_aes_256(t): + priv = usm.Privacy(usm.PRIV_AES_256, t.passwd) + t.assertTrue(priv) + t.assertEqual(priv.protocol, usm.PRIV_AES_256) + t.assertEqual(priv.passphrase, t.passwd) + + def test_equal(t): + priv1 = usm.Privacy(usm.PRIV_DES, t.passwd) + priv2 = usm.Privacy(usm.PRIV_DES, t.passwd) + t.assertEqual(priv1, priv2) + + def test_not_equal(t): + priv1 = usm.Privacy(usm.PRIV_DES, t.passwd) + priv2 = usm.Privacy(usm.PRIV_AES, t.passwd) + t.assertNotEqual(priv1, priv2) + + priv3 = usm.Privacy(usm.PRIV_AES, t.passwd + "456") + t.assertNotEqual(priv2, priv3) diff --git a/tests/test_usm.py b/tests/test_usm.py new file mode 100644 index 0000000..3a6e4e9 --- /dev/null +++ b/tests/test_usm.py @@ -0,0 +1,114 @@ +import unittest + +from pynetsnmp import usm + +_sorted_auth_names = sorted( + ["NOAUTH", "MD5", "SHA", "SHA-224", "SHA-256", "SHA-384", "SHA-512"] +) + + +class TestAuthProtocols(unittest.TestCase): + def test_noauth_contained(t): + t.assertIn(usm.AUTH_NOAUTH, usm.auth_protocols) + + def test_md5_contained(t): + t.assertIn(usm.AUTH_MD5, usm.auth_protocols) + + def test_sha_contained(t): + t.assertIn(usm.AUTH_SHA, usm.auth_protocols) + + def test_sha_224_contained(t): + t.assertIn(usm.AUTH_SHA_224, usm.auth_protocols) + + def test_sha_256_contained(t): + t.assertIn(usm.AUTH_SHA_256, usm.auth_protocols) + + def test_sha_384_contained(t): + t.assertIn(usm.AUTH_SHA_384, usm.auth_protocols) + + def test_sha_512_contained(t): + t.assertIn(usm.AUTH_SHA_512, usm.auth_protocols) + + def test_length(t): + t.assertEqual(7, len(usm.auth_protocols)) + + def test_iterable(t): + names = sorted(str(p) for p in usm.auth_protocols) + t.assertEqual(7, len(names)) + t.assertListEqual(_sorted_auth_names, names) + + def test_noauth_getitem(t): + proto = usm.auth_protocols[usm.AUTH_NOAUTH.name] + t.assertEqual(usm.AUTH_NOAUTH, proto) + + def test_md5_getitem(t): + proto = usm.auth_protocols[usm.AUTH_MD5.name] + t.assertEqual(usm.AUTH_MD5, proto) + + def test_sha_getitem(t): + proto = usm.auth_protocols[usm.AUTH_SHA.name] + t.assertEqual(usm.AUTH_SHA, proto) + + def test_sha_224_getitem(t): + proto = usm.auth_protocols[usm.AUTH_SHA_224.name] + t.assertEqual(usm.AUTH_SHA_224, proto) + + def test_sha_256_getitem(t): + proto = usm.auth_protocols[usm.AUTH_SHA_256.name] + t.assertEqual(usm.AUTH_SHA_256, proto) + + def test_sha_384_getitem(t): + proto = usm.auth_protocols[usm.AUTH_SHA_384.name] + t.assertEqual(usm.AUTH_SHA_384, proto) + + def test_sha_512_getitem(t): + proto = usm.auth_protocols[usm.AUTH_SHA_512.name] + t.assertEqual(usm.AUTH_SHA_512, proto) + + +_sorted_priv_names = sorted(["NOPRIV", "DES", "AES", "AES-192", "AES-256"]) + + +class TestPrivProtocols(unittest.TestCase): + def test_nopriv_contained(t): + t.assertIn(usm.PRIV_NOPRIV, usm.priv_protocols) + + def test_des_contained(t): + t.assertIn(usm.PRIV_DES, usm.priv_protocols) + + def test_aes_contained(t): + t.assertIn(usm.PRIV_AES, usm.priv_protocols) + + def test_aes_192_contained(t): + t.assertIn(usm.PRIV_AES_192, usm.priv_protocols) + + def test_aes_256_contained(t): + t.assertIn(usm.PRIV_AES_256, usm.priv_protocols) + + def test_length(t): + t.assertEqual(5, len(usm.priv_protocols)) + + def test_iterable(t): + names = sorted(str(p) for p in usm.priv_protocols) + t.assertEqual(5, len(names)) + t.assertListEqual(_sorted_priv_names, names) + + def test_nopriv_getitem(t): + proto = usm.priv_protocols[usm.PRIV_NOPRIV.name] + t.assertEqual(usm.PRIV_NOPRIV, proto) + + def test_des_getitem(t): + proto = usm.priv_protocols[usm.PRIV_DES.name] + t.assertEqual(usm.PRIV_DES, proto) + + def test_aes_getitem(t): + proto = usm.priv_protocols[usm.PRIV_AES.name] + t.assertEqual(usm.PRIV_AES, proto) + + def test_aes_192_getitem(t): + proto = usm.priv_protocols[usm.PRIV_AES_192.name] + t.assertEqual(usm.PRIV_AES_192, proto) + + def test_aes_256_getitem(t): + proto = usm.priv_protocols[usm.PRIV_AES_256.name] + t.assertEqual(usm.PRIV_AES_256, proto) diff --git a/tests/test_usmuser.py b/tests/test_usmuser.py new file mode 100644 index 0000000..d28c025 --- /dev/null +++ b/tests/test_usmuser.py @@ -0,0 +1,157 @@ +import unittest + +from pynetsnmp import usm + + +class TestUser(unittest.TestCase): + name = "john_doe" + passwd = "secured123" # noqa: S105 + + def test_default(t): + user = usm.User(t.name) + t.assertEqual(t.name, user.name) + t.assertEqual(usm.Authentication.new_noauth(), user.auth) + t.assertEqual(usm.Privacy.new_nopriv(), user.priv) + t.assertIsNone(user.engine) + t.assertIsNone(user.context) + t.assertEqual(user.version, "3") + expected = ("-v", "3", "-u", t.name, "-l", "noAuthNoPriv") + t.assertSequenceEqual(expected, user.getArguments()) + + def test_engineid(t): + engineid = hex(3443489794829589283483234)[2:].strip("L") + user = usm.User(t.name, engine=engineid) + t.assertEqual(t.name, user.name) + t.assertEqual(usm.Authentication.new_noauth(), user.auth) + t.assertEqual(usm.Privacy.new_nopriv(), user.priv) + t.assertEqual(engineid, user.engine) + t.assertIsNone(user.context) + expected = ( + "-v", + "3", + "-u", + t.name, + "-l", + "noAuthNoPriv", + "-e", + engineid, + ) + t.assertSequenceEqual(expected, user.getArguments()) + + def test_contextid(t): + contextid = hex(9084090984572743455234)[2:].strip("L") + user = usm.User(t.name, context=contextid) + t.assertEqual(t.name, user.name) + t.assertEqual(usm.Authentication.new_noauth(), user.auth) + t.assertEqual(usm.Privacy.new_nopriv(), user.priv) + t.assertIsNone(user.engine) + t.assertEqual(contextid, user.context) + expected = ( + "-v", + "3", + "-u", + t.name, + "-l", + "noAuthNoPriv", + "-n", + contextid, + ) + t.assertSequenceEqual(expected, user.getArguments()) + + def test_auth(t): + auth = usm.Authentication(usm.AUTH_SHA_224, t.passwd) + user = usm.User(t.name, auth=auth) + t.assertEqual(t.name, user.name) + t.assertEqual(auth, user.auth) + t.assertEqual(usm.Privacy.new_nopriv(), user.priv) + t.assertIsNone(user.engine) + t.assertIsNone(user.context) + t.assertEqual(user.version, "3") + expected = ( + "-v", + "3", + "-u", + t.name, + "-l", + "authNoPriv", + "-a", + auth.protocol.name, + "-A", + auth.passphrase, + ) + t.assertSequenceEqual(expected, user.getArguments()) + + def test_authpriv(t): + auth = usm.Authentication(usm.AUTH_SHA_224, t.passwd) + priv = usm.Privacy(usm.PRIV_AES_256, t.passwd) + user = usm.User(t.name, auth=auth, priv=priv) + t.assertEqual(t.name, user.name) + t.assertEqual(auth, user.auth) + t.assertEqual(priv, user.priv) + t.assertIsNone(user.engine) + t.assertIsNone(user.context) + t.assertEqual(user.version, "3") + expected = ( + "-v", + "3", + "-u", + t.name, + "-l", + "authPriv", + "-a", + auth.protocol.name, + "-A", + auth.passphrase, + "-x", + priv.protocol.name, + "-X", + priv.passphrase, + ) + t.assertSequenceEqual(expected, user.getArguments()) + + def test_all_args(t): + auth = usm.Authentication(usm.AUTH_SHA_224, t.passwd) + priv = usm.Privacy(usm.PRIV_AES_256, t.passwd) + contextid = hex(9084090984572743455234)[2:].strip("L") + engineid = hex(3443489794829589283483234)[2:].strip("L") + user = usm.User( + t.name, auth=auth, priv=priv, engine=engineid, context=contextid + ) + t.assertEqual(t.name, user.name) + t.assertEqual(auth, user.auth) + t.assertEqual(priv, user.priv) + t.assertEqual(engineid, user.engine) + t.assertEqual(contextid, user.context) + t.assertEqual(user.version, "3") + expected = ( + "-v", + "3", + "-u", + t.name, + "-l", + "authPriv", + "-a", + auth.protocol.name, + "-A", + auth.passphrase, + "-x", + priv.protocol.name, + "-X", + priv.passphrase, + "-e", + engineid, + "-n", + contextid, + ) + t.assertSequenceEqual(expected, user.getArguments()) + + def test_equality(t): + auth1 = usm.Authentication(usm.AUTH_SHA_224, t.passwd) + user1 = usm.User(t.name, auth=auth1) + auth2 = usm.Authentication(usm.AUTH_SHA_224, t.passwd) + user2 = usm.User(t.name, auth=auth2) + auth3 = usm.Authentication(usm.AUTH_SHA_256, t.passwd) + priv3 = usm.Privacy(usm.PRIV_AES_256, t.passwd) + user3 = usm.User(t.name, auth=auth3, priv=priv3) + t.assertEqual(user1, user2) + t.assertNotEqual(user1, user3)