swmpo.state_machine

State machine synthesizer inspired by the STUN algorithm.

  1"""State machine synthesizer inspired by the STUN algorithm."""
  2from collections import defaultdict
  3from dataclasses import dataclass
  4from itertools import product
  5from swmpo.predicates import Predicate
  6from swmpo.predicates import predicate_to_str
  7from swmpo.predicates import str_to_predicate
  8from swmpo.predicates import get_robustness_value
  9from swmpo.transition import get_vector
 10from swmpo.transition import Transition
 11from swmpo.transition_predicates import get_transition_predicates
 12from swmpo.partition import StatePartitionItem
 13from swmpo.world_models.model import get_input_output_size
 14from swmpo.world_models.world_model import serialize_model
 15from swmpo.world_models.world_model import deserialize_model
 16import torch
 17from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
 18from matplotlib.figure import Figure
 19from pathlib import Path
 20import random
 21import json
 22import tempfile
 23import shutil
 24import zipfile
 25
 26from swmpo.world_models.world_model import WorldModel
 27
 28
 29# Serialize local models
 30LOCAL_MODELS_ARCHITECTURE = "local_models_structure.json"
 31LOCAL_MODELS_DIR = "local_models"
 32TRANSITION_PREDICATES_DIR = "transition_predicates"
 33TRANSITION_HISTOGRAM_PATH = "transition_histogram.json"
 34NUM_CONSECUTIVE = 0
 35
 36
 37@dataclass
 38class StateMachine:
 39    """
 40    - transition_histogram[i][j]: how many times transition (i, j) was
 41      taken in the training data. Only for logging purposes.
 42    """
 43    local_models: list[WorldModel]
 44    transition_predicates: list[list[Predicate]]
 45    transition_histogram: list[list[int]]
 46    local_models_hidden_sizes: list[int]
 47    local_models_input_size: int
 48    local_models_output_size: int
 49
 50
 51@dataclass
 52class StateMachineOptimizationResult:
 53    state_machine: StateMachine
 54    partition_loss_log: list[float]
 55
 56
 57def state_machine_model(
 58    state_machine: StateMachine,
 59    state: torch.Tensor,
 60    action: torch.Tensor,
 61    current_node: int,
 62    dt: float,
 63) -> tuple[torch.Tensor, int]:
 64    node_indices = list(range(len(state_machine.local_models)))
 65
 66    # Predict next state
 67    local_model = state_machine.local_models[current_node]
 68    next_state = local_model.get_prediction(
 69        source_state=state,
 70        action=action,
 71        dt=dt,
 72    )
 73
 74    # Test each transition predicate
 75    transition = Transition(
 76        source_state=state,
 77        action=action,
 78        next_state=next_state,
 79    )
 80    x = get_vector(transition)
 81    acceptable_next_states = list[int]()
 82    for j in node_indices:
 83        predicate = state_machine.transition_predicates[current_node][j]
 84
 85        x_list = list[float](x.detach().flatten().cpu().numpy().tolist())
 86        robustness_value = get_robustness_value(
 87            predicate,
 88            x_list,
 89        )
 90
 91        if robustness_value > 0.0:
 92            acceptable_next_states.append(j)
 93
 94    # Identify next state
 95    if len(acceptable_next_states) == 0:
 96        next_node = current_node
 97    else:
 98        next_node = acceptable_next_states[0]
 99    return (next_state, next_node)
100
101
102def get_visited_states(
103    state_machine: StateMachine,
104    initial_state: int,
105    episode: list[Transition],
106    dt: float,
107) -> list[int]:
108    """Return the sequence of states that the state machine traversed when
109    processing the list of state-action tuples.
110
111    In case of ties, we arbitrarily choose the first accepted transition.
112    """
113    current_node = initial_state
114    visited_nodes = [current_node]
115
116    consecutive_visits = 0  # Track consecutive visits for the current state
117
118    for transition in episode:
119        _, next_node = state_machine_model(
120            state_machine=state_machine,
121            state=transition.source_state,
122            action=transition.action,
123            current_node=current_node,
124            dt=dt,
125        )
126
127        # Only transition if we have visited the current state at least
128        # 10 times consecutively
129        if next_node != current_node:
130            if consecutive_visits >= NUM_CONSECUTIVE:
131                # print(f"Transitioned from state {current_node} to {next_node} after {consecutive_visits} visits.")
132                current_node = next_node
133                consecutive_visits = 1  # Reset count for the new state
134            else:
135                consecutive_visits += 1
136        else:
137            consecutive_visits += 1
138
139        visited_nodes.append(current_node)
140
141    return visited_nodes
142
143
144def get_local_model_errors(
145        state_machine: StateMachine,
146        episode: list[Transition],
147        dt: float,
148        ) -> list[list[float]]:
149    """Return the list of errors of each state for each transition in the
150    episode. The returned value is a list of `len(state_machine.local_models)`
151    lists of size `len(episode)`."""
152    episode_errors = list[list[float]]()
153    for transition in episode:
154        # Evaluate each model in the current transition
155        transition_errors = [
156            model.get_raw_error(transition, dt)
157            for model in state_machine.local_models
158        ]
159        episode_errors.append(transition_errors)
160    return episode_errors
161
162
163def serialize_state_machine(
164        state_machine: StateMachine,
165        output_zip_path: Path,
166        ):
167    """Serialize the state machine to the given directory.
168    Output directory is assumed to exist.
169    """
170    assert output_zip_path.suffix == ".zip"
171    with tempfile.TemporaryDirectory() as tmpdirname:
172        output_dir = Path(tmpdirname)
173
174        # Serialize local models
175        local_models_dir = output_dir/LOCAL_MODELS_DIR
176        local_models_dir.mkdir()
177        for i, model in enumerate(state_machine.local_models):
178            model_path = local_models_dir/f"{i}.zip"
179            serialize_model(
180                model,
181                output_zip_path=model_path,
182            )
183
184        # Serialize local models architecture
185        architecture_path = output_dir/LOCAL_MODELS_ARCHITECTURE
186        architecture = dict(
187            local_models_input_size=state_machine.local_models_input_size,
188            local_models_output_size=state_machine.local_models_output_size,
189            local_models_hidden_sizes=state_machine.local_models_hidden_sizes,
190        )
191        with open(architecture_path, "wt") as fp:
192            json.dump(architecture, fp)
193
194        # Serialize transition predicates
195        transition_predicates_dir = output_dir/TRANSITION_PREDICATES_DIR
196        transition_predicates_dir.mkdir()
197        state_indices = list(range(len(state_machine.local_models)))
198        for i, j in product(state_indices, state_indices):
199            # Serialize actual predicate
200            # ### If it's sklearn
201            # transition_predicate = state_machine.transition_predicates[i][j]
202            # transition_predicate_path = transition_predicates_dir/f"{i}-{j}.joblib"
203            # dump(transition_predicate, transition_predicate_path)
204
205            # ### If it's a swmpo predicate
206            # TODO: handle case where predicate is None
207            predicate = state_machine.transition_predicates[i][j]
208            file_id = f"{i}-{j}.json"
209            transition_predicate_path = transition_predicates_dir/file_id
210            json_str = predicate_to_str(predicate)
211            with open(transition_predicate_path, "wt") as fp:
212                _ = fp.write(json_str)
213
214            # Also save a diagram for visualization
215            #if transition_predicate is not None:
216            #    transition_predicate_plot_path = transition_predicates_dir/f"{i}-{j}.svg"
217            #    fig = Figure()
218            #    _ = FigureCanvas(fig)
219            #    ax = fig.add_subplot()
220            #    sklearn.tree.plot_tree(transition_predicate, ax=ax)
221            #    fig.savefig(transition_predicate_plot_path)
222
223        # Serialize transition histogram
224        transition_histogram_path = output_dir/TRANSITION_HISTOGRAM_PATH
225        with open(transition_histogram_path, "wt") as fp:
226            json.dump(state_machine.transition_histogram, fp)
227
228        # Plot transition matrix
229        transition_histogram_plot_path = output_dir/"transition_histogram.svg"
230        fig = Figure()
231        _ = FigureCanvas(fig)
232        ax = fig.add_subplot()
233        M = state_machine.transition_histogram
234        _ = ax.imshow(M)
235        for i in range(len(M)):
236            for j in range(len(M[i])):
237                _ = ax.text(
238                    j,
239                    i,
240                    str(state_machine.transition_histogram[i][j]),
241                    ha="center", va="center", color="w",
242                )
243        ticks = list(range(len(M)))
244        _ = ax.set_xticks(ticks, labels=[str(t) for t in ticks])
245        _ = ax.set_yticks(ticks, labels=[str(t) for t in ticks])
246        _ = ax.set_title("Transition histogram")
247        fig.savefig(transition_histogram_plot_path)
248
249        _ = shutil.make_archive(
250            str(output_zip_path.with_suffix("")),
251            'zip',
252            output_dir
253        )
254
255
256def deserialize_state_machine(
257        zip_path: Path,
258        ) -> StateMachine:
259    """Load the state machine in the given ZIP file written by
260    `swmpo.state_machine.serialize_state_machine`."""
261    with tempfile.TemporaryDirectory() as tmpdirname:
262        output_dir = Path(tmpdirname)
263
264        with zipfile.ZipFile(zip_path, "r") as zip_ref:
265            zip_ref.extractall(output_dir)
266
267        architecture_path = output_dir/LOCAL_MODELS_ARCHITECTURE
268        with open(architecture_path, "rt") as fp:
269            architecture = json.load(fp)
270            local_models_hidden_sizes = list[int](
271                architecture["local_models_hidden_sizes"]
272            )
273            input_size = int(architecture["local_models_input_size"])
274            output_size = int(architecture["local_models_output_size"])
275
276        # Load local models
277        local_model_paths = (output_dir/LOCAL_MODELS_DIR).glob("*.zip")
278        sorted_local_model_paths = sorted(
279            local_model_paths,
280            key=lambda path: int(path.stem),
281        )
282        local_models = list[WorldModel]()
283        for local_model_path in sorted_local_model_paths:
284            local_model = deserialize_model(local_model_path)
285            local_models.append(local_model)
286
287        # Load transition predicates
288        state_indices = list(range(len(local_models)))
289        transition_predicates_dir = output_dir/TRANSITION_PREDICATES_DIR
290        transition_predicates = defaultdict[int, dict[int, Predicate]](dict)
291        for i, j in product(state_indices, state_indices):
292            # TODO: put a JSON index file with the (i, j) -> pathmapping
293            #transition_predicate_path = transition_predicates_dir/f"{i}-{j}.joblib"
294            file_id = f"{i}-{j}.json"
295            transition_predicate_path = transition_predicates_dir/file_id
296
297            error_message = f"'{transition_predicate_path}' not found!"
298            assert transition_predicate_path.exists(), error_message
299
300            #predicate = load(transition_predicate_path)
301            with open(transition_predicate_path, "rt") as fp:
302                predicate_str = fp.read()
303            predicate = str_to_predicate(predicate_str)
304            transition_predicates[i][j] = predicate
305
306        transition_predicates = [
307            [
308                transition_predicates[i][j]
309                for j in state_indices
310            ]
311            for i in state_indices
312        ]
313
314        # Load transition histogram
315        transition_histogram_path = output_dir/TRANSITION_HISTOGRAM_PATH
316        with open(transition_histogram_path, "rt") as fp:
317            transition_histogram = list[list[int]](json.load(fp))
318
319    state_machine = StateMachine(
320        local_models=local_models,
321        transition_predicates=transition_predicates,
322        transition_histogram=transition_histogram,
323        local_models_hidden_sizes=local_models_hidden_sizes,
324        local_models_input_size=input_size,
325        local_models_output_size=output_size,
326    )
327    return state_machine
328
329
330def get_partition_induced_state_machine(
331    partition: list[StatePartitionItem],
332    predicate_hyperparameters: dict[str, float | int | str],
333    seed: str,
334) -> StateMachine:
335    _random = random.Random(seed)
336
337    # Characterize the transition predicates between the sets of the partition.
338    subsets = [
339        item.subset
340        for item in partition
341    ]
342    transition_predicates = get_transition_predicates(
343        partition=subsets,
344        predicate_hyperparameters=predicate_hyperparameters,
345        seed=str(_random.random()),
346    )
347
348    # Assemble state machine
349    local_models = [
350        item.local_model
351        for item in partition
352    ]
353    all_transitions = [
354        transition
355        for subset in subsets
356        for transition in subset
357    ]
358    input_size, output_size = get_input_output_size(all_transitions[0])
359    all_hidden_sizes = [tuple(item.hidden_sizes) for item in partition]
360    error_message = "Local models have different hidden sizes!"
361    assert len(set(all_hidden_sizes)) == 1, error_message
362    assert partition
363    hidden_sizes = partition[0].hidden_sizes
364    state_machine = StateMachine(
365        local_models=local_models,
366        transition_predicates=transition_predicates.transition_predicates,
367        transition_histogram=transition_predicates.transition_histogram,
368        local_models_hidden_sizes=hidden_sizes,
369        local_models_input_size=input_size,
370        local_models_output_size=output_size,
371    )
372    return state_machine
373
374
375def get_state_machine_errors(
376        state_machine: StateMachine,
377        episode: list[Transition],
378        initial_state: int,
379        dt: float,
380        ) -> list[float]:
381    """Return the errors of each state of the state machine."""
382    current_node = initial_state
383    errors = list[float]()
384    for transition in episode:
385        predicted_next_state, next_node = state_machine_model(
386            state_machine=state_machine,
387            state=transition.source_state,
388            action=transition.action,
389            current_node=current_node,
390            dt=dt,
391        )
392
393        # Log error
394        error = float(
395            (predicted_next_state - transition.next_state).norm().item()
396        )
397        errors.append(error)
398
399        # Transition
400        current_node = next_node
401    return errors
LOCAL_MODELS_ARCHITECTURE = 'local_models_structure.json'
LOCAL_MODELS_DIR = 'local_models'
TRANSITION_PREDICATES_DIR = 'transition_predicates'
TRANSITION_HISTOGRAM_PATH = 'transition_histogram.json'
NUM_CONSECUTIVE = 0
@dataclass
class StateMachine:
38@dataclass
39class StateMachine:
40    """
41    - transition_histogram[i][j]: how many times transition (i, j) was
42      taken in the training data. Only for logging purposes.
43    """
44    local_models: list[WorldModel]
45    transition_predicates: list[list[Predicate]]
46    transition_histogram: list[list[int]]
47    local_models_hidden_sizes: list[int]
48    local_models_input_size: int
49    local_models_output_size: int
  • transition_histogram[i][j]: how many times transition (i, j) was taken in the training data. Only for logging purposes.
StateMachine( local_models: list[swmpo.world_models.world_model.WorldModel], transition_predicates: list[list[swmpo.predicates.And | swmpo.predicates.Or | swmpo.predicates.LessThan | swmpo.predicates.GreaterThan | bool]], transition_histogram: list[list[int]], local_models_hidden_sizes: list[int], local_models_input_size: int, local_models_output_size: int)
local_models: list[swmpo.world_models.world_model.WorldModel]
transition_histogram: list[list[int]]
local_models_hidden_sizes: list[int]
local_models_input_size: int
local_models_output_size: int
@dataclass
class StateMachineOptimizationResult:
52@dataclass
53class StateMachineOptimizationResult:
54    state_machine: StateMachine
55    partition_loss_log: list[float]
StateMachineOptimizationResult( state_machine: StateMachine, partition_loss_log: list[float])
state_machine: StateMachine
partition_loss_log: list[float]
def state_machine_model( state_machine: StateMachine, state: torch.Tensor, action: torch.Tensor, current_node: int, dt: float) -> tuple[torch.Tensor, int]:
 58def state_machine_model(
 59    state_machine: StateMachine,
 60    state: torch.Tensor,
 61    action: torch.Tensor,
 62    current_node: int,
 63    dt: float,
 64) -> tuple[torch.Tensor, int]:
 65    node_indices = list(range(len(state_machine.local_models)))
 66
 67    # Predict next state
 68    local_model = state_machine.local_models[current_node]
 69    next_state = local_model.get_prediction(
 70        source_state=state,
 71        action=action,
 72        dt=dt,
 73    )
 74
 75    # Test each transition predicate
 76    transition = Transition(
 77        source_state=state,
 78        action=action,
 79        next_state=next_state,
 80    )
 81    x = get_vector(transition)
 82    acceptable_next_states = list[int]()
 83    for j in node_indices:
 84        predicate = state_machine.transition_predicates[current_node][j]
 85
 86        x_list = list[float](x.detach().flatten().cpu().numpy().tolist())
 87        robustness_value = get_robustness_value(
 88            predicate,
 89            x_list,
 90        )
 91
 92        if robustness_value > 0.0:
 93            acceptable_next_states.append(j)
 94
 95    # Identify next state
 96    if len(acceptable_next_states) == 0:
 97        next_node = current_node
 98    else:
 99        next_node = acceptable_next_states[0]
100    return (next_state, next_node)
def get_visited_states( state_machine: StateMachine, initial_state: int, episode: list[swmpo.transition.Transition], dt: float) -> list[int]:
103def get_visited_states(
104    state_machine: StateMachine,
105    initial_state: int,
106    episode: list[Transition],
107    dt: float,
108) -> list[int]:
109    """Return the sequence of states that the state machine traversed when
110    processing the list of state-action tuples.
111
112    In case of ties, we arbitrarily choose the first accepted transition.
113    """
114    current_node = initial_state
115    visited_nodes = [current_node]
116
117    consecutive_visits = 0  # Track consecutive visits for the current state
118
119    for transition in episode:
120        _, next_node = state_machine_model(
121            state_machine=state_machine,
122            state=transition.source_state,
123            action=transition.action,
124            current_node=current_node,
125            dt=dt,
126        )
127
128        # Only transition if we have visited the current state at least
129        # 10 times consecutively
130        if next_node != current_node:
131            if consecutive_visits >= NUM_CONSECUTIVE:
132                # print(f"Transitioned from state {current_node} to {next_node} after {consecutive_visits} visits.")
133                current_node = next_node
134                consecutive_visits = 1  # Reset count for the new state
135            else:
136                consecutive_visits += 1
137        else:
138            consecutive_visits += 1
139
140        visited_nodes.append(current_node)
141
142    return visited_nodes

Return the sequence of states that the state machine traversed when processing the list of state-action tuples.

In case of ties, we arbitrarily choose the first accepted transition.

def get_local_model_errors( state_machine: StateMachine, episode: list[swmpo.transition.Transition], dt: float) -> list[list[float]]:
145def get_local_model_errors(
146        state_machine: StateMachine,
147        episode: list[Transition],
148        dt: float,
149        ) -> list[list[float]]:
150    """Return the list of errors of each state for each transition in the
151    episode. The returned value is a list of `len(state_machine.local_models)`
152    lists of size `len(episode)`."""
153    episode_errors = list[list[float]]()
154    for transition in episode:
155        # Evaluate each model in the current transition
156        transition_errors = [
157            model.get_raw_error(transition, dt)
158            for model in state_machine.local_models
159        ]
160        episode_errors.append(transition_errors)
161    return episode_errors

Return the list of errors of each state for each transition in the episode. The returned value is a list of len(state_machine.local_models) lists of size len(episode).

def serialize_state_machine( state_machine: StateMachine, output_zip_path: pathlib.Path):
164def serialize_state_machine(
165        state_machine: StateMachine,
166        output_zip_path: Path,
167        ):
168    """Serialize the state machine to the given directory.
169    Output directory is assumed to exist.
170    """
171    assert output_zip_path.suffix == ".zip"
172    with tempfile.TemporaryDirectory() as tmpdirname:
173        output_dir = Path(tmpdirname)
174
175        # Serialize local models
176        local_models_dir = output_dir/LOCAL_MODELS_DIR
177        local_models_dir.mkdir()
178        for i, model in enumerate(state_machine.local_models):
179            model_path = local_models_dir/f"{i}.zip"
180            serialize_model(
181                model,
182                output_zip_path=model_path,
183            )
184
185        # Serialize local models architecture
186        architecture_path = output_dir/LOCAL_MODELS_ARCHITECTURE
187        architecture = dict(
188            local_models_input_size=state_machine.local_models_input_size,
189            local_models_output_size=state_machine.local_models_output_size,
190            local_models_hidden_sizes=state_machine.local_models_hidden_sizes,
191        )
192        with open(architecture_path, "wt") as fp:
193            json.dump(architecture, fp)
194
195        # Serialize transition predicates
196        transition_predicates_dir = output_dir/TRANSITION_PREDICATES_DIR
197        transition_predicates_dir.mkdir()
198        state_indices = list(range(len(state_machine.local_models)))
199        for i, j in product(state_indices, state_indices):
200            # Serialize actual predicate
201            # ### If it's sklearn
202            # transition_predicate = state_machine.transition_predicates[i][j]
203            # transition_predicate_path = transition_predicates_dir/f"{i}-{j}.joblib"
204            # dump(transition_predicate, transition_predicate_path)
205
206            # ### If it's a swmpo predicate
207            # TODO: handle case where predicate is None
208            predicate = state_machine.transition_predicates[i][j]
209            file_id = f"{i}-{j}.json"
210            transition_predicate_path = transition_predicates_dir/file_id
211            json_str = predicate_to_str(predicate)
212            with open(transition_predicate_path, "wt") as fp:
213                _ = fp.write(json_str)
214
215            # Also save a diagram for visualization
216            #if transition_predicate is not None:
217            #    transition_predicate_plot_path = transition_predicates_dir/f"{i}-{j}.svg"
218            #    fig = Figure()
219            #    _ = FigureCanvas(fig)
220            #    ax = fig.add_subplot()
221            #    sklearn.tree.plot_tree(transition_predicate, ax=ax)
222            #    fig.savefig(transition_predicate_plot_path)
223
224        # Serialize transition histogram
225        transition_histogram_path = output_dir/TRANSITION_HISTOGRAM_PATH
226        with open(transition_histogram_path, "wt") as fp:
227            json.dump(state_machine.transition_histogram, fp)
228
229        # Plot transition matrix
230        transition_histogram_plot_path = output_dir/"transition_histogram.svg"
231        fig = Figure()
232        _ = FigureCanvas(fig)
233        ax = fig.add_subplot()
234        M = state_machine.transition_histogram
235        _ = ax.imshow(M)
236        for i in range(len(M)):
237            for j in range(len(M[i])):
238                _ = ax.text(
239                    j,
240                    i,
241                    str(state_machine.transition_histogram[i][j]),
242                    ha="center", va="center", color="w",
243                )
244        ticks = list(range(len(M)))
245        _ = ax.set_xticks(ticks, labels=[str(t) for t in ticks])
246        _ = ax.set_yticks(ticks, labels=[str(t) for t in ticks])
247        _ = ax.set_title("Transition histogram")
248        fig.savefig(transition_histogram_plot_path)
249
250        _ = shutil.make_archive(
251            str(output_zip_path.with_suffix("")),
252            'zip',
253            output_dir
254        )

Serialize the state machine to the given directory. Output directory is assumed to exist.

def deserialize_state_machine(zip_path: pathlib.Path) -> StateMachine:
257def deserialize_state_machine(
258        zip_path: Path,
259        ) -> StateMachine:
260    """Load the state machine in the given ZIP file written by
261    `swmpo.state_machine.serialize_state_machine`."""
262    with tempfile.TemporaryDirectory() as tmpdirname:
263        output_dir = Path(tmpdirname)
264
265        with zipfile.ZipFile(zip_path, "r") as zip_ref:
266            zip_ref.extractall(output_dir)
267
268        architecture_path = output_dir/LOCAL_MODELS_ARCHITECTURE
269        with open(architecture_path, "rt") as fp:
270            architecture = json.load(fp)
271            local_models_hidden_sizes = list[int](
272                architecture["local_models_hidden_sizes"]
273            )
274            input_size = int(architecture["local_models_input_size"])
275            output_size = int(architecture["local_models_output_size"])
276
277        # Load local models
278        local_model_paths = (output_dir/LOCAL_MODELS_DIR).glob("*.zip")
279        sorted_local_model_paths = sorted(
280            local_model_paths,
281            key=lambda path: int(path.stem),
282        )
283        local_models = list[WorldModel]()
284        for local_model_path in sorted_local_model_paths:
285            local_model = deserialize_model(local_model_path)
286            local_models.append(local_model)
287
288        # Load transition predicates
289        state_indices = list(range(len(local_models)))
290        transition_predicates_dir = output_dir/TRANSITION_PREDICATES_DIR
291        transition_predicates = defaultdict[int, dict[int, Predicate]](dict)
292        for i, j in product(state_indices, state_indices):
293            # TODO: put a JSON index file with the (i, j) -> pathmapping
294            #transition_predicate_path = transition_predicates_dir/f"{i}-{j}.joblib"
295            file_id = f"{i}-{j}.json"
296            transition_predicate_path = transition_predicates_dir/file_id
297
298            error_message = f"'{transition_predicate_path}' not found!"
299            assert transition_predicate_path.exists(), error_message
300
301            #predicate = load(transition_predicate_path)
302            with open(transition_predicate_path, "rt") as fp:
303                predicate_str = fp.read()
304            predicate = str_to_predicate(predicate_str)
305            transition_predicates[i][j] = predicate
306
307        transition_predicates = [
308            [
309                transition_predicates[i][j]
310                for j in state_indices
311            ]
312            for i in state_indices
313        ]
314
315        # Load transition histogram
316        transition_histogram_path = output_dir/TRANSITION_HISTOGRAM_PATH
317        with open(transition_histogram_path, "rt") as fp:
318            transition_histogram = list[list[int]](json.load(fp))
319
320    state_machine = StateMachine(
321        local_models=local_models,
322        transition_predicates=transition_predicates,
323        transition_histogram=transition_histogram,
324        local_models_hidden_sizes=local_models_hidden_sizes,
325        local_models_input_size=input_size,
326        local_models_output_size=output_size,
327    )
328    return state_machine

Load the state machine in the given ZIP file written by swmpo.state_machine.serialize_state_machine.

def get_partition_induced_state_machine( partition: list[swmpo.partition.StatePartitionItem], predicate_hyperparameters: dict[str, float | int | str], seed: str) -> StateMachine:
331def get_partition_induced_state_machine(
332    partition: list[StatePartitionItem],
333    predicate_hyperparameters: dict[str, float | int | str],
334    seed: str,
335) -> StateMachine:
336    _random = random.Random(seed)
337
338    # Characterize the transition predicates between the sets of the partition.
339    subsets = [
340        item.subset
341        for item in partition
342    ]
343    transition_predicates = get_transition_predicates(
344        partition=subsets,
345        predicate_hyperparameters=predicate_hyperparameters,
346        seed=str(_random.random()),
347    )
348
349    # Assemble state machine
350    local_models = [
351        item.local_model
352        for item in partition
353    ]
354    all_transitions = [
355        transition
356        for subset in subsets
357        for transition in subset
358    ]
359    input_size, output_size = get_input_output_size(all_transitions[0])
360    all_hidden_sizes = [tuple(item.hidden_sizes) for item in partition]
361    error_message = "Local models have different hidden sizes!"
362    assert len(set(all_hidden_sizes)) == 1, error_message
363    assert partition
364    hidden_sizes = partition[0].hidden_sizes
365    state_machine = StateMachine(
366        local_models=local_models,
367        transition_predicates=transition_predicates.transition_predicates,
368        transition_histogram=transition_predicates.transition_histogram,
369        local_models_hidden_sizes=hidden_sizes,
370        local_models_input_size=input_size,
371        local_models_output_size=output_size,
372    )
373    return state_machine
def get_state_machine_errors( state_machine: StateMachine, episode: list[swmpo.transition.Transition], initial_state: int, dt: float) -> list[float]:
376def get_state_machine_errors(
377        state_machine: StateMachine,
378        episode: list[Transition],
379        initial_state: int,
380        dt: float,
381        ) -> list[float]:
382    """Return the errors of each state of the state machine."""
383    current_node = initial_state
384    errors = list[float]()
385    for transition in episode:
386        predicted_next_state, next_node = state_machine_model(
387            state_machine=state_machine,
388            state=transition.source_state,
389            action=transition.action,
390            current_node=current_node,
391            dt=dt,
392        )
393
394        # Log error
395        error = float(
396            (predicted_next_state - transition.next_state).norm().item()
397        )
398        errors.append(error)
399
400        # Transition
401        current_node = next_node
402    return errors

Return the errors of each state of the state machine.