Merge pull request #81 from pebenito/nx-dataclasses

Convert data structures to Python dataclasses where relevant.
This commit is contained in:
Chris PeBenito 2023-02-07 11:47:14 -05:00 committed by GitHub
commit d491963133
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 60 additions and 59 deletions

View File

@ -16,6 +16,7 @@ To run SETools command line tools, the following packages are required:
* Python 3.6+ * Python 3.6+
* NetworkX 2.0+ (2.6+ for Python 3.9+) * NetworkX 2.0+ (2.6+ for Python 3.9+)
* setuptools * setuptools
* dataclasses (Python 3.6 only)
* libselinux * libselinux
* libsepol 3.2+ * libsepol 3.2+

View File

@ -8,7 +8,8 @@ import itertools
import logging import logging
from collections import defaultdict from collections import defaultdict
from contextlib import suppress from contextlib import suppress
from typing import DefaultDict, Iterable, List, NamedTuple, Optional, Union from dataclasses import dataclass, InitVar
from typing import DefaultDict, Iterable, List, Optional, Union
try: try:
import networkx as nx import networkx as nx
@ -17,14 +18,16 @@ except ImportError:
logging.getLogger(__name__).debug("NetworkX failed to import.") logging.getLogger(__name__).debug("NetworkX failed to import.")
from .descriptors import EdgeAttrDict, EdgeAttrList from .descriptors import EdgeAttrDict, EdgeAttrList
from .mixins import NetworkXGraphEdge
from .policyrep import AnyTERule, SELinuxPolicy, TERuletype, Type from .policyrep import AnyTERule, SELinuxPolicy, TERuletype, Type
__all__ = ['DomainTransitionAnalysis', 'DomainTransition', 'DomainEntrypoint', 'DTAPath'] __all__ = ['DomainTransitionAnalysis', 'DomainTransition', 'DomainEntrypoint', 'DTAPath']
class DomainEntrypoint(NamedTuple): @dataclass
class DomainEntrypoint:
"""Entrypoint list entry named tuple output format.""" """Entrypoint list entry."""
name: Type name: Type
entrypoint: List[AnyTERule] entrypoint: List[AnyTERule]
@ -32,9 +35,10 @@ class DomainEntrypoint(NamedTuple):
type_transition: List[AnyTERule] type_transition: List[AnyTERule]
class DomainTransition(NamedTuple): @dataclass
class DomainTransition:
"""Transition step output named tuple format.""" """Transition step output."""
source: Type source: Type
target: Type target: Type
@ -580,7 +584,8 @@ class DomainTransitionAnalysis:
nx.number_of_edges(self.subG))) nx.number_of_edges(self.subG)))
class Edge: @dataclass
class Edge(NetworkXGraphEdge):
""" """
A graph edge. Also used for returning domain transition steps. A graph edge. Also used for returning domain transition steps.
@ -595,6 +600,10 @@ class Edge:
The default is False. The default is False.
""" """
G: nx.DiGraph
source: Type
target: Type
create: InitVar[bool] = False
transition = EdgeAttrList() transition = EdgeAttrList()
setexec = EdgeAttrList() setexec = EdgeAttrList()
dyntransition = EdgeAttrList() dyntransition = EdgeAttrList()
@ -603,16 +612,10 @@ class Edge:
execute = EdgeAttrDict() execute = EdgeAttrDict()
type_transition = EdgeAttrDict() type_transition = EdgeAttrDict()
def __init__(self, graph, source: Type, target: Type, create: bool = False) -> None: def __post_init__(self, create) -> None:
self.G = graph if not self.G.has_edge(self.source, self.target):
self.source: Type = source if create:
self.target: Type = target self.G.add_edge(self.source, self.target)
if not self.G.has_edge(source, target):
if not create:
raise ValueError("Edge does not exist in graph")
else:
self.G.add_edge(source, target)
self.transition = None self.transition = None
self.entrypoint = None self.entrypoint = None
self.execute = None self.execute = None
@ -620,20 +623,5 @@ class Edge:
self.setexec = None self.setexec = None
self.dyntransition = None self.dyntransition = None
self.setcurrent = None self.setcurrent = None
def __getitem__(self, key):
# This is implemented so this object can be used in NetworkX
# functions that operate on (source, target) tuples
if isinstance(key, slice):
return [self._index_to_item(i) for i in range(* key.indices(2))]
else: else:
return self._index_to_item(key) raise ValueError("Edge does not exist in graph")
def _index_to_item(self, index: int) -> Type:
"""Return source or target based on index."""
if index == 0:
return self.source
elif index == 1:
return self.target
else:
raise IndexError("Invalid index (edges only have 2 items): {0}".format(index))

View File

@ -5,6 +5,7 @@
import itertools import itertools
import logging import logging
from contextlib import suppress from contextlib import suppress
from dataclasses import dataclass, InitVar
from typing import cast, Iterable, List, Mapping, Optional, Union from typing import cast, Iterable, List, Mapping, Optional, Union
try: try:
@ -14,6 +15,7 @@ except ImportError:
logging.getLogger(__name__).debug("NetworkX failed to import.") logging.getLogger(__name__).debug("NetworkX failed to import.")
from .descriptors import EdgeAttrIntMax, EdgeAttrList from .descriptors import EdgeAttrIntMax, EdgeAttrList
from .mixins import NetworkXGraphEdge
from .permmap import PermissionMap from .permmap import PermissionMap
from .policyrep import AVRule, SELinuxPolicy, TERuletype, Type from .policyrep import AVRule, SELinuxPolicy, TERuletype, Type
@ -390,7 +392,8 @@ class InfoFlowAnalysis:
nx.number_of_edges(self.subG))) nx.number_of_edges(self.subG)))
class InfoFlowStep: @dataclass
class InfoFlowStep(NetworkXGraphEdge):
""" """
A graph edge. Also used for returning information flow steps. A graph edge. Also used for returning information flow steps.
@ -405,6 +408,10 @@ class InfoFlowStep:
The default is False. The default is False.
""" """
G: nx.DiGraph
source: Type
target: Type
create: InitVar[bool] = False
rules = EdgeAttrList() rules = EdgeAttrList()
# use capacity to store the info flow weight so # use capacity to store the info flow weight so
@ -414,32 +421,11 @@ class InfoFlowStep:
# (see below add_edge() call) # (see below add_edge() call)
weight = EdgeAttrIntMax('capacity') weight = EdgeAttrIntMax('capacity')
def __init__(self, graph, source: Type, target: Type, create: bool = False) -> None: def __post_init__(self, create) -> None:
self.G = graph if not self.G.has_edge(self.source, self.target):
self.source = source
self.target = target
if not self.G.has_edge(source, target):
if create: if create:
self.G.add_edge(source, target, weight=1) self.G.add_edge(self.source, self.target, weight=1)
self.rules = None self.rules = None
self.weight = None self.weight = None
else: else:
raise ValueError("InfoFlowStep does not exist in graph") raise ValueError("InfoFlowStep does not exist in graph")
def __getitem__(self, key):
# This is implemented so this object can be used in NetworkX
# functions that operate on (source, target) tuples
if isinstance(key, slice):
return [self._index_to_item(i) for i in range(* key.indices(2))]
else:
return self._index_to_item(key)
def _index_to_item(self, index):
"""Return source or target based on index."""
if index == 0:
return self.source
elif index == 1:
return self.target
else:
raise IndexError("Invalid index (edges only have 2 items): {0}".format(index))

View File

@ -6,7 +6,7 @@
# pylint: disable=attribute-defined-outside-init,no-member # pylint: disable=attribute-defined-outside-init,no-member
import re import re
from logging import Logger from logging import Logger
from typing import Iterable from typing import Any
from .descriptors import CriteriaDescriptor, CriteriaSetDescriptor, CriteriaPermissionSetDescriptor from .descriptors import CriteriaDescriptor, CriteriaSetDescriptor, CriteriaPermissionSetDescriptor
from .policyrep import Context from .policyrep import Context
@ -208,3 +208,28 @@ class MatchPermission:
return obj.perms >= self.perms return obj.perms >= self.perms
else: else:
return match_regex_or_set(obj.perms, self.perms, self.perms_equal, self.perms_regex) return match_regex_or_set(obj.perms, self.perms, self.perms_equal, self.perms_regex)
class NetworkXGraphEdge:
"""Mixin enabling use in NetworkX functions."""
source: Any
target: Any
def __getitem__(self, key):
# This is implemented so this object can be used in NetworkX
# functions that operate on (source, target) tuples
if isinstance(key, slice):
return [self._index_to_item(i) for i in range(* key.indices(2))]
else:
return self._index_to_item(key)
def _index_to_item(self, index: int):
"""Return source or target based on index."""
if index == 0:
return self.source
elif index == 1:
return self.target
else:
raise IndexError(f"Invalid index (NetworkXGraphEdge only has 2 items): {index}")

View File

@ -48,6 +48,7 @@ commands = mypy -p setools
passenv = USERSPACE_SRC passenv = USERSPACE_SRC
deps = networkx>=2.0 deps = networkx>=2.0
cython>=0.27 cython>=0.27
py36: dataclasses
py38: cython>=0.29.14 py38: cython>=0.29.14
py39: networkx>=2.6 py39: networkx>=2.6
py39: cython>=0.29.14 py39: cython>=0.29.14