最小二乘法
我们已经推导出了穿过橄榄球形散点图的回归线方程。但并非所有散点图都是橄榄球形的,即使是线性的散点图也不一定。每个散点图是否都存在一条穿过它的“最佳”直线?如果是,我们能否仍然使用上一节中推导的斜率和截距公式,还是需要新的公式?
要解决这些问题,我们需要对“最佳”给出一个合理的定义。回顾一下,该直线的目的是在给定 $x$ 值的情况下“预测”或“估计” $y$ 的值。估计通常并不完美。每个估计值与真实值之间都存在一个“误差”。一条直线被认为是“最佳”的合理标准是:它在所有直线中具有最小的总体误差。
在本节中,我们将使这个标准更加精确,并看看能否在该标准下找到最佳直线。
from datascience import *
%matplotlib inline
path_data = '../../../assets/data/'
import matplotlib.pyplot as plots
plots.style.use('fivethirtyeight')
import numpy as np
def standard_units(any_numbers):
"Convert any array of numbers to standard units."
return (any_numbers - np.mean(any_numbers))/np.std(any_numbers)
def correlation(t, x, y):
return np.mean(standard_units(t.column(x))*standard_units(t.column(y)))
def slope(table, x, y):
r = correlation(table, x, y)
return r * np.std(table.column(y))/np.std(table.column(x))
def intercept(table, x, y):
a = slope(table, x, y)
return np.mean(table.column(y)) - a * np.mean(table.column(x))
def fit(table, x, y):
"""Return the height of the regression line at each x value."""
a = slope(table, x, y)
b = intercept(table, x, y)
return a * table.column(x) + b
我们的第一个示例是一个数据集,其中小说《小妇人》的每一章对应一行。目标是根据句号的数量来估计字符数(即字母、空格、标点符号等)。回顾一下,我们在本课程的第一讲中就尝试过这样做。
little_women = Table.read_table(path_data + 'little_women.csv')
little_women = little_women.move_to_start('Periods')
little_women.show(3)
<IPython.core.display.HTML object>little_women.scatter('Periods', 'Characters')
Scatterplot with 'Periods' on the x-axis and 'Characters' on the y-axis. The x-axis ranges from about 100 to 400. The y-axis ranges from about 10000 to 40000. There is a positive correlation with no outliers.为了探索这些数据,我们需要使用上一节中定义的函数 correlation、slope、intercept 和 fit。
correlation(little_women, 'Periods', 'Characters')
0.9229576895854816散点图非常接近线性,相关性超过 0.92。
估计中的误差
下图显示了我们在上一节中建立的散点图和直线。我们尚不知道这是否是所有直线中的最佳选择。我们首先必须精确地说出“最佳”的含义。
lw_with_predictions = little_women.with_column('Linear Prediction', fit(little_women, 'Periods', 'Characters'))
lw_with_predictions.scatter('Periods')
Scatterplot with 'Periods' on the x-axis. Two colors of data points are shown. In dark blue is 'Characters,' and in gold is 'Linear Predictions.' The dark blue data points are the same as in the previous graph and the gold data points exist with one per x value of the blue data points. The gold data points are in the center of the dark blue data points and clearly show that as the x value increases so does the y value.对应于散点图上的每个点,存在一个预测误差,计算方式为实际值减去预测值。它是该点与直线之间的垂直距离,如果点在直线下方,则该距离为负值。
actual = lw_with_predictions.column('Characters')
predicted = lw_with_predictions.column('Linear Prediction')
errors = actual - predicted
lw_with_predictions.with_column('Error', errors)
Periods | Characters | Linear Prediction | Error
189 | 21759 | 21183.6 | 575.403
188 | 22148 | 21096.6 | 1051.38
231 | 20558 | 24836.7 | -4278.67
195 | 25526 | 21705.5 | 3820.54
255 | 23395 | 26924.1 | -3529.13
140 | 14622 | 16921.7 | -2299.68
131 | 14431 | 16138.9 | -1707.88
214 | 22476 | 23358 | -882.043
337 | 33767 | 34056.3 | -289.317
185 | 18508 | 20835.7 | -2327.69
... (37 rows omitted)我们可以使用 slope 和 intercept 来计算拟合线的斜率和截距。下图显示了该线(浅蓝色)。四个点对应的误差以红色显示。这四个点没有什么特别之处,只是为了显示清晰而选择的。函数 lw_errors 接受斜率和截距(按此顺序)作为其参数并绘制图形。
lw_reg_slope = slope(little_women, 'Periods', 'Characters')
lw_reg_intercept = intercept(little_women, 'Periods', 'Characters')
sample = [[131, 14431], [231, 20558], [392, 40935], [157, 23524]]
def lw_errors(slope, intercept):
little_women.scatter('Periods', 'Characters')
xlims = np.array([50, 450])
plots.plot(xlims, slope * xlims + intercept, lw=2)
for x, y in sample:
plots.plot([x, x], [y, slope * x + intercept], color='r', lw=2)
print('Slope of Regression Line: ', np.round(lw_reg_slope), 'characters per period')
print('Intercept of Regression Line:', np.round(lw_reg_intercept), 'characters')
lw_errors(lw_reg_slope, lw_reg_intercept)
Slope of Regression Line: 87.0 characters per period
Intercept of Regression Line: 4745.0 characters
Scatterplot with 'Periods' on the x-axis and 'Characters' on the y-axis. The same dark blue data points are reproduced here, along with a blue line that follows their shape. Red lines are drawn between some of the data points and the the light blue line. Some of the data points are above the light blue line and some of the data points are below it. The red lines have similar, but slightly varrying lengths.如果我们使用不同的直线来生成估计值,误差将会不同。下图显示了如果我们使用另一条直线进行估计,误差会有多大。第二个图显示了使用一条完全不合理的直线所得到的大误差。
lw_errors(50, 10000)
Scatterplot with 'Periods' on the x-axis and 'Characters' on the y-axis. The same dark blue data points are reproduced here, along with a blue line that does not follow their shape as well as the light blue line in the previous graph, but still has a positive slope. The light blue line here appears to be low for for higher x values. Red lines are again drawn between some of the data points and the the light blue line. Some of the data points are above the light blue line and some of the data points are below it. Some of the red lines are very short, and some of them are quite long.lw_errors(-100, 50000)
Scatterplot with 'Periods' on the x-axis and 'Characters' on the y-axis. The same dark blue data points are reproduced here, along with a different light blue line that does not follow their shape at all; it has a negative slope. The light blue line here appears to be high for for lower x values and low for higher x values. Red lines are again drawn between some of the data points and the the light blue line. Some of the data points are above the light blue line and some of the data points are below it. Most of the red lines are long, except towards the center of the graph.均方根误差
我们现在需要的是一个衡量误差大致大小的总体指标。你会认识到构建这个指标的方法——这与我们构建标准差的方法完全相同。
如果你使用任意一条直线来计算估计值,那么你的误差有些可能是正的,有些可能是负的。为了避免在度量误差的大致大小时正负抵消,我们将取平方误差的均值,而不是误差本身的均值。
估计的均方误差是衡量平方误差大致大小的指标,但正如我们之前指出的,其单位难以解释。取平方根得到均方根误差(rmse),它与被预测变量的单位相同,因此更容易理解。
最小化均方根误差
我们到目前为止的观察可以总结如下:
- 要基于 $x$ 得到 $y$ 的估计值,你可以使用任何你想要的直线。
- 每条直线都有一个估计的均方根误差。
- “更好”的直线具有更小的误差。
是否存在一条“最佳”直线?也就是说,是否存在一条在所有直线中最小化均方根误差的直线?
为了回答这个问题,我们首先定义一个函数 lw_rmse 来计算穿过《小妇人》散点图的任意直线的均方根误差。该函数接受斜率和截距(按此顺序)作为其参数。
def lw_rmse(slope, intercept):
lw_errors(slope, intercept)
x = little_women.column('Periods')
y = little_women.column('Characters')
fitted = slope * x + intercept
mse = np.mean((y - fitted) ** 2)
print("Root mean squared error:", mse ** 0.5)
lw_rmse(50, 10000)
Root mean squared error: 4322.167831766537
Scatterplot with 'Periods' on the x-axis and 'Characters' on the y-axis. The same dark blue data points are reproduced here, along with a blue line that does not follow their shape as well as the light blue line in the previous graph, but still has a positive slope. The light blue line here appears to be low for for higher x values. Red lines are again drawn between some of the data points and the the light blue line. Some of the data points are above the light blue line and some of the data points are below it. Some of the red lines are very short, and some of them are quite long.lw_rmse(-100, 50000)
Root mean squared error: 16710.11983735375
Scatterplot with 'Periods' on the x-axis and 'Characters' on the y-axis. The same dark blue data points are reproduced here, along with a different light blue line that does not follow their shape at all; it has a negative slope. The light blue line here appears to be high for for lower x values and low for higher x values. Red lines are again drawn between some of the data points and the the light blue line. Some of the data points are above the light blue line and some of the data points are below it. Most of the red lines are long, except towards the center of the graph.不出所料,不好的直线有较大的 rmse 值。但如果我们选择的斜率和截距接近回归线的斜率和截距,rmse 会小得多。
lw_rmse(90, 4000)
Root mean squared error: 2715.5391063834586
Scatterplot with 'Periods' on the x-axis and 'Characters' on the y-axis. The same dark blue data points are reproduced here, along with a blue line that follows their shape. Red lines are drawn between some of the data points and the the light blue line. Some of the data points are above the light blue line and some of the data points are below it. The red lines have similar, but slightly varrying lengths.以下是回归线对应的均方根误差。根据一个非凡的数学事实,没有其他直线能比这条线更好。
- 回归线是唯一一条在所有直线中使估计均方误差最小的直线。
lw_rmse(lw_reg_slope, lw_reg_intercept)
Root mean squared error: 2701.690785311856
Scatterplot with 'Periods' on the x-axis and 'Characters' on the y-axis. The same dark blue data points are reproduced here, along with a blue line that follows their shape. Red lines are drawn between some of the data points and the the light blue line. Some of the data points are above the light blue line and some of the data points are below it. The red lines have similar, but slightly varrying lengths.这个结论的证明需要抽象数学,超出了本课程的范围。另一方面,我们有一个强大的工具——Python——能够轻松执行大规模数值计算。因此,我们可以使用 Python 来确认回归线最小化了均方误差。
数值优化
首先注意,最小化均方根误差的直线同时也是最小化平方误差的直线。开平方对于最小化没有影响。因此,我们可以省去一步计算,直接最小化均方误差(mse)。
我们试图根据《小妇人》各章节的句号数量 ($x$) 来预测字符数 ($y$)。如果我们使用直线
$$ \mbox{prediction} ~=~ ax + b $$
它的 mse 将取决于斜率 $a$ 和截距 $b$。函数 lw_mse 接受斜率和截距作为其参数,并返回相应的 mse。
def lw_mse(any_slope, any_intercept):
x = little_women.column('Periods')
y = little_women.column('Characters')
fitted = any_slope*x + any_intercept
return np.mean((y - fitted) ** 2)
让我们检查一下 lw_mse 对于回归线的均方根误差是否得到了正确的答案。记住 lw_mse 返回的是均方误差,所以我们需要取平方根才能得到 rmse。
lw_mse(lw_reg_slope, lw_reg_intercept)**0.5
2701.690785311856这与我们之前使用 lw_rmse 得到的值相同:
lw_rmse(lw_reg_slope, lw_reg_intercept)
Root mean squared error: 2701.690785311856
Scatterplot with 'Periods' on the x-axis and 'Characters' on the y-axis. The same dark blue data points are reproduced here, along with a blue line that follows their shape. Red lines are drawn between some of the data points and the the light blue line. Some of the data points are above the light blue line and some of the data points are below it. The red lines have similar, but slightly varrying lengths.你也可以确认 lw_mse 对于其他斜率和截距也能返回正确的值。例如,以下是我们之前尝试的那条非常糟糕的直线的 rmse。
lw_mse(-100, 50000)**0.5
16710.11983735375以下是接近回归线的某条直线的 rmse。
lw_mse(90, 4000)**0.5
2715.5391063834586如果我们尝试不同的值,可以通过试错找到低误差的斜率和截距,但这需要一些时间。幸运的是,有一个 Python 函数可以为我们完成所有的试错工作。
minimize 函数可用于查找使函数返回最小值的参数。Python 使用类似的试错方法,沿着能够逐步降低输出值的方向进行调整。
minimize 的参数是一个本身接受数值参数并返回数值的函数。例如,函数 lw_mse 接受数值斜率和截距作为参数,并返回相应的 mse。
调用 minimize(lw_mse) 返回一个由最小化 mse 的斜率和截距组成的数组。这些最小化值是通过智能试错得到的极好近似值,而不是基于公式的精确值。
best = minimize(lw_mse)
best
array([ 86.97784117, 4744.78484535])这些值与之前使用 slope 和 intercept 函数计算的值相同。由于 minimize 的非精确性质,我们看到一些小的偏差,但值本质上是相同的。
print("slope from formula: ", lw_reg_slope)
print("slope from minimize: ", best.item(0))
print("intercept from formula: ", lw_reg_intercept)
print("intercept from minimize: ", best.item(1))
slope from formula: 86.97784125829821
slope from minimize: 86.97784116615884
intercept from formula: 4744.784796574928
intercept from minimize: 4744.784845352655
最小二乘线
因此,我们不仅发现回归线最小化了均方误差,而且最小化均方误差也恰好给出了回归线。回归线是唯一最小化均方误差的直线。
这就是为什么回归线有时被称为“最小二乘线”。