swmpo.gymnasium_wrapper

Augment MDPs with state machines using Gymnasium's API.

  1"""Augment MDPs with state machines using Gymnasium's API."""
  2from collections import defaultdict
  3from gymnasium import Wrapper
  4from swmpo.state_machine import StateMachine
  5from swmpo.state_machine import state_machine_model
  6from swmpo.transition import Transition
  7import numpy as np
  8from gymnasium import spaces
  9from gymnasium import Env
 10import statistics
 11import torch
 12
 13
 14def get_one_hot_encoding(
 15    active_i: int,
 16    total_n: int,
 17) -> list[float]:
 18    """Get the one-hot encoding of the given active class."""
 19    one_hot_vector = [
 20        0.0 if i != active_i else 1.0
 21        for i in range(total_n)
 22    ]
 23    return one_hot_vector
 24
 25
 26def get_augmented_observation(
 27    obs: np.ndarray,
 28    active_i: int,
 29    total_n: int,
 30) -> np.ndarray:
 31    """Augment the observation with a one-hot encoding
 32    vector of the current state of the state machine."""
 33    one_hot_vector = np.array(get_one_hot_encoding(active_i, total_n))
 34    augmented_observation = np.concatenate((
 35        obs,
 36        one_hot_vector,
 37    ))
 38    return augmented_observation
 39
 40
 41def get_exploration_reward(
 42    current_episode_mode_rewards: list[float],
 43    prev_episode_mode_rewards: list[float] | None,
 44    window_size: int,
 45) -> float:
 46    """
 47    Reward the agent for exploring modes where reward has been improving.
 48    - current_reward: reward in current timestep
 49    - mode_rewards: previous rewards in current timestep
 50    """
 51    if len(current_episode_mode_rewards) == 0:
 52        return 0.0
 53    if prev_episode_mode_rewards is None:
 54        return 0.0
 55    if len(prev_episode_mode_rewards) == 0:
 56        return 0.0
 57
 58    # Return the rate of improvement in the window of the last rewards
 59    new_mean = statistics.mean(current_episode_mode_rewards)
 60    old_mean = statistics.mean(prev_episode_mode_rewards)
 61    delta_mean = new_mean - old_mean
 62    return max(0.0, delta_mean)
 63
 64
 65def get_total_reward(
 66    base_reward: float,
 67    current_episode_mode_rewards: list[float],
 68    prev_episode_mode_rewards: list[float] | None,
 69    current_mode: int,
 70    current_episode_visited_modes: set[int],
 71    exploration_window_size: int,
 72    extrinsic_reward_scale: float,
 73) -> float:
 74    # Get exploration reward
 75    exploration_reward = get_exploration_reward(
 76        prev_episode_mode_rewards=prev_episode_mode_rewards,
 77        current_episode_mode_rewards=current_episode_mode_rewards,
 78        window_size=exploration_window_size,
 79    )
 80
 81    # Get extrinsic reward
 82    should_reward = current_mode not in current_episode_visited_modes
 83    extrinsic_reward = extrinsic_reward_scale if should_reward else 0.0
 84
 85    # Get total reward
 86    total_reward = sum((
 87        base_reward,
 88        extrinsic_reward,
 89        exploration_reward
 90    ))
 91    return total_reward
 92
 93
 94def cast_state(state):
 95    if not isinstance(state, torch.Tensor):
 96        if isinstance(state, np.ndarray):
 97            state = torch.from_numpy(state)
 98        else:
 99            state = torch.tensor(state)
100    else:
101        state = state
102    return state
103
104
105class DeepSynthWrapper(Wrapper):
106    """Augment an environment with information from a state machine.
107
108    This wrapper modifies observations and rewards:
109    - reward is augmented providing a constant value for each new
110      node visited during an episode.
111    - observations are augmented with a one-hot encoding of the state
112      of the state machine.
113
114    It is expected the wrapped env returns a dict with a `state` key that
115    contains a `torch.Tensor` that will be used to update the state machine.
116
117    This augmented version of an MDP was proposed in Hasanbeig et al, 2021.
118    """
119
120    def __init__(
121            self,
122            env: Env,
123            state_machine: StateMachine,
124            initial_state_machine_state: int,
125            extrinsic_reward_scale: float,
126            exploration_window_size: int,
127            dt: float,
128            ):
129        super().__init__(env)
130        old_space = env.observation_space
131        assert isinstance(old_space, spaces.Box)
132        encoding_len = len(get_one_hot_encoding(
133            initial_state_machine_state, len(state_machine.local_models),
134        ))
135        low = np.concatenate((
136            old_space.low,
137            np.array([0.0 for _ in range(encoding_len)]),
138        ))
139        high = np.concatenate((
140            old_space.high,
141            np.array([1.0 for _ in range(encoding_len)]),
142        ))
143        self.observation_space = spaces.Box(
144            low=low,
145            high=high,
146            dtype=np.float32,
147        )
148        self.initial_node = initial_state_machine_state
149        self.current_node = initial_state_machine_state
150        self.state_machine = state_machine
151        self.extrinsic_reward_scale = extrinsic_reward_scale
152        self.visited_nodes = set()
153        self.exploration_window_size = exploration_window_size
154        self.dt = dt
155
156        # Keep track of per-mode rewards
157        self.mode_rewards = [defaultdict(list)]
158
159    def reset(self, *args, **kwargs):
160        self.current_node = self.initial_node
161        self.visited_nodes = set()
162        self.mode_rewards.append(defaultdict(list))
163
164        obs, info = self.env.reset(*args, **kwargs)
165        augmented_obs = get_augmented_observation(
166            obs=obs,
167            active_i=self.current_node,
168            total_n=len(self.state_machine.local_models),
169        )
170        info["active_mode"] = self.current_node
171        self.state = info["state"]
172        return augmented_obs, info
173
174    def step(self, action):
175        # Step environment
176        prev_state = self.state
177        obs, reward, terminated, truncated, info = self.env.step(action)
178        info["active_mode"] = self.current_node
179
180        # Step node machine
181        new_state = info["state"]
182        _, new_node = state_machine_model(
183            state_machine=self.state_machine,
184            action=torch.from_numpy(action),
185            state=cast_state(prev_state),
186            current_node=self.current_node,
187            dt=self.dt,
188        )
189
190        # Get total reward
191        if len(self.mode_rewards) < 2:
192            prev_episode_mode_rewards = None
193        else:
194            prev_episode_mode_rewards = self.mode_rewards[-2]
195        total_reward = get_total_reward(
196            base_reward=float(reward),
197            prev_episode_mode_rewards=prev_episode_mode_rewards,
198            current_episode_mode_rewards=self.mode_rewards[-1][new_node],
199            current_mode=new_node,
200            current_episode_visited_modes=self.visited_nodes,
201            exploration_window_size=self.exploration_window_size,
202            extrinsic_reward_scale=self.extrinsic_reward_scale,
203        )
204
205        # Get augmented observation
206        augmented_obs = get_augmented_observation(
207            obs=obs,
208            active_i=new_node,
209            total_n=len(self.state_machine.local_models),
210        )
211
212        # Update per-mode rewards
213        self.mode_rewards[-1][new_node].append(float(reward))
214
215        # Update state machine and visited nodes
216        self.visited_nodes = self.visited_nodes | {new_node}
217        self.current_node = new_node
218
219        self.state = new_state
220
221        return augmented_obs, total_reward, terminated, truncated, info
def get_one_hot_encoding(active_i: int, total_n: int) -> list[float]:
15def get_one_hot_encoding(
16    active_i: int,
17    total_n: int,
18) -> list[float]:
19    """Get the one-hot encoding of the given active class."""
20    one_hot_vector = [
21        0.0 if i != active_i else 1.0
22        for i in range(total_n)
23    ]
24    return one_hot_vector

Get the one-hot encoding of the given active class.

def get_augmented_observation(obs: numpy.ndarray, active_i: int, total_n: int) -> numpy.ndarray:
27def get_augmented_observation(
28    obs: np.ndarray,
29    active_i: int,
30    total_n: int,
31) -> np.ndarray:
32    """Augment the observation with a one-hot encoding
33    vector of the current state of the state machine."""
34    one_hot_vector = np.array(get_one_hot_encoding(active_i, total_n))
35    augmented_observation = np.concatenate((
36        obs,
37        one_hot_vector,
38    ))
39    return augmented_observation

Augment the observation with a one-hot encoding vector of the current state of the state machine.

def get_exploration_reward( current_episode_mode_rewards: list[float], prev_episode_mode_rewards: list[float] | None, window_size: int) -> float:
42def get_exploration_reward(
43    current_episode_mode_rewards: list[float],
44    prev_episode_mode_rewards: list[float] | None,
45    window_size: int,
46) -> float:
47    """
48    Reward the agent for exploring modes where reward has been improving.
49    - current_reward: reward in current timestep
50    - mode_rewards: previous rewards in current timestep
51    """
52    if len(current_episode_mode_rewards) == 0:
53        return 0.0
54    if prev_episode_mode_rewards is None:
55        return 0.0
56    if len(prev_episode_mode_rewards) == 0:
57        return 0.0
58
59    # Return the rate of improvement in the window of the last rewards
60    new_mean = statistics.mean(current_episode_mode_rewards)
61    old_mean = statistics.mean(prev_episode_mode_rewards)
62    delta_mean = new_mean - old_mean
63    return max(0.0, delta_mean)

Reward the agent for exploring modes where reward has been improving.

  • current_reward: reward in current timestep
  • mode_rewards: previous rewards in current timestep
def get_total_reward( base_reward: float, current_episode_mode_rewards: list[float], prev_episode_mode_rewards: list[float] | None, current_mode: int, current_episode_visited_modes: set[int], exploration_window_size: int, extrinsic_reward_scale: float) -> float:
66def get_total_reward(
67    base_reward: float,
68    current_episode_mode_rewards: list[float],
69    prev_episode_mode_rewards: list[float] | None,
70    current_mode: int,
71    current_episode_visited_modes: set[int],
72    exploration_window_size: int,
73    extrinsic_reward_scale: float,
74) -> float:
75    # Get exploration reward
76    exploration_reward = get_exploration_reward(
77        prev_episode_mode_rewards=prev_episode_mode_rewards,
78        current_episode_mode_rewards=current_episode_mode_rewards,
79        window_size=exploration_window_size,
80    )
81
82    # Get extrinsic reward
83    should_reward = current_mode not in current_episode_visited_modes
84    extrinsic_reward = extrinsic_reward_scale if should_reward else 0.0
85
86    # Get total reward
87    total_reward = sum((
88        base_reward,
89        extrinsic_reward,
90        exploration_reward
91    ))
92    return total_reward
def cast_state(state):
 95def cast_state(state):
 96    if not isinstance(state, torch.Tensor):
 97        if isinstance(state, np.ndarray):
 98            state = torch.from_numpy(state)
 99        else:
100            state = torch.tensor(state)
101    else:
102        state = state
103    return state
class DeepSynthWrapper(gymnasium.core.Env[~WrapperObsType, ~WrapperActType], typing.Generic[~WrapperObsType, ~WrapperActType, ~ObsType, ~ActType]):
106class DeepSynthWrapper(Wrapper):
107    """Augment an environment with information from a state machine.
108
109    This wrapper modifies observations and rewards:
110    - reward is augmented providing a constant value for each new
111      node visited during an episode.
112    - observations are augmented with a one-hot encoding of the state
113      of the state machine.
114
115    It is expected the wrapped env returns a dict with a `state` key that
116    contains a `torch.Tensor` that will be used to update the state machine.
117
118    This augmented version of an MDP was proposed in Hasanbeig et al, 2021.
119    """
120
121    def __init__(
122            self,
123            env: Env,
124            state_machine: StateMachine,
125            initial_state_machine_state: int,
126            extrinsic_reward_scale: float,
127            exploration_window_size: int,
128            dt: float,
129            ):
130        super().__init__(env)
131        old_space = env.observation_space
132        assert isinstance(old_space, spaces.Box)
133        encoding_len = len(get_one_hot_encoding(
134            initial_state_machine_state, len(state_machine.local_models),
135        ))
136        low = np.concatenate((
137            old_space.low,
138            np.array([0.0 for _ in range(encoding_len)]),
139        ))
140        high = np.concatenate((
141            old_space.high,
142            np.array([1.0 for _ in range(encoding_len)]),
143        ))
144        self.observation_space = spaces.Box(
145            low=low,
146            high=high,
147            dtype=np.float32,
148        )
149        self.initial_node = initial_state_machine_state
150        self.current_node = initial_state_machine_state
151        self.state_machine = state_machine
152        self.extrinsic_reward_scale = extrinsic_reward_scale
153        self.visited_nodes = set()
154        self.exploration_window_size = exploration_window_size
155        self.dt = dt
156
157        # Keep track of per-mode rewards
158        self.mode_rewards = [defaultdict(list)]
159
160    def reset(self, *args, **kwargs):
161        self.current_node = self.initial_node
162        self.visited_nodes = set()
163        self.mode_rewards.append(defaultdict(list))
164
165        obs, info = self.env.reset(*args, **kwargs)
166        augmented_obs = get_augmented_observation(
167            obs=obs,
168            active_i=self.current_node,
169            total_n=len(self.state_machine.local_models),
170        )
171        info["active_mode"] = self.current_node
172        self.state = info["state"]
173        return augmented_obs, info
174
175    def step(self, action):
176        # Step environment
177        prev_state = self.state
178        obs, reward, terminated, truncated, info = self.env.step(action)
179        info["active_mode"] = self.current_node
180
181        # Step node machine
182        new_state = info["state"]
183        _, new_node = state_machine_model(
184            state_machine=self.state_machine,
185            action=torch.from_numpy(action),
186            state=cast_state(prev_state),
187            current_node=self.current_node,
188            dt=self.dt,
189        )
190
191        # Get total reward
192        if len(self.mode_rewards) < 2:
193            prev_episode_mode_rewards = None
194        else:
195            prev_episode_mode_rewards = self.mode_rewards[-2]
196        total_reward = get_total_reward(
197            base_reward=float(reward),
198            prev_episode_mode_rewards=prev_episode_mode_rewards,
199            current_episode_mode_rewards=self.mode_rewards[-1][new_node],
200            current_mode=new_node,
201            current_episode_visited_modes=self.visited_nodes,
202            exploration_window_size=self.exploration_window_size,
203            extrinsic_reward_scale=self.extrinsic_reward_scale,
204        )
205
206        # Get augmented observation
207        augmented_obs = get_augmented_observation(
208            obs=obs,
209            active_i=new_node,
210            total_n=len(self.state_machine.local_models),
211        )
212
213        # Update per-mode rewards
214        self.mode_rewards[-1][new_node].append(float(reward))
215
216        # Update state machine and visited nodes
217        self.visited_nodes = self.visited_nodes | {new_node}
218        self.current_node = new_node
219
220        self.state = new_state
221
222        return augmented_obs, total_reward, terminated, truncated, info

Augment an environment with information from a state machine.

This wrapper modifies observations and rewards:

  • reward is augmented providing a constant value for each new node visited during an episode.
  • observations are augmented with a one-hot encoding of the state of the state machine.

It is expected the wrapped env returns a dict with a state key that contains a torch.Tensor that will be used to update the state machine.

This augmented version of an MDP was proposed in Hasanbeig et al, 2021.

DeepSynthWrapper( env: gymnasium.core.Env, state_machine: swmpo.state_machine.StateMachine, initial_state_machine_state: int, extrinsic_reward_scale: float, exploration_window_size: int, dt: float)
121    def __init__(
122            self,
123            env: Env,
124            state_machine: StateMachine,
125            initial_state_machine_state: int,
126            extrinsic_reward_scale: float,
127            exploration_window_size: int,
128            dt: float,
129            ):
130        super().__init__(env)
131        old_space = env.observation_space
132        assert isinstance(old_space, spaces.Box)
133        encoding_len = len(get_one_hot_encoding(
134            initial_state_machine_state, len(state_machine.local_models),
135        ))
136        low = np.concatenate((
137            old_space.low,
138            np.array([0.0 for _ in range(encoding_len)]),
139        ))
140        high = np.concatenate((
141            old_space.high,
142            np.array([1.0 for _ in range(encoding_len)]),
143        ))
144        self.observation_space = spaces.Box(
145            low=low,
146            high=high,
147            dtype=np.float32,
148        )
149        self.initial_node = initial_state_machine_state
150        self.current_node = initial_state_machine_state
151        self.state_machine = state_machine
152        self.extrinsic_reward_scale = extrinsic_reward_scale
153        self.visited_nodes = set()
154        self.exploration_window_size = exploration_window_size
155        self.dt = dt
156
157        # Keep track of per-mode rewards
158        self.mode_rewards = [defaultdict(list)]

Wraps an environment to allow a modular transformation of the step() and reset() methods.

Args: env: The environment to wrap

observation_space: Union[gymnasium.spaces.space.Space[~ObsType], gymnasium.spaces.space.Space[~WrapperObsType]]
398    @property
399    def observation_space(
400        self,
401    ) -> spaces.Space[ObsType] | spaces.Space[WrapperObsType]:
402        """Return the :attr:`Env` :attr:`observation_space` unless overwritten then the wrapper :attr:`observation_space` is used."""
403        if self._observation_space is None:
404            return self.env.observation_space
405        return self._observation_space

Return the Env observation_space unless overwritten then the wrapper observation_space is used.

initial_node
current_node
state_machine
extrinsic_reward_scale
visited_nodes
exploration_window_size
dt
mode_rewards
def reset(self, *args, **kwargs):
160    def reset(self, *args, **kwargs):
161        self.current_node = self.initial_node
162        self.visited_nodes = set()
163        self.mode_rewards.append(defaultdict(list))
164
165        obs, info = self.env.reset(*args, **kwargs)
166        augmented_obs = get_augmented_observation(
167            obs=obs,
168            active_i=self.current_node,
169            total_n=len(self.state_machine.local_models),
170        )
171        info["active_mode"] = self.current_node
172        self.state = info["state"]
173        return augmented_obs, info

Uses the reset() of the env that can be overwritten to change the returned data.

def step(self, action):
175    def step(self, action):
176        # Step environment
177        prev_state = self.state
178        obs, reward, terminated, truncated, info = self.env.step(action)
179        info["active_mode"] = self.current_node
180
181        # Step node machine
182        new_state = info["state"]
183        _, new_node = state_machine_model(
184            state_machine=self.state_machine,
185            action=torch.from_numpy(action),
186            state=cast_state(prev_state),
187            current_node=self.current_node,
188            dt=self.dt,
189        )
190
191        # Get total reward
192        if len(self.mode_rewards) < 2:
193            prev_episode_mode_rewards = None
194        else:
195            prev_episode_mode_rewards = self.mode_rewards[-2]
196        total_reward = get_total_reward(
197            base_reward=float(reward),
198            prev_episode_mode_rewards=prev_episode_mode_rewards,
199            current_episode_mode_rewards=self.mode_rewards[-1][new_node],
200            current_mode=new_node,
201            current_episode_visited_modes=self.visited_nodes,
202            exploration_window_size=self.exploration_window_size,
203            extrinsic_reward_scale=self.extrinsic_reward_scale,
204        )
205
206        # Get augmented observation
207        augmented_obs = get_augmented_observation(
208            obs=obs,
209            active_i=new_node,
210            total_n=len(self.state_machine.local_models),
211        )
212
213        # Update per-mode rewards
214        self.mode_rewards[-1][new_node].append(float(reward))
215
216        # Update state machine and visited nodes
217        self.visited_nodes = self.visited_nodes | {new_node}
218        self.current_node = new_node
219
220        self.state = new_state
221
222        return augmented_obs, total_reward, terminated, truncated, info

Uses the step() of the env that can be overwritten to change the returned data.