swmpo.predicates
Simple transformation of sk-learn decision trees to predicates.
1"""Simple transformation of sk-learn decision trees to predicates.""" 2from dataclasses import dataclass 3from typing import Any 4import jsonpickle 5 6 7@dataclass 8class And: 9 left: "Predicate" 10 right: "Predicate" 11 12 13@dataclass 14class Or: 15 left: "Predicate" 16 right: "Predicate" 17 18 19@dataclass 20class LessThan: 21 left: "Value" 22 right: "Value" 23 24 25@dataclass 26class GreaterThan: 27 left: "Value" 28 right: "Value" 29 30 31@dataclass 32class Feature: 33 index: int 34 35 36Value = ( 37 Feature | 38 float 39) 40 41 42Predicate = ( 43 And | 44 Or | 45 LessThan | 46 GreaterThan | 47 bool 48) 49 50 51def predicate_to_str(predicate: Predicate) -> str: 52 """Serialize predicate as a string.""" 53 return jsonpickle.encode( 54 predicate, 55 make_refs=False, 56 indent=2, 57 ) 58 59 60def type_check_value(obj: Any) -> bool: 61 allowed_classes = ( 62 Feature, 63 float, 64 ) 65 if not isinstance(obj, allowed_classes): 66 return False 67 68 if isinstance(obj, float): 69 return True 70 71 # obj is Feature 72 if not isinstance(obj.index, int): 73 return False 74 75 return True 76 77 78def type_check_predicate(obj: Any) -> bool: 79 allowed_classes = ( 80 And, 81 Or, 82 LessThan, 83 GreaterThan, 84 bool, 85 ) 86 if not isinstance(obj, allowed_classes): 87 return False 88 89 if isinstance(obj, bool): 90 return True 91 92 if isinstance(obj, (And, Or)): 93 if not type_check_predicate(obj.left): 94 return False 95 if not type_check_predicate(obj.right): 96 return False 97 return True 98 99 # Obj is LessThan or GreaterThan 100 if not type_check_value(obj.left): 101 return False 102 if not type_check_value(obj.right): 103 return False 104 return True 105 106 107def str_to_predicate(predicate: str) -> Predicate: 108 """Deserialize predicate from a string.""" 109 obj = jsonpickle.decode( 110 predicate, 111 on_missing="error", 112 ) 113 assert type_check_predicate(obj) 114 return obj 115 116 117def get_value( 118 value: Value, 119 x: list[float], 120) -> float: 121 if isinstance(value, float): 122 return value 123 elif isinstance(value, Feature): 124 return x[value.index] 125 else: 126 raise ValueError( 127 f"Input of type '{type(value)}' not a predicate value!" 128 ) 129 130 131def get_robustness_value( 132 predicate: Predicate, 133 x: list[float], 134) -> float: 135 """Get the robustness value of the predicate evaluated on the given 136 input. 137 138 A predicate is said to be True for a given input if and only if 139 the corresponding robustness value is greater than zero. 140 """ 141 if isinstance(predicate, bool): 142 if predicate: 143 return 1.0 144 else: 145 return -1.0 146 elif isinstance(predicate, And): 147 lv = get_robustness_value(predicate.left, x) 148 rv = get_robustness_value(predicate.right, x) 149 return min(lv, rv) 150 elif isinstance(predicate, Or): 151 lv = get_robustness_value(predicate.left, x) 152 rv = get_robustness_value(predicate.right, x) 153 return max(lv, rv) 154 elif isinstance(predicate, LessThan): 155 lv = get_value(predicate.left, x) 156 rv = get_value(predicate.right, x) 157 return rv - lv 158 elif isinstance(predicate, GreaterThan): 159 lv = get_value(predicate.left, x) 160 rv = get_value(predicate.right, x) 161 return lv - rv 162 else: 163 raise ValueError(f"Predicate of invalid type '{type(predicate)}'!") 164 165 166def get_pretty_str(predicate: Predicate) -> str: 167 def add_indent(s: str) -> str: 168 lines = s.splitlines() 169 lines = [f"\t{li}" for li in lines] 170 return "\n".join(lines) 171 172 def pretty_value(s: Value) -> str: 173 if isinstance(s, (float, int)): 174 return str(s) 175 else: 176 return f"x[{s.index}]" 177 178 def recurse(predicate: Predicate) -> str: 179 if isinstance(predicate, bool): 180 if predicate: 181 return "True" 182 else: 183 return "False" 184 elif isinstance(predicate, And): 185 ls = recurse(predicate.left) 186 rs = recurse(predicate.right) 187 ls = add_indent(ls) 188 rs = add_indent(rs) 189 ls = f"(\n{ls}\n)" 190 rs = f"(\n{rs}\n)" 191 return f"{ls}\nAND\n{rs}" 192 elif isinstance(predicate, Or): 193 ls = recurse(predicate.left) 194 rs = recurse(predicate.right) 195 ls = add_indent(ls) 196 rs = add_indent(rs) 197 ls = f"(\n{ls}\n)" 198 rs = f"(\n{rs}\n)" 199 return f"{ls}\nOR\n{rs}" 200 elif isinstance(predicate, LessThan): 201 ls = pretty_value(predicate.left) 202 rs = pretty_value(predicate.right) 203 return f"{ls} < {rs}" 204 elif isinstance(predicate, GreaterThan): 205 ls = pretty_value(predicate.left) 206 rs = pretty_value(predicate.right) 207 return f"{ls} > {rs}" 208 else: 209 raise ValueError(f"Predicate of invalid type '{type(predicate)}'!") 210 211 return recurse(predicate)
@dataclass
class
And:
And( left: And | Or | LessThan | GreaterThan | bool, right: And | Or | LessThan | GreaterThan | bool)
left: And | Or | LessThan | GreaterThan | bool
right: And | Or | LessThan | GreaterThan | bool
@dataclass
class
Or:
Or( left: And | Or | LessThan | GreaterThan | bool, right: And | Or | LessThan | GreaterThan | bool)
left: And | Or | LessThan | GreaterThan | bool
right: And | Or | LessThan | GreaterThan | bool
@dataclass
class
LessThan:
left: Feature | float
right: Feature | float
@dataclass
class
GreaterThan:
left: Feature | float
right: Feature | float
@dataclass
class
Feature:
Value =
Feature | float
Predicate =
And | Or | LessThan | GreaterThan | bool
52def predicate_to_str(predicate: Predicate) -> str: 53 """Serialize predicate as a string.""" 54 return jsonpickle.encode( 55 predicate, 56 make_refs=False, 57 indent=2, 58 )
Serialize predicate as a string.
def
type_check_value(obj: Any) -> bool:
def
type_check_predicate(obj: Any) -> bool:
79def type_check_predicate(obj: Any) -> bool: 80 allowed_classes = ( 81 And, 82 Or, 83 LessThan, 84 GreaterThan, 85 bool, 86 ) 87 if not isinstance(obj, allowed_classes): 88 return False 89 90 if isinstance(obj, bool): 91 return True 92 93 if isinstance(obj, (And, Or)): 94 if not type_check_predicate(obj.left): 95 return False 96 if not type_check_predicate(obj.right): 97 return False 98 return True 99 100 # Obj is LessThan or GreaterThan 101 if not type_check_value(obj.left): 102 return False 103 if not type_check_value(obj.right): 104 return False 105 return True
108def str_to_predicate(predicate: str) -> Predicate: 109 """Deserialize predicate from a string.""" 110 obj = jsonpickle.decode( 111 predicate, 112 on_missing="error", 113 ) 114 assert type_check_predicate(obj) 115 return obj
Deserialize predicate from a string.
def
get_robustness_value( predicate: And | Or | LessThan | GreaterThan | bool, x: list[float]) -> float:
132def get_robustness_value( 133 predicate: Predicate, 134 x: list[float], 135) -> float: 136 """Get the robustness value of the predicate evaluated on the given 137 input. 138 139 A predicate is said to be True for a given input if and only if 140 the corresponding robustness value is greater than zero. 141 """ 142 if isinstance(predicate, bool): 143 if predicate: 144 return 1.0 145 else: 146 return -1.0 147 elif isinstance(predicate, And): 148 lv = get_robustness_value(predicate.left, x) 149 rv = get_robustness_value(predicate.right, x) 150 return min(lv, rv) 151 elif isinstance(predicate, Or): 152 lv = get_robustness_value(predicate.left, x) 153 rv = get_robustness_value(predicate.right, x) 154 return max(lv, rv) 155 elif isinstance(predicate, LessThan): 156 lv = get_value(predicate.left, x) 157 rv = get_value(predicate.right, x) 158 return rv - lv 159 elif isinstance(predicate, GreaterThan): 160 lv = get_value(predicate.left, x) 161 rv = get_value(predicate.right, x) 162 return lv - rv 163 else: 164 raise ValueError(f"Predicate of invalid type '{type(predicate)}'!")
Get the robustness value of the predicate evaluated on the given input.
A predicate is said to be True for a given input if and only if the corresponding robustness value is greater than zero.
167def get_pretty_str(predicate: Predicate) -> str: 168 def add_indent(s: str) -> str: 169 lines = s.splitlines() 170 lines = [f"\t{li}" for li in lines] 171 return "\n".join(lines) 172 173 def pretty_value(s: Value) -> str: 174 if isinstance(s, (float, int)): 175 return str(s) 176 else: 177 return f"x[{s.index}]" 178 179 def recurse(predicate: Predicate) -> str: 180 if isinstance(predicate, bool): 181 if predicate: 182 return "True" 183 else: 184 return "False" 185 elif isinstance(predicate, And): 186 ls = recurse(predicate.left) 187 rs = recurse(predicate.right) 188 ls = add_indent(ls) 189 rs = add_indent(rs) 190 ls = f"(\n{ls}\n)" 191 rs = f"(\n{rs}\n)" 192 return f"{ls}\nAND\n{rs}" 193 elif isinstance(predicate, Or): 194 ls = recurse(predicate.left) 195 rs = recurse(predicate.right) 196 ls = add_indent(ls) 197 rs = add_indent(rs) 198 ls = f"(\n{ls}\n)" 199 rs = f"(\n{rs}\n)" 200 return f"{ls}\nOR\n{rs}" 201 elif isinstance(predicate, LessThan): 202 ls = pretty_value(predicate.left) 203 rs = pretty_value(predicate.right) 204 return f"{ls} < {rs}" 205 elif isinstance(predicate, GreaterThan): 206 ls = pretty_value(predicate.left) 207 rs = pretty_value(predicate.right) 208 return f"{ls} > {rs}" 209 else: 210 raise ValueError(f"Predicate of invalid type '{type(predicate)}'!") 211 212 return recurse(predicate)