1. 数据集准备
from torchvision import datasets
mnist_set = datasets.MNIST(root="./MNIST", train=True, download=True)
具体参数说明请自行搜索。注意若donwload=True
,则torchvision会通过内置链接自动下载数据集,
但是有时会失效。因此可以自己去网络上下载并解压后排列成指定文件树,如下
MNIST
├── MNSIT
│ ├── raw
│ │ ├── t10k-images-idx3-ubyte.gz
│ │ ├── t10k-labels-idx1-ubyte.gz
│ │ ├── train-images-idx3-ubyte.gz
│ │ └── train-labels-idx1-ubyte.gz
然后使用如下语句去读取数据集
img, target = minst_set[0]
其中每个img类型为PILimage,target类型为int,代表该图片对应的数字。
x_, y_ = list(zip(*([(np.array(img).reshape(28*28), target) for img, target in mnist_set])))
上面的语句实现了将MNIST数据集转换成numpy数组的形式,其中x_是每个成员为[1, 784]的numpy数组,y_为对应的数字所组成的列表。
2. SVM训练
求解SVM是一个很复杂的问题,但是万幸的是sklearn中有封装的很好的模块,可以很简单的直接使用
from sklearn.svm import SVC
svc = SVC(kernel='rbf', C=1)
svc.fit(x_, y_)
其中fit接口接受两个参数,第一个参数为训练数据[batch_size, data],第二个参数为训练标签[batch_size,1]。
SVC的构造函数如下
SVC(C=1.0, kernel='rbf', degree=3, gamma='scale', coef0=0.0, shrinking=True, probability=False, tol=0.001, cache_size=200, class_weight=None, verbose=False, max_iter=-1, decision_function_shape='ovr', random_state=None)
4. 数字分割
这里就是使用opencv对拍摄的图像进行轮廓提取后拟合外接矩形,借此来提取数字部分的ROI。
这里选择进行Canny边缘检测后去进行轮廓提取,然后拟合外接矩形,因为相较于直接二值化后去提取数字部分的ROI,
边缘检测对数字与纸张的边界更加敏感,即便在光照不均匀的情况下,也能较好的提取出数字的边缘。鲁棒性强。
5. 杂项与代码
使用pickle模块对训练好的模型对象进行序列化保存与加载,可以将训练好的模型保存到本地,以便后续使用。
最后贴出代码
给出几个识别后的效果: