This post is a paper-reading note. Paper: Truncated Backpropagation for Bilevel Optimization, Amirreza Shaban & Ching-An Cheng et al. AISTATS 2018.

Bilevel Optimization (BO)

  • one application of BO in Machine Learning is Hyperparamter Optimization (HO) \[ \text{min}_{\lambda} f(\hat{w}, \lambda) \text{ s.t. } \hat{w} \approx \text{argmin}_w g(w, \lambda) \]
    • denote parameter and hyperparameter by $w$ and $\lambda$, the outer and inner objective function by $f$ (validation loss function) and $g$ (training loss function)
    • for this setup, the outer problem (minimizing the validation loss) does not depend on the hyperparameter $\lambda$, so the direct gradient $\frac{\partial \mathcal{f}}{\partial \lambda} = 0$
    • note: we follow Shaban et al. 2018 that includes the algorithm to solver the inner objective as part of the problem’s formulation. Thus, we use $\hat{w}$ to denote the solution given by the solver instead of the exact minimizer $w^{*}$
  • challenges:
    • dependency of the optimization of $\lambda$ on the inner problem is complicated, so that evaluating exact gradients is not scalable for high-dimensional problems
  • prior works:
    • other approaches for HO: Grid Search (black box, run the training procedure many times), Random Search, Bayesian Optimization, hypernetwork (MacKay & Vicol et al. 2019, Bae et al. 2020)
    • Implicit Differentiation: rely on estimate of Jacobian $J_{\lambda} \hat{w} = \frac{\partial \hat{w}}{\partial \lambda}$
      • note: this method relies on the assumption that $\hat{w} = w^{*}$, i.e. optimality of the approximate solution
    • Dynamical System perspective: treat the iterative algorithm to solve the inner problem as a dynamical system and apply backpropagation through the system

Truncated Backpropagation for Bilevel Optimization

  • idea: approximate gradients using Truncated Backpropagation (TBP) through “time” (as in a previous post, TBP reduces time and space complexities by removing long term dependencies)
    • note: here “time” refers to the optimization steps performed to solve the inner problem
  • hypergradient
    • following the Dynamic System perspective, the iterative optimization algorithm that solves inner problem is viewed as a dynamical system: denote the transition function by $\Xi_t$ and the number of iterations by $T$ \[ w_{t + 1} = \Xi_{t + 1} (w_t, \lambda), ~~~ \hat{w_0} = \Xi_{\lambda} \text{ at } t = 0, ~~~ \hat{w} = w_T\]
    • unrolling the computation graph: \[ \frac{df}{d\lambda} = \frac{\partial f}{\partial \lambda} + \sum_{t = 0}^T B_t A_{t + 1} … A_T \frac{\partial f}{\partial \hat{w}} ~~~ \text{# by the Chain Rule}\] \[ \text{where } A_{t + 1} = \frac{\partial \Xi_{t + 1}(w_t, \lambda)}{\partial w_t} = \frac{\partial w_{t + 1}}{\partial w_t}, B_{t + 1} = \frac{\partial \Xi_{t + 1}(w_t, \lambda)}{\partial \lambda} = \frac{\partial w_{t + 1}}{\partial \lambda} \text{ for } t \geq 0, B_0 = \frac{d \Xi_{0} (\lambda)}{d \lambda}\]
    • dimensions:
      • denote $w_t \in \mathbb{R}^M$ and $\lambda \in \mathbb{R}^N$, then $A_t \in \mathbb{R}^{M \times M}$, $B_t \in \mathbb{R}^{N \times M}$
    • reverse mode differentiation (RMD): \[ \frac{d f}{d \lambda} = h_{-1}, ~~~ \alpha_T = \frac{\partial f}{\partial \hat{w}}, ~~~ h_T = \frac{\partial f}{\partial \lambda}\] \[ h_{t - 1} = h_t + B_t \alpha_t, ~~~ \alpha_{t - 1} = A_t\alpha_t, ~~~ t = 0, …, T \]
      • need to store intermediate values $\{w_t \in \mathbb{R}^M\}_{t = 1}^T$, then the space requirement is $O(MT)$
    • forward mode differentiation (FMD): \[ \frac{d f}{d \lambda} = Z_T \frac{\partial f}{\partial \hat{w}} + \frac{\partial f}{\partial \lambda}, ~~~ Z_0 = B_0\] \[ Z_{t + 1} = Z_t A_{t + 1} + B_{t + 1}, ~~~ t = 0, …, T - 1\]
      • need to propagate the matrices $Z_t \in \mathbb{R}^{M \times N}$, so the time complexity is N times of that of reverse mode RMD (matrix-matrix vs. matrix-vector multiplications), but no need to store intermediate values $w_t$ from the forward pass
    • K-step truncated backpropagation (K-RMD): \[ h_{T - K} = \frac{\partial f}{\partial \lambda} + \sum_{t = T - K + 1}^T B_t A_{t + 1} … A_T \frac{\partial f}{\partial \hat{w}} \]
      • only need to store $\{w_t\}_{t = T - k + 1}^T$, the space requirement is $O(MK)$
  • main theoretical results (informal):
    1. (Accuracy) if the inner problem is locally strongly convex around $\hat{w}$, then the bias of $h_{T - K}$ decays exponentially in K.
    2. (Sufficient Descent) if the inner problem $g$ is second-order continuously differentiable, then $-h_{T - K}$ is a sufficent descent direction.
    3. (Convergence) following the above results, we have that on-average convergence to $\epsilon$-approximate stationary point is guaranteed by $O(\text{log } 1 / \epsilon)$-step truncated backpropagation
  • relation with Implicit Differentiation:
    • in the limit where $\hat{w}$ converges to $w^{*}$, $h_{T - K}$ can be viewed as an order-K (i.e. first K terms) Taylor series approximating the matrix inverse in the total derivative, the residual term has an upper bound \[ \frac{df}{d\lambda} = \frac{\partial f}{\partial \lambda} - \frac{\partial^2 g}{\partial \lambda \partial w} \bigg( \frac{\partial^2 g}{\partial w^2} \bigg)^{-1} \frac{\partial f}{\partial \hat{w}}\]
      • note: the above equation relies on the assumptions that (1) g is second-order continuously differentiable (2) there exists a unique optimal solution $w^{*}$ and all the derivatives are evaluated at $w^{*}$
    • experiment: compare K-step truncated RMD to K-step Conjuagate Gradient Descent (CG)
      • both require local strong-convexity to ensure a good approximation
      • if $w^{*}$ is available, then CG gives a smaller bias
      • in practice, $w^{*}$ is usually unknown, K-step truncated RMD has a weaker assumption