1
function
[test_targets, a]
=
Perceptron(train_patterns, train_targets, test_patterns, alg_param)
2
3
%
Classify using the Perceptron algorithm (Fixed increment single
-
sample perceptron)
4
%
Inputs:
5
%
train_patterns
-
Train patterns
6
%
train_targets
-
Train targets
7
%
test_patterns
-
Test patterns
8
%
alg_param
-
Either: Number of iterations, weights vector or [weights, number of iterations]
9
%
Outputs
10
%
test_targets
-
Predicted targets
11
%
a
-
Perceptron weights
12
%
NOTE: Works
for
only two classes.
13
%
测试用法
14
%
train_patterns
=
[
-
0.5
,
-
0.5
,
0.3
,
0.1
,
-
0.1
,
0.8
,
0.2
,
0.3
;
15
%
0.3
,
-
0.2
,
-
0.6
,
0.1
,
-
0.5
,
1.0
,
0.3
,
0.9
];
16
%
train_targets
=
[
0
,
0
,
0
,
1
,
0
,
1
,
1
,
1
];
17
%
test_patterns
=
[
0.2
-
0.3
]'
%
输出0
18
%
%
alg_param
=
100
;
19
%
alg_param
=
[
0.01
;
0.01
;
0.0
;
0.0
;
0.01
;
0.01
;
0.01
;
0.01
];
20
%
alg_param
=
[
0.01
;
0.01
;
0.0
;
0.0
;
0.01
;
0.01
;
0.01
;
0.01
;
100
];
21
%
22
%
[test_targets, a]
=
Perceptron(train_patterns, train_targets, test_patterns, alg_param)
23
%
Plotpv(train_patterns,train_targets);
%
绘点,绘制分类模式
24
%
a
=
a';
25
%
plotpc(a(
1
:
end
-
1
),a(
end
:
end
));
%
绘分割线;绘制决策面
26
27
[c, r]
=
size(train_patterns);
28
%
Weighted Perceptron or not
?
29
switch
length(alg_param),
30
case r
+
1
,
31
%
Ada boost form
32
p
=
alg_param(
1
:
end
-
1
);
33
max_iter
=
alg_param(
end
);
34
case {r,
0
},
35
%
No parameter given
36
p
=
ones(
1
,r);
37
max_iter
=
5000
;
38
otherwise
39
%
Number of iterations given
40
max_iter
=
alg_param;
41
p
=
ones(
1
,r);
42
end
43
train_patterns
=
[train_patterns ; ones(
1
,r)];
44
train_zero
=
find(train_targets
==
0
);
45
%
Preprocessing
46
y
=
train_patterns;
47
y(:,train_zero)
=
-
y(:,train_zero);
48
%
Initial weights
49
a
=
sum(y')';
50
n
=
length(train_targets);
51
iter
=
0
;
52
while
((sum(a'
*
train_patterns.
*
(
2
*
train_targets
-
1
)<
0
)
>
0
)
&&
(iter < max_iter))
53
iter
=
iter
+
1
;
54
indice
=
1
+
floor(rand(
1
)
*
n);
55
if
(a'
*
y(:,indice) <
=
0
)
56
a
=
a
+
p(indice)
*
y(:,indice);
57
end
58
end
59
if
(iter
==
max_iter)
&&
(length(alg_param)~
=
r
+
1
),
60
disp(['Maximum iteration (' num2str(max_iter) ') reached']);
61
end
62
%
Classify test patterns
63
test_targets
=
a'
*
[test_patterns; ones(
1
, size(test_patterns,
2
))]
>
0
;
64
65
![感知机分类算法](http://img.e-com-net.com/image/product/d6858c023fc74b72bace1ea739f5fdc2.jpg)