swmpo.predicates

Simple transformation of sk-learn decision trees to predicates.

  1"""Simple transformation of sk-learn decision trees to predicates."""
  2from dataclasses import dataclass
  3from typing import Any
  4import jsonpickle
  5
  6
  7@dataclass
  8class And:
  9    left: "Predicate"
 10    right: "Predicate"
 11
 12
 13@dataclass
 14class Or:
 15    left: "Predicate"
 16    right: "Predicate"
 17
 18
 19@dataclass
 20class LessThan:
 21    left: "Value"
 22    right: "Value"
 23
 24
 25@dataclass
 26class GreaterThan:
 27    left: "Value"
 28    right: "Value"
 29
 30
 31@dataclass
 32class Feature:
 33    index: int
 34
 35
 36Value = (
 37    Feature |
 38    float
 39)
 40
 41
 42Predicate = (
 43    And |
 44    Or |
 45    LessThan |
 46    GreaterThan |
 47    bool
 48)
 49
 50
 51def predicate_to_str(predicate: Predicate) -> str:
 52    """Serialize predicate as a string."""
 53    return jsonpickle.encode(
 54        predicate,
 55        make_refs=False,
 56        indent=2,
 57    )
 58
 59
 60def type_check_value(obj: Any) -> bool:
 61    allowed_classes = (
 62        Feature,
 63        float,
 64    )
 65    if not isinstance(obj, allowed_classes):
 66        return False
 67
 68    if isinstance(obj, float):
 69        return True
 70
 71    # obj is Feature
 72    if not isinstance(obj.index, int):
 73        return False
 74
 75    return True
 76
 77
 78def type_check_predicate(obj: Any) -> bool:
 79    allowed_classes = (
 80        And,
 81        Or,
 82        LessThan,
 83        GreaterThan,
 84        bool,
 85    )
 86    if not isinstance(obj, allowed_classes):
 87        return False
 88
 89    if isinstance(obj, bool):
 90        return True
 91
 92    if isinstance(obj, (And, Or)):
 93        if not type_check_predicate(obj.left):
 94            return False
 95        if not type_check_predicate(obj.right):
 96            return False
 97        return True
 98
 99    # Obj is LessThan or GreaterThan
100    if not type_check_value(obj.left):
101        return False
102    if not type_check_value(obj.right):
103        return False
104    return True
105
106
107def str_to_predicate(predicate: str) -> Predicate:
108    """Deserialize predicate from a string."""
109    obj = jsonpickle.decode(
110        predicate,
111        on_missing="error",
112    )
113    assert type_check_predicate(obj)
114    return obj
115
116
117def get_value(
118    value: Value,
119    x: list[float],
120) -> float:
121    if isinstance(value, float):
122        return value
123    elif isinstance(value, Feature):
124        return x[value.index]
125    else:
126        raise ValueError(
127            f"Input of type '{type(value)}' not a predicate value!"
128        )
129
130
131def get_robustness_value(
132    predicate: Predicate,
133    x: list[float],
134) -> float:
135    """Get the robustness value of the predicate evaluated on the given
136    input.
137
138    A predicate is said to be True for a given input if and only if
139    the corresponding robustness value is greater than zero.
140    """
141    if isinstance(predicate, bool):
142        if predicate:
143            return 1.0
144        else:
145            return -1.0
146    elif isinstance(predicate, And):
147        lv = get_robustness_value(predicate.left, x)
148        rv = get_robustness_value(predicate.right, x)
149        return min(lv, rv)
150    elif isinstance(predicate, Or):
151        lv = get_robustness_value(predicate.left, x)
152        rv = get_robustness_value(predicate.right, x)
153        return max(lv, rv)
154    elif isinstance(predicate, LessThan):
155        lv = get_value(predicate.left, x)
156        rv = get_value(predicate.right, x)
157        return rv - lv
158    elif isinstance(predicate, GreaterThan):
159        lv = get_value(predicate.left, x)
160        rv = get_value(predicate.right, x)
161        return lv - rv
162    else:
163        raise ValueError(f"Predicate of invalid type '{type(predicate)}'!")
164
165
166def get_pretty_str(predicate: Predicate) -> str:
167    def add_indent(s: str) -> str:
168        lines = s.splitlines()
169        lines = [f"\t{li}" for li in lines]
170        return "\n".join(lines)
171
172    def pretty_value(s: Value) -> str:
173        if isinstance(s, (float, int)):
174            return str(s)
175        else:
176            return f"x[{s.index}]"
177
178    def recurse(predicate: Predicate) -> str:
179        if isinstance(predicate, bool):
180            if predicate:
181                return "True"
182            else:
183                return "False"
184        elif isinstance(predicate, And):
185            ls = recurse(predicate.left)
186            rs = recurse(predicate.right)
187            ls = add_indent(ls)
188            rs = add_indent(rs)
189            ls = f"(\n{ls}\n)"
190            rs = f"(\n{rs}\n)"
191            return f"{ls}\nAND\n{rs}"
192        elif isinstance(predicate, Or):
193            ls = recurse(predicate.left)
194            rs = recurse(predicate.right)
195            ls = add_indent(ls)
196            rs = add_indent(rs)
197            ls = f"(\n{ls}\n)"
198            rs = f"(\n{rs}\n)"
199            return f"{ls}\nOR\n{rs}"
200        elif isinstance(predicate, LessThan):
201            ls = pretty_value(predicate.left)
202            rs = pretty_value(predicate.right)
203            return f"{ls} < {rs}"
204        elif isinstance(predicate, GreaterThan):
205            ls = pretty_value(predicate.left)
206            rs = pretty_value(predicate.right)
207            return f"{ls} > {rs}"
208        else:
209            raise ValueError(f"Predicate of invalid type '{type(predicate)}'!")
210
211    return recurse(predicate)
@dataclass
class And:
 8@dataclass
 9class And:
10    left: "Predicate"
11    right: "Predicate"
And( left: And | Or | LessThan | GreaterThan | bool, right: And | Or | LessThan | GreaterThan | bool)
left: And | Or | LessThan | GreaterThan | bool
right: And | Or | LessThan | GreaterThan | bool
@dataclass
class Or:
14@dataclass
15class Or:
16    left: "Predicate"
17    right: "Predicate"
Or( left: And | Or | LessThan | GreaterThan | bool, right: And | Or | LessThan | GreaterThan | bool)
left: And | Or | LessThan | GreaterThan | bool
right: And | Or | LessThan | GreaterThan | bool
@dataclass
class LessThan:
20@dataclass
21class LessThan:
22    left: "Value"
23    right: "Value"
LessThan( left: Feature | float, right: Feature | float)
left: Feature | float
right: Feature | float
@dataclass
class GreaterThan:
26@dataclass
27class GreaterThan:
28    left: "Value"
29    right: "Value"
GreaterThan( left: Feature | float, right: Feature | float)
left: Feature | float
right: Feature | float
@dataclass
class Feature:
32@dataclass
33class Feature:
34    index: int
Feature(index: int)
index: int
Value = Feature | float
Predicate = And | Or | LessThan | GreaterThan | bool
def predicate_to_str( predicate: And | Or | LessThan | GreaterThan | bool) -> str:
52def predicate_to_str(predicate: Predicate) -> str:
53    """Serialize predicate as a string."""
54    return jsonpickle.encode(
55        predicate,
56        make_refs=False,
57        indent=2,
58    )

Serialize predicate as a string.

def type_check_value(obj: Any) -> bool:
61def type_check_value(obj: Any) -> bool:
62    allowed_classes = (
63        Feature,
64        float,
65    )
66    if not isinstance(obj, allowed_classes):
67        return False
68
69    if isinstance(obj, float):
70        return True
71
72    # obj is Feature
73    if not isinstance(obj.index, int):
74        return False
75
76    return True
def type_check_predicate(obj: Any) -> bool:
 79def type_check_predicate(obj: Any) -> bool:
 80    allowed_classes = (
 81        And,
 82        Or,
 83        LessThan,
 84        GreaterThan,
 85        bool,
 86    )
 87    if not isinstance(obj, allowed_classes):
 88        return False
 89
 90    if isinstance(obj, bool):
 91        return True
 92
 93    if isinstance(obj, (And, Or)):
 94        if not type_check_predicate(obj.left):
 95            return False
 96        if not type_check_predicate(obj.right):
 97            return False
 98        return True
 99
100    # Obj is LessThan or GreaterThan
101    if not type_check_value(obj.left):
102        return False
103    if not type_check_value(obj.right):
104        return False
105    return True
def str_to_predicate( predicate: str) -> And | Or | LessThan | GreaterThan | bool:
108def str_to_predicate(predicate: str) -> Predicate:
109    """Deserialize predicate from a string."""
110    obj = jsonpickle.decode(
111        predicate,
112        on_missing="error",
113    )
114    assert type_check_predicate(obj)
115    return obj

Deserialize predicate from a string.

def get_value(value: Feature | float, x: list[float]) -> float:
118def get_value(
119    value: Value,
120    x: list[float],
121) -> float:
122    if isinstance(value, float):
123        return value
124    elif isinstance(value, Feature):
125        return x[value.index]
126    else:
127        raise ValueError(
128            f"Input of type '{type(value)}' not a predicate value!"
129        )
def get_robustness_value( predicate: And | Or | LessThan | GreaterThan | bool, x: list[float]) -> float:
132def get_robustness_value(
133    predicate: Predicate,
134    x: list[float],
135) -> float:
136    """Get the robustness value of the predicate evaluated on the given
137    input.
138
139    A predicate is said to be True for a given input if and only if
140    the corresponding robustness value is greater than zero.
141    """
142    if isinstance(predicate, bool):
143        if predicate:
144            return 1.0
145        else:
146            return -1.0
147    elif isinstance(predicate, And):
148        lv = get_robustness_value(predicate.left, x)
149        rv = get_robustness_value(predicate.right, x)
150        return min(lv, rv)
151    elif isinstance(predicate, Or):
152        lv = get_robustness_value(predicate.left, x)
153        rv = get_robustness_value(predicate.right, x)
154        return max(lv, rv)
155    elif isinstance(predicate, LessThan):
156        lv = get_value(predicate.left, x)
157        rv = get_value(predicate.right, x)
158        return rv - lv
159    elif isinstance(predicate, GreaterThan):
160        lv = get_value(predicate.left, x)
161        rv = get_value(predicate.right, x)
162        return lv - rv
163    else:
164        raise ValueError(f"Predicate of invalid type '{type(predicate)}'!")

Get the robustness value of the predicate evaluated on the given input.

A predicate is said to be True for a given input if and only if the corresponding robustness value is greater than zero.

def get_pretty_str( predicate: And | Or | LessThan | GreaterThan | bool) -> str:
167def get_pretty_str(predicate: Predicate) -> str:
168    def add_indent(s: str) -> str:
169        lines = s.splitlines()
170        lines = [f"\t{li}" for li in lines]
171        return "\n".join(lines)
172
173    def pretty_value(s: Value) -> str:
174        if isinstance(s, (float, int)):
175            return str(s)
176        else:
177            return f"x[{s.index}]"
178
179    def recurse(predicate: Predicate) -> str:
180        if isinstance(predicate, bool):
181            if predicate:
182                return "True"
183            else:
184                return "False"
185        elif isinstance(predicate, And):
186            ls = recurse(predicate.left)
187            rs = recurse(predicate.right)
188            ls = add_indent(ls)
189            rs = add_indent(rs)
190            ls = f"(\n{ls}\n)"
191            rs = f"(\n{rs}\n)"
192            return f"{ls}\nAND\n{rs}"
193        elif isinstance(predicate, Or):
194            ls = recurse(predicate.left)
195            rs = recurse(predicate.right)
196            ls = add_indent(ls)
197            rs = add_indent(rs)
198            ls = f"(\n{ls}\n)"
199            rs = f"(\n{rs}\n)"
200            return f"{ls}\nOR\n{rs}"
201        elif isinstance(predicate, LessThan):
202            ls = pretty_value(predicate.left)
203            rs = pretty_value(predicate.right)
204            return f"{ls} < {rs}"
205        elif isinstance(predicate, GreaterThan):
206            ls = pretty_value(predicate.left)
207            rs = pretty_value(predicate.right)
208            return f"{ls} > {rs}"
209        else:
210            raise ValueError(f"Predicate of invalid type '{type(predicate)}'!")
211
212    return recurse(predicate)