swmpo.biased_sample
Biased replay buffer for StableBaselines3 OffPolicy algorithms.
It stores the info["ground_truth_mode"] of each experience, and uses
it to perform biased sampling.
1"""Biased replay buffer for StableBaselines3 OffPolicy algorithms. 2 3It stores the `info["ground_truth_mode"]` of each experience, and uses 4it to perform biased sampling. 5""" 6from stable_baselines3.common.buffers import ReplayBuffer 7from stable_baselines3.common.buffers import ReplayBufferSamples 8from stable_baselines3.common.vec_env import VecNormalize 9from collections import defaultdict 10from itertools import product 11from typing import List 12from typing import Any 13from typing import Dict 14from typing import Optional 15import random 16import numpy as np 17 18 19class BiasedModeReplayBuffer(ReplayBuffer): 20 21 def __init__(self, *args, **kwargs): 22 super().__init__(*args, **kwargs) 23 assert not self.optimize_memory_usage, "Optimized memory usage not supported!" 24 25 # Initialize mode 26 self.modes = np.zeros( 27 (self.buffer_size, self.n_envs,), 28 dtype=np.int32, 29 ) 30 31 def add( 32 self, 33 obs: np.ndarray, 34 next_obs: np.ndarray, 35 action: np.ndarray, 36 reward: np.ndarray, 37 done: np.ndarray, 38 infos: List[Dict[str, Any]], 39 ) -> None: 40 # The original function adds to the current `self.pos`. So the 41 # new data should be saved to that location first, then the 42 # original function should be called 43 44 # Extract modes 45 modes = [ 46 info["ground_truth_mode"] if "ground_truth_mode" in info.keys() else info["active_mode"] 47 for info in infos 48 ] 49 50 # Save modes 51 self.modes[self.pos] = np.array(modes, dtype=np.int32) 52 53 # Call the original function 54 super().add(obs, next_obs, action, reward, done, infos) 55 56 def sample( 57 self, 58 batch_size: int, 59 env: Optional[VecNormalize], 60 ) -> None: 61 """This function returns a sample of experiences from the replay buffer. 62 The distribution over experiences is independent of the mode. 63 64 That is, in the limit, each mode represented in the buffer is sampled 65 an equal amount of times. 66 """ 67 # Notes on the implementation of the original function: 68 # - The replay buffer saves experiences in a collection of array mostly 69 # as one would expect from theory. The only difference is that 70 # the implementation is written with `VecEnv` in mind; so instead 71 # of accessing the arrays with a single index as one would 72 # expect (`arr[experience_i]`) we have to access with two 73 # indices, one for step index and the second one for environment 74 # instance (`arr[batch_i, env_i]`). 75 # - `ReplayBuffer.sample` simply calls `BaseBuffer.sample` because the 76 # optimize flag is `false` 77 # - `BaseBuffer.sample` samples a set of `batch_i` and calls 78 # `ReplayBuffer._get_sample`, which samples a set of `env_i`. 79 # We will call a (`batch_i, env_i`) tuple an "experience index". 80 81 # Organize experience indices by mode 82 experiences = defaultdict(list) 83 batch_upper_bound = self.buffer_size if self.full else self.pos 84 all_batch_inds = list(range(batch_upper_bound)) 85 all_env_indices = list(range(self.n_envs)) 86 for batch_i, env_i in product(all_batch_inds, all_env_indices): 87 experience_idx = (batch_i, env_i) 88 mode_i = self.modes[experience_idx] 89 experiences[mode_i].append(experience_idx) 90 91 # Compute experience weights 92 # Formula: 93 # sample weight = (total_n - mode_size)/total_n 94 # We don't normalize because it will be normalized down the line anyway 95 exp_total_n = sum( 96 len(experience_idxs) 97 for experience_idxs in experiences.values() 98 ) 99 mode_weights = { 100 mode_i: exp_total_n - len(experience_idxs) 101 for mode_i, experience_idxs in experiences.items() 102 } 103 104 # Create list with all experiences 105 population = list() 106 weights = list() 107 for mode_i, experience_idxs in experiences.items(): 108 mode_weight = mode_weights[mode_i] 109 # Special case: there is only one mode. In that case, 110 # the only mode weight is zero. We replace that with one. 111 if mode_weight == 0.0: 112 mode_weight = 1 113 for experience_idx in experience_idxs: 114 population.append(experience_idx) 115 weights.append(mode_weight) 116 117 # Choose experiences 118 choices = random.choices(population, weights, k=batch_size) 119 batch_inds = [batch_i for batch_i, _ in choices] 120 env_indices = [env_i for _, env_i in choices] 121 122 # Packaging logic. Exactly as in `ReplayBuffer._get_samples` 123 next_obs = self._normalize_obs( 124 self.next_observations[batch_inds, env_indices, :], env 125 ) 126 data = ( 127 self._normalize_obs(self.observations[batch_inds, env_indices, :], env), 128 self.actions[batch_inds, env_indices, :], 129 next_obs, 130 # Only use dones that are not due to timeouts 131 # deactivated by default (timeouts is initialized as an array of False) 132 (self.dones[batch_inds, env_indices] * (1 - self.timeouts[batch_inds, env_indices])).reshape(-1, 1), 133 self._normalize_reward(self.rewards[batch_inds, env_indices].reshape(-1, 1), env), 134 ) 135 return ReplayBufferSamples(*tuple(map(self.to_torch, data)))
class
BiasedModeReplayBuffer(stable_baselines3.common.buffers.ReplayBuffer):
20class BiasedModeReplayBuffer(ReplayBuffer): 21 22 def __init__(self, *args, **kwargs): 23 super().__init__(*args, **kwargs) 24 assert not self.optimize_memory_usage, "Optimized memory usage not supported!" 25 26 # Initialize mode 27 self.modes = np.zeros( 28 (self.buffer_size, self.n_envs,), 29 dtype=np.int32, 30 ) 31 32 def add( 33 self, 34 obs: np.ndarray, 35 next_obs: np.ndarray, 36 action: np.ndarray, 37 reward: np.ndarray, 38 done: np.ndarray, 39 infos: List[Dict[str, Any]], 40 ) -> None: 41 # The original function adds to the current `self.pos`. So the 42 # new data should be saved to that location first, then the 43 # original function should be called 44 45 # Extract modes 46 modes = [ 47 info["ground_truth_mode"] if "ground_truth_mode" in info.keys() else info["active_mode"] 48 for info in infos 49 ] 50 51 # Save modes 52 self.modes[self.pos] = np.array(modes, dtype=np.int32) 53 54 # Call the original function 55 super().add(obs, next_obs, action, reward, done, infos) 56 57 def sample( 58 self, 59 batch_size: int, 60 env: Optional[VecNormalize], 61 ) -> None: 62 """This function returns a sample of experiences from the replay buffer. 63 The distribution over experiences is independent of the mode. 64 65 That is, in the limit, each mode represented in the buffer is sampled 66 an equal amount of times. 67 """ 68 # Notes on the implementation of the original function: 69 # - The replay buffer saves experiences in a collection of array mostly 70 # as one would expect from theory. The only difference is that 71 # the implementation is written with `VecEnv` in mind; so instead 72 # of accessing the arrays with a single index as one would 73 # expect (`arr[experience_i]`) we have to access with two 74 # indices, one for step index and the second one for environment 75 # instance (`arr[batch_i, env_i]`). 76 # - `ReplayBuffer.sample` simply calls `BaseBuffer.sample` because the 77 # optimize flag is `false` 78 # - `BaseBuffer.sample` samples a set of `batch_i` and calls 79 # `ReplayBuffer._get_sample`, which samples a set of `env_i`. 80 # We will call a (`batch_i, env_i`) tuple an "experience index". 81 82 # Organize experience indices by mode 83 experiences = defaultdict(list) 84 batch_upper_bound = self.buffer_size if self.full else self.pos 85 all_batch_inds = list(range(batch_upper_bound)) 86 all_env_indices = list(range(self.n_envs)) 87 for batch_i, env_i in product(all_batch_inds, all_env_indices): 88 experience_idx = (batch_i, env_i) 89 mode_i = self.modes[experience_idx] 90 experiences[mode_i].append(experience_idx) 91 92 # Compute experience weights 93 # Formula: 94 # sample weight = (total_n - mode_size)/total_n 95 # We don't normalize because it will be normalized down the line anyway 96 exp_total_n = sum( 97 len(experience_idxs) 98 for experience_idxs in experiences.values() 99 ) 100 mode_weights = { 101 mode_i: exp_total_n - len(experience_idxs) 102 for mode_i, experience_idxs in experiences.items() 103 } 104 105 # Create list with all experiences 106 population = list() 107 weights = list() 108 for mode_i, experience_idxs in experiences.items(): 109 mode_weight = mode_weights[mode_i] 110 # Special case: there is only one mode. In that case, 111 # the only mode weight is zero. We replace that with one. 112 if mode_weight == 0.0: 113 mode_weight = 1 114 for experience_idx in experience_idxs: 115 population.append(experience_idx) 116 weights.append(mode_weight) 117 118 # Choose experiences 119 choices = random.choices(population, weights, k=batch_size) 120 batch_inds = [batch_i for batch_i, _ in choices] 121 env_indices = [env_i for _, env_i in choices] 122 123 # Packaging logic. Exactly as in `ReplayBuffer._get_samples` 124 next_obs = self._normalize_obs( 125 self.next_observations[batch_inds, env_indices, :], env 126 ) 127 data = ( 128 self._normalize_obs(self.observations[batch_inds, env_indices, :], env), 129 self.actions[batch_inds, env_indices, :], 130 next_obs, 131 # Only use dones that are not due to timeouts 132 # deactivated by default (timeouts is initialized as an array of False) 133 (self.dones[batch_inds, env_indices] * (1 - self.timeouts[batch_inds, env_indices])).reshape(-1, 1), 134 self._normalize_reward(self.rewards[batch_inds, env_indices].reshape(-1, 1), env), 135 ) 136 return ReplayBufferSamples(*tuple(map(self.to_torch, data)))
Replay buffer used in off-policy algorithms like SAC/TD3.
Parameters
- buffer_size: Max number of element in the buffer
- observation_space: Observation space
- action_space: Action space
- device: PyTorch device
- n_envs: Number of parallel environments
- optimize_memory_usage: Enable a memory efficient variant of the replay buffer which reduces by almost a factor two the memory used, at a cost of more complexity. See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 and https://github.com/DLR-RM/stable-baselines3/pull/28#issuecomment-637559274 Cannot be used in combination with handle_timeout_termination.
- handle_timeout_termination: Handle timeout termination (due to timelimit) separately and treat the task as infinite horizon task. https://github.com/DLR-RM/stable-baselines3/issues/284
def
add( self, obs: numpy.ndarray, next_obs: numpy.ndarray, action: numpy.ndarray, reward: numpy.ndarray, done: numpy.ndarray, infos: List[Dict[str, Any]]) -> None:
32 def add( 33 self, 34 obs: np.ndarray, 35 next_obs: np.ndarray, 36 action: np.ndarray, 37 reward: np.ndarray, 38 done: np.ndarray, 39 infos: List[Dict[str, Any]], 40 ) -> None: 41 # The original function adds to the current `self.pos`. So the 42 # new data should be saved to that location first, then the 43 # original function should be called 44 45 # Extract modes 46 modes = [ 47 info["ground_truth_mode"] if "ground_truth_mode" in info.keys() else info["active_mode"] 48 for info in infos 49 ] 50 51 # Save modes 52 self.modes[self.pos] = np.array(modes, dtype=np.int32) 53 54 # Call the original function 55 super().add(obs, next_obs, action, reward, done, infos)
Add elements to the buffer.
def
sample( self, batch_size: int, env: Optional[stable_baselines3.common.vec_env.vec_normalize.VecNormalize]) -> None:
57 def sample( 58 self, 59 batch_size: int, 60 env: Optional[VecNormalize], 61 ) -> None: 62 """This function returns a sample of experiences from the replay buffer. 63 The distribution over experiences is independent of the mode. 64 65 That is, in the limit, each mode represented in the buffer is sampled 66 an equal amount of times. 67 """ 68 # Notes on the implementation of the original function: 69 # - The replay buffer saves experiences in a collection of array mostly 70 # as one would expect from theory. The only difference is that 71 # the implementation is written with `VecEnv` in mind; so instead 72 # of accessing the arrays with a single index as one would 73 # expect (`arr[experience_i]`) we have to access with two 74 # indices, one for step index and the second one for environment 75 # instance (`arr[batch_i, env_i]`). 76 # - `ReplayBuffer.sample` simply calls `BaseBuffer.sample` because the 77 # optimize flag is `false` 78 # - `BaseBuffer.sample` samples a set of `batch_i` and calls 79 # `ReplayBuffer._get_sample`, which samples a set of `env_i`. 80 # We will call a (`batch_i, env_i`) tuple an "experience index". 81 82 # Organize experience indices by mode 83 experiences = defaultdict(list) 84 batch_upper_bound = self.buffer_size if self.full else self.pos 85 all_batch_inds = list(range(batch_upper_bound)) 86 all_env_indices = list(range(self.n_envs)) 87 for batch_i, env_i in product(all_batch_inds, all_env_indices): 88 experience_idx = (batch_i, env_i) 89 mode_i = self.modes[experience_idx] 90 experiences[mode_i].append(experience_idx) 91 92 # Compute experience weights 93 # Formula: 94 # sample weight = (total_n - mode_size)/total_n 95 # We don't normalize because it will be normalized down the line anyway 96 exp_total_n = sum( 97 len(experience_idxs) 98 for experience_idxs in experiences.values() 99 ) 100 mode_weights = { 101 mode_i: exp_total_n - len(experience_idxs) 102 for mode_i, experience_idxs in experiences.items() 103 } 104 105 # Create list with all experiences 106 population = list() 107 weights = list() 108 for mode_i, experience_idxs in experiences.items(): 109 mode_weight = mode_weights[mode_i] 110 # Special case: there is only one mode. In that case, 111 # the only mode weight is zero. We replace that with one. 112 if mode_weight == 0.0: 113 mode_weight = 1 114 for experience_idx in experience_idxs: 115 population.append(experience_idx) 116 weights.append(mode_weight) 117 118 # Choose experiences 119 choices = random.choices(population, weights, k=batch_size) 120 batch_inds = [batch_i for batch_i, _ in choices] 121 env_indices = [env_i for _, env_i in choices] 122 123 # Packaging logic. Exactly as in `ReplayBuffer._get_samples` 124 next_obs = self._normalize_obs( 125 self.next_observations[batch_inds, env_indices, :], env 126 ) 127 data = ( 128 self._normalize_obs(self.observations[batch_inds, env_indices, :], env), 129 self.actions[batch_inds, env_indices, :], 130 next_obs, 131 # Only use dones that are not due to timeouts 132 # deactivated by default (timeouts is initialized as an array of False) 133 (self.dones[batch_inds, env_indices] * (1 - self.timeouts[batch_inds, env_indices])).reshape(-1, 1), 134 self._normalize_reward(self.rewards[batch_inds, env_indices].reshape(-1, 1), env), 135 ) 136 return ReplayBufferSamples(*tuple(map(self.to_torch, data)))
This function returns a sample of experiences from the replay buffer. The distribution over experiences is independent of the mode.
That is, in the limit, each mode represented in the buffer is sampled an equal amount of times.