diff --git a/setools/dta.py b/setools/dta.py index 3f87677..31faa45 100644 --- a/setools/dta.py +++ b/setools/dta.py @@ -8,7 +8,8 @@ import itertools import logging from collections import defaultdict from contextlib import suppress -from typing import DefaultDict, Iterable, List, NamedTuple, Optional, Union +from dataclasses import dataclass, InitVar +from typing import DefaultDict, Iterable, List, Optional, Union try: import networkx as nx @@ -22,9 +23,10 @@ from .policyrep import AnyTERule, SELinuxPolicy, TERuletype, Type __all__ = ['DomainTransitionAnalysis', 'DomainTransition', 'DomainEntrypoint', 'DTAPath'] -class DomainEntrypoint(NamedTuple): +@dataclass +class DomainEntrypoint: - """Entrypoint list entry named tuple output format.""" + """Entrypoint list entry.""" name: Type entrypoint: List[AnyTERule] @@ -32,9 +34,10 @@ class DomainEntrypoint(NamedTuple): type_transition: List[AnyTERule] -class DomainTransition(NamedTuple): +@dataclass +class DomainTransition: - """Transition step output named tuple format.""" + """Transition step output.""" source: Type target: Type @@ -580,6 +583,7 @@ class DomainTransitionAnalysis: nx.number_of_edges(self.subG))) +@dataclass class Edge: """ @@ -595,6 +599,10 @@ class Edge: The default is False. """ + G: nx.DiGraph + source: Type + target: Type + create: InitVar[bool] = False transition = EdgeAttrList() setexec = EdgeAttrList() dyntransition = EdgeAttrList() @@ -603,16 +611,10 @@ class Edge: execute = EdgeAttrDict() type_transition = EdgeAttrDict() - def __init__(self, graph, source: Type, target: Type, create: bool = False) -> None: - self.G = graph - self.source: Type = source - self.target: Type = target - - if not self.G.has_edge(source, target): - if not create: - raise ValueError("Edge does not exist in graph") - else: - self.G.add_edge(source, target) + def __post_init__(self, create) -> None: + if not self.G.has_edge(self.source, self.target): + if create: + self.G.add_edge(self.source, self.target) self.transition = None self.entrypoint = None self.execute = None @@ -620,6 +622,8 @@ class Edge: self.setexec = None self.dyntransition = None self.setcurrent = None + else: + raise ValueError("Edge does not exist in graph") def __getitem__(self, key): # This is implemented so this object can be used in NetworkX diff --git a/setools/infoflow.py b/setools/infoflow.py index cc45a34..177cd9c 100644 --- a/setools/infoflow.py +++ b/setools/infoflow.py @@ -5,6 +5,7 @@ import itertools import logging from contextlib import suppress +from dataclasses import dataclass, InitVar from typing import cast, Iterable, List, Mapping, Optional, Union try: @@ -390,6 +391,7 @@ class InfoFlowAnalysis: nx.number_of_edges(self.subG))) +@dataclass class InfoFlowStep: """ @@ -405,6 +407,10 @@ class InfoFlowStep: The default is False. """ + G: nx.DiGraph + source: Type + target: Type + create: InitVar[bool] = False rules = EdgeAttrList() # use capacity to store the info flow weight so @@ -414,14 +420,10 @@ class InfoFlowStep: # (see below add_edge() call) weight = EdgeAttrIntMax('capacity') - def __init__(self, graph, source: Type, target: Type, create: bool = False) -> None: - self.G = graph - self.source = source - self.target = target - - if not self.G.has_edge(source, target): + def __post_init__(self, create) -> None: + if not self.G.has_edge(self.source, self.target): if create: - self.G.add_edge(source, target, weight=1) + self.G.add_edge(self.source, self.target, weight=1) self.rules = None self.weight = None else: @@ -435,11 +437,11 @@ class InfoFlowStep: else: return self._index_to_item(key) - def _index_to_item(self, index): + def _index_to_item(self, index: int) -> Type: """Return source or target based on index.""" if index == 0: return self.source elif index == 1: return self.target else: - raise IndexError("Invalid index (edges only have 2 items): {0}".format(index)) + raise IndexError("Invalid index (InfoFlowSteps only have 2 items): {0}".format(index))