Cutting Your Losses: Loss Functions & the Sum of Squares Loss
Often times, particularly in a regression framework, we are given a set of inputs (independent variables) and a set outputs (dependent variables) , and we want to devise a model function
that predicts the outputs given some inputs as best as possible. But what does it mean for a model to predict “as best as possible” exactly? In order to make the notion of how good a model is explicit, it is common to adopt a loss function
The loss function is some function of the model’s prediction errors (a.k.a. residuals) at predicting outputs given the inputs (the loss function is also often referred to as the cost function, as it makes explicit the “cost” of incorrect prediction). “Good” models of a dataset will have small prediction errors, and therefore produce small loss function values. Determining the “best” model is equivalent to finding model function that minimizes the loss function. A common choice for this loss function is the sum of squared of the errors (SSE) loss. If there are input-output pairs, the SSE Loss function is formally:
This formula states that, for each output predicted by the model, we determine how far away the prediction is from the actual value (i.e. subtraction). Each of the individual distances are then squared and added to give a single number indicating how well (or badly) the model function captures the structure of the data across all the datapoints. The “best” model under this loss is called the least sum of squares (LSS) solution.
But why square the errors before summing them? At first, this seems somewhat unintuitive (or even ad hoc!). Surely there are other, more straight-forward loss functions we can devise. An initial notion of just adding the errors leads to a dead end because adding many positive and negative errors (i.e. resulting from data located below and above the model function) just cancels out; we want our measure of errors to be all positive (or all negative). Therefore, another idea would be to just take the absolute value of the errors before summing. Turns out, this is a known loss function, called the sum of absolute errors (SAE) or sum of absolute deviations (SAD) loss function. Finding the “best” SAE/SAD model is called the least absolute error LAE/LAD solution and such a solution was actually proposed decades before LSS. Though LAE is indeed used in contemporary methods (we’ll talk more about LAE later), the sum of squares loss function is far more popular in practice. Why does sum of squares always make the cut?
Useful Interpretations of the Sum of Squares Loss for Linear Regression
Areas of squares
Figure 1 demonstrates a set of 2D data (blue dots) and the LSS linear function (black line) of the form
where the parameters (the offet of the line from ) and (the slope) have been estimated (LSS) to “best” fit the data.
In this interpretation, the goal of finding the LSS solution is to find the line that results in the smallest red area. This interpretation is also useful for understanding the important regression metric known as the coefficient of determination , which is an indicator of how well a linear model function explains or predicts a dataset. Imagine that instead of the line fit in Figures 1-2, we instead fit a simpler model that has no slope parameter, and only a bias/offset parameter (Figure 3). In this case the simpler model only captures the mean value of the data along the y-dimension.
The squared error in this model corresponds to the (unscaled) variance of the data. Lets denote the total area of the green squares in Figure 3 as the total sum of squares (TSS) error, and the area of the red squares in Figure 2 as the residual sum of squares (RSS) of the linear model fit. The metric is related to the red and green areas as follows
If the linear model is doing a good job of fitting the data, then the variance of the model errors/residuals (RSS) term will be small compared to the variance of the dataset (TSS), and the metric will be close to one. If the model is doing a poor job of fitting the data, then the variance residuals will approach that of the data itself, and the metric will be close to zero (Note too that the value of can also take negative values, in the case when the RSS is larger than the TSS, indicating a very poor model).
A bar suspended by springs
We can gain some important insight to the importance of the least squares loss by developing concepts within the framework of a physical system (Figure 4). In this formulation, a set of springs (red, dashed lines, our errors ) suspend a bar (solid black line, our linear function to a set of anchors (blue datapoints, our outputs ). Note that in this formulation, the springs are constrained to operate only along the vertical direction (y-dimension). This constraint is equivalent to saying that there is only error in our measurement of the dependent variables, and is often an assumption made in regression frameworks.
From Hooke’s Law, the force created by each spring on the bar is proportional to the distance (error) from the bar (linear function) to its corresponding anchor (datapoint) :
Further, there is a potential energy associated with each spring (datapoint). The total potential energy for the entire system is as follows:
From the equation corresponding to the first zero-net-force condition, we can solve for the bias parameter of the linear function that describes the orientation of the bar:
Here the (pronounced “bar”) means the average value. Plugging this expression into the second second zero-net-torque condition equation, we discover that the slope of the line has an interesting interpretation related to the variances of the data:
The expressions for the parameters and tell us that, under the least squares linear regression framework, The average of the dependent variables is equal to a scaled version of the average of independent variables plus an offset:
Further, the scaling factor (the slope) is equal to the ratio of the covariance between the dependent and independent variables to the variance of the independent variable. Therefore if and are positively correlated, the slope will be positive, if they are negatively correlated, the slope will be negative.
Because of these relationships, the LSS solution has a number of useful properties:
- The sum of the residuals under the LSS solution is zero(this is equivalent to the first zero-net-force condition above).
- Because of 1., the average residual of the LSS solution is zero
- The covariance between the independent variables and the residuals is zero.
- The LLS solution always passes through the mean (center of mass) of the sample
- The LSS solution minimizes the variance of the residuals/model errors.
Therefore, the least squares loss function directly relates model residuals to how the independent and dependent variables co-vary. These relationships are not available with other loss functions such as the least absolute deviation.
What’s interesting, is that the two physical constraint equations derived from the physical system above are also obtained through other analytic analyses of linear regression including defining the LSS problem using both maximum likelihood estimation and method of moments.
There are many other reasons, albeit suggestions, as to why squared errors are often preferred to other rectifying functions of the errors (i.e. making all errors be positive or zero):
- The Least Squares solution can be derived in closed form, allowing simple analytic implementations and fast computation of model parameters.
- Unlike the LAE loss, the SSE loss is differentiable (i.e. is smooth) everywhere, which allows model parameters to be estimated using straight-forward, gradient-based optimizations
- Squared errors have deep ties in statistics and maximum likelihood estimation methods (as mentioned above), particularly when the errors are distributed according to the Golden Boy of statistical distributions, the Normal distribution.
- There are a number of geometric and linear algebra theorems that support using least squares. For instance the Gauss-Markov theorem states that if errors of a linear function are distributed Normally about the mean of the line, then the LSS solution gives the best unbiased estimator for the parameters .
- Squared functions have a long history of facilitating calculus calculations used throughout the physical sciences.
The SSE loss does have a number of downfalls as well. For instance, because each error is squared, any outliers in the dataset can dominate the parameter estimation process. For this reason, the LSS loss is said to lack robustness. Therefore preprocessing of the the dataset (i.e. removing or thresholding outlier values) may be necessary when using the LSS loss.
The code to produce the figures in this post is below. The code can be directly copied to your clipboard using the toolbar at the top right of the code display.
close all; clear randn('state',12345); x = -5:5; % INPUTS y = x + .5*randn(size(x)); % OUTPUTS % PART 1 - LINEAR MODEL % PLOT DATA figure(1); h1 = scatter(x,y,'filled'); xlim([min(x)-1 max(x)+1]); xlim([min(y)-1 max(y)+1]); axis square xlabel('x'); ylabel('y'); fprintf('\nHere are a set of 2D data pairs...'); pause clc % FIND LSS SOLUTION bias = ones(size(x)); beta=regress(y',[bias;x]'); intercept = beta(1); slope = beta(2); yHat = x*slope + intercept; % PLOT MODEL FIT hold on; h2 = plot(x,yHat,'k-','Linewidth',2); fprintf('\n...and the LSS linear function determined for the points.\n'); pause clc % CALCULATE PREDICTION ERROR e = y - yHat; posErrors = find(e>=0); negErrors = setdiff(1:numel(x),posErrors); % PLOT "SQUARED" ERRORS cnt = 1; for iP = 1:numel(posErrors); xs = [x(posErrors(iP))-e(posErrors(iP)), ... x(posErrors(iP)), ... x(posErrors(iP)), ... x(posErrors(iP))-e(posErrors(iP))]; ys = [y(posErrors(iP))-e(posErrors(iP)), ... y(posErrors(iP))-e(posErrors(iP)), ... y(posErrors(iP)), ... y(posErrors(iP))]; hS(cnt)=patch(xs,ys,'r'); set(hS(cnt),'EdgeColor','r'); set(hS(cnt),'FaceAlpha',.5); cnt = cnt+1; end for iN = 1:numel(negErrors); xs = [x(negErrors(iN))-e(negErrors(iN)), ... x(negErrors(iN)), ... x(negErrors(iN)), ... x(negErrors(iN))-e(negErrors(iN))]; ys = [y(negErrors(iN)), ... y(negErrors(iN)), ... y(negErrors(iN))-e(negErrors(iN)), ... y(negErrors(iN))-e(negErrors(iN))]; hS(cnt)= patch(xs,ys,'r'); set(hS(cnt),'EdgeColor','r'); set(hS(cnt),'FaceAlpha',.5); cnt = cnt+1; end uistack(h2,'top'); uistack(h1,'top'); fprintf('\nOne helpful interpretation is to represent the') fprintf('\nsquared errors literally as the area spanned') fprintf('\nin the space (red squares).\n') fprintf('\nFinding the LSS solution is equivalent to minimizing') fprintf('\nthe sum of the area of these squares.\n') pause clc % PART 2 - TRIVIAL LINEAR MODEL figure(2); h1 = scatter(x,y,'filled'); xlim([min(x)-1 max(x)+1]); xlim([min(y)-1 max(y)+1]); axis square xlabel('x'); ylabel('y'); yHat0 = mean(y).*ones(size(x)); e0 = y - yHat0; posErrors0 = find(e0>=0); negErrors0 = setdiff(1:numel(x),posErrors0); % PLOT TRIVIAL MODEL FIT hold on; h2 = plot(x,yHat0,'k-','Linewidth',2); fprintf('\nNow, imagine that we fit a simpler model to the datapoints') fprintf('\nthat is a line with no slope parameter and an offset parameter.') fprintf('\nIn this case we''re essentially fitting the mean of the data.') pause clc % PLOT TRIVIAL MODEL "SQUARED" ERRORS cnt = 1; for iP = 1:numel(posErrors0); xs = [x(posErrors0(iP))-e0(posErrors0(iP)), ... x(posErrors0(iP)), ... x(posErrors0(iP)), ... x(posErrors0(iP))-e0(posErrors0(iP))]; ys = [y(posErrors0(iP))-e0(posErrors0(iP)), ... y(posErrors0(iP))-e0(posErrors0(iP)), ... y(posErrors0(iP)), ... y(posErrors0(iP))]; hS(cnt)=patch(xs,ys,'g'); set(hS(cnt),'EdgeColor','g'); set(hS(cnt),'FaceAlpha',.5); cnt = cnt+1; end for iN = 1:numel(negErrors0); xs = [x(negErrors0(iN))-e0(negErrors0(iN)), ... x(negErrors0(iN)), ... x(negErrors0(iN)), ... x(negErrors0(iN))-e0(negErrors0(iN))]; ys = [y(negErrors0(iN)), ... y(negErrors0(iN)), ... y(negErrors0(iN))-e0(negErrors0(iN)), ... y(negErrors0(iN))-e0(negErrors0(iN))]; hS(cnt)= patch(xs,ys,'g'); set(hS(cnt),'EdgeColor','g'); set(hS(cnt),'FaceAlpha',.5); cnt = cnt+1; end uistack(h2,'top'); uistack(h1,'top'); fprintf('\nTherefore the sum of residuals for this model area equal to the') fprintf('\n(unscaled) variance of the data.\n') pause fprintf('\nThe ratio of the area of the green boxes in Figure 1 to the area of') fprintf('\nthe red boxes in Figure 2 is related to the important metric known') fprintf('\nas the coefficient of determination, R^2. Specifically:\n') fprintf('\nR^2 = 1 - Red/Green\n') fprintf('\nNote that as the linear model fit improves, the area of the red') fprintf('\nboxes decreases and the value of R^2 approaches one.') pause clc % PART 3- "SPRINGS" INTERPRETATION figure(3) h1 = scatter(x,y,'filled'); xlim([min(x)-1 max(x)+1]); xlim([min(y)-1 max(y)+1]);hold on h2 = plot(x,yHat,'k-','Linewidth',2); axis square xlabel('x'); ylabel('y'); h3 = line([x;x],[y;yHat],'color','r','linestyle','--','Linewidth',2); uistack(h1,'top'); fprintf('\nIt can also be helpful to think of errors corresponding to') fprintf('\nindividual datapoints as springs (Figuree 3, red dashes)') fprintf('\nattached to a suspended bar (black line)\n') fprintf('\nIf the springs are limited to only operate in the the y-direction,') fprintf('\nthen sum of the potential energies stored in the springs when') fprintf('\nthe bar has reached its equilibtium position directly corresponds') fprintf('\nto the sum of squares error function:\n') fprintf('\ne = y - yHat;\nU = integral(ke)de = 1/2ke^2\n') fprintf('\nif k = 1, then:\n') fprintf('\nU = 1/2(y - yHat)^2, \nthe least squares error function\n') pause clc close all; clear all; clc