swmpo.sequence_distance
Comparing sequences of mode sequences.
1"""Comparing sequences of mode sequences.""" 2from itertools import permutations 3import math 4 5 6def _get_levenshtein(a: list[int], b: list[int]) -> float: 7 """Levenshtein distance.""" 8 # Adapted from Wikipedia 9 m = len(a) 10 n = len(b) 11 12 # d is an m by n array of zeros 13 d = [ 14 [0 for _ in range(n+1)] 15 for _ in range(m+1) 16 ] 17 for i in range(1, m+1): 18 d[i][0] = i 19 20 for j in range(1, n+1): 21 d[0][j] = j 22 23 for j in range(1, n+1): 24 for i in range(1, m+1): 25 if a[i-1] == b[j-1]: 26 substitutionCost = 0 27 else: 28 substitutionCost = 1 29 30 d[i][j] = min( 31 d[i-1][j] + 1, # deletion 32 d[i][j-1] + 1, # insertion 33 d[i-1][j-1] + substitutionCost) # substitution 34 35 return d[m][n] 36 37 38def get_error(visited_states: list[int], ground_truth: list[int]) -> float: 39 """Levenshtein distance.""" 40 return _get_levenshtein(visited_states, ground_truth) 41 42 43def _get_best_permutation( 44 sequence: list[int], 45 ground_truth: list[int], 46 indices: list[int], 47) -> dict[int, int]: 48 """Returns a sequence `new` with the property that `new[i] = perm[sequence[i]]`, 49 where `perm` is a permutation of the mode indices that preserves the initial 50 mode. The permutation is such that the returned sequence has the minimum 51 error. 52 53 That is, if we have a sequence `[0, 2, 3, 2, 4]`, we could return 54 `[0, 3, 2, 3, 4]`. 55 56 Indices is the set of labels that can be permuted. 57 """ 58 best_error = math.inf 59 best_perm = None 60 i = 0 61 all_indices = set(sequence) | set(ground_truth) 62 non_permuted_indices = all_indices - set(indices) 63 for permutation in permutations(indices): 64 if i > 1000: 65 break # TODO: do something about large index sets 66 i += 1 67 perm = { 68 original: new 69 for original, new in zip(indices, permutation) 70 } 71 # complete permutation 72 for i in non_permuted_indices: 73 perm[i] = i 74 new_sequence = [ 75 perm[original] 76 for original in sequence 77 ] 78 new_error = get_error( 79 new_sequence, 80 ground_truth, 81 ) 82 if new_error < best_error: 83 best_error = new_error 84 best_perm = perm 85 assert best_perm is not None 86 return best_perm 87 88 89def get_best_permutation( 90 sequence: list[int], 91 ground_truth: list[int], 92 initial_state: int, 93) -> list[int]: 94 """Returns a sequence `new` with the property that `new[i] = perm[sequence[i]]`, 95 where `perm` is a permutation of the mode indices that preserves the initial 96 mode. The permutation is such that the returned sequence has the minimum 97 error. 98 99 That is, if we have a sequence `[1, 2, 3, 2, 4]`, we could return 100 `[1, 3, 2, 3, 4]`. 101 """ 102 indices = list((set(sequence)|set(ground_truth)) - set([initial_state])) 103 perm = _get_best_permutation( 104 sequence=sequence, 105 ground_truth=ground_truth, 106 indices=indices, 107 ) 108 new_sequence = [ 109 perm[original] 110 for original in sequence 111 ] 112 return new_sequence
def
get_error(visited_states: list[int], ground_truth: list[int]) -> float:
39def get_error(visited_states: list[int], ground_truth: list[int]) -> float: 40 """Levenshtein distance.""" 41 return _get_levenshtein(visited_states, ground_truth)
Levenshtein distance.
def
get_best_permutation( sequence: list[int], ground_truth: list[int], initial_state: int) -> list[int]:
90def get_best_permutation( 91 sequence: list[int], 92 ground_truth: list[int], 93 initial_state: int, 94) -> list[int]: 95 """Returns a sequence `new` with the property that `new[i] = perm[sequence[i]]`, 96 where `perm` is a permutation of the mode indices that preserves the initial 97 mode. The permutation is such that the returned sequence has the minimum 98 error. 99 100 That is, if we have a sequence `[1, 2, 3, 2, 4]`, we could return 101 `[1, 3, 2, 3, 4]`. 102 """ 103 indices = list((set(sequence)|set(ground_truth)) - set([initial_state])) 104 perm = _get_best_permutation( 105 sequence=sequence, 106 ground_truth=ground_truth, 107 indices=indices, 108 ) 109 new_sequence = [ 110 perm[original] 111 for original in sequence 112 ] 113 return new_sequence
Returns a sequence new with the property that new[i] = perm[sequence[i]],
where perm is a permutation of the mode indices that preserves the initial
mode. The permutation is such that the returned sequence has the minimum
error.
That is, if we have a sequence [1, 2, 3, 2, 4], we could return
[1, 3, 2, 3, 4].