# Why Linear Regression Estimates the Conditional Mean

Because you can never know too much about linear regression.

## Introduction

If you look at any textbook on linear regression, you will find that it says the following:

“Linear regression estimates the **conditional mean** of the response variable.”

This means that, for a given value of the predictor variable \(X\), linear regression will give you the **mean** value of the response variable \(Y\).

Now, in and of itself, it’s a pretty neat fact… but, *why* is it true?

Like me, you may have been tempted to take to google for an answer. And, like me, you *may* have found the online explanations hard to follow.

This is my attempt to break down the explanation more simply.

## Recap on Linear Regression

Let’s begin with a quick recap on linear regression.

In (simple) linear regression, we are looking for a line of best fit to model the relationship between our predictor, \(X\) and our response variable \(Y\).

The line of best fit takes the form of an equation

\[Y = \beta_{0} + \beta_{1}X\]

where \(\beta_{0}\) is the intercept, and \(\beta_{1}\) is the coefficient of the slope.

To find the intercept and slope coefficients of the line of best fit, linear regression uses the least squares method, which seeks to minimise the sum of squared deviations between the \(n\) **observed** data points \(y_{1}...y_{n}\) and the **predicted** values, which we’ll call \(\hat{y}\):

\[\sum_{i=1}^{n} (y_{i} - \hat{y})^2\]

And, as it turns out, the values for the coefficients that we obtain by minimising the sum of squared deviations *always* result in a line of best fit that estimates the conditional mean of the response variable \(Y\).

Why? Well, the “simple” answer is that it can be proved mathematically. That’s not a very satisfying or helpful answer though.

However, one thing I *do* think is helpful for understanding the “why” is exploring the sum of squared deviations in a slightly simpler context.

## The Sum of Squared Deviations Method

So far, we’ve talked about minimising the sum of squared deviations in the context of linear regression. But, minimising the sum of squared deviations is a general method that we can also apply in other contexts.

For instance, let’s generate a dataset of 1000 numbers, with a mean of ~20 and a standard deviation of 2.

```
set.seed(8825)
sample_data <- rnorm(1000, mean = 20, sd = 2)
# confirm mean is ~= 20
mean(sample_data)
```

`## [1] 19.92143`

Now, we could calculate the sum of the squared deviations of each of these data points from the mean…

`sum((sample_data - mean(sample_data))^2)`

`## [1] 4004.373`

…(which is exactly what we’d need to do to calculate the standard deviation of the data)

And we could also calculate the sum of the squared deviations of these data points from any other value, such as the median, mode, or any other arbitrary value.

For instance, here’s the what we get if we calculate the sum of squared deviations of each data point from the median.

`sum((sample_data - median(sample_data))^2)`

`## [1] 4008.821`

So now, let’s calculate the sum of the squared deviations using a variety of different values:

```
# values to calculate the deviation from in our dataset
dev_values <- c(0, 5, 10, 12, 15, 18, 19, 19.92, 21, 22, 25, 28, 30, 35, 40)
# generate empty list
squared_residuals <- rep(NA, length(dev_values))
# calculate sum of squared deviations of the data from each value in dev_values
for (i in 1:length(dev_values)) {
squared_residuals[i] = sum((sample_data - dev_values[i])^2)
}
squared_residuals
```

```
## [1] 400867.938 226653.590 102439.242 66753.502 28224.893 7696.285
## [7] 4853.415 4004.375 5167.676 8324.806 29796.197 69267.588
## [13] 105581.849 231367.501 407153.153
```

Next, let’s plot the resulting sum of squared deviations obtained using each value:

```
data.frame(dev_values = dev_values, squared_residuals = squared_residuals) %>%
ggplot(aes(dev_values, squared_residuals)) +
stat_smooth(method="lm",
formula = y ~ poly(x, 2),
se = FALSE,
colour = "#FCC3B6",
linetype = "dashed") +
geom_point(col = "#C70039", size = 2.2, alpha = 0.7) +
labs(title = "Sum of Squared Residuals (SSR) Loss Function",
x = "Summary Value",
y = "SSR") +
theme_minimal() +
theme(text = element_text(family = "Lato"),
plot.title = element_text(family = "Lato Semibold", hjust = 0.5)) +
scale_y_continuous(labels = scales::comma)
```

Notice that the value that gives us the *smallest* sum of squared deviations, the lowest point on our curve, turns out to be 19.92, which is the mean of our dataset!

Now, this isn’t just a fun feature of our sample dataset; given any set of numbers \(x_{1}...x_{n}\), the value that results in the smallest sum of squared deviations will *always* be the mean.

And just in the same way, in linear regression, the predicted \(\hat{y}\) values that minimise the sum of squared deviations will always be the conditional mean of \(y\).

Now, this simulation might help you see how minimising the sum of squared deviations is equivalent to using the mean, but it still doesn’t explain *why* it’s the case.

For that, we need to look at the mathematical proof. Here, again, we’re going to focus on the slightly simpler use-case of minimising the sum of squares for a single set of values.

## Mathematical Proof: Background

When we calculate the sum of squared deviations between some sample data \(y_{1}...y_{i}\), and another value \(\hat{y}\), what we’re really doing is passing the data through a function:

\[f(y) = \sum_{i=1}^{n} (y_{i} - \hat{y})^2\] And, in minimising the sum of squared deviations, our aim is to find the value for \(\hat{y}\) that minimises the output of the function.

Now, whenever we have a function whose output we want to minimise, we call the function a **loss** function, denoted as \(L(y)\).

So, we can write our sum of squared deviations function as this:

\[L(y) = \sum_{i=1}^{n} (y_{i} - \hat{y})^2\]

Whenever we want to find the value of \(\hat{y}\) that **minimises** a loss function, the way to solve this problem is by differentiation.

Why? Well, let’s take a look back at our plot, where we calculated the sum of squared deviations for different values of \(\hat{y}\).

The value that we want to find out, the one that minimises the sum of squared deviations, is the one at the lowest point of the curve, where the gradient of the curve is equal to zero.

And so, what we’re really doing here is asking, what value does my summary statistic take, at the point at which the **gradient** of the sum of squared deviations function is equal to zero?

And finding gradients? Well, that’s a job for differentiation!

So, we want to differentiate our loss function:

\[\displaystyle \frac{d}{d\hat{y}} \lbrace{L(y)}\rbrace = \frac{d}{d\hat{y}} \lbrace\sum_{i=1}^{n} (y_{i} - \hat{y})^2\rbrace\] Differentiating the loss function gives us this:

\[\displaystyle \frac{dL}{d\hat{y}} = \sum_{i=1}^{n} -2(y_{i} - \hat{y})\] It can be a little tricky to understand what’s happened here, especially if you’re not using to differentiations involving \(\sum_{}\) symbols and \(y_{i}\) terms.

To make it a bit clearer what I’ve just done, I’m going to momentarily pause on differentiating our actual loss function and instead detour to a simpler problem, that of differentiating the equation \(y = (1 - \hat{y})^{2}\).

Now, we can write this equation like so:

\[\displaystyle y = (1 - \hat{y})(1 - \hat{y})\]

Which, expanded out, gives us the following:

\[\displaystyle y = (1 - 2\hat{y} + \hat{y}^{2})\]

Finally, differentiating the above gives us this:

\[\displaystyle \frac{dy}{d\hat{y}} = (2\hat{y} - 2)\] which is also equivalent to this:

\[\displaystyle \frac{dy}{d\hat{y}} = -2(1 - \hat{y})\] And, the differentiation works pretty much the same way for our *actual* function, \(L(y) = \sum_{i=1}^{n} (y_{i} - \hat{y})^2\):

\[\displaystyle \frac{dL}{d\hat{y}} = \sum_{i=1}^{n} -2(y_{i} - \hat{y})\]

Now, as mentioned earlier, to minimise the loss function, we need to find the value of \(\hat{y}\) when the gradient is zero, so let’s set this whole thing equal to zero:

\[\displaystyle 0 = \sum_{i=1}^{n} -2(y_{i} - \hat{y})\]

Divide both sides by -2 and we get this:

\[\displaystyle 0 = \sum_{i=1}^{n} (y_{i} - \hat{y})\]

Now, in the same way that \(3(5- 4)\) is the same as \(3 * 5 - 3 * 4\), the sum of \(y_{i} - \hat{y}\) is the same as saying the sum of the \(y_{i}\) values, minus the sum of adding up \(\hat{y}\) n times:

\[\displaystyle 0 = \sum_{i=1}^{n} y_{i} - \sum_{i=1}^{n}\hat{y}\]

And, the sum of adding up \(\hat{y}\) n times can also be written like so:

\[\displaystyle 0 = \sum_{i=1}^{n} y_{i} - n\hat{y}\]

We want to find the value of \(\hat{y}\), so let’s rearrange the equation a little:

\[\displaystyle n\hat{y} = \sum_{i=1}^{n} y_{i}\]

Finally, let’s divide both sides by n to find the value of \(\hat{y}\).

\[\displaystyle \hat{y} = \frac{\sum_{i=1}^{n} y_{i}}{n}\] And let’s look at what we’re left with here; \(\hat{y}\) is equal to the sum of all the \(y_{1}...y_{n}\) values in the data set, divided by \(n\), the number of values in the dataset… otherwise knows as, the mean of the \(y\) values!

And that’s that!