最近邻
在本节中,我们将介绍“最近邻”分类方法。现在请先专注于理解概念,如果某些代码看起来有些复杂,不必担心。在本章后面,我们将看到如何将这些概念组织成执行分类的代码。
import matplotlib
#matplotlib.use('Agg')
path_data = '../../../assets/data/'
from datascience import *
%matplotlib inline
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import math
import scipy.stats as stats
plt.style.use('fivethirtyeight')
def standard_units(x):
return (x - np.mean(x))/np.std(x)
def distance(point1, point2):
"""The distance between two arrays of numbers."""
return np.sqrt(np.sum((point1 - point2)**2))
def all_distances(training, point):
"""The distance between p (an array of numbers) and the numbers in row i of attribute_table."""
attributes = training.drop('Class')
def distance_from_point(row):
return distance(point, np.array(row))
return attributes.apply(distance_from_point)
def table_with_distances(training, point):
"""A copy of the training table with the distance from each row to array p."""
return training.with_column('Distance', all_distances(training, point))
def closest(training, point, k):
"""A table containing the k closest rows in the training table to array p."""
with_dists = table_with_distances(training, point)
sorted_by_distance = with_dists.sort('Distance')
topk = sorted_by_distance.take(np.arange(k))
return topk
def majority(topkclasses):
"""1 if the majority of the "Class" column is 1s, and 0 otherwise."""
ones = topkclasses.where('Class', are.equal_to(1)).num_rows
zeros = topkclasses.where('Class', are.equal_to(0)).num_rows
if ones > zeros:
return 1
else:
return 0
def classify(training, p, k):
"""Classify an example with attributes p using k-nearest neighbor classification with the given training table."""
closestk = closest(training, p, k)
topkclasses = closestk.select('Class')
return majority(topkclasses)
慢性肾脏疾病
让我们通过一个示例来学习。我们将使用一个为帮助医生诊断慢性肾脏疾病(CKD)而收集的数据集。数据集中的每一行代表一位过去接受过治疗且诊断结果已知的患者。对于每位患者,我们有一系列血液检测的测量值。我们希望找出哪些测量值对诊断 CKD 最有用,并开发一种方法,根据未来的患者的血液检测结果将其分类为“患有 CKD”或“未患有 CKD”。
ckd = Table.read_table(path_data + 'ckd.csv').relabeled('Blood Glucose Random', 'Glucose')
ckd
Age | Blood Pressure | Specific Gravity | Albumin | Sugar | Red Blood Cells | Pus Cell | Pus Cell clumps | Bacteria | Glucose | Blood Urea | Serum Creatinine | Sodium | Potassium | Hemoglobin | Packed Cell Volume | White Blood Cell Count | Red Blood Cell Count | Hypertension | Diabetes Mellitus | Coronary Artery Disease | Appetite | Pedal Edema | Anemia | Class
48 | 70 | 1.005 | 4 | 0 | normal | abnormal | present | notpresent | 117 | 56 | 3.8 | 111 | 2.5 | 11.2 | 32 | 6700 | 3.9 | yes | no | no | poor | yes | yes | 1
53 | 90 | 1.02 | 2 | 0 | abnormal | abnormal | present | notpresent | 70 | 107 | 7.2 | 114 | 3.7 | 9.5 | 29 | 12100 | 3.7 | yes | yes | no | poor | no | yes | 1
63 | 70 | 1.01 | 3 | 0 | abnormal | abnormal | present | notpresent | 380 | 60 | 2.7 | 131 | 4.2 | 10.8 | 32 | 4500 | 3.8 | yes | yes | no | poor | yes | no | 1
68 | 80 | 1.01 | 3 | 2 | normal | abnormal | present | present | 157 | 90 | 4.1 | 130 | 6.4 | 5.6 | 16 | 11000 | 2.6 | yes | yes | yes | poor | yes | no | 1
61 | 80 | 1.015 | 2 | 0 | abnormal | abnormal | notpresent | notpresent | 173 | 148 | 3.9 | 135 | 5.2 | 7.7 | 24 | 9200 | 3.2 | yes | yes | yes | poor | yes | yes | 1
48 | 80 | 1.025 | 4 | 0 | normal | abnormal | notpresent | notpresent | 95 | 163 | 7.7 | 136 | 3.8 | 9.8 | 32 | 6900 | 3.4 | yes | no | no | good | no | yes | 1
69 | 70 | 1.01 | 3 | 4 | normal | abnormal | notpresent | notpresent | 264 | 87 | 2.7 | 130 | 4 | 12.5 | 37 | 9600 | 4.1 | yes | yes | yes | good | yes | no | 1
73 | 70 | 1.005 | 0 | 0 | normal | normal | notpresent | notpresent | 70 | 32 | 0.9 | 125 | 4 | 10 | 29 | 18900 | 3.5 | yes | yes | no | good | yes | no | 1
73 | 80 | 1.02 | 2 | 0 | abnormal | abnormal | notpresent | notpresent | 253 | 142 | 4.6 | 138 | 5.8 | 10.5 | 33 | 7200 | 4.3 | yes | yes | yes | good | no | no | 1
46 | 60 | 1.01 | 1 | 0 | normal | normal | notpresent | notpresent | 163 | 92 | 3.3 | 141 | 4 | 9.8 | 28 | 14600 | 3.2 | yes | yes | no | good | no | no | 1
... (148 rows omitted)一些变量是分类变量(如“abnormal”等词语),一些是定量变量。定量变量都有不同的尺度。我们希望通过肉眼进行比较和距离估计,因此我们只选择少数几个变量,并以标准单位进行处理。这样我们就不必担心每个不同变量的尺度问题了。
ckd = Table().with_columns(
'Hemoglobin', standard_units(ckd.column('Hemoglobin')),
'Glucose', standard_units(ckd.column('Glucose')),
'White Blood Cell Count', standard_units(ckd.column('White Blood Cell Count')),
'Class', ckd.column('Class')
)
ckd
Hemoglobin | Glucose | White Blood Cell Count | Class
-0.865744 | -0.221549 | -0.569768 | 1
-1.45745 | -0.947597 | 1.16268 | 1
-1.00497 | 3.84123 | -1.27558 | 1
-2.81488 | 0.396364 | 0.809777 | 1
-2.08395 | 0.643529 | 0.232293 | 1
-1.35303 | -0.561402 | -0.505603 | 1
-0.413266 | 2.04928 | 0.360623 | 1
-1.28342 | -0.947597 | 3.34429 | 1
-1.10939 | 1.87936 | -0.409356 | 1
-1.35303 | 0.489051 | 1.96475 | 1
... (148 rows omitted)让我们特别关注两列:血红蛋白水平(患者血液中)和血糖水平(一天中随机时间;非空腹采血)。
我们将绘制散点图来可视化两个变量之间的关系。蓝色点代表 CKD 患者;金色点代表非 CKD 患者。哪些医学检测结果似乎指示 CKD?
color_table = Table().with_columns(
'Class', make_array(1, 0),
'Color', make_array('darkblue', 'gold')
)
ckd = ckd.join('Class', color_table)
ckd.scatter('Hemoglobin', 'Glucose', group='Color')
Scatterplot with 'Hemoglobin' on the x-axis and 'Glucose' on the y-axis. Data points are shown in dark blue or in gold. The dark blue data points exist all over the graph, but not where the gold data points do, from about x=0 to x=1.5 and y values between -1 and just above 0.假设 Alice 是一位不在数据集中的新患者。如果我告诉你 Alice 的血红蛋白水平和血糖水平,你能预测她是否患有 CKD 吗?看起来确实可以!你可以看到一个非常清晰的模式:右下角的点往往代表没有 CKD 的人,而其余点则倾向于代表 CKD 患者。对于人类来说,这个模式很明显。但我们如何编程让计算机自动检测这样的模式呢?
最近邻分类器
人们可能寻找的模式有很多种,分类算法也有很多。但我将告诉你一种出奇有效的方法。它被称为“最近邻分类”。思路是这样的:如果我们知道 Alice 的血红蛋白和血糖数值,我们可以将她放在这个散点图的某个位置上;血红蛋白是她的 x 坐标,血糖是她的 y 坐标。现在,为了预测她是否患有 CKD,我们找到散点图中最近的点,检查它是蓝色还是金色;我们预测 Alice 应该与该患者得到相同的诊断。
换句话说,为了将 Alice 分类为 CKD 或非 CKD,我们找到训练集中“最接近”Alice 的患者,然后用该患者的诊断作为我们对 Alice 的预测。直觉上,如果两点在散点图中彼此接近,那么相应的测量值就非常相似,因此我们可能预期它们得到相同的诊断(可能性较大)。我们不知道 Alice 的诊断结果,但我们知道训练集中所有患者的诊断结果,因此我们找到训练集中与 Alice 最相似的患者,用该患者的诊断来预测 Alice 的诊断。
在下面的图中,红点代表 Alice。它通过一条黑线连接到离它最近的点——即训练集中的“最近邻”。该图由一个名为 show_closest 的函数绘制。该函数接受一个表示 Alice 点的 $x$ 和 $y$ 坐标的数组。改变这些值,看看最近的点如何变化!特别要注意最近的点是蓝色还是金色。
def show_closest(point):
"""point = array([x,y])
gives the coordinates of a new point
shown in red"""
HemoGl = ckd.drop('White Blood Cell Count', 'Color')
t = closest(HemoGl, point, 1)
x_closest = t.row(0).item(1)
y_closest = t.row(0).item(2)
ckd.scatter('Hemoglobin', 'Glucose', group='Color')
plt.scatter(point.item(0), point.item(1), color='red', s=30)
plt.plot(make_array(point.item(0), x_closest), make_array(point.item(1), y_closest), color='k', lw=2);
# In this example, Alice's Hemoglobin attribute is 0 and her Glucose is 1.5.
alice = make_array(0, 1.5)
show_closest(alice)
The same scatterplot as before is reproduced here with an additional data point in red at about (0, 1.5). The red data point has a black line drawn to the nearest data point in dark blue. The lins isn't that long. The previous description is: Scatterplot with 'Hemoglobin' on the x-axis and 'Glucose' on the y-axis. Data points are shown in dark blue or in gold. The dark blue data points exist all over the graph, but not where the gold data points do, from about x=0 to x=1.5 and y values between -1 and just above 0.因此,我们的“最近邻分类器”的工作方式如下: - 找到训练集中离新点最近的点。 - 如果该最近点是“CKD”点,则将新点分类为“CKD”。如果最近点是“非 CKD”点,则将新点分类为“非 CKD”。
散点图表明,这个最近邻分类器应该相当准确。右下角的点倾向于得到“非 CKD”的诊断,因为它们的最近邻是金色点。其余的点倾向于得到“CKD”的诊断,因为它们的最近邻是蓝色点。因此,在这个例子中,最近邻策略似乎很好地抓住了我们的直觉。
决策边界
有时,可视化分类器的一种有用方法是标出分类器会预测为“CKD”的属性区域以及会预测为“非 CKD”的属性区域。我们在两者之间得到一些边界,边界一侧的点将被分类为“CKD”,另一侧的点将被分类为“非 CKD”。这个边界被称为“决策边界”。每个不同的分类器都有不同的决策边界;决策边界只是可视化分类器用于分类点的标准的一种方式。
例如,假设 Alice 点的坐标是 (0, 1.5)。注意最近邻是蓝色的。现在尝试降低该点的高度($y$ 坐标)。你会看到大约在 $y = 0.95$ 处,最近邻从蓝色变为金色。
alice = make_array(0, 0.97)
show_closest(alice)
The same scatterplot as before is reproduced here with an additional data point in red at about (0, 1). The red data point has a black line drawn to the nearest data point in dark blue. The lins is a bit longer than the previous black line. The previous description is: Scatterplot with 'Hemoglobin' on the x-axis and 'Glucose' on the y-axis. Data points are shown in dark blue or in gold. The dark blue data points exist all over the graph, but not where the gold data points do, from about x=0 to x=1.5 and y values between -1 and just above 0.这里有数百个新的未分类点,全部为红色。
x_array = make_array()
y_array = make_array()
for x in np.arange(-2, 2.1, 0.1):
for y in np.arange(-2, 2.1, 0.1):
x_array = np.append(x_array, x)
y_array = np.append(y_array, y)
test_grid = Table().with_columns(
'Hemoglobin', x_array,
'Glucose', y_array
)
test_grid.scatter('Hemoglobin', 'Glucose', color='red', alpha=0.4, s=30)
plt.scatter(ckd.column('Hemoglobin'), ckd.column('Glucose'), c=ckd.column('Color'), edgecolor='k')
plt.xlim(-2, 2)
plt.ylim(-2, 2);
A scatterplot with 'hemoglobin' on the x-axis and 'Glucose' on the y-axis. Red points form a grid over the graph. On top of the red points are dark blue data points, primarily on the left hand side of the graph. There are also gold data points, only on the right hand side of the graph. One dark blue data point has a high x value on the right hand side of the graph, but is above the gold data points with a higher y value.每个红点在训练集中都有一个最近邻(与之前相同的蓝点和金点)。对于某些红点,你可以轻易判断最近邻是蓝色还是金色。对于其他红点,用肉眼做出决定就有点棘手了。这些是靠近决策边界的点。
但计算机可以轻松确定每个点的最近邻。所以让我们让它将我们的最近邻分类器应用于每个红点:
对于每个红点,它必须找到训练集中最近的点;然后必须将红点的颜色更改为最近邻的颜色。
结果图显示了哪些点将被分类为“CKD”(所有蓝色的点),哪些将被分类为“非 CKD”(所有金色的点)。
def classify_grid(training, test, k):
c = make_array()
for i in range(test.num_rows):
# Run the classifier on the ith patient in the test set
c = np.append(c, classify(training, make_array(test.row(i)), k))
return c
c = classify_grid(ckd.drop('White Blood Cell Count', 'Color'), test_grid, 1)
test_grid = test_grid.with_column('Class', c).join('Class', color_table)
test_grid.scatter('Hemoglobin', 'Glucose', group='Color', alpha=0.4, s=30)
plt.scatter(ckd.column('Hemoglobin'), ckd.column('Glucose'), c=ckd.column('Color'), edgecolor='k')
plt.xlim(-2, 2)
plt.ylim(-2, 2);
Scatterplot with 'Hemoglobin' on the x-axis and 'Glucose' on the y-axis. The same dark blue and gold data points are drawn. More transparent data points are drawn in a grid behind these, either colored blue or gold, depending on if they're closest to a dark blue data point or a full gold data point. The left hand side of the graph is mostly blue, along with the top third. The rest of the graph is the lighter gold.决策边界就是分类器将红点从变为蓝色切换到变为金色的地方。
k-最近邻
然而,两个类别之间的分离并不总是如此清晰。例如,假设我们不考察血红蛋白水平,而是考察白细胞计数。看看会发生什么:
ckd.scatter('White Blood Cell Count', 'Glucose', group='Color')
A scatterplot with 'White Blood Cell Count' on the x-axis and 'Glucose' on the y-axis. There are gold data points in the bottom left corner of the graph, and dark blue data points all over. The gold and dark blue data points do overlap in this graph.如你所见,非 CKD 患者都聚集在左下角。大多数 CKD 患者位于该簇的上方或右侧……但并非全部。有一些 CKD 患者出现在上图的左下角(如散落在金色簇中的少数蓝点所示)。这意味着你无法仅凭这两项血液检测指标就确定某人是否患有 CKD。
如果我们知道 Alice 的血糖水平和白细胞计数,我们能预测她是否患有 CKD 吗?是的,我们可以做出预测,但我们不应期望它是 100% 准确的。直觉上,似乎有一种自然的预测策略:标出 Alice 在散点图中的位置;如果她在左下角,预测她没有 CKD,否则预测她有 CKD。
这并不完美——我们的预测有时会出错。(花一分钟思考一下:它对哪些患者会出错?)正如上面的散点图所示,有时 CKD 患者的血糖和白细胞水平与非 CKD 患者看起来完全相同,因此任何分类器都不可避免地会对他们做出错误的预测。
我们能在计算机上自动化这个过程吗?嗯,最近邻分类器在这里也是一个合理的选择。花一分钟思考一下:它的预测与上述直观策略相比如何?它们何时会不同?
它的预测将与我们的直观策略非常相似,但偶尔会做出不同的预测。特别是,如果 Alice 的血液检测结果恰好将她放在左下角的一个蓝点附近,直观策略会预测“非 CKD”,而最近邻分类器会预测“CKD”。
最近邻分类器有一个简单的推广形式可以解决这个异常。它被称为“k-最近邻分类器”。为了预测一个新点的类别,它找到训练集中最接近的 $k$ 个点($k$ 个最近邻),并让它们对新点的类别进行投票。通常,我们使用 $k$ 个最近邻中的多数类作为我们的预测。
这种方法平滑了单个离群点的影响,在实践中通常比简单的最近邻效果更好。让我们看看它如何运作,使用 $k=5$。