Truncated Backpropagation for Bilevel Optimization
This post is a paper-reading note. Paper: Truncated Backpropagation for Bilevel Optimization, Amirreza Shaban & Ching-An Cheng et al. AISTATS 2018.
Bilevel Optimization (BO)
- one application of BO in Machine Learning is Hyperparamter Optimization (HO)
\[ \text{min}_{\lambda} f(\hat{w}, \lambda) \text{ s.t. } \hat{w} \approx \text{argmin}_w g(w, \lambda) \]
- denote parameter and hyperparameter by $w$ and $\lambda$, the outer and inner objective function by $f$ (validation loss function) and $g$ (training loss function)
- for this setup, the outer problem (minimizing the validation loss) does not depend on the hyperparameter $\lambda$, so the direct gradient $\frac{\partial \mathcal{f}}{\partial \lambda} = 0$
- note: we follow Shaban et al. 2018 that includes the algorithm to solver the inner objective as part of the problem’s formulation. Thus, we use $\hat{w}$ to denote the solution given by the solver instead of the exact minimizer $w^{*}$
- challenges:
- dependency of the optimization of $\lambda$ on the inner problem is complicated, so that evaluating exact gradients is not scalable for high-dimensional problems
- prior works:
- other approaches for HO: Grid Search (black box, run the training procedure many times), Random Search, Bayesian Optimization, hypernetwork (MacKay & Vicol et al. 2019, Bae et al. 2020)
- Implicit Differentiation: rely on estimate of Jacobian $J_{\lambda} \hat{w} = \frac{\partial \hat{w}}{\partial \lambda}$
- note: this method relies on the assumption that $\hat{w} = w^{*}$, i.e. optimality of the approximate solution
- Dynamical System perspective: treat the iterative algorithm to solve the inner problem as a dynamical system and apply backpropagation through the system
Truncated Backpropagation for Bilevel Optimization
- idea: approximate gradients using Truncated Backpropagation (TBP) through “time” (as in a previous post, TBP reduces time and space complexities by removing long term dependencies)
- note: here “time” refers to the optimization steps performed to solve the inner problem
- hypergradient
- following the Dynamic System perspective, the iterative optimization algorithm that solves inner problem is viewed as a dynamical system: denote the transition function by $\Xi_t$ and the number of iterations by $T$ \[ w_{t + 1} = \Xi_{t + 1} (w_t, \lambda), ~~~ \hat{w_0} = \Xi_{\lambda} \text{ at } t = 0, ~~~ \hat{w} = w_T\]
- unrolling the computation graph: \[ \frac{df}{d\lambda} = \frac{\partial f}{\partial \lambda} + \sum_{t = 0}^T B_t A_{t + 1} … A_T \frac{\partial f}{\partial \hat{w}} ~~~ \text{# by the Chain Rule}\] \[ \text{where } A_{t + 1} = \frac{\partial \Xi_{t + 1}(w_t, \lambda)}{\partial w_t} = \frac{\partial w_{t + 1}}{\partial w_t}, B_{t + 1} = \frac{\partial \Xi_{t + 1}(w_t, \lambda)}{\partial \lambda} = \frac{\partial w_{t + 1}}{\partial \lambda} \text{ for } t \geq 0, B_0 = \frac{d \Xi_{0} (\lambda)}{d \lambda}\]
- dimensions:
- denote $w_t \in \mathbb{R}^M$ and $\lambda \in \mathbb{R}^N$, then $A_t \in \mathbb{R}^{M \times M}$, $B_t \in \mathbb{R}^{N \times M}$
- reverse mode differentiation (RMD):
\[ \frac{d f}{d \lambda} = h_{-1}, ~~~ \alpha_T = \frac{\partial f}{\partial \hat{w}}, ~~~ h_T = \frac{\partial f}{\partial \lambda}\]
\[ h_{t - 1} = h_t + B_t \alpha_t, ~~~ \alpha_{t - 1} = A_t\alpha_t, ~~~ t = 0, …, T \]
- need to store intermediate values $\{w_t \in \mathbb{R}^M\}_{t = 1}^T$, then the space requirement is $O(MT)$
- forward mode differentiation (FMD):
\[ \frac{d f}{d \lambda} = Z_T \frac{\partial f}{\partial \hat{w}} + \frac{\partial f}{\partial \lambda}, ~~~ Z_0 = B_0\]
\[ Z_{t + 1} = Z_t A_{t + 1} + B_{t + 1}, ~~~ t = 0, …, T - 1\]
- need to propagate the matrices $Z_t \in \mathbb{R}^{M \times N}$, so the time complexity is N times of that of reverse mode RMD (matrix-matrix vs. matrix-vector multiplications), but no need to store intermediate values $w_t$ from the forward pass
- K-step truncated backpropagation (K-RMD):
\[ h_{T - K} = \frac{\partial f}{\partial \lambda} + \sum_{t = T - K + 1}^T B_t A_{t + 1} … A_T \frac{\partial f}{\partial \hat{w}} \]
- only need to store $\{w_t\}_{t = T - k + 1}^T$, the space requirement is $O(MK)$
- main theoretical results (informal):
- (Accuracy) if the inner problem is locally strongly convex around $\hat{w}$, then the bias of $h_{T - K}$ decays exponentially in K.
- (Sufficient Descent) if the inner problem $g$ is second-order continuously differentiable, then $-h_{T - K}$ is a sufficent descent direction.
- (Convergence) following the above results, we have that on-average convergence to $\epsilon$-approximate stationary point is guaranteed by $O(\text{log } 1 / \epsilon)$-step truncated backpropagation
- relation with Implicit Differentiation:
- in the limit where $\hat{w}$ converges to $w^{*}$, $h_{T - K}$ can be viewed as an order-K (i.e. first K terms) Taylor series approximating the matrix inverse in the total derivative, the residual term has an upper bound
\[ \frac{df}{d\lambda} = \frac{\partial f}{\partial \lambda} - \frac{\partial^2 g}{\partial \lambda \partial w} \bigg( \frac{\partial^2 g}{\partial w^2} \bigg)^{-1} \frac{\partial f}{\partial \hat{w}}\]
- note: the above equation relies on the assumptions that (1) g is second-order continuously differentiable (2) there exists a unique optimal solution $w^{*}$ and all the derivatives are evaluated at $w^{*}$
- experiment: compare K-step truncated RMD to K-step Conjuagate Gradient Descent (CG)
- both require local strong-convexity to ensure a good approximation
- if $w^{*}$ is available, then CG gives a smaller bias
- in practice, $w^{*}$ is usually unknown, K-step truncated RMD has a weaker assumption
- in the limit where $\hat{w}$ converges to $w^{*}$, $h_{T - K}$ can be viewed as an order-K (i.e. first K terms) Taylor series approximating the matrix inverse in the total derivative, the residual term has an upper bound
\[ \frac{df}{d\lambda} = \frac{\partial f}{\partial \lambda} - \frac{\partial^2 g}{\partial \lambda \partial w} \bigg( \frac{\partial^2 g}{\partial w^2} \bigg)^{-1} \frac{\partial f}{\partial \hat{w}}\]