dta, infoflow: Use dataclasses where relevant.

Signed-off-by: Chris PeBenito <chpebeni@linux.microsoft.com>
This commit is contained in:
Chris PeBenito 2023-02-02 16:25:42 -05:00
parent 51a3e3aa98
commit ebbfef2482
2 changed files with 30 additions and 24 deletions

View File

@ -8,7 +8,8 @@ import itertools
import logging
from collections import defaultdict
from contextlib import suppress
from typing import DefaultDict, Iterable, List, NamedTuple, Optional, Union
from dataclasses import dataclass, InitVar
from typing import DefaultDict, Iterable, List, Optional, Union
try:
import networkx as nx
@ -22,9 +23,10 @@ from .policyrep import AnyTERule, SELinuxPolicy, TERuletype, Type
__all__ = ['DomainTransitionAnalysis', 'DomainTransition', 'DomainEntrypoint', 'DTAPath']
class DomainEntrypoint(NamedTuple):
@dataclass
class DomainEntrypoint:
"""Entrypoint list entry named tuple output format."""
"""Entrypoint list entry."""
name: Type
entrypoint: List[AnyTERule]
@ -32,9 +34,10 @@ class DomainEntrypoint(NamedTuple):
type_transition: List[AnyTERule]
class DomainTransition(NamedTuple):
@dataclass
class DomainTransition:
"""Transition step output named tuple format."""
"""Transition step output."""
source: Type
target: Type
@ -580,6 +583,7 @@ class DomainTransitionAnalysis:
nx.number_of_edges(self.subG)))
@dataclass
class Edge:
"""
@ -595,6 +599,10 @@ class Edge:
The default is False.
"""
G: nx.DiGraph
source: Type
target: Type
create: InitVar[bool] = False
transition = EdgeAttrList()
setexec = EdgeAttrList()
dyntransition = EdgeAttrList()
@ -603,16 +611,10 @@ class Edge:
execute = EdgeAttrDict()
type_transition = EdgeAttrDict()
def __init__(self, graph, source: Type, target: Type, create: bool = False) -> None:
self.G = graph
self.source: Type = source
self.target: Type = target
if not self.G.has_edge(source, target):
if not create:
raise ValueError("Edge does not exist in graph")
else:
self.G.add_edge(source, target)
def __post_init__(self, create) -> None:
if not self.G.has_edge(self.source, self.target):
if create:
self.G.add_edge(self.source, self.target)
self.transition = None
self.entrypoint = None
self.execute = None
@ -620,6 +622,8 @@ class Edge:
self.setexec = None
self.dyntransition = None
self.setcurrent = None
else:
raise ValueError("Edge does not exist in graph")
def __getitem__(self, key):
# This is implemented so this object can be used in NetworkX

View File

@ -5,6 +5,7 @@
import itertools
import logging
from contextlib import suppress
from dataclasses import dataclass, InitVar
from typing import cast, Iterable, List, Mapping, Optional, Union
try:
@ -390,6 +391,7 @@ class InfoFlowAnalysis:
nx.number_of_edges(self.subG)))
@dataclass
class InfoFlowStep:
"""
@ -405,6 +407,10 @@ class InfoFlowStep:
The default is False.
"""
G: nx.DiGraph
source: Type
target: Type
create: InitVar[bool] = False
rules = EdgeAttrList()
# use capacity to store the info flow weight so
@ -414,14 +420,10 @@ class InfoFlowStep:
# (see below add_edge() call)
weight = EdgeAttrIntMax('capacity')
def __init__(self, graph, source: Type, target: Type, create: bool = False) -> None:
self.G = graph
self.source = source
self.target = target
if not self.G.has_edge(source, target):
def __post_init__(self, create) -> None:
if not self.G.has_edge(self.source, self.target):
if create:
self.G.add_edge(source, target, weight=1)
self.G.add_edge(self.source, self.target, weight=1)
self.rules = None
self.weight = None
else:
@ -435,11 +437,11 @@ class InfoFlowStep:
else:
return self._index_to_item(key)
def _index_to_item(self, index):
def _index_to_item(self, index: int) -> Type:
"""Return source or target based on index."""
if index == 0:
return self.source
elif index == 1:
return self.target
else:
raise IndexError("Invalid index (edges only have 2 items): {0}".format(index))
raise IndexError("Invalid index (InfoFlowSteps only have 2 items): {0}".format(index))