date-created: 2024-06-28 08:47:43 date-modified: 2024-06-28 08:49:55

Understanding L2 regularization, Weight decay and AdamW

anchored to 116.00_anchor_machine_learning requires / attaches to 116.14_deep_learning

Source: link


What is regularization?

In simple words regularization helps in reduces over-fitting on the data. There are many regularization strategies.

The major regularization techniques used in practice are:

  • L2 Regularization
  • L1 Regularization
  • Data Augmentation
  • Dropout
  • Early Stopping

In L2 regularization, an extra term often referred to as regularization term is added to the loss function of the network.

Consider the the following cross entropy loss function (without regularization):

loss=−1m∑i=1m(y(i)log⁡(yhat(i))+(1−y(i))log⁡(1−yhat(i)))

To apply L2 regularization to the loss function above we add the term given below to the loss function :

λ2m∑ww2

where λ is a hyperparameter of the model known as the regularization parameter. λ is a hyper-parameter which means it is not learned during the training but is tuned by the user manually

After applying the regularization term to our original loss function : finalLoss=−1m∑i=1m(y(i)log⁡(yhat(i))+(1−y(i))log⁡(1−yhat(i)))+λ2m∑ww2

or , finalLoss=loss+λ2m∑ww2

or in simple code :

final_loss = loss_fn(y, y_hat) + lamdba * np.sum(np.pow(weights, 2)) / 2 final_loss = loss_fn(y, y_hat) + lamdba * l2_reg_term

Note: all code equations are written in python, numpy notation.

Cosequently the weight update step for vanilla SGD is going to look something like this:

w = w - learning_rate * grad_w - learning_rate * lamdba * grad(l2_reg_term, w) w = w - learning_rate * grad_w - learning_rate * lamdba * w

Note: assume that grad_w is the gradients of the loss of the model wrt weights of the model.

Note: assume that grad(a,b) calculates the gradients of a wrt to b.

In major deep-learning libraires L2 regularization is implemented by by adding lamdba * w to the gradients, rather than actually changing the loss function.

Weight Decay :

In weight decay we do not modify the loss function, the loss function remains the instead instead we modfy the update step :

The loss remains the same :

final_loss = loss_fn(y, y_hat)

During the update parameters :

w = w - learing_rate * grad_w - learning_rate * lamdba * w

Tip: The major difference between L2 regularization & weight decay is while the former modifies the gradients to add lamdba * w , weight decay does not modify the gradients but instead it subtracts learning_rate * lamdba * w from the weights in the update step.

A weight decay update is going to look like this :

In this equation we see how we subtract a little portion of the weight at each step, hence the name decay.

Important: From the above equations weight decay and L2 regularization may seem the same and it is infact same for vanilla SGD , but as soon as we add momentum, or use a more sophisticated optimizer like Adam, L2 regularization and weight decay become different.

Weight Decay != L2 regularization

SGD with Momentum :

To prove this point let’s first take a look at SGD with momentum

In SGD with momentum the gradients are not directly subtracted from the weights in the update step.

  • First, we calculate a moving average of the gradients .
  • and then , we subtract the moving average from the weights.

For L2 regularization the steps will be :

Compute gradients gradients = grad_w + lamdba * w Compute the moving average Vdw = beta * Vdw + (1-beta) * (gradients) Update the weights of the model w = w - learning_rate * Vdw

Now, weight decay’s update will look like

Compute gradients gradients = grad_w Compute the moving average Vdw = beta * Vdw + (1-beta) * (gradients) Update the weights of the model w = w - learning_rate * Vdw - learning_rate * lamdba * w

Note: Vdw is a moving average of the parameter w . It starts at 0 and then at each step it is updated using the formulae for Vdw given above. beta is a hyperparameter .

Adam :

This difference is much more visible when using the Adam Optimizer. Adam computes adaptive learning rates for each parameter. Adam stores moving average of past squared gradients and moving average of past gradients. These moving averages of past and past squared gradients Sdw and Vdw are computed as follows:

Vdw = beta1 * Vdw + (1-beta1) * (gradients) Sdw = beta2 * Sdw + (1-beta2) * np.square(gradients)

Note: similar to SGD with momentum Vdw and Sdw are moving averages of the parameter w. These moving averages start from 0 and at each step are updated with the formulaes given above. beta1 and beta2 are hyperparameters.

and the update step is computed as :

w = w - learning_rate * ( Vdw/(np.sqrt(Sdw) + eps) )

Note: eps is a hypermarameter added for numerical stability. Commomly, eps=1e−08 .

For L2 regularization the steps will be :

Compute gradients and moving_avg gradients = grad_w + lamdba * w Vdw = beta1 * Vdw + (1-beta1) * (gradients) Sdw = beta2 * Sdw + (1-beta2) * np.square(gradients) Update the parameters w = w - learning_rate * ( Vdw/(np.sqrt(Sdw) + eps) )

For weight-decay the steps will be :

Compute gradients and moving_avg gradients = grad_w Vdw = beta1 * Vdw + (1-beta1) * (gradients) Sdw = beta2 * Sdw + (1-beta2) * np.square(gradients) Update the parameters w = w - learning_rate * ( Vdw/(np.sqrt(Sdw) + eps) ) - learning_rate * lamdba * w

The difference between L2 regularization and weight decay is clearly visible now.

In the case of L2 regularization we add this lamdba∗w to the gradients then compute a moving average of the gradients and their squares before using both of them for the update.

Whereas the weight decay method simply consists in doing the update, then subtract to each weight.

After much experimentation Ilya Loshchilov and Frank Hutter suggest in their paper : DECOUPLED WEIGHT DECAY REGULARIZATION we should use weight decay with Adam, and not the L2 regularization that classic deep learning libraries implement. This is what gave rise to AdamW.

In simple terms, AdamW is simply Adam optimzer used with weight decay instead of classic L2 regularization.