swmpo.state_machine
State machine synthesizer inspired by the STUN algorithm.
1"""State machine synthesizer inspired by the STUN algorithm.""" 2from collections import defaultdict 3from dataclasses import dataclass 4from itertools import product 5from swmpo.predicates import Predicate 6from swmpo.predicates import predicate_to_str 7from swmpo.predicates import str_to_predicate 8from swmpo.predicates import get_robustness_value 9from swmpo.transition import get_vector 10from swmpo.transition import Transition 11from swmpo.transition_predicates import get_transition_predicates 12from swmpo.partition import StatePartitionItem 13from swmpo.world_models.model import get_input_output_size 14from swmpo.world_models.world_model import serialize_model 15from swmpo.world_models.world_model import deserialize_model 16import torch 17from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas 18from matplotlib.figure import Figure 19from pathlib import Path 20import random 21import json 22import tempfile 23import shutil 24import zipfile 25 26from swmpo.world_models.world_model import WorldModel 27 28 29# Serialize local models 30LOCAL_MODELS_ARCHITECTURE = "local_models_structure.json" 31LOCAL_MODELS_DIR = "local_models" 32TRANSITION_PREDICATES_DIR = "transition_predicates" 33TRANSITION_HISTOGRAM_PATH = "transition_histogram.json" 34NUM_CONSECUTIVE = 0 35 36 37@dataclass 38class StateMachine: 39 """ 40 - transition_histogram[i][j]: how many times transition (i, j) was 41 taken in the training data. Only for logging purposes. 42 """ 43 local_models: list[WorldModel] 44 transition_predicates: list[list[Predicate]] 45 transition_histogram: list[list[int]] 46 local_models_hidden_sizes: list[int] 47 local_models_input_size: int 48 local_models_output_size: int 49 50 51@dataclass 52class StateMachineOptimizationResult: 53 state_machine: StateMachine 54 partition_loss_log: list[float] 55 56 57def state_machine_model( 58 state_machine: StateMachine, 59 state: torch.Tensor, 60 action: torch.Tensor, 61 current_node: int, 62 dt: float, 63) -> tuple[torch.Tensor, int]: 64 node_indices = list(range(len(state_machine.local_models))) 65 66 # Predict next state 67 local_model = state_machine.local_models[current_node] 68 next_state = local_model.get_prediction( 69 source_state=state, 70 action=action, 71 dt=dt, 72 ) 73 74 # Test each transition predicate 75 transition = Transition( 76 source_state=state, 77 action=action, 78 next_state=next_state, 79 ) 80 x = get_vector(transition) 81 acceptable_next_states = list[int]() 82 for j in node_indices: 83 predicate = state_machine.transition_predicates[current_node][j] 84 85 x_list = list[float](x.detach().flatten().cpu().numpy().tolist()) 86 robustness_value = get_robustness_value( 87 predicate, 88 x_list, 89 ) 90 91 if robustness_value > 0.0: 92 acceptable_next_states.append(j) 93 94 # Identify next state 95 if len(acceptable_next_states) == 0: 96 next_node = current_node 97 else: 98 next_node = acceptable_next_states[0] 99 return (next_state, next_node) 100 101 102def get_visited_states( 103 state_machine: StateMachine, 104 initial_state: int, 105 episode: list[Transition], 106 dt: float, 107) -> list[int]: 108 """Return the sequence of states that the state machine traversed when 109 processing the list of state-action tuples. 110 111 In case of ties, we arbitrarily choose the first accepted transition. 112 """ 113 current_node = initial_state 114 visited_nodes = [current_node] 115 116 consecutive_visits = 0 # Track consecutive visits for the current state 117 118 for transition in episode: 119 _, next_node = state_machine_model( 120 state_machine=state_machine, 121 state=transition.source_state, 122 action=transition.action, 123 current_node=current_node, 124 dt=dt, 125 ) 126 127 # Only transition if we have visited the current state at least 128 # 10 times consecutively 129 if next_node != current_node: 130 if consecutive_visits >= NUM_CONSECUTIVE: 131 # print(f"Transitioned from state {current_node} to {next_node} after {consecutive_visits} visits.") 132 current_node = next_node 133 consecutive_visits = 1 # Reset count for the new state 134 else: 135 consecutive_visits += 1 136 else: 137 consecutive_visits += 1 138 139 visited_nodes.append(current_node) 140 141 return visited_nodes 142 143 144def get_local_model_errors( 145 state_machine: StateMachine, 146 episode: list[Transition], 147 dt: float, 148 ) -> list[list[float]]: 149 """Return the list of errors of each state for each transition in the 150 episode. The returned value is a list of `len(state_machine.local_models)` 151 lists of size `len(episode)`.""" 152 episode_errors = list[list[float]]() 153 for transition in episode: 154 # Evaluate each model in the current transition 155 transition_errors = [ 156 model.get_raw_error(transition, dt) 157 for model in state_machine.local_models 158 ] 159 episode_errors.append(transition_errors) 160 return episode_errors 161 162 163def serialize_state_machine( 164 state_machine: StateMachine, 165 output_zip_path: Path, 166 ): 167 """Serialize the state machine to the given directory. 168 Output directory is assumed to exist. 169 """ 170 assert output_zip_path.suffix == ".zip" 171 with tempfile.TemporaryDirectory() as tmpdirname: 172 output_dir = Path(tmpdirname) 173 174 # Serialize local models 175 local_models_dir = output_dir/LOCAL_MODELS_DIR 176 local_models_dir.mkdir() 177 for i, model in enumerate(state_machine.local_models): 178 model_path = local_models_dir/f"{i}.zip" 179 serialize_model( 180 model, 181 output_zip_path=model_path, 182 ) 183 184 # Serialize local models architecture 185 architecture_path = output_dir/LOCAL_MODELS_ARCHITECTURE 186 architecture = dict( 187 local_models_input_size=state_machine.local_models_input_size, 188 local_models_output_size=state_machine.local_models_output_size, 189 local_models_hidden_sizes=state_machine.local_models_hidden_sizes, 190 ) 191 with open(architecture_path, "wt") as fp: 192 json.dump(architecture, fp) 193 194 # Serialize transition predicates 195 transition_predicates_dir = output_dir/TRANSITION_PREDICATES_DIR 196 transition_predicates_dir.mkdir() 197 state_indices = list(range(len(state_machine.local_models))) 198 for i, j in product(state_indices, state_indices): 199 # Serialize actual predicate 200 # ### If it's sklearn 201 # transition_predicate = state_machine.transition_predicates[i][j] 202 # transition_predicate_path = transition_predicates_dir/f"{i}-{j}.joblib" 203 # dump(transition_predicate, transition_predicate_path) 204 205 # ### If it's a swmpo predicate 206 # TODO: handle case where predicate is None 207 predicate = state_machine.transition_predicates[i][j] 208 file_id = f"{i}-{j}.json" 209 transition_predicate_path = transition_predicates_dir/file_id 210 json_str = predicate_to_str(predicate) 211 with open(transition_predicate_path, "wt") as fp: 212 _ = fp.write(json_str) 213 214 # Also save a diagram for visualization 215 #if transition_predicate is not None: 216 # transition_predicate_plot_path = transition_predicates_dir/f"{i}-{j}.svg" 217 # fig = Figure() 218 # _ = FigureCanvas(fig) 219 # ax = fig.add_subplot() 220 # sklearn.tree.plot_tree(transition_predicate, ax=ax) 221 # fig.savefig(transition_predicate_plot_path) 222 223 # Serialize transition histogram 224 transition_histogram_path = output_dir/TRANSITION_HISTOGRAM_PATH 225 with open(transition_histogram_path, "wt") as fp: 226 json.dump(state_machine.transition_histogram, fp) 227 228 # Plot transition matrix 229 transition_histogram_plot_path = output_dir/"transition_histogram.svg" 230 fig = Figure() 231 _ = FigureCanvas(fig) 232 ax = fig.add_subplot() 233 M = state_machine.transition_histogram 234 _ = ax.imshow(M) 235 for i in range(len(M)): 236 for j in range(len(M[i])): 237 _ = ax.text( 238 j, 239 i, 240 str(state_machine.transition_histogram[i][j]), 241 ha="center", va="center", color="w", 242 ) 243 ticks = list(range(len(M))) 244 _ = ax.set_xticks(ticks, labels=[str(t) for t in ticks]) 245 _ = ax.set_yticks(ticks, labels=[str(t) for t in ticks]) 246 _ = ax.set_title("Transition histogram") 247 fig.savefig(transition_histogram_plot_path) 248 249 _ = shutil.make_archive( 250 str(output_zip_path.with_suffix("")), 251 'zip', 252 output_dir 253 ) 254 255 256def deserialize_state_machine( 257 zip_path: Path, 258 ) -> StateMachine: 259 """Load the state machine in the given ZIP file written by 260 `swmpo.state_machine.serialize_state_machine`.""" 261 with tempfile.TemporaryDirectory() as tmpdirname: 262 output_dir = Path(tmpdirname) 263 264 with zipfile.ZipFile(zip_path, "r") as zip_ref: 265 zip_ref.extractall(output_dir) 266 267 architecture_path = output_dir/LOCAL_MODELS_ARCHITECTURE 268 with open(architecture_path, "rt") as fp: 269 architecture = json.load(fp) 270 local_models_hidden_sizes = list[int]( 271 architecture["local_models_hidden_sizes"] 272 ) 273 input_size = int(architecture["local_models_input_size"]) 274 output_size = int(architecture["local_models_output_size"]) 275 276 # Load local models 277 local_model_paths = (output_dir/LOCAL_MODELS_DIR).glob("*.zip") 278 sorted_local_model_paths = sorted( 279 local_model_paths, 280 key=lambda path: int(path.stem), 281 ) 282 local_models = list[WorldModel]() 283 for local_model_path in sorted_local_model_paths: 284 local_model = deserialize_model(local_model_path) 285 local_models.append(local_model) 286 287 # Load transition predicates 288 state_indices = list(range(len(local_models))) 289 transition_predicates_dir = output_dir/TRANSITION_PREDICATES_DIR 290 transition_predicates = defaultdict[int, dict[int, Predicate]](dict) 291 for i, j in product(state_indices, state_indices): 292 # TODO: put a JSON index file with the (i, j) -> pathmapping 293 #transition_predicate_path = transition_predicates_dir/f"{i}-{j}.joblib" 294 file_id = f"{i}-{j}.json" 295 transition_predicate_path = transition_predicates_dir/file_id 296 297 error_message = f"'{transition_predicate_path}' not found!" 298 assert transition_predicate_path.exists(), error_message 299 300 #predicate = load(transition_predicate_path) 301 with open(transition_predicate_path, "rt") as fp: 302 predicate_str = fp.read() 303 predicate = str_to_predicate(predicate_str) 304 transition_predicates[i][j] = predicate 305 306 transition_predicates = [ 307 [ 308 transition_predicates[i][j] 309 for j in state_indices 310 ] 311 for i in state_indices 312 ] 313 314 # Load transition histogram 315 transition_histogram_path = output_dir/TRANSITION_HISTOGRAM_PATH 316 with open(transition_histogram_path, "rt") as fp: 317 transition_histogram = list[list[int]](json.load(fp)) 318 319 state_machine = StateMachine( 320 local_models=local_models, 321 transition_predicates=transition_predicates, 322 transition_histogram=transition_histogram, 323 local_models_hidden_sizes=local_models_hidden_sizes, 324 local_models_input_size=input_size, 325 local_models_output_size=output_size, 326 ) 327 return state_machine 328 329 330def get_partition_induced_state_machine( 331 partition: list[StatePartitionItem], 332 predicate_hyperparameters: dict[str, float | int | str], 333 seed: str, 334) -> StateMachine: 335 _random = random.Random(seed) 336 337 # Characterize the transition predicates between the sets of the partition. 338 subsets = [ 339 item.subset 340 for item in partition 341 ] 342 transition_predicates = get_transition_predicates( 343 partition=subsets, 344 predicate_hyperparameters=predicate_hyperparameters, 345 seed=str(_random.random()), 346 ) 347 348 # Assemble state machine 349 local_models = [ 350 item.local_model 351 for item in partition 352 ] 353 all_transitions = [ 354 transition 355 for subset in subsets 356 for transition in subset 357 ] 358 input_size, output_size = get_input_output_size(all_transitions[0]) 359 all_hidden_sizes = [tuple(item.hidden_sizes) for item in partition] 360 error_message = "Local models have different hidden sizes!" 361 assert len(set(all_hidden_sizes)) == 1, error_message 362 assert partition 363 hidden_sizes = partition[0].hidden_sizes 364 state_machine = StateMachine( 365 local_models=local_models, 366 transition_predicates=transition_predicates.transition_predicates, 367 transition_histogram=transition_predicates.transition_histogram, 368 local_models_hidden_sizes=hidden_sizes, 369 local_models_input_size=input_size, 370 local_models_output_size=output_size, 371 ) 372 return state_machine 373 374 375def get_state_machine_errors( 376 state_machine: StateMachine, 377 episode: list[Transition], 378 initial_state: int, 379 dt: float, 380 ) -> list[float]: 381 """Return the errors of each state of the state machine.""" 382 current_node = initial_state 383 errors = list[float]() 384 for transition in episode: 385 predicted_next_state, next_node = state_machine_model( 386 state_machine=state_machine, 387 state=transition.source_state, 388 action=transition.action, 389 current_node=current_node, 390 dt=dt, 391 ) 392 393 # Log error 394 error = float( 395 (predicted_next_state - transition.next_state).norm().item() 396 ) 397 errors.append(error) 398 399 # Transition 400 current_node = next_node 401 return errors
LOCAL_MODELS_ARCHITECTURE =
'local_models_structure.json'
LOCAL_MODELS_DIR =
'local_models'
TRANSITION_PREDICATES_DIR =
'transition_predicates'
TRANSITION_HISTOGRAM_PATH =
'transition_histogram.json'
NUM_CONSECUTIVE =
0
@dataclass
class
StateMachine:
38@dataclass 39class StateMachine: 40 """ 41 - transition_histogram[i][j]: how many times transition (i, j) was 42 taken in the training data. Only for logging purposes. 43 """ 44 local_models: list[WorldModel] 45 transition_predicates: list[list[Predicate]] 46 transition_histogram: list[list[int]] 47 local_models_hidden_sizes: list[int] 48 local_models_input_size: int 49 local_models_output_size: int
- transition_histogram[i][j]: how many times transition (i, j) was taken in the training data. Only for logging purposes.
StateMachine( local_models: list[swmpo.world_models.world_model.WorldModel], transition_predicates: list[list[swmpo.predicates.And | swmpo.predicates.Or | swmpo.predicates.LessThan | swmpo.predicates.GreaterThan | bool]], transition_histogram: list[list[int]], local_models_hidden_sizes: list[int], local_models_input_size: int, local_models_output_size: int)
transition_predicates: list[list[swmpo.predicates.And | swmpo.predicates.Or | swmpo.predicates.LessThan | swmpo.predicates.GreaterThan | bool]]
@dataclass
class
StateMachineOptimizationResult:
52@dataclass 53class StateMachineOptimizationResult: 54 state_machine: StateMachine 55 partition_loss_log: list[float]
StateMachineOptimizationResult( state_machine: StateMachine, partition_loss_log: list[float])
state_machine: StateMachine
def
state_machine_model( state_machine: StateMachine, state: torch.Tensor, action: torch.Tensor, current_node: int, dt: float) -> tuple[torch.Tensor, int]:
58def state_machine_model( 59 state_machine: StateMachine, 60 state: torch.Tensor, 61 action: torch.Tensor, 62 current_node: int, 63 dt: float, 64) -> tuple[torch.Tensor, int]: 65 node_indices = list(range(len(state_machine.local_models))) 66 67 # Predict next state 68 local_model = state_machine.local_models[current_node] 69 next_state = local_model.get_prediction( 70 source_state=state, 71 action=action, 72 dt=dt, 73 ) 74 75 # Test each transition predicate 76 transition = Transition( 77 source_state=state, 78 action=action, 79 next_state=next_state, 80 ) 81 x = get_vector(transition) 82 acceptable_next_states = list[int]() 83 for j in node_indices: 84 predicate = state_machine.transition_predicates[current_node][j] 85 86 x_list = list[float](x.detach().flatten().cpu().numpy().tolist()) 87 robustness_value = get_robustness_value( 88 predicate, 89 x_list, 90 ) 91 92 if robustness_value > 0.0: 93 acceptable_next_states.append(j) 94 95 # Identify next state 96 if len(acceptable_next_states) == 0: 97 next_node = current_node 98 else: 99 next_node = acceptable_next_states[0] 100 return (next_state, next_node)
def
get_visited_states( state_machine: StateMachine, initial_state: int, episode: list[swmpo.transition.Transition], dt: float) -> list[int]:
103def get_visited_states( 104 state_machine: StateMachine, 105 initial_state: int, 106 episode: list[Transition], 107 dt: float, 108) -> list[int]: 109 """Return the sequence of states that the state machine traversed when 110 processing the list of state-action tuples. 111 112 In case of ties, we arbitrarily choose the first accepted transition. 113 """ 114 current_node = initial_state 115 visited_nodes = [current_node] 116 117 consecutive_visits = 0 # Track consecutive visits for the current state 118 119 for transition in episode: 120 _, next_node = state_machine_model( 121 state_machine=state_machine, 122 state=transition.source_state, 123 action=transition.action, 124 current_node=current_node, 125 dt=dt, 126 ) 127 128 # Only transition if we have visited the current state at least 129 # 10 times consecutively 130 if next_node != current_node: 131 if consecutive_visits >= NUM_CONSECUTIVE: 132 # print(f"Transitioned from state {current_node} to {next_node} after {consecutive_visits} visits.") 133 current_node = next_node 134 consecutive_visits = 1 # Reset count for the new state 135 else: 136 consecutive_visits += 1 137 else: 138 consecutive_visits += 1 139 140 visited_nodes.append(current_node) 141 142 return visited_nodes
Return the sequence of states that the state machine traversed when processing the list of state-action tuples.
In case of ties, we arbitrarily choose the first accepted transition.
def
get_local_model_errors( state_machine: StateMachine, episode: list[swmpo.transition.Transition], dt: float) -> list[list[float]]:
145def get_local_model_errors( 146 state_machine: StateMachine, 147 episode: list[Transition], 148 dt: float, 149 ) -> list[list[float]]: 150 """Return the list of errors of each state for each transition in the 151 episode. The returned value is a list of `len(state_machine.local_models)` 152 lists of size `len(episode)`.""" 153 episode_errors = list[list[float]]() 154 for transition in episode: 155 # Evaluate each model in the current transition 156 transition_errors = [ 157 model.get_raw_error(transition, dt) 158 for model in state_machine.local_models 159 ] 160 episode_errors.append(transition_errors) 161 return episode_errors
Return the list of errors of each state for each transition in the
episode. The returned value is a list of len(state_machine.local_models)
lists of size len(episode).
164def serialize_state_machine( 165 state_machine: StateMachine, 166 output_zip_path: Path, 167 ): 168 """Serialize the state machine to the given directory. 169 Output directory is assumed to exist. 170 """ 171 assert output_zip_path.suffix == ".zip" 172 with tempfile.TemporaryDirectory() as tmpdirname: 173 output_dir = Path(tmpdirname) 174 175 # Serialize local models 176 local_models_dir = output_dir/LOCAL_MODELS_DIR 177 local_models_dir.mkdir() 178 for i, model in enumerate(state_machine.local_models): 179 model_path = local_models_dir/f"{i}.zip" 180 serialize_model( 181 model, 182 output_zip_path=model_path, 183 ) 184 185 # Serialize local models architecture 186 architecture_path = output_dir/LOCAL_MODELS_ARCHITECTURE 187 architecture = dict( 188 local_models_input_size=state_machine.local_models_input_size, 189 local_models_output_size=state_machine.local_models_output_size, 190 local_models_hidden_sizes=state_machine.local_models_hidden_sizes, 191 ) 192 with open(architecture_path, "wt") as fp: 193 json.dump(architecture, fp) 194 195 # Serialize transition predicates 196 transition_predicates_dir = output_dir/TRANSITION_PREDICATES_DIR 197 transition_predicates_dir.mkdir() 198 state_indices = list(range(len(state_machine.local_models))) 199 for i, j in product(state_indices, state_indices): 200 # Serialize actual predicate 201 # ### If it's sklearn 202 # transition_predicate = state_machine.transition_predicates[i][j] 203 # transition_predicate_path = transition_predicates_dir/f"{i}-{j}.joblib" 204 # dump(transition_predicate, transition_predicate_path) 205 206 # ### If it's a swmpo predicate 207 # TODO: handle case where predicate is None 208 predicate = state_machine.transition_predicates[i][j] 209 file_id = f"{i}-{j}.json" 210 transition_predicate_path = transition_predicates_dir/file_id 211 json_str = predicate_to_str(predicate) 212 with open(transition_predicate_path, "wt") as fp: 213 _ = fp.write(json_str) 214 215 # Also save a diagram for visualization 216 #if transition_predicate is not None: 217 # transition_predicate_plot_path = transition_predicates_dir/f"{i}-{j}.svg" 218 # fig = Figure() 219 # _ = FigureCanvas(fig) 220 # ax = fig.add_subplot() 221 # sklearn.tree.plot_tree(transition_predicate, ax=ax) 222 # fig.savefig(transition_predicate_plot_path) 223 224 # Serialize transition histogram 225 transition_histogram_path = output_dir/TRANSITION_HISTOGRAM_PATH 226 with open(transition_histogram_path, "wt") as fp: 227 json.dump(state_machine.transition_histogram, fp) 228 229 # Plot transition matrix 230 transition_histogram_plot_path = output_dir/"transition_histogram.svg" 231 fig = Figure() 232 _ = FigureCanvas(fig) 233 ax = fig.add_subplot() 234 M = state_machine.transition_histogram 235 _ = ax.imshow(M) 236 for i in range(len(M)): 237 for j in range(len(M[i])): 238 _ = ax.text( 239 j, 240 i, 241 str(state_machine.transition_histogram[i][j]), 242 ha="center", va="center", color="w", 243 ) 244 ticks = list(range(len(M))) 245 _ = ax.set_xticks(ticks, labels=[str(t) for t in ticks]) 246 _ = ax.set_yticks(ticks, labels=[str(t) for t in ticks]) 247 _ = ax.set_title("Transition histogram") 248 fig.savefig(transition_histogram_plot_path) 249 250 _ = shutil.make_archive( 251 str(output_zip_path.with_suffix("")), 252 'zip', 253 output_dir 254 )
Serialize the state machine to the given directory. Output directory is assumed to exist.
257def deserialize_state_machine( 258 zip_path: Path, 259 ) -> StateMachine: 260 """Load the state machine in the given ZIP file written by 261 `swmpo.state_machine.serialize_state_machine`.""" 262 with tempfile.TemporaryDirectory() as tmpdirname: 263 output_dir = Path(tmpdirname) 264 265 with zipfile.ZipFile(zip_path, "r") as zip_ref: 266 zip_ref.extractall(output_dir) 267 268 architecture_path = output_dir/LOCAL_MODELS_ARCHITECTURE 269 with open(architecture_path, "rt") as fp: 270 architecture = json.load(fp) 271 local_models_hidden_sizes = list[int]( 272 architecture["local_models_hidden_sizes"] 273 ) 274 input_size = int(architecture["local_models_input_size"]) 275 output_size = int(architecture["local_models_output_size"]) 276 277 # Load local models 278 local_model_paths = (output_dir/LOCAL_MODELS_DIR).glob("*.zip") 279 sorted_local_model_paths = sorted( 280 local_model_paths, 281 key=lambda path: int(path.stem), 282 ) 283 local_models = list[WorldModel]() 284 for local_model_path in sorted_local_model_paths: 285 local_model = deserialize_model(local_model_path) 286 local_models.append(local_model) 287 288 # Load transition predicates 289 state_indices = list(range(len(local_models))) 290 transition_predicates_dir = output_dir/TRANSITION_PREDICATES_DIR 291 transition_predicates = defaultdict[int, dict[int, Predicate]](dict) 292 for i, j in product(state_indices, state_indices): 293 # TODO: put a JSON index file with the (i, j) -> pathmapping 294 #transition_predicate_path = transition_predicates_dir/f"{i}-{j}.joblib" 295 file_id = f"{i}-{j}.json" 296 transition_predicate_path = transition_predicates_dir/file_id 297 298 error_message = f"'{transition_predicate_path}' not found!" 299 assert transition_predicate_path.exists(), error_message 300 301 #predicate = load(transition_predicate_path) 302 with open(transition_predicate_path, "rt") as fp: 303 predicate_str = fp.read() 304 predicate = str_to_predicate(predicate_str) 305 transition_predicates[i][j] = predicate 306 307 transition_predicates = [ 308 [ 309 transition_predicates[i][j] 310 for j in state_indices 311 ] 312 for i in state_indices 313 ] 314 315 # Load transition histogram 316 transition_histogram_path = output_dir/TRANSITION_HISTOGRAM_PATH 317 with open(transition_histogram_path, "rt") as fp: 318 transition_histogram = list[list[int]](json.load(fp)) 319 320 state_machine = StateMachine( 321 local_models=local_models, 322 transition_predicates=transition_predicates, 323 transition_histogram=transition_histogram, 324 local_models_hidden_sizes=local_models_hidden_sizes, 325 local_models_input_size=input_size, 326 local_models_output_size=output_size, 327 ) 328 return state_machine
Load the state machine in the given ZIP file written by
swmpo.state_machine.serialize_state_machine.
def
get_partition_induced_state_machine( partition: list[swmpo.partition.StatePartitionItem], predicate_hyperparameters: dict[str, float | int | str], seed: str) -> StateMachine:
331def get_partition_induced_state_machine( 332 partition: list[StatePartitionItem], 333 predicate_hyperparameters: dict[str, float | int | str], 334 seed: str, 335) -> StateMachine: 336 _random = random.Random(seed) 337 338 # Characterize the transition predicates between the sets of the partition. 339 subsets = [ 340 item.subset 341 for item in partition 342 ] 343 transition_predicates = get_transition_predicates( 344 partition=subsets, 345 predicate_hyperparameters=predicate_hyperparameters, 346 seed=str(_random.random()), 347 ) 348 349 # Assemble state machine 350 local_models = [ 351 item.local_model 352 for item in partition 353 ] 354 all_transitions = [ 355 transition 356 for subset in subsets 357 for transition in subset 358 ] 359 input_size, output_size = get_input_output_size(all_transitions[0]) 360 all_hidden_sizes = [tuple(item.hidden_sizes) for item in partition] 361 error_message = "Local models have different hidden sizes!" 362 assert len(set(all_hidden_sizes)) == 1, error_message 363 assert partition 364 hidden_sizes = partition[0].hidden_sizes 365 state_machine = StateMachine( 366 local_models=local_models, 367 transition_predicates=transition_predicates.transition_predicates, 368 transition_histogram=transition_predicates.transition_histogram, 369 local_models_hidden_sizes=hidden_sizes, 370 local_models_input_size=input_size, 371 local_models_output_size=output_size, 372 ) 373 return state_machine
def
get_state_machine_errors( state_machine: StateMachine, episode: list[swmpo.transition.Transition], initial_state: int, dt: float) -> list[float]:
376def get_state_machine_errors( 377 state_machine: StateMachine, 378 episode: list[Transition], 379 initial_state: int, 380 dt: float, 381 ) -> list[float]: 382 """Return the errors of each state of the state machine.""" 383 current_node = initial_state 384 errors = list[float]() 385 for transition in episode: 386 predicted_next_state, next_node = state_machine_model( 387 state_machine=state_machine, 388 state=transition.source_state, 389 action=transition.action, 390 current_node=current_node, 391 dt=dt, 392 ) 393 394 # Log error 395 error = float( 396 (predicted_next_state - transition.next_state).norm().item() 397 ) 398 errors.append(error) 399 400 # Transition 401 current_node = next_node 402 return errors
Return the errors of each state of the state machine.