swmpo.partition
Partition of a dataset of state transitions guided by synthesis of local neural models.
1"""Partition of a dataset of state transitions guided by synthesis of 2local neural models.""" 3from dataclasses import dataclass 4from swmpo.transition import Transition 5from swmpo.transition import get_vector 6from swmpo.transition import serialize 7from swmpo.transition import deserialize 8from swmpo.world_models.model import get_input_output_size 9from swmpo.world_models.world_model import serialize_model 10from swmpo.world_models.world_model import deserialize_model 11from swmpo.world_models.world_model import WorldModel 12from swmpo.world_models.world_model import get_optimized_model 13from swmpo.world_models.mode_world_model import ModeWorldModel 14from swmpo.world_models.mode_world_model import get_predictive_residual_encoder 15from swmpo.world_models.mode_world_model import get_mode_vector 16from swmpo.transition_prunning.island_prunning import prune_short_transitions 17from sklearn.preprocessing import StandardScaler 18from functools import cached_property 19import sklearn.cluster 20import random 21import torch 22import tempfile 23from pathlib import Path 24import shutil 25import umap 26import json 27import zipfile 28from concurrent.futures import ThreadPoolExecutor 29 30# Avoid pytorch from doing threading. This is so that the script doesn't 31# take over the computer's resources. You can remove these lines if not running 32# on a lab computer. 33torch.set_num_threads(1) 34 35 36@dataclass 37class StatePartitionItem: 38 """A partition of a dataset of partitions.""" 39 local_model: WorldModel 40 subset: list[Transition] 41 hidden_sizes: list[int] 42 43 @property 44 def local_model_input_size(self) -> int: 45 input_size, _ = get_input_output_size(self.subset[0]) 46 return input_size 47 48 @property 49 def local_model_output_size(self) -> int: 50 _, output_size = get_input_output_size(self.subset[0]) 51 return output_size 52 53 @cached_property 54 def transition_vectors_as_set_of_tuples(self) -> set[tuple[float, ...]]: 55 vectors = set[tuple[float, ...]]() 56 for transition in self.subset: 57 x = get_vector(transition) 58 vectors.add(tuple[float, ...](x.tolist())) 59 return vectors 60 61 62def serialize_partition_item( 63 partition_item: StatePartitionItem, 64 output_zip_path: Path, 65): 66 """Serialize a partition item to a ZIP file.""" 67 assert output_zip_path.suffix == ".zip" 68 69 with tempfile.TemporaryDirectory() as tmpdirname: 70 output_dir = Path(tmpdirname) 71 72 # Serialize local model 73 model_path = output_dir/"model.zip" 74 serialize_model( 75 model=partition_item.local_model, 76 output_zip_path=model_path, 77 ) 78 79 # Serialize subset of transitions 80 subset_dir = output_dir/"transition_subset" 81 subset_dir.mkdir() 82 transition_directory = list[str]() 83 for i, transition in enumerate(partition_item.subset): 84 transition_zip = f"transition_{i}.zip" 85 serialize(transition, subset_dir/transition_zip) 86 transition_directory.append(str(transition_zip)) 87 88 # Serialize directory 89 transition_directory_json_path = output_dir/"transition_directory.json" 90 with open(transition_directory_json_path, "wt") as fp: 91 json.dump(transition_directory, fp, indent=2) 92 93 # Serialize hidden sizes 94 hidden_sizes_json_path = output_dir/"hidden_sizes.json" 95 with open(hidden_sizes_json_path, "wt") as fp: 96 json.dump(partition_item.hidden_sizes, fp, indent=2) 97 98 # ZIP directory 99 _ = shutil.make_archive( 100 str(output_zip_path.with_suffix("")), 101 'zip', 102 output_dir 103 ) 104 105 106def deserialize_partition_item( 107 zip_path: Path, 108) -> StatePartitionItem: 109 """Deserialize a partition item from a ZIP file.""" 110 with tempfile.TemporaryDirectory() as tmpdirname: 111 output_dir = Path(tmpdirname) 112 113 with zipfile.ZipFile(zip_path, "r") as zip_ref: 114 zip_ref.extractall(output_dir) 115 116 # Load model 117 model_path = output_dir/"model.zip" 118 local_model = deserialize_model(model_path) 119 120 # Load transition directory 121 transition_directory_json_path = output_dir/"transition_directory.json" 122 with open(transition_directory_json_path, "rt") as fp: 123 transition_directory = list[str](json.load(fp)) 124 125 # Load transitions 126 futures = list() 127 with ThreadPoolExecutor() as executor: 128 for i, transition_path in enumerate(transition_directory): 129 zip_path = output_dir/"transition_subset"/transition_path 130 future = executor.submit( 131 deserialize, 132 zip_path, 133 ) 134 futures.append(future) 135 136 subset = list[Transition]() 137 for i, future in enumerate(futures): 138 print(f"Deserializing partition item {i}/{len(futures)}") 139 transition = future.result() 140 subset.append(transition) 141 142 # Load hidden sizes 143 hidden_sizes_json_path = output_dir/"hidden_sizes.json" 144 with open(hidden_sizes_json_path, "rt") as fp: 145 hidden_sizes = list[int](json.load(fp)) 146 147 item = StatePartitionItem( 148 local_model=local_model, 149 subset=subset, 150 hidden_sizes=hidden_sizes, 151 ) 152 return item 153 154 155def serialize_partition( 156 partition: list[StatePartitionItem], 157 output_zip_path: Path, 158): 159 assert output_zip_path.suffix == ".zip" 160 161 with tempfile.TemporaryDirectory() as tmpdirname: 162 output_dir = Path(tmpdirname) 163 164 # Serialize each item 165 item_directory = list[str]() 166 for i, item in enumerate(partition): 167 item_path = f"item_{i}.zip" 168 serialize_partition_item( 169 partition_item=item, 170 output_zip_path=output_dir/item_path, 171 ) 172 item_directory.append(item_path) 173 174 # Serialize directory 175 item_directory_json_path = output_dir/"item_directory.json" 176 with open(item_directory_json_path, "wt") as fp: 177 json.dump(item_directory, fp, indent=2) 178 179 # ZIP directory 180 _ = shutil.make_archive( 181 str(output_zip_path.with_suffix("")), 182 'zip', 183 output_dir 184 ) 185 186 187def deserialize_partition(zip_path: Path) -> list[StatePartitionItem]: 188 items = list[StatePartitionItem]() 189 190 with tempfile.TemporaryDirectory() as tmpdirname: 191 output_dir = Path(tmpdirname) 192 193 with zipfile.ZipFile(zip_path, "r") as zip_ref: 194 zip_ref.extractall(output_dir) 195 196 # Load directory 197 item_directory_json_path = output_dir/"item_directory.json" 198 with open(item_directory_json_path, "rt") as fp: 199 item_paths = list[str](json.load(fp)) 200 201 # Load each partition item 202 for i, item_path in enumerate(item_paths): 203 zip_path = output_dir/item_path 204 item = deserialize_partition_item(zip_path) 205 items.append(item) 206 207 return items 208 209 210def deserialize_partition_item_local_model( 211 zip_path: Path, 212) -> WorldModel: 213 """Deserialize a partition item from a ZIP file.""" 214 with tempfile.TemporaryDirectory() as tmpdirname: 215 output_dir = Path(tmpdirname) 216 217 with zipfile.ZipFile(zip_path, "r") as zip_ref: 218 zip_ref.extract("model.zip", path=output_dir) 219 220 # Load model 221 model_path = output_dir/"model.zip" 222 local_model = deserialize_model(model_path) 223 return local_model 224 225 226def deserialize_partition_local_models(zip_path: Path) -> list[WorldModel]: 227 models = list[WorldModel]() 228 229 with tempfile.TemporaryDirectory() as tmpdirname: 230 output_dir = Path(tmpdirname) 231 232 with zipfile.ZipFile(zip_path, "r") as zip_ref: 233 zip_ref.extractall(output_dir) 234 235 # Load directory 236 item_directory_json_path = output_dir/"item_directory.json" 237 with open(item_directory_json_path, "rt") as fp: 238 item_paths = list[str](json.load(fp)) 239 240 # Load each partition item 241 for i, item_path in enumerate(item_paths): 242 zip_path = output_dir/item_path 243 model = deserialize_partition_item_local_model(zip_path) 244 models.append(model) 245 246 return models 247 248 249class PartitionSortingError(Exception): 250 pass 251 252 253def item_contains_transition( 254 item: StatePartitionItem, 255 transition: Transition, 256 ) -> bool: 257 """Return whether the transition appears 258 in the partition item.""" 259 x = tuple(get_vector(transition).tolist()) 260 return x in item.transition_vectors_as_set_of_tuples 261 262 263def get_initial_transition_n( 264 item: StatePartitionItem, 265 episodes: list[list[Transition]], 266 ) -> int: 267 """Return the number of times a transition occurs in the partition item.""" 268 # Extract initial transitions 269 initial_transitions = [ 270 episode[0] 271 for episode in episodes 272 if len(episode) > 0 273 ] 274 275 # Filter-in the initial transitions that 276 # appear in the partition item 277 occurrences = [ 278 initial_transition 279 for initial_transition in initial_transitions 280 if item_contains_transition( 281 item=item, 282 transition=initial_transition 283 ) 284 ] 285 return len(occurrences) 286 287 288def get_sorted_partition( 289 partition: list[StatePartitionItem], 290 episodes: list[list[Transition]], 291) -> list[StatePartitionItem]: 292 """Sort the partition so that the partition item with the most 293 initial transitions is first.""" 294 # Identify first item 295 sorted_partition = list(reversed(sorted( 296 partition, 297 key=lambda item: get_initial_transition_n( 298 item=item, 299 episodes=episodes 300 ) 301 ))) 302 return sorted_partition 303 304 305def get_partition_modes( 306 trajectory: list[Transition], 307 partition: list[StatePartitionItem], 308) -> list[int]: 309 """Return the list of indices of each transition in the trajectory.""" 310 modes = list[int]() 311 for transition in trajectory: 312 index = None 313 for i, item in enumerate(partition): 314 if item_contains_transition(item, transition): 315 index = i 316 assert index is not None, "Partition doesn't contain transition!" 317 modes.append(index) 318 return modes 319 320 321def get_clusters( 322 mode_world_model: ModeWorldModel, 323 trajectories: list[list[Transition]], 324 cluster_n: int, 325 min_island_size: int, 326 dimensionality_reduce: int | None, 327 seed: str, 328 device: str, 329) -> list[set[tuple[int, int]]]: 330 """Partition the given dataset of transitions into disjoint subsets. 331 The returned sets contain the indices of the transitions in the set.""" 332 _random = random.Random(seed) 333 334 # Bookkeeping 335 vector_indices = list[tuple[int, int]]() 336 location_index = dict[tuple[int, int], int]() 337 for i, trajectory in enumerate(trajectories): 338 for j, _ in enumerate(trajectory): 339 location = (i, j) 340 index = len(vector_indices) 341 vector_indices.append(location) 342 location_index[location] = index 343 344 # Get the latent vector for each transition 345 encoded_vectors = list[list[float]]() 346 for trajectory in trajectories: 347 for transition in trajectory: 348 embedding = get_mode_vector( 349 transition, 350 mode_world_model=mode_world_model, 351 device=device, 352 ) 353 encoded_vectors.append(embedding) 354 X = torch.tensor(encoded_vectors) 355 356 # Normalize embeddings 357 X = StandardScaler().fit_transform(X) 358 if dimensionality_reduce is not None: 359 reducer = umap.UMAP( 360 n_components=dimensionality_reduce, 361 random_state=int.from_bytes(_random.randbytes(3), 'big', signed=False), 362 ) 363 X = reducer.fit_transform(X) 364 365 # Cluster latent vectors 366 cluster = sklearn.cluster.KMeans( 367 n_clusters=cluster_n, 368 random_state=int.from_bytes(_random.randbytes(3), 'big', signed=False), 369 ) 370 labels = list[int](cluster.fit_predict(X)) 371 372 for trajectory in trajectories: 373 assert len(trajectory) > 0 374 375 # Prune short transitions 376 new_labels = list(labels) 377 for i, trajectory in enumerate(trajectories): 378 # Reconstruct sequence of assigned modes 379 modes = list[int]() 380 for j, transition in enumerate(trajectory): 381 location = (i, j) 382 index = location_index[location] 383 mode = labels[index] 384 modes.append(mode) 385 386 # Prune sequence of modes 387 new_modes = prune_short_transitions(modes, min_island_size) 388 389 # Add new labels 390 for j, new_mode in enumerate(new_modes): 391 location = (i, j) 392 index = location_index[location] 393 new_labels[index] = new_mode 394 labels = new_labels 395 396 # Assemble clusters 397 clusters = [ 398 set[tuple[int, int]]() 399 for _ in range(cluster_n) 400 ] 401 for i, cluster_i in enumerate(labels): 402 location = vector_indices[i] 403 clusters[cluster_i].add(location) 404 405 # Remove empty clusters 406 clusters = [ 407 cluster 408 for cluster in clusters 409 if len(cluster) > 0 410 ] 411 return clusters 412 413 414@dataclass 415class OptimizationResult: 416 partition: list[StatePartitionItem] 417 loss_log: list[float] 418 mode_world_model: ModeWorldModel 419 420 421def get_optimized_error_partition( 422 trajectories: list[list[Transition]], 423 hidden_sizes: list[int], 424 learning_rate: float, 425 latent_size: int, 426 iter_n: int, 427 clustering_dimensionality_reduce: int | None, 428 clustering_information_content_regularization_scale: float, 429 clustering_mutual_information_regularization_scale: float, 430 local_model_hyperparameters: dict[str, str | int | float], 431 dt: float, 432 size: int, 433 seed: str, 434 min_island_size: int, 435 batch_size: int, 436 mutual_information_mini_batch_size: int, 437 device: str, 438 verbose: bool, 439) -> OptimizationResult: 440 """Helper function to optimize the partition ensemble by error weighting.""" 441 _random = random.Random(seed) 442 443 # Normalize trajectory vectors 444 if verbose: 445 print("Normalizing trajectories for mode_world_model") 446 447 # Train mode_world_model 448 mode_world_model = get_predictive_residual_encoder( 449 trajectories=trajectories, 450 hidden_sizes=hidden_sizes, 451 learning_rate=learning_rate, 452 latent_size=latent_size, 453 iter_n=iter_n, 454 seed=str(_random.random()), 455 dt=dt, 456 batch_size=batch_size, 457 information_content_regularization_scale=clustering_information_content_regularization_scale, 458 mutual_information_regularization_scale=clustering_mutual_information_regularization_scale, 459 mutual_information_mini_batch_size=mutual_information_mini_batch_size, 460 device=device, 461 verbose=verbose, 462 ) 463 464 # Cluster latent state 465 clusters = get_clusters( 466 mode_world_model=mode_world_model, 467 trajectories=trajectories, 468 cluster_n=size, 469 min_island_size=min_island_size, 470 dimensionality_reduce=clustering_dimensionality_reduce, 471 seed=str(_random.random()), 472 device=device, 473 ) 474 475 # Assemble partition items 476 partition = list[StatePartitionItem]() 477 for cluster in clusters: 478 479 cluster_sub_trajectories = list[list[Transition]]() 480 481 for i, trajectory in enumerate(trajectories): 482 sub_trajectory = list[Transition]() 483 484 for j, t in enumerate(trajectory): 485 if (i, j) not in cluster: 486 if len(sub_trajectory) > 0: 487 # We have transitioned out of the cluster 488 cluster_sub_trajectories.append(sub_trajectory) 489 sub_trajectory = list[Transition]() 490 else: 491 # We continue building the sub-trajectory 492 sub_trajectory.append(t) 493 494 if len(sub_trajectory) > 1: 495 cluster_sub_trajectories.append(sub_trajectory) 496 497 subset = [ 498 trajectories[i][j] 499 for (i, j) in cluster 500 ] 501 model = get_optimized_model( 502 hyperparameters=local_model_hyperparameters, 503 trajectories=cluster_sub_trajectories, 504 dt=dt, 505 seed=str(_random.random()), 506 verbose=verbose, 507 ) 508 partition_item = StatePartitionItem( 509 local_model=model, 510 subset=subset, 511 hidden_sizes=hidden_sizes, 512 ) 513 partition.append(partition_item) 514 515 # Assemble result 516 loss_log = mode_world_model.loss_log 517 optimization_result = OptimizationResult( 518 partition=partition, 519 loss_log=loss_log, 520 mode_world_model=mode_world_model, 521 ) 522 return optimization_result 523 524 525def get_partition( 526 episodes: list[list[Transition]], 527 hidden_sizes: list[int], 528 learning_rate: float, 529 latent_size: int, 530 mode_model_iter_n: int, 531 clustering_dimensionality_reduce: int | None, 532 clustering_information_content_regularization_scale: float, 533 clustering_mutual_information_regularization_scale: float, 534 dt: float, 535 size: int, 536 min_island_size: int, 537 seed: str, 538 batch_size: int, 539 local_model_hyperparameters: dict[str, str | int | float], 540 mutual_information_mini_batch_size: int, 541 device: str, 542 verbose: bool, 543) -> OptimizationResult: 544 """Returns a partition of the set of transitions in the given episodes. 545 Each subset of the partition has a corresponding neural model of the 546 dynamics in that subset. 547 548 The returned partition will be sorted so that the first subset contains 549 the maximum number of initial transitions from the input episodes. 550 """ 551 _random = random.Random(seed) 552 553 # Optimize smooth partition 554 optimization_result = get_optimized_error_partition( 555 trajectories=episodes, 556 hidden_sizes=hidden_sizes, 557 learning_rate=learning_rate, 558 iter_n=mode_model_iter_n, 559 dt=dt, 560 latent_size=latent_size, 561 clustering_dimensionality_reduce=clustering_dimensionality_reduce, 562 clustering_information_content_regularization_scale=clustering_information_content_regularization_scale, 563 clustering_mutual_information_regularization_scale=clustering_mutual_information_regularization_scale, 564 local_model_hyperparameters=local_model_hyperparameters, 565 size=size, 566 verbose=verbose, 567 batch_size=batch_size, 568 mutual_information_mini_batch_size=mutual_information_mini_batch_size, 569 min_island_size=min_island_size, 570 device=device, 571 seed=str(_random.random()), 572 ) 573 574 # Sort state partition so that the first item contains the first 575 # transition of the episodes 576 state_partition = get_sorted_partition( 577 episodes=episodes, 578 partition=optimization_result.partition, 579 ) 580 581 optimization_result = OptimizationResult( 582 partition=state_partition, 583 loss_log=optimization_result.loss_log, 584 mode_world_model=optimization_result.mode_world_model, 585 ) 586 return optimization_result
37@dataclass 38class StatePartitionItem: 39 """A partition of a dataset of partitions.""" 40 local_model: WorldModel 41 subset: list[Transition] 42 hidden_sizes: list[int] 43 44 @property 45 def local_model_input_size(self) -> int: 46 input_size, _ = get_input_output_size(self.subset[0]) 47 return input_size 48 49 @property 50 def local_model_output_size(self) -> int: 51 _, output_size = get_input_output_size(self.subset[0]) 52 return output_size 53 54 @cached_property 55 def transition_vectors_as_set_of_tuples(self) -> set[tuple[float, ...]]: 56 vectors = set[tuple[float, ...]]() 57 for transition in self.subset: 58 x = get_vector(transition) 59 vectors.add(tuple[float, ...](x.tolist())) 60 return vectors
A partition of a dataset of partitions.
63def serialize_partition_item( 64 partition_item: StatePartitionItem, 65 output_zip_path: Path, 66): 67 """Serialize a partition item to a ZIP file.""" 68 assert output_zip_path.suffix == ".zip" 69 70 with tempfile.TemporaryDirectory() as tmpdirname: 71 output_dir = Path(tmpdirname) 72 73 # Serialize local model 74 model_path = output_dir/"model.zip" 75 serialize_model( 76 model=partition_item.local_model, 77 output_zip_path=model_path, 78 ) 79 80 # Serialize subset of transitions 81 subset_dir = output_dir/"transition_subset" 82 subset_dir.mkdir() 83 transition_directory = list[str]() 84 for i, transition in enumerate(partition_item.subset): 85 transition_zip = f"transition_{i}.zip" 86 serialize(transition, subset_dir/transition_zip) 87 transition_directory.append(str(transition_zip)) 88 89 # Serialize directory 90 transition_directory_json_path = output_dir/"transition_directory.json" 91 with open(transition_directory_json_path, "wt") as fp: 92 json.dump(transition_directory, fp, indent=2) 93 94 # Serialize hidden sizes 95 hidden_sizes_json_path = output_dir/"hidden_sizes.json" 96 with open(hidden_sizes_json_path, "wt") as fp: 97 json.dump(partition_item.hidden_sizes, fp, indent=2) 98 99 # ZIP directory 100 _ = shutil.make_archive( 101 str(output_zip_path.with_suffix("")), 102 'zip', 103 output_dir 104 )
Serialize a partition item to a ZIP file.
107def deserialize_partition_item( 108 zip_path: Path, 109) -> StatePartitionItem: 110 """Deserialize a partition item from a ZIP file.""" 111 with tempfile.TemporaryDirectory() as tmpdirname: 112 output_dir = Path(tmpdirname) 113 114 with zipfile.ZipFile(zip_path, "r") as zip_ref: 115 zip_ref.extractall(output_dir) 116 117 # Load model 118 model_path = output_dir/"model.zip" 119 local_model = deserialize_model(model_path) 120 121 # Load transition directory 122 transition_directory_json_path = output_dir/"transition_directory.json" 123 with open(transition_directory_json_path, "rt") as fp: 124 transition_directory = list[str](json.load(fp)) 125 126 # Load transitions 127 futures = list() 128 with ThreadPoolExecutor() as executor: 129 for i, transition_path in enumerate(transition_directory): 130 zip_path = output_dir/"transition_subset"/transition_path 131 future = executor.submit( 132 deserialize, 133 zip_path, 134 ) 135 futures.append(future) 136 137 subset = list[Transition]() 138 for i, future in enumerate(futures): 139 print(f"Deserializing partition item {i}/{len(futures)}") 140 transition = future.result() 141 subset.append(transition) 142 143 # Load hidden sizes 144 hidden_sizes_json_path = output_dir/"hidden_sizes.json" 145 with open(hidden_sizes_json_path, "rt") as fp: 146 hidden_sizes = list[int](json.load(fp)) 147 148 item = StatePartitionItem( 149 local_model=local_model, 150 subset=subset, 151 hidden_sizes=hidden_sizes, 152 ) 153 return item
Deserialize a partition item from a ZIP file.
156def serialize_partition( 157 partition: list[StatePartitionItem], 158 output_zip_path: Path, 159): 160 assert output_zip_path.suffix == ".zip" 161 162 with tempfile.TemporaryDirectory() as tmpdirname: 163 output_dir = Path(tmpdirname) 164 165 # Serialize each item 166 item_directory = list[str]() 167 for i, item in enumerate(partition): 168 item_path = f"item_{i}.zip" 169 serialize_partition_item( 170 partition_item=item, 171 output_zip_path=output_dir/item_path, 172 ) 173 item_directory.append(item_path) 174 175 # Serialize directory 176 item_directory_json_path = output_dir/"item_directory.json" 177 with open(item_directory_json_path, "wt") as fp: 178 json.dump(item_directory, fp, indent=2) 179 180 # ZIP directory 181 _ = shutil.make_archive( 182 str(output_zip_path.with_suffix("")), 183 'zip', 184 output_dir 185 )
188def deserialize_partition(zip_path: Path) -> list[StatePartitionItem]: 189 items = list[StatePartitionItem]() 190 191 with tempfile.TemporaryDirectory() as tmpdirname: 192 output_dir = Path(tmpdirname) 193 194 with zipfile.ZipFile(zip_path, "r") as zip_ref: 195 zip_ref.extractall(output_dir) 196 197 # Load directory 198 item_directory_json_path = output_dir/"item_directory.json" 199 with open(item_directory_json_path, "rt") as fp: 200 item_paths = list[str](json.load(fp)) 201 202 # Load each partition item 203 for i, item_path in enumerate(item_paths): 204 zip_path = output_dir/item_path 205 item = deserialize_partition_item(zip_path) 206 items.append(item) 207 208 return items
211def deserialize_partition_item_local_model( 212 zip_path: Path, 213) -> WorldModel: 214 """Deserialize a partition item from a ZIP file.""" 215 with tempfile.TemporaryDirectory() as tmpdirname: 216 output_dir = Path(tmpdirname) 217 218 with zipfile.ZipFile(zip_path, "r") as zip_ref: 219 zip_ref.extract("model.zip", path=output_dir) 220 221 # Load model 222 model_path = output_dir/"model.zip" 223 local_model = deserialize_model(model_path) 224 return local_model
Deserialize a partition item from a ZIP file.
227def deserialize_partition_local_models(zip_path: Path) -> list[WorldModel]: 228 models = list[WorldModel]() 229 230 with tempfile.TemporaryDirectory() as tmpdirname: 231 output_dir = Path(tmpdirname) 232 233 with zipfile.ZipFile(zip_path, "r") as zip_ref: 234 zip_ref.extractall(output_dir) 235 236 # Load directory 237 item_directory_json_path = output_dir/"item_directory.json" 238 with open(item_directory_json_path, "rt") as fp: 239 item_paths = list[str](json.load(fp)) 240 241 # Load each partition item 242 for i, item_path in enumerate(item_paths): 243 zip_path = output_dir/item_path 244 model = deserialize_partition_item_local_model(zip_path) 245 models.append(model) 246 247 return models
Common base class for all non-exit exceptions.
254def item_contains_transition( 255 item: StatePartitionItem, 256 transition: Transition, 257 ) -> bool: 258 """Return whether the transition appears 259 in the partition item.""" 260 x = tuple(get_vector(transition).tolist()) 261 return x in item.transition_vectors_as_set_of_tuples
Return whether the transition appears in the partition item.
264def get_initial_transition_n( 265 item: StatePartitionItem, 266 episodes: list[list[Transition]], 267 ) -> int: 268 """Return the number of times a transition occurs in the partition item.""" 269 # Extract initial transitions 270 initial_transitions = [ 271 episode[0] 272 for episode in episodes 273 if len(episode) > 0 274 ] 275 276 # Filter-in the initial transitions that 277 # appear in the partition item 278 occurrences = [ 279 initial_transition 280 for initial_transition in initial_transitions 281 if item_contains_transition( 282 item=item, 283 transition=initial_transition 284 ) 285 ] 286 return len(occurrences)
Return the number of times a transition occurs in the partition item.
289def get_sorted_partition( 290 partition: list[StatePartitionItem], 291 episodes: list[list[Transition]], 292) -> list[StatePartitionItem]: 293 """Sort the partition so that the partition item with the most 294 initial transitions is first.""" 295 # Identify first item 296 sorted_partition = list(reversed(sorted( 297 partition, 298 key=lambda item: get_initial_transition_n( 299 item=item, 300 episodes=episodes 301 ) 302 ))) 303 return sorted_partition
Sort the partition so that the partition item with the most initial transitions is first.
306def get_partition_modes( 307 trajectory: list[Transition], 308 partition: list[StatePartitionItem], 309) -> list[int]: 310 """Return the list of indices of each transition in the trajectory.""" 311 modes = list[int]() 312 for transition in trajectory: 313 index = None 314 for i, item in enumerate(partition): 315 if item_contains_transition(item, transition): 316 index = i 317 assert index is not None, "Partition doesn't contain transition!" 318 modes.append(index) 319 return modes
Return the list of indices of each transition in the trajectory.
322def get_clusters( 323 mode_world_model: ModeWorldModel, 324 trajectories: list[list[Transition]], 325 cluster_n: int, 326 min_island_size: int, 327 dimensionality_reduce: int | None, 328 seed: str, 329 device: str, 330) -> list[set[tuple[int, int]]]: 331 """Partition the given dataset of transitions into disjoint subsets. 332 The returned sets contain the indices of the transitions in the set.""" 333 _random = random.Random(seed) 334 335 # Bookkeeping 336 vector_indices = list[tuple[int, int]]() 337 location_index = dict[tuple[int, int], int]() 338 for i, trajectory in enumerate(trajectories): 339 for j, _ in enumerate(trajectory): 340 location = (i, j) 341 index = len(vector_indices) 342 vector_indices.append(location) 343 location_index[location] = index 344 345 # Get the latent vector for each transition 346 encoded_vectors = list[list[float]]() 347 for trajectory in trajectories: 348 for transition in trajectory: 349 embedding = get_mode_vector( 350 transition, 351 mode_world_model=mode_world_model, 352 device=device, 353 ) 354 encoded_vectors.append(embedding) 355 X = torch.tensor(encoded_vectors) 356 357 # Normalize embeddings 358 X = StandardScaler().fit_transform(X) 359 if dimensionality_reduce is not None: 360 reducer = umap.UMAP( 361 n_components=dimensionality_reduce, 362 random_state=int.from_bytes(_random.randbytes(3), 'big', signed=False), 363 ) 364 X = reducer.fit_transform(X) 365 366 # Cluster latent vectors 367 cluster = sklearn.cluster.KMeans( 368 n_clusters=cluster_n, 369 random_state=int.from_bytes(_random.randbytes(3), 'big', signed=False), 370 ) 371 labels = list[int](cluster.fit_predict(X)) 372 373 for trajectory in trajectories: 374 assert len(trajectory) > 0 375 376 # Prune short transitions 377 new_labels = list(labels) 378 for i, trajectory in enumerate(trajectories): 379 # Reconstruct sequence of assigned modes 380 modes = list[int]() 381 for j, transition in enumerate(trajectory): 382 location = (i, j) 383 index = location_index[location] 384 mode = labels[index] 385 modes.append(mode) 386 387 # Prune sequence of modes 388 new_modes = prune_short_transitions(modes, min_island_size) 389 390 # Add new labels 391 for j, new_mode in enumerate(new_modes): 392 location = (i, j) 393 index = location_index[location] 394 new_labels[index] = new_mode 395 labels = new_labels 396 397 # Assemble clusters 398 clusters = [ 399 set[tuple[int, int]]() 400 for _ in range(cluster_n) 401 ] 402 for i, cluster_i in enumerate(labels): 403 location = vector_indices[i] 404 clusters[cluster_i].add(location) 405 406 # Remove empty clusters 407 clusters = [ 408 cluster 409 for cluster in clusters 410 if len(cluster) > 0 411 ] 412 return clusters
Partition the given dataset of transitions into disjoint subsets. The returned sets contain the indices of the transitions in the set.
415@dataclass 416class OptimizationResult: 417 partition: list[StatePartitionItem] 418 loss_log: list[float] 419 mode_world_model: ModeWorldModel
422def get_optimized_error_partition( 423 trajectories: list[list[Transition]], 424 hidden_sizes: list[int], 425 learning_rate: float, 426 latent_size: int, 427 iter_n: int, 428 clustering_dimensionality_reduce: int | None, 429 clustering_information_content_regularization_scale: float, 430 clustering_mutual_information_regularization_scale: float, 431 local_model_hyperparameters: dict[str, str | int | float], 432 dt: float, 433 size: int, 434 seed: str, 435 min_island_size: int, 436 batch_size: int, 437 mutual_information_mini_batch_size: int, 438 device: str, 439 verbose: bool, 440) -> OptimizationResult: 441 """Helper function to optimize the partition ensemble by error weighting.""" 442 _random = random.Random(seed) 443 444 # Normalize trajectory vectors 445 if verbose: 446 print("Normalizing trajectories for mode_world_model") 447 448 # Train mode_world_model 449 mode_world_model = get_predictive_residual_encoder( 450 trajectories=trajectories, 451 hidden_sizes=hidden_sizes, 452 learning_rate=learning_rate, 453 latent_size=latent_size, 454 iter_n=iter_n, 455 seed=str(_random.random()), 456 dt=dt, 457 batch_size=batch_size, 458 information_content_regularization_scale=clustering_information_content_regularization_scale, 459 mutual_information_regularization_scale=clustering_mutual_information_regularization_scale, 460 mutual_information_mini_batch_size=mutual_information_mini_batch_size, 461 device=device, 462 verbose=verbose, 463 ) 464 465 # Cluster latent state 466 clusters = get_clusters( 467 mode_world_model=mode_world_model, 468 trajectories=trajectories, 469 cluster_n=size, 470 min_island_size=min_island_size, 471 dimensionality_reduce=clustering_dimensionality_reduce, 472 seed=str(_random.random()), 473 device=device, 474 ) 475 476 # Assemble partition items 477 partition = list[StatePartitionItem]() 478 for cluster in clusters: 479 480 cluster_sub_trajectories = list[list[Transition]]() 481 482 for i, trajectory in enumerate(trajectories): 483 sub_trajectory = list[Transition]() 484 485 for j, t in enumerate(trajectory): 486 if (i, j) not in cluster: 487 if len(sub_trajectory) > 0: 488 # We have transitioned out of the cluster 489 cluster_sub_trajectories.append(sub_trajectory) 490 sub_trajectory = list[Transition]() 491 else: 492 # We continue building the sub-trajectory 493 sub_trajectory.append(t) 494 495 if len(sub_trajectory) > 1: 496 cluster_sub_trajectories.append(sub_trajectory) 497 498 subset = [ 499 trajectories[i][j] 500 for (i, j) in cluster 501 ] 502 model = get_optimized_model( 503 hyperparameters=local_model_hyperparameters, 504 trajectories=cluster_sub_trajectories, 505 dt=dt, 506 seed=str(_random.random()), 507 verbose=verbose, 508 ) 509 partition_item = StatePartitionItem( 510 local_model=model, 511 subset=subset, 512 hidden_sizes=hidden_sizes, 513 ) 514 partition.append(partition_item) 515 516 # Assemble result 517 loss_log = mode_world_model.loss_log 518 optimization_result = OptimizationResult( 519 partition=partition, 520 loss_log=loss_log, 521 mode_world_model=mode_world_model, 522 ) 523 return optimization_result
Helper function to optimize the partition ensemble by error weighting.
526def get_partition( 527 episodes: list[list[Transition]], 528 hidden_sizes: list[int], 529 learning_rate: float, 530 latent_size: int, 531 mode_model_iter_n: int, 532 clustering_dimensionality_reduce: int | None, 533 clustering_information_content_regularization_scale: float, 534 clustering_mutual_information_regularization_scale: float, 535 dt: float, 536 size: int, 537 min_island_size: int, 538 seed: str, 539 batch_size: int, 540 local_model_hyperparameters: dict[str, str | int | float], 541 mutual_information_mini_batch_size: int, 542 device: str, 543 verbose: bool, 544) -> OptimizationResult: 545 """Returns a partition of the set of transitions in the given episodes. 546 Each subset of the partition has a corresponding neural model of the 547 dynamics in that subset. 548 549 The returned partition will be sorted so that the first subset contains 550 the maximum number of initial transitions from the input episodes. 551 """ 552 _random = random.Random(seed) 553 554 # Optimize smooth partition 555 optimization_result = get_optimized_error_partition( 556 trajectories=episodes, 557 hidden_sizes=hidden_sizes, 558 learning_rate=learning_rate, 559 iter_n=mode_model_iter_n, 560 dt=dt, 561 latent_size=latent_size, 562 clustering_dimensionality_reduce=clustering_dimensionality_reduce, 563 clustering_information_content_regularization_scale=clustering_information_content_regularization_scale, 564 clustering_mutual_information_regularization_scale=clustering_mutual_information_regularization_scale, 565 local_model_hyperparameters=local_model_hyperparameters, 566 size=size, 567 verbose=verbose, 568 batch_size=batch_size, 569 mutual_information_mini_batch_size=mutual_information_mini_batch_size, 570 min_island_size=min_island_size, 571 device=device, 572 seed=str(_random.random()), 573 ) 574 575 # Sort state partition so that the first item contains the first 576 # transition of the episodes 577 state_partition = get_sorted_partition( 578 episodes=episodes, 579 partition=optimization_result.partition, 580 ) 581 582 optimization_result = OptimizationResult( 583 partition=state_partition, 584 loss_log=optimization_result.loss_log, 585 mode_world_model=optimization_result.mode_world_model, 586 ) 587 return optimization_result
Returns a partition of the set of transitions in the given episodes. Each subset of the partition has a corresponding neural model of the dynamics in that subset.
The returned partition will be sorted so that the first subset contains the maximum number of initial transitions from the input episodes.