LVQ(Learning Vector Quantization)是三层网络, f : R n → { 0 , 1 } C f:\R^n\rightarrow\{0,1\}^C f:Rn→{0,1}C,分两层映射: f = g ∘ h f=g\circ h f=g∘h。
第一层 h ( ⋅ ; W ) : R n → { 0 , 1 } m h(\cdot;W):\R^n\rightarrow\{0,1\}^m h(⋅;W):Rn→{0,1}m 是竞争层,参数 W = ( w 1 , … , w m ) ∈ R n × m W=(w_1,\dots,w_m)\in\R^{n\times m} W=(w1,…,wm)∈Rn×m 是 m 个与输入同维的向量, m ≥ C m\geq C m≥C,作为簇中心分割空间(见 [1] 蜂窝图),输入 x 首先与这些中心算距离 d ∈ R m d\in\R^m d∈Rm, d i = dist ( x , w i ) d_i=\text{dist}(x,w_i) di=dist(x,wi),归入最近的中心那一簇,输出一个 one-hot 向量 z ∈ { 0 , 1 } m z\in\{0,1\}^m z∈{0,1}m, z i = 1 z_i=1 zi=1 当且仅当 x 离 w i w_i wi 最近。竞争的含义有两层:forward 时只有一个 z i = 1 z_i=1 zi=1,backward 时只对对应的 w i w_i wi 做更新。
第二层 g : { 0 , 1 } m → { 0 , 1 } C g:\{0,1\}^m\rightarrow\{0,1\}^C g:{0,1}m→{0,1}C 是对竞争层输出的综合,每个输出神经元仅与其中某几个隐层输出相连,且权重恒为 1,即此层的映射矩阵是常量(见 [1] 的 W 2 W ^2 W2)。理解成每个类在特征空间占领几个簇,x 比分到 c 类,当且仅当它在隐层比分在 c 类对应的其中一个簇。最终输出就是预测向量 y ^ ∈ { 0 , 1 } C \hat y\in\{0,1\}^C y^∈{0,1}C是 one-hot 的,不是 soft 的概率向量。
更新 W 时,只更新与 x 最近的中心 w c w_c wc:
其中 η t \eta_t ηt 是第 t 步的学习率,随迭代递减: η t = η ( 1 − t #max iter ) \eta_t=\eta(1-\frac{t}{\text{\#max iter}}) ηt=η(1−#max itert)。
用绝对值损失: l = ∣ y − y ^ ∣ l=|y-\hat y| l=∣y−y^∣,则
∂ l ∂ z i = { 1 , y ^ i = 1 ∧ y i = 0 − 1 , y ^ i = 0 ∧ y i = 1 0 , e l s e ( 包 括 分 类 正 确 ) \frac{\partial l}{\partial z_i}=\begin{cases} 1, &\hat y_i=1\wedge y_i=0 \\ -1,&\hat y_i=0\wedge y_i=1 \\ 0,&else(包括分类正确) \end{cases} ∂zi∂l=⎩⎪⎨⎪⎧1,−1,0,y^i=1∧yi=0y^i=0∧yi=1else(包括分类正确)
对于 ∂ l ∂ z c \frac{\partial l}{\partial z_c} ∂zc∂l,只有 0/1 两种,所以更新公式可以改写成: w c : = w c − η t ( x − w c ) ( 2 ⋅ ∂ l ∂ z c − 1 ) w_c:=w_c-\eta_t(x-w_c)(2\cdot\frac{\partial l}{\partial z_c}-1) wc:=wc−ηt(x−wc)(2⋅∂zc∂l−1)。
import argparse
import os
import tensorflow as tf
from tensorflow import keras as K
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
parser = argparse.ArgumentParser()
parser.add_argument("--n_class", type=int, default=10)
parser.add_argument("--n_hid_pc", type=int, default=10)
parser.add_argument("--lr", type=float, default=0.001)
parser.add_argument("--epoch", type=int, default=5)
parser.add_argument("--batch_size", type=int, default=64)
args = parser.parse_args()
# data
mnist = K.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data() # [n, 28, 28], [n]
x_train, x_test = x_train / 255.0, x_test / 255.0
print("label max:", tf.reduce_max(y_test))
x_train = tf.reshape(x_train, [-1, 784])
x_test = tf.reshape(x_test, [-1, 784])
y_train = tf.one_hot(y_train, 10)
y_test = tf.one_hot(y_test, 10)
print("data:", type(x_train), x_train.shape, y_test.shape)
train_ds = tf.data.Dataset.from_tensor_slices(
(x_train, y_train)).shuffle(10000).batch(args.batch_size)
test_ds = tf.data.Dataset.from_tensor_slices(
(x_test, y_test)).batch(args.batch_size)
def euclidean(A, B=None, sqrt=False):
if (B is None) or (B is A):
aTb = tf.matmul(A, tf.transpose(A))
aTa = bTb = tf.linalg.diag_part(aTb)
else:
aTb = tf.matmul(A, tf.transpose(B))
aTa = tf.linalg.diag_part(tf.matmul(A, tf.transpose(A)))
bTb = tf.linalg.diag_part(tf.matmul(B, tf.transpose(B)))
D = aTa[:, None] - 2.0 * aTb + bTb[None, :]
D = tf.maximum(D, 0.0)
if sqrt:
mask = tf.cast(tf.equal(D, 0.0), "float32")
D = D + mask * 1e-16
D = tf.math.sqrt(D)
D = D * (1.0 - mask)
return D
def top_k_mask(D, k, rand_pick=False):
"""M[i][j] = 1 <=> D[i][j] is oen of the BIGGEST k in i-th row
Args:
D: (n, n), distance matrix
k: param `k` of kNN
rand_pick: true or false
- if `True`, only ONE of the top-K element in each row will be selected randomly;
- if `False`, ALL the top-K elements will be selected as usual.
Ref:
- https://cloud.tencent.com/developer/ask/196899
- https://blog.csdn.net/HackerTom/article/details/103587415
"""
n_row = tf.shape(D)[0]
n_col = tf.shape(D)[1]
k_val, k_idx = tf.math.top_k(D, k)
if rand_pick:
c_idx = tf.random_uniform([n_row, 1],
minval=0, maxval=k,
dtype="int32")
r_idx = tf.range(n_row, dtype="int32")[:, None]
idx = tf.concat([r_idx, c_idx], axis=1)
k_idx = tf.gather_nd(k_idx, idx)[:, None]
idx_offset = (tf.range(n_row) * n_col)[:, None]
k_idx_linear = k_idx + idx_offset
k_idx_flat = tf.reshape(k_idx_linear, [-1, 1])
updates = tf.ones_like(k_idx_flat[:, 0], "int32")
mask = tf.scatter_nd(k_idx_flat, updates, [n_row * n_col])
mask = tf.reshape(mask, [-1, n_col])
mask = tf.cast(mask, "float32")
return mask
@tf.custom_gradient
def lvq(X, W):
"""X: [n, d]
W: [m, d]
"""
D = euclidean(X, W) # [n, m]
y = top_k_mask(- D, 1) # [n, m], minus for nearest
def grad(dy):
# dy: [n, m]
mask_sgn = tf.expand_dims((2 * dy - 1) * y, 2) # [n, m, 1]
X_minus_W = tf.expand_dims(X, 1) - tf.expand_dims(W, 0) # [n, m, d]
dW = tf.reduce_sum(X_minus_W * mask_sgn, 0)
dW = dW / tf.maximum(1., tf.reduce_sum(mask_sgn, 0))
return X, dW
return y, grad
class LVQ(K.Model):
def __init__(self, dim, n_class, n_hid_pc):
super(LVQ, self).__init__()
self.W = tf.Variable(tf.random.truncated_normal(
[n_class*n_hid_pc, dim]))
Q = np.zeros([n_hid_pc*n_class, n_class])
for i in range(n_class):
Q[i*n_class:(i+1)*n_class, i] = 1
self.Q = tf.constant(Q, dtype="float32")
def call(self, x):
z = lvq(x, self.W)
return tf.matmul(z, self.Q)
model = LVQ(x_train.shape[1], args.n_class, args.n_hid_pc)
class LVQ_Schedule(K.optimizers.schedules.LearningRateSchedule):
"""lr(n) = lr(0) * (1 - n / N)
ref: https://github.com/tensorflow/tensorflow/blob/v2.2.0/tensorflow/python/keras/optimizer_v2/learning_rate_schedule.py#L410-L511
"""
def __init__(self, initial_learning_rate, n_iter):
super(LVQ_Schedule, self).__init__()
self.initial_learning_rate = tf.convert_to_tensor(
initial_learning_rate, dtype="float32")
self.n_iter = tf.convert_to_tensor(n_iter, dtype="float32")
def __call__(self, step):
return self.initial_learning_rate * (1. - step / self.n_iter)
def get_config(self):
return {
"initial_learning_rate": self.initial_learning_rate,
"n_iter": self.n_iter
}
# optimizer = K.optimizers.SGD(
# learning_rate=LVQ_Schedule(args.lr, args.epoch), momentum=0.9)
optimizer = K.optimizers.Adam()
train_loss = K.metrics.Mean(name='train_loss')
train_accuracy = K.metrics.CategoricalAccuracy(name='train_accuracy')
test_loss = K.metrics.Mean(name='test_loss')
test_accuracy = K.metrics.CategoricalAccuracy(name='test_accuracy')
#@tf.function
def train_step(images, labels):
with tf.GradientTape() as tape:
pred = model(images)
loss = tf.reduce_sum(tf.math.abs(labels - pred))
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_loss(loss)
train_accuracy(labels, pred)
return loss
#@tf.function
def test_step(images, labels):
pred = model(images)
t_loss = tf.reduce_sum(tf.math.abs(labels - pred))
test_loss(t_loss)
test_accuracy(labels, pred)
pred = tf.argmax(pred, axis=1)
true = tf.argmax(labels, axis=1) # tf.cast(labels, "int64")
n_correct = tf.reduce_sum(tf.cast(pred == true, "float32"))
return n_correct
loss_list, acc_list = [], []
for epoch in range(args.epoch):
# 在下一个epoch开始时,重置评估指标
train_loss.reset_states()
train_accuracy.reset_states()
test_loss.reset_states()
test_accuracy.reset_states()
for images, labels in train_ds:
# images = tf.image.resize(images, [224, 224])
# images = tf.tile(images, tf.constant([1, 1, 1, 3]))
l = train_step(images, labels)
loss_list.append(l.numpy())
n_corr = 0
for images, labels in test_ds:
# images = tf.image.resize(images, [224, 224])
# images = tf.tile(images, tf.constant([1, 1, 1, 3]))
_n_corr = test_step(images, labels)
n_corr += _n_corr.numpy()
acc_list.append(n_corr / y_test.shape[0])
template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
print(template.format(epoch+1,
train_loss.result(),
train_accuracy.result()*100,
test_loss.result(),
test_accuracy.result()*100))
# plot loss
fig = plt.figure()
plt.title("loss")
plt.plot(np.arange(len(loss_list)), loss_list)
# plt.show()
fig.savefig("loss.png")
# plot accuracy
fig = plt.figure()
plt.title("accuracy")
plt.plot(np.arange(len(acc_list)), acc_list)
# plt.show()
fig.savefig("accuracy.png")
# T-SNE
F = model.W.numpy()
w_label = tf.repeat(tf.range(args.n_class), args.n_hid_pc).numpy()
fig = plt.figure()
plt.title("T-SNE")
for i in range(F.shape[0]):
plt.text(F[i, 0], F[i, 1], str(w_label[i]),
color=plt.cm.Set1(w_label[i] / 10.))
fig.savefig("tsne.png")