swmpo.plotting

Plotting utilities.

  1"""Plotting utilities."""
  2from swmpo.state_machine import StateMachine
  3from pathlib import Path
  4from itertools import product
  5import matplotlib as mpl
  6from matplotlib import cm
  7import matplotlib.colors
  8import tempfile
  9import networkx as nx
 10import ffmpeg
 11import numpy as np
 12import json
 13from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
 14from matplotlib.figure import Figure
 15from swmpo.transition import Transition
 16from swmpo.state_machine import get_local_model_errors
 17from swmpo.state_machine import get_state_machine_errors
 18from swmpo.state_machine import get_visited_states
 19from swmpo.state_machine import state_machine_model
 20from swmpo.sequence_distance import _get_best_permutation
 21import random
 22
 23
 24def plot_state_machine(
 25    state_machine: StateMachine,
 26    active_state: int,
 27    output_path: Path,
 28):
 29    """Plot the state machine diagram."""
 30    G = nx.DiGraph()
 31    state_indices = list(range(len(state_machine.local_models)))
 32    for i in state_indices:
 33        G.add_node(i)
 34    for i, j in product(state_indices, state_indices):
 35        # Check that the transition is possible
 36        if state_machine.transition_histogram[i][j] > 0:
 37            G.add_edge(i, j)
 38
 39    # Create matplotlib figure
 40    fig = Figure()
 41    _ = FigureCanvas(fig)
 42    ax = fig.add_subplot()
 43
 44    # Decide node colors
 45    node_color = [
 46        "blue" if state == active_state else "black"
 47        for state in state_indices
 48    ]
 49
 50    # Plot
 51    nx.draw_circular(G, ax=ax, node_color=node_color)
 52
 53    # Save figure
 54    fig.suptitle('Active state')
 55    fig.savefig(output_path)
 56
 57
 58def plot_animation(
 59    state_machine: StateMachine,
 60    visited_states: list[int],
 61    output_path: Path,
 62    fps: int,
 63):
 64    """Plot an animation of the given state machine."""
 65    with tempfile.TemporaryDirectory() as tdir:
 66        frame_dir = Path(tdir)
 67        for i, state in enumerate(visited_states):
 68            state_path = frame_dir/(f"{i}.png").rjust(10, "0")
 69            plot_state_machine(
 70                state_machine=state_machine,
 71                active_state=state,
 72                output_path=state_path,
 73            )
 74        (
 75            ffmpeg
 76            .input(frame_dir/"*.png", pattern_type="glob", framerate=fps)
 77            .output(str(output_path))
 78            .run(quiet=True)
 79        )
 80
 81
 82def plot_state_machine_errors_diagram(
 83    state_machine: StateMachine,
 84    local_model_errors: list[float],
 85    max_error: float,
 86    output_path: Path,
 87):
 88    """Plot the state machine diagram coloring the nodes according to
 89    their prediction error."""
 90    G = nx.DiGraph()
 91    state_indices = list(range(len(state_machine.local_models)))
 92    for i in state_indices:
 93        G.add_node(i)
 94    for i, j in product(state_indices, state_indices):
 95        # Check that the transition is possible
 96        if state_machine.transition_histogram[i][j] > 0:
 97            G.add_edge(i, j)
 98
 99    # Create matplotlib figure
100    fig = Figure()
101    _ = FigureCanvas(fig)
102    ax = fig.add_subplot()
103
104    # Decide node colors
105    cmap = mpl.colormaps['viridis']
106    norm = matplotlib.colors.Normalize(vmin=0.0, vmax=max_error)
107    mappable = cm.ScalarMappable(norm=norm, cmap=cmap)
108    node_color = [
109        mappable.to_rgba(np.array([error]))[0]
110        for error in local_model_errors
111    ]
112
113    # Plot
114    nx.draw_circular(G, ax=ax, node_color=node_color)
115
116    # Add colorbar
117    fig.colorbar(mappable, ax=ax)
118
119    # Save figure
120    fig.suptitle('Node errors')
121    fig.savefig(output_path)
122
123
124def plot_errors_animation(
125    state_machine: StateMachine,
126    local_model_errors: list[list[float]],
127    output_path: Path,
128    fps: int,
129):
130    """Plot an animation of the given state machine."""
131    max_error = max(max(errors) for errors in local_model_errors)
132    with tempfile.TemporaryDirectory() as tdir:
133        frame_dir = Path(tdir)
134        for i, errors in enumerate(local_model_errors):
135            state_path = frame_dir/(f"{i}.png").rjust(10)
136            plot_state_machine_errors_diagram(
137                state_machine=state_machine,
138                local_model_errors=errors,
139                max_error=max_error,
140                output_path=state_path,
141            )
142        (
143            ffmpeg
144            .input(frame_dir/"*.png", pattern_type="glob", framerate=fps)
145            .output(str(output_path))
146            .run(quiet=True)
147        )
148
149
150def get_starts_widths(
151    visited_states: list[int],
152) -> tuple[list[int], list[int], list[int]]:
153    """Helper function to compute the starts and widths for matplotlib's `barh`
154    function, for visualizing sequences of visited states."""
155    starts = [0]
156    widths = list()
157    current_width = 0
158    colors = list()
159    previous_state = visited_states[0]
160    for i in range(1, len(visited_states)):
161        current_width += 1
162        current_state = visited_states[i]
163        if current_state != previous_state:
164            starts.append(i)
165            colors.append(previous_state)
166            widths.append(current_width)
167            current_width = 0
168        previous_state = current_state
169    # Trailing transition
170    if current_width != 0:
171        colors.append(visited_states[-1])
172        widths.append(current_width)
173    else:
174        starts = starts[:-1]
175    return starts, widths, colors
176
177
178def get_partition_colors() -> list[tuple[float, float, float, float]]:
179    """Get colors with which partitions can be plotted."""
180    cmap = mpl.colormaps['tab20b']
181    colors = cmap(list(range(cmap.N))).tolist()
182    _random = random.Random("seed")
183    _random.shuffle(colors)
184    return colors
185
186
187def plot_state_machine_errors(
188    state_machine: StateMachine,
189    episode: list[Transition],
190    initial_state: int,
191    dt: float,
192    ground_truth_visited_states: list[int] | None,
193    output_dir: Path,
194):
195    """Plot the error of the state machine for predicting the given episode."""
196    # Run state machine
197    local_model_errors = get_local_model_errors(
198        state_machine=state_machine,
199        episode=episode,
200        dt=dt,
201    )
202    state_machine_errors = get_state_machine_errors(
203        state_machine=state_machine,
204        episode=episode,
205        initial_state=initial_state,
206        dt=dt,
207    )
208    visited_states = get_visited_states(
209        state_machine=state_machine,
210        initial_state=initial_state,
211        episode=episode,
212        dt=dt,
213    )
214
215    # Create matplotlib figure
216    fig = Figure()
217    _ = FigureCanvas(fig)
218
219    # Choose model colors
220    partition_colors = get_partition_colors()
221
222    # Permute local models for easy visualization
223    if ground_truth_visited_states is not None:
224        # Get mode permutation
225        perm = _get_best_permutation(
226            sequence=visited_states,
227            ground_truth=ground_truth_visited_states,
228            indices=list(range(len(state_machine.local_models))),
229        )
230        local_model_colors = [
231            partition_colors[perm[i]] if i in perm.keys() else partition_colors[i]
232            for i in range(len(partition_colors))
233        ]
234    else:
235        local_model_colors = list(partition_colors)
236
237    # Plot errors
238    ax1 = fig.add_subplot(2, 1, 1)
239    x = list(range(len(local_model_errors)))
240    for i in range(len(state_machine.local_models)):
241        model_errors = [
242            local_model_errors[j][i]
243            for j in x
244        ]
245        color = local_model_colors[i]
246        ax1.plot(x, model_errors, label=f"Local model {i}", color=color)
247    x = list(range(len(state_machine_errors)))
248    ax1.plot(
249        x,
250        state_machine_errors,
251        color='black',
252        linestyle=':',
253        label="Full state machine",
254    )
255    ax1.set_xlim(left=0, right=len(x))
256
257    # Plot active states
258    ax2 = fig.add_subplot(2, 1, 2)
259    labels = ["FSM states"]
260    starts, widths, colors = get_starts_widths(visited_states)
261    color = [local_model_colors[i] for i in colors]
262    ax2.barh(labels, widths, left=starts, height=0.5, color=color)
263    ax2.set_xlim(left=0, right=len(x))
264
265    # Plot "partition"
266    labels = ["Minimum-loss induced states"]
267    induced_visited_states = list()
268    for step_losses in local_model_errors:
269        indices = list(range(len(step_losses)))
270        min_loss_i = min(indices, key=lambda i: step_losses[i])
271        induced_visited_states.append(min_loss_i)
272    starts, widths, colors = get_starts_widths(induced_visited_states)
273    color = [local_model_colors[i] for i in colors]
274    ax2.barh(labels, widths, left=starts, height=0.5, color=color)
275    ax2.set_xlim(left=0, right=len(x))
276
277    # Plot ground truth
278    if ground_truth_visited_states is not None and len(ground_truth_visited_states) > 0:
279        labels = ["Ground truth states"]
280        starts, widths, colors = get_starts_widths(ground_truth_visited_states)
281        color = [partition_colors[i] for i in colors]
282        ax2.barh(labels, widths, left=starts, height=1.5, color=color)
283
284    # Save figure
285    fig_output_path = output_dir/f"state_machine_errors.svg"
286    fig.legend()
287    fig.suptitle('State machine prediction errors')
288    fig.tight_layout()
289    fig.savefig(fig_output_path)
290
291    # Save data
292    data_output_path = output_dir/f"data.json"
293    data = dict(
294        local_model_errors=local_model_errors,
295        state_machine_errors=state_machine_errors,
296        visited_states=visited_states,
297        ground_truth_visited_states=ground_truth_visited_states,
298    )
299    with open(data_output_path, "wt") as fp:
300        json.dump(data, fp, indent=2)
301
302
303def get_state_machine_forecasting_errors(
304    state_machine: StateMachine,
305    episode: list[Transition],
306    initial_state: int,
307    dt: float,
308) -> list[float]:
309    """Return the forecasting errors of the state machine over the given
310    episode (i.e., using the state machine from the initial
311    state to simulate the system and comparing with the ground-truth
312    states)."""
313    errors = list[float]()
314
315    state = episode[0].source_state
316    mode = initial_state
317
318    for transition in episode:
319        # Get current error
320        real = transition.source_state
321        step_error = (real - state).norm()
322        errors.append(step_error.item())
323
324        # Step the simulation
325        state, mode = state_machine_model(
326            state_machine=state_machine,
327            state=state,
328            action=transition.action,
329            current_node=mode,
330            dt=dt,
331        )
332
333    return errors
334
335
336def plot_state_machine_forecasting_errors(
337    state_machine: StateMachine,
338    episode: list[Transition],
339    initial_state: int,
340    dt: float,
341    ground_truth_visited_states: list[int] | None,
342    output_path: Path,
343):
344    """Plot the error of the state machine for predicting the given episode."""
345    # Run state machine
346    local_model_errors = get_local_model_errors(
347        state_machine=state_machine,
348        episode=episode,
349        dt=dt,
350    )
351    forecasting_errors = get_state_machine_forecasting_errors(
352        state_machine=state_machine,
353        episode=episode,
354        initial_state=initial_state,
355        dt=dt,
356    )
357    visited_states = get_visited_states(
358        state_machine=state_machine,
359        initial_state=initial_state,
360        episode=episode,
361        dt=dt,
362    )
363
364    # Get state machine forecasting errors
365
366    # Create matplotlib figure
367    fig = Figure()
368    _ = FigureCanvas(fig)
369
370    # Choose model colors
371    partition_colors = get_partition_colors()
372
373    # Permute local models for easy visualization
374    if ground_truth_visited_states is not None:
375        # Get mode permutation
376        perm = _get_best_permutation(
377            sequence=visited_states,
378            ground_truth=ground_truth_visited_states,
379            indices=list(range(len(state_machine.local_models))),
380        )
381        local_model_colors = [
382            partition_colors[perm[i]] if i in perm.keys() else partition_colors[i]
383            for i in range(len(partition_colors))
384        ]
385    else:
386        local_model_colors = list(partition_colors)
387
388    # Plot errors
389    ax1 = fig.add_subplot(2, 1, 1)
390    x = list(range(len(local_model_errors)))
391    ax1.plot(
392        x,
393        forecasting_errors,
394        color='black',
395        linestyle='-',
396        label="Full state machine",
397    )
398    ax1.set_xlim(left=0, right=len(x))
399
400    # Plot active states
401    ax2 = fig.add_subplot(2, 1, 2)
402    labels = ["FSM states"]
403    starts, widths, colors = get_starts_widths(visited_states)
404    color = [local_model_colors[i] for i in colors]
405    ax2.barh(labels, widths, left=starts, height=0.5, color=color)
406    ax2.set_xlim(left=0, right=len(x))
407
408    # Plot "partition"
409    labels = ["Minimum-loss induced states"]
410    induced_visited_states = list()
411    for step_losses in local_model_errors:
412        indices = list(range(len(step_losses)))
413        min_loss_i = min(indices, key=lambda i: step_losses[i])
414        induced_visited_states.append(min_loss_i)
415    starts, widths, colors = get_starts_widths(induced_visited_states)
416    color = [local_model_colors[i] for i in colors]
417    ax2.barh(labels, widths, left=starts, height=0.5, color=color)
418    ax2.set_xlim(left=0, right=len(x))
419
420    # Plot ground truth
421    if ground_truth_visited_states is not None and len(ground_truth_visited_states) > 0:
422        labels = ["Ground truth states"]
423        starts, widths, colors = get_starts_widths(ground_truth_visited_states)
424        color = [partition_colors[i] for i in colors]
425        ax2.barh(labels, widths, left=starts, height=1.5, color=color)
426
427    # Save figure
428    fig.legend()
429    fig.suptitle('Long-horizon forecasting performance')
430    fig.tight_layout()
431    fig.savefig(output_path)
def plot_state_machine( state_machine: swmpo.state_machine.StateMachine, active_state: int, output_path: pathlib.Path):
25def plot_state_machine(
26    state_machine: StateMachine,
27    active_state: int,
28    output_path: Path,
29):
30    """Plot the state machine diagram."""
31    G = nx.DiGraph()
32    state_indices = list(range(len(state_machine.local_models)))
33    for i in state_indices:
34        G.add_node(i)
35    for i, j in product(state_indices, state_indices):
36        # Check that the transition is possible
37        if state_machine.transition_histogram[i][j] > 0:
38            G.add_edge(i, j)
39
40    # Create matplotlib figure
41    fig = Figure()
42    _ = FigureCanvas(fig)
43    ax = fig.add_subplot()
44
45    # Decide node colors
46    node_color = [
47        "blue" if state == active_state else "black"
48        for state in state_indices
49    ]
50
51    # Plot
52    nx.draw_circular(G, ax=ax, node_color=node_color)
53
54    # Save figure
55    fig.suptitle('Active state')
56    fig.savefig(output_path)

Plot the state machine diagram.

def plot_animation( state_machine: swmpo.state_machine.StateMachine, visited_states: list[int], output_path: pathlib.Path, fps: int):
59def plot_animation(
60    state_machine: StateMachine,
61    visited_states: list[int],
62    output_path: Path,
63    fps: int,
64):
65    """Plot an animation of the given state machine."""
66    with tempfile.TemporaryDirectory() as tdir:
67        frame_dir = Path(tdir)
68        for i, state in enumerate(visited_states):
69            state_path = frame_dir/(f"{i}.png").rjust(10, "0")
70            plot_state_machine(
71                state_machine=state_machine,
72                active_state=state,
73                output_path=state_path,
74            )
75        (
76            ffmpeg
77            .input(frame_dir/"*.png", pattern_type="glob", framerate=fps)
78            .output(str(output_path))
79            .run(quiet=True)
80        )

Plot an animation of the given state machine.

def plot_state_machine_errors_diagram( state_machine: swmpo.state_machine.StateMachine, local_model_errors: list[float], max_error: float, output_path: pathlib.Path):
 83def plot_state_machine_errors_diagram(
 84    state_machine: StateMachine,
 85    local_model_errors: list[float],
 86    max_error: float,
 87    output_path: Path,
 88):
 89    """Plot the state machine diagram coloring the nodes according to
 90    their prediction error."""
 91    G = nx.DiGraph()
 92    state_indices = list(range(len(state_machine.local_models)))
 93    for i in state_indices:
 94        G.add_node(i)
 95    for i, j in product(state_indices, state_indices):
 96        # Check that the transition is possible
 97        if state_machine.transition_histogram[i][j] > 0:
 98            G.add_edge(i, j)
 99
100    # Create matplotlib figure
101    fig = Figure()
102    _ = FigureCanvas(fig)
103    ax = fig.add_subplot()
104
105    # Decide node colors
106    cmap = mpl.colormaps['viridis']
107    norm = matplotlib.colors.Normalize(vmin=0.0, vmax=max_error)
108    mappable = cm.ScalarMappable(norm=norm, cmap=cmap)
109    node_color = [
110        mappable.to_rgba(np.array([error]))[0]
111        for error in local_model_errors
112    ]
113
114    # Plot
115    nx.draw_circular(G, ax=ax, node_color=node_color)
116
117    # Add colorbar
118    fig.colorbar(mappable, ax=ax)
119
120    # Save figure
121    fig.suptitle('Node errors')
122    fig.savefig(output_path)

Plot the state machine diagram coloring the nodes according to their prediction error.

def plot_errors_animation( state_machine: swmpo.state_machine.StateMachine, local_model_errors: list[list[float]], output_path: pathlib.Path, fps: int):
125def plot_errors_animation(
126    state_machine: StateMachine,
127    local_model_errors: list[list[float]],
128    output_path: Path,
129    fps: int,
130):
131    """Plot an animation of the given state machine."""
132    max_error = max(max(errors) for errors in local_model_errors)
133    with tempfile.TemporaryDirectory() as tdir:
134        frame_dir = Path(tdir)
135        for i, errors in enumerate(local_model_errors):
136            state_path = frame_dir/(f"{i}.png").rjust(10)
137            plot_state_machine_errors_diagram(
138                state_machine=state_machine,
139                local_model_errors=errors,
140                max_error=max_error,
141                output_path=state_path,
142            )
143        (
144            ffmpeg
145            .input(frame_dir/"*.png", pattern_type="glob", framerate=fps)
146            .output(str(output_path))
147            .run(quiet=True)
148        )

Plot an animation of the given state machine.

def get_starts_widths(visited_states: list[int]) -> tuple[list[int], list[int], list[int]]:
151def get_starts_widths(
152    visited_states: list[int],
153) -> tuple[list[int], list[int], list[int]]:
154    """Helper function to compute the starts and widths for matplotlib's `barh`
155    function, for visualizing sequences of visited states."""
156    starts = [0]
157    widths = list()
158    current_width = 0
159    colors = list()
160    previous_state = visited_states[0]
161    for i in range(1, len(visited_states)):
162        current_width += 1
163        current_state = visited_states[i]
164        if current_state != previous_state:
165            starts.append(i)
166            colors.append(previous_state)
167            widths.append(current_width)
168            current_width = 0
169        previous_state = current_state
170    # Trailing transition
171    if current_width != 0:
172        colors.append(visited_states[-1])
173        widths.append(current_width)
174    else:
175        starts = starts[:-1]
176    return starts, widths, colors

Helper function to compute the starts and widths for matplotlib's barh function, for visualizing sequences of visited states.

def get_partition_colors() -> list[tuple[float, float, float, float]]:
179def get_partition_colors() -> list[tuple[float, float, float, float]]:
180    """Get colors with which partitions can be plotted."""
181    cmap = mpl.colormaps['tab20b']
182    colors = cmap(list(range(cmap.N))).tolist()
183    _random = random.Random("seed")
184    _random.shuffle(colors)
185    return colors

Get colors with which partitions can be plotted.

def plot_state_machine_errors( state_machine: swmpo.state_machine.StateMachine, episode: list[swmpo.transition.Transition], initial_state: int, dt: float, ground_truth_visited_states: list[int] | None, output_dir: pathlib.Path):
188def plot_state_machine_errors(
189    state_machine: StateMachine,
190    episode: list[Transition],
191    initial_state: int,
192    dt: float,
193    ground_truth_visited_states: list[int] | None,
194    output_dir: Path,
195):
196    """Plot the error of the state machine for predicting the given episode."""
197    # Run state machine
198    local_model_errors = get_local_model_errors(
199        state_machine=state_machine,
200        episode=episode,
201        dt=dt,
202    )
203    state_machine_errors = get_state_machine_errors(
204        state_machine=state_machine,
205        episode=episode,
206        initial_state=initial_state,
207        dt=dt,
208    )
209    visited_states = get_visited_states(
210        state_machine=state_machine,
211        initial_state=initial_state,
212        episode=episode,
213        dt=dt,
214    )
215
216    # Create matplotlib figure
217    fig = Figure()
218    _ = FigureCanvas(fig)
219
220    # Choose model colors
221    partition_colors = get_partition_colors()
222
223    # Permute local models for easy visualization
224    if ground_truth_visited_states is not None:
225        # Get mode permutation
226        perm = _get_best_permutation(
227            sequence=visited_states,
228            ground_truth=ground_truth_visited_states,
229            indices=list(range(len(state_machine.local_models))),
230        )
231        local_model_colors = [
232            partition_colors[perm[i]] if i in perm.keys() else partition_colors[i]
233            for i in range(len(partition_colors))
234        ]
235    else:
236        local_model_colors = list(partition_colors)
237
238    # Plot errors
239    ax1 = fig.add_subplot(2, 1, 1)
240    x = list(range(len(local_model_errors)))
241    for i in range(len(state_machine.local_models)):
242        model_errors = [
243            local_model_errors[j][i]
244            for j in x
245        ]
246        color = local_model_colors[i]
247        ax1.plot(x, model_errors, label=f"Local model {i}", color=color)
248    x = list(range(len(state_machine_errors)))
249    ax1.plot(
250        x,
251        state_machine_errors,
252        color='black',
253        linestyle=':',
254        label="Full state machine",
255    )
256    ax1.set_xlim(left=0, right=len(x))
257
258    # Plot active states
259    ax2 = fig.add_subplot(2, 1, 2)
260    labels = ["FSM states"]
261    starts, widths, colors = get_starts_widths(visited_states)
262    color = [local_model_colors[i] for i in colors]
263    ax2.barh(labels, widths, left=starts, height=0.5, color=color)
264    ax2.set_xlim(left=0, right=len(x))
265
266    # Plot "partition"
267    labels = ["Minimum-loss induced states"]
268    induced_visited_states = list()
269    for step_losses in local_model_errors:
270        indices = list(range(len(step_losses)))
271        min_loss_i = min(indices, key=lambda i: step_losses[i])
272        induced_visited_states.append(min_loss_i)
273    starts, widths, colors = get_starts_widths(induced_visited_states)
274    color = [local_model_colors[i] for i in colors]
275    ax2.barh(labels, widths, left=starts, height=0.5, color=color)
276    ax2.set_xlim(left=0, right=len(x))
277
278    # Plot ground truth
279    if ground_truth_visited_states is not None and len(ground_truth_visited_states) > 0:
280        labels = ["Ground truth states"]
281        starts, widths, colors = get_starts_widths(ground_truth_visited_states)
282        color = [partition_colors[i] for i in colors]
283        ax2.barh(labels, widths, left=starts, height=1.5, color=color)
284
285    # Save figure
286    fig_output_path = output_dir/f"state_machine_errors.svg"
287    fig.legend()
288    fig.suptitle('State machine prediction errors')
289    fig.tight_layout()
290    fig.savefig(fig_output_path)
291
292    # Save data
293    data_output_path = output_dir/f"data.json"
294    data = dict(
295        local_model_errors=local_model_errors,
296        state_machine_errors=state_machine_errors,
297        visited_states=visited_states,
298        ground_truth_visited_states=ground_truth_visited_states,
299    )
300    with open(data_output_path, "wt") as fp:
301        json.dump(data, fp, indent=2)

Plot the error of the state machine for predicting the given episode.

def get_state_machine_forecasting_errors( state_machine: swmpo.state_machine.StateMachine, episode: list[swmpo.transition.Transition], initial_state: int, dt: float) -> list[float]:
304def get_state_machine_forecasting_errors(
305    state_machine: StateMachine,
306    episode: list[Transition],
307    initial_state: int,
308    dt: float,
309) -> list[float]:
310    """Return the forecasting errors of the state machine over the given
311    episode (i.e., using the state machine from the initial
312    state to simulate the system and comparing with the ground-truth
313    states)."""
314    errors = list[float]()
315
316    state = episode[0].source_state
317    mode = initial_state
318
319    for transition in episode:
320        # Get current error
321        real = transition.source_state
322        step_error = (real - state).norm()
323        errors.append(step_error.item())
324
325        # Step the simulation
326        state, mode = state_machine_model(
327            state_machine=state_machine,
328            state=state,
329            action=transition.action,
330            current_node=mode,
331            dt=dt,
332        )
333
334    return errors

Return the forecasting errors of the state machine over the given episode (i.e., using the state machine from the initial state to simulate the system and comparing with the ground-truth states).

def plot_state_machine_forecasting_errors( state_machine: swmpo.state_machine.StateMachine, episode: list[swmpo.transition.Transition], initial_state: int, dt: float, ground_truth_visited_states: list[int] | None, output_path: pathlib.Path):
337def plot_state_machine_forecasting_errors(
338    state_machine: StateMachine,
339    episode: list[Transition],
340    initial_state: int,
341    dt: float,
342    ground_truth_visited_states: list[int] | None,
343    output_path: Path,
344):
345    """Plot the error of the state machine for predicting the given episode."""
346    # Run state machine
347    local_model_errors = get_local_model_errors(
348        state_machine=state_machine,
349        episode=episode,
350        dt=dt,
351    )
352    forecasting_errors = get_state_machine_forecasting_errors(
353        state_machine=state_machine,
354        episode=episode,
355        initial_state=initial_state,
356        dt=dt,
357    )
358    visited_states = get_visited_states(
359        state_machine=state_machine,
360        initial_state=initial_state,
361        episode=episode,
362        dt=dt,
363    )
364
365    # Get state machine forecasting errors
366
367    # Create matplotlib figure
368    fig = Figure()
369    _ = FigureCanvas(fig)
370
371    # Choose model colors
372    partition_colors = get_partition_colors()
373
374    # Permute local models for easy visualization
375    if ground_truth_visited_states is not None:
376        # Get mode permutation
377        perm = _get_best_permutation(
378            sequence=visited_states,
379            ground_truth=ground_truth_visited_states,
380            indices=list(range(len(state_machine.local_models))),
381        )
382        local_model_colors = [
383            partition_colors[perm[i]] if i in perm.keys() else partition_colors[i]
384            for i in range(len(partition_colors))
385        ]
386    else:
387        local_model_colors = list(partition_colors)
388
389    # Plot errors
390    ax1 = fig.add_subplot(2, 1, 1)
391    x = list(range(len(local_model_errors)))
392    ax1.plot(
393        x,
394        forecasting_errors,
395        color='black',
396        linestyle='-',
397        label="Full state machine",
398    )
399    ax1.set_xlim(left=0, right=len(x))
400
401    # Plot active states
402    ax2 = fig.add_subplot(2, 1, 2)
403    labels = ["FSM states"]
404    starts, widths, colors = get_starts_widths(visited_states)
405    color = [local_model_colors[i] for i in colors]
406    ax2.barh(labels, widths, left=starts, height=0.5, color=color)
407    ax2.set_xlim(left=0, right=len(x))
408
409    # Plot "partition"
410    labels = ["Minimum-loss induced states"]
411    induced_visited_states = list()
412    for step_losses in local_model_errors:
413        indices = list(range(len(step_losses)))
414        min_loss_i = min(indices, key=lambda i: step_losses[i])
415        induced_visited_states.append(min_loss_i)
416    starts, widths, colors = get_starts_widths(induced_visited_states)
417    color = [local_model_colors[i] for i in colors]
418    ax2.barh(labels, widths, left=starts, height=0.5, color=color)
419    ax2.set_xlim(left=0, right=len(x))
420
421    # Plot ground truth
422    if ground_truth_visited_states is not None and len(ground_truth_visited_states) > 0:
423        labels = ["Ground truth states"]
424        starts, widths, colors = get_starts_widths(ground_truth_visited_states)
425        color = [partition_colors[i] for i in colors]
426        ax2.barh(labels, widths, left=starts, height=1.5, color=color)
427
428    # Save figure
429    fig.legend()
430    fig.suptitle('Long-horizon forecasting performance')
431    fig.tight_layout()
432    fig.savefig(output_path)

Plot the error of the state machine for predicting the given episode.