swmpo.tree_to_predicate

Simple transformation of sk-learn decision trees to predicates.

 1"""Simple transformation of sk-learn decision trees to predicates."""
 2from sklearn.tree import DecisionTreeClassifier
 3from functools import reduce
 4from swmpo.predicates import Predicate
 5from swmpo.predicates import Feature
 6from swmpo.predicates import LessThan
 7from swmpo.predicates import GreaterThan
 8from swmpo.predicates import And
 9from swmpo.predicates import Or
10
11
12def tree_to_predicate(tree: DecisionTreeClassifier) -> Predicate:
13    """Transform the decision tree to a predicate.
14
15    It is assumed that there are only two classes: "1" and "0", which
16    correspond to True and False, respectively.
17    """
18    # Check special case where there is only one class
19    classes = list(tree.classes_)
20    if classes == [0]:
21        return False
22    elif classes == [1]:
23        return True
24
25    tree_ = tree.tree_
26
27    def find_true_branches(
28        node: int,
29        branch_predicate: Predicate | None,
30    ) -> list[Predicate]:
31        is_leaf = tree_.children_left[node] == tree_.children_right[node]
32
33        if not is_leaf:
34            feature_i = int(tree_.feature[node])
35            feature = Feature(feature_i)
36            threshold = float(tree_.threshold[node])
37            left_branch_predicate = LessThan(
38                left=feature,
39                right=threshold,
40            )
41            right_branch_predicate = GreaterThan(
42                left=feature,
43                right=threshold,
44            )
45
46            if branch_predicate is not None:
47                left_branch_predicate = And(
48                    branch_predicate,
49                    left_branch_predicate,
50                )
51                right_branch_predicate = And(
52                    branch_predicate,
53                    right_branch_predicate,
54                )
55
56            true_left_branches = find_true_branches(
57                tree_.children_left[node],
58                left_branch_predicate,
59            )
60            true_right_branches = find_true_branches(
61                tree_.children_right[node],
62                right_branch_predicate,
63            )
64
65            true_branches = true_left_branches + true_right_branches
66            return true_branches
67
68        else:
69            # Node is a leaf
70            class_weights = tree_.value[node]
71            class0_weight, class1_weight = class_weights[0]
72
73            if class0_weight > class1_weight:
74                # Leaf is False. There are no True branches
75                return list()
76            elif branch_predicate is not None:
77                # Leaf is True.
78                return [branch_predicate]
79            else:
80                return list()
81
82    true_predicates = find_true_branches(0, None)
83    if len(true_predicates) == 0:
84        return False
85    if len(true_predicates) == 1:
86        return true_predicates[0]
87    return reduce(Or, true_predicates)
def tree_to_predicate( tree: sklearn.tree._classes.DecisionTreeClassifier) -> swmpo.predicates.And | swmpo.predicates.Or | swmpo.predicates.LessThan | swmpo.predicates.GreaterThan | bool:
13def tree_to_predicate(tree: DecisionTreeClassifier) -> Predicate:
14    """Transform the decision tree to a predicate.
15
16    It is assumed that there are only two classes: "1" and "0", which
17    correspond to True and False, respectively.
18    """
19    # Check special case where there is only one class
20    classes = list(tree.classes_)
21    if classes == [0]:
22        return False
23    elif classes == [1]:
24        return True
25
26    tree_ = tree.tree_
27
28    def find_true_branches(
29        node: int,
30        branch_predicate: Predicate | None,
31    ) -> list[Predicate]:
32        is_leaf = tree_.children_left[node] == tree_.children_right[node]
33
34        if not is_leaf:
35            feature_i = int(tree_.feature[node])
36            feature = Feature(feature_i)
37            threshold = float(tree_.threshold[node])
38            left_branch_predicate = LessThan(
39                left=feature,
40                right=threshold,
41            )
42            right_branch_predicate = GreaterThan(
43                left=feature,
44                right=threshold,
45            )
46
47            if branch_predicate is not None:
48                left_branch_predicate = And(
49                    branch_predicate,
50                    left_branch_predicate,
51                )
52                right_branch_predicate = And(
53                    branch_predicate,
54                    right_branch_predicate,
55                )
56
57            true_left_branches = find_true_branches(
58                tree_.children_left[node],
59                left_branch_predicate,
60            )
61            true_right_branches = find_true_branches(
62                tree_.children_right[node],
63                right_branch_predicate,
64            )
65
66            true_branches = true_left_branches + true_right_branches
67            return true_branches
68
69        else:
70            # Node is a leaf
71            class_weights = tree_.value[node]
72            class0_weight, class1_weight = class_weights[0]
73
74            if class0_weight > class1_weight:
75                # Leaf is False. There are no True branches
76                return list()
77            elif branch_predicate is not None:
78                # Leaf is True.
79                return [branch_predicate]
80            else:
81                return list()
82
83    true_predicates = find_true_branches(0, None)
84    if len(true_predicates) == 0:
85        return False
86    if len(true_predicates) == 1:
87        return true_predicates[0]
88    return reduce(Or, true_predicates)

Transform the decision tree to a predicate.

It is assumed that there are only two classes: "1" and "0", which correspond to True and False, respectively.