swmpo.biased_sample

Biased replay buffer for StableBaselines3 OffPolicy algorithms.

It stores the info["ground_truth_mode"] of each experience, and uses it to perform biased sampling.

  1"""Biased replay buffer for StableBaselines3 OffPolicy algorithms.
  2
  3It stores the `info["ground_truth_mode"]` of each experience, and uses
  4it to perform biased sampling.
  5"""
  6from stable_baselines3.common.buffers import ReplayBuffer
  7from stable_baselines3.common.buffers import ReplayBufferSamples
  8from stable_baselines3.common.vec_env import VecNormalize
  9from collections import defaultdict
 10from itertools import product
 11from typing import List
 12from typing import Any
 13from typing import Dict
 14from typing import Optional
 15import random
 16import numpy as np
 17
 18
 19class BiasedModeReplayBuffer(ReplayBuffer):
 20
 21    def __init__(self, *args, **kwargs):
 22        super().__init__(*args, **kwargs)
 23        assert not self.optimize_memory_usage, "Optimized memory usage not supported!"
 24
 25        # Initialize mode
 26        self.modes = np.zeros(
 27            (self.buffer_size, self.n_envs,),
 28            dtype=np.int32,
 29        )
 30
 31    def add(
 32        self,
 33        obs: np.ndarray,
 34        next_obs: np.ndarray,
 35        action: np.ndarray,
 36        reward: np.ndarray,
 37        done: np.ndarray,
 38        infos: List[Dict[str, Any]],
 39    ) -> None:
 40        # The original function adds to the current `self.pos`. So the
 41        # new data should be saved to that location first, then the
 42        # original function should be called
 43
 44        # Extract modes
 45        modes = [
 46            info["ground_truth_mode"] if "ground_truth_mode" in info.keys() else info["active_mode"]
 47            for info in infos
 48        ]
 49
 50        # Save modes
 51        self.modes[self.pos] = np.array(modes, dtype=np.int32)
 52
 53        # Call the original function
 54        super().add(obs, next_obs, action, reward, done, infos)
 55
 56    def sample(
 57        self,
 58        batch_size: int,
 59        env: Optional[VecNormalize],
 60    ) -> None:
 61        """This function returns a sample of experiences from the replay buffer.
 62        The distribution over experiences is independent of the mode.
 63
 64        That is, in the limit, each mode represented in the buffer is sampled
 65        an equal amount of times.
 66        """
 67        # Notes on the implementation of the original function:
 68        # - The replay buffer saves experiences in a collection of array mostly
 69        # as one would expect from theory. The only difference is that
 70        # the implementation is written with `VecEnv` in mind; so instead
 71        # of accessing the arrays with a single index as one would
 72        # expect (`arr[experience_i]`) we have to access with two
 73        # indices, one for step index and the second one for environment
 74        # instance (`arr[batch_i, env_i]`).
 75        # - `ReplayBuffer.sample` simply calls `BaseBuffer.sample` because the
 76        # optimize flag is `false`
 77        # - `BaseBuffer.sample` samples a set of `batch_i` and calls
 78        # `ReplayBuffer._get_sample`, which samples a set of `env_i`.
 79        # We will call a (`batch_i, env_i`) tuple an "experience index".
 80
 81        # Organize experience indices by mode
 82        experiences = defaultdict(list)
 83        batch_upper_bound = self.buffer_size if self.full else self.pos
 84        all_batch_inds = list(range(batch_upper_bound))
 85        all_env_indices = list(range(self.n_envs))
 86        for batch_i, env_i in product(all_batch_inds, all_env_indices):
 87            experience_idx = (batch_i, env_i)
 88            mode_i = self.modes[experience_idx]
 89            experiences[mode_i].append(experience_idx)
 90
 91        # Compute experience weights
 92        # Formula:
 93        # sample weight = (total_n - mode_size)/total_n
 94        # We don't normalize because it will be normalized down the line anyway
 95        exp_total_n = sum(
 96            len(experience_idxs)
 97            for experience_idxs in experiences.values()
 98        )
 99        mode_weights = {
100            mode_i: exp_total_n - len(experience_idxs)
101            for mode_i, experience_idxs in experiences.items()
102        }
103
104        # Create list with all experiences
105        population = list()
106        weights = list()
107        for mode_i, experience_idxs in experiences.items():
108            mode_weight = mode_weights[mode_i]
109            # Special case: there is only one mode. In that case,
110            # the only mode weight is zero. We replace that with one.
111            if mode_weight == 0.0:
112                mode_weight = 1
113            for experience_idx in experience_idxs:
114                population.append(experience_idx)
115                weights.append(mode_weight)
116
117        # Choose experiences
118        choices = random.choices(population, weights, k=batch_size)
119        batch_inds = [batch_i for batch_i, _ in choices]
120        env_indices = [env_i for _, env_i in choices]
121
122        # Packaging logic. Exactly as in `ReplayBuffer._get_samples`
123        next_obs = self._normalize_obs(
124            self.next_observations[batch_inds, env_indices, :], env
125        )
126        data = (
127            self._normalize_obs(self.observations[batch_inds, env_indices, :], env),
128            self.actions[batch_inds, env_indices, :],
129            next_obs,
130            # Only use dones that are not due to timeouts
131            # deactivated by default (timeouts is initialized as an array of False)
132            (self.dones[batch_inds, env_indices] * (1 - self.timeouts[batch_inds, env_indices])).reshape(-1, 1),
133            self._normalize_reward(self.rewards[batch_inds, env_indices].reshape(-1, 1), env),
134        )
135        return ReplayBufferSamples(*tuple(map(self.to_torch, data)))
class BiasedModeReplayBuffer(stable_baselines3.common.buffers.ReplayBuffer):
 20class BiasedModeReplayBuffer(ReplayBuffer):
 21
 22    def __init__(self, *args, **kwargs):
 23        super().__init__(*args, **kwargs)
 24        assert not self.optimize_memory_usage, "Optimized memory usage not supported!"
 25
 26        # Initialize mode
 27        self.modes = np.zeros(
 28            (self.buffer_size, self.n_envs,),
 29            dtype=np.int32,
 30        )
 31
 32    def add(
 33        self,
 34        obs: np.ndarray,
 35        next_obs: np.ndarray,
 36        action: np.ndarray,
 37        reward: np.ndarray,
 38        done: np.ndarray,
 39        infos: List[Dict[str, Any]],
 40    ) -> None:
 41        # The original function adds to the current `self.pos`. So the
 42        # new data should be saved to that location first, then the
 43        # original function should be called
 44
 45        # Extract modes
 46        modes = [
 47            info["ground_truth_mode"] if "ground_truth_mode" in info.keys() else info["active_mode"]
 48            for info in infos
 49        ]
 50
 51        # Save modes
 52        self.modes[self.pos] = np.array(modes, dtype=np.int32)
 53
 54        # Call the original function
 55        super().add(obs, next_obs, action, reward, done, infos)
 56
 57    def sample(
 58        self,
 59        batch_size: int,
 60        env: Optional[VecNormalize],
 61    ) -> None:
 62        """This function returns a sample of experiences from the replay buffer.
 63        The distribution over experiences is independent of the mode.
 64
 65        That is, in the limit, each mode represented in the buffer is sampled
 66        an equal amount of times.
 67        """
 68        # Notes on the implementation of the original function:
 69        # - The replay buffer saves experiences in a collection of array mostly
 70        # as one would expect from theory. The only difference is that
 71        # the implementation is written with `VecEnv` in mind; so instead
 72        # of accessing the arrays with a single index as one would
 73        # expect (`arr[experience_i]`) we have to access with two
 74        # indices, one for step index and the second one for environment
 75        # instance (`arr[batch_i, env_i]`).
 76        # - `ReplayBuffer.sample` simply calls `BaseBuffer.sample` because the
 77        # optimize flag is `false`
 78        # - `BaseBuffer.sample` samples a set of `batch_i` and calls
 79        # `ReplayBuffer._get_sample`, which samples a set of `env_i`.
 80        # We will call a (`batch_i, env_i`) tuple an "experience index".
 81
 82        # Organize experience indices by mode
 83        experiences = defaultdict(list)
 84        batch_upper_bound = self.buffer_size if self.full else self.pos
 85        all_batch_inds = list(range(batch_upper_bound))
 86        all_env_indices = list(range(self.n_envs))
 87        for batch_i, env_i in product(all_batch_inds, all_env_indices):
 88            experience_idx = (batch_i, env_i)
 89            mode_i = self.modes[experience_idx]
 90            experiences[mode_i].append(experience_idx)
 91
 92        # Compute experience weights
 93        # Formula:
 94        # sample weight = (total_n - mode_size)/total_n
 95        # We don't normalize because it will be normalized down the line anyway
 96        exp_total_n = sum(
 97            len(experience_idxs)
 98            for experience_idxs in experiences.values()
 99        )
100        mode_weights = {
101            mode_i: exp_total_n - len(experience_idxs)
102            for mode_i, experience_idxs in experiences.items()
103        }
104
105        # Create list with all experiences
106        population = list()
107        weights = list()
108        for mode_i, experience_idxs in experiences.items():
109            mode_weight = mode_weights[mode_i]
110            # Special case: there is only one mode. In that case,
111            # the only mode weight is zero. We replace that with one.
112            if mode_weight == 0.0:
113                mode_weight = 1
114            for experience_idx in experience_idxs:
115                population.append(experience_idx)
116                weights.append(mode_weight)
117
118        # Choose experiences
119        choices = random.choices(population, weights, k=batch_size)
120        batch_inds = [batch_i for batch_i, _ in choices]
121        env_indices = [env_i for _, env_i in choices]
122
123        # Packaging logic. Exactly as in `ReplayBuffer._get_samples`
124        next_obs = self._normalize_obs(
125            self.next_observations[batch_inds, env_indices, :], env
126        )
127        data = (
128            self._normalize_obs(self.observations[batch_inds, env_indices, :], env),
129            self.actions[batch_inds, env_indices, :],
130            next_obs,
131            # Only use dones that are not due to timeouts
132            # deactivated by default (timeouts is initialized as an array of False)
133            (self.dones[batch_inds, env_indices] * (1 - self.timeouts[batch_inds, env_indices])).reshape(-1, 1),
134            self._normalize_reward(self.rewards[batch_inds, env_indices].reshape(-1, 1), env),
135        )
136        return ReplayBufferSamples(*tuple(map(self.to_torch, data)))

Replay buffer used in off-policy algorithms like SAC/TD3.

Parameters
BiasedModeReplayBuffer(*args, **kwargs)
22    def __init__(self, *args, **kwargs):
23        super().__init__(*args, **kwargs)
24        assert not self.optimize_memory_usage, "Optimized memory usage not supported!"
25
26        # Initialize mode
27        self.modes = np.zeros(
28            (self.buffer_size, self.n_envs,),
29            dtype=np.int32,
30        )
modes
def add( self, obs: numpy.ndarray, next_obs: numpy.ndarray, action: numpy.ndarray, reward: numpy.ndarray, done: numpy.ndarray, infos: List[Dict[str, Any]]) -> None:
32    def add(
33        self,
34        obs: np.ndarray,
35        next_obs: np.ndarray,
36        action: np.ndarray,
37        reward: np.ndarray,
38        done: np.ndarray,
39        infos: List[Dict[str, Any]],
40    ) -> None:
41        # The original function adds to the current `self.pos`. So the
42        # new data should be saved to that location first, then the
43        # original function should be called
44
45        # Extract modes
46        modes = [
47            info["ground_truth_mode"] if "ground_truth_mode" in info.keys() else info["active_mode"]
48            for info in infos
49        ]
50
51        # Save modes
52        self.modes[self.pos] = np.array(modes, dtype=np.int32)
53
54        # Call the original function
55        super().add(obs, next_obs, action, reward, done, infos)

Add elements to the buffer.

def sample( self, batch_size: int, env: Optional[stable_baselines3.common.vec_env.vec_normalize.VecNormalize]) -> None:
 57    def sample(
 58        self,
 59        batch_size: int,
 60        env: Optional[VecNormalize],
 61    ) -> None:
 62        """This function returns a sample of experiences from the replay buffer.
 63        The distribution over experiences is independent of the mode.
 64
 65        That is, in the limit, each mode represented in the buffer is sampled
 66        an equal amount of times.
 67        """
 68        # Notes on the implementation of the original function:
 69        # - The replay buffer saves experiences in a collection of array mostly
 70        # as one would expect from theory. The only difference is that
 71        # the implementation is written with `VecEnv` in mind; so instead
 72        # of accessing the arrays with a single index as one would
 73        # expect (`arr[experience_i]`) we have to access with two
 74        # indices, one for step index and the second one for environment
 75        # instance (`arr[batch_i, env_i]`).
 76        # - `ReplayBuffer.sample` simply calls `BaseBuffer.sample` because the
 77        # optimize flag is `false`
 78        # - `BaseBuffer.sample` samples a set of `batch_i` and calls
 79        # `ReplayBuffer._get_sample`, which samples a set of `env_i`.
 80        # We will call a (`batch_i, env_i`) tuple an "experience index".
 81
 82        # Organize experience indices by mode
 83        experiences = defaultdict(list)
 84        batch_upper_bound = self.buffer_size if self.full else self.pos
 85        all_batch_inds = list(range(batch_upper_bound))
 86        all_env_indices = list(range(self.n_envs))
 87        for batch_i, env_i in product(all_batch_inds, all_env_indices):
 88            experience_idx = (batch_i, env_i)
 89            mode_i = self.modes[experience_idx]
 90            experiences[mode_i].append(experience_idx)
 91
 92        # Compute experience weights
 93        # Formula:
 94        # sample weight = (total_n - mode_size)/total_n
 95        # We don't normalize because it will be normalized down the line anyway
 96        exp_total_n = sum(
 97            len(experience_idxs)
 98            for experience_idxs in experiences.values()
 99        )
100        mode_weights = {
101            mode_i: exp_total_n - len(experience_idxs)
102            for mode_i, experience_idxs in experiences.items()
103        }
104
105        # Create list with all experiences
106        population = list()
107        weights = list()
108        for mode_i, experience_idxs in experiences.items():
109            mode_weight = mode_weights[mode_i]
110            # Special case: there is only one mode. In that case,
111            # the only mode weight is zero. We replace that with one.
112            if mode_weight == 0.0:
113                mode_weight = 1
114            for experience_idx in experience_idxs:
115                population.append(experience_idx)
116                weights.append(mode_weight)
117
118        # Choose experiences
119        choices = random.choices(population, weights, k=batch_size)
120        batch_inds = [batch_i for batch_i, _ in choices]
121        env_indices = [env_i for _, env_i in choices]
122
123        # Packaging logic. Exactly as in `ReplayBuffer._get_samples`
124        next_obs = self._normalize_obs(
125            self.next_observations[batch_inds, env_indices, :], env
126        )
127        data = (
128            self._normalize_obs(self.observations[batch_inds, env_indices, :], env),
129            self.actions[batch_inds, env_indices, :],
130            next_obs,
131            # Only use dones that are not due to timeouts
132            # deactivated by default (timeouts is initialized as an array of False)
133            (self.dones[batch_inds, env_indices] * (1 - self.timeouts[batch_inds, env_indices])).reshape(-1, 1),
134            self._normalize_reward(self.rewards[batch_inds, env_indices].reshape(-1, 1), env),
135        )
136        return ReplayBufferSamples(*tuple(map(self.to_torch, data)))

This function returns a sample of experiences from the replay buffer. The distribution over experiences is independent of the mode.

That is, in the limit, each mode represented in the buffer is sampled an equal amount of times.