我正在尝试在线使用(out-of-core)学习算法,用于使用SGDClassifier的MNIST问题
但是,似乎精度并不会一直提高。
在这种情况下我该怎么办?最好地保存分类器?
SGDClassifier是否收敛到某个最佳解决方案?
这是我的代码:
import numpy as np
from sklearn.linear_model.stochastic_gradient import SGDClassifier
from sklearn.datasets import fetch_mldata
from sklearn.utils import shuffle
#use all digits
mnist = fetch_mldata("MNIST original")
X_train, y_train = mnist.data[:70000] / 255., mnist.target[:70000]
X_train, y_train = shuffle(X_train, y_train)
X_test, y_test = X_train[60000:70000], y_train[60000:70000]
step =1000
batches= np.arange(0,60000,step)
all_classes = np.array([0,1,2,3,4,5,6,7,8,9])
classifier = SGDClassifier()
for curr in batches:
X_curr, y_curr = X_train[curr:curr+step], y_train[curr:curr+step]
classifier.partial_fit(X_curr, y_curr, classes=all_classes)
score= classifier.score(X_test, y_test)
print score
print "all done"
我在MNIST上使用10k样本训练和10k样本测试了linearSVM与SGD,得到0.883 13,95和0.85 1,32,因此SGD更快但准确性更低。
#test linearSVM vs SGD
t0 = time.time()
clf = LinearSVC()
clf.fit(X_train, y_train)
score= clf.score(X_test, y_test)
print score
print (time.time()-t0)
t1 = time.time()
clf = SGDClassifier()
clf.fit(X_train, y_train)
score= clf.score(X_test, y_test)
print score
print (time.time()-t1)
我也在这里找到了一些信息
https://stats.stackexchange.com/a/14936/16843
更新:数据经过一遍(10遍)获得了90.8%的最佳准确性,因此可以解决。 SGD的另一种特殊性是,在将数据传递给分类器之前,必须先对数据进行混洗。
最佳答案
首先要说明:您正在使用SGDClassifier
和默认参数:它们可能不是此数据集的最佳值:也请尝试其他值(特别是对于alpha而言,是正则化参数)。
现在回答您的问题,线性模型在像数字图像分类任务MNIST这样的数据集上做得很好的可能性很小。您可能想尝试线性模型,例如:SVC(kernel='rbf')
(但不可扩展,请尝试训练集的一小部分),而不是增量/核心外ExtraTreesClassifier(n_estimator=100)
或更多,但也不在核外。次估计量越多,训练所需的时间就越长。
您也可以尝试使用SVC(kernel='rbf')
的Nystroem approximation,方法是使用拟合在数据小子集(例如10000个样本)上的Nystroem(n_components=1000, gamma=0.05)
转换数据集,然后将整个转换后的训练集传递给线性模型,例如:需要两次通过数据集。
github上还有一个pull request for 1 hidden layer perceptron,其计算速度应比SGDClassifier
更快,并且在MNIST上达到98%的测试集准确度(并且还提供用于偏芯学习的partial_fit API)。
编辑:ExtraTreesClassifier
分数的估计值的波动是预期的:SGD代表随机梯度下降,这意味着一次将示例视为一个:分类不当的样本可能会以某种方式导致模型权重的更新这对其他样本不利,因此您需要对数据进行多次传递,以使学习率降低得足够多,从而获得对验证准确性的更平滑估计。您可以在for循环中使用itertools.repeat对数据集进行多次传递(例如10次)。
关于python - MNIST和SGDClassifier分类器,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/18895553/