19. Policy Parametrization Methods#
In the previous chapter, we explored various approaches to approximate dynamic programming, focusing on ways to handle large state spaces through function approximation. However, these methods still face significant challenges when dealing with large or continuous action spaces. The need to maximize over actions during the Bellman operator evaluation becomes computationally prohibitive as the action space grows.
This chapter explores a natural evolution of these ideas: rather than exhaustively searching over actions, we can parameterize and directly optimize the policy itself. We begin by examining how fitted Q methods, while powerful for handling large state spaces, still struggle with action space complexity.
20. Embedded Optimization#
Recall that in fitted Q methods, the main idea is to compute the Bellman operator only at a subset of all states, relying on function approximation to generalize to the remaining states. At each step of the successive approximation loop, we build a dataset of input state-action pairs mapped to their corresponding optimality operator evaluations:
This dataset is then fed to our function approximator (neural network, random forest, linear model) to obtain the next set of parameters:
While this strategy allows us to handle very large or even infinite (continuous) state spaces, it still requires maximizing over actions (\(\max_{a \in A}\)) during the dataset creation when computing the operator \(L\) for each basepoint. This maximization becomes computationally expensive for large action spaces. A natural improvement is to add another level of optimization: for each sample added to our regression dataset, we can employ numerical optimization methods to find actions that maximize the Bellman operator for the given state.
(Fitted Q-Iteration with Explicit Optimization)
Input Given an MDP \((S, A, P, R, \gamma)\), base points \(\mathcal{B}\), function approximator class \(q(s,a; \boldsymbol{\theta})\), maximum iterations \(N\), tolerance \(\varepsilon > 0\)
Output Parameters \(\boldsymbol{\theta}\) for Q-function approximation
Initialize \(\boldsymbol{\theta}_0\) (e.g., for zero initialization)
\(n \leftarrow 0\)
repeat
\(\mathcal{D} \leftarrow \emptyset\) // Regression Dataset
For each \((s,a,r,s') \in \mathcal{B}\): // Assumes Monte Carlo Integration with one sample
\(y_{s,a} \leftarrow r + \gamma \texttt{maximize}(q(s', \cdot; \boldsymbol{\theta}_n))\) // \(s'\) and \(\boldsymbol{\theta}_n\) are kept fixed
\(\mathcal{D} \leftarrow \mathcal{D} \cup \{((s,a), y_{s,a})\}\)
\(\boldsymbol{\theta}_{n+1} \leftarrow \texttt{fit}(\mathcal{D})\)
\(\delta \leftarrow \frac{1}{|\mathcal{D}||A|}\sum_{(s,a) \in \mathcal{D} \times A} (q(s,a; \boldsymbol{\theta}_{n+1}) - q(s,a; \boldsymbol{\theta}_n))^2\)
\(n \leftarrow n + 1\)
until (\(\delta < \varepsilon\) or \(n \geq N\))
return \(\boldsymbol{\theta}_n\)
The above pseudocode introduces a generic \(\texttt{maximize}\) routine which represents any numerical optimization method that searches for an action maximizing the given function. This approach is versatile and can be adapted to different types of action spaces. For continuous action spaces, we can employ standard nonlinear optimization methods like gradient descent or L-BFGS (e.g., using scipy.optimize.minimize). For large discrete action spaces, we can use integer programming solvers - linear integer programming if the Q-function approximator is linear in actions, or mixed-integer nonlinear programming (MINLP) solvers for nonlinear Q-functions. The choice of solver depends on the structure of our Q-function approximator and the constraints on our action space.
20.1. Amortized Optimization Approach#
This process is computationally intensive. A natural question is whether we can “amortize” some of this computation by replacing the explicit optimization for each sample with a direct mapping that gives us an approximate maximizer directly. For Q-functions, recall that the operator is given by:
If \(q^*\) is the optimal state-action value function, then \(v^*(s) = \max_a q^*(s,a)\), and we can derive the optimal policy directly by computing the decision rule:
Since \(q^*\) is a fixed point of \(L\), we can write:
Note that \(d^\star\) is implemented by our \(\texttt{maximize}\) numerical solver in the procedure above. A practical strategy would be to collect these maximizer values at each step and use them to train a function approximator that directly predicts these solutions. Due to computational constraints, we might want to compute these exact maximizer values only for a subset of states, based on some computational budget, and use the fitted decision rule to generalize to the remaining states. This leads to the following amortized version:
(Fitted Q-Iteration with Amortized Optimization)
Input Given an MDP \((S, A, P, R, \gamma)\), base points \(\mathcal{B}\), subset for exact optimization \(\mathcal{B}_{\text{opt}} \subset \mathcal{B}\), Q-function approximator \(q(s,a; \boldsymbol{\theta})\), policy approximator \(d(s; \boldsymbol{w})\), maximum iterations \(N\), tolerance \(\varepsilon > 0\)
Output Parameters \(\boldsymbol{\theta}\) for Q-function, \(\boldsymbol{w}\) for policy
Initialize \(\boldsymbol{\theta}_0\), \(\boldsymbol{w}_0\)
\(n \leftarrow 0\)
repeat
\(\mathcal{D}_q \leftarrow \emptyset\) // Q-function regression dataset
\(\mathcal{D}_d \leftarrow \emptyset\) // Policy regression dataset
For each \((s,a,r,s') \in \mathcal{B}\):
// Determine next state’s action using either exact optimization or approximation
if \(s' \in \mathcal{B}_{\text{opt}}\) then
\(a^*_{s'} \leftarrow \texttt{maximize}(q(s', \cdot; \boldsymbol{\theta}_n))\)
\(\mathcal{D}_d \leftarrow \mathcal{D}_d \cup \{(s', a^*_{s'})\}\)
else
\(a^*_{s'} \leftarrow d(s'; \boldsymbol{w}_n)\)
// Compute Q-function target using chosen action
\(y_{s,a} \leftarrow r + \gamma q(s', a^*_{s'}; \boldsymbol{\theta}_n)\)
\(\mathcal{D}_q \leftarrow \mathcal{D}_q \cup \{((s,a), y_{s,a})\}\)
// Update both function approximators
\(\boldsymbol{\theta}_{n+1} \leftarrow \texttt{fit}(\mathcal{D}_q)\)
\(\boldsymbol{w}_{n+1} \leftarrow \texttt{fit}(\mathcal{D}_d)\)
// Compute convergence criteria
\(\delta_q \leftarrow \frac{1}{|\mathcal{D}_q|}\sum_{(s,a) \in \mathcal{D}_q} (q(s,a; \boldsymbol{\theta}_{n+1}) - q(s,a; \boldsymbol{\theta}_n))^2\)
\(\delta_d \leftarrow \frac{1}{|\mathcal{D}_d|}\sum_{(s,a^*) \in \mathcal{D}_d} \|a^* - d(s; \boldsymbol{w}_{n+1})\|^2\)
\(n \leftarrow n + 1\)
until (\(\max(\delta_q, \delta_d) \geq \varepsilon\) or \(n \geq N\))
return \(\boldsymbol{\theta}_n\), \(\boldsymbol{w}_n\)
An important observation about this procedure is that the policy \(d(s; \boldsymbol{w})\) is being trained on a dataset \(\mathcal{D}_d\) containing optimal actions computed with respect to an evolving Q-function. Specifically, at iteration n, we collect pairs \((s', a^*_{s'})\) where \(a^*_{s'} = \arg\max_a q(s', a; \boldsymbol{\theta}_n)\). However, after updating to \(\boldsymbol{\theta}_{n+1}\), these actions may no longer be optimal with respect to the new Q-function.
A natural approach to handle this staleness would be to maintain only the most recent optimization data. We could modify our procedure to keep a sliding window of K iterations, where at iteration n, we only use data from iterations max(0, n-K) to n. This would be implemented by augmenting each entry in \(\mathcal{D}_d\) with a timestamp:
where t indicates the iteration at which the optimal action was computed. When fitting the policy network, we would then only use data points that are at most K iterations old:
This introduces a trade-off between using more data (larger K) versus using more recent, accurate data (smaller K). The choice of K would depend on how quickly the Q-function evolves and the computational budget available for computing exact optimal actions.
Now the main issue with this approach, apart from the intrinsic out-of-distribution drift that we are trying to track, is that it requires “ground truth” - samples of optimal actions computed by the actual solver. This raises an intriguing question: how few samples do we actually need? Could we even envision eliminating the solver entirely? What seems impossible at first glance turns out to be achievable. The intuition is that as our policy improves at selecting actions, we can bootstrap from these increasingly better choices. As we continuously amortize these improving actions over time, it creates a virtuous cycle of self-improvement towards the optimal policy. But for this bootstrapping process to work, we need careful management - move too quickly and the process may become unstable. Let’s examine how this balance can be achieved.
21. Deterministic Parametrized Policies#
In this section, we consider deterministic parametrized policies of the form \(d(s; \boldsymbol{w})\) which directly output an action given a state. This approach differs from stochastic policies that output probability distributions over actions, making it particularly suitable for continuous control problems where the optimal policy is often deterministic. We’ll see how fitted Q-value methods can be naturally extended to simultaneously learn both the Q-function and such a deterministic policy.
21.1. Neural Fitted Q-iteration for Continuous Actions (NFQCA)#
To develop this approach, let’s first consider an idealized setting where we have access to \(q^\star\), the optimal Q-function. Then we can state our goal as finding policy parameters \(\boldsymbol{w}\) that maximize \(q^\star\) with respect to the actions chosen by our policy across the state space:
However, it’s computationally infeasible to satisfy this condition for every possible state \(s\), especially in large or continuous state spaces. To address this, we assume a distribution of states, denoted \(\mu(s)\), and take the expectation, leading to the problem:
However in practice, we do not have access to \(q^*\). Instead, we need to approximate \(q^*\) with a Q-function \(q(s, a; \boldsymbol{\theta})\), parameterized by \(\boldsymbol{\theta}\), which we will learn simultaneously with the policy function \(d(s; \boldsymbol{w})\). Given a samples of initial states drawn from \(\mu\), we then maximize this objective via a Monte Carlo surrogate problem:
When using neural networks to parametrize \(q\) and \(d\), we obtain the Neural Fitted Q-Iteration with Continuous Actions (NFQCA) algorithm proposed by [19].
(Neural Fitted Q-Iteration with Continuous Actions (NFQCA))
Input MDP \((S, A, P, R, \gamma)\), base points \(\mathcal{B}\), Q-function \(q(s,a; \boldsymbol{\theta})\), policy \(d(s; \boldsymbol{w})\)
Output Parameters \(\boldsymbol{\theta}\) for Q-function, \(\boldsymbol{w}\) for policy
Initialize \(\boldsymbol{\theta}_0\), \(\boldsymbol{w}_0\)
for \(n = 0,1,2,...\) do
\(\mathcal{D}_q \leftarrow \emptyset\)
For each \((s,a,r,s') \in \mathcal{B}\):
\(a'_{s'} \leftarrow d(s'; \boldsymbol{w}_n)\)
\(y_{s,a} \leftarrow r + \gamma q(s', a'_{s'}; \boldsymbol{\theta}_n)\)
\(\mathcal{D}_q \leftarrow \mathcal{D}_q \cup \{((s,a), y_{s,a})\}\)
\(\boldsymbol{\theta}_{n+1} \leftarrow \texttt{fit}(\mathcal{D}_q)\)
\(\boldsymbol{w}_{n+1} \leftarrow \texttt{minimize}_{\boldsymbol{w}} -\frac{1}{|\mathcal{B}|} \sum_{(s,a,r,s') \in \mathcal{B}} q(s, d(s; \boldsymbol{w}); \boldsymbol{\theta}_{n+1})\)
return \(\boldsymbol{\theta}_n\), \(\boldsymbol{w}_n\)
In practice, both the fit
and minimize
operations above are implemented using gradient descent. For the Q-function, the fit
operation minimizes the mean squared error between the network’s predictions and the target values:
For the policy update, the minimize
operation uses gradient descent on the composition of the “critic” network \(q\) and the “actor” network \(d\). This results in the following update rule:
where \(\alpha\) is the learning rate. Both operations can be efficiently implemented using modern automatic differentiation libraries and stochastic gradient descent variants like Adam or RMSProp.
21.2. Deep Deterministic Policy Gradient (DDPG)#
Just as DQN adapted Neural Fitted Q-Iteration to the online setting, DDPG [28] extends NFQCA to learn from data collected online. Like NFQCA, DDPG simultaneously learns a Q-function and a deterministic policy that maximizes it, but differs in how it collects and processes data.
Instead of maintaining a fixed set of basepoints, DDPG uses a replay buffer that continuously stores new transitions as the agent interacts with the environment. Since the policy is deterministic, exploration becomes challenging. DDPG addresses this by adding noise to the policy’s actions during data collection:
where \(\mathcal{N}\) represents exploration noise drawn from an Ornstein-Uhlenbeck (OU) process. The OU process is particularly well-suited for control tasks as it generates temporally correlated noise, leading to smoother exploration trajectories compared to independent random noise. It is defined by the stochastic differential equation:
where \(\mu\) is the long-term mean value (typically set to 0), \(\theta\) determines how strongly the noise is pulled toward this mean, \(\sigma\) scales the random fluctuations, and \(dW_t\) is a Wiener process (continuous-time random walk). For implementation, we discretize this continuous-time process using the Euler-Maruyama method:
where \(\Delta t\) is the time step and \(\epsilon_t \sim \mathcal{N}(0,1)\) is standard Gaussian noise. Think of this process like a spring mechanism: when the noise value \(\mathcal{N}_t\) deviates from \(\mu\), the term \(\theta(\mu - \mathcal{N}_t)\Delta t\) acts like a spring force, continuously pulling it back. Unlike a spring, however, this return to \(\mu\) is not oscillatory - it’s more like motion through a viscous fluid, where the force simply decreases as the noise gets closer to \(\mu\). The random term \(\sigma\sqrt{\Delta t}\epsilon_t\) then adds perturbations to this smooth return trajectory. This creates noise that wanders away from \(\mu\) (enabling exploration) but is always gently pulled back (preventing the actions from wandering too far), with \(\theta\) controlling the strength of this pulling force.
The policy gradient update follows the same principle as NFQCA:
We then embed this exploration mechanism into the data collection procedure and use the same flattened FQI structure that we adopted in DQN. Similar to DQN, flattening the outer-inner optimization structure leads to the need for target networks - both for the Q-function and the policy.
(Deep Deterministic Policy Gradient (DDPG))
Input MDP \((S, A, P, R, \gamma)\), Q-network \(q(s,a; \boldsymbol{\theta})\), policy network \(d(s; \boldsymbol{w})\), learning rates \(\alpha_q, \alpha_d\), replay buffer size \(B\), mini-batch size \(b\), target update frequency \(K\)
Initialize
Parameters \(\boldsymbol{\theta}_0\), \(\boldsymbol{w}_0\) randomly
Target parameters: \(\boldsymbol{\theta}_{target} \leftarrow \boldsymbol{\theta}_0\), \(\boldsymbol{w}_{target} \leftarrow \boldsymbol{w}_0\)
Initialize replay buffer \(\mathcal{R}\) with capacity \(B\)
Initialize exploration noise process \(\mathcal{N}\)
\(n \leftarrow 0\)
while training:
Observe current state \(s\)
Select action with noise: \(a = d(s; \boldsymbol{w}_n) + \mathcal{N}\)
Execute \(a\), observe reward \(r\) and next state \(s'\)
Store \((s,a,r,s')\) in \(\mathcal{R}\), replacing oldest if full
Sample mini-batch of \(b\) transitions \((s_i,a_i,r_i,s'_i)\) from \(\mathcal{R}\)
For each sampled transition:
\(y_i \leftarrow r_i + \gamma q(s'_i, d(s'_i; \boldsymbol{w}_{target}); \boldsymbol{\theta}_{target})\)
Update Q-network: \(\boldsymbol{\theta}_{n+1} \leftarrow \boldsymbol{\theta}_n - \alpha_q \nabla_{\boldsymbol{\theta}} \frac{1}{b}\sum_i(y_i - q(s_i,a_i;\boldsymbol{\theta}_n))^2\)
Update policy: \(\boldsymbol{w}_{n+1} \leftarrow \boldsymbol{w}_n + \alpha_d \frac{1}{b}\sum_i \nabla_a q(s_i,a;\boldsymbol{\theta}_{n+1})|_{a=d(s_i;\boldsymbol{w}_n)} \nabla_{\boldsymbol{w}} d(s_i;\boldsymbol{w}_n)\)
If \(n \bmod K = 0\):
\(\boldsymbol{\theta}_{target} \leftarrow \boldsymbol{\theta}_n\)
\(\boldsymbol{w}_{target} \leftarrow \boldsymbol{w}_n\)
\(n \leftarrow n + 1\)
return \(\boldsymbol{\theta}_n\), \(\boldsymbol{w}_n\)
21.3. Twin Delayed Deep Deterministic Policy Gradient (TD3)#
While DDPG provided a foundation for continuous control with deep RL, it suffers from similar overestimation issues as DQN. TD3 [9] addresses these challenges through three key modifications: double Q-learning to reduce overestimation bias, delayed policy updates to reduce per-update error, and target policy smoothing to prevent exploitation of Q-function errors.
(Twin Delayed Deep Deterministic Policy Gradient (TD3))
Input MDP \((S, A, P, R, \gamma)\), twin Q-networks \(q^A(s,a; \boldsymbol{\theta}^A)\), \(q^B(s,a; \boldsymbol{\theta}^B)\), policy network \(d(s; \boldsymbol{w})\), learning rates \(\alpha_q, \alpha_d\), replay buffer size \(B\), mini-batch size \(b\), policy delay \(d\), noise scale \(\sigma\), noise clip \(c\), exploration noise std \(\sigma_{explore}\)
Initialize
Parameters \(\boldsymbol{\theta}^A_0\), \(\boldsymbol{\theta}^B_0\), \(\boldsymbol{w}_0\) randomly
Target parameters: \(\boldsymbol{\theta}^A_{target} \leftarrow \boldsymbol{\theta}^A_0\), \(\boldsymbol{\theta}^B_{target} \leftarrow \boldsymbol{\theta}^B_0\), \(\boldsymbol{w}_{target} \leftarrow \boldsymbol{w}_0\)
Initialize replay buffer \(\mathcal{R}\) with capacity \(B\)
\(n \leftarrow 0\)
while training:
Observe current state \(s\)
Select action with Gaussian noise: \(a = d(s; \boldsymbol{w}_n) + \epsilon\), where \(\epsilon \sim \mathcal{N}(0, \sigma_{explore})\)
Execute \(a\), observe reward \(r\) and next state \(s'\)
Store \((s,a,r,s')\) in \(\mathcal{R}\), replacing oldest if full
Sample mini-batch of \(b\) transitions \((s_i,a_i,r_i,s'_i)\) from \(\mathcal{R}\)
For each sampled transition:
\(\tilde{a}_i \leftarrow d(s'_i; \boldsymbol{w}_{target}) + \text{clip}(\mathcal{N}(0, \sigma), -c, c)\) // Add clipped noise
\(q_{target} \leftarrow \min(q^A(s'_i, \tilde{a}_i; \boldsymbol{\theta}^A_{target}), q^B(s'_i, \tilde{a}_i; \boldsymbol{\theta}^B_{target}))\)
\(y_i \leftarrow r_i + \gamma q_{target}\)
Update Q-networks:
\(\boldsymbol{\theta}^A_{n+1} \leftarrow \boldsymbol{\theta}^A_n - \alpha_q \nabla_{\boldsymbol{\theta}} \frac{1}{b}\sum_i(y_i - q^A(s_i,a_i;\boldsymbol{\theta}^A_n))^2\)
\(\boldsymbol{\theta}^B_{n+1} \leftarrow \boldsymbol{\theta}^B_n - \alpha_q \nabla_{\boldsymbol{\theta}} \frac{1}{b}\sum_i(y_i - q^B(s_i,a_i;\boldsymbol{\theta}^B_n))^2\)
If \(n \bmod d = 0\): // Delayed policy update
Update policy: \(\boldsymbol{w}_{n+1} \leftarrow \boldsymbol{w}_n + \alpha_d \frac{1}{b}\sum_i \nabla_a q^A(s_i,a;\boldsymbol{\theta}^A_{n+1})|_{a=d(s_i;\boldsymbol{w}_n)} \nabla_{\boldsymbol{w}} d(s_i;\boldsymbol{w}_n)\)
Soft update of target networks:
\(\boldsymbol{\theta}^A_{target} \leftarrow \tau\boldsymbol{\theta}^A_{n+1} + (1-\tau)\boldsymbol{\theta}^A_{target}\)
\(\boldsymbol{\theta}^B_{target} \leftarrow \tau\boldsymbol{\theta}^B_{n+1} + (1-\tau)\boldsymbol{\theta}^B_{target}\)
\(\boldsymbol{w}_{target} \leftarrow \tau\boldsymbol{w}_{n+1} + (1-\tau)\boldsymbol{w}_{target}\)
\(n \leftarrow n + 1\)
return \(\boldsymbol{\theta}^A_n\), \(\boldsymbol{\theta}^B_n\), \(\boldsymbol{w}_n\)
Similar to Double Q-learning, TD3 decouples selection from evaluation when forming the targets. However, instead of intertwining the two existing online and target networks, TD3 suggests learning two Q-functions simultaneously and uses their minimum when computing target values to help combat the overestimation bias further.
Furthermore, when computing target Q-values, TD3 adds small random noise to the target policy’s actions and clips it to keep the perturbations bounded. This regularization technique essentially implements a form of “policy smoothing” that prevents the policy from exploiting areas where the Q-function may have erroneously high values:
$$\tilde{a} = d(s'; \boldsymbol{w}_{target}) + \text{clip}(\mathcal{N}(0, \sigma), -c, c)$$
While DDPG used the OU process which generates temporally correlated noise, TD3’s authors found that simple uncorrelated Gaussian noise works just as well for exploration. It is also easier to implement and tune since you only need to set a single parameter (\(\sigma_{explore}\)) for exploration rather than the multiple parameters required by the OU process (\(\theta\), \(\mu\), \(\sigma\)).
Finally, TD3 updates the policy network (and target networks) less frequently than the Q-networks, typically once every \(d\) Q-function updates. This helps reduce the per-update error and gives the Q-functions time to become more accurate before they are used to update the policy.
22. Stochastic Policy Parameterization#
22.1. Soft Actor Critic#
Adapting the intuition of NFQCA to the smooth Bellman optimality equations leads us to the soft actor-critic algorithm [18]. To understand this connection, let’s first examine how the smooth Bellman equations emerge naturally from entropy regularization.
Consider the standard Bellman operator augmented with an entropy term. The smooth Bellman operator \(\mathrm{L}_\beta\) takes the form:
where \(\mathcal{H}(d) = -\mathbb{E}_{a \sim d}[\log d(a|s)]\) represents the entropy of the policy. To find the solution to the optimization problem embedded in the operator \(\mathrm{L}_\beta\), we set the functional derivative of the objective with respect to the decision rule to zero:
Enforcing that \(\int_A d(a|s)da = 1\) leads to the following Lagrangian:
Solving for \(d\) shows that the optimal policy is a Boltzmann distribution
When we substitute this optimal policy back into the entropy-regularized objective, we obtain:
As we saw at the beginning of this chapter, the smooth Bellman optimality operator for Q-factors is defined as:
This operator maintains the contraction property of its standard counterpart, guaranteeing a unique fixed point \(q^*\). The optimal policy takes the form:
where \(Z(s) = \int_A \exp(\frac{1}{\beta}q^*(s,a))da\). The optimal value function can be recovered as:
22.1.1. Fitted Q-Iteration for the Smooth Bellman Equations#
Following the principles of fitted value iteration, we can approximate approximate the effect of the smooth Bellman operator by computing it exactly at a number of basepoints and generalizing elsewhere using function approximation. Concretely, given a collection of states \(s_i\) and actions \(a_i\), we would compute regression target values:
and fit our Q-function approximator by minimizing:
The expectation over next states can be handled through Monte Carlo estimation using samples from the environment: given a transition \((s_i,a_i,s'_i)\), we can approximate:
However, we still face the challenge of computing the integral over actions. This motivates maintaining separate function approximators for both Q and V, using samples from the current policy to estimate the value function:
By maintaining both approximators, we can estimate targets using sampled actions from our policy. Specifically, if we have a transition \((s_i,a_i,s'_i)\) and sample \(a'_i \sim d(\cdot|s'_i;\phi)\), our target becomes:
This is a remarkable idea! One that exists only due to the dual representation of the smooth Bellman equations as an entropy-regularized problem which transforms the intractable log-sum-exp into a form we can estimate efficiently through sampling.
22.1.2. Approximating Boltzmann Policies by Gaussians#
The entropy-regularized objective and the smooth Bellman equation are mathematically equivalent. However, both formulations face a practical challenge: they require evaluating an intractable integral due to the Boltzmann distribution. Soft Actor-Critic (SAC) addresses this problem by approximating the optimal policy with a simpler, more tractable Gaussian distribution. Given the optimal soft policy:
we seek to approximate it with a Gaussian policy:
This approximation task naturally raises the question of how to measure the “closeness” between the target Boltzmann distribution and a candidate Gaussian approximation. Following common practice in deep learning, we employ the Kullback-Leibler (KL) divergence as our measure of distributional distance. To find the best approximation, we minimize the KL divergence between our policy and the optimal policy, using our current estimate \(q_\theta\) of \(q^*\):
However, an important question remains: how can we solve this optimization problem when it involves the intractable partition function \(Z(s)\)? To see this, recall that for two distributions p and q, the KL divergence takes the form \(D_{KL}(p\|q) = \mathbb{E}_{x \sim p}[\log p(x) - \log q(x)]\). Let’s denote the target Boltzmann distribution based on our current Q-estimate as:
Then the KL minimization becomes:
Since \(\log Z(s)\) is constant with respect to \(\phi\), minimizing this KL divergence is equivalent to:
22.1.3. Reparameterizating the Objective#
One last challenge remains: \(\phi\) appears in the distribution underlying the inner expectation, not just in the integrand. This setting departs from standard empirical risk minimization (ERM) in supervised learning where the distribution of the data (e.g., cats and dogs in image classification) remains fixed regardless of model parameters. Here, however, the “data” - our sampled actions - depends directly on the parameters \(\phi\) we’re trying to optimize.
This dependence prevents us from simply using sample average estimators and differentiating through them, as we typically do in supervised learning. The challenge of correctly and efficiently estimating such derivatives has been extensively studied in the simulation literature under the umbrella of “derivative estimation.” SAC adopts a particular solution known as the reparameterization trick in deep learning (or the IPA estimator in simulation literature). This approach transforms the problem by pushing \(\phi\) inside the expectation through a change of variables.
To address this, we can express our Gaussian policy through a deterministic function \(f_\phi\) that transforms noise samples to actions:
This transformation allows us to rewrite our objective using an expectation over the fixed noise distribution:
Now \(\phi\) appears only in the integrand through the function \(f_\phi\), not in the sampling distribution. The objective involves two terms. First, the log-probability of our Gaussian policy has a simple closed form:
Second, \(\phi\) enters through the composition of \(q^\star\) with \(f_\phi\): \(q^\star(s,f_\phi(s,\epsilon))\). The chain rule for this composition would involve derivatives of both functions. While this might be problematic if the Q-factors were to come from outside of our control (ie. not in the computational graph), but since SAC learns it simultaneously with the policy, then we can simply compute all required derivatives through automatic differentiation.
This composition of policy and value functions - where \(f_\phi\) enters as input to \(q_\theta\) - directly parallels the structure we encountered in deterministic policy methods like NFQCA and DDPG. In those methods, we optimized:
where \(f_\phi(s)\) was a deterministic policy. SAC extends this idea to stochastic policies by having \(f_\phi\) transform both state and noise:
Thus, rather than learning a single action for each state as in DDPG, we learn a function that transforms random noise into actions, explicitly parameterizing a distribution over actions while maintaining the same underlying principle of differentiating through composed policy and value functions.
(Soft Actor-Critic)
Input MDP \((S, A, P, R, \gamma)\), Q-networks \(q^1(s,a; \boldsymbol{\theta}^1)\), \(q^2(s,a; \boldsymbol{\theta}^2)\), value network \(v(s; \boldsymbol{\psi})\), policy network \(d(a|s; \boldsymbol{\phi})\), learning rates \(\alpha_q, \alpha_v, \alpha_\pi\), replay buffer size \(B\), mini-batch size \(b\), target smoothing coefficient \(\tau\)
Initialize
Parameters \(\boldsymbol{\theta}^1_0\), \(\boldsymbol{\theta}^2_0\), \(\boldsymbol{\psi}_0\), \(\boldsymbol{\phi}_0\) randomly
Target parameters: \(\boldsymbol{\bar{\psi}}_0 \leftarrow \boldsymbol{\psi}_0\)
Initialize replay buffer \(\mathcal{R}\) with capacity \(B\)
while training:
Observe current state \(s\)
Sample action from policy: \(a \sim d(a|s; \boldsymbol{\phi})\)
Execute \(a\), observe reward \(r\) and next state \(s'\)
Store \((s, a, r, s')\) in \(\mathcal{R}\), replacing oldest if full
Sample mini-batch of \(b\) transitions \((s_i, a_i, r_i, s'_i)\) from \(\mathcal{R}\)
Update Value Network:
Compute target for value network:
\[ y_v = \mathbb{E}_{a' \sim d(\cdot|s'; \boldsymbol{\phi})} \left[ \min \left( q^1(s', a'; \boldsymbol{\theta}^1), q^2(s', a'; \boldsymbol{\theta}^2) \right) - \alpha \log d(a'|s'; \boldsymbol{\phi}) \right] \]Update \(\boldsymbol{\psi}\) via gradient descent:
\[ \boldsymbol{\psi} \leftarrow \boldsymbol{\psi} - \alpha_v \nabla_{\boldsymbol{\psi}} \frac{1}{b} \sum_i (v(s_i; \boldsymbol{\psi}) - y_v)^2 \]
Update Q-Networks:
Compute targets for Q-networks:
\[ y_q = r_i + \gamma \cdot v(s'_i; \boldsymbol{\bar{\psi}}) \]Update \(\boldsymbol{\theta}^1\) and \(\boldsymbol{\theta}^2\) via gradient descent:
\[ \boldsymbol{\theta}^j \leftarrow \boldsymbol{\theta}^j - \alpha_q \nabla_{\boldsymbol{\theta}^j} \frac{1}{b} \sum_i (q^j(s_i, a_i; \boldsymbol{\theta}^j) - y_q)^2, \quad j \in \{1, 2\} \]
Update Policy Network:
Sample actions \(a \sim d(\cdot|s_i; \boldsymbol{\phi})\) for each \(s_i\) in the mini-batch
Update \(\boldsymbol{\phi}\) via gradient ascent:
\[ \boldsymbol{\phi} \leftarrow \boldsymbol{\phi} + \alpha_\pi \nabla_{\boldsymbol{\phi}} \frac{1}{b} \sum_i \left[ \alpha \log d(a|s_i; \boldsymbol{\phi}) - q^1(s_i, a; \boldsymbol{\theta}^1) \right] \]
Update Target Value Network:
return Learned parameters \(\boldsymbol{\theta}^1\), \(\boldsymbol{\theta}^2\), \(\boldsymbol{\psi}\), \(\boldsymbol{\phi}\)
23. Derivative Estimation for Stochastic Optimization#
Consider optimizing an objective that involves an expectation:
For concreteness, let’s examine a simple example where \(x \sim \mathcal{N}(\theta,1)\) and \(f(x,\theta) = x^2\theta\). The derivative we seek is:
While we can compute this exactly for the Gaussian example, this is often impossible for more general problems. We might then be tempted to approximate our objective using samples:
Then differentiate this approximation:
However, this naive approach ignores that the samples themselves depend on \(\theta\). The correct derivative requires the product rule:
The issue here is that while the first term could be numerically integrated using the Monte Carlo, the second one can’t as it’s not an expectation.
Would there be a way to transform our objective in such a way that the Monte Carlo estimator for the objective could be differentiated directly while ensuring that the resulting derivative is unbiased? We will see that there are two main solutions to that problem: by doing a change of measure, or a change of variables.
23.1. Change of Measure: The Likelihood Ratio Method#
One solution comes from rewriting our objective using any distribution \(q(x)\):
Let’s write this more functionally by defining:
Now when we differentiate \(J\), it’s clear that we must take the partial derivative of \(h\) with respect to its second argument:
The so-called “score function” derivative estimator is obtained for the choice of \(q(x) = p(x;\theta)\), where the ratio simplifies to \(1\) and its derivative becomes the score function:
23.2. A Change of Variables Approach: The Reparameterization Trick#
An alternative approach eliminates the \(\theta\)-dependence in the sampling distribution by expressing \(x\) through a deterministic transformation of the noise:
Therefore if we want to sample from some target distribution \(p(x;\theta)\), we can do so by first sampling from a simple base distribution \(q(\epsilon)\) (like a standard normal) and then transforming those samples through a carefully chosen function \(g\). If \(g(\cdot,\theta)\) is invertible, the change of variables formula tells us how these distributions relate:
For example, if we want to sample from any multivariate Gaussian distributions with covariance matrix \(\Sigma\) and mean \(\mu\), it suffices to be able to sample from a standard normal noise and compute the linear transformation:
where \(\Sigma^{1/2}\) is the matrix square root obtained via Cholesky decomposition. In the univariate case, this transformation is simply:
where \(\sigma = \sqrt{\sigma^2}\) is the standard deviation (square root of the variance).
23.2.1. Common Examples of Reparameterization#
23.2.1.1. Bounded Intervals: The Truncated Normal#
When we need samples constrained to an interval \([a,b]\), we can use the truncated normal distribution. To sample from it, we transform uniform noise through the inverse cumulative distribution function (CDF) of the standard normal:
Here:
\(\Phi(z) = \frac{1}{2}\left[1 + \text{erf}\left(\frac{z}{\sqrt{2}}\right)\right]\) is the CDF of the standard normal distribution
\(\Phi^{-1}\) is its inverse (the quantile function)
\(\text{erf}(z) = \frac{2}{\sqrt{\pi}}\int_0^z e^{-t^2}dt\) is the error function
The resulting samples follow a normal distribution restricted to \([a,b]\), with the density properly normalized over this interval.
23.2.1.2. Sampling from [0,1]: The Kumaraswamy Distribution#
When we need samples in the unit interval [0,1], a natural choice might be the Beta distribution. However, its inverse CDF doesn’t have a closed form. Instead, we can use the Kumaraswamy distribution as a convenient approximation, which allows for a simple reparameterization:
where:
\(\alpha, \beta > 0\) are shape parameters that control the distribution
\(\alpha\) determines the concentration around 0
\(\beta\) determines the concentration around 1
The distribution is similar to Beta(α,β) but with analytically tractable CDF and inverse CDF
The Kumaraswamy distribution has density:
23.2.1.3. Discrete Actions: The Gumbel-Softmax#
When sampling from a categorical distribution with probabilities \(\{\pi_i\}\), one approach uses \(\text{Gumbel}(0,1)\) noise combined with the argmax of log-perturbed probabilities:
This approach, known in machine learning as the Gumbel-Max trick, relies on sampling Gumbel noise from uniform random variables through the transformation \(g_i = -\log(-\log(u_i))\) where \(u_i \sim \text{Uniform}(0,1)\). To see why this gives us samples from the categorical distribution, consider the probability of selecting category \(i\):
Since the difference of two Gumbel random variables follows a logistic distribution, \(g_i - g_j \sim \text{Logistic}(0,1)\), and these differences are independent for different \(j\) (due to the independence of the original Gumbel variables), we can write:
The last equality requires some additional algebra to show, but follows from the fact that these probabilities must sum to 1 over all \(i\).
While we have shown that the Gumbel-Max trick gives us exact samples from a categorical distribution, the argmax operation isn’t differentiable. For stochastic optimization problems of the form:
we need \(g\) to be differentiable with respect to \(\theta\). This leads us to consider a continuous relaxation where we replace the hard argmax with a temperature-controlled softmax:
As \(\tau \to 0\), this approximation approaches the argmax:
The resulting distribution over the probability simplex is called the Gumbel-Softmax (or Concrete) distribution. The temperature parameter \(\tau\) controls the discreteness of our samples: smaller values give samples closer to one-hot vectors but with less stable gradients, while larger values give smoother gradients but more diffuse samples.
23.3. Demonstration: Numerical Analysis of Gradient Estimators#
Let us examine the behavior of our three gradient estimators for the stochastic optimization objective:
To get an analytical expression for the derivative, first note that we can factor out \(\theta\) to obtain \(J(\theta) = \theta\mathbb{E}[x^2]\) where \(x \sim \mathcal{N}(\theta,1)\). By definition of the variance, we know that \(\text{Var}(x) = \mathbb{E}[x^2] - (\mathbb{E}[x])^2\), which we can rearrange to \(\mathbb{E}[x^2] = \text{Var}(x) + (\mathbb{E}[x])^2\). Since \(x \sim \mathcal{N}(\theta,1)\), we have \(\text{Var}(x) = 1\) and \(\mathbb{E}[x] = \theta\), therefore \(\mathbb{E}[x^2] = 1 + \theta^2\). This gives us:
Now differentiating with respect to \(\theta\) using the product rule yields:
For concreteness, we fix \(\theta = 1.0\) and analyze samples drawn using Monte Carlo estimation with batch size 1000 and 1000 independent trials. Evaluating at \(\theta = 1\) gives us \(\frac{d}{d\theta}J(\theta)\big|_{\theta=1} = 1 + 3(1)^2 = 4\), which serves as our ground truth against which we compare our estimators:
First, we consider the naive estimator that incorrectly differentiates the Monte Carlo approximation:
\[\hat{g}_{\text{naive}}(\theta) = \frac{1}{N}\sum_{i=1}^N x_i^2\]For \(x \sim \mathcal{N}(1,1)\), we have \(\mathbb{E}[x^2] = \theta^2 + 1 = 2.0\) and \(\mathbb{E}[\hat{g}_{\text{naive}}] = 2.0\). We should therefore expect a bias of about \(-2\) in our experiment.
Then we compute the score function estimator:
\[\hat{g}_{\text{SF}}(\theta) = \frac{1}{N}\sum_{i=1}^N \left[x_i^2\theta(x_i - \theta) + x_i^2\right]\]This estimator is unbiased with \(\mathbb{E}[\hat{g}_{\text{SF}}] = 4\)
Finally, through the reparameterization \(x = \theta + \epsilon\) where \(\epsilon \sim \mathcal{N}(0,1)\), we obtain:
\[\hat{g}_{\text{RT}}(\theta) = \frac{1}{N}\sum_{i=1}^N \left[2\theta(\theta + \epsilon_i) + (\theta + \epsilon_i)^2\right]\]This estimator is also unbiased with \(\mathbb{E}[\hat{g}_{\text{RT}}] = 4\).
Show code cell source
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import seaborn as sns
key = jax.random.PRNGKey(0)
@jax.jit
def naive_gradient_batch(key, theta):
samples = jax.random.normal(key, (1000,)) + theta
return jnp.mean(samples**2)
@jax.jit
def score_function_batch(key, theta):
samples = jax.random.normal(key, (1000,)) + theta
return jnp.mean(samples**2 * theta * (samples - theta) + samples**2)
@jax.jit
def reparam_gradient_batch(key, theta):
eps = jax.random.normal(key, (1000,))
samples = theta + eps
return jnp.mean(samples**2 + 2*theta*samples)
# Run trials
n_trials = 1000
theta = 1.0
true_grad = 3 + theta**2
keys = jax.random.split(key, n_trials)
naive_estimates = jnp.array([naive_gradient_batch(k, theta) for k in keys])
score_estimates = jnp.array([score_function_batch(k, theta) for k in keys])
reparam_estimates = jnp.array([reparam_gradient_batch(k, theta) for k in keys])
# Create violin plots with individual points
plt.figure(figsize=(12, 6))
data = [naive_estimates, score_estimates, reparam_estimates]
colors = ['#ff9999', '#66b3ff', '#99ff99']
parts = plt.violinplot(data, showextrema=False)
for i, pc in enumerate(parts['bodies']):
pc.set_facecolor(colors[i])
pc.set_alpha(0.7)
# Add box plots
plt.boxplot(data, notch=True, showfliers=False)
# Add true gradient line
plt.axhline(y=true_grad, color='r', linestyle='--', label='True Gradient')
plt.xticks([1, 2, 3], ['Naive', 'Score Function', 'Reparam'])
plt.ylabel('Gradient Estimate')
plt.title(f'Gradient Estimators (θ={theta}, true grad={true_grad:.2f})')
plt.grid(True, alpha=0.3)
plt.legend()
# Print statistics
methods = {
'Naive': naive_estimates,
'Score Function': score_estimates,
'Reparameterization': reparam_estimates
}
for name, estimates in methods.items():
bias = jnp.mean(estimates) - true_grad
variance = jnp.var(estimates)
print(f"\n{name}:")
print(f"Mean: {jnp.mean(estimates):.6f}")
print(f"Bias: {bias:.6f}")
print(f"Variance: {variance:.6f}")
print(f"MSE: {bias**2 + variance:.6f}")
Naive:
Mean: 1.999266
Bias: -2.000734
Variance: 0.005756
MSE: 4.008693
Score Function:
Mean: 3.995299
Bias: -0.004701
Variance: 0.058130
MSE: 0.058152
Reparameterization:
Mean: 3.999579
Bias: -0.000421
Variance: 0.017229
MSE: 0.017230
The numerical experiments coroborate our theory. The naive estimator consistently underestimates the true gradient by 2.0, though it maintains a relatively small variance. This systematic bias would make it unsuitable for optimization despite its low variance. The score function estimator corrects this bias but introduces substantial variance. While unbiased, this estimator would require many samples to achieve reliable gradient estimates. Finally, the reparameterization trick achieves a much lower variance while remaining unbiased. While this experiment is for didactic purposes only, it reproduces what is commonly found in practice: that when applicable, the reparameterization estimator tends to perform better than the score function counterpart.
24. Stochastic Value Gradients with Model-Based Rollouts#
Building on the intuition of amortized optimization in deterministic policy gradient as well as the reparameterization trick, we can develop a more general framework for policy optimization that handles stochastic dynamics and policies. Let’s start by considering a stochastic policy that can be reparameterized:
Similarly, we can express stochastic dynamics through a reparameterized model:
where \(\boldsymbol{\phi}\) are the parameters of our dynamics model. This allows us to write the expected sum of rewards over a trajectory as:
where states and actions are determined by:
We can now compute gradients through this entire trajectory by sampling the primitive random variables \(\epsilon_t\) and \(\omega_t\) once and applying the chain rule. For a finite batch of \(N\) trajectories, our Monte Carlo approximation becomes:
where each trajectory \(i\) is generated using:
The gradient of this objective with respect to the policy parameters \(\boldsymbol{w}\) can be computed by backpropagation through time:
where the partial derivatives can be computed recursively. The recursive state equations are:
The partial derivatives follow:
With base cases:
If we were to implement the gradient accumulation manually forward in time, we would get the following algorithm:
(Stochastic Value Gradients (SVG) Infinity (manual gradient evaluation))
Initialize
Policy parameters \(\boldsymbol{w}_0\) randomly
\(n \leftarrow 0\)
while not converged:
Sample batch of \(N\) initial states \(s_0^i \sim \rho_0(s)\)
For each initial state:
Sample noise sequences \(\epsilon_{0:T}^i, \omega_{0:T}^i \sim \mathcal{N}(0,I)\)
Generate trajectory using:
\(a_t^i(\boldsymbol{w}) = d(s_t^i(\boldsymbol{w});\boldsymbol{w}) + \sigma(s_t^i(\boldsymbol{w});\boldsymbol{w})\epsilon_t^i\)
\(s_{t+1}^i(\boldsymbol{w}) = f(s_t^i(\boldsymbol{w}),a_t^i(\boldsymbol{w}),\omega_t^i;\boldsymbol{\phi})\)
Compute gradient contributions recursively:
Initialize \(\frac{\partial s_0^i(\boldsymbol{w})}{\partial \boldsymbol{w}} = 0\)
For \(t=0\) to \(T\):
\(\frac{\partial a_t^i(\boldsymbol{w})}{\partial \boldsymbol{w}} = \frac{\partial d}{\partial s_t}\frac{\partial s_t^i(\boldsymbol{w})}{\partial \boldsymbol{w}} + \frac{\partial d}{\partial \boldsymbol{w}} + \epsilon_t^i(\frac{\partial \sigma}{\partial s_t}\frac{\partial s_t^i(\boldsymbol{w})}{\partial \boldsymbol{w}} + \frac{\partial \sigma}{\partial \boldsymbol{w}})\)
\(\frac{\partial s_{t+1}^i(\boldsymbol{w})}{\partial \boldsymbol{w}} = \frac{\partial f}{\partial s_t}\frac{\partial s_t^i(\boldsymbol{w})}{\partial \boldsymbol{w}} + \frac{\partial f}{\partial a_t}\frac{\partial a_t^i(\boldsymbol{w})}{\partial \boldsymbol{w}}\)
Update policy: \(\boldsymbol{w}_{n+1} \leftarrow \boldsymbol{w}_n + \alpha_w \frac{1}{N}\sum_{i=1}^N \sum_{t=0}^T \gamma^t (\frac{\partial r}{\partial s_t^i}\frac{\partial s_t^i(\boldsymbol{w})}{\partial \boldsymbol{w}} + \frac{\partial r}{\partial a_t^i}\frac{\partial a_t^i(\boldsymbol{w})}{\partial \boldsymbol{w}})\)
\(n \leftarrow n + 1\)
return \(\boldsymbol{w}_n\)
Most likely however, this procedure will be implemented within an automatic differentiation framework. In this case, it suffices to implement the following variant:
(Stochastic Value Gradients (SVG) Infinity (automatic differentiation))
Input Initial state distribution \(\rho_0(s)\), policy networks \(d(s;\boldsymbol{w})\) and \(\sigma(s;\boldsymbol{w})\), dynamics model \(f(s,a,\omega;\boldsymbol{\phi})\), reward function \(r(s,a)\), rollout horizon \(T\), learning rate \(\alpha_w\), batch size \(N\)
Initialize
Policy parameters \(\boldsymbol{w}_0\) randomly
\(n \leftarrow 0\)
while not converged:
Sample batch of \(N\) initial states \(s_0^i \sim \rho_0(s)\)
Sample noise sequences \(\epsilon_{0:T}^i, \omega_{0:T}^i \sim \mathcal{N}(0,I)\) for \(i=1,\ldots,N\)
Compute objective using autodiff-enabled computation graph:
For each \(i=1,\ldots,N\):
Initialize \(s_0^i(\boldsymbol{w}) = s_0^i\)
For \(t=0\) to \(T\):
\(a_t^i(\boldsymbol{w}) = d(s_t^i(\boldsymbol{w});\boldsymbol{w}) + \sigma(s_t^i(\boldsymbol{w});\boldsymbol{w})\epsilon_t^i\)
\(s_{t+1}^i(\boldsymbol{w}) = f(s_t^i(\boldsymbol{w}),a_t^i(\boldsymbol{w}),\omega_t^i;\boldsymbol{\phi})\)
\(J(\boldsymbol{w}) = \frac{1}{N}\sum_{i=1}^N \sum_{t=0}^T \gamma^t r(s_t^i(\boldsymbol{w}),a_t^i(\boldsymbol{w}))\)
Compute gradient using autodiff: \(\nabla_{\boldsymbol{w}}J\)
Update policy: \(\boldsymbol{w}_{n+1} \leftarrow \boldsymbol{w}_n + \alpha_w \nabla_{\boldsymbol{w}}J\)
\(n \leftarrow n + 1\)
return \(\boldsymbol{w}_n\)