swmpo.sequence_distance

Comparing sequences of mode sequences.

  1"""Comparing sequences of mode sequences."""
  2from itertools import permutations
  3import math
  4
  5
  6def _get_levenshtein(a: list[int], b: list[int]) -> float:
  7    """Levenshtein distance."""
  8    # Adapted from Wikipedia
  9    m = len(a)
 10    n = len(b)
 11
 12    # d is an m by n array of zeros
 13    d = [
 14        [0 for _ in range(n+1)]
 15        for _ in range(m+1)
 16    ]
 17    for i in range(1, m+1):
 18        d[i][0] = i
 19
 20    for j in range(1, n+1):
 21        d[0][j] = j
 22
 23    for j in range(1, n+1):
 24        for i in range(1, m+1):
 25            if a[i-1] == b[j-1]:
 26                substitutionCost = 0
 27            else:
 28                substitutionCost = 1
 29
 30            d[i][j] = min(
 31                d[i-1][j] + 1,                   # deletion
 32                d[i][j-1] + 1,                   # insertion
 33                d[i-1][j-1] + substitutionCost)  # substitution
 34
 35    return d[m][n]
 36
 37
 38def get_error(visited_states: list[int], ground_truth: list[int]) -> float:
 39    """Levenshtein distance."""
 40    return _get_levenshtein(visited_states, ground_truth)
 41
 42
 43def _get_best_permutation(
 44    sequence: list[int],
 45    ground_truth: list[int],
 46    indices: list[int],
 47) -> dict[int, int]:
 48    """Returns a sequence `new` with the property that `new[i] = perm[sequence[i]]`,
 49    where `perm` is a permutation of the mode indices that preserves the initial
 50    mode. The permutation is such that the returned sequence has the minimum
 51    error.
 52
 53    That is, if we have a sequence `[0, 2, 3, 2, 4]`, we could return
 54    `[0, 3, 2, 3, 4]`.
 55
 56    Indices is the set of labels that can be permuted.
 57    """
 58    best_error = math.inf
 59    best_perm = None
 60    i = 0
 61    all_indices = set(sequence) | set(ground_truth)
 62    non_permuted_indices = all_indices - set(indices)
 63    for permutation in permutations(indices):
 64        if i > 1000:
 65            break  # TODO: do something about large index sets
 66        i += 1
 67        perm = {
 68            original: new
 69            for original, new in zip(indices, permutation)
 70        }
 71        # complete permutation
 72        for i in non_permuted_indices:
 73            perm[i] = i
 74        new_sequence = [
 75            perm[original]
 76            for original in sequence
 77        ]
 78        new_error = get_error(
 79            new_sequence,
 80            ground_truth,
 81        )
 82        if new_error < best_error:
 83            best_error = new_error
 84            best_perm = perm
 85    assert best_perm is not None
 86    return best_perm
 87
 88
 89def get_best_permutation(
 90    sequence: list[int],
 91    ground_truth: list[int],
 92    initial_state: int,
 93) -> list[int]:
 94    """Returns a sequence `new` with the property that `new[i] = perm[sequence[i]]`,
 95    where `perm` is a permutation of the mode indices that preserves the initial
 96    mode. The permutation is such that the returned sequence has the minimum
 97    error.
 98
 99    That is, if we have a sequence `[1, 2, 3, 2, 4]`, we could return
100    `[1, 3, 2, 3, 4]`.
101    """
102    indices = list((set(sequence)|set(ground_truth)) - set([initial_state]))
103    perm = _get_best_permutation(
104        sequence=sequence,
105        ground_truth=ground_truth,
106        indices=indices,
107    )
108    new_sequence = [
109        perm[original]
110        for original in sequence
111    ]
112    return new_sequence
def get_error(visited_states: list[int], ground_truth: list[int]) -> float:
39def get_error(visited_states: list[int], ground_truth: list[int]) -> float:
40    """Levenshtein distance."""
41    return _get_levenshtein(visited_states, ground_truth)

Levenshtein distance.

def get_best_permutation( sequence: list[int], ground_truth: list[int], initial_state: int) -> list[int]:
 90def get_best_permutation(
 91    sequence: list[int],
 92    ground_truth: list[int],
 93    initial_state: int,
 94) -> list[int]:
 95    """Returns a sequence `new` with the property that `new[i] = perm[sequence[i]]`,
 96    where `perm` is a permutation of the mode indices that preserves the initial
 97    mode. The permutation is such that the returned sequence has the minimum
 98    error.
 99
100    That is, if we have a sequence `[1, 2, 3, 2, 4]`, we could return
101    `[1, 3, 2, 3, 4]`.
102    """
103    indices = list((set(sequence)|set(ground_truth)) - set([initial_state]))
104    perm = _get_best_permutation(
105        sequence=sequence,
106        ground_truth=ground_truth,
107        indices=indices,
108    )
109    new_sequence = [
110        perm[original]
111        for original in sequence
112    ]
113    return new_sequence

Returns a sequence new with the property that new[i] = perm[sequence[i]], where perm is a permutation of the mode indices that preserves the initial mode. The permutation is such that the returned sequence has the minimum error.

That is, if we have a sequence [1, 2, 3, 2, 4], we could return [1, 3, 2, 3, 4].