diff --git a/sedta b/sedta index b1dd836..fba56c7 100755 --- a/sedta +++ b/sedta @@ -13,51 +13,6 @@ import warnings import setools -def print_transition(trans: setools.DomainTransition) -> None: - print() - - if trans.transition: - print("Domain transition rule(s):") - for t in trans.transition: - print(t) - - if trans.setexec: - print("\nSet execution context rule(s):") - for s in trans.setexec: - print(s) - - for entrypoint in trans.entrypoints: - print("\nEntrypoint {0}:".format(entrypoint.name)) - - print("\tDomain entrypoint rule(s):") - for e in entrypoint.entrypoint: - print("\t{0}".format(e)) - - print("\n\tFile execute rule(s):") - for e in entrypoint.execute: - print("\t{0}".format(e)) - - if entrypoint.type_transition: - print("\n\tType transition rule(s):") - for t in entrypoint.type_transition: - print("\t{0}".format(t)) - - print() - - if trans.dyntransition: - print("Dynamic transition rule(s):") - for d in trans.dyntransition: - print(d) - - print("\nSet current process context rule(s):") - for s in trans.setcurrent: - print(s) - - print() - - print() - - signal.signal(signal.SIGPIPE, signal.SIG_DFL) parser = argparse.ArgumentParser( @@ -113,6 +68,10 @@ try: p = setools.SELinuxPolicy(args.policy) g = setools.DomainTransitionAnalysis(p, exclude=args.exclude) + pathnum: int = 0 + path: setools.DTAPath + stepnum: int = 0 + step: setools.DomainTransition if args.shortest_path or args.all_paths: g.source = args.source g.target = args.target @@ -123,23 +82,21 @@ try: g.mode = setools.DomainTransitionAnalysis.Mode.AllPaths g.all_paths_max_steps = args.all_paths - i = 0 - for i, path in enumerate(g.results(), start=1): - print("Domain transition path {0}:".format(i)) + for pathnum, path in enumerate(g.results(), start=1): # type: ignore + print(f"Domain transition path {pathnum}:") for stepnum, step in enumerate(path, start=1): - - print("Step {0}: {1} -> {2}".format(stepnum, step.source, step.target)) - if args.full: - print_transition(step) + print(f"Step {stepnum}: {step:full}\n") + else: + print(f"Step {stepnum}: {step}") - if args.limit_trans and i >= args.limit_trans: + if args.limit_trans and pathnum >= args.limit_trans: break print() - print("\n{} domain transition path(s) found.".format(i)) + print(f"\n{pathnum} domain transition path(s) found.") else: # single transition if args.reverse: @@ -149,19 +106,16 @@ try: g.mode = setools.DomainTransitionAnalysis.Mode.TransitionsOut g.source = args.source - transitions = g.results() - - i = 0 - for i, step in enumerate(transitions, start=1): - print("Transition {0}: {1} -> {2}".format(i, step.source, step.target)) - + for pathnum, step in enumerate(g.results(), start=1): # type: ignore if args.full: - print_transition(step) + print(f"Transition {pathnum}: {step:full}\n") + else: + print(f"Transition {pathnum}: {step}") - if args.limit_trans and i >= args.limit_trans: + if args.limit_trans and pathnum >= args.limit_trans: break - print("\n{} domain transition(s) found.".format(i)) + print(f"\n{pathnum} domain transition(s) found.") if args.stats: print("\nGraph statistics:") diff --git a/setools/dta.py b/setools/dta.py index 9a7e8b2..261e14f 100644 --- a/setools/dta.py +++ b/setools/dta.py @@ -40,6 +40,24 @@ class DomainEntrypoint: execute: List[AnyTERule] type_transition: List[AnyTERule] + def __lt__(self, other: "DomainEntrypoint") -> bool: + # basic comparison for sorting + return self.name < other.name + + def __str__(self) -> str: + lines: List[str] = [f"\nEntrypoint {self.name}:", + "\tDomain entrypoint rule(s):"] + lines.extend(f"\t{e}" for e in sorted(self.entrypoint)) + + lines.append("\n\tFile execute rule(s):") + lines.extend(f"\t{e}" for e in sorted(self.execute)) + + if self.type_transition: + lines.append("\n\tType transition rule(s):") + lines.extend(f"\t{t}" for t in sorted(self.type_transition)) + + return "\n".join(lines) + @dataclass class DomainTransition: @@ -54,6 +72,38 @@ class DomainTransition: dyntransition: List[AnyTERule] setcurrent: List[AnyTERule] + def __format__(self, spec: str) -> str: + lines: List[str] = [f"{self.source} -> {self.target}\n"] + if spec == "full": + if self.transition: + lines.append("Domain transition rule(s):") + lines.extend(str(t) for t in sorted(self.transition)) + + if self.setexec: + lines.append("\nSet execution context rule(s):") + lines.extend(str(s) for s in sorted(self.setexec)) + + lines.extend(f"{e}\n" for e in sorted(self.entrypoints)) + + if self.dyntransition: + lines.append("Dynamic transition rule(s):") + lines.extend(str(d) for d in sorted(self.dyntransition)) + + lines.append("\nSet current process context rule(s):") + lines.extend(str(s) for s in sorted(self.setcurrent)) + + lines.append("") + + return "\n".join(lines) + + if not spec: + return lines[0] + + return super().__format__(spec) + + def __str__(self) -> str: + return self.__format__("full") + # # Typing