swmpo.transition_normalization

Normalize datasets of trajectories.

  1"""Normalize datasets of trajectories."""
  2from dataclasses import dataclass
  3from swmpo.transition import Transition
  4import statistics
  5import torch
  6import json
  7from pathlib import Path
  8
  9
 10@dataclass
 11class VectorStats:
 12    means: list[float]
 13    maxs: list[float]
 14    mins: list[float]
 15    stdevs: list[float]
 16
 17
 18def serialize_vector_stats(
 19    stats: VectorStats,
 20    output_json: Path,
 21):
 22    obj = dict(
 23        means=stats.means,
 24        maxs=stats.maxs,
 25        mins=stats.mins,
 26        stdevs=stats.stdevs,
 27    )
 28    with open(output_json, "wt") as fp:
 29        json.dump(obj, fp, indent=2)
 30
 31
 32def deserialize_vector_stats(
 33    serialized_json: Path,
 34) -> VectorStats:
 35    with open(serialized_json, "rt") as fp:
 36        obj = json.load(fp)
 37    vector_stats = VectorStats(
 38        means=obj["means"],
 39        maxs=obj["maxs"],
 40        mins=obj["mins"],
 41        stdevs=obj["stdevs"],
 42    )
 43    return vector_stats
 44
 45
 46def get_vector_stats(vectors: list[list[float]]) -> VectorStats:
 47    feature_vals: list[list[float]] = torch.tensor(
 48        vectors
 49    ).transpose(0, 1).tolist()
 50    means = [
 51        statistics.mean(vals)
 52        for vals in feature_vals
 53    ]
 54    stdevs = [
 55        statistics.stdev(vals)
 56        for vals in feature_vals
 57    ]
 58    maxs = [
 59        max(vals)
 60        for vals in feature_vals
 61    ]
 62    mins = [
 63        min(vals)
 64        for vals in feature_vals
 65    ]
 66    normalization = VectorStats(
 67        means=means,
 68        maxs=maxs,
 69        mins=mins,
 70        stdevs=stdevs,
 71    )
 72    return normalization
 73
 74
 75def get_normalized_vector(
 76    vector: torch.Tensor,
 77    stats: VectorStats,
 78) -> torch.Tensor:
 79    # Patch stdevs
 80    stdevs = [
 81        stdev_i
 82        if stdev_i > 0
 83        else 1.0
 84        for stdev_i in stats.stdevs
 85    ]
 86    vals = [
 87        (val-stats.means[i])/stdevs[i]
 88        for i, val in enumerate(vector)
 89    ]
 90    norm_vec = torch.stack(vals)
 91    return norm_vec
 92
 93
 94def get_raw_vector(
 95    vector: torch.Tensor,
 96    stats: VectorStats,
 97) -> torch.Tensor:
 98    """'Unnormalize' a vector."""
 99    vals = [
100        (val*stats.stdevs[i]+stats.means[i])
101        for i, val in enumerate(vector)
102    ]
103    norm_vec = torch.stack(vals)
104    return norm_vec
105
106
107@dataclass
108class TransitionStatistics:
109    """Per-feature statistics."""
110    state_normalization: VectorStats
111    action_normalization: VectorStats
112
113
114def get_transition_statistics(
115    transitions: list[Transition],
116) -> TransitionStatistics:
117    state_vectors = list()
118    action_vectors = list()
119
120    for transition in transitions:
121        # We only use source state because next state will be a source
122        # state of some other vector
123        state_vectors.append(transition.source_state.tolist())
124        action_vectors.append(transition.action.tolist())
125
126    state_normalization = get_vector_stats(state_vectors)
127    action_normalization = get_vector_stats(action_vectors)
128
129    stats = TransitionStatistics(
130        state_normalization=state_normalization,
131        action_normalization=action_normalization,
132    )
133    return stats
134
135
136def get_normalized_state(
137    state: torch.Tensor,
138    stats: TransitionStatistics,
139) -> torch.Tensor:
140    return get_normalized_vector(state, stats.state_normalization)
141
142
143def get_normalized_action(
144    action: torch.Tensor,
145    stats: TransitionStatistics,
146) -> torch.Tensor:
147    return get_normalized_vector(action, stats.action_normalization)
148
149
150def get_raw_state(
151    state: torch.Tensor,
152    stats: TransitionStatistics,
153) -> torch.Tensor:
154    return get_raw_vector(state, stats.state_normalization)
155
156
157def get_raw_action(
158    action: torch.Tensor,
159    stats: TransitionStatistics,
160) -> torch.Tensor:
161    return get_raw_vector(action, stats.action_normalization)
162
163
164def get_normalized_transition(
165    transition: Transition,
166    stats: TransitionStatistics,
167) -> Transition:
168    norm_transition = Transition(
169        source_state=get_normalized_state(transition.source_state, stats),
170        next_state=get_normalized_state(transition.next_state, stats),
171        action=get_normalized_action(transition.action, stats),
172    )
173    return norm_transition
174
175
176def get_normalized_trajectories(
177    trajectories: list[list[Transition]],
178    stats: TransitionStatistics,
179) -> list[list[Transition]]:
180    """Normalize trajectory vectors using dataset statistics."""
181    normalized_trajectories = [
182        [
183            get_normalized_transition(
184                transition,
185                stats,
186            )
187            for transition in trajectory
188        ]
189        for trajectory in trajectories
190    ]
191    return normalized_trajectories
@dataclass
class VectorStats:
11@dataclass
12class VectorStats:
13    means: list[float]
14    maxs: list[float]
15    mins: list[float]
16    stdevs: list[float]
VectorStats( means: list[float], maxs: list[float], mins: list[float], stdevs: list[float])
means: list[float]
maxs: list[float]
mins: list[float]
stdevs: list[float]
def serialize_vector_stats( stats: VectorStats, output_json: pathlib.Path):
19def serialize_vector_stats(
20    stats: VectorStats,
21    output_json: Path,
22):
23    obj = dict(
24        means=stats.means,
25        maxs=stats.maxs,
26        mins=stats.mins,
27        stdevs=stats.stdevs,
28    )
29    with open(output_json, "wt") as fp:
30        json.dump(obj, fp, indent=2)
def deserialize_vector_stats( serialized_json: pathlib.Path) -> VectorStats:
33def deserialize_vector_stats(
34    serialized_json: Path,
35) -> VectorStats:
36    with open(serialized_json, "rt") as fp:
37        obj = json.load(fp)
38    vector_stats = VectorStats(
39        means=obj["means"],
40        maxs=obj["maxs"],
41        mins=obj["mins"],
42        stdevs=obj["stdevs"],
43    )
44    return vector_stats
def get_vector_stats(vectors: list[list[float]]) -> VectorStats:
47def get_vector_stats(vectors: list[list[float]]) -> VectorStats:
48    feature_vals: list[list[float]] = torch.tensor(
49        vectors
50    ).transpose(0, 1).tolist()
51    means = [
52        statistics.mean(vals)
53        for vals in feature_vals
54    ]
55    stdevs = [
56        statistics.stdev(vals)
57        for vals in feature_vals
58    ]
59    maxs = [
60        max(vals)
61        for vals in feature_vals
62    ]
63    mins = [
64        min(vals)
65        for vals in feature_vals
66    ]
67    normalization = VectorStats(
68        means=means,
69        maxs=maxs,
70        mins=mins,
71        stdevs=stdevs,
72    )
73    return normalization
def get_normalized_vector( vector: torch.Tensor, stats: VectorStats) -> torch.Tensor:
76def get_normalized_vector(
77    vector: torch.Tensor,
78    stats: VectorStats,
79) -> torch.Tensor:
80    # Patch stdevs
81    stdevs = [
82        stdev_i
83        if stdev_i > 0
84        else 1.0
85        for stdev_i in stats.stdevs
86    ]
87    vals = [
88        (val-stats.means[i])/stdevs[i]
89        for i, val in enumerate(vector)
90    ]
91    norm_vec = torch.stack(vals)
92    return norm_vec
def get_raw_vector( vector: torch.Tensor, stats: VectorStats) -> torch.Tensor:
 95def get_raw_vector(
 96    vector: torch.Tensor,
 97    stats: VectorStats,
 98) -> torch.Tensor:
 99    """'Unnormalize' a vector."""
100    vals = [
101        (val*stats.stdevs[i]+stats.means[i])
102        for i, val in enumerate(vector)
103    ]
104    norm_vec = torch.stack(vals)
105    return norm_vec

'Unnormalize' a vector.

@dataclass
class TransitionStatistics:
108@dataclass
109class TransitionStatistics:
110    """Per-feature statistics."""
111    state_normalization: VectorStats
112    action_normalization: VectorStats

Per-feature statistics.

TransitionStatistics( state_normalization: VectorStats, action_normalization: VectorStats)
state_normalization: VectorStats
action_normalization: VectorStats
def get_transition_statistics( transitions: list[swmpo.transition.Transition]) -> TransitionStatistics:
115def get_transition_statistics(
116    transitions: list[Transition],
117) -> TransitionStatistics:
118    state_vectors = list()
119    action_vectors = list()
120
121    for transition in transitions:
122        # We only use source state because next state will be a source
123        # state of some other vector
124        state_vectors.append(transition.source_state.tolist())
125        action_vectors.append(transition.action.tolist())
126
127    state_normalization = get_vector_stats(state_vectors)
128    action_normalization = get_vector_stats(action_vectors)
129
130    stats = TransitionStatistics(
131        state_normalization=state_normalization,
132        action_normalization=action_normalization,
133    )
134    return stats
def get_normalized_state( state: torch.Tensor, stats: TransitionStatistics) -> torch.Tensor:
137def get_normalized_state(
138    state: torch.Tensor,
139    stats: TransitionStatistics,
140) -> torch.Tensor:
141    return get_normalized_vector(state, stats.state_normalization)
def get_normalized_action( action: torch.Tensor, stats: TransitionStatistics) -> torch.Tensor:
144def get_normalized_action(
145    action: torch.Tensor,
146    stats: TransitionStatistics,
147) -> torch.Tensor:
148    return get_normalized_vector(action, stats.action_normalization)
def get_raw_state( state: torch.Tensor, stats: TransitionStatistics) -> torch.Tensor:
151def get_raw_state(
152    state: torch.Tensor,
153    stats: TransitionStatistics,
154) -> torch.Tensor:
155    return get_raw_vector(state, stats.state_normalization)
def get_raw_action( action: torch.Tensor, stats: TransitionStatistics) -> torch.Tensor:
158def get_raw_action(
159    action: torch.Tensor,
160    stats: TransitionStatistics,
161) -> torch.Tensor:
162    return get_raw_vector(action, stats.action_normalization)
def get_normalized_transition( transition: swmpo.transition.Transition, stats: TransitionStatistics) -> swmpo.transition.Transition:
165def get_normalized_transition(
166    transition: Transition,
167    stats: TransitionStatistics,
168) -> Transition:
169    norm_transition = Transition(
170        source_state=get_normalized_state(transition.source_state, stats),
171        next_state=get_normalized_state(transition.next_state, stats),
172        action=get_normalized_action(transition.action, stats),
173    )
174    return norm_transition
def get_normalized_trajectories( trajectories: list[list[swmpo.transition.Transition]], stats: TransitionStatistics) -> list[list[swmpo.transition.Transition]]:
177def get_normalized_trajectories(
178    trajectories: list[list[Transition]],
179    stats: TransitionStatistics,
180) -> list[list[Transition]]:
181    """Normalize trajectory vectors using dataset statistics."""
182    normalized_trajectories = [
183        [
184            get_normalized_transition(
185                transition,
186                stats,
187            )
188            for transition in trajectory
189        ]
190        for trajectory in trajectories
191    ]
192    return normalized_trajectories

Normalize trajectory vectors using dataset statistics.