policyrep: Refactor MLS classes to load most attributes on construction.

Category sets and aliases deferred still, to prevent too much construction
up-front. However, the results are saved, rather than lost.
This commit is contained in:
Chris PeBenito 2018-08-06 15:44:44 -04:00
parent bfdb1e66d8
commit 745a7ae9bd

View File

@ -47,28 +47,34 @@ cdef class Category(PolicySymbol):
"""An MLS category.""" """An MLS category."""
cdef sepol.cat_datum_t *handle cdef:
sepol.cat_datum_t *handle
readonly str name
readonly uint32_t _value
list _aliases
@staticmethod @staticmethod
cdef factory(SELinuxPolicy policy, sepol.cat_datum_t *symbol): cdef inline Category factory(SELinuxPolicy policy, sepol.cat_datum_t *symbol):
"""Factory function for creating Category objects.""" """Factory function for creating Category objects."""
cdef Category c
if not policy.mls: if not policy.mls:
raise MLSDisabled raise MLSDisabled
try: try:
return _cat_cache[<uintptr_t>symbol] return _cat_cache[<uintptr_t>symbol]
except KeyError: except KeyError:
c = Category() c = Category.__new__(Category)
c.policy = policy c.policy = policy
c.handle = symbol c.name = policy.category_value_to_name(symbol.s.value - 1)
c._value = symbol.s.value
_cat_cache[<uintptr_t>symbol] = c _cat_cache[<uintptr_t>symbol] = c
return c return c
def __str__(self): def __str__(self):
return self.policy.category_value_to_name(self.handle.s.value - 1) return self.name
def __hash__(self): def __hash__(self):
return hash(str(self)) return hash(self.name)
def __lt__(self, other): def __lt__(self, other):
# Comparison based on their index instead of their names. # Comparison based on their index instead of their names.
@ -78,27 +84,29 @@ cdef class Category(PolicySymbol):
"""Low-level equality check (C pointers).""" """Low-level equality check (C pointers)."""
return self.handle == other.handle return self.handle == other.handle
@property cdef inline void _load_aliases(self):
def _value(self): """Helper method to load aliases."""
""" if self._aliases is None:
The value of the category. self._aliases = list(self.policy.category_aliases(self))
This is a low-level policy detail exposed for internal use only.
"""
return self.handle.s.value
def aliases(self): def aliases(self):
"""Generator that yields all aliases for this category.""" """Generator that yields all aliases for this category."""
self._load_aliases()
return self.policy.category_aliases(self) return self.policy.category_aliases(self)
def statement(self): def statement(self):
aliases = list(self.aliases()) cdef:
stmt = "category {0}".format(self) str stmt
if aliases: size_t count
if len(aliases) > 1:
stmt += " alias {{ {0} }}".format(' '.join(aliases)) self._load_aliases()
else: count = len(self._aliases)
stmt += " alias {0}".format(aliases[0])
stmt = "category {0}".format(self.name)
if count > 1:
stmt += " alias {{ {0} }}".format(' '.join(self._aliases))
elif count == 1:
stmt += " alias {0}".format(self._aliases[0])
stmt += ";" stmt += ";"
return stmt return stmt
@ -107,28 +115,36 @@ cdef class Sensitivity(PolicySymbol):
"""An MLS sensitivity""" """An MLS sensitivity"""
cdef sepol.level_datum_t *handle cdef:
sepol.level_datum_t *handle
readonly str name
readonly uint32_t _value
list _aliases
LevelDecl _leveldecl
@staticmethod @staticmethod
cdef factory(SELinuxPolicy policy, sepol.level_datum_t *symbol): cdef inline Sensitivity factory(SELinuxPolicy policy, sepol.level_datum_t *symbol):
"""Factory function for creating Sensitivity objects.""" """Factory function for creating Sensitivity objects."""
cdef Sensitivity s
if not policy.mls: if not policy.mls:
raise MLSDisabled raise MLSDisabled
try: try:
return _sens_cache[<uintptr_t>symbol] return _sens_cache[<uintptr_t>symbol]
except KeyError: except KeyError:
s = Sensitivity() s = Sensitivity.__new__(Sensitivity)
_sens_cache[<uintptr_t>symbol] = s
s.policy = policy s.policy = policy
s.handle = symbol s.handle = symbol
_sens_cache[<uintptr_t>symbol] = s s.name = policy.level_value_to_name(symbol.level.sens - 1)
s._value = symbol.level.sens
return s return s
def __str__(self): def __str__(self):
return self.policy.level_value_to_name(self.handle.level.sens - 1) return self.name
def __hash__(self): def __hash__(self):
return hash(str(self)) return hash(self.name)
def __ge__(self, other): def __ge__(self, other):
return self._value >= other._value return self._value >= other._value
@ -146,31 +162,36 @@ cdef class Sensitivity(PolicySymbol):
"""Low-level equality check (C pointers).""" """Low-level equality check (C pointers)."""
return self.handle == other.handle return self.handle == other.handle
@property cdef inline void _load_aliases(self):
def _value(self): """Helper method to load aliases."""
""" if self._aliases is None:
The value of the component. self._aliases = list(self.policy.sensitivity_aliases(self))
This is a low-level policy detail exposed for internal use only.
"""
return self.handle.level.sens
def aliases(self): def aliases(self):
"""Generator that yields all aliases for this sensitivity.""" """Generator that yields all aliases for this sensitivity."""
return self.policy.sensitivity_aliases(self) self._load_aliases()
return iter(self._aliases)
def level_decl(self): def level_decl(self):
"""Get the level declaration corresponding to this sensitivity.""" """Get the level declaration corresponding to this sensitivity."""
return LevelDecl.factory(self.policy, self.handle) if self._leveldecl is None:
self._leveldecl = LevelDecl.factory(self.policy, self.handle)
return self._leveldecl
def statement(self): def statement(self):
aliases = list(self.aliases()) cdef:
stmt = "sensitivity {0}".format(self) str stmt
if aliases: size_t count
if len(aliases) > 1:
stmt += " alias {{ {0} }}".format(' '.join(aliases)) self._load_aliases()
else: count = len(self._aliases)
stmt += " alias {0}".format(aliases[0])
stmt = "sensitivity {0}".format(self.name)
if count > 1:
stmt += " alias {{ {0} }}".format(' '.join(self._aliases))
elif count == 1:
stmt += " alias {0}".format(self._aliases[0])
stmt += ";" stmt += ";"
return stmt return stmt
@ -179,11 +200,15 @@ cdef class BaseMLSLevel(PolicySymbol):
"""Base class for MLS levels.""" """Base class for MLS levels."""
cdef:
set _categories
readonly Sensitivity sensitivity
def __str__(self): def __str__(self):
lvl = str(self.sensitivity) lvl = str(self.sensitivity)
# sort by policy declaration order # sort by policy declaration order
cats = sorted(self.categories(), key=lambda k: k._value) cats = sorted(self._categories, key=lambda k: k._value)
if cats: if cats:
# generate short category notation # generate short category notation
@ -206,11 +231,7 @@ cdef class BaseMLSLevel(PolicySymbol):
All categories are yielded, not a compact notation such as All categories are yielded, not a compact notation such as
c0.c255 c0.c255
""" """
raise NotImplementedError return iter(self._categories)
@property
def sensitivity(self):
raise NotImplementedError
cdef class LevelDecl(BaseMLSLevel): cdef class LevelDecl(BaseMLSLevel):
@ -221,21 +242,22 @@ cdef class LevelDecl(BaseMLSLevel):
level s7:c0.c1023; level s7:c0.c1023;
""" """
cdef sepol.level_datum_t *handle
@staticmethod @staticmethod
cdef factory(SELinuxPolicy policy, sepol.level_datum_t *symbol): cdef inline LevelDecl factory(SELinuxPolicy policy, sepol.level_datum_t *symbol):
"""Factory function for creating LevelDecl objects.""" """Factory function for creating LevelDecl objects."""
cdef LevelDecl l
if not policy.mls: if not policy.mls:
raise MLSDisabled raise MLSDisabled
try: try:
return _leveldecl_cache[<uintptr_t>symbol] return _leveldecl_cache[<uintptr_t>symbol]
except KeyError: except KeyError:
l = LevelDecl() l = LevelDecl.__new__(LevelDecl)
l.policy = policy
l.handle = symbol
_leveldecl_cache[<uintptr_t>symbol] = l _leveldecl_cache[<uintptr_t>symbol] = l
l.policy = policy
l._categories = set(CategoryEbitmapIterator.factory(policy, &symbol.level.cat))
# the datum for levels is also used for Sensitivity objects
l.sensitivity = Sensitivity.factory(policy, symbol)
return l return l
def __hash__(self): def __hash__(self):
@ -271,25 +293,6 @@ cdef class LevelDecl(BaseMLSLevel):
assert not isinstance(other, Level), "Levels cannot be compared to level declarations" assert not isinstance(other, Level), "Levels cannot be compared to level declarations"
return self.sensitivity < other.sensitivity return self.sensitivity < other.sensitivity
def _eq(self, LevelDecl other):
"""Low-level equality check (C pointers)."""
return self.handle == other.handle
def categories(self):
"""
Generator that yields all individual categories for this level.
All categories are yielded, not a compact notation such as
c0.c255
"""
return CategoryEbitmapIterator.factory(self.policy, &self.handle.level.cat)
@property
def sensitivity(self):
"""The sensitivity of the level."""
# since the datum for levels is also used for
# Sensitivity objects, use self's datum
return Sensitivity.factory(self.policy, self.handle)
def statement(self): def statement(self):
return "level {0};".format(self) return "level {0};".format(self)
@ -303,39 +306,43 @@ cdef class Level(BaseMLSLevel):
if the level is user-generated. if the level is user-generated.
""" """
cdef:
sepol.mls_level_t *handle
list _categories
Sensitivity _sensitivity
@staticmethod @staticmethod
cdef factory(SELinuxPolicy policy, sepol.mls_level_t *symbol): cdef inline Level factory(SELinuxPolicy policy, sepol.mls_level_t *symbol):
"""Factory function for creating Level objects.""" """Factory function for creating Level objects."""
if not policy.mls: if not policy.mls:
raise MLSDisabled raise MLSDisabled
l = Level() cdef Level l = Level.__new__(Level)
l.policy = policy l.policy = policy
l.handle = symbol l.sensitivity = Sensitivity.factory(policy, policy.level_value_to_datum(symbol.sens - 1))
l._categories = set(CategoryEbitmapIterator.factory(policy, &symbol.cat))
return l return l
@staticmethod @staticmethod
cdef factory_from_string(SELinuxPolicy policy, str name): cdef inline Level factory_from_string(SELinuxPolicy policy, str name):
"""Factory function variant for constructing Level objects by a string.""" """Factory function variant for constructing Level objects by a string."""
if not policy.mls: if not policy.mls:
raise MLSDisabled raise MLSDisabled
sens_split = name.split(":") cdef:
sens = sens_split[0] Level l = Level.__new__(Level)
list sens_split = name.split(":")
str sens = sens_split[0]
Sensitivity s
list c
str cats
list catrange
str group
l.policy = policy
try: try:
s = policy.lookup_sensitivity(sens) l.sensitivity = policy.lookup_sensitivity(sens)
except InvalidSensitivity as ex: except InvalidSensitivity as ex:
raise InvalidLevel("{0} is not a valid level ({1} is not a valid sensitivity)". \ raise InvalidLevel("{0} is not a valid level ({1} is not a valid sensitivity)". \
format(name, sens)) from ex format(name, sens)) from ex
c = [] l._categories = set()
try: try:
cats = sens_split[1] cats = sens_split[1]
@ -346,9 +353,9 @@ cdef class Level(BaseMLSLevel):
catrange = group.split(".") catrange = group.split(".")
if len(catrange) == 2: if len(catrange) == 2:
try: try:
c.extend(expand_cat_range(policy, l._categories.update(expand_cat_range(policy,
policy.lookup_category(catrange[0]), policy.lookup_category(catrange[0]),
policy.lookup_category(catrange[1]))) policy.lookup_category(catrange[1])))
except InvalidCategory as ex: except InvalidCategory as ex:
raise InvalidLevel( raise InvalidLevel(
"{0} is not a valid level ({1} is not a valid category range)". "{0} is not a valid level ({1} is not a valid category range)".
@ -356,7 +363,7 @@ cdef class Level(BaseMLSLevel):
elif len(catrange) == 1: elif len(catrange) == 1:
try: try:
c.append(policy.lookup_category(catrange[0])) l._categories.add(policy.lookup_category(catrange[0]))
except InvalidCategory as ex: except InvalidCategory as ex:
raise InvalidLevel("{0} is not a valid level ({1} is not a valid category)". raise InvalidLevel("{0} is not a valid level ({1} is not a valid category)".
format(name, group)) from ex format(name, group)) from ex
@ -364,15 +371,8 @@ cdef class Level(BaseMLSLevel):
else: else:
raise InvalidLevel("{0} is not a valid level (level parsing error)".format(name)) raise InvalidLevel("{0} is not a valid level (level parsing error)".format(name))
# build object
l = Level()
l.policy = policy
l.handle = NULL
l._sensitivity = s
l._categories = c
# verify level is valid # verify level is valid
if not l <= s.level_decl(): if not l <= l.sensitivity.level_decl():
raise InvalidLevel( raise InvalidLevel(
"{0} is not a valid level (one or more categories are not associated with the " "{0} is not a valid level (one or more categories are not associated with the "
"sensitivity)".format(name)) "sensitivity)".format(name))
@ -388,61 +388,32 @@ cdef class Level(BaseMLSLevel):
except AttributeError: except AttributeError:
return str(self) == str(other) return str(self) == str(other)
else: else:
selfcats = set(self.categories()) return self.sensitivity == other.sensitivity and self._categories == othercats
return self.sensitivity == other.sensitivity and selfcats == othercats
def __ge__(self, other): def __ge__(self, other):
# Dom operator # Dom operator
selfcats = set(self.categories())
othercats = set(other.categories()) othercats = set(other.categories())
return self.sensitivity >= other.sensitivity and selfcats >= othercats return self.sensitivity >= other.sensitivity and self._categories >= othercats
def __gt__(self, other): def __gt__(self, other):
selfcats = set(self.categories())
othercats = set(other.categories()) othercats = set(other.categories())
return ((self.sensitivity > other.sensitivity and selfcats >= othercats) or return ((self.sensitivity > other.sensitivity and self._categories >= othercats) or
(self.sensitivity >= other.sensitivity and selfcats > othercats)) (self.sensitivity >= other.sensitivity and self._categories > othercats))
def __le__(self, other): def __le__(self, other):
# Domby operator # Domby operator
selfcats = set(self.categories())
othercats = set(other.categories()) othercats = set(other.categories())
return self.sensitivity <= other.sensitivity and selfcats <= othercats return self.sensitivity <= other.sensitivity and self._categories <= othercats
def __lt__(self, other): def __lt__(self, other):
selfcats = set(self.categories())
othercats = set(other.categories()) othercats = set(other.categories())
return ((self.sensitivity < other.sensitivity and selfcats <= othercats) or return ((self.sensitivity < other.sensitivity and self._categories <= othercats) or
(self.sensitivity <= other.sensitivity and selfcats < othercats)) (self.sensitivity <= other.sensitivity and self._categories < othercats))
def __xor__(self, other): def __xor__(self, other):
# Incomp operator # Incomp operator
return not (self >= other or self <= other) return not (self >= other or self <= other)
def _eq(self, Level other):
"""Low-level equality check (C pointers)."""
return self.handle == other.handle
def categories(self):
"""
Generator that yields all individual categories for this level.
All categories are yielded, not a compact notation such as
c0.c255
"""
if self.handle == NULL:
return iter(self._categories)
else:
return CategoryEbitmapIterator.factory(self.policy, &self.handle.cat)
@property
def sensitivity(self):
"""The sensitivity of the level."""
if self.handle == NULL:
return self._sensitivity
else:
return Sensitivity.factory(self.policy,
self.policy.level_value_to_datum(self.handle.sens - 1))
def statement(self): def statement(self):
raise NoStatement raise NoStatement
@ -452,63 +423,58 @@ cdef class Range(PolicySymbol):
"""An MLS range""" """An MLS range"""
cdef: cdef:
sepol.mls_range_t *handle readonly Level low
Level _low readonly Level high
Level _high
@staticmethod @staticmethod
cdef factory(SELinuxPolicy policy, sepol.mls_range_t *symbol): cdef inline Range factory(SELinuxPolicy policy, sepol.mls_range_t *symbol):
"""Factory function for creating Range objects.""" """Factory function for creating Range objects."""
if not policy.mls: if not policy.mls:
raise MLSDisabled raise MLSDisabled
r = Range() cdef Range r = Range.__new__(Range)
r.policy = policy r.policy = policy
r.handle = symbol r.low = Level.factory(policy, &symbol.level[0])
r.high = Level.factory(policy, &symbol.level[1])
return r return r
@staticmethod @staticmethod
cdef factory_from_string(SELinuxPolicy policy, str name): cdef inline Range factory_from_string(SELinuxPolicy policy, str name):
"""Factory function variant for constructing Range objects by name.""" """Factory function variant for constructing Range objects by name."""
if not policy.mls: if not policy.mls:
raise MLSDisabled raise MLSDisabled
cdef Range r = Range.__new__(Range)
r.policy = policy
# build range: # build range:
levels = name.split("-") cdef list levels = name.split("-")
# strip() levels to handle ranges with spaces in them, # strip() levels to handle ranges with spaces in them,
# e.g. s0:c1 - s0:c0.c255 # e.g. s0:c1 - s0:c0.c255
try: try:
low = Level.factory_from_string(policy, levels[0].strip()) r.low = Level.factory_from_string(policy, levels[0].strip())
except InvalidLevel as ex: except InvalidLevel as ex:
raise InvalidRange("{0} is not a valid range ({1}).".format(name, ex)) from ex raise InvalidRange("{0} is not a valid range ({1}).".format(name, ex)) from ex
try: try:
high = Level.factory_from_string(policy, levels[1].strip()) r.high = Level.factory_from_string(policy, levels[1].strip())
except InvalidLevel as ex: except InvalidLevel as ex:
raise InvalidRange("{0} is not a valid range ({1}).".format(name, ex)) from ex raise InvalidRange("{0} is not a valid range ({1}).".format(name, ex)) from ex
except IndexError: except IndexError:
high = low r.high = r.low
# verify high level dominates low range # verify high level dominates low range
if not high >= low: if not r.high >= r.low:
raise InvalidRange("{0} is not a valid range ({1} is not dominated by {2})". raise InvalidRange("{0} is not a valid range ({1.low} is not dominated by {1.high})".
format(name, low, high)) format(name, r))
r = Range()
r.policy = policy
r.handle = NULL
r._low = low
r._high = high
return r return r
def __str__(self): def __str__(self):
high = self.high if self.high == self.low:
low = self.low return str(self.low)
if high == low:
return str(low)
return "{0} - {1}".format(low, high) return "{0.low} - {0.high}".format(self)
def __hash__(self): def __hash__(self):
return hash(str(self)) return hash(str(self))
@ -527,26 +493,6 @@ cdef class Range(PolicySymbol):
def __contains__(self, other): def __contains__(self, other):
return self.low <= other <= self.high return self.low <= other <= self.high
def _eq(self, Range other):
"""Low-level equality check (C pointers)."""
return self.handle == other.handle
@property
def high(self):
"""The high end/clearance level of this range."""
if self.handle == NULL:
return self._high
else:
return Level.factory(self.policy, &self.handle.level[1])
@property
def low(self):
"""The low end/current level of this range."""
if self.handle == NULL:
return self._low
else:
return Level.factory(self.policy, &self.handle.level[0])
def statement(self): def statement(self):
raise NoStatement raise NoStatement