swmpo.partition

Partition of a dataset of state transitions guided by synthesis of local neural models.

  1"""Partition of a dataset of state transitions guided by synthesis of
  2local neural models."""
  3from dataclasses import dataclass
  4from swmpo.transition import Transition
  5from swmpo.transition import get_vector
  6from swmpo.transition import serialize
  7from swmpo.transition import deserialize
  8from swmpo.world_models.model import get_input_output_size
  9from swmpo.world_models.world_model import serialize_model
 10from swmpo.world_models.world_model import deserialize_model
 11from swmpo.world_models.world_model import WorldModel
 12from swmpo.world_models.world_model import get_optimized_model
 13from swmpo.world_models.mode_world_model import ModeWorldModel
 14from swmpo.world_models.mode_world_model import get_predictive_residual_encoder
 15from swmpo.world_models.mode_world_model import get_mode_vector
 16from swmpo.transition_prunning.island_prunning import prune_short_transitions
 17from sklearn.preprocessing import StandardScaler
 18from functools import cached_property
 19import sklearn.cluster
 20import random
 21import torch
 22import tempfile
 23from pathlib import Path
 24import shutil
 25import umap
 26import json
 27import zipfile
 28from concurrent.futures import ThreadPoolExecutor
 29
 30# Avoid pytorch from doing threading. This is so that the script doesn't
 31# take over the computer's resources. You can remove these lines if not running
 32# on a lab computer.
 33torch.set_num_threads(1)
 34
 35
 36@dataclass
 37class StatePartitionItem:
 38    """A partition of a dataset of partitions."""
 39    local_model: WorldModel
 40    subset: list[Transition]
 41    hidden_sizes: list[int]
 42
 43    @property
 44    def local_model_input_size(self) -> int:
 45        input_size, _ = get_input_output_size(self.subset[0])
 46        return input_size
 47
 48    @property
 49    def local_model_output_size(self) -> int:
 50        _, output_size = get_input_output_size(self.subset[0])
 51        return output_size
 52
 53    @cached_property
 54    def transition_vectors_as_set_of_tuples(self) -> set[tuple[float, ...]]:
 55        vectors = set[tuple[float, ...]]()
 56        for transition in self.subset:
 57            x = get_vector(transition)
 58            vectors.add(tuple[float, ...](x.tolist()))
 59        return vectors
 60
 61
 62def serialize_partition_item(
 63    partition_item: StatePartitionItem,
 64    output_zip_path: Path,
 65):
 66    """Serialize a partition item to a ZIP file."""
 67    assert output_zip_path.suffix == ".zip"
 68
 69    with tempfile.TemporaryDirectory() as tmpdirname:
 70        output_dir = Path(tmpdirname)
 71
 72        # Serialize local model
 73        model_path = output_dir/"model.zip"
 74        serialize_model(
 75            model=partition_item.local_model,
 76            output_zip_path=model_path,
 77        )
 78
 79        # Serialize subset of transitions
 80        subset_dir = output_dir/"transition_subset"
 81        subset_dir.mkdir()
 82        transition_directory = list[str]()
 83        for i, transition in enumerate(partition_item.subset):
 84            transition_zip = f"transition_{i}.zip"
 85            serialize(transition, subset_dir/transition_zip)
 86            transition_directory.append(str(transition_zip))
 87
 88        # Serialize directory
 89        transition_directory_json_path = output_dir/"transition_directory.json"
 90        with open(transition_directory_json_path, "wt") as fp:
 91            json.dump(transition_directory, fp, indent=2)
 92
 93        # Serialize hidden sizes
 94        hidden_sizes_json_path = output_dir/"hidden_sizes.json"
 95        with open(hidden_sizes_json_path, "wt") as fp:
 96            json.dump(partition_item.hidden_sizes, fp, indent=2)
 97
 98        # ZIP directory
 99        _ = shutil.make_archive(
100            str(output_zip_path.with_suffix("")),
101            'zip',
102            output_dir
103        )
104
105
106def deserialize_partition_item(
107    zip_path: Path,
108) -> StatePartitionItem:
109    """Deserialize a partition item from a ZIP file."""
110    with tempfile.TemporaryDirectory() as tmpdirname:
111        output_dir = Path(tmpdirname)
112
113        with zipfile.ZipFile(zip_path, "r") as zip_ref:
114            zip_ref.extractall(output_dir)
115
116        # Load model
117        model_path = output_dir/"model.zip"
118        local_model = deserialize_model(model_path)
119
120        # Load transition directory
121        transition_directory_json_path = output_dir/"transition_directory.json"
122        with open(transition_directory_json_path, "rt") as fp:
123            transition_directory = list[str](json.load(fp))
124
125        # Load transitions
126        futures = list()
127        with ThreadPoolExecutor() as executor:
128            for i, transition_path in enumerate(transition_directory):
129                zip_path = output_dir/"transition_subset"/transition_path
130                future = executor.submit(
131                    deserialize,
132                    zip_path,
133                )
134                futures.append(future)
135
136        subset = list[Transition]()
137        for i, future in enumerate(futures):
138            print(f"Deserializing partition item {i}/{len(futures)}")
139            transition = future.result()
140            subset.append(transition)
141
142        # Load hidden sizes
143        hidden_sizes_json_path = output_dir/"hidden_sizes.json"
144        with open(hidden_sizes_json_path, "rt") as fp:
145            hidden_sizes = list[int](json.load(fp))
146
147    item = StatePartitionItem(
148        local_model=local_model,
149        subset=subset,
150        hidden_sizes=hidden_sizes,
151    )
152    return item
153
154
155def serialize_partition(
156    partition: list[StatePartitionItem],
157    output_zip_path: Path,
158):
159    assert output_zip_path.suffix == ".zip"
160
161    with tempfile.TemporaryDirectory() as tmpdirname:
162        output_dir = Path(tmpdirname)
163
164        # Serialize each item
165        item_directory = list[str]()
166        for i, item in enumerate(partition):
167            item_path = f"item_{i}.zip"
168            serialize_partition_item(
169                partition_item=item,
170                output_zip_path=output_dir/item_path,
171            )
172            item_directory.append(item_path)
173
174        # Serialize directory
175        item_directory_json_path = output_dir/"item_directory.json"
176        with open(item_directory_json_path, "wt") as fp:
177            json.dump(item_directory, fp, indent=2)
178
179        # ZIP directory
180        _ = shutil.make_archive(
181            str(output_zip_path.with_suffix("")),
182            'zip',
183            output_dir
184        )
185
186
187def deserialize_partition(zip_path: Path) -> list[StatePartitionItem]:
188    items = list[StatePartitionItem]()
189
190    with tempfile.TemporaryDirectory() as tmpdirname:
191        output_dir = Path(tmpdirname)
192
193        with zipfile.ZipFile(zip_path, "r") as zip_ref:
194            zip_ref.extractall(output_dir)
195
196        # Load directory
197        item_directory_json_path = output_dir/"item_directory.json"
198        with open(item_directory_json_path, "rt") as fp:
199            item_paths = list[str](json.load(fp))
200
201        # Load each partition item
202        for i, item_path in enumerate(item_paths):
203            zip_path = output_dir/item_path
204            item = deserialize_partition_item(zip_path)
205            items.append(item)
206
207    return items
208
209
210def deserialize_partition_item_local_model(
211    zip_path: Path,
212) -> WorldModel:
213    """Deserialize a partition item from a ZIP file."""
214    with tempfile.TemporaryDirectory() as tmpdirname:
215        output_dir = Path(tmpdirname)
216
217        with zipfile.ZipFile(zip_path, "r") as zip_ref:
218            zip_ref.extract("model.zip", path=output_dir)
219
220        # Load model
221        model_path = output_dir/"model.zip"
222        local_model = deserialize_model(model_path)
223    return local_model
224
225
226def deserialize_partition_local_models(zip_path: Path) -> list[WorldModel]:
227    models = list[WorldModel]()
228
229    with tempfile.TemporaryDirectory() as tmpdirname:
230        output_dir = Path(tmpdirname)
231
232        with zipfile.ZipFile(zip_path, "r") as zip_ref:
233            zip_ref.extractall(output_dir)
234
235        # Load directory
236        item_directory_json_path = output_dir/"item_directory.json"
237        with open(item_directory_json_path, "rt") as fp:
238            item_paths = list[str](json.load(fp))
239
240        # Load each partition item
241        for i, item_path in enumerate(item_paths):
242            zip_path = output_dir/item_path
243            model = deserialize_partition_item_local_model(zip_path)
244            models.append(model)
245
246    return models
247
248
249class PartitionSortingError(Exception):
250    pass
251
252
253def item_contains_transition(
254        item: StatePartitionItem,
255        transition: Transition,
256        ) -> bool:
257    """Return whether the transition appears
258    in the partition item."""
259    x = tuple(get_vector(transition).tolist())
260    return x in item.transition_vectors_as_set_of_tuples
261
262
263def get_initial_transition_n(
264        item: StatePartitionItem,
265        episodes: list[list[Transition]],
266        ) -> int:
267    """Return the number of times a transition occurs in the partition item."""
268    # Extract initial transitions
269    initial_transitions = [
270        episode[0]
271        for episode in episodes
272        if len(episode) > 0
273    ]
274
275    # Filter-in the initial transitions that
276    # appear in the partition item
277    occurrences = [
278        initial_transition
279        for initial_transition in initial_transitions
280        if item_contains_transition(
281            item=item,
282            transition=initial_transition
283        )
284    ]
285    return len(occurrences)
286
287
288def get_sorted_partition(
289    partition: list[StatePartitionItem],
290    episodes: list[list[Transition]],
291) -> list[StatePartitionItem]:
292    """Sort the partition so that the partition item with the most
293    initial transitions is first."""
294    # Identify first item
295    sorted_partition = list(reversed(sorted(
296        partition,
297        key=lambda item: get_initial_transition_n(
298            item=item,
299            episodes=episodes
300        )
301    )))
302    return sorted_partition
303
304
305def get_partition_modes(
306    trajectory: list[Transition],
307    partition: list[StatePartitionItem],
308) -> list[int]:
309    """Return the list of indices of each transition in the trajectory."""
310    modes = list[int]()
311    for transition in trajectory:
312        index = None
313        for i, item in enumerate(partition):
314            if item_contains_transition(item, transition):
315                index = i
316        assert index is not None, "Partition doesn't contain transition!"
317        modes.append(index)
318    return modes
319
320
321def get_clusters(
322    mode_world_model: ModeWorldModel,
323    trajectories: list[list[Transition]],
324    cluster_n: int,
325    min_island_size: int,
326    dimensionality_reduce: int | None,
327    seed: str,
328    device: str,
329) -> list[set[tuple[int, int]]]:
330    """Partition the given dataset of transitions into disjoint subsets.
331    The returned sets contain the indices of the transitions in the set."""
332    _random = random.Random(seed)
333
334    # Bookkeeping
335    vector_indices = list[tuple[int, int]]()
336    location_index = dict[tuple[int, int], int]()
337    for i, trajectory in enumerate(trajectories):
338        for j, _ in enumerate(trajectory):
339            location = (i, j)
340            index = len(vector_indices)
341            vector_indices.append(location)
342            location_index[location] = index
343
344    # Get the latent vector for each transition
345    encoded_vectors = list[list[float]]()
346    for trajectory in trajectories:
347        for transition in trajectory:
348            embedding = get_mode_vector(
349                transition,
350                mode_world_model=mode_world_model,
351                device=device,
352            )
353            encoded_vectors.append(embedding)
354    X = torch.tensor(encoded_vectors)
355
356    # Normalize embeddings
357    X = StandardScaler().fit_transform(X)
358    if dimensionality_reduce is not None:
359        reducer = umap.UMAP(
360            n_components=dimensionality_reduce,
361            random_state=int.from_bytes(_random.randbytes(3), 'big', signed=False),
362        )
363        X = reducer.fit_transform(X)
364
365    # Cluster latent vectors
366    cluster = sklearn.cluster.KMeans(
367        n_clusters=cluster_n,
368        random_state=int.from_bytes(_random.randbytes(3), 'big', signed=False),
369    )
370    labels = list[int](cluster.fit_predict(X))
371
372    for trajectory in trajectories:
373        assert len(trajectory) > 0
374
375    # Prune short transitions
376    new_labels = list(labels)
377    for i, trajectory in enumerate(trajectories):
378        # Reconstruct sequence of assigned modes
379        modes = list[int]()
380        for j, transition in enumerate(trajectory):
381            location = (i, j)
382            index = location_index[location]
383            mode = labels[index]
384            modes.append(mode)
385
386        # Prune sequence of modes
387        new_modes = prune_short_transitions(modes, min_island_size)
388
389        # Add new labels
390        for j, new_mode in enumerate(new_modes):
391            location = (i, j)
392            index = location_index[location]
393            new_labels[index] = new_mode
394    labels = new_labels
395
396    # Assemble clusters
397    clusters = [
398        set[tuple[int, int]]()
399        for _ in range(cluster_n)
400    ]
401    for i, cluster_i in enumerate(labels):
402        location = vector_indices[i]
403        clusters[cluster_i].add(location)
404
405    # Remove empty clusters
406    clusters = [
407        cluster
408        for cluster in clusters
409        if len(cluster) > 0
410    ]
411    return clusters
412
413
414@dataclass
415class OptimizationResult:
416    partition: list[StatePartitionItem]
417    loss_log: list[float]
418    mode_world_model: ModeWorldModel
419
420
421def get_optimized_error_partition(
422    trajectories: list[list[Transition]],
423    hidden_sizes: list[int],
424    learning_rate: float,
425    latent_size: int,
426    iter_n: int,
427    clustering_dimensionality_reduce: int | None,
428    clustering_information_content_regularization_scale: float,
429    clustering_mutual_information_regularization_scale: float,
430    local_model_hyperparameters: dict[str, str | int | float],
431    dt: float,
432    size: int,
433    seed: str,
434    min_island_size: int,
435    batch_size: int,
436    mutual_information_mini_batch_size: int,
437    device: str,
438    verbose: bool,
439) -> OptimizationResult:
440    """Helper function to optimize the partition ensemble by error weighting."""
441    _random = random.Random(seed)
442
443    # Normalize trajectory vectors
444    if verbose:
445        print("Normalizing trajectories for mode_world_model")
446
447    # Train mode_world_model
448    mode_world_model = get_predictive_residual_encoder(
449        trajectories=trajectories,
450        hidden_sizes=hidden_sizes,
451        learning_rate=learning_rate,
452        latent_size=latent_size,
453        iter_n=iter_n,
454        seed=str(_random.random()),
455        dt=dt,
456        batch_size=batch_size,
457        information_content_regularization_scale=clustering_information_content_regularization_scale,
458        mutual_information_regularization_scale=clustering_mutual_information_regularization_scale,
459        mutual_information_mini_batch_size=mutual_information_mini_batch_size,
460        device=device,
461        verbose=verbose,
462    )
463
464    # Cluster latent state
465    clusters = get_clusters(
466        mode_world_model=mode_world_model,
467        trajectories=trajectories,
468        cluster_n=size,
469        min_island_size=min_island_size,
470        dimensionality_reduce=clustering_dimensionality_reduce,
471        seed=str(_random.random()),
472        device=device,
473    )
474
475    # Assemble partition items
476    partition = list[StatePartitionItem]()
477    for cluster in clusters:
478
479        cluster_sub_trajectories = list[list[Transition]]()
480
481        for i, trajectory in enumerate(trajectories):
482            sub_trajectory = list[Transition]()
483
484            for j, t in enumerate(trajectory):
485                if (i, j) not in cluster:
486                    if len(sub_trajectory) > 0:
487                        # We have transitioned out of the cluster
488                        cluster_sub_trajectories.append(sub_trajectory)
489                        sub_trajectory = list[Transition]()
490                else:
491                    # We continue building the sub-trajectory
492                    sub_trajectory.append(t)
493
494            if len(sub_trajectory) > 1:
495                cluster_sub_trajectories.append(sub_trajectory)
496
497        subset = [
498            trajectories[i][j]
499            for (i, j) in cluster
500        ]
501        model = get_optimized_model(
502            hyperparameters=local_model_hyperparameters,
503            trajectories=cluster_sub_trajectories,
504            dt=dt,
505            seed=str(_random.random()),
506            verbose=verbose,
507        )
508        partition_item = StatePartitionItem(
509            local_model=model,
510            subset=subset,
511            hidden_sizes=hidden_sizes,
512        )
513        partition.append(partition_item)
514
515    # Assemble result
516    loss_log = mode_world_model.loss_log
517    optimization_result = OptimizationResult(
518        partition=partition,
519        loss_log=loss_log,
520        mode_world_model=mode_world_model,
521    )
522    return optimization_result
523
524
525def get_partition(
526    episodes: list[list[Transition]],
527    hidden_sizes: list[int],
528    learning_rate: float,
529    latent_size: int,
530    mode_model_iter_n: int,
531    clustering_dimensionality_reduce: int | None,
532    clustering_information_content_regularization_scale: float,
533    clustering_mutual_information_regularization_scale: float,
534    dt: float,
535    size: int,
536    min_island_size: int,
537    seed: str,
538    batch_size: int,
539    local_model_hyperparameters: dict[str, str | int | float],
540    mutual_information_mini_batch_size: int,
541    device: str,
542    verbose: bool,
543) -> OptimizationResult:
544    """Returns a partition of the set of transitions in the given episodes.
545    Each subset of the partition has a corresponding neural model of the
546    dynamics in that subset.
547
548    The returned partition will be sorted so that the first subset contains
549    the maximum number of initial transitions from the input episodes.
550    """
551    _random = random.Random(seed)
552
553    # Optimize smooth partition
554    optimization_result = get_optimized_error_partition(
555        trajectories=episodes,
556        hidden_sizes=hidden_sizes,
557        learning_rate=learning_rate,
558        iter_n=mode_model_iter_n,
559        dt=dt,
560        latent_size=latent_size,
561        clustering_dimensionality_reduce=clustering_dimensionality_reduce,
562        clustering_information_content_regularization_scale=clustering_information_content_regularization_scale,
563        clustering_mutual_information_regularization_scale=clustering_mutual_information_regularization_scale,
564        local_model_hyperparameters=local_model_hyperparameters,
565        size=size,
566        verbose=verbose,
567        batch_size=batch_size,
568        mutual_information_mini_batch_size=mutual_information_mini_batch_size,
569        min_island_size=min_island_size,
570        device=device,
571        seed=str(_random.random()),
572    )
573
574    # Sort state partition so that the first item contains the first
575    # transition of the episodes
576    state_partition = get_sorted_partition(
577        episodes=episodes,
578        partition=optimization_result.partition,
579    )
580
581    optimization_result = OptimizationResult(
582        partition=state_partition,
583        loss_log=optimization_result.loss_log,
584        mode_world_model=optimization_result.mode_world_model,
585    )
586    return optimization_result
@dataclass
class StatePartitionItem:
37@dataclass
38class StatePartitionItem:
39    """A partition of a dataset of partitions."""
40    local_model: WorldModel
41    subset: list[Transition]
42    hidden_sizes: list[int]
43
44    @property
45    def local_model_input_size(self) -> int:
46        input_size, _ = get_input_output_size(self.subset[0])
47        return input_size
48
49    @property
50    def local_model_output_size(self) -> int:
51        _, output_size = get_input_output_size(self.subset[0])
52        return output_size
53
54    @cached_property
55    def transition_vectors_as_set_of_tuples(self) -> set[tuple[float, ...]]:
56        vectors = set[tuple[float, ...]]()
57        for transition in self.subset:
58            x = get_vector(transition)
59            vectors.add(tuple[float, ...](x.tolist()))
60        return vectors

A partition of a dataset of partitions.

StatePartitionItem( local_model: swmpo.world_models.world_model.WorldModel, subset: list[swmpo.transition.Transition], hidden_sizes: list[int])
local_model: swmpo.world_models.world_model.WorldModel
hidden_sizes: list[int]
local_model_input_size: int
44    @property
45    def local_model_input_size(self) -> int:
46        input_size, _ = get_input_output_size(self.subset[0])
47        return input_size
local_model_output_size: int
49    @property
50    def local_model_output_size(self) -> int:
51        _, output_size = get_input_output_size(self.subset[0])
52        return output_size
transition_vectors_as_set_of_tuples: set[tuple[float, ...]]
54    @cached_property
55    def transition_vectors_as_set_of_tuples(self) -> set[tuple[float, ...]]:
56        vectors = set[tuple[float, ...]]()
57        for transition in self.subset:
58            x = get_vector(transition)
59            vectors.add(tuple[float, ...](x.tolist()))
60        return vectors
def serialize_partition_item( partition_item: StatePartitionItem, output_zip_path: pathlib.Path):
 63def serialize_partition_item(
 64    partition_item: StatePartitionItem,
 65    output_zip_path: Path,
 66):
 67    """Serialize a partition item to a ZIP file."""
 68    assert output_zip_path.suffix == ".zip"
 69
 70    with tempfile.TemporaryDirectory() as tmpdirname:
 71        output_dir = Path(tmpdirname)
 72
 73        # Serialize local model
 74        model_path = output_dir/"model.zip"
 75        serialize_model(
 76            model=partition_item.local_model,
 77            output_zip_path=model_path,
 78        )
 79
 80        # Serialize subset of transitions
 81        subset_dir = output_dir/"transition_subset"
 82        subset_dir.mkdir()
 83        transition_directory = list[str]()
 84        for i, transition in enumerate(partition_item.subset):
 85            transition_zip = f"transition_{i}.zip"
 86            serialize(transition, subset_dir/transition_zip)
 87            transition_directory.append(str(transition_zip))
 88
 89        # Serialize directory
 90        transition_directory_json_path = output_dir/"transition_directory.json"
 91        with open(transition_directory_json_path, "wt") as fp:
 92            json.dump(transition_directory, fp, indent=2)
 93
 94        # Serialize hidden sizes
 95        hidden_sizes_json_path = output_dir/"hidden_sizes.json"
 96        with open(hidden_sizes_json_path, "wt") as fp:
 97            json.dump(partition_item.hidden_sizes, fp, indent=2)
 98
 99        # ZIP directory
100        _ = shutil.make_archive(
101            str(output_zip_path.with_suffix("")),
102            'zip',
103            output_dir
104        )

Serialize a partition item to a ZIP file.

def deserialize_partition_item(zip_path: pathlib.Path) -> StatePartitionItem:
107def deserialize_partition_item(
108    zip_path: Path,
109) -> StatePartitionItem:
110    """Deserialize a partition item from a ZIP file."""
111    with tempfile.TemporaryDirectory() as tmpdirname:
112        output_dir = Path(tmpdirname)
113
114        with zipfile.ZipFile(zip_path, "r") as zip_ref:
115            zip_ref.extractall(output_dir)
116
117        # Load model
118        model_path = output_dir/"model.zip"
119        local_model = deserialize_model(model_path)
120
121        # Load transition directory
122        transition_directory_json_path = output_dir/"transition_directory.json"
123        with open(transition_directory_json_path, "rt") as fp:
124            transition_directory = list[str](json.load(fp))
125
126        # Load transitions
127        futures = list()
128        with ThreadPoolExecutor() as executor:
129            for i, transition_path in enumerate(transition_directory):
130                zip_path = output_dir/"transition_subset"/transition_path
131                future = executor.submit(
132                    deserialize,
133                    zip_path,
134                )
135                futures.append(future)
136
137        subset = list[Transition]()
138        for i, future in enumerate(futures):
139            print(f"Deserializing partition item {i}/{len(futures)}")
140            transition = future.result()
141            subset.append(transition)
142
143        # Load hidden sizes
144        hidden_sizes_json_path = output_dir/"hidden_sizes.json"
145        with open(hidden_sizes_json_path, "rt") as fp:
146            hidden_sizes = list[int](json.load(fp))
147
148    item = StatePartitionItem(
149        local_model=local_model,
150        subset=subset,
151        hidden_sizes=hidden_sizes,
152    )
153    return item

Deserialize a partition item from a ZIP file.

def serialize_partition( partition: list[StatePartitionItem], output_zip_path: pathlib.Path):
156def serialize_partition(
157    partition: list[StatePartitionItem],
158    output_zip_path: Path,
159):
160    assert output_zip_path.suffix == ".zip"
161
162    with tempfile.TemporaryDirectory() as tmpdirname:
163        output_dir = Path(tmpdirname)
164
165        # Serialize each item
166        item_directory = list[str]()
167        for i, item in enumerate(partition):
168            item_path = f"item_{i}.zip"
169            serialize_partition_item(
170                partition_item=item,
171                output_zip_path=output_dir/item_path,
172            )
173            item_directory.append(item_path)
174
175        # Serialize directory
176        item_directory_json_path = output_dir/"item_directory.json"
177        with open(item_directory_json_path, "wt") as fp:
178            json.dump(item_directory, fp, indent=2)
179
180        # ZIP directory
181        _ = shutil.make_archive(
182            str(output_zip_path.with_suffix("")),
183            'zip',
184            output_dir
185        )
def deserialize_partition(zip_path: pathlib.Path) -> list[StatePartitionItem]:
188def deserialize_partition(zip_path: Path) -> list[StatePartitionItem]:
189    items = list[StatePartitionItem]()
190
191    with tempfile.TemporaryDirectory() as tmpdirname:
192        output_dir = Path(tmpdirname)
193
194        with zipfile.ZipFile(zip_path, "r") as zip_ref:
195            zip_ref.extractall(output_dir)
196
197        # Load directory
198        item_directory_json_path = output_dir/"item_directory.json"
199        with open(item_directory_json_path, "rt") as fp:
200            item_paths = list[str](json.load(fp))
201
202        # Load each partition item
203        for i, item_path in enumerate(item_paths):
204            zip_path = output_dir/item_path
205            item = deserialize_partition_item(zip_path)
206            items.append(item)
207
208    return items
def deserialize_partition_item_local_model(zip_path: pathlib.Path) -> swmpo.world_models.world_model.WorldModel:
211def deserialize_partition_item_local_model(
212    zip_path: Path,
213) -> WorldModel:
214    """Deserialize a partition item from a ZIP file."""
215    with tempfile.TemporaryDirectory() as tmpdirname:
216        output_dir = Path(tmpdirname)
217
218        with zipfile.ZipFile(zip_path, "r") as zip_ref:
219            zip_ref.extract("model.zip", path=output_dir)
220
221        # Load model
222        model_path = output_dir/"model.zip"
223        local_model = deserialize_model(model_path)
224    return local_model

Deserialize a partition item from a ZIP file.

def deserialize_partition_local_models( zip_path: pathlib.Path) -> list[swmpo.world_models.world_model.WorldModel]:
227def deserialize_partition_local_models(zip_path: Path) -> list[WorldModel]:
228    models = list[WorldModel]()
229
230    with tempfile.TemporaryDirectory() as tmpdirname:
231        output_dir = Path(tmpdirname)
232
233        with zipfile.ZipFile(zip_path, "r") as zip_ref:
234            zip_ref.extractall(output_dir)
235
236        # Load directory
237        item_directory_json_path = output_dir/"item_directory.json"
238        with open(item_directory_json_path, "rt") as fp:
239            item_paths = list[str](json.load(fp))
240
241        # Load each partition item
242        for i, item_path in enumerate(item_paths):
243            zip_path = output_dir/item_path
244            model = deserialize_partition_item_local_model(zip_path)
245            models.append(model)
246
247    return models
class PartitionSortingError(builtins.Exception):
250class PartitionSortingError(Exception):
251    pass

Common base class for all non-exit exceptions.

def item_contains_transition( item: StatePartitionItem, transition: swmpo.transition.Transition) -> bool:
254def item_contains_transition(
255        item: StatePartitionItem,
256        transition: Transition,
257        ) -> bool:
258    """Return whether the transition appears
259    in the partition item."""
260    x = tuple(get_vector(transition).tolist())
261    return x in item.transition_vectors_as_set_of_tuples

Return whether the transition appears in the partition item.

def get_initial_transition_n( item: StatePartitionItem, episodes: list[list[swmpo.transition.Transition]]) -> int:
264def get_initial_transition_n(
265        item: StatePartitionItem,
266        episodes: list[list[Transition]],
267        ) -> int:
268    """Return the number of times a transition occurs in the partition item."""
269    # Extract initial transitions
270    initial_transitions = [
271        episode[0]
272        for episode in episodes
273        if len(episode) > 0
274    ]
275
276    # Filter-in the initial transitions that
277    # appear in the partition item
278    occurrences = [
279        initial_transition
280        for initial_transition in initial_transitions
281        if item_contains_transition(
282            item=item,
283            transition=initial_transition
284        )
285    ]
286    return len(occurrences)

Return the number of times a transition occurs in the partition item.

def get_sorted_partition( partition: list[StatePartitionItem], episodes: list[list[swmpo.transition.Transition]]) -> list[StatePartitionItem]:
289def get_sorted_partition(
290    partition: list[StatePartitionItem],
291    episodes: list[list[Transition]],
292) -> list[StatePartitionItem]:
293    """Sort the partition so that the partition item with the most
294    initial transitions is first."""
295    # Identify first item
296    sorted_partition = list(reversed(sorted(
297        partition,
298        key=lambda item: get_initial_transition_n(
299            item=item,
300            episodes=episodes
301        )
302    )))
303    return sorted_partition

Sort the partition so that the partition item with the most initial transitions is first.

def get_partition_modes( trajectory: list[swmpo.transition.Transition], partition: list[StatePartitionItem]) -> list[int]:
306def get_partition_modes(
307    trajectory: list[Transition],
308    partition: list[StatePartitionItem],
309) -> list[int]:
310    """Return the list of indices of each transition in the trajectory."""
311    modes = list[int]()
312    for transition in trajectory:
313        index = None
314        for i, item in enumerate(partition):
315            if item_contains_transition(item, transition):
316                index = i
317        assert index is not None, "Partition doesn't contain transition!"
318        modes.append(index)
319    return modes

Return the list of indices of each transition in the trajectory.

def get_clusters( mode_world_model: swmpo.world_models.mode_world_model.ModeWorldModel, trajectories: list[list[swmpo.transition.Transition]], cluster_n: int, min_island_size: int, dimensionality_reduce: int | None, seed: str, device: str) -> list[set[tuple[int, int]]]:
322def get_clusters(
323    mode_world_model: ModeWorldModel,
324    trajectories: list[list[Transition]],
325    cluster_n: int,
326    min_island_size: int,
327    dimensionality_reduce: int | None,
328    seed: str,
329    device: str,
330) -> list[set[tuple[int, int]]]:
331    """Partition the given dataset of transitions into disjoint subsets.
332    The returned sets contain the indices of the transitions in the set."""
333    _random = random.Random(seed)
334
335    # Bookkeeping
336    vector_indices = list[tuple[int, int]]()
337    location_index = dict[tuple[int, int], int]()
338    for i, trajectory in enumerate(trajectories):
339        for j, _ in enumerate(trajectory):
340            location = (i, j)
341            index = len(vector_indices)
342            vector_indices.append(location)
343            location_index[location] = index
344
345    # Get the latent vector for each transition
346    encoded_vectors = list[list[float]]()
347    for trajectory in trajectories:
348        for transition in trajectory:
349            embedding = get_mode_vector(
350                transition,
351                mode_world_model=mode_world_model,
352                device=device,
353            )
354            encoded_vectors.append(embedding)
355    X = torch.tensor(encoded_vectors)
356
357    # Normalize embeddings
358    X = StandardScaler().fit_transform(X)
359    if dimensionality_reduce is not None:
360        reducer = umap.UMAP(
361            n_components=dimensionality_reduce,
362            random_state=int.from_bytes(_random.randbytes(3), 'big', signed=False),
363        )
364        X = reducer.fit_transform(X)
365
366    # Cluster latent vectors
367    cluster = sklearn.cluster.KMeans(
368        n_clusters=cluster_n,
369        random_state=int.from_bytes(_random.randbytes(3), 'big', signed=False),
370    )
371    labels = list[int](cluster.fit_predict(X))
372
373    for trajectory in trajectories:
374        assert len(trajectory) > 0
375
376    # Prune short transitions
377    new_labels = list(labels)
378    for i, trajectory in enumerate(trajectories):
379        # Reconstruct sequence of assigned modes
380        modes = list[int]()
381        for j, transition in enumerate(trajectory):
382            location = (i, j)
383            index = location_index[location]
384            mode = labels[index]
385            modes.append(mode)
386
387        # Prune sequence of modes
388        new_modes = prune_short_transitions(modes, min_island_size)
389
390        # Add new labels
391        for j, new_mode in enumerate(new_modes):
392            location = (i, j)
393            index = location_index[location]
394            new_labels[index] = new_mode
395    labels = new_labels
396
397    # Assemble clusters
398    clusters = [
399        set[tuple[int, int]]()
400        for _ in range(cluster_n)
401    ]
402    for i, cluster_i in enumerate(labels):
403        location = vector_indices[i]
404        clusters[cluster_i].add(location)
405
406    # Remove empty clusters
407    clusters = [
408        cluster
409        for cluster in clusters
410        if len(cluster) > 0
411    ]
412    return clusters

Partition the given dataset of transitions into disjoint subsets. The returned sets contain the indices of the transitions in the set.

@dataclass
class OptimizationResult:
415@dataclass
416class OptimizationResult:
417    partition: list[StatePartitionItem]
418    loss_log: list[float]
419    mode_world_model: ModeWorldModel
OptimizationResult( partition: list[StatePartitionItem], loss_log: list[float], mode_world_model: swmpo.world_models.mode_world_model.ModeWorldModel)
partition: list[StatePartitionItem]
loss_log: list[float]
mode_world_model: swmpo.world_models.mode_world_model.ModeWorldModel
def get_optimized_error_partition( trajectories: list[list[swmpo.transition.Transition]], hidden_sizes: list[int], learning_rate: float, latent_size: int, iter_n: int, clustering_dimensionality_reduce: int | None, clustering_information_content_regularization_scale: float, clustering_mutual_information_regularization_scale: float, local_model_hyperparameters: dict[str, str | int | float], dt: float, size: int, seed: str, min_island_size: int, batch_size: int, mutual_information_mini_batch_size: int, device: str, verbose: bool) -> OptimizationResult:
422def get_optimized_error_partition(
423    trajectories: list[list[Transition]],
424    hidden_sizes: list[int],
425    learning_rate: float,
426    latent_size: int,
427    iter_n: int,
428    clustering_dimensionality_reduce: int | None,
429    clustering_information_content_regularization_scale: float,
430    clustering_mutual_information_regularization_scale: float,
431    local_model_hyperparameters: dict[str, str | int | float],
432    dt: float,
433    size: int,
434    seed: str,
435    min_island_size: int,
436    batch_size: int,
437    mutual_information_mini_batch_size: int,
438    device: str,
439    verbose: bool,
440) -> OptimizationResult:
441    """Helper function to optimize the partition ensemble by error weighting."""
442    _random = random.Random(seed)
443
444    # Normalize trajectory vectors
445    if verbose:
446        print("Normalizing trajectories for mode_world_model")
447
448    # Train mode_world_model
449    mode_world_model = get_predictive_residual_encoder(
450        trajectories=trajectories,
451        hidden_sizes=hidden_sizes,
452        learning_rate=learning_rate,
453        latent_size=latent_size,
454        iter_n=iter_n,
455        seed=str(_random.random()),
456        dt=dt,
457        batch_size=batch_size,
458        information_content_regularization_scale=clustering_information_content_regularization_scale,
459        mutual_information_regularization_scale=clustering_mutual_information_regularization_scale,
460        mutual_information_mini_batch_size=mutual_information_mini_batch_size,
461        device=device,
462        verbose=verbose,
463    )
464
465    # Cluster latent state
466    clusters = get_clusters(
467        mode_world_model=mode_world_model,
468        trajectories=trajectories,
469        cluster_n=size,
470        min_island_size=min_island_size,
471        dimensionality_reduce=clustering_dimensionality_reduce,
472        seed=str(_random.random()),
473        device=device,
474    )
475
476    # Assemble partition items
477    partition = list[StatePartitionItem]()
478    for cluster in clusters:
479
480        cluster_sub_trajectories = list[list[Transition]]()
481
482        for i, trajectory in enumerate(trajectories):
483            sub_trajectory = list[Transition]()
484
485            for j, t in enumerate(trajectory):
486                if (i, j) not in cluster:
487                    if len(sub_trajectory) > 0:
488                        # We have transitioned out of the cluster
489                        cluster_sub_trajectories.append(sub_trajectory)
490                        sub_trajectory = list[Transition]()
491                else:
492                    # We continue building the sub-trajectory
493                    sub_trajectory.append(t)
494
495            if len(sub_trajectory) > 1:
496                cluster_sub_trajectories.append(sub_trajectory)
497
498        subset = [
499            trajectories[i][j]
500            for (i, j) in cluster
501        ]
502        model = get_optimized_model(
503            hyperparameters=local_model_hyperparameters,
504            trajectories=cluster_sub_trajectories,
505            dt=dt,
506            seed=str(_random.random()),
507            verbose=verbose,
508        )
509        partition_item = StatePartitionItem(
510            local_model=model,
511            subset=subset,
512            hidden_sizes=hidden_sizes,
513        )
514        partition.append(partition_item)
515
516    # Assemble result
517    loss_log = mode_world_model.loss_log
518    optimization_result = OptimizationResult(
519        partition=partition,
520        loss_log=loss_log,
521        mode_world_model=mode_world_model,
522    )
523    return optimization_result

Helper function to optimize the partition ensemble by error weighting.

def get_partition( episodes: list[list[swmpo.transition.Transition]], hidden_sizes: list[int], learning_rate: float, latent_size: int, mode_model_iter_n: int, clustering_dimensionality_reduce: int | None, clustering_information_content_regularization_scale: float, clustering_mutual_information_regularization_scale: float, dt: float, size: int, min_island_size: int, seed: str, batch_size: int, local_model_hyperparameters: dict[str, str | int | float], mutual_information_mini_batch_size: int, device: str, verbose: bool) -> OptimizationResult:
526def get_partition(
527    episodes: list[list[Transition]],
528    hidden_sizes: list[int],
529    learning_rate: float,
530    latent_size: int,
531    mode_model_iter_n: int,
532    clustering_dimensionality_reduce: int | None,
533    clustering_information_content_regularization_scale: float,
534    clustering_mutual_information_regularization_scale: float,
535    dt: float,
536    size: int,
537    min_island_size: int,
538    seed: str,
539    batch_size: int,
540    local_model_hyperparameters: dict[str, str | int | float],
541    mutual_information_mini_batch_size: int,
542    device: str,
543    verbose: bool,
544) -> OptimizationResult:
545    """Returns a partition of the set of transitions in the given episodes.
546    Each subset of the partition has a corresponding neural model of the
547    dynamics in that subset.
548
549    The returned partition will be sorted so that the first subset contains
550    the maximum number of initial transitions from the input episodes.
551    """
552    _random = random.Random(seed)
553
554    # Optimize smooth partition
555    optimization_result = get_optimized_error_partition(
556        trajectories=episodes,
557        hidden_sizes=hidden_sizes,
558        learning_rate=learning_rate,
559        iter_n=mode_model_iter_n,
560        dt=dt,
561        latent_size=latent_size,
562        clustering_dimensionality_reduce=clustering_dimensionality_reduce,
563        clustering_information_content_regularization_scale=clustering_information_content_regularization_scale,
564        clustering_mutual_information_regularization_scale=clustering_mutual_information_regularization_scale,
565        local_model_hyperparameters=local_model_hyperparameters,
566        size=size,
567        verbose=verbose,
568        batch_size=batch_size,
569        mutual_information_mini_batch_size=mutual_information_mini_batch_size,
570        min_island_size=min_island_size,
571        device=device,
572        seed=str(_random.random()),
573    )
574
575    # Sort state partition so that the first item contains the first
576    # transition of the episodes
577    state_partition = get_sorted_partition(
578        episodes=episodes,
579        partition=optimization_result.partition,
580    )
581
582    optimization_result = OptimizationResult(
583        partition=state_partition,
584        loss_log=optimization_result.loss_log,
585        mode_world_model=optimization_result.mode_world_model,
586    )
587    return optimization_result

Returns a partition of the set of transitions in the given episodes. Each subset of the partition has a corresponding neural model of the dynamics in that subset.

The returned partition will be sorted so that the first subset contains the maximum number of initial transitions from the input episodes.