diff --git a/setools/dta.py b/setools/dta.py index 31faa45..f9f3c58 100644 --- a/setools/dta.py +++ b/setools/dta.py @@ -18,6 +18,7 @@ except ImportError: logging.getLogger(__name__).debug("NetworkX failed to import.") from .descriptors import EdgeAttrDict, EdgeAttrList +from .mixins import NetworkXGraphEdge from .policyrep import AnyTERule, SELinuxPolicy, TERuletype, Type __all__ = ['DomainTransitionAnalysis', 'DomainTransition', 'DomainEntrypoint', 'DTAPath'] @@ -584,7 +585,7 @@ class DomainTransitionAnalysis: @dataclass -class Edge: +class Edge(NetworkXGraphEdge): """ A graph edge. Also used for returning domain transition steps. @@ -624,20 +625,3 @@ class Edge: self.setcurrent = None else: raise ValueError("Edge does not exist in graph") - - def __getitem__(self, key): - # This is implemented so this object can be used in NetworkX - # functions that operate on (source, target) tuples - if isinstance(key, slice): - return [self._index_to_item(i) for i in range(* key.indices(2))] - else: - return self._index_to_item(key) - - def _index_to_item(self, index: int) -> Type: - """Return source or target based on index.""" - if index == 0: - return self.source - elif index == 1: - return self.target - else: - raise IndexError("Invalid index (edges only have 2 items): {0}".format(index)) diff --git a/setools/infoflow.py b/setools/infoflow.py index 177cd9c..af615ca 100644 --- a/setools/infoflow.py +++ b/setools/infoflow.py @@ -15,6 +15,7 @@ except ImportError: logging.getLogger(__name__).debug("NetworkX failed to import.") from .descriptors import EdgeAttrIntMax, EdgeAttrList +from .mixins import NetworkXGraphEdge from .permmap import PermissionMap from .policyrep import AVRule, SELinuxPolicy, TERuletype, Type @@ -392,7 +393,7 @@ class InfoFlowAnalysis: @dataclass -class InfoFlowStep: +class InfoFlowStep(NetworkXGraphEdge): """ A graph edge. Also used for returning information flow steps. @@ -428,20 +429,3 @@ class InfoFlowStep: self.weight = None else: raise ValueError("InfoFlowStep does not exist in graph") - - def __getitem__(self, key): - # This is implemented so this object can be used in NetworkX - # functions that operate on (source, target) tuples - if isinstance(key, slice): - return [self._index_to_item(i) for i in range(* key.indices(2))] - else: - return self._index_to_item(key) - - def _index_to_item(self, index: int) -> Type: - """Return source or target based on index.""" - if index == 0: - return self.source - elif index == 1: - return self.target - else: - raise IndexError("Invalid index (InfoFlowSteps only have 2 items): {0}".format(index)) diff --git a/setools/mixins.py b/setools/mixins.py index cbde443..5da75fb 100644 --- a/setools/mixins.py +++ b/setools/mixins.py @@ -6,7 +6,7 @@ # pylint: disable=attribute-defined-outside-init,no-member import re from logging import Logger -from typing import Iterable +from typing import Any from .descriptors import CriteriaDescriptor, CriteriaSetDescriptor, CriteriaPermissionSetDescriptor from .policyrep import Context @@ -208,3 +208,28 @@ class MatchPermission: return obj.perms >= self.perms else: return match_regex_or_set(obj.perms, self.perms, self.perms_equal, self.perms_regex) + + +class NetworkXGraphEdge: + + """Mixin enabling use in NetworkX functions.""" + + source: Any + target: Any + + def __getitem__(self, key): + # This is implemented so this object can be used in NetworkX + # functions that operate on (source, target) tuples + if isinstance(key, slice): + return [self._index_to_item(i) for i in range(* key.indices(2))] + else: + return self._index_to_item(key) + + def _index_to_item(self, index: int): + """Return source or target based on index.""" + if index == 0: + return self.source + elif index == 1: + return self.target + else: + raise IndexError(f"Invalid index (NetworkXGraphEdge only has 2 items): {index}")