swmpo.transition_predicates

Characterization of the transitions between sets of states.

A transition predicate is a function that takes as input a vector corresponding to the concatenation of a state and action and maps it to number. If the number is greater than zero, then the predicate is true. Otherwise the predicate is said to be false.

  1"""Characterization of the transitions between sets of states.
  2
  3A transition predicate is a function that takes as input a vector
  4corresponding to the concatenation of a state and action and maps it
  5to number. If the number is greater than zero, then the predicate is true.
  6Otherwise the predicate is said to be false.
  7"""
  8from swmpo.transition import Transition
  9from swmpo.transition import get_vector
 10from collections import defaultdict
 11from itertools import product
 12from sklearn.preprocessing import StandardScaler
 13from sklearn.pipeline import Pipeline
 14from sklearn.neural_network import MLPClassifier
 15from sklearn.tree import DecisionTreeClassifier
 16from swmpo.tree_to_predicate import tree_to_predicate
 17from swmpo.predicates import Predicate
 18import random
 19from dataclasses import dataclass
 20
 21
 22@dataclass
 23class TransitionPredicates:
 24    transition_predicates: list[list[Predicate]]
 25    transition_histogram: list[list[int]]
 26
 27
 28def get_transition_predicates(
 29    partition: list[list[Transition]],
 30    predicate_hyperparameters: dict[str, float | int | str],
 31    seed: str,
 32) -> TransitionPredicates:
 33    """Returns the adjacency matrix containing the
 34    predicates and histogram of transitions between states."""
 35    _random = random.Random(seed)
 36    state_indices = list(range(len(partition)))
 37    transition_predicate_matrix = defaultdict(dict)
 38    transition_histogram = [[0 for _ in state_indices] for _ in state_indices]
 39
 40    # SPEED: make source state existence look-up fast for
 41    # predicate dataset construction. Otherwise checking if a
 42    # "next state" is the "source state" of any transition is
 43    # prohibitively slow.
 44    source_states = [
 45        set([
 46            tuple(transition.source_state.tolist())
 47            for transition in subset
 48        ])
 49        for subset in partition
 50    ]
 51    # /SPEED
 52
 53    for i, j in product(state_indices, state_indices):
 54        # Partition the set of next states in node i into two sets:
 55        # - All the next states in node i that are a source state
 56        # in node j. These are states which say "yes, transition from
 57        # node i to node j".
 58        # - All the next states in node i that are not a source state
 59        # in node j. These are states which say "no, do not transition from
 60        # node i to node j".
 61        positive = list()
 62        negative = list()
 63        for transition in partition[i]:
 64            next_state = tuple(transition.next_state.tolist())
 65            is_next_source_in_j = next_state in source_states[j]
 66            vector = get_vector(transition)
 67            if is_next_source_in_j:
 68                positive.append(vector.detach().numpy())
 69            else:
 70                negative.append(vector.detach().numpy())
 71
 72        # Make the dataset balanced
 73        min_size = min(len(positive), len(negative))
 74        _random.shuffle(positive)
 75        _random.shuffle(negative)
 76        positive = positive[:min_size]
 77        negative = negative[:min_size]
 78
 79        # Then, turn these two sets into classification problem.
 80        X = positive + negative
 81        Y = [1 for _ in positive] + [0 for _ in negative]
 82
 83        # Synthesize a transition predicate
 84        if len(X) == 0:
 85            transition_predicate = False
 86        else:
 87            #transition_predicate = Pipeline([
 88            #    #("normalizer", StandardScaler()),
 89            #    #("classifier", MLPClassifier(
 90            #    #    random_state=_random.getrandbits(32),
 91            #    #    **predicate_hyperparameters,
 92            #    #)),
 93            #    ("classifier", DecisionTreeClassifier(
 94            #        random_state=_random.getrandbits(32),
 95            #    )),
 96            #])
 97            tree = DecisionTreeClassifier(
 98                random_state=_random.getrandbits(32),
 99                **predicate_hyperparameters,
100            )
101            tree.fit(X, Y)
102
103            transition_predicate = tree_to_predicate(tree)
104
105        # Store the transition predicate
106        transition_predicate_matrix[i][j] = transition_predicate
107
108        # Store the non-conditional transition probability
109        transition_histogram[i][j] = len(positive)
110
111    # Turn matrix into lists to adhere to API
112    transition_predicates = [
113        [
114            transition_predicate_matrix[i][j]
115            for j in transition_predicate_matrix[i].keys()
116        ]
117        for i in transition_predicate_matrix.keys()
118    ]
119    data = TransitionPredicates(
120        transition_predicates=transition_predicates,
121        transition_histogram=transition_histogram,
122    )
123    return data
@dataclass
class TransitionPredicates:
23@dataclass
24class TransitionPredicates:
25    transition_predicates: list[list[Predicate]]
26    transition_histogram: list[list[int]]
TransitionPredicates( transition_predicates: list[list[swmpo.predicates.And | swmpo.predicates.Or | swmpo.predicates.LessThan | swmpo.predicates.GreaterThan | bool]], transition_histogram: list[list[int]])
transition_histogram: list[list[int]]
def get_transition_predicates( partition: list[list[swmpo.transition.Transition]], predicate_hyperparameters: dict[str, float | int | str], seed: str) -> TransitionPredicates:
 29def get_transition_predicates(
 30    partition: list[list[Transition]],
 31    predicate_hyperparameters: dict[str, float | int | str],
 32    seed: str,
 33) -> TransitionPredicates:
 34    """Returns the adjacency matrix containing the
 35    predicates and histogram of transitions between states."""
 36    _random = random.Random(seed)
 37    state_indices = list(range(len(partition)))
 38    transition_predicate_matrix = defaultdict(dict)
 39    transition_histogram = [[0 for _ in state_indices] for _ in state_indices]
 40
 41    # SPEED: make source state existence look-up fast for
 42    # predicate dataset construction. Otherwise checking if a
 43    # "next state" is the "source state" of any transition is
 44    # prohibitively slow.
 45    source_states = [
 46        set([
 47            tuple(transition.source_state.tolist())
 48            for transition in subset
 49        ])
 50        for subset in partition
 51    ]
 52    # /SPEED
 53
 54    for i, j in product(state_indices, state_indices):
 55        # Partition the set of next states in node i into two sets:
 56        # - All the next states in node i that are a source state
 57        # in node j. These are states which say "yes, transition from
 58        # node i to node j".
 59        # - All the next states in node i that are not a source state
 60        # in node j. These are states which say "no, do not transition from
 61        # node i to node j".
 62        positive = list()
 63        negative = list()
 64        for transition in partition[i]:
 65            next_state = tuple(transition.next_state.tolist())
 66            is_next_source_in_j = next_state in source_states[j]
 67            vector = get_vector(transition)
 68            if is_next_source_in_j:
 69                positive.append(vector.detach().numpy())
 70            else:
 71                negative.append(vector.detach().numpy())
 72
 73        # Make the dataset balanced
 74        min_size = min(len(positive), len(negative))
 75        _random.shuffle(positive)
 76        _random.shuffle(negative)
 77        positive = positive[:min_size]
 78        negative = negative[:min_size]
 79
 80        # Then, turn these two sets into classification problem.
 81        X = positive + negative
 82        Y = [1 for _ in positive] + [0 for _ in negative]
 83
 84        # Synthesize a transition predicate
 85        if len(X) == 0:
 86            transition_predicate = False
 87        else:
 88            #transition_predicate = Pipeline([
 89            #    #("normalizer", StandardScaler()),
 90            #    #("classifier", MLPClassifier(
 91            #    #    random_state=_random.getrandbits(32),
 92            #    #    **predicate_hyperparameters,
 93            #    #)),
 94            #    ("classifier", DecisionTreeClassifier(
 95            #        random_state=_random.getrandbits(32),
 96            #    )),
 97            #])
 98            tree = DecisionTreeClassifier(
 99                random_state=_random.getrandbits(32),
100                **predicate_hyperparameters,
101            )
102            tree.fit(X, Y)
103
104            transition_predicate = tree_to_predicate(tree)
105
106        # Store the transition predicate
107        transition_predicate_matrix[i][j] = transition_predicate
108
109        # Store the non-conditional transition probability
110        transition_histogram[i][j] = len(positive)
111
112    # Turn matrix into lists to adhere to API
113    transition_predicates = [
114        [
115            transition_predicate_matrix[i][j]
116            for j in transition_predicate_matrix[i].keys()
117        ]
118        for i in transition_predicate_matrix.keys()
119    ]
120    data = TransitionPredicates(
121        transition_predicates=transition_predicates,
122        transition_histogram=transition_histogram,
123    )
124    return data

Returns the adjacency matrix containing the predicates and histogram of transitions between states.