我试图在虹膜数据集上使用简单的命令:train_test_split,并使用svm进行预测,但是当我使用“ fit”时,如下所示:

dat_iris = datasets.load_iris()
x1 = dat_iris.data[:,2]
y1 = dat_iris.target
x_train,y_train,x_test,y_test = train_test_split(x1, y1, test_size = 0.3,
random_state=0)
svm_model = SVC(kernel='linear',C=1.0, random_state=0)
svm_model.fit(x_train,y_train)
y_pred = svm_model.predict(x_train)


但是出现以下错误:

ValueError                                Traceback (most recent call last)
<ipython-input-245-120527f222b3> in <module>()
      7
      8 svm_model = SVC(kernel='linear',C=1.0, random_state=0)
----> 9 svm_model.fit(x_train,y_train)
     10 y_pred = svm_model.predict(x_train)
     11 metrics.classification_report(y_pred, y_train)

~/anaconda3/lib/python3.6/site-packages/sklearn/svm/base.py in fit(self, X, y, sample_weight)
    147         self._sparse = sparse and not callable(self.kernel)
    148
--> 149         X, y = check_X_y(X, y, dtype=np.float64, order='C', accept_sparse='csr')
    150         y = self._validate_targets(y)
    151

~/anaconda3/lib/python3.6/site-packages/sklearn/utils/validation.py in check_X_y(X, y, accept_sparse, dtype, order, copy, force_all_finite, ensure_2d, allow_nd, multi_output, ensure_min_samples, ensure_min_features, y_numeric, warn_on_dtype, estimator)
    550         y = y.astype(np.float64)
    551
--> 552     check_consistent_length(X, y)
    553
    554     return X, y

~/anaconda3/lib/python3.6/site-packages/sklearn/utils/validation.py in check_consistent_length(*arrays)
    171     if len(uniques) > 1:
    172         raise ValueError("Found input variables with inconsistent numbers of"
--> 173                          " samples: %r" % [int(l) for l in lengths])
    174
    175

ValueError: Found input variables with inconsistent numbers of samples: [105, 45]


这可能是由于目标或输入的大小而引起的,如何解决此问题?

最佳答案

您混合了返回参数的顺序。
它应该是:

X_train, X_test, y_train, y_test = train_test_split(x1, y1, test_size = 0.3,
random_state=0)

关于machine-learning - 制作火车时出错-通过sklearn.train_test_split()从虹膜数据中测试集,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/47761901/

10-12 23:09