Subho's research at your service 🫡

Preconditioned SGD can level up your training game

For those who are a bit new to ML, you might find this post to be a little difficult to comprehend but I shall try my best to make you clear till the end of this post.

As the title suggests, we might be discussing about Preconditioned stochastic gradient descent and uncover some very interesting intuitions behind these optimizers and how they converge.

So to start, we shall first scratch some basics of what SGD is! SGD in simple terms is an optimization technique which randomly selects a training sample (or small batch) at a time and calculates the gradients for minimisation of loss function.

wt+1=wtηL(wt;xi,yi)

But this is not gonna work everytime, the reason being a single scale optimization for a batch in every iteration, which means you take steps of same size in every direction be it steep slopes or gentle slopes, same shoe for all terrains.

Preconditioned SGD exploits this by using a preconditioner matrix which helps to normalize the scale of different parameters, thus helping to take larger steps in directions with small gradients and smaller steps in directions with large gradients leading to faster convergence.

wt+1=wtηPtL(wt;xi,yi)

The preconditioner Pt is typically designed to approximate the inverse of the Hessian matrix (H⁻¹) or a similar curvature metric. This is powerful because the Hessian contains information about how different parameters interact and how quickly the loss changes in different directions.

Now how to find the best preconditioner, there must be a way? Yup there is, and finding the right precond matrix will help us a ton for optimizing our objective.

There are multiple criterions for finding best preconditioner, which you may find here, but here I might only be explaining about Newton preconditioner fitting criterion.

It goes like this - Ev[hTPh+vTP1v]

Where h is the Hessian vector product which represents how the curvature of the function affects a direction. And v represents the random vector helping us explore different directions.

This criterion has a solution of P2=H2. The above criteria balances 2 effects, how preconditioner effects the curvature hTPh and its inverse effect vTP1v by taking the expectation over random vectors v. This gives us a more global view of the curvature. So when the Hessian H is positive definite, this becomes equivalent to Newton's method. It's like finding glasses that perfectly correct for the curvature in all directions.

Now we know that working on full preconditioner matrix P is computationally expensive, so lets represent it in some inner product form,

P=Q1Q2...Qd

where ⊗ denotes the Kronecker product and d is the number of dimensions in our parameter tensor. And the optimization objective for the preconditioner (from the document) follows one of these forms.

  1. For Newton-type updates, we use the same update rule

Ev[hTPh+vTP1v]

  1. Gradient whitening adjusts PSGD steps based on the historical "steepness" of each direction.

Ev,z[gzTPgz+vTP1v]

P2=tgtgtT or Ez[gzgzT]

where g_z represents per-sample gradients.

# From the implementation
if precond_type != "Newton":
    v, g_damped = damped_pair_vg(g)
    update_precond_kron_math_(*Q_exprs, v, g_damped, lr_precond, step_normalizer, self._tiny)
  1. The update rule for each factor Qi follows a multiplicative form like this for every Kronecker-factored samples like this,

QiQiμQihhTQiTQiTvvTQi1Qih2+QiTv2Qi

Now this equation ensures that matrices Qi lies in a lie group coz we want smooth changes in optimization and updates to be invertible. But why though, now I shall test your remembering cabaility. Don't fear I won't do much of a scene here but if you remember the solution to the criteria we picked is P2=H2.

So we need P1 to properly normalize the optimization space. Also instead of additive update (Q + ΔQ), we use:

Qnew=(Iμ×update)×Q

This multiplicative form ensures we stay in the Lie group - the new matrix remains invertible.

Now lets get to the main objective already, by measuring how well our preconditioner P fits both the Hessian-vector product hTPh and its inverse behavior vTP1v.

First we shall compute the Hessian-vector product (h).

if exact_hvp:
    # Computing exact Hessian-vector product using autograd
    with torch.enable_grad():
        loss = closure_returns if isinstance(closure_returns, torch.Tensor) else closure_returns[0]
        grads = torch.autograd.grad(loss, self._params, create_graph=True)
        
        vs = [torch.randn_like(p) for p in self._params]  
        Hvs = torch.autograd.grad(grads, self._params, vs) 
else:
    # Approximate Hessian-vector product using finite differences
    with torch.enable_grad():
        loss = closure_returns if isinstance(closure_returns, torch.Tensor) else closure_returns[0]
        grads = torch.autograd.grad(loss, self._params)
 
    vs = [self._delta * torch.randn_like(p) for p in self._params]
    for param, v in zip(self._params, vs):
        param.add_(v)
    
    with torch.enable_grad():
        perturbed_loss = closure_returns if isinstance(perturbed_returns, torch.Tensor) else perturbed_returns[0]
        perturbed_grads = torch.autograd.grad(perturbed_loss, self._params)
    
    # Compute approximate Hvp
    Hvs = [pg - g for pg, g in zip(perturbed_grads, grads)]

Then, these values are used in the preconditioner update, where we get the hTPhv and vTP1v terms, here is the code for reference,

def update_precond_kron_math_(Q, exprs, V, G, step, step_normalizer, tiny):
    # G is our Hessian-vector product h
    # V is our random vector v
    
    # Compute A = Q * G (this gives us the Ph term)
    exprA, exprGs, _ = exprs
    A = exprA(*Q, G)

    if V is not None:
        # Computing Q^{-1}v term
        invQhinvQ, trace_invQhinvQ = None, None
        p = list(range(order))
        conjB = torch.permute(V.conj(), p[1:] + p[:1])
        
        # Apply inverse preconditioner to v
        for i, q in enumerate(Q):
            conjB = conjB/q if q.dim()<2 else solve_triangular_right(conjB, q)
            if i < order - 1:
                conjB = torch.transpose(conjB, i, order - 1)

The update itself follows the optimization on the Lie group which is the optimization objective for the preconditioner matrices.

# Update each Q factor
    for i, q in enumerate(Q):
        # term1 corresponds to h^TPh
        term1 = exprGs[i](A, A.conj())
        
        # term2 corresponds to v^TP^{-1}v
        if conjB is not None:
            term2 = exprGs[i](conjB.conj(), conjB)
        
        # The update maintains the Lie group structure
        if q.dim() < 2:
            q.sub_(step/(torch.max(torch.abs(term1 + term2)) + tiny) * 
                   (term1 - term2) * q)
        else:
            q.sub_(step/(torch.linalg.norm(term1 + term2) + tiny) * 
                   torch.triu(term1 - term2) @ q)

Lets hop on to the results :)

MNIST_OPTIMS

I did train a binary classifier model on MNIST data and used PCA for 2D visualization. This was a binary classification problem and shows each optimizers decision boundary in reduced space.

KronPSGD shows a sharper, more well-defined decision boundary between digits 0 and 1. Less uncertainty in the transition region compared to Adam and SGD.

Also Adam shows more speckled/noisy regions, while the red and blue regions are more uniform in KronPSGD.

Wanna test it out yourself? Here is where you can find the impl 🥂.

Follow me on X for more of such implementations, see ya!