dta, infoflow: Use dataclasses where relevant.

Signed-off-by: Chris PeBenito <chpebeni@linux.microsoft.com>
This commit is contained in:
Chris PeBenito 2023-02-02 16:25:42 -05:00
parent 51a3e3aa98
commit ebbfef2482
2 changed files with 30 additions and 24 deletions

View File

@ -8,7 +8,8 @@ import itertools
import logging import logging
from collections import defaultdict from collections import defaultdict
from contextlib import suppress from contextlib import suppress
from typing import DefaultDict, Iterable, List, NamedTuple, Optional, Union from dataclasses import dataclass, InitVar
from typing import DefaultDict, Iterable, List, Optional, Union
try: try:
import networkx as nx import networkx as nx
@ -22,9 +23,10 @@ from .policyrep import AnyTERule, SELinuxPolicy, TERuletype, Type
__all__ = ['DomainTransitionAnalysis', 'DomainTransition', 'DomainEntrypoint', 'DTAPath'] __all__ = ['DomainTransitionAnalysis', 'DomainTransition', 'DomainEntrypoint', 'DTAPath']
class DomainEntrypoint(NamedTuple): @dataclass
class DomainEntrypoint:
"""Entrypoint list entry named tuple output format.""" """Entrypoint list entry."""
name: Type name: Type
entrypoint: List[AnyTERule] entrypoint: List[AnyTERule]
@ -32,9 +34,10 @@ class DomainEntrypoint(NamedTuple):
type_transition: List[AnyTERule] type_transition: List[AnyTERule]
class DomainTransition(NamedTuple): @dataclass
class DomainTransition:
"""Transition step output named tuple format.""" """Transition step output."""
source: Type source: Type
target: Type target: Type
@ -580,6 +583,7 @@ class DomainTransitionAnalysis:
nx.number_of_edges(self.subG))) nx.number_of_edges(self.subG)))
@dataclass
class Edge: class Edge:
""" """
@ -595,6 +599,10 @@ class Edge:
The default is False. The default is False.
""" """
G: nx.DiGraph
source: Type
target: Type
create: InitVar[bool] = False
transition = EdgeAttrList() transition = EdgeAttrList()
setexec = EdgeAttrList() setexec = EdgeAttrList()
dyntransition = EdgeAttrList() dyntransition = EdgeAttrList()
@ -603,16 +611,10 @@ class Edge:
execute = EdgeAttrDict() execute = EdgeAttrDict()
type_transition = EdgeAttrDict() type_transition = EdgeAttrDict()
def __init__(self, graph, source: Type, target: Type, create: bool = False) -> None: def __post_init__(self, create) -> None:
self.G = graph if not self.G.has_edge(self.source, self.target):
self.source: Type = source if create:
self.target: Type = target self.G.add_edge(self.source, self.target)
if not self.G.has_edge(source, target):
if not create:
raise ValueError("Edge does not exist in graph")
else:
self.G.add_edge(source, target)
self.transition = None self.transition = None
self.entrypoint = None self.entrypoint = None
self.execute = None self.execute = None
@ -620,6 +622,8 @@ class Edge:
self.setexec = None self.setexec = None
self.dyntransition = None self.dyntransition = None
self.setcurrent = None self.setcurrent = None
else:
raise ValueError("Edge does not exist in graph")
def __getitem__(self, key): def __getitem__(self, key):
# This is implemented so this object can be used in NetworkX # This is implemented so this object can be used in NetworkX

View File

@ -5,6 +5,7 @@
import itertools import itertools
import logging import logging
from contextlib import suppress from contextlib import suppress
from dataclasses import dataclass, InitVar
from typing import cast, Iterable, List, Mapping, Optional, Union from typing import cast, Iterable, List, Mapping, Optional, Union
try: try:
@ -390,6 +391,7 @@ class InfoFlowAnalysis:
nx.number_of_edges(self.subG))) nx.number_of_edges(self.subG)))
@dataclass
class InfoFlowStep: class InfoFlowStep:
""" """
@ -405,6 +407,10 @@ class InfoFlowStep:
The default is False. The default is False.
""" """
G: nx.DiGraph
source: Type
target: Type
create: InitVar[bool] = False
rules = EdgeAttrList() rules = EdgeAttrList()
# use capacity to store the info flow weight so # use capacity to store the info flow weight so
@ -414,14 +420,10 @@ class InfoFlowStep:
# (see below add_edge() call) # (see below add_edge() call)
weight = EdgeAttrIntMax('capacity') weight = EdgeAttrIntMax('capacity')
def __init__(self, graph, source: Type, target: Type, create: bool = False) -> None: def __post_init__(self, create) -> None:
self.G = graph if not self.G.has_edge(self.source, self.target):
self.source = source
self.target = target
if not self.G.has_edge(source, target):
if create: if create:
self.G.add_edge(source, target, weight=1) self.G.add_edge(self.source, self.target, weight=1)
self.rules = None self.rules = None
self.weight = None self.weight = None
else: else:
@ -435,11 +437,11 @@ class InfoFlowStep:
else: else:
return self._index_to_item(key) return self._index_to_item(key)
def _index_to_item(self, index): def _index_to_item(self, index: int) -> Type:
"""Return source or target based on index.""" """Return source or target based on index."""
if index == 0: if index == 0:
return self.source return self.source
elif index == 1: elif index == 1:
return self.target return self.target
else: else:
raise IndexError("Invalid index (edges only have 2 items): {0}".format(index)) raise IndexError("Invalid index (InfoFlowSteps only have 2 items): {0}".format(index))