1 importnumpy as np2 importmatplotlib.pyplot as plt3 from .plot_helpers importcm2, cm3, discrete_scatter4
5 def_call_classifier_chunked(classifier_pred_or_decide, X):6 #The chunk_size is used to chunk the large arrays to work with x86
7 #memory models that are restricted to < 2 GB in memory allocation. The
8 #chunk_size value used here is based on a measurement with the
9 #MLPClassifier using the following parameters:
10 #MLPClassifier(solver='lbfgs', random_state=0,
11 #hidden_layer_sizes=[1000,1000,1000])
12 #by reducing the value it is possible to trade in time for memory.
13 #It is possible to chunk the array as the calculations are independent of
14 #each other.
15 #Note: an intermittent version made a distinction between
16 #32- and 64 bit architectures avoiding the chunking. Testing revealed
17 #that even on 64 bit architectures the chunking increases the
18 #performance by a factor of 3-5, largely due to the avoidance of memory
19 #swapping.
20 chunk_size = 10000
21
22 #We use a list to collect all result chunks
23 Y_result_chunks =[]24
25 #Call the classifier in chunks.
26 for x_chunk innp.array_split(X, np.arange(chunk_size, X.shape[0],27 chunk_size, dtype=np.int32),28 axis=0):29 Y_result_chunks.append(classifier_pred_or_decide(x_chunk))30
31 returnnp.concatenate(Y_result_chunks)32
33
34 def plot_2d_classification(classifier, X, fill=False, ax=None, eps=None,35 alpha=1, cm=cm3):36 #multiclass
37 if eps isNone:38 eps = X.std() / 2.39
40 if ax isNone:41 ax =plt.gca()42
43 x_min, x_max = X[:, 0].min() - eps, X[:, 0].max() +eps44 y_min, y_max = X[:, 1].min() - eps, X[:, 1].max() +eps45 xx = np.linspace(x_min, x_max, 1000)46 yy = np.linspace(y_min, y_max, 1000)47
48 X1, X2 =np.meshgrid(xx, yy)49 X_grid =np.c_[X1.ravel(), X2.ravel()]50 decision_values =classifier.predict(X_grid)51 ax.imshow(decision_values.reshape(X1.shape), extent=(x_min, x_max,52 y_min, y_max),53 aspect='auto', origin='lower', alpha=alpha, cmap=cm)54 ax.set_xlim(x_min, x_max)55 ax.set_ylim(y_min, y_max)56 ax.set_xticks(())57 ax.set_yticks(())58
59
60 def plot_2d_scores(classifier, X, ax=None, eps=None, alpha=1, cm="viridis",61 function=None):62 #binary with fill
63 if eps isNone:64 eps = X.std() / 2.65
66 if ax isNone:67 ax =plt.gca()68
69 x_min, x_max = X[:, 0].min() - eps, X[:, 0].max() +eps70 y_min, y_max = X[:, 1].min() - eps, X[:, 1].max() +eps71 xx = np.linspace(x_min, x_max, 100)72 yy = np.linspace(y_min, y_max, 100)73
74 X1, X2 =np.meshgrid(xx, yy)75 X_grid =np.c_[X1.ravel(), X2.ravel()]76 if function isNone:77 function = getattr(classifier, "decision_function",78 getattr(classifier, "predict_proba"))79 else:80 function =getattr(classifier, function)81 decision_values =function(X_grid)82 if decision_values.ndim > 1 and decision_values.shape[1] > 1:83 #predict_proba
84 decision_values = decision_values[:, 1]85 grr =ax.imshow(decision_values.reshape(X1.shape),86 extent=(x_min, x_max, y_min, y_max), aspect='auto',87 origin='lower', alpha=alpha, cmap=cm)88
89 ax.set_xlim(x_min, x_max)90 ax.set_ylim(y_min, y_max)91 ax.set_xticks(())92 ax.set_yticks(())93 returngrr94
95
96 def plot_2d_separator(classifier, X, fill=False, ax=None, eps=None, alpha=1,97 cm=cm2, linewidth=None, threshold=None,98 linestyle="solid"):99 #binary?
100 if eps isNone:101 eps = X.std() / 2.102
103 if ax isNone:104 ax =plt.gca()105
106 x_min, x_max = X[:, 0].min() - eps, X[:, 0].max() +eps107 y_min, y_max = X[:, 1].min() - eps, X[:, 1].max() +eps108 xx = np.linspace(x_min, x_max, 1000)109 yy = np.linspace(y_min, y_max, 1000)110
111 X1, X2 =np.meshgrid(xx, yy)112 X_grid =np.c_[X1.ravel(), X2.ravel()]113 if hasattr(classifier, "decision_function"):114 decision_values =_call_classifier_chunked(classifier.decision_function,115 X_grid)116 levels = [0] if threshold is None else[threshold]117 fill_levels = [decision_values.min()] + levels +[118 decision_values.max()]119 else:120 #no decision_function
121 decision_values =_call_classifier_chunked(classifier.predict_proba,122 X_grid)[:, 1]123 levels = [.5] if threshold is None else[threshold]124 fill_levels = [0] + levels + [1]125 iffill:126 ax.contourf(X1, X2, decision_values.reshape(X1.shape),127 levels=fill_levels, alpha=alpha, cmap=cm)128 else:129 ax.contour(X1, X2, decision_values.reshape(X1.shape), levels=levels,130 colors="black", alpha=alpha, linewidths=linewidth,131 linestyles=linestyle, zorder=5)132
133 ax.set_xlim(x_min, x_max)134 ax.set_ylim(y_min, y_max)135 ax.set_xticks(())136 ax.set_yticks(())137
138
139 if __name__ == '__main__':140 from sklearn.datasets importmake_blobs141 from sklearn.linear_model importLogisticRegression142 X, y = make_blobs(centers=2, random_state=42)143 clf = LogisticRegression(solver='lbfgs').fit(X, y)144 plot_2d_separator(clf, X, fill=True)145 discrete_scatter(X[:, 0], X[:, 1], y)146 plt.show()