Stochastic gradient descent

Rezana Dowra
5 min readNov 27, 2022

--

Regularisation: The case of large n

Stochastic Gradient Descent: The case of Large m

Once the number n of dimensions is moderately large the matrix operations needed for solving the normal equations require a prohibitive amount of computation.

In general you want to use as much data as possible to learn more accurate models. For very large m, it can be problematic to spend O(mn) time computing the gradient in every single iteration of gradient descent.

One way to think about the competing methods of gradient descent and solving the normal equations is as a trade-off between the number of iterations and the computation required per iteration

Solving the right problem

Increasing the number of dimensions

So far we’ve discussed computational issues — how to actually implement gradient descent so that it scales to very large data sets. We’ve been taking the function f as given (e.g., as MSE). How do we know that we’re minimising the “right” f?

One reason that minimising MSE might be the wrong optimisation problem is because of simple type-checking errors. The point of linear regression is to predict real-valued labels for data points. But what if we don’t want to predict a real value? For example, what if we want to predict a binary value, like whether or not a given email is spam? There turn out to be analogs of linear regression and the mean-squared error objective for several other prediction tasks. For example, for binary classification, the most common approach is called “logistic regression,” where the linear prediction functions we’ve been studying are replaced by functions bounded between 0 and 1.

Encoding nonlinear relationships with extra dimensions

So far we have looked at tools that are directed towards linear functions, but can these be used for nonlinear relationship functions? The idea is to increase the number of dimensions in the data points.

With n = 1, consider mapping each data point x^i ∈ R to a (d + 1) dimensional vector: x 7→ xˆ = (1, x, x² , . . . , xᵈ ).
For example, if d = 4, a point with value 3 is mapped to the 5-tuple (1, 3, 9, 81, 243).

Now imagine solving the linear regression problem not with the original data points x¹, . . . , xᵐ , but with their expanded versions xˆ¹, . . . , xˆᵐ.

The result is a linear function in d+1 dimensions, specified by coefficients w₀ , . . . , wₙ
The prediction of w for an expanded data point:

xˆ^i is w^T xˆ^(i) = ∑ wⱼ (x^i)ʲ . We can interpret this as a nonlinear prediction for the original data point x^i .

Overfitting

Computational resources permitting, more data is always better. What about more features always better? The benefit is intuitively clear — better accuracy of fit.

But there is a catch: with many features, there is a strong risk of overfitting — of learning a prediction function that is an artefact of the specific training data set, rather than one that captures something more fundamental, or able to work well outside of the training data.

Thus there is a trade-off between expressivity and predictive power. To understand the issue, let’s return to the case of n = 1 and the map x → xˆ = (1, x, x² , . . . , xᵈ ).

Suppose we take d = m, so that we are free to use polynomials with degree equal to the number of data points. Now we can get a mean-squared error of zero!

The reason is just polynomial interpolation — given m pairs (x¹, y¹), . . . ,(xᵐ , yᵐ )), there is always a degree-m polynomial that passes through all of them.

Using a high degree polynomial

Is a MSE of zero good? Not necessarily. The polynomial in the graph above is quite “squiggly,” and meanwhile there is a line that fits the points quite closely. If the true relationship between the x (i) s and y (i) s is indeed roughly linear (plus some noise), then the line will give much more accurate predictions on new data points x than the squiggly polynomial.

In machine learning terminology, the squiggly polynomial fails to “generalise.” Remember that the point of this exercise is to learn something “general,” meaning a prediction function that transcends the particular data set and remains approximately accurate even for data points that you’ve never seen.

Regularisation: The case of large n

In machine learning it is common that we see the use of all features conceivable and using n as large as possible. With this approach you could easily overfit and this should be something you guard against.

Philosophically, the solution is Occams Razor — to be biased towards simpler models on the basis that they are capturing something more fundamental.

Regularisation is a concrete method for implementing the principle of Occams Razor. The idea is to add a penalty term. For the case of Linear Regression the new optimisation problem is to minimise:

MSE(w) + penalty (w)

where penalty (w) is increasing the complexity of w. Thus a complex solution will be chosen over a simple solution only if it leads to a big decrease in the mean-squared error

L₂ Regularisation

There are many ways to define the penalty term. The most widely used one which has many names: Ridge regression, L₂ regularisation, or Tikhonov regularisation. These all mean that we set

penalty(w) = λ || w||²₂

Where λ is a positive hyperparameter, a configuration that allows you to trade-off smaller MSE (preferred for small λ) versus smaller model complexity (preferred for larger λ)

That is we identify “complex functions” with those with large weights. To see how this addresses the overfitting problem, we note that a degree-m polynomial passing through m points is likely to have all non-zero coefficients, including some large coefficients.
A linear function has mostly zero coefficients, and will be preferred over the squiggly polynomial for data points with an approximately linear relationship (unless λ is very small). Simplicity is always preferred.

In summary

  1. For many machine learning problems, replacing the basic gradient descent method by stochastic gradient descent is critical for big data. While gradient descent touches every data point every iteration, stochastic gradient descent uses as little as one data point in each iteration. Stochastic gradient descent is the dominant paradigm in modern machine learning (e.g., in most deep learning work).
  2. More data is always better, as long as you have the computational resources to handle it.
  3. More features (or dimensions) offer a trade-off, allowing more expressive power at the risk of overfitting the data set. Still, these days the prevailing trend is to include as many features as possible.
  4. Regularisation is key to pushing back on overfitting risk with high-dimensional data. The general idea is to trade off the “complexity” of the learned model (e.g., the magnitudes of weights of a linear prediction function) with its error on the data set.
    The goal of learning the simplest model with good explanatory power, on the basis that this explanation is the most likely to generalise unseen points.
  5. Adding regularisation imposes essentially no extra computational demands on (stochastic) gradient descent.

Reference:

https://timroughgarden.org/s16/l/l6.pdf

--

--

No responses yet