Implement an enumeration for MLS rule type.

This commit is contained in:
Chris PeBenito 2016-09-03 16:40:54 -04:00
parent a441a92937
commit 4b5b6c0970
10 changed files with 70 additions and 51 deletions

View File

@ -77,7 +77,7 @@ rbacrtypes.add_argument("--role_trans", action="append_const",
mlsrtypes = parser.add_argument_group("MLS Rule Types") mlsrtypes = parser.add_argument_group("MLS Rule Types")
mlsrtypes.add_argument("--range_trans", action="append_const", mlsrtypes.add_argument("--range_trans", action="append_const",
const="range_transition", dest="mlsrtypes", const=setools.MLSRuletype.range_transition, dest="mlsrtypes",
help="Search range_transition rules.") help="Search range_transition rules.")
expr = parser.add_argument_group("Expressions") expr = parser.add_argument_group("Expressions")

View File

@ -18,6 +18,7 @@
# #
from collections import defaultdict, namedtuple from collections import defaultdict, namedtuple
from ..policyrep import MLSRuletype
from .descriptors import DiffResultDescriptor from .descriptors import DiffResultDescriptor
from .difference import Difference, SymbolWrapper, Wrapper from .difference import Difference, SymbolWrapper, Wrapper
from .mls import RangeWrapper from .mls import RangeWrapper
@ -51,8 +52,10 @@ class MLSRulesDifference(Difference):
self._create_mls_rule_lists() self._create_mls_rule_lists()
added, removed, matched = self._set_diff( added, removed, matched = self._set_diff(
self._expand_generator(self._left_mls_rules["range_transition"], MLSRuleWrapper), self._expand_generator(self._left_mls_rules[MLSRuletype.range_transition],
self._expand_generator(self._right_mls_rules["range_transition"], MLSRuleWrapper)) MLSRuleWrapper),
self._expand_generator(self._right_mls_rules[MLSRuletype.range_transition],
MLSRuleWrapper))
modified = [] modified = []

View File

@ -20,6 +20,7 @@ import logging
from .descriptors import CriteriaDescriptor, CriteriaSetDescriptor from .descriptors import CriteriaDescriptor, CriteriaSetDescriptor
from .mixins import MatchObjClass from .mixins import MatchObjClass
from .policyrep import MLSRuletype
from .query import PolicyQuery from .query import PolicyQuery
from .util import match_indirect_regex, match_range from .util import match_indirect_regex, match_range
@ -45,7 +46,7 @@ class MLSRuleQuery(MatchObjClass, PolicyQuery):
matching the rule's object class. matching the rule's object class.
""" """
ruletype = CriteriaSetDescriptor(lookup_function="validate_mls_ruletype") ruletype = CriteriaSetDescriptor(enum_class=MLSRuletype)
source = CriteriaDescriptor("source_regex", "lookup_type_or_attr") source = CriteriaDescriptor("source_regex", "lookup_type_or_attr")
source_regex = False source_regex = False
source_indirect = True source_indirect = True

View File

@ -23,6 +23,7 @@
from . import exception from . import exception
from .netcontext import PortconProtocol, PortconRange from .netcontext import PortconProtocol, PortconRange
from .mlsrule import MLSRuletype
from .rbacrule import RBACRuletype from .rbacrule import RBACRuletype
from .selinuxpolicy import SELinuxPolicy from .selinuxpolicy import SELinuxPolicy
from .terule import IoctlSet, TERuletype from .terule import IoctlSet, TERuletype

View File

@ -1,4 +1,5 @@
# Copyright 2014, 2016, Tresys Technology, LLC # Copyright 2014, 2016, Tresys Technology, LLC
# Copyright 2016, Chris PeBenito <pebenito@ieee.org>
# #
# This file is part of SETools. # This file is part of SETools.
# #
@ -23,6 +24,7 @@ from . import qpol
from . import rule from . import rule
from . import typeattr from . import typeattr
from . import mls from . import mls
from .util import PolicyEnum
def mls_rule_factory(policy, symbol): def mls_rule_factory(policy, symbol):
@ -57,10 +59,17 @@ def expanded_mls_rule_factory(original, source, target):
def validate_ruletype(t): def validate_ruletype(t):
"""Validate MLS rule types.""" """Validate MLS rule types."""
if t not in ["range_transition"]: try:
return MLSRuletype.lookup(t)
except KeyError:
raise exception.InvalidMLSRuleType("{0} is not a valid MLS rule type.".format(t)) raise exception.InvalidMLSRuleType("{0} is not a valid MLS rule type.".format(t))
return t
class MLSRuletype(PolicyEnum):
"""An enumeration of MLS rule types."""
range_transition = 1
class MLSRule(rule.PolicyRule): class MLSRule(rule.PolicyRule):
@ -70,7 +79,7 @@ class MLSRule(rule.PolicyRule):
def __str__(self): def __str__(self):
return "{0.ruletype} {0.source} {0.target}:{0.tclass} {0.default};".format(self) return "{0.ruletype} {0.source} {0.target}:{0.tclass} {0.default};".format(self)
ruletype = "range_transition" ruletype = MLSRuletype.range_transition
@property @property
def source(self): def source(self):

View File

@ -603,6 +603,8 @@ class SELinuxPolicy(object):
@staticmethod @staticmethod
def validate_mls_ruletype(types): def validate_mls_ruletype(types):
"""Validate MLS rule types.""" """Validate MLS rule types."""
warnings.warn("MLS ruletypes have changed to an enumeration.",
DeprecationWarning)
return mlsrule.validate_ruletype(types) return mlsrule.validate_ruletype(types)
@staticmethod @staticmethod

View File

@ -35,7 +35,7 @@ class MLSRuleTableModel(SEToolsTableModel):
if role == Qt.DisplayRole: if role == Qt.DisplayRole:
if col == 0: if col == 0:
return rule.ruletype return rule.ruletype.name
elif col == 1: elif col == 1:
return str(rule.source) return str(rule.source)
elif col == 2: elif col == 2:

View File

@ -20,6 +20,7 @@ import unittest
from socket import IPPROTO_TCP, IPPROTO_UDP from socket import IPPROTO_TCP, IPPROTO_UDP
from setools import SELinuxPolicy, PolicyDifference from setools import SELinuxPolicy, PolicyDifference
from setools import MLSRuletype as MRT
from setools import RBACRuletype as RRT from setools import RBACRuletype as RRT
from setools import TERuletype as TRT from setools import TERuletype as TRT
@ -749,11 +750,11 @@ class PolicyDifferenceTest(ValidateRule, unittest.TestCase):
self.assertEqual(2, len(rules)) self.assertEqual(2, len(rules))
# added rule with new type # added rule with new type
self.validate_rule(rules[0], "range_transition", "added_type", "system", "infoflow4", self.validate_rule(rules[0], MRT.range_transition, "added_type", "system", "infoflow4",
"s3") "s3")
# added rule with existing types # added rule with existing types
self.validate_rule(rules[1], "range_transition", "rt_added_rule_source", self.validate_rule(rules[1], MRT.range_transition, "rt_added_rule_source",
"rt_added_rule_target", "infoflow", "s3") "rt_added_rule_target", "infoflow", "s3")
def test_removed_range_transition_rules(self): def test_removed_range_transition_rules(self):
@ -762,11 +763,11 @@ class PolicyDifferenceTest(ValidateRule, unittest.TestCase):
self.assertEqual(2, len(rules)) self.assertEqual(2, len(rules))
# removed rule with new type # removed rule with new type
self.validate_rule(rules[0], "range_transition", "removed_type", "system", "infoflow4", self.validate_rule(rules[0], MRT.range_transition, "removed_type", "system", "infoflow4",
"s1") "s1")
# removed rule with existing types # removed rule with existing types
self.validate_rule(rules[1], "range_transition", "rt_removed_rule_source", self.validate_rule(rules[1], MRT.range_transition, "rt_removed_rule_source",
"rt_removed_rule_target", "infoflow", "s1") "rt_removed_rule_target", "infoflow", "s1")
def test_modified_range_transition_rules(self): def test_modified_range_transition_rules(self):
@ -775,7 +776,7 @@ class PolicyDifferenceTest(ValidateRule, unittest.TestCase):
self.assertEqual(1, len(l)) self.assertEqual(1, len(l))
rule, added_default, removed_default = l[0] rule, added_default, removed_default = l[0]
self.assertEqual("range_transition", rule.ruletype) self.assertEqual(MRT.range_transition, rule.ruletype)
self.assertEqual("rt_matched_source", rule.source) self.assertEqual("rt_matched_source", rule.source)
self.assertEqual("system", rule.target) self.assertEqual("system", rule.target)
self.assertEqual("infoflow", rule.tclass) self.assertEqual("infoflow", rule.tclass)

View File

@ -18,7 +18,7 @@
import unittest import unittest
from setools import SELinuxPolicy, MLSRuleQuery from setools import SELinuxPolicy, MLSRuleQuery
from setools.policyrep.exception import InvalidMLSRuleType from setools import MLSRuletype as RT
from . import mixins from . import mixins
@ -52,7 +52,7 @@ class MLSRuleQueryTest(mixins.ValidateRule, unittest.TestCase):
r = sorted(q.results()) r = sorted(q.results())
self.assertEqual(len(r), 1) self.assertEqual(len(r), 1)
self.validate_rule(r[0], "range_transition", "test1s", "test1t", "infoflow", "s0") self.validate_rule(r[0], RT.range_transition, "test1s", "test1t", "infoflow", "s0")
def test_003_source_direct_regex(self): def test_003_source_direct_regex(self):
"""MLS rule query with regex, direct, source match.""" """MLS rule query with regex, direct, source match."""
@ -61,8 +61,8 @@ class MLSRuleQueryTest(mixins.ValidateRule, unittest.TestCase):
r = sorted(q.results()) r = sorted(q.results())
self.assertEqual(len(r), 2) self.assertEqual(len(r), 2)
self.validate_rule(r[0], "range_transition", "test3s", "test3t", "infoflow", "s1") self.validate_rule(r[0], RT.range_transition, "test3s", "test3t", "infoflow", "s1")
self.validate_rule(r[1], "range_transition", "test3s", "test3t", "infoflow2", "s2") self.validate_rule(r[1], RT.range_transition, "test3s", "test3t", "infoflow2", "s2")
def test_005_issue111(self): def test_005_issue111(self):
"""MLS rule query with attribute source criteria, indirect match.""" """MLS rule query with attribute source criteria, indirect match."""
@ -71,8 +71,8 @@ class MLSRuleQueryTest(mixins.ValidateRule, unittest.TestCase):
r = sorted(q.results()) r = sorted(q.results())
self.assertEqual(len(r), 2) self.assertEqual(len(r), 2)
self.validate_rule(r[0], "range_transition", "test5t1", "test5target", "infoflow", "s1") self.validate_rule(r[0], RT.range_transition, "test5t1", "test5target", "infoflow", "s1")
self.validate_rule(r[1], "range_transition", "test5t2", "test5target", "infoflow7", "s2") self.validate_rule(r[1], RT.range_transition, "test5t2", "test5target", "infoflow7", "s2")
def test_010_target_direct(self): def test_010_target_direct(self):
"""MLS rule query with exact, direct, target match.""" """MLS rule query with exact, direct, target match."""
@ -81,8 +81,8 @@ class MLSRuleQueryTest(mixins.ValidateRule, unittest.TestCase):
r = sorted(q.results()) r = sorted(q.results())
self.assertEqual(len(r), 2) self.assertEqual(len(r), 2)
self.validate_rule(r[0], "range_transition", "test10s", "test10t", "infoflow", "s0") self.validate_rule(r[0], RT.range_transition, "test10s", "test10t", "infoflow", "s0")
self.validate_rule(r[1], "range_transition", "test10s", "test10t", "infoflow2", "s1") self.validate_rule(r[1], RT.range_transition, "test10s", "test10t", "infoflow2", "s1")
def test_012_target_direct_regex(self): def test_012_target_direct_regex(self):
"""MLS rule query with regex, direct, target match.""" """MLS rule query with regex, direct, target match."""
@ -91,7 +91,7 @@ class MLSRuleQueryTest(mixins.ValidateRule, unittest.TestCase):
r = sorted(q.results()) r = sorted(q.results())
self.assertEqual(len(r), 1) self.assertEqual(len(r), 1)
self.validate_rule(r[0], "range_transition", "test12s", "test12aFAIL", "infoflow", "s2") self.validate_rule(r[0], RT.range_transition, "test12s", "test12aFAIL", "infoflow", "s2")
def test_014_issue111(self): def test_014_issue111(self):
"""MLS rule query with attribute target criteria, indirect match.""" """MLS rule query with attribute target criteria, indirect match."""
@ -100,8 +100,8 @@ class MLSRuleQueryTest(mixins.ValidateRule, unittest.TestCase):
r = sorted(q.results()) r = sorted(q.results())
self.assertEqual(len(r), 2) self.assertEqual(len(r), 2)
self.validate_rule(r[0], "range_transition", "test14source", "test14t1", "infoflow", "s1") self.validate_rule(r[0], RT.range_transition, "test14source", "test14t1", "infoflow", "s1")
self.validate_rule(r[1], "range_transition", "test14source", "test14t2", "infoflow7", "s2") self.validate_rule(r[1], RT.range_transition, "test14source", "test14t2", "infoflow7", "s2")
@unittest.skip("Setting tclass to a string is no longer supported.") @unittest.skip("Setting tclass to a string is no longer supported.")
def test_020_class(self): def test_020_class(self):
@ -110,7 +110,7 @@ class MLSRuleQueryTest(mixins.ValidateRule, unittest.TestCase):
r = sorted(q.results()) r = sorted(q.results())
self.assertEqual(len(r), 1) self.assertEqual(len(r), 1)
self.validate_rule(r[0], "range_transition", "test20", "test20", "infoflow7", "s1") self.validate_rule(r[0], RT.range_transition, "test20", "test20", "infoflow7", "s1")
def test_021_class_list(self): def test_021_class_list(self):
"""MLS rule query with object class list match.""" """MLS rule query with object class list match."""
@ -119,8 +119,8 @@ class MLSRuleQueryTest(mixins.ValidateRule, unittest.TestCase):
r = sorted(q.results()) r = sorted(q.results())
self.assertEqual(len(r), 2) self.assertEqual(len(r), 2)
self.validate_rule(r[0], "range_transition", "test21", "test21", "infoflow3", "s2") self.validate_rule(r[0], RT.range_transition, "test21", "test21", "infoflow3", "s2")
self.validate_rule(r[1], "range_transition", "test21", "test21", "infoflow4", "s1") self.validate_rule(r[1], RT.range_transition, "test21", "test21", "infoflow4", "s1")
def test_022_class_regex(self): def test_022_class_regex(self):
"""MLS rule query with object class regex match.""" """MLS rule query with object class regex match."""
@ -128,8 +128,8 @@ class MLSRuleQueryTest(mixins.ValidateRule, unittest.TestCase):
r = sorted(q.results()) r = sorted(q.results())
self.assertEqual(len(r), 2) self.assertEqual(len(r), 2)
self.validate_rule(r[0], "range_transition", "test22", "test22", "infoflow5", "s1") self.validate_rule(r[0], RT.range_transition, "test22", "test22", "infoflow5", "s1")
self.validate_rule(r[1], "range_transition", "test22", "test22", "infoflow6", "s2") self.validate_rule(r[1], RT.range_transition, "test22", "test22", "infoflow6", "s2")
def test_040_range_exact(self): def test_040_range_exact(self):
"""MLS rule query with context range exact match""" """MLS rule query with context range exact match"""
@ -137,7 +137,7 @@ class MLSRuleQueryTest(mixins.ValidateRule, unittest.TestCase):
r = sorted(q.results()) r = sorted(q.results())
self.assertEqual(len(r), 1) self.assertEqual(len(r), 1)
self.validate_rule(r[0], "range_transition", "test40", "test40", "infoflow", self.validate_rule(r[0], RT.range_transition, "test40", "test40", "infoflow",
"s40:c1 - s40:c0.c4") "s40:c1 - s40:c0.c4")
def test_041_range_overlap1(self): def test_041_range_overlap1(self):
@ -146,7 +146,7 @@ class MLSRuleQueryTest(mixins.ValidateRule, unittest.TestCase):
r = sorted(q.results()) r = sorted(q.results())
self.assertEqual(len(r), 1) self.assertEqual(len(r), 1)
self.validate_rule(r[0], "range_transition", "test41", "test41", "infoflow", self.validate_rule(r[0], RT.range_transition, "test41", "test41", "infoflow",
"s41:c1 - s41:c1.c3") "s41:c1 - s41:c1.c3")
def test_041_range_overlap2(self): def test_041_range_overlap2(self):
@ -155,7 +155,7 @@ class MLSRuleQueryTest(mixins.ValidateRule, unittest.TestCase):
r = sorted(q.results()) r = sorted(q.results())
self.assertEqual(len(r), 1) self.assertEqual(len(r), 1)
self.validate_rule(r[0], "range_transition", "test41", "test41", "infoflow", self.validate_rule(r[0], RT.range_transition, "test41", "test41", "infoflow",
"s41:c1 - s41:c1.c3") "s41:c1 - s41:c1.c3")
def test_041_range_overlap3(self): def test_041_range_overlap3(self):
@ -164,7 +164,7 @@ class MLSRuleQueryTest(mixins.ValidateRule, unittest.TestCase):
r = sorted(q.results()) r = sorted(q.results())
self.assertEqual(len(r), 1) self.assertEqual(len(r), 1)
self.validate_rule(r[0], "range_transition", "test41", "test41", "infoflow", self.validate_rule(r[0], RT.range_transition, "test41", "test41", "infoflow",
"s41:c1 - s41:c1.c3") "s41:c1 - s41:c1.c3")
def test_041_range_overlap4(self): def test_041_range_overlap4(self):
@ -173,7 +173,7 @@ class MLSRuleQueryTest(mixins.ValidateRule, unittest.TestCase):
r = sorted(q.results()) r = sorted(q.results())
self.assertEqual(len(r), 1) self.assertEqual(len(r), 1)
self.validate_rule(r[0], "range_transition", "test41", "test41", "infoflow", self.validate_rule(r[0], RT.range_transition, "test41", "test41", "infoflow",
"s41:c1 - s41:c1.c3") "s41:c1 - s41:c1.c3")
def test_041_range_overlap5(self): def test_041_range_overlap5(self):
@ -182,7 +182,7 @@ class MLSRuleQueryTest(mixins.ValidateRule, unittest.TestCase):
r = sorted(q.results()) r = sorted(q.results())
self.assertEqual(len(r), 1) self.assertEqual(len(r), 1)
self.validate_rule(r[0], "range_transition", "test41", "test41", "infoflow", self.validate_rule(r[0], RT.range_transition, "test41", "test41", "infoflow",
"s41:c1 - s41:c1.c3") "s41:c1 - s41:c1.c3")
def test_042_range_subset1(self): def test_042_range_subset1(self):
@ -191,7 +191,7 @@ class MLSRuleQueryTest(mixins.ValidateRule, unittest.TestCase):
r = sorted(q.results()) r = sorted(q.results())
self.assertEqual(len(r), 1) self.assertEqual(len(r), 1)
self.validate_rule(r[0], "range_transition", "test42", "test42", "infoflow", self.validate_rule(r[0], RT.range_transition, "test42", "test42", "infoflow",
"s42:c1 - s42:c1.c3") "s42:c1 - s42:c1.c3")
def test_042_range_subset2(self): def test_042_range_subset2(self):
@ -200,7 +200,7 @@ class MLSRuleQueryTest(mixins.ValidateRule, unittest.TestCase):
r = sorted(q.results()) r = sorted(q.results())
self.assertEqual(len(r), 1) self.assertEqual(len(r), 1)
self.validate_rule(r[0], "range_transition", "test42", "test42", "infoflow", self.validate_rule(r[0], RT.range_transition, "test42", "test42", "infoflow",
"s42:c1 - s42:c1.c3") "s42:c1 - s42:c1.c3")
def test_043_range_superset1(self): def test_043_range_superset1(self):
@ -209,7 +209,7 @@ class MLSRuleQueryTest(mixins.ValidateRule, unittest.TestCase):
r = sorted(q.results()) r = sorted(q.results())
self.assertEqual(len(r), 1) self.assertEqual(len(r), 1)
self.validate_rule(r[0], "range_transition", "test43", "test43", "infoflow", self.validate_rule(r[0], RT.range_transition, "test43", "test43", "infoflow",
"s43:c1 - s43:c1.c3") "s43:c1 - s43:c1.c3")
def test_043_range_superset2(self): def test_043_range_superset2(self):
@ -218,7 +218,7 @@ class MLSRuleQueryTest(mixins.ValidateRule, unittest.TestCase):
r = sorted(q.results()) r = sorted(q.results())
self.assertEqual(len(r), 1) self.assertEqual(len(r), 1)
self.validate_rule(r[0], "range_transition", "test43", "test43", "infoflow", self.validate_rule(r[0], RT.range_transition, "test43", "test43", "infoflow",
"s43:c1 - s43:c1.c3") "s43:c1 - s43:c1.c3")
def test_044_range_proper_subset1(self): def test_044_range_proper_subset1(self):
@ -227,7 +227,7 @@ class MLSRuleQueryTest(mixins.ValidateRule, unittest.TestCase):
r = sorted(q.results()) r = sorted(q.results())
self.assertEqual(len(r), 1) self.assertEqual(len(r), 1)
self.validate_rule(r[0], "range_transition", "test44", "test44", "infoflow", self.validate_rule(r[0], RT.range_transition, "test44", "test44", "infoflow",
"s44:c1 - s44:c1.c3") "s44:c1 - s44:c1.c3")
def test_044_range_proper_subset2(self): def test_044_range_proper_subset2(self):
@ -245,7 +245,7 @@ class MLSRuleQueryTest(mixins.ValidateRule, unittest.TestCase):
r = sorted(q.results()) r = sorted(q.results())
self.assertEqual(len(r), 1) self.assertEqual(len(r), 1)
self.validate_rule(r[0], "range_transition", "test44", "test44", "infoflow", self.validate_rule(r[0], RT.range_transition, "test44", "test44", "infoflow",
"s44:c1 - s44:c1.c3") "s44:c1 - s44:c1.c3")
def test_044_range_proper_subset4(self): def test_044_range_proper_subset4(self):
@ -255,7 +255,7 @@ class MLSRuleQueryTest(mixins.ValidateRule, unittest.TestCase):
r = sorted(q.results()) r = sorted(q.results())
self.assertEqual(len(r), 1) self.assertEqual(len(r), 1)
self.validate_rule(r[0], "range_transition", "test44", "test44", "infoflow", self.validate_rule(r[0], RT.range_transition, "test44", "test44", "infoflow",
"s44:c1 - s44:c1.c3") "s44:c1 - s44:c1.c3")
def test_045_range_proper_superset1(self): def test_045_range_proper_superset1(self):
@ -265,7 +265,7 @@ class MLSRuleQueryTest(mixins.ValidateRule, unittest.TestCase):
r = sorted(q.results()) r = sorted(q.results())
self.assertEqual(len(r), 1) self.assertEqual(len(r), 1)
self.validate_rule(r[0], "range_transition", "test45", "test45", "infoflow", self.validate_rule(r[0], RT.range_transition, "test45", "test45", "infoflow",
"s45:c1 - s45:c1.c3") "s45:c1 - s45:c1.c3")
def test_045_range_proper_superset2(self): def test_045_range_proper_superset2(self):
@ -283,7 +283,7 @@ class MLSRuleQueryTest(mixins.ValidateRule, unittest.TestCase):
r = sorted(q.results()) r = sorted(q.results())
self.assertEqual(len(r), 1) self.assertEqual(len(r), 1)
self.validate_rule(r[0], "range_transition", "test45", "test45", "infoflow", self.validate_rule(r[0], RT.range_transition, "test45", "test45", "infoflow",
"s45:c1 - s45:c1.c3") "s45:c1 - s45:c1.c3")
def test_045_range_proper_superset4(self): def test_045_range_proper_superset4(self):
@ -293,10 +293,10 @@ class MLSRuleQueryTest(mixins.ValidateRule, unittest.TestCase):
r = sorted(q.results()) r = sorted(q.results())
self.assertEqual(len(r), 1) self.assertEqual(len(r), 1)
self.validate_rule(r[0], "range_transition", "test45", "test45", "infoflow", self.validate_rule(r[0], RT.range_transition, "test45", "test45", "infoflow",
"s45:c1 - s45:c1.c3") "s45:c1 - s45:c1.c3")
def test_900_invalid_ruletype(self): def test_900_invalid_ruletype(self):
"""MLS rule query with invalid rule type.""" """MLS rule query with invalid rule type."""
with self.assertRaises(InvalidMLSRuleType): with self.assertRaises(KeyError):
q = MLSRuleQuery(self.p, ruletype="type_transition") q = MLSRuleQuery(self.p, ruletype=["type_transition"])

View File

@ -22,6 +22,7 @@ try:
except ImportError: except ImportError:
from mock import Mock, patch from mock import Mock, patch
from setools import MLSRuletype as MRT
from setools.policyrep.qpol import qpol_policy_t, qpol_range_trans_t from setools.policyrep.qpol import qpol_policy_t, qpol_range_trans_t
from setools.policyrep.mlsrule import mls_rule_factory, validate_ruletype from setools.policyrep.mlsrule import mls_rule_factory, validate_ruletype
from setools.policyrep.exception import InvalidMLSRuleType, RuleNotConditional from setools.policyrep.exception import InvalidMLSRuleType, RuleNotConditional
@ -34,7 +35,7 @@ class MLSRuleTest(unittest.TestCase):
def mock_rangetrans_factory(self, source, target, tclass, default): def mock_rangetrans_factory(self, source, target, tclass, default):
mock_rule = Mock(qpol_range_trans_t) mock_rule = Mock(qpol_range_trans_t)
mock_rule.rule_type.return_value = "range_transition" mock_rule.rule_type.return_value = MRT.range_transition
mock_rule.source_type.return_value = source mock_rule.source_type.return_value = source
mock_rule.target_type.return_value = target mock_rule.target_type.return_value = target
mock_rule.object_class.return_value = tclass mock_rule.object_class.return_value = tclass
@ -52,8 +53,9 @@ class MLSRuleTest(unittest.TestCase):
def test_001_validate_ruletype(self): def test_001_validate_ruletype(self):
"""RangeTransition valid rule types.""" """RangeTransition valid rule types."""
self.assertEqual("range_transition", validate_ruletype("range_transition")) self.assertEqual(MRT.range_transition, validate_ruletype("range_transition"))
@unittest.skip("MLS ruletype changed to an enumeration.")
def test_002_validate_ruletype_invalid(self): def test_002_validate_ruletype_invalid(self):
"""RangeTransition valid rule types.""" """RangeTransition valid rule types."""
with self.assertRaises(InvalidMLSRuleType): with self.assertRaises(InvalidMLSRuleType):
@ -62,7 +64,7 @@ class MLSRuleTest(unittest.TestCase):
def test_010_ruletype(self): def test_010_ruletype(self):
"""RangeTransition rule type""" """RangeTransition rule type"""
rule = self.mock_rangetrans_factory("a", "b", "c", "d") rule = self.mock_rangetrans_factory("a", "b", "c", "d")
self.assertEqual("range_transition", rule.ruletype) self.assertEqual(MRT.range_transition, rule.ruletype)
def test_020_source_type(self): def test_020_source_type(self):
"""RangeTransition source type""" """RangeTransition source type"""