python pls_python – sklearn的PLSRegression:“ValueError:数组不能包含infs或NaNs”

import numpy as np

import sklearn.cross_decomposition

pls2 = sklearn.cross_decomposition.PLSRegression()

xx = np.random.random((5,5))

yy = np.zeros((5,5) )

yy[0,:] = [0,1,0,0,0]

yy[1,:] = [0,0,0,1,0]

yy[2,:] = [0,0,0,0,1]

#yy[3,:] = [1,0,0,0,0] # Uncommenting this line solves the issue

pls2.fit(xx, yy)

我明白了:

C:\Anaconda\lib\site-packages\sklearn\cross_decomposition\pls_.py:44: RuntimeWarning: invalid value encountered in divide

x_weights = np.dot(X.T, y_score) / np.dot(y_score.T, y_score)

C:\Anaconda\lib\site-packages\sklearn\cross_decomposition\pls_.py:64: RuntimeWarning: invalid value encountered in less

if np.dot(x_weights_diff.T, x_weights_diff) < tol or Y.shape[1] == 1:

C:\Anaconda\lib\site-packages\sklearn\cross_decomposition\pls_.py:67: UserWarning: Maximum number of iterations reached

warnings.warn('Maximum number of iterations reached')

C:\Anaconda\lib\site-packages\sklearn\cross_decomposition\pls_.py:297: RuntimeWarning: invalid value encountered in less

if np.dot(x_scores.T, x_scores) < np.finfo(np.double).eps:

C:\Anaconda\lib\site-packages\sklearn\cross_decomposition\pls_.py:275: RuntimeWarning: invalid value encountered in less

if np.all(np.dot(Yk.T, Yk) < np.finfo(np.double).eps):

Traceback (most recent call last):

File "C:\svn\hw4\code\test_plsr2.py", line 8, in

pls2.fit(xx, yy)

File "C:\Anaconda\lib\site-packages\sklearn\cross_decomposition\pls_.py", line 335, in fit

linalg.pinv(np.dot(self.x_loadings_.T, self.x_weights_)))

File "C:\Anaconda\lib\site-packages\scipy\linalg\basic.py", line 889, in pinv

a = _asarray_validated(a, check_finite=check_finite)

File "C:\Anaconda\lib\site-packages\scipy\_lib\_util.py", line 135, in _asarray_validated

a = np.asarray_chkfinite(a)

File "C:\Anaconda\lib\site-packages\numpy\lib\function_base.py", line 613, in asarray_chkfinite

"array must not contain infs or NaNs")

ValueError: array must not contain infs or NaNs

可能是什么问题?

我知道scikit-learn GitHub issue #2089,但由于我使用scikit-learn 0.16.1(使用Python 2.7.10 x64),这个问题应该解决(GitHub问题中提到的代码片段工作正常).

最佳答案 请检查传入的任何值是否为NaN或inf:

np.isnan(xx).any()

np.isnan(yy).any()

np.isinf(xx).any()

np.isinf(yy).any()

如果其中任何一个产生真实.删除nan条目或inf条目.例如.您可以将它们设置为0:

xx = np.nan_to_num(xx)

yy = np.nan_to_num(yy)

也可以为numpy提供如此大的正负值和归零值,使得库中深处的方程产生零,Nan或Inf.奇怪的是,一种解决方法是发送较小的数字(比如-1和1之间的代表性数字.一种方法是通过标准化,参见:https://stackoverflow.com/a/36390482/445131

如果没有一个解决问题,那么您可能正在处理您使用的库中的低级错误,或者您的数据中的某种奇点.创建一个sscce并将其发布到stackoverflow或在维护您的软件的库上创建一个新的错误报告.

你可能感兴趣的:(python,pls)