From d015ef75578deab5bb7c5215bb906b6026f487ca Mon Sep 17 00:00:00 2001 From: Jason Peacock Date: Mon, 28 Oct 2024 12:21:29 -0500 Subject: [PATCH] Add types for SNMP security fields. The changes made are to support clients specifying the configuration without having to know about net-snmp's command line arguments. ZEN-35109 --- pynetsnmp/SnmpSession.py | 4 +- pynetsnmp/conversions.py | 29 ++++- pynetsnmp/netsnmp.py | 8 +- pynetsnmp/security.py | 111 ++++++++++++++++++++ pynetsnmp/twistedsnmp.py | 221 ++++++++++++++++++++------------------- pynetsnmp/usm.py | 91 ++++++++++++++++ 6 files changed, 350 insertions(+), 114 deletions(-) create mode 100644 pynetsnmp/security.py create mode 100644 pynetsnmp/usm.py diff --git a/pynetsnmp/SnmpSession.py b/pynetsnmp/SnmpSession.py index 1e5b118..f9d969d 100644 --- a/pynetsnmp/SnmpSession.py +++ b/pynetsnmp/SnmpSession.py @@ -1,7 +1,7 @@ -from __future__ import absolute_import - """Backwards compatible API for SnmpSession""" +from __future__ import absolute_import + from . import netsnmp diff --git a/pynetsnmp/conversions.py b/pynetsnmp/conversions.py index beac056..0398a01 100644 --- a/pynetsnmp/conversions.py +++ b/pynetsnmp/conversions.py @@ -1,11 +1,36 @@ from __future__ import absolute_import +from ipaddr import IPAddress + def asOidStr(oid): """converts an oid int sequence to an oid string""" - return "." + ".".join([str(x) for x in oid]) + return "." + ".".join(str(x) for x in oid) def asOid(oidStr): """converts an OID string into a tuple of integers""" - return tuple([int(x) for x in oidStr.strip(".").split(".")]) + return tuple(int(x) for x in oidStr.strip(".").split(".")) + + +def asAgent(ip, port): + """take a google ipaddr object and port number and produce a net-snmp + agent specification (see the snmpcmd manpage)""" + ip, interface = ip.split("%") if "%" in ip else (ip, None) + address = IPAddress(ip) + + if address.version == 4: + return "udp:{}:{}".format(address.compressed, port) + + if address.version == 6: + if address.is_link_local: + if interface is None: + raise RuntimeError( + "Cannot create agent specification from link local " + "IPv6 address without an interface" + ) + else: + return "udp6:[{}%{}]:{}".format( + address.compressed, interface, port + ) + return "udp6:[{}]:{}".format(address.compressed, port) diff --git a/pynetsnmp/netsnmp.py b/pynetsnmp/netsnmp.py index 4c8bcc3..c5d6c39 100644 --- a/pynetsnmp/netsnmp.py +++ b/pynetsnmp/netsnmp.py @@ -684,9 +684,7 @@ def _doNothingProc(argc, argv, arg): def parse_args(args, session): - args = [ - sys.argv[0], - ] + args + args = [sys.argv[0]] + args argc = len(args) argv = (c_char_p * argc)() for i in range(argc): @@ -694,7 +692,9 @@ def parse_args(args, session): argv[i] = create_string_buffer(args[i]).raw # WARNING: Usage of snmp_parse_args call causes memory leak. if lib.snmp_parse_args(argc, argv, session, "", _doNothingProc) < 0: - raise ArgumentParseError("Unable to parse arguments", " ".join(argv)) + raise ArgumentParseError( + "Unable to parse arguments arguments='{}'".format(" ".join(argv)) + ) # keep a reference to the args for as long as sess is alive return argv diff --git a/pynetsnmp/security.py b/pynetsnmp/security.py new file mode 100644 index 0000000..302761c --- /dev/null +++ b/pynetsnmp/security.py @@ -0,0 +1,111 @@ +from __future__ import absolute_import + +from .CONSTANTS import SNMP_VERSION_1, SNMP_VERSION_2c, SNMP_VERSION_3 +from .usm import auth_protocols, priv_protocols + + +class Community(object): + """ + Provides the community based security model for SNMP v1/V2c. + """ + + def __init__(self, name, version=SNMP_VERSION_2c): + version = _version_map.get(version) + if version is None: + raise ValueError("Unsupported SNMP version '{}'".format(version)) + self.name = name + self.version = version + + def getArguments(self): + community = ("-c", str(self.name)) if self.name else () + return ("-v", self.version) + community + + +class UsmUser(object): + """ + Provides User-based Security Model configuration for SNMP v3. + """ + + def __init__(self, name, auth=None, priv=None, engine=None, context=None): + self.name = name + if not isinstance(auth, (type(None), Authentication)): + raise ValueError("invalid authentication protocol") + self.auth = auth + if not isinstance(auth, (type(None), Privacy)): + raise ValueError("invalid privacy protocol") + self.priv = priv + self.engine = engine + self.context = context + self.version = _version_map.get(SNMP_VERSION_3) + + def getArguments(self): + auth = ( + ("-a", str(self.auth.protocol), "-A", self.auth.passphrase) + if self.auth + else () + ) + if auth: + # The privacy arguments are only given if the authentication + # arguments are also provided. + priv = ( + ("-x", str(self.priv.protocol), "-X", self.priv.passphrase) + if self.priv + else () + ) + else: + priv = () + seclevel = ("-l", _sec_level.get((auth, priv), "noAuthNoPriv")) + + return ( + ("-v", self.version) + + (("-u", self.name) if self.name else ()) + + seclevel + + auth + + priv + + (("-e", self.engine) if self.engine else ()) + + (("-n", self.context) if self.context else ()) + ) + + +_sec_level = {(True, True): "authPriv", (True, False): "authNoPriv"} +_version_map = { + SNMP_VERSION_1: "1", + SNMP_VERSION_2c: "2c", + SNMP_VERSION_3: "3", + "v1": "1", + "v2c": "2c", + "v3": "3", +} + + +class Authentication(object): + """ + Provides the authentication data for UsmUser objects. + """ + + def __init__(self, protocol, passphrase): + if protocol is None: + raise ValueError( + "Invalid Authentication protocol '{}'".format(protocol) + ) + self.protocol = auth_protocols[protocol] + if not passphrase: + raise ValueError( + "authentication protocol requires an " + "authentication passphrase" + ) + self.passphrase = passphrase + + +class Privacy(object): + """ + Provides the privacy data for UsmUser objects. + """ + + def __init__(self, protocol, passphrase): + if protocol is None: + raise ValueError("Invalid Privacy protocol '{}'".format(protocol)) + self.protocol = priv_protocols[protocol] + if not passphrase: + raise ValueError("privacy protocol requires a privacy passphrase") + self.passphrase = passphrase diff --git a/pynetsnmp/twistedsnmp.py b/pynetsnmp/twistedsnmp.py index e5fae50..701e804 100644 --- a/pynetsnmp/twistedsnmp.py +++ b/pynetsnmp/twistedsnmp.py @@ -3,7 +3,6 @@ import logging import struct -from ipaddr import IPAddress from twisted.internet import defer, reactor from twisted.internet.selectreactor import SelectReactor from twisted.internet.error import TimeoutError @@ -31,7 +30,7 @@ SNMP_ERR_WRONGTYPE, SNMP_ERR_WRONGVALUE, ) -from .conversions import asOidStr, asOid +from .conversions import asAgent, asOidStr, asOid from .tableretriever import TableRetriever @@ -39,6 +38,10 @@ class Timer(object): callLater = None +DEFAULT_PORT = 161 +DEFAULT_TIMEOUT = 2 +DEFAULT_RETRIES = 6 + timer = Timer() fdMap = {} @@ -103,9 +106,12 @@ def updateReactor(): log.debug("reactor settings: %r, %r", fds, t) for fd in fds: if isSelect and fd > netsnmp.MAXFD: - log.error("fd > %d detected!!" + - " This will not work properly with the SelectReactor and is being ignored." + - " Timeouts will occur unless you switch to EPollReactor instead!") + log.error( + "fd > %d detected!! " + "This will not work properly with the SelectReactor and " + "is being ignored. Timeouts will occur unless you switch " + "to EPollReactor instead!" + ) continue if fd not in fdMap: @@ -130,34 +136,6 @@ def __init__(self, oid): Exception.__init__(self, "Bad Name", oid) -def _get_agent_spec(ipobj, interface, port): - """take a google ipaddr object and port number and produce a net-snmp - agent specification (see the snmpcmd manpage)""" - if ipobj.version == 4: - agent = "udp:%s:%s" % (ipobj.compressed, port) - elif ipobj.version == 6: - if ipobj.is_link_local: - if interface is None: - raise RuntimeError( - "Cannot create agent specification from link local " - "IPv6 address without an interface" - ) - else: - agent = "udp6:[%s%%%s]:%s" % ( - ipobj.compressed, - interface, - port, - ) - else: - agent = "udp6:[%s]:%s" % (ipobj.compressed, port) - else: - raise RuntimeError( - "Cannot create agent specification for IP address version: %s" - % ipobj.version - ) - return agent - - class SnmpError(Exception): def __init__(self, message, *args, **kwargs): self.message = message @@ -206,6 +184,30 @@ class AgentProxy(object): the SNMP query. The list is ordered correctly by the OID (i.e. it is not ordered by the OID string).""" + @classmethod + def create( + cls, + address, + security=None, + timeout=DEFAULT_TIMEOUT, + retries=DEFAULT_RETRIES, + ): + try: + ip, port = address + except ValueError: + port = DEFAULT_PORT + try: + ip = address.pop(0) + except AttributeError: + ip = address + return cls( + ip, + port=port, + security=security, + timeout=timeout, + tries=retries, + ) + def __init__( self, ip, @@ -213,15 +215,21 @@ def __init__( community="public", snmpVersion="1", protocol=None, - allowCache=False, + allowCache=False, # no longer used timeout=1.5, tries=3, cmdLineArgs=(), + security=None, ): + if security is not None: + self._security = security + self.snmpVersion = security.version + else: + self._security = None + self.snmpVersion = snmpVersion self.ip = ip self.port = port self.community = community - self.snmpVersion = snmpVersion self.timeout = timeout self.tries = tries self.cmdLineArgs = cmdLineArgs @@ -256,16 +264,59 @@ def _signSafePop(self, d, intkey): def callback(self, pdu): """netsnmp session callback""" - result = [] response = netsnmp.getResult(pdu, self._log) try: - d, oids_requested = self._signSafePop(self.defers, pdu.reqid) + d, oids_requested = self._pop_requested_oids(pdu, response) + except RuntimeError: + return + + result = tuple( + (oid, asOidStr(value) if isinstance(value, tuple) else value) + for oid, value in response + ) + + if len(result) == 1 and result[0][0] not in oids_requested: + usmStatsOidStr = asOidStr(result[0][0]) + if usmStatsOidStr in USM_STATS_OIDS: + msg = USM_STATS_OIDS.get(usmStatsOidStr) + reactor.callLater( + 0, d.errback, failure.Failure(Snmpv3Error(msg)) + ) + return + elif usmStatsOidStr == ".1.3.6.1.6.3.15.1.1.2.0": + # we may get a subsequent snmp result with the correct value + # if not the timeout will be called at some point + self.defers[pdu.reqid] = (d, oids_requested) + return + if pdu.errstat != SNMP_ERR_NOERROR: + pduError = PDU_ERRORS.get( + pdu.errstat, "Unknown error (%d)" % pdu.errstat + ) + message = "Packet for %s has error: %s" % (self.ip, pduError) + if pdu.errstat in ( + SNMP_ERR_NOACCESS, + SNMP_ERR_RESOURCEUNAVAILABLE, + SNMP_ERR_AUTHORIZATIONERROR, + ): + reactor.callLater( + 0, d.errback, failure.Failure(SnmpError(message)) + ) + return + else: + result = [] + self._log.warning(message + ". OIDS: %s", oids_requested) + + reactor.callLater(0, d.callback, result) + + def _pop_requested_oids(self, pdu, response): + try: + return self._signSafePop(self.defers, pdu.reqid) except KeyError: # We seem to end up here if we use bad credentials with authPriv. # The only reasonable thing to do is call all of the deferreds with # Snmpv3Errors. - for usmStatsOid, count in response: + for usmStatsOid, _ in response: usmStatsOidStr = asOidStr(usmStatsOid) if usmStatsOidStr == ".1.3.6.1.6.3.15.1.1.2.0": @@ -280,7 +331,7 @@ def callback(self, pdu): "devices use usmStatsNotInTimeWindows as a normal " "part of the SNMPv3 handshake." ) - return + raise RuntimeError("usmStatsNotInTimeWindows error") if usmStatsOidStr == ".1.3.6.1.2.1.1.1.0": # Some devices (Cisco Nexus/MDS) use sysDescr as a normal @@ -289,7 +340,7 @@ def callback(self, pdu): "Received sysDescr during handshake. Some devices use " "sysDescr as a normal part of the SNMPv3 handshake." ) - return + raise RuntimeError("sysDescr during handshake") default_msg = "packet dropped (OID: {0})".format( usmStatsOidStr @@ -306,44 +357,7 @@ def callback(self, pdu): 0, d.errback, failure.Failure(Snmpv3Error(message)) ) - return - - for oid, value in response: - if isinstance(value, tuple): - value = asOidStr(value) - result.append((oid, value)) - if len(result) == 1 and result[0][0] not in oids_requested: - usmStatsOidStr = asOidStr(result[0][0]) - if usmStatsOidStr in USM_STATS_OIDS: - msg = USM_STATS_OIDS.get(usmStatsOidStr) - reactor.callLater( - 0, d.errback, failure.Failure(Snmpv3Error(msg)) - ) - return - elif usmStatsOidStr == ".1.3.6.1.6.3.15.1.1.2.0": - # we may get a subsequent snmp result with the correct value - # if not the timeout will be called at some point - self.defers[pdu.reqid] = (d, oids_requested) - return - if pdu.errstat != SNMP_ERR_NOERROR: - pduError = PDU_ERRORS.get( - pdu.errstat, "Unknown error (%d)" % pdu.errstat - ) - message = "Packet for %s has error: %s" % (self.ip, pduError) - if pdu.errstat in ( - SNMP_ERR_NOACCESS, - SNMP_ERR_RESOURCEUNAVAILABLE, - SNMP_ERR_AUTHORIZATIONERROR, - ): - reactor.callLater( - 0, d.errback, failure.Failure(SnmpError(message)) - ) - return - else: - result = [] - self._log.warning(message + ". OIDS: %s", oids_requested) - - reactor.callLater(0, d.callback, result) + raise RuntimeError(message) def timeout_(self, reqid): d = self._signSafePop(self.defers, reqid)[0] @@ -357,18 +371,7 @@ def _getCmdLineArgs(self): if version == "2": version += "c" - if "%" in self.ip: - address, interface = self.ip.split("%") - else: - address = self.ip - interface = None - - self._log.debug( - "AgentProxy._getCmdLineArgs: using google ipaddr on %s", address - ) - - ipobj = IPAddress(address) - agent = _get_agent_spec(ipobj, interface, self.port) + agent = asAgent(self.ip, self.port) cmdLineArgs = list(self.cmdLineArgs) + [ "-v", @@ -388,17 +391,26 @@ def open(self): self.session.close() self.session = None - self.session = netsnmp.Session( - version=netsnmp.SNMP_VERSION_MAP.get( - self.snmpVersion, netsnmp.SNMP_VERSION_2c - ), - timeout=int(self.timeout), - retries=int(self.tries), - peername="%s:%d" % (self.ip, self.port), - community=self.community, - community_len=len(self.community), - cmdLineArgs=self._getCmdLineArgs(), - ) + if self._security: + agent = asAgent(self.ip, self.port) + cmdlineargs = self._security.getArguments() + ( + ("-t", str(self.timeout), "-r", str(self.tries), agent) + ) + self.session = netsnmp.Session( + cmdLineArgs=cmdlineargs + ) + else: + self.session = netsnmp.Session( + version=netsnmp.SNMP_VERSION_MAP.get( + self.snmpVersion, netsnmp.SNMP_VERSION_2c + ), + timeout=int(self.timeout), + retries=int(self.tries), + peername="%s:%d" % (self.ip, self.port), + community=self.community, + community_len=len(self.community), + cmdLineArgs=self._getCmdLineArgs(), + ) self.session.callback = self.callback self.session.timeout = self.timeout_ @@ -468,11 +480,8 @@ def getbulk(self, nonrepeaters, maxrepititions, oidStrs): return deferred def _convertToDict(self, result): - def strKey(item): - return asOidStr(item[0]), item[1] - - if isinstance(result, list): - return dict(map(strKey, result)) + if isinstance(result, (list, tuple)): + return {asOidStr(key): value for key, value in result} return result diff --git a/pynetsnmp/usm.py b/pynetsnmp/usm.py new file mode 100644 index 0000000..bccd675 --- /dev/null +++ b/pynetsnmp/usm.py @@ -0,0 +1,91 @@ +class _Protocol(object): + __slots__ = ("__name",) + + def __init__(self, name): + self.__name = name + + def __str__(self): + return self.__name + + def __repr__(self): + return "<{0.__module__}.{0.__name__} {1}>".format( + self.__class__, self.__name + ) + + +class _Protocols(object): + __slots__ = ("__protocols", "__kind") + + def __init__(self, protocols, kind): + self.__protocols = protocols + self.__kind = kind + + def __len__(self): + return len(self.__protocols) + + def __iter__(self): + return iter(self.__protocols) + + def __contains__(self, proto): + if proto not in self.__protocols: + return any(str(p) == proto for p in self.__protocols) + return True + + def __getitem__(self, name): + name = str(name) + proto = next((p for p in self.__protocols if str(p) == name), None) + if proto is None: + raise KeyError("No {} protocol '{}'".format(self.__kind, name)) + return proto + + def __repr__(self): + return "<{0.__module__}.{0.__name__} {1}>".format( + self.__class__, ", ".join(str(p) for p in self.__protocols) + ) + + +AUTH_MD5 = _Protocol("MD5") +AUTH_SHA = _Protocol("SHA") +AUTH_SHA_224 = _Protocol("SHA-224") +AUTH_SHA_256 = _Protocol("SHA-256") +AUTH_SHA_384 = _Protocol("SHA-384") +AUTH_SHA_512 = _Protocol("SHA-512") + +auth_protocols = _Protocols( + ( + AUTH_MD5, + AUTH_SHA, + AUTH_SHA_224, + AUTH_SHA_256, + AUTH_SHA_384, + AUTH_SHA_512, + ), + "authentication", +) + +PRIV_DES = _Protocol("DES") +PRIV_AES = _Protocol("AES") +PRIV_AES_192 = _Protocol("AES-192") +PRIV_AES_256 = _Protocol("AES-256") + +priv_protocols = _Protocols( + (PRIV_DES, PRIV_AES, PRIV_AES_192, PRIV_AES_256), "privacy" +) + +del _Protocol +del _Protocols + +__all__ = ( + "AUTH_MD5", + "AUTH_SHA", + "AUTH_SHA_224", + "AUTH_SHA_256", + "AUTH_SHA_384", + "AUTH_SHA_512", + "auth_protocols", + "PRIV_DES", + "PRIV_AES", + "PRIV_AES_192", + "PRIV_AES_256", + "priv_protocols", +)