swmpo.plotting
Plotting utilities.
1"""Plotting utilities.""" 2from swmpo.state_machine import StateMachine 3from pathlib import Path 4from itertools import product 5import matplotlib as mpl 6from matplotlib import cm 7import matplotlib.colors 8import tempfile 9import networkx as nx 10import ffmpeg 11import numpy as np 12import json 13from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas 14from matplotlib.figure import Figure 15from swmpo.transition import Transition 16from swmpo.state_machine import get_local_model_errors 17from swmpo.state_machine import get_state_machine_errors 18from swmpo.state_machine import get_visited_states 19from swmpo.state_machine import state_machine_model 20from swmpo.sequence_distance import _get_best_permutation 21import random 22 23 24def plot_state_machine( 25 state_machine: StateMachine, 26 active_state: int, 27 output_path: Path, 28): 29 """Plot the state machine diagram.""" 30 G = nx.DiGraph() 31 state_indices = list(range(len(state_machine.local_models))) 32 for i in state_indices: 33 G.add_node(i) 34 for i, j in product(state_indices, state_indices): 35 # Check that the transition is possible 36 if state_machine.transition_histogram[i][j] > 0: 37 G.add_edge(i, j) 38 39 # Create matplotlib figure 40 fig = Figure() 41 _ = FigureCanvas(fig) 42 ax = fig.add_subplot() 43 44 # Decide node colors 45 node_color = [ 46 "blue" if state == active_state else "black" 47 for state in state_indices 48 ] 49 50 # Plot 51 nx.draw_circular(G, ax=ax, node_color=node_color) 52 53 # Save figure 54 fig.suptitle('Active state') 55 fig.savefig(output_path) 56 57 58def plot_animation( 59 state_machine: StateMachine, 60 visited_states: list[int], 61 output_path: Path, 62 fps: int, 63): 64 """Plot an animation of the given state machine.""" 65 with tempfile.TemporaryDirectory() as tdir: 66 frame_dir = Path(tdir) 67 for i, state in enumerate(visited_states): 68 state_path = frame_dir/(f"{i}.png").rjust(10, "0") 69 plot_state_machine( 70 state_machine=state_machine, 71 active_state=state, 72 output_path=state_path, 73 ) 74 ( 75 ffmpeg 76 .input(frame_dir/"*.png", pattern_type="glob", framerate=fps) 77 .output(str(output_path)) 78 .run(quiet=True) 79 ) 80 81 82def plot_state_machine_errors_diagram( 83 state_machine: StateMachine, 84 local_model_errors: list[float], 85 max_error: float, 86 output_path: Path, 87): 88 """Plot the state machine diagram coloring the nodes according to 89 their prediction error.""" 90 G = nx.DiGraph() 91 state_indices = list(range(len(state_machine.local_models))) 92 for i in state_indices: 93 G.add_node(i) 94 for i, j in product(state_indices, state_indices): 95 # Check that the transition is possible 96 if state_machine.transition_histogram[i][j] > 0: 97 G.add_edge(i, j) 98 99 # Create matplotlib figure 100 fig = Figure() 101 _ = FigureCanvas(fig) 102 ax = fig.add_subplot() 103 104 # Decide node colors 105 cmap = mpl.colormaps['viridis'] 106 norm = matplotlib.colors.Normalize(vmin=0.0, vmax=max_error) 107 mappable = cm.ScalarMappable(norm=norm, cmap=cmap) 108 node_color = [ 109 mappable.to_rgba(np.array([error]))[0] 110 for error in local_model_errors 111 ] 112 113 # Plot 114 nx.draw_circular(G, ax=ax, node_color=node_color) 115 116 # Add colorbar 117 fig.colorbar(mappable, ax=ax) 118 119 # Save figure 120 fig.suptitle('Node errors') 121 fig.savefig(output_path) 122 123 124def plot_errors_animation( 125 state_machine: StateMachine, 126 local_model_errors: list[list[float]], 127 output_path: Path, 128 fps: int, 129): 130 """Plot an animation of the given state machine.""" 131 max_error = max(max(errors) for errors in local_model_errors) 132 with tempfile.TemporaryDirectory() as tdir: 133 frame_dir = Path(tdir) 134 for i, errors in enumerate(local_model_errors): 135 state_path = frame_dir/(f"{i}.png").rjust(10) 136 plot_state_machine_errors_diagram( 137 state_machine=state_machine, 138 local_model_errors=errors, 139 max_error=max_error, 140 output_path=state_path, 141 ) 142 ( 143 ffmpeg 144 .input(frame_dir/"*.png", pattern_type="glob", framerate=fps) 145 .output(str(output_path)) 146 .run(quiet=True) 147 ) 148 149 150def get_starts_widths( 151 visited_states: list[int], 152) -> tuple[list[int], list[int], list[int]]: 153 """Helper function to compute the starts and widths for matplotlib's `barh` 154 function, for visualizing sequences of visited states.""" 155 starts = [0] 156 widths = list() 157 current_width = 0 158 colors = list() 159 previous_state = visited_states[0] 160 for i in range(1, len(visited_states)): 161 current_width += 1 162 current_state = visited_states[i] 163 if current_state != previous_state: 164 starts.append(i) 165 colors.append(previous_state) 166 widths.append(current_width) 167 current_width = 0 168 previous_state = current_state 169 # Trailing transition 170 if current_width != 0: 171 colors.append(visited_states[-1]) 172 widths.append(current_width) 173 else: 174 starts = starts[:-1] 175 return starts, widths, colors 176 177 178def get_partition_colors() -> list[tuple[float, float, float, float]]: 179 """Get colors with which partitions can be plotted.""" 180 cmap = mpl.colormaps['tab20b'] 181 colors = cmap(list(range(cmap.N))).tolist() 182 _random = random.Random("seed") 183 _random.shuffle(colors) 184 return colors 185 186 187def plot_state_machine_errors( 188 state_machine: StateMachine, 189 episode: list[Transition], 190 initial_state: int, 191 dt: float, 192 ground_truth_visited_states: list[int] | None, 193 output_dir: Path, 194): 195 """Plot the error of the state machine for predicting the given episode.""" 196 # Run state machine 197 local_model_errors = get_local_model_errors( 198 state_machine=state_machine, 199 episode=episode, 200 dt=dt, 201 ) 202 state_machine_errors = get_state_machine_errors( 203 state_machine=state_machine, 204 episode=episode, 205 initial_state=initial_state, 206 dt=dt, 207 ) 208 visited_states = get_visited_states( 209 state_machine=state_machine, 210 initial_state=initial_state, 211 episode=episode, 212 dt=dt, 213 ) 214 215 # Create matplotlib figure 216 fig = Figure() 217 _ = FigureCanvas(fig) 218 219 # Choose model colors 220 partition_colors = get_partition_colors() 221 222 # Permute local models for easy visualization 223 if ground_truth_visited_states is not None: 224 # Get mode permutation 225 perm = _get_best_permutation( 226 sequence=visited_states, 227 ground_truth=ground_truth_visited_states, 228 indices=list(range(len(state_machine.local_models))), 229 ) 230 local_model_colors = [ 231 partition_colors[perm[i]] if i in perm.keys() else partition_colors[i] 232 for i in range(len(partition_colors)) 233 ] 234 else: 235 local_model_colors = list(partition_colors) 236 237 # Plot errors 238 ax1 = fig.add_subplot(2, 1, 1) 239 x = list(range(len(local_model_errors))) 240 for i in range(len(state_machine.local_models)): 241 model_errors = [ 242 local_model_errors[j][i] 243 for j in x 244 ] 245 color = local_model_colors[i] 246 ax1.plot(x, model_errors, label=f"Local model {i}", color=color) 247 x = list(range(len(state_machine_errors))) 248 ax1.plot( 249 x, 250 state_machine_errors, 251 color='black', 252 linestyle=':', 253 label="Full state machine", 254 ) 255 ax1.set_xlim(left=0, right=len(x)) 256 257 # Plot active states 258 ax2 = fig.add_subplot(2, 1, 2) 259 labels = ["FSM states"] 260 starts, widths, colors = get_starts_widths(visited_states) 261 color = [local_model_colors[i] for i in colors] 262 ax2.barh(labels, widths, left=starts, height=0.5, color=color) 263 ax2.set_xlim(left=0, right=len(x)) 264 265 # Plot "partition" 266 labels = ["Minimum-loss induced states"] 267 induced_visited_states = list() 268 for step_losses in local_model_errors: 269 indices = list(range(len(step_losses))) 270 min_loss_i = min(indices, key=lambda i: step_losses[i]) 271 induced_visited_states.append(min_loss_i) 272 starts, widths, colors = get_starts_widths(induced_visited_states) 273 color = [local_model_colors[i] for i in colors] 274 ax2.barh(labels, widths, left=starts, height=0.5, color=color) 275 ax2.set_xlim(left=0, right=len(x)) 276 277 # Plot ground truth 278 if ground_truth_visited_states is not None and len(ground_truth_visited_states) > 0: 279 labels = ["Ground truth states"] 280 starts, widths, colors = get_starts_widths(ground_truth_visited_states) 281 color = [partition_colors[i] for i in colors] 282 ax2.barh(labels, widths, left=starts, height=1.5, color=color) 283 284 # Save figure 285 fig_output_path = output_dir/f"state_machine_errors.svg" 286 fig.legend() 287 fig.suptitle('State machine prediction errors') 288 fig.tight_layout() 289 fig.savefig(fig_output_path) 290 291 # Save data 292 data_output_path = output_dir/f"data.json" 293 data = dict( 294 local_model_errors=local_model_errors, 295 state_machine_errors=state_machine_errors, 296 visited_states=visited_states, 297 ground_truth_visited_states=ground_truth_visited_states, 298 ) 299 with open(data_output_path, "wt") as fp: 300 json.dump(data, fp, indent=2) 301 302 303def get_state_machine_forecasting_errors( 304 state_machine: StateMachine, 305 episode: list[Transition], 306 initial_state: int, 307 dt: float, 308) -> list[float]: 309 """Return the forecasting errors of the state machine over the given 310 episode (i.e., using the state machine from the initial 311 state to simulate the system and comparing with the ground-truth 312 states).""" 313 errors = list[float]() 314 315 state = episode[0].source_state 316 mode = initial_state 317 318 for transition in episode: 319 # Get current error 320 real = transition.source_state 321 step_error = (real - state).norm() 322 errors.append(step_error.item()) 323 324 # Step the simulation 325 state, mode = state_machine_model( 326 state_machine=state_machine, 327 state=state, 328 action=transition.action, 329 current_node=mode, 330 dt=dt, 331 ) 332 333 return errors 334 335 336def plot_state_machine_forecasting_errors( 337 state_machine: StateMachine, 338 episode: list[Transition], 339 initial_state: int, 340 dt: float, 341 ground_truth_visited_states: list[int] | None, 342 output_path: Path, 343): 344 """Plot the error of the state machine for predicting the given episode.""" 345 # Run state machine 346 local_model_errors = get_local_model_errors( 347 state_machine=state_machine, 348 episode=episode, 349 dt=dt, 350 ) 351 forecasting_errors = get_state_machine_forecasting_errors( 352 state_machine=state_machine, 353 episode=episode, 354 initial_state=initial_state, 355 dt=dt, 356 ) 357 visited_states = get_visited_states( 358 state_machine=state_machine, 359 initial_state=initial_state, 360 episode=episode, 361 dt=dt, 362 ) 363 364 # Get state machine forecasting errors 365 366 # Create matplotlib figure 367 fig = Figure() 368 _ = FigureCanvas(fig) 369 370 # Choose model colors 371 partition_colors = get_partition_colors() 372 373 # Permute local models for easy visualization 374 if ground_truth_visited_states is not None: 375 # Get mode permutation 376 perm = _get_best_permutation( 377 sequence=visited_states, 378 ground_truth=ground_truth_visited_states, 379 indices=list(range(len(state_machine.local_models))), 380 ) 381 local_model_colors = [ 382 partition_colors[perm[i]] if i in perm.keys() else partition_colors[i] 383 for i in range(len(partition_colors)) 384 ] 385 else: 386 local_model_colors = list(partition_colors) 387 388 # Plot errors 389 ax1 = fig.add_subplot(2, 1, 1) 390 x = list(range(len(local_model_errors))) 391 ax1.plot( 392 x, 393 forecasting_errors, 394 color='black', 395 linestyle='-', 396 label="Full state machine", 397 ) 398 ax1.set_xlim(left=0, right=len(x)) 399 400 # Plot active states 401 ax2 = fig.add_subplot(2, 1, 2) 402 labels = ["FSM states"] 403 starts, widths, colors = get_starts_widths(visited_states) 404 color = [local_model_colors[i] for i in colors] 405 ax2.barh(labels, widths, left=starts, height=0.5, color=color) 406 ax2.set_xlim(left=0, right=len(x)) 407 408 # Plot "partition" 409 labels = ["Minimum-loss induced states"] 410 induced_visited_states = list() 411 for step_losses in local_model_errors: 412 indices = list(range(len(step_losses))) 413 min_loss_i = min(indices, key=lambda i: step_losses[i]) 414 induced_visited_states.append(min_loss_i) 415 starts, widths, colors = get_starts_widths(induced_visited_states) 416 color = [local_model_colors[i] for i in colors] 417 ax2.barh(labels, widths, left=starts, height=0.5, color=color) 418 ax2.set_xlim(left=0, right=len(x)) 419 420 # Plot ground truth 421 if ground_truth_visited_states is not None and len(ground_truth_visited_states) > 0: 422 labels = ["Ground truth states"] 423 starts, widths, colors = get_starts_widths(ground_truth_visited_states) 424 color = [partition_colors[i] for i in colors] 425 ax2.barh(labels, widths, left=starts, height=1.5, color=color) 426 427 # Save figure 428 fig.legend() 429 fig.suptitle('Long-horizon forecasting performance') 430 fig.tight_layout() 431 fig.savefig(output_path)
def
plot_state_machine( state_machine: swmpo.state_machine.StateMachine, active_state: int, output_path: pathlib.Path):
25def plot_state_machine( 26 state_machine: StateMachine, 27 active_state: int, 28 output_path: Path, 29): 30 """Plot the state machine diagram.""" 31 G = nx.DiGraph() 32 state_indices = list(range(len(state_machine.local_models))) 33 for i in state_indices: 34 G.add_node(i) 35 for i, j in product(state_indices, state_indices): 36 # Check that the transition is possible 37 if state_machine.transition_histogram[i][j] > 0: 38 G.add_edge(i, j) 39 40 # Create matplotlib figure 41 fig = Figure() 42 _ = FigureCanvas(fig) 43 ax = fig.add_subplot() 44 45 # Decide node colors 46 node_color = [ 47 "blue" if state == active_state else "black" 48 for state in state_indices 49 ] 50 51 # Plot 52 nx.draw_circular(G, ax=ax, node_color=node_color) 53 54 # Save figure 55 fig.suptitle('Active state') 56 fig.savefig(output_path)
Plot the state machine diagram.
def
plot_animation( state_machine: swmpo.state_machine.StateMachine, visited_states: list[int], output_path: pathlib.Path, fps: int):
59def plot_animation( 60 state_machine: StateMachine, 61 visited_states: list[int], 62 output_path: Path, 63 fps: int, 64): 65 """Plot an animation of the given state machine.""" 66 with tempfile.TemporaryDirectory() as tdir: 67 frame_dir = Path(tdir) 68 for i, state in enumerate(visited_states): 69 state_path = frame_dir/(f"{i}.png").rjust(10, "0") 70 plot_state_machine( 71 state_machine=state_machine, 72 active_state=state, 73 output_path=state_path, 74 ) 75 ( 76 ffmpeg 77 .input(frame_dir/"*.png", pattern_type="glob", framerate=fps) 78 .output(str(output_path)) 79 .run(quiet=True) 80 )
Plot an animation of the given state machine.
def
plot_state_machine_errors_diagram( state_machine: swmpo.state_machine.StateMachine, local_model_errors: list[float], max_error: float, output_path: pathlib.Path):
83def plot_state_machine_errors_diagram( 84 state_machine: StateMachine, 85 local_model_errors: list[float], 86 max_error: float, 87 output_path: Path, 88): 89 """Plot the state machine diagram coloring the nodes according to 90 their prediction error.""" 91 G = nx.DiGraph() 92 state_indices = list(range(len(state_machine.local_models))) 93 for i in state_indices: 94 G.add_node(i) 95 for i, j in product(state_indices, state_indices): 96 # Check that the transition is possible 97 if state_machine.transition_histogram[i][j] > 0: 98 G.add_edge(i, j) 99 100 # Create matplotlib figure 101 fig = Figure() 102 _ = FigureCanvas(fig) 103 ax = fig.add_subplot() 104 105 # Decide node colors 106 cmap = mpl.colormaps['viridis'] 107 norm = matplotlib.colors.Normalize(vmin=0.0, vmax=max_error) 108 mappable = cm.ScalarMappable(norm=norm, cmap=cmap) 109 node_color = [ 110 mappable.to_rgba(np.array([error]))[0] 111 for error in local_model_errors 112 ] 113 114 # Plot 115 nx.draw_circular(G, ax=ax, node_color=node_color) 116 117 # Add colorbar 118 fig.colorbar(mappable, ax=ax) 119 120 # Save figure 121 fig.suptitle('Node errors') 122 fig.savefig(output_path)
Plot the state machine diagram coloring the nodes according to their prediction error.
def
plot_errors_animation( state_machine: swmpo.state_machine.StateMachine, local_model_errors: list[list[float]], output_path: pathlib.Path, fps: int):
125def plot_errors_animation( 126 state_machine: StateMachine, 127 local_model_errors: list[list[float]], 128 output_path: Path, 129 fps: int, 130): 131 """Plot an animation of the given state machine.""" 132 max_error = max(max(errors) for errors in local_model_errors) 133 with tempfile.TemporaryDirectory() as tdir: 134 frame_dir = Path(tdir) 135 for i, errors in enumerate(local_model_errors): 136 state_path = frame_dir/(f"{i}.png").rjust(10) 137 plot_state_machine_errors_diagram( 138 state_machine=state_machine, 139 local_model_errors=errors, 140 max_error=max_error, 141 output_path=state_path, 142 ) 143 ( 144 ffmpeg 145 .input(frame_dir/"*.png", pattern_type="glob", framerate=fps) 146 .output(str(output_path)) 147 .run(quiet=True) 148 )
Plot an animation of the given state machine.
def
get_starts_widths(visited_states: list[int]) -> tuple[list[int], list[int], list[int]]:
151def get_starts_widths( 152 visited_states: list[int], 153) -> tuple[list[int], list[int], list[int]]: 154 """Helper function to compute the starts and widths for matplotlib's `barh` 155 function, for visualizing sequences of visited states.""" 156 starts = [0] 157 widths = list() 158 current_width = 0 159 colors = list() 160 previous_state = visited_states[0] 161 for i in range(1, len(visited_states)): 162 current_width += 1 163 current_state = visited_states[i] 164 if current_state != previous_state: 165 starts.append(i) 166 colors.append(previous_state) 167 widths.append(current_width) 168 current_width = 0 169 previous_state = current_state 170 # Trailing transition 171 if current_width != 0: 172 colors.append(visited_states[-1]) 173 widths.append(current_width) 174 else: 175 starts = starts[:-1] 176 return starts, widths, colors
Helper function to compute the starts and widths for matplotlib's barh
function, for visualizing sequences of visited states.
def
get_partition_colors() -> list[tuple[float, float, float, float]]:
179def get_partition_colors() -> list[tuple[float, float, float, float]]: 180 """Get colors with which partitions can be plotted.""" 181 cmap = mpl.colormaps['tab20b'] 182 colors = cmap(list(range(cmap.N))).tolist() 183 _random = random.Random("seed") 184 _random.shuffle(colors) 185 return colors
Get colors with which partitions can be plotted.
def
plot_state_machine_errors( state_machine: swmpo.state_machine.StateMachine, episode: list[swmpo.transition.Transition], initial_state: int, dt: float, ground_truth_visited_states: list[int] | None, output_dir: pathlib.Path):
188def plot_state_machine_errors( 189 state_machine: StateMachine, 190 episode: list[Transition], 191 initial_state: int, 192 dt: float, 193 ground_truth_visited_states: list[int] | None, 194 output_dir: Path, 195): 196 """Plot the error of the state machine for predicting the given episode.""" 197 # Run state machine 198 local_model_errors = get_local_model_errors( 199 state_machine=state_machine, 200 episode=episode, 201 dt=dt, 202 ) 203 state_machine_errors = get_state_machine_errors( 204 state_machine=state_machine, 205 episode=episode, 206 initial_state=initial_state, 207 dt=dt, 208 ) 209 visited_states = get_visited_states( 210 state_machine=state_machine, 211 initial_state=initial_state, 212 episode=episode, 213 dt=dt, 214 ) 215 216 # Create matplotlib figure 217 fig = Figure() 218 _ = FigureCanvas(fig) 219 220 # Choose model colors 221 partition_colors = get_partition_colors() 222 223 # Permute local models for easy visualization 224 if ground_truth_visited_states is not None: 225 # Get mode permutation 226 perm = _get_best_permutation( 227 sequence=visited_states, 228 ground_truth=ground_truth_visited_states, 229 indices=list(range(len(state_machine.local_models))), 230 ) 231 local_model_colors = [ 232 partition_colors[perm[i]] if i in perm.keys() else partition_colors[i] 233 for i in range(len(partition_colors)) 234 ] 235 else: 236 local_model_colors = list(partition_colors) 237 238 # Plot errors 239 ax1 = fig.add_subplot(2, 1, 1) 240 x = list(range(len(local_model_errors))) 241 for i in range(len(state_machine.local_models)): 242 model_errors = [ 243 local_model_errors[j][i] 244 for j in x 245 ] 246 color = local_model_colors[i] 247 ax1.plot(x, model_errors, label=f"Local model {i}", color=color) 248 x = list(range(len(state_machine_errors))) 249 ax1.plot( 250 x, 251 state_machine_errors, 252 color='black', 253 linestyle=':', 254 label="Full state machine", 255 ) 256 ax1.set_xlim(left=0, right=len(x)) 257 258 # Plot active states 259 ax2 = fig.add_subplot(2, 1, 2) 260 labels = ["FSM states"] 261 starts, widths, colors = get_starts_widths(visited_states) 262 color = [local_model_colors[i] for i in colors] 263 ax2.barh(labels, widths, left=starts, height=0.5, color=color) 264 ax2.set_xlim(left=0, right=len(x)) 265 266 # Plot "partition" 267 labels = ["Minimum-loss induced states"] 268 induced_visited_states = list() 269 for step_losses in local_model_errors: 270 indices = list(range(len(step_losses))) 271 min_loss_i = min(indices, key=lambda i: step_losses[i]) 272 induced_visited_states.append(min_loss_i) 273 starts, widths, colors = get_starts_widths(induced_visited_states) 274 color = [local_model_colors[i] for i in colors] 275 ax2.barh(labels, widths, left=starts, height=0.5, color=color) 276 ax2.set_xlim(left=0, right=len(x)) 277 278 # Plot ground truth 279 if ground_truth_visited_states is not None and len(ground_truth_visited_states) > 0: 280 labels = ["Ground truth states"] 281 starts, widths, colors = get_starts_widths(ground_truth_visited_states) 282 color = [partition_colors[i] for i in colors] 283 ax2.barh(labels, widths, left=starts, height=1.5, color=color) 284 285 # Save figure 286 fig_output_path = output_dir/f"state_machine_errors.svg" 287 fig.legend() 288 fig.suptitle('State machine prediction errors') 289 fig.tight_layout() 290 fig.savefig(fig_output_path) 291 292 # Save data 293 data_output_path = output_dir/f"data.json" 294 data = dict( 295 local_model_errors=local_model_errors, 296 state_machine_errors=state_machine_errors, 297 visited_states=visited_states, 298 ground_truth_visited_states=ground_truth_visited_states, 299 ) 300 with open(data_output_path, "wt") as fp: 301 json.dump(data, fp, indent=2)
Plot the error of the state machine for predicting the given episode.
def
get_state_machine_forecasting_errors( state_machine: swmpo.state_machine.StateMachine, episode: list[swmpo.transition.Transition], initial_state: int, dt: float) -> list[float]:
304def get_state_machine_forecasting_errors( 305 state_machine: StateMachine, 306 episode: list[Transition], 307 initial_state: int, 308 dt: float, 309) -> list[float]: 310 """Return the forecasting errors of the state machine over the given 311 episode (i.e., using the state machine from the initial 312 state to simulate the system and comparing with the ground-truth 313 states).""" 314 errors = list[float]() 315 316 state = episode[0].source_state 317 mode = initial_state 318 319 for transition in episode: 320 # Get current error 321 real = transition.source_state 322 step_error = (real - state).norm() 323 errors.append(step_error.item()) 324 325 # Step the simulation 326 state, mode = state_machine_model( 327 state_machine=state_machine, 328 state=state, 329 action=transition.action, 330 current_node=mode, 331 dt=dt, 332 ) 333 334 return errors
Return the forecasting errors of the state machine over the given episode (i.e., using the state machine from the initial state to simulate the system and comparing with the ground-truth states).
def
plot_state_machine_forecasting_errors( state_machine: swmpo.state_machine.StateMachine, episode: list[swmpo.transition.Transition], initial_state: int, dt: float, ground_truth_visited_states: list[int] | None, output_path: pathlib.Path):
337def plot_state_machine_forecasting_errors( 338 state_machine: StateMachine, 339 episode: list[Transition], 340 initial_state: int, 341 dt: float, 342 ground_truth_visited_states: list[int] | None, 343 output_path: Path, 344): 345 """Plot the error of the state machine for predicting the given episode.""" 346 # Run state machine 347 local_model_errors = get_local_model_errors( 348 state_machine=state_machine, 349 episode=episode, 350 dt=dt, 351 ) 352 forecasting_errors = get_state_machine_forecasting_errors( 353 state_machine=state_machine, 354 episode=episode, 355 initial_state=initial_state, 356 dt=dt, 357 ) 358 visited_states = get_visited_states( 359 state_machine=state_machine, 360 initial_state=initial_state, 361 episode=episode, 362 dt=dt, 363 ) 364 365 # Get state machine forecasting errors 366 367 # Create matplotlib figure 368 fig = Figure() 369 _ = FigureCanvas(fig) 370 371 # Choose model colors 372 partition_colors = get_partition_colors() 373 374 # Permute local models for easy visualization 375 if ground_truth_visited_states is not None: 376 # Get mode permutation 377 perm = _get_best_permutation( 378 sequence=visited_states, 379 ground_truth=ground_truth_visited_states, 380 indices=list(range(len(state_machine.local_models))), 381 ) 382 local_model_colors = [ 383 partition_colors[perm[i]] if i in perm.keys() else partition_colors[i] 384 for i in range(len(partition_colors)) 385 ] 386 else: 387 local_model_colors = list(partition_colors) 388 389 # Plot errors 390 ax1 = fig.add_subplot(2, 1, 1) 391 x = list(range(len(local_model_errors))) 392 ax1.plot( 393 x, 394 forecasting_errors, 395 color='black', 396 linestyle='-', 397 label="Full state machine", 398 ) 399 ax1.set_xlim(left=0, right=len(x)) 400 401 # Plot active states 402 ax2 = fig.add_subplot(2, 1, 2) 403 labels = ["FSM states"] 404 starts, widths, colors = get_starts_widths(visited_states) 405 color = [local_model_colors[i] for i in colors] 406 ax2.barh(labels, widths, left=starts, height=0.5, color=color) 407 ax2.set_xlim(left=0, right=len(x)) 408 409 # Plot "partition" 410 labels = ["Minimum-loss induced states"] 411 induced_visited_states = list() 412 for step_losses in local_model_errors: 413 indices = list(range(len(step_losses))) 414 min_loss_i = min(indices, key=lambda i: step_losses[i]) 415 induced_visited_states.append(min_loss_i) 416 starts, widths, colors = get_starts_widths(induced_visited_states) 417 color = [local_model_colors[i] for i in colors] 418 ax2.barh(labels, widths, left=starts, height=0.5, color=color) 419 ax2.set_xlim(left=0, right=len(x)) 420 421 # Plot ground truth 422 if ground_truth_visited_states is not None and len(ground_truth_visited_states) > 0: 423 labels = ["Ground truth states"] 424 starts, widths, colors = get_starts_widths(ground_truth_visited_states) 425 color = [partition_colors[i] for i in colors] 426 ax2.barh(labels, widths, left=starts, height=1.5, color=color) 427 428 # Save figure 429 fig.legend() 430 fig.suptitle('Long-horizon forecasting performance') 431 fig.tight_layout() 432 fig.savefig(output_path)
Plot the error of the state machine for predicting the given episode.