swmpo.gymnasium_wrapper
Augment MDPs with state machines using Gymnasium's API.
1"""Augment MDPs with state machines using Gymnasium's API.""" 2from collections import defaultdict 3from gymnasium import Wrapper 4from swmpo.state_machine import StateMachine 5from swmpo.state_machine import state_machine_model 6from swmpo.transition import Transition 7import numpy as np 8from gymnasium import spaces 9from gymnasium import Env 10import statistics 11import torch 12 13 14def get_one_hot_encoding( 15 active_i: int, 16 total_n: int, 17) -> list[float]: 18 """Get the one-hot encoding of the given active class.""" 19 one_hot_vector = [ 20 0.0 if i != active_i else 1.0 21 for i in range(total_n) 22 ] 23 return one_hot_vector 24 25 26def get_augmented_observation( 27 obs: np.ndarray, 28 active_i: int, 29 total_n: int, 30) -> np.ndarray: 31 """Augment the observation with a one-hot encoding 32 vector of the current state of the state machine.""" 33 one_hot_vector = np.array(get_one_hot_encoding(active_i, total_n)) 34 augmented_observation = np.concatenate(( 35 obs, 36 one_hot_vector, 37 )) 38 return augmented_observation 39 40 41def get_exploration_reward( 42 current_episode_mode_rewards: list[float], 43 prev_episode_mode_rewards: list[float] | None, 44 window_size: int, 45) -> float: 46 """ 47 Reward the agent for exploring modes where reward has been improving. 48 - current_reward: reward in current timestep 49 - mode_rewards: previous rewards in current timestep 50 """ 51 if len(current_episode_mode_rewards) == 0: 52 return 0.0 53 if prev_episode_mode_rewards is None: 54 return 0.0 55 if len(prev_episode_mode_rewards) == 0: 56 return 0.0 57 58 # Return the rate of improvement in the window of the last rewards 59 new_mean = statistics.mean(current_episode_mode_rewards) 60 old_mean = statistics.mean(prev_episode_mode_rewards) 61 delta_mean = new_mean - old_mean 62 return max(0.0, delta_mean) 63 64 65def get_total_reward( 66 base_reward: float, 67 current_episode_mode_rewards: list[float], 68 prev_episode_mode_rewards: list[float] | None, 69 current_mode: int, 70 current_episode_visited_modes: set[int], 71 exploration_window_size: int, 72 extrinsic_reward_scale: float, 73) -> float: 74 # Get exploration reward 75 exploration_reward = get_exploration_reward( 76 prev_episode_mode_rewards=prev_episode_mode_rewards, 77 current_episode_mode_rewards=current_episode_mode_rewards, 78 window_size=exploration_window_size, 79 ) 80 81 # Get extrinsic reward 82 should_reward = current_mode not in current_episode_visited_modes 83 extrinsic_reward = extrinsic_reward_scale if should_reward else 0.0 84 85 # Get total reward 86 total_reward = sum(( 87 base_reward, 88 extrinsic_reward, 89 exploration_reward 90 )) 91 return total_reward 92 93 94def cast_state(state): 95 if not isinstance(state, torch.Tensor): 96 if isinstance(state, np.ndarray): 97 state = torch.from_numpy(state) 98 else: 99 state = torch.tensor(state) 100 else: 101 state = state 102 return state 103 104 105class DeepSynthWrapper(Wrapper): 106 """Augment an environment with information from a state machine. 107 108 This wrapper modifies observations and rewards: 109 - reward is augmented providing a constant value for each new 110 node visited during an episode. 111 - observations are augmented with a one-hot encoding of the state 112 of the state machine. 113 114 It is expected the wrapped env returns a dict with a `state` key that 115 contains a `torch.Tensor` that will be used to update the state machine. 116 117 This augmented version of an MDP was proposed in Hasanbeig et al, 2021. 118 """ 119 120 def __init__( 121 self, 122 env: Env, 123 state_machine: StateMachine, 124 initial_state_machine_state: int, 125 extrinsic_reward_scale: float, 126 exploration_window_size: int, 127 dt: float, 128 ): 129 super().__init__(env) 130 old_space = env.observation_space 131 assert isinstance(old_space, spaces.Box) 132 encoding_len = len(get_one_hot_encoding( 133 initial_state_machine_state, len(state_machine.local_models), 134 )) 135 low = np.concatenate(( 136 old_space.low, 137 np.array([0.0 for _ in range(encoding_len)]), 138 )) 139 high = np.concatenate(( 140 old_space.high, 141 np.array([1.0 for _ in range(encoding_len)]), 142 )) 143 self.observation_space = spaces.Box( 144 low=low, 145 high=high, 146 dtype=np.float32, 147 ) 148 self.initial_node = initial_state_machine_state 149 self.current_node = initial_state_machine_state 150 self.state_machine = state_machine 151 self.extrinsic_reward_scale = extrinsic_reward_scale 152 self.visited_nodes = set() 153 self.exploration_window_size = exploration_window_size 154 self.dt = dt 155 156 # Keep track of per-mode rewards 157 self.mode_rewards = [defaultdict(list)] 158 159 def reset(self, *args, **kwargs): 160 self.current_node = self.initial_node 161 self.visited_nodes = set() 162 self.mode_rewards.append(defaultdict(list)) 163 164 obs, info = self.env.reset(*args, **kwargs) 165 augmented_obs = get_augmented_observation( 166 obs=obs, 167 active_i=self.current_node, 168 total_n=len(self.state_machine.local_models), 169 ) 170 info["active_mode"] = self.current_node 171 self.state = info["state"] 172 return augmented_obs, info 173 174 def step(self, action): 175 # Step environment 176 prev_state = self.state 177 obs, reward, terminated, truncated, info = self.env.step(action) 178 info["active_mode"] = self.current_node 179 180 # Step node machine 181 new_state = info["state"] 182 _, new_node = state_machine_model( 183 state_machine=self.state_machine, 184 action=torch.from_numpy(action), 185 state=cast_state(prev_state), 186 current_node=self.current_node, 187 dt=self.dt, 188 ) 189 190 # Get total reward 191 if len(self.mode_rewards) < 2: 192 prev_episode_mode_rewards = None 193 else: 194 prev_episode_mode_rewards = self.mode_rewards[-2] 195 total_reward = get_total_reward( 196 base_reward=float(reward), 197 prev_episode_mode_rewards=prev_episode_mode_rewards, 198 current_episode_mode_rewards=self.mode_rewards[-1][new_node], 199 current_mode=new_node, 200 current_episode_visited_modes=self.visited_nodes, 201 exploration_window_size=self.exploration_window_size, 202 extrinsic_reward_scale=self.extrinsic_reward_scale, 203 ) 204 205 # Get augmented observation 206 augmented_obs = get_augmented_observation( 207 obs=obs, 208 active_i=new_node, 209 total_n=len(self.state_machine.local_models), 210 ) 211 212 # Update per-mode rewards 213 self.mode_rewards[-1][new_node].append(float(reward)) 214 215 # Update state machine and visited nodes 216 self.visited_nodes = self.visited_nodes | {new_node} 217 self.current_node = new_node 218 219 self.state = new_state 220 221 return augmented_obs, total_reward, terminated, truncated, info
def
get_one_hot_encoding(active_i: int, total_n: int) -> list[float]:
15def get_one_hot_encoding( 16 active_i: int, 17 total_n: int, 18) -> list[float]: 19 """Get the one-hot encoding of the given active class.""" 20 one_hot_vector = [ 21 0.0 if i != active_i else 1.0 22 for i in range(total_n) 23 ] 24 return one_hot_vector
Get the one-hot encoding of the given active class.
def
get_augmented_observation(obs: numpy.ndarray, active_i: int, total_n: int) -> numpy.ndarray:
27def get_augmented_observation( 28 obs: np.ndarray, 29 active_i: int, 30 total_n: int, 31) -> np.ndarray: 32 """Augment the observation with a one-hot encoding 33 vector of the current state of the state machine.""" 34 one_hot_vector = np.array(get_one_hot_encoding(active_i, total_n)) 35 augmented_observation = np.concatenate(( 36 obs, 37 one_hot_vector, 38 )) 39 return augmented_observation
Augment the observation with a one-hot encoding vector of the current state of the state machine.
def
get_exploration_reward( current_episode_mode_rewards: list[float], prev_episode_mode_rewards: list[float] | None, window_size: int) -> float:
42def get_exploration_reward( 43 current_episode_mode_rewards: list[float], 44 prev_episode_mode_rewards: list[float] | None, 45 window_size: int, 46) -> float: 47 """ 48 Reward the agent for exploring modes where reward has been improving. 49 - current_reward: reward in current timestep 50 - mode_rewards: previous rewards in current timestep 51 """ 52 if len(current_episode_mode_rewards) == 0: 53 return 0.0 54 if prev_episode_mode_rewards is None: 55 return 0.0 56 if len(prev_episode_mode_rewards) == 0: 57 return 0.0 58 59 # Return the rate of improvement in the window of the last rewards 60 new_mean = statistics.mean(current_episode_mode_rewards) 61 old_mean = statistics.mean(prev_episode_mode_rewards) 62 delta_mean = new_mean - old_mean 63 return max(0.0, delta_mean)
Reward the agent for exploring modes where reward has been improving.
- current_reward: reward in current timestep
- mode_rewards: previous rewards in current timestep
def
get_total_reward( base_reward: float, current_episode_mode_rewards: list[float], prev_episode_mode_rewards: list[float] | None, current_mode: int, current_episode_visited_modes: set[int], exploration_window_size: int, extrinsic_reward_scale: float) -> float:
66def get_total_reward( 67 base_reward: float, 68 current_episode_mode_rewards: list[float], 69 prev_episode_mode_rewards: list[float] | None, 70 current_mode: int, 71 current_episode_visited_modes: set[int], 72 exploration_window_size: int, 73 extrinsic_reward_scale: float, 74) -> float: 75 # Get exploration reward 76 exploration_reward = get_exploration_reward( 77 prev_episode_mode_rewards=prev_episode_mode_rewards, 78 current_episode_mode_rewards=current_episode_mode_rewards, 79 window_size=exploration_window_size, 80 ) 81 82 # Get extrinsic reward 83 should_reward = current_mode not in current_episode_visited_modes 84 extrinsic_reward = extrinsic_reward_scale if should_reward else 0.0 85 86 # Get total reward 87 total_reward = sum(( 88 base_reward, 89 extrinsic_reward, 90 exploration_reward 91 )) 92 return total_reward
def
cast_state(state):
class
DeepSynthWrapper(gymnasium.core.Env[~WrapperObsType, ~WrapperActType], typing.Generic[~WrapperObsType, ~WrapperActType, ~ObsType, ~ActType]):
106class DeepSynthWrapper(Wrapper): 107 """Augment an environment with information from a state machine. 108 109 This wrapper modifies observations and rewards: 110 - reward is augmented providing a constant value for each new 111 node visited during an episode. 112 - observations are augmented with a one-hot encoding of the state 113 of the state machine. 114 115 It is expected the wrapped env returns a dict with a `state` key that 116 contains a `torch.Tensor` that will be used to update the state machine. 117 118 This augmented version of an MDP was proposed in Hasanbeig et al, 2021. 119 """ 120 121 def __init__( 122 self, 123 env: Env, 124 state_machine: StateMachine, 125 initial_state_machine_state: int, 126 extrinsic_reward_scale: float, 127 exploration_window_size: int, 128 dt: float, 129 ): 130 super().__init__(env) 131 old_space = env.observation_space 132 assert isinstance(old_space, spaces.Box) 133 encoding_len = len(get_one_hot_encoding( 134 initial_state_machine_state, len(state_machine.local_models), 135 )) 136 low = np.concatenate(( 137 old_space.low, 138 np.array([0.0 for _ in range(encoding_len)]), 139 )) 140 high = np.concatenate(( 141 old_space.high, 142 np.array([1.0 for _ in range(encoding_len)]), 143 )) 144 self.observation_space = spaces.Box( 145 low=low, 146 high=high, 147 dtype=np.float32, 148 ) 149 self.initial_node = initial_state_machine_state 150 self.current_node = initial_state_machine_state 151 self.state_machine = state_machine 152 self.extrinsic_reward_scale = extrinsic_reward_scale 153 self.visited_nodes = set() 154 self.exploration_window_size = exploration_window_size 155 self.dt = dt 156 157 # Keep track of per-mode rewards 158 self.mode_rewards = [defaultdict(list)] 159 160 def reset(self, *args, **kwargs): 161 self.current_node = self.initial_node 162 self.visited_nodes = set() 163 self.mode_rewards.append(defaultdict(list)) 164 165 obs, info = self.env.reset(*args, **kwargs) 166 augmented_obs = get_augmented_observation( 167 obs=obs, 168 active_i=self.current_node, 169 total_n=len(self.state_machine.local_models), 170 ) 171 info["active_mode"] = self.current_node 172 self.state = info["state"] 173 return augmented_obs, info 174 175 def step(self, action): 176 # Step environment 177 prev_state = self.state 178 obs, reward, terminated, truncated, info = self.env.step(action) 179 info["active_mode"] = self.current_node 180 181 # Step node machine 182 new_state = info["state"] 183 _, new_node = state_machine_model( 184 state_machine=self.state_machine, 185 action=torch.from_numpy(action), 186 state=cast_state(prev_state), 187 current_node=self.current_node, 188 dt=self.dt, 189 ) 190 191 # Get total reward 192 if len(self.mode_rewards) < 2: 193 prev_episode_mode_rewards = None 194 else: 195 prev_episode_mode_rewards = self.mode_rewards[-2] 196 total_reward = get_total_reward( 197 base_reward=float(reward), 198 prev_episode_mode_rewards=prev_episode_mode_rewards, 199 current_episode_mode_rewards=self.mode_rewards[-1][new_node], 200 current_mode=new_node, 201 current_episode_visited_modes=self.visited_nodes, 202 exploration_window_size=self.exploration_window_size, 203 extrinsic_reward_scale=self.extrinsic_reward_scale, 204 ) 205 206 # Get augmented observation 207 augmented_obs = get_augmented_observation( 208 obs=obs, 209 active_i=new_node, 210 total_n=len(self.state_machine.local_models), 211 ) 212 213 # Update per-mode rewards 214 self.mode_rewards[-1][new_node].append(float(reward)) 215 216 # Update state machine and visited nodes 217 self.visited_nodes = self.visited_nodes | {new_node} 218 self.current_node = new_node 219 220 self.state = new_state 221 222 return augmented_obs, total_reward, terminated, truncated, info
Augment an environment with information from a state machine.
This wrapper modifies observations and rewards:
- reward is augmented providing a constant value for each new node visited during an episode.
- observations are augmented with a one-hot encoding of the state of the state machine.
It is expected the wrapped env returns a dict with a state key that
contains a torch.Tensor that will be used to update the state machine.
This augmented version of an MDP was proposed in Hasanbeig et al, 2021.
DeepSynthWrapper( env: gymnasium.core.Env, state_machine: swmpo.state_machine.StateMachine, initial_state_machine_state: int, extrinsic_reward_scale: float, exploration_window_size: int, dt: float)
121 def __init__( 122 self, 123 env: Env, 124 state_machine: StateMachine, 125 initial_state_machine_state: int, 126 extrinsic_reward_scale: float, 127 exploration_window_size: int, 128 dt: float, 129 ): 130 super().__init__(env) 131 old_space = env.observation_space 132 assert isinstance(old_space, spaces.Box) 133 encoding_len = len(get_one_hot_encoding( 134 initial_state_machine_state, len(state_machine.local_models), 135 )) 136 low = np.concatenate(( 137 old_space.low, 138 np.array([0.0 for _ in range(encoding_len)]), 139 )) 140 high = np.concatenate(( 141 old_space.high, 142 np.array([1.0 for _ in range(encoding_len)]), 143 )) 144 self.observation_space = spaces.Box( 145 low=low, 146 high=high, 147 dtype=np.float32, 148 ) 149 self.initial_node = initial_state_machine_state 150 self.current_node = initial_state_machine_state 151 self.state_machine = state_machine 152 self.extrinsic_reward_scale = extrinsic_reward_scale 153 self.visited_nodes = set() 154 self.exploration_window_size = exploration_window_size 155 self.dt = dt 156 157 # Keep track of per-mode rewards 158 self.mode_rewards = [defaultdict(list)]
observation_space: Union[gymnasium.spaces.space.Space[~ObsType], gymnasium.spaces.space.Space[~WrapperObsType]]
398 @property 399 def observation_space( 400 self, 401 ) -> spaces.Space[ObsType] | spaces.Space[WrapperObsType]: 402 """Return the :attr:`Env` :attr:`observation_space` unless overwritten then the wrapper :attr:`observation_space` is used.""" 403 if self._observation_space is None: 404 return self.env.observation_space 405 return self._observation_space
Return the Env observation_space unless overwritten then the wrapper observation_space is used.
def
reset(self, *args, **kwargs):
160 def reset(self, *args, **kwargs): 161 self.current_node = self.initial_node 162 self.visited_nodes = set() 163 self.mode_rewards.append(defaultdict(list)) 164 165 obs, info = self.env.reset(*args, **kwargs) 166 augmented_obs = get_augmented_observation( 167 obs=obs, 168 active_i=self.current_node, 169 total_n=len(self.state_machine.local_models), 170 ) 171 info["active_mode"] = self.current_node 172 self.state = info["state"] 173 return augmented_obs, info
def
step(self, action):
175 def step(self, action): 176 # Step environment 177 prev_state = self.state 178 obs, reward, terminated, truncated, info = self.env.step(action) 179 info["active_mode"] = self.current_node 180 181 # Step node machine 182 new_state = info["state"] 183 _, new_node = state_machine_model( 184 state_machine=self.state_machine, 185 action=torch.from_numpy(action), 186 state=cast_state(prev_state), 187 current_node=self.current_node, 188 dt=self.dt, 189 ) 190 191 # Get total reward 192 if len(self.mode_rewards) < 2: 193 prev_episode_mode_rewards = None 194 else: 195 prev_episode_mode_rewards = self.mode_rewards[-2] 196 total_reward = get_total_reward( 197 base_reward=float(reward), 198 prev_episode_mode_rewards=prev_episode_mode_rewards, 199 current_episode_mode_rewards=self.mode_rewards[-1][new_node], 200 current_mode=new_node, 201 current_episode_visited_modes=self.visited_nodes, 202 exploration_window_size=self.exploration_window_size, 203 extrinsic_reward_scale=self.extrinsic_reward_scale, 204 ) 205 206 # Get augmented observation 207 augmented_obs = get_augmented_observation( 208 obs=obs, 209 active_i=new_node, 210 total_n=len(self.state_machine.local_models), 211 ) 212 213 # Update per-mode rewards 214 self.mode_rewards[-1][new_node].append(float(reward)) 215 216 # Update state machine and visited nodes 217 self.visited_nodes = self.visited_nodes | {new_node} 218 self.current_node = new_node 219 220 self.state = new_state 221 222 return augmented_obs, total_reward, terminated, truncated, info