python查数据维度分析行业_python – scitkit-learn查询数据维度必须与训练数据维度匹配...

我正在尝试使用scikit learn网站上的代码:

我正在使用自己的数据.

我的问题是,我有两个以上的功能.如果我想“扩展”2到3或4的功能….

我越来越:

“查询数据维度必须与培训数据维度相匹配”

def machine():

with open("test.txt",'r') as csvr:

reader= csv.reader(csvr,delimiter='\t')

for i,row in enumerate(reader):

if i==0:

pass

elif '' in row[2:]:

pass

else:

liste.append(map(float,row[2:]))

a = np.array(liste)

h = .02

names = ["Nearest Neighbors", "Linear SVM", "RBF SVM", "Decision Tree",

"Random Forest", "AdaBoost", "Naive Bayes", "LDA", "QDA"]

classifiers = [

KNeighborsClassifier(1),

SVC(kernel="linear", C=0.025),

SVC(gamma=2, C=1),

DecisionTreeClassifier(max_depth=5),

RandomForestClassifier(max_depth=5, n_estimators=10, max_features=1),

AdaBoostClassifier(),

GaussianNB(),

LDA(),

QDA()]

X = a[:,:3]

y = np.ravel(a[:,13])

linearly_separable = (X, y)

datasets =[linearly_separable]

figure = plt.figure(figsize=(27, 9))

i = 1

for ds in datasets:

X, y = ds

X = StandardScaler().fit_transform(X)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.4)

x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5

y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5

xx, yy = np.meshgrid(np.arange(x_min, x_max, h),

np.arange(y_min, y_max, h))

cm = plt.cm.RdBu

cm_bright = ListedColormap(['#FF0000', '#0000FF'])

ax = plt.subplot(len(datasets), len(classifiers) + 1, i)

ax.scatter(X_train[:, 0], X_train[:, 1], c=y_train, cmap=cm_bright)

ax.scatter(X_test[:, 0], X_test[:, 1], c=y_test, cmap=cm_bright, alpha=0.6)

ax.set_xlim(xx.min(), xx.max())

ax.set_ylim(yy.min(), yy.max())

ax.set_xticks(())

ax.set_yticks(())

i += 1

for name, clf in zip(names, classifiers):

ax = plt.subplot(len(datasets), len(classifiers) + 1, i)

print clf.fit(X_train, y_train)

score = clf.score(X_test, y_test)

print y.shape, X.shape

if hasattr(clf, "decision_function"):

Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()])

print Z

else:

Z = clf.predict_proba(np.c_[xx.ravel(), yy.ravel()])[:, 1]

Z = Z.reshape(xx.shape)

ax.contourf(xx, yy, Z, cmap=cm, alpha=.8)

ax.scatter(X_train[:, 0], X_train[:, 1], c=y_train, cmap=cm_bright)

ax.scatter(X_test[:, 0], X_test[:, 1], c=y_test, cmap=cm_bright,

alpha=0.6)

ax.set_xlim(xx.min(), xx.max())

ax.set_ylim(yy.min(), yy.max())

ax.set_xticks(())

ax.set_yticks(())

ax.set_title(name)

ax.text(xx.max() - .3, yy.min() + .3, ('%.2f' % score).lstrip('0'),

size=15, horizontalalignment='right')

i += 1

figure.subplots_adjust(left=.02, right=.98)

plt.show()

在这种情况下,我使用三个功能.

我在代码中做错了什么,它是X_train和X_test数据的东西吗?只有两个功能,一切都很好.

我的X值:

(array([[ 1., 1., 0.],

[ 1., 0., 0.],

[ 1., 0., 0.],

[ 1., 0., 0.],

[ 1., 1., 0.],

[ 1., 0., 0.],

[ 1., 0., 0.],

[ 3., 3., 0.],

[ 1., 1., 0.],

[ 1., 1., 0.],

[ 0., 0., 0.],

[ 0., 0., 0.],

[ 0., 0., 0.],

[ 0., 0., 0.],

[ 0., 0., 0.],

[ 0., 0., 0.],

[ 4., 4., 2.],

[ 0., 0., 0.],

[ 6., 3., 0.],

[ 5., 3., 2.],

[ 2., 2., 0.],

[ 4., 4., 2.],

[ 2., 1., 0.],

[ 2., 2., 0.]]), array([ 1., 1., 1., 1., 0., 1., 1., 0., 1., 1., 0., 1., 1.,

1., 1., 1., 0., 1., 1., 0., 1., 0., 1., 1.]))

第一个数组是X数组,第二个数组是y(目标)数组.

抱歉格式错误=错误:

Traceback (most recent call last):

File "allM.py", line 144, in

mainplot(namePlot,1,2)

File "allM.py", line 117, in mainplot

Z = clf.predict_proba(np.c_[xx.ravel(), yy.ravel()])[:, 1]

File "/usr/local/lib/python2.7/dist-packages/sklearn/neighbors/classification.py", line 191, in predict_proba

neigh_dist, neigh_ind = self.kneighbors(X)

File "/usr/local/lib/python2.7/dist-packages/sklearn/neighbors/base.py", line 332, in kneighbors

return_distance=return_distance)

File "binary_tree.pxi", line 1298, in sklearn.neighbors.kd_tree.BinaryTree.query (sklearn/neighbors/kd_tree.c:10433)

ValueError: query data dimension must match training data dimension

这是X阵列而没有将他放入数据集“ds”.

[[ 1. 1. 0.][ 1. 0. 0.][ 1. 0. 0.][ 1. 0. 0.][ 1. 1. 0.][ 1. 0. 0.][ 1. 0. 0.][ 3. 3. 0.][ 1. 1. 0.][ 1. 1. 0.][ 0. 0. 0.][ 0. 0. 0.][ 0. 0. 0.][ 0. 0. 0.][ 0. 0. 0.][ 0. 0. 0.][ 4. 4. 2.][ 0. 0. 0.][ 6. 3. 0.][ 5. 3. 2.][ 2. 2. 0.][ 4. 4. 2.][ 2. 1. 0.][ 2. 2. 0.]]

你可能感兴趣的:(python查数据维度分析行业)