swmpo.transition_predicates
Characterization of the transitions between sets of states.
A transition predicate is a function that takes as input a vector corresponding to the concatenation of a state and action and maps it to number. If the number is greater than zero, then the predicate is true. Otherwise the predicate is said to be false.
1"""Characterization of the transitions between sets of states. 2 3A transition predicate is a function that takes as input a vector 4corresponding to the concatenation of a state and action and maps it 5to number. If the number is greater than zero, then the predicate is true. 6Otherwise the predicate is said to be false. 7""" 8from swmpo.transition import Transition 9from swmpo.transition import get_vector 10from collections import defaultdict 11from itertools import product 12from sklearn.preprocessing import StandardScaler 13from sklearn.pipeline import Pipeline 14from sklearn.neural_network import MLPClassifier 15from sklearn.tree import DecisionTreeClassifier 16from swmpo.tree_to_predicate import tree_to_predicate 17from swmpo.predicates import Predicate 18import random 19from dataclasses import dataclass 20 21 22@dataclass 23class TransitionPredicates: 24 transition_predicates: list[list[Predicate]] 25 transition_histogram: list[list[int]] 26 27 28def get_transition_predicates( 29 partition: list[list[Transition]], 30 predicate_hyperparameters: dict[str, float | int | str], 31 seed: str, 32) -> TransitionPredicates: 33 """Returns the adjacency matrix containing the 34 predicates and histogram of transitions between states.""" 35 _random = random.Random(seed) 36 state_indices = list(range(len(partition))) 37 transition_predicate_matrix = defaultdict(dict) 38 transition_histogram = [[0 for _ in state_indices] for _ in state_indices] 39 40 # SPEED: make source state existence look-up fast for 41 # predicate dataset construction. Otherwise checking if a 42 # "next state" is the "source state" of any transition is 43 # prohibitively slow. 44 source_states = [ 45 set([ 46 tuple(transition.source_state.tolist()) 47 for transition in subset 48 ]) 49 for subset in partition 50 ] 51 # /SPEED 52 53 for i, j in product(state_indices, state_indices): 54 # Partition the set of next states in node i into two sets: 55 # - All the next states in node i that are a source state 56 # in node j. These are states which say "yes, transition from 57 # node i to node j". 58 # - All the next states in node i that are not a source state 59 # in node j. These are states which say "no, do not transition from 60 # node i to node j". 61 positive = list() 62 negative = list() 63 for transition in partition[i]: 64 next_state = tuple(transition.next_state.tolist()) 65 is_next_source_in_j = next_state in source_states[j] 66 vector = get_vector(transition) 67 if is_next_source_in_j: 68 positive.append(vector.detach().numpy()) 69 else: 70 negative.append(vector.detach().numpy()) 71 72 # Make the dataset balanced 73 min_size = min(len(positive), len(negative)) 74 _random.shuffle(positive) 75 _random.shuffle(negative) 76 positive = positive[:min_size] 77 negative = negative[:min_size] 78 79 # Then, turn these two sets into classification problem. 80 X = positive + negative 81 Y = [1 for _ in positive] + [0 for _ in negative] 82 83 # Synthesize a transition predicate 84 if len(X) == 0: 85 transition_predicate = False 86 else: 87 #transition_predicate = Pipeline([ 88 # #("normalizer", StandardScaler()), 89 # #("classifier", MLPClassifier( 90 # # random_state=_random.getrandbits(32), 91 # # **predicate_hyperparameters, 92 # #)), 93 # ("classifier", DecisionTreeClassifier( 94 # random_state=_random.getrandbits(32), 95 # )), 96 #]) 97 tree = DecisionTreeClassifier( 98 random_state=_random.getrandbits(32), 99 **predicate_hyperparameters, 100 ) 101 tree.fit(X, Y) 102 103 transition_predicate = tree_to_predicate(tree) 104 105 # Store the transition predicate 106 transition_predicate_matrix[i][j] = transition_predicate 107 108 # Store the non-conditional transition probability 109 transition_histogram[i][j] = len(positive) 110 111 # Turn matrix into lists to adhere to API 112 transition_predicates = [ 113 [ 114 transition_predicate_matrix[i][j] 115 for j in transition_predicate_matrix[i].keys() 116 ] 117 for i in transition_predicate_matrix.keys() 118 ] 119 data = TransitionPredicates( 120 transition_predicates=transition_predicates, 121 transition_histogram=transition_histogram, 122 ) 123 return data
@dataclass
class
TransitionPredicates:
23@dataclass 24class TransitionPredicates: 25 transition_predicates: list[list[Predicate]] 26 transition_histogram: list[list[int]]
TransitionPredicates( transition_predicates: list[list[swmpo.predicates.And | swmpo.predicates.Or | swmpo.predicates.LessThan | swmpo.predicates.GreaterThan | bool]], transition_histogram: list[list[int]])
transition_predicates: list[list[swmpo.predicates.And | swmpo.predicates.Or | swmpo.predicates.LessThan | swmpo.predicates.GreaterThan | bool]]
def
get_transition_predicates( partition: list[list[swmpo.transition.Transition]], predicate_hyperparameters: dict[str, float | int | str], seed: str) -> TransitionPredicates:
29def get_transition_predicates( 30 partition: list[list[Transition]], 31 predicate_hyperparameters: dict[str, float | int | str], 32 seed: str, 33) -> TransitionPredicates: 34 """Returns the adjacency matrix containing the 35 predicates and histogram of transitions between states.""" 36 _random = random.Random(seed) 37 state_indices = list(range(len(partition))) 38 transition_predicate_matrix = defaultdict(dict) 39 transition_histogram = [[0 for _ in state_indices] for _ in state_indices] 40 41 # SPEED: make source state existence look-up fast for 42 # predicate dataset construction. Otherwise checking if a 43 # "next state" is the "source state" of any transition is 44 # prohibitively slow. 45 source_states = [ 46 set([ 47 tuple(transition.source_state.tolist()) 48 for transition in subset 49 ]) 50 for subset in partition 51 ] 52 # /SPEED 53 54 for i, j in product(state_indices, state_indices): 55 # Partition the set of next states in node i into two sets: 56 # - All the next states in node i that are a source state 57 # in node j. These are states which say "yes, transition from 58 # node i to node j". 59 # - All the next states in node i that are not a source state 60 # in node j. These are states which say "no, do not transition from 61 # node i to node j". 62 positive = list() 63 negative = list() 64 for transition in partition[i]: 65 next_state = tuple(transition.next_state.tolist()) 66 is_next_source_in_j = next_state in source_states[j] 67 vector = get_vector(transition) 68 if is_next_source_in_j: 69 positive.append(vector.detach().numpy()) 70 else: 71 negative.append(vector.detach().numpy()) 72 73 # Make the dataset balanced 74 min_size = min(len(positive), len(negative)) 75 _random.shuffle(positive) 76 _random.shuffle(negative) 77 positive = positive[:min_size] 78 negative = negative[:min_size] 79 80 # Then, turn these two sets into classification problem. 81 X = positive + negative 82 Y = [1 for _ in positive] + [0 for _ in negative] 83 84 # Synthesize a transition predicate 85 if len(X) == 0: 86 transition_predicate = False 87 else: 88 #transition_predicate = Pipeline([ 89 # #("normalizer", StandardScaler()), 90 # #("classifier", MLPClassifier( 91 # # random_state=_random.getrandbits(32), 92 # # **predicate_hyperparameters, 93 # #)), 94 # ("classifier", DecisionTreeClassifier( 95 # random_state=_random.getrandbits(32), 96 # )), 97 #]) 98 tree = DecisionTreeClassifier( 99 random_state=_random.getrandbits(32), 100 **predicate_hyperparameters, 101 ) 102 tree.fit(X, Y) 103 104 transition_predicate = tree_to_predicate(tree) 105 106 # Store the transition predicate 107 transition_predicate_matrix[i][j] = transition_predicate 108 109 # Store the non-conditional transition probability 110 transition_histogram[i][j] = len(positive) 111 112 # Turn matrix into lists to adhere to API 113 transition_predicates = [ 114 [ 115 transition_predicate_matrix[i][j] 116 for j in transition_predicate_matrix[i].keys() 117 ] 118 for i in transition_predicate_matrix.keys() 119 ] 120 data = TransitionPredicates( 121 transition_predicates=transition_predicates, 122 transition_histogram=transition_histogram, 123 ) 124 return data
Returns the adjacency matrix containing the predicates and histogram of transitions between states.