diff --git a/setools/policyrep/netcontext.py b/setools/policyrep/netcontext.py index d6576b0..dba84e7 100644 --- a/setools/policyrep/netcontext.py +++ b/setools/policyrep/netcontext.py @@ -16,7 +16,7 @@ # License along with SETools. If not, see # . # -import socket +from socket import IPPROTO_TCP, IPPROTO_UDP, getprotobyname from collections import namedtuple from . import qpol @@ -146,16 +146,22 @@ class PortconProtocol(int): corresponding protocol string (udp, tcp). """ - _proto_to_text = {socket.IPPROTO_TCP: 'tcp', - socket.IPPROTO_UDP: 'udp'} + _proto_to_text = {IPPROTO_TCP: 'tcp', + IPPROTO_UDP: 'udp'} def __new__(cls, value): - if value not in cls._proto_to_text: - raise ValueError("{0} is not a supported IP protocol number. " - "Values such as {1} (TCP) or {2} (UDP) should be used.". - format(value, socket.IPPROTO_TCP, socket.IPPROTO_UDP)) + try: + # convert string representation + num = getprotobyname(value) + except TypeError: + num = value - return super(PortconProtocol, cls).__new__(cls, value) + if num not in cls._proto_to_text: + raise ValueError("{0} is not a supported IP protocol. " + "Values such as {1} (TCP) or {2} (UDP) should be used.". + format(value, IPPROTO_TCP, IPPROTO_UDP)) + + return super(PortconProtocol, cls).__new__(cls, num) def __str__(self): return self._proto_to_text[self]