本文介绍了sklearn的PLSRegression:"ValueError:数组不得包含infs或NaN".的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!
问题描述
使用 sklearn.cross_decomposition.PLSRegression
:
import numpy as np
import sklearn.cross_decomposition
pls2 = sklearn.cross_decomposition.PLSRegression()
xx = np.random.random((5,5))
yy = np.zeros((5,5) )
yy[0,:] = [0,1,0,0,0]
yy[1,:] = [0,0,0,1,0]
yy[2,:] = [0,0,0,0,1]
#yy[3,:] = [1,0,0,0,0] # Uncommenting this line solves the issue
pls2.fit(xx, yy)
我得到:
C:\Anaconda\lib\site-packages\sklearn\cross_decomposition\pls_.py:44: RuntimeWarning: invalid value encountered in divide
x_weights = np.dot(X.T, y_score) / np.dot(y_score.T, y_score)
C:\Anaconda\lib\site-packages\sklearn\cross_decomposition\pls_.py:64: RuntimeWarning: invalid value encountered in less
if np.dot(x_weights_diff.T, x_weights_diff) < tol or Y.shape[1] == 1:
C:\Anaconda\lib\site-packages\sklearn\cross_decomposition\pls_.py:67: UserWarning: Maximum number of iterations reached
warnings.warn('Maximum number of iterations reached')
C:\Anaconda\lib\site-packages\sklearn\cross_decomposition\pls_.py:297: RuntimeWarning: invalid value encountered in less
if np.dot(x_scores.T, x_scores) < np.finfo(np.double).eps:
C:\Anaconda\lib\site-packages\sklearn\cross_decomposition\pls_.py:275: RuntimeWarning: invalid value encountered in less
if np.all(np.dot(Yk.T, Yk) < np.finfo(np.double).eps):
Traceback (most recent call last):
File "C:\svn\hw4\code\test_plsr2.py", line 8, in <module>
pls2.fit(xx, yy)
File "C:\Anaconda\lib\site-packages\sklearn\cross_decomposition\pls_.py", line 335, in fit
linalg.pinv(np.dot(self.x_loadings_.T, self.x_weights_)))
File "C:\Anaconda\lib\site-packages\scipy\linalg\basic.py", line 889, in pinv
a = _asarray_validated(a, check_finite=check_finite)
File "C:\Anaconda\lib\site-packages\scipy\_lib\_util.py", line 135, in _asarray_validated
a = np.asarray_chkfinite(a)
File "C:\Anaconda\lib\site-packages\numpy\lib\function_base.py", line 613, in asarray_chkfinite
"array must not contain infs or NaNs")
ValueError: array must not contain infs or NaNs
可能是什么问题?
我知道 scikit-learn GitHub第2089版,但是由于我使用scikit-learn 0.16.1(在Python 2.7.10 x64中),因此应该解决此问题(在GitHub问题中提到的代码段可以正常工作).
I am aware of scikit-learn GitHub issue #2089, but since I use scikit-learn 0.16.1 (with Python 2.7.10 x64) this problem should be solved (the code snippets mentioned in the GitHub issue work fine).
推荐答案
问题是由scikit-learn中的错误引起的.我在GitHub上进行了报告: https://github.com/scikit -learn/scikit-learn/issues/2089#issuecomment-152753095
The issue is caused by a bug in scikit-learn. I reported it on GitHub: https://github.com/scikit-learn/scikit-learn/issues/2089#issuecomment-152753095
这篇关于sklearn的PLSRegression:"ValueError:数组不得包含infs或NaN".的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!