swmpo.tree_to_predicate
Simple transformation of sk-learn decision trees to predicates.
1"""Simple transformation of sk-learn decision trees to predicates.""" 2from sklearn.tree import DecisionTreeClassifier 3from functools import reduce 4from swmpo.predicates import Predicate 5from swmpo.predicates import Feature 6from swmpo.predicates import LessThan 7from swmpo.predicates import GreaterThan 8from swmpo.predicates import And 9from swmpo.predicates import Or 10 11 12def tree_to_predicate(tree: DecisionTreeClassifier) -> Predicate: 13 """Transform the decision tree to a predicate. 14 15 It is assumed that there are only two classes: "1" and "0", which 16 correspond to True and False, respectively. 17 """ 18 # Check special case where there is only one class 19 classes = list(tree.classes_) 20 if classes == [0]: 21 return False 22 elif classes == [1]: 23 return True 24 25 tree_ = tree.tree_ 26 27 def find_true_branches( 28 node: int, 29 branch_predicate: Predicate | None, 30 ) -> list[Predicate]: 31 is_leaf = tree_.children_left[node] == tree_.children_right[node] 32 33 if not is_leaf: 34 feature_i = int(tree_.feature[node]) 35 feature = Feature(feature_i) 36 threshold = float(tree_.threshold[node]) 37 left_branch_predicate = LessThan( 38 left=feature, 39 right=threshold, 40 ) 41 right_branch_predicate = GreaterThan( 42 left=feature, 43 right=threshold, 44 ) 45 46 if branch_predicate is not None: 47 left_branch_predicate = And( 48 branch_predicate, 49 left_branch_predicate, 50 ) 51 right_branch_predicate = And( 52 branch_predicate, 53 right_branch_predicate, 54 ) 55 56 true_left_branches = find_true_branches( 57 tree_.children_left[node], 58 left_branch_predicate, 59 ) 60 true_right_branches = find_true_branches( 61 tree_.children_right[node], 62 right_branch_predicate, 63 ) 64 65 true_branches = true_left_branches + true_right_branches 66 return true_branches 67 68 else: 69 # Node is a leaf 70 class_weights = tree_.value[node] 71 class0_weight, class1_weight = class_weights[0] 72 73 if class0_weight > class1_weight: 74 # Leaf is False. There are no True branches 75 return list() 76 elif branch_predicate is not None: 77 # Leaf is True. 78 return [branch_predicate] 79 else: 80 return list() 81 82 true_predicates = find_true_branches(0, None) 83 if len(true_predicates) == 0: 84 return False 85 if len(true_predicates) == 1: 86 return true_predicates[0] 87 return reduce(Or, true_predicates)
def
tree_to_predicate( tree: sklearn.tree._classes.DecisionTreeClassifier) -> swmpo.predicates.And | swmpo.predicates.Or | swmpo.predicates.LessThan | swmpo.predicates.GreaterThan | bool:
13def tree_to_predicate(tree: DecisionTreeClassifier) -> Predicate: 14 """Transform the decision tree to a predicate. 15 16 It is assumed that there are only two classes: "1" and "0", which 17 correspond to True and False, respectively. 18 """ 19 # Check special case where there is only one class 20 classes = list(tree.classes_) 21 if classes == [0]: 22 return False 23 elif classes == [1]: 24 return True 25 26 tree_ = tree.tree_ 27 28 def find_true_branches( 29 node: int, 30 branch_predicate: Predicate | None, 31 ) -> list[Predicate]: 32 is_leaf = tree_.children_left[node] == tree_.children_right[node] 33 34 if not is_leaf: 35 feature_i = int(tree_.feature[node]) 36 feature = Feature(feature_i) 37 threshold = float(tree_.threshold[node]) 38 left_branch_predicate = LessThan( 39 left=feature, 40 right=threshold, 41 ) 42 right_branch_predicate = GreaterThan( 43 left=feature, 44 right=threshold, 45 ) 46 47 if branch_predicate is not None: 48 left_branch_predicate = And( 49 branch_predicate, 50 left_branch_predicate, 51 ) 52 right_branch_predicate = And( 53 branch_predicate, 54 right_branch_predicate, 55 ) 56 57 true_left_branches = find_true_branches( 58 tree_.children_left[node], 59 left_branch_predicate, 60 ) 61 true_right_branches = find_true_branches( 62 tree_.children_right[node], 63 right_branch_predicate, 64 ) 65 66 true_branches = true_left_branches + true_right_branches 67 return true_branches 68 69 else: 70 # Node is a leaf 71 class_weights = tree_.value[node] 72 class0_weight, class1_weight = class_weights[0] 73 74 if class0_weight > class1_weight: 75 # Leaf is False. There are no True branches 76 return list() 77 elif branch_predicate is not None: 78 # Leaf is True. 79 return [branch_predicate] 80 else: 81 return list() 82 83 true_predicates = find_true_branches(0, None) 84 if len(true_predicates) == 0: 85 return False 86 if len(true_predicates) == 1: 87 return true_predicates[0] 88 return reduce(Or, true_predicates)
Transform the decision tree to a predicate.
It is assumed that there are only two classes: "1" and "0", which correspond to True and False, respectively.