CategoryResourceRepost/极客时间专栏/数据分析实战45讲/第二模块:数据分析算法篇/25丨KNN(下):如何对手写数字进行识别?.md
louzefeng d3828a7aee mod
2024-07-11 05:50:32 +00:00

228 lines
11 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<audio id="audio" title="25丨KNN如何对手写数字进行识别" controls="" preload="none"><source id="mp3" src="https://static001.geekbang.org/resource/audio/bd/11/bd87b5d9179bfd3740fd19f6b7641b11.mp3"></audio>
今天我来带你进行KNN的实战。上节课我讲了KNN实际上是计算待分类物体与其他物体之间的距离然后通过统计最近的K个邻居的分类情况来决定这个物体的分类情况。
这节课我们先看下如何在sklearn中使用KNN算法然后通过sklearn中自带的手写数字数据集来进行实战。
之前我还讲过SVM、朴素贝叶斯和决策树分类我们还可以用这个数据集来做下训练对比下这四个分类器的训练结果。
## 如何在sklearn中使用KNN
在Python的sklearn工具包中有KNN算法。KNN既可以做分类器也可以做回归。如果是做分类你需要引用
```
from sklearn.neighbors import KNeighborsClassifier
```
如果是做回归,你需要引用:
```
from sklearn.neighbors import KNeighborsRegressor
```
从名字上你也能看出来Classifier对应的是分类Regressor对应的是回归。一般来说如果一个算法有Classifier类都能找到相应的Regressor类。比如在决策树分类中你可以使用DecisionTreeClassifier也可以使用决策树来做回归DecisionTreeRegressor。
好了我们看下如何在sklearn中创建KNN分类器。
这里我们使用构造函数KNeighborsClassifier(n_neighbors=5, weights=uniform, algorithm=auto, leaf_size=30),这里有几个比较主要的参数,我分别来讲解下:
1.n_neighbors即KNN中的K值代表的是邻居的数量。K值如果比较小会造成过拟合。如果K值比较大无法将未知物体分类出来。一般我们使用默认值5。
2.weights是用来确定邻居的权重有三种方式
<li>
weights=uniform代表所有邻居的权重相同
</li>
<li>
weights=distance代表权重是距离的倒数即与距离成反比
</li>
<li>
自定义函数,你可以自定义不同距离所对应的权重。大部分情况下不需要自己定义函数。
</li>
3.algorithm用来规定计算邻居的方法它有四种方式
<li>
algorithm=auto根据数据的情况自动选择适合的算法默认情况选择auto
</li>
<li>
algorithm=kd_tree也叫作KD树是多维空间的数据结构方便对关键数据进行检索不过KD树适用于维度少的情况一般维数不超过20如果维数大于20之后效率反而会下降
</li>
<li>
algorithm=ball_tree也叫作球树它和KD树一样都是多维空间的数据结果不同于KD树球树更适用于维度大的情况
</li>
<li>
algorithm=brute也叫作暴力搜索它和KD树不同的地方是在于采用的是线性扫描而不是通过构造树结构进行快速检索。当训练集大的时候效率很低。
</li>
4.leaf_size代表构造KD树或球树时的叶子数默认是30调整leaf_size会影响到树的构造和搜索速度。
创建完KNN分类器之后我们就可以输入训练集对它进行训练这里我们使用fit()函数传入训练集中的样本特征矩阵和分类标识会自动得到训练好的KNN分类器。然后可以使用predict()函数来对结果进行预测,这里传入测试集的特征矩阵,可以得到测试集的预测分类结果。
## 如何用KNN对手写数字进行识别分类
手写数字数据集是个非常有名的用于图像识别的数据集。数字识别的过程就是将这些图片与分类结果0-9一一对应起来。完整的手写数字数据集MNIST里面包括了60000个训练样本以及10000个测试样本。如果你学习深度学习的话MNIST基本上是你接触的第一个数据集。
今天我们用sklearn自带的手写数字数据集做KNN分类你可以把这个数据集理解成一个简版的MNIST数据集它只包括了1797幅数字图像每幅图像大小是8*8像素。
好了我们先来规划下整个KNN分类的流程
<img src="https://static001.geekbang.org/resource/image/8a/78/8af94562f6bd3ac42036ec47f5ad2578.jpg" alt=""><br>
整个训练过程基本上都会包括三个阶段:
<li>
数据加载我们可以直接从sklearn中加载自带的手写数字数据集
</li>
<li>
准备阶段在这个阶段中我们需要对数据集有个初步的了解比如样本的个数、图像长什么样、识别结果是怎样的。你可以通过可视化的方式来查看图像的呈现。通过数据规范化可以让数据都在同一个数量级的维度。另外因为训练集是图像每幅图像是个8*8的矩阵我们不需要对它进行特征选择将全部的图像数据作为特征值矩阵即可
</li>
<li>
分类阶段:通过训练可以得到分类器,然后用测试集进行准确率的计算。
</li>
好了,按照上面的步骤,我们一起来实现下这个项目。
首先是加载数据和对数据的探索:
```
# 加载数据
digits = load_digits()
data = digits.data
# 数据探索
print(data.shape)
# 查看第一幅图像
print(digits.images[0])
# 第一幅图像代表的数字含义
print(digits.target[0])
# 将第一幅图像显示出来
plt.gray()
plt.imshow(digits.images[0])
plt.show()
```
运行结果:
```
(1797, 64)
[[ 0. 0. 5. 13. 9. 1. 0. 0.]
[ 0. 0. 13. 15. 10. 15. 5. 0.]
[ 0. 3. 15. 2. 0. 11. 8. 0.]
[ 0. 4. 12. 0. 0. 8. 8. 0.]
[ 0. 5. 8. 0. 0. 9. 8. 0.]
[ 0. 4. 11. 0. 1. 12. 7. 0.]
[ 0. 2. 14. 5. 10. 12. 0. 0.]
[ 0. 0. 6. 13. 10. 0. 0. 0.]]
0
```
<img src="https://static001.geekbang.org/resource/image/62/3c/625b7e95a22c025efa545d7144ec5f3c.png" alt=""><br>
我们对原始数据集中的第一幅进行数据可视化可以看到图像是个8*8的像素矩阵上面这幅图像是一个“0”从训练集的分类标注中我们也可以看到分类标注为“0”。
sklearn自带的手写数字数据集一共包括了1797个样本每幅图像都是8*8像素的矩阵。因为并没有专门的测试集所以我们需要对数据集做划分划分成训练集和测试集。因为KNN算法和距离定义相关我们需要对数据进行规范化处理采用Z-Score规范化代码如下
```
# 分割数据将25%的数据作为测试集,其余作为训练集(你也可以指定其他比例的数据作为训练集)
train_x, test_x, train_y, test_y = train_test_split(data, digits.target, test_size=0.25, random_state=33)
# 采用Z-Score规范化
ss = preprocessing.StandardScaler()
train_ss_x = ss.fit_transform(train_x)
test_ss_x = ss.transform(test_x)
```
然后我们构造一个KNN分类器knn把训练集的数据传入构造好的knn并通过测试集进行结果预测与测试集的结果进行对比得到KNN分类器准确率代码如下
```
# 创建KNN分类器
knn = KNeighborsClassifier()
knn.fit(train_ss_x, train_y)
predict_y = knn.predict(test_ss_x)
print(&quot;KNN准确率: %.4lf&quot; % accuracy_score(test_y, predict_y))
```
运行结果:
```
KNN准确率: 0.9756
```
好了这样我们就构造好了一个KNN分类器。之前我们还讲过SVM、朴素贝叶斯和决策树分类。我们用手写数字数据集一起来训练下这些分类器然后对比下哪个分类器的效果更好。代码如下
```
# 创建SVM分类器
svm = SVC()
svm.fit(train_ss_x, train_y)
predict_y=svm.predict(test_ss_x)
print('SVM准确率: %0.4lf' % accuracy_score(test_y, predict_y))
# 采用Min-Max规范化
mm = preprocessing.MinMaxScaler()
train_mm_x = mm.fit_transform(train_x)
test_mm_x = mm.transform(test_x)
# 创建Naive Bayes分类器
mnb = MultinomialNB()
mnb.fit(train_mm_x, train_y)
predict_y = mnb.predict(test_mm_x)
print(&quot;多项式朴素贝叶斯准确率: %.4lf&quot; % accuracy_score(test_y, predict_y))
# 创建CART决策树分类器
dtc = DecisionTreeClassifier()
dtc.fit(train_mm_x, train_y)
predict_y = dtc.predict(test_mm_x)
print(&quot;CART决策树准确率: %.4lf&quot; % accuracy_score(test_y, predict_y))
```
运行结果如下:
```
SVM准确率: 0.9867
多项式朴素贝叶斯准确率: 0.8844
CART决策树准确率: 0.8556
```
这里需要注意的是我们在做多项式朴素贝叶斯分类的时候传入的数据不能有负数。因为Z-Score会将数值规范化为一个标准的正态分布即均值为0方差为1数值会包含负数。因此我们需要采用Min-Max规范化将数据规范化到[0,1]范围内。
好了我们整理下这4个分类器的结果。
<img src="https://static001.geekbang.org/resource/image/0f/e8/0f498e0197935bfe15d9b1209bad8fe8.png" alt=""><br>
你能看出来KNN的准确率还是不错的和SVM不相上下。
你可以自己跑一遍整个代码在运行前还需要import相关的工具包下面的这些工具包你都会用到所以都需要引用
```
from sklearn.model_selection import train_test_split
from sklearn import preprocessing
from sklearn.metrics import accuracy_score
from sklearn.datasets import load_digits
from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import SVC
from sklearn.naive_bayes import MultinomialNB
from sklearn.tree import DecisionTreeClassifier
import matplotlib.pyplot as plt
```
代码中我使用了train_test_split做数据集的拆分使用matplotlib.pyplot工具包显示图像使用accuracy_score进行分类器准确率的计算使用preprocessing中的StandardScaler和MinMaxScaler做数据的规范化。
完整的代码你可以从[GitHub](https://github.com/cystanford/knn)上下载。
## 总结
今天我带你一起做了手写数字分类识别的实战分别用KNN、SVM、朴素贝叶斯和决策树做分类器并统计了四个分类器的准确率。在这个过程中你应该对数据探索、数据可视化、数据规范化、模型训练和结果评估的使用过程有了一定的体会。在数据量不大的情况下使用sklearn还是方便的。
如果数据量很大比如MNIST数据集中的6万个训练数据和1万个测试数据那么采用深度学习+GPU运算的方式会更适合。因为深度学习的特点就是需要大量并行的重复计算GPU最擅长的就是做大量的并行计算。
<img src="https://static001.geekbang.org/resource/image/d0/e1/d08f489c3bffaacb6910f32a0fa600e1.png" alt=""><br>
最后留两道思考题吧请你说说项目中KNN分类器的常用构造参数功能函数都有哪些以及你对KNN使用的理解如果把KNN中的K值设置为200数据集还是sklearn中的手写数字数据集再跑一遍程序看看分类器的准确率是多少
欢迎在评论区与我分享你的答案,也欢迎点击“请朋友读”,把这篇文章分享给你的朋友或者同事。