swmpo.transition_normalization
Normalize datasets of trajectories.
1"""Normalize datasets of trajectories.""" 2from dataclasses import dataclass 3from swmpo.transition import Transition 4import statistics 5import torch 6import json 7from pathlib import Path 8 9 10@dataclass 11class VectorStats: 12 means: list[float] 13 maxs: list[float] 14 mins: list[float] 15 stdevs: list[float] 16 17 18def serialize_vector_stats( 19 stats: VectorStats, 20 output_json: Path, 21): 22 obj = dict( 23 means=stats.means, 24 maxs=stats.maxs, 25 mins=stats.mins, 26 stdevs=stats.stdevs, 27 ) 28 with open(output_json, "wt") as fp: 29 json.dump(obj, fp, indent=2) 30 31 32def deserialize_vector_stats( 33 serialized_json: Path, 34) -> VectorStats: 35 with open(serialized_json, "rt") as fp: 36 obj = json.load(fp) 37 vector_stats = VectorStats( 38 means=obj["means"], 39 maxs=obj["maxs"], 40 mins=obj["mins"], 41 stdevs=obj["stdevs"], 42 ) 43 return vector_stats 44 45 46def get_vector_stats(vectors: list[list[float]]) -> VectorStats: 47 feature_vals: list[list[float]] = torch.tensor( 48 vectors 49 ).transpose(0, 1).tolist() 50 means = [ 51 statistics.mean(vals) 52 for vals in feature_vals 53 ] 54 stdevs = [ 55 statistics.stdev(vals) 56 for vals in feature_vals 57 ] 58 maxs = [ 59 max(vals) 60 for vals in feature_vals 61 ] 62 mins = [ 63 min(vals) 64 for vals in feature_vals 65 ] 66 normalization = VectorStats( 67 means=means, 68 maxs=maxs, 69 mins=mins, 70 stdevs=stdevs, 71 ) 72 return normalization 73 74 75def get_normalized_vector( 76 vector: torch.Tensor, 77 stats: VectorStats, 78) -> torch.Tensor: 79 # Patch stdevs 80 stdevs = [ 81 stdev_i 82 if stdev_i > 0 83 else 1.0 84 for stdev_i in stats.stdevs 85 ] 86 vals = [ 87 (val-stats.means[i])/stdevs[i] 88 for i, val in enumerate(vector) 89 ] 90 norm_vec = torch.stack(vals) 91 return norm_vec 92 93 94def get_raw_vector( 95 vector: torch.Tensor, 96 stats: VectorStats, 97) -> torch.Tensor: 98 """'Unnormalize' a vector.""" 99 vals = [ 100 (val*stats.stdevs[i]+stats.means[i]) 101 for i, val in enumerate(vector) 102 ] 103 norm_vec = torch.stack(vals) 104 return norm_vec 105 106 107@dataclass 108class TransitionStatistics: 109 """Per-feature statistics.""" 110 state_normalization: VectorStats 111 action_normalization: VectorStats 112 113 114def get_transition_statistics( 115 transitions: list[Transition], 116) -> TransitionStatistics: 117 state_vectors = list() 118 action_vectors = list() 119 120 for transition in transitions: 121 # We only use source state because next state will be a source 122 # state of some other vector 123 state_vectors.append(transition.source_state.tolist()) 124 action_vectors.append(transition.action.tolist()) 125 126 state_normalization = get_vector_stats(state_vectors) 127 action_normalization = get_vector_stats(action_vectors) 128 129 stats = TransitionStatistics( 130 state_normalization=state_normalization, 131 action_normalization=action_normalization, 132 ) 133 return stats 134 135 136def get_normalized_state( 137 state: torch.Tensor, 138 stats: TransitionStatistics, 139) -> torch.Tensor: 140 return get_normalized_vector(state, stats.state_normalization) 141 142 143def get_normalized_action( 144 action: torch.Tensor, 145 stats: TransitionStatistics, 146) -> torch.Tensor: 147 return get_normalized_vector(action, stats.action_normalization) 148 149 150def get_raw_state( 151 state: torch.Tensor, 152 stats: TransitionStatistics, 153) -> torch.Tensor: 154 return get_raw_vector(state, stats.state_normalization) 155 156 157def get_raw_action( 158 action: torch.Tensor, 159 stats: TransitionStatistics, 160) -> torch.Tensor: 161 return get_raw_vector(action, stats.action_normalization) 162 163 164def get_normalized_transition( 165 transition: Transition, 166 stats: TransitionStatistics, 167) -> Transition: 168 norm_transition = Transition( 169 source_state=get_normalized_state(transition.source_state, stats), 170 next_state=get_normalized_state(transition.next_state, stats), 171 action=get_normalized_action(transition.action, stats), 172 ) 173 return norm_transition 174 175 176def get_normalized_trajectories( 177 trajectories: list[list[Transition]], 178 stats: TransitionStatistics, 179) -> list[list[Transition]]: 180 """Normalize trajectory vectors using dataset statistics.""" 181 normalized_trajectories = [ 182 [ 183 get_normalized_transition( 184 transition, 185 stats, 186 ) 187 for transition in trajectory 188 ] 189 for trajectory in trajectories 190 ] 191 return normalized_trajectories
@dataclass
class
VectorStats:
47def get_vector_stats(vectors: list[list[float]]) -> VectorStats: 48 feature_vals: list[list[float]] = torch.tensor( 49 vectors 50 ).transpose(0, 1).tolist() 51 means = [ 52 statistics.mean(vals) 53 for vals in feature_vals 54 ] 55 stdevs = [ 56 statistics.stdev(vals) 57 for vals in feature_vals 58 ] 59 maxs = [ 60 max(vals) 61 for vals in feature_vals 62 ] 63 mins = [ 64 min(vals) 65 for vals in feature_vals 66 ] 67 normalization = VectorStats( 68 means=means, 69 maxs=maxs, 70 mins=mins, 71 stdevs=stdevs, 72 ) 73 return normalization
76def get_normalized_vector( 77 vector: torch.Tensor, 78 stats: VectorStats, 79) -> torch.Tensor: 80 # Patch stdevs 81 stdevs = [ 82 stdev_i 83 if stdev_i > 0 84 else 1.0 85 for stdev_i in stats.stdevs 86 ] 87 vals = [ 88 (val-stats.means[i])/stdevs[i] 89 for i, val in enumerate(vector) 90 ] 91 norm_vec = torch.stack(vals) 92 return norm_vec
95def get_raw_vector( 96 vector: torch.Tensor, 97 stats: VectorStats, 98) -> torch.Tensor: 99 """'Unnormalize' a vector.""" 100 vals = [ 101 (val*stats.stdevs[i]+stats.means[i]) 102 for i, val in enumerate(vector) 103 ] 104 norm_vec = torch.stack(vals) 105 return norm_vec
'Unnormalize' a vector.
@dataclass
class
TransitionStatistics:
108@dataclass 109class TransitionStatistics: 110 """Per-feature statistics.""" 111 state_normalization: VectorStats 112 action_normalization: VectorStats
Per-feature statistics.
TransitionStatistics( state_normalization: VectorStats, action_normalization: VectorStats)
state_normalization: VectorStats
action_normalization: VectorStats
def
get_transition_statistics( transitions: list[swmpo.transition.Transition]) -> TransitionStatistics:
115def get_transition_statistics( 116 transitions: list[Transition], 117) -> TransitionStatistics: 118 state_vectors = list() 119 action_vectors = list() 120 121 for transition in transitions: 122 # We only use source state because next state will be a source 123 # state of some other vector 124 state_vectors.append(transition.source_state.tolist()) 125 action_vectors.append(transition.action.tolist()) 126 127 state_normalization = get_vector_stats(state_vectors) 128 action_normalization = get_vector_stats(action_vectors) 129 130 stats = TransitionStatistics( 131 state_normalization=state_normalization, 132 action_normalization=action_normalization, 133 ) 134 return stats
def
get_normalized_transition( transition: swmpo.transition.Transition, stats: TransitionStatistics) -> swmpo.transition.Transition:
165def get_normalized_transition( 166 transition: Transition, 167 stats: TransitionStatistics, 168) -> Transition: 169 norm_transition = Transition( 170 source_state=get_normalized_state(transition.source_state, stats), 171 next_state=get_normalized_state(transition.next_state, stats), 172 action=get_normalized_action(transition.action, stats), 173 ) 174 return norm_transition
def
get_normalized_trajectories( trajectories: list[list[swmpo.transition.Transition]], stats: TransitionStatistics) -> list[list[swmpo.transition.Transition]]:
177def get_normalized_trajectories( 178 trajectories: list[list[Transition]], 179 stats: TransitionStatistics, 180) -> list[list[Transition]]: 181 """Normalize trajectory vectors using dataset statistics.""" 182 normalized_trajectories = [ 183 [ 184 get_normalized_transition( 185 transition, 186 stats, 187 ) 188 for transition in trajectory 189 ] 190 for trajectory in trajectories 191 ] 192 return normalized_trajectories
Normalize trajectory vectors using dataset statistics.