Kaldi’s PLDA implementation is based on [1], the so-called two-covariance PLDA by [2]. The authors derive a clean update formula for the EM training and give a detailed comment in the source code. Here we add some explanations to make formula derivation more easy to catch.
A pdf version of this note can be found here
1. Background
Recall that PLDA assume a two stage generative process:
1) generate the class center according to
y∼N(μ,Φb) y ∼ N ( μ , Φ b )
2) then, generate the observed data by:
x∼N(y,Φw) x ∼ N ( y , Φ w )
Here, μ μ is estimated by the global mean value:
μ=∑k=1K∑i=1nkzki μ = ∑ k = 1 K ∑ i = 1 n k z k i
here
zki z k i depicts the
i i -th sample of the
k k -th class.
So let’s to the estimation of Φb Φ b and Φw Φ w .
Note that, as μ μ is fixed, we remove it from all samples. Hereafter, we assume all samples have pre-processed by removing mu m u from them.
The prior distribution of an arbitrary sample z z is:
p(z)∼N(0,Φw+Φw) p ( z ) ∼ N ( 0 , Φ w + Φ w )
Let’s suppose the mean of a particular class is
m m , and suppose that that class had
n n examples.
m=1n∑i=1nzi∼N(0,Φw+Φwn) m = 1 n ∑ i = 1 n z i ∼ N ( 0 , Φ w + Φ w n )
i.e. m m is Gaussian-distributed with zero mean and variance equal to the between-class variance plus 1/n 1 / n times the within-class variance. Now, m m is observed (average of all observed samples).
2. EM
We’re doing an E-M procedure where we treat m m as the sum of two variables:
m=x+y m = x + y
where x∼N(0,Φb) x ∼ N ( 0 , Φ b ) , y∼N(0,Φw/n) y ∼ N ( 0 , Φ w / n ) .
The distribution of x x will contribute to the stats of Φb Φ b , and y y to Φw Φ w .
2.1 E Step
Note that given m m , there’s only one latent variable in effect. Observe the y=m−x y = m − x , so we can focus on working out the distribution of x x and then we can very simply get the distribution of y y .
Given m m , the posterior distribution of x x is:
p(x|m)=∫ypx(x|m,y)py(y)=px(x|m)py(m−x|m) p ( x | m ) = ∫ y p x ( x | m , y ) p y ( y ) = p x ( x | m ) p y ( m − x | m )
Hereafter, we drop the condition on
m m for brevity.
p(x)=px(x)py(m−x)=N(x|0,Φb)N(x|m,Φw/n) p ( x ) = p x ( x ) p y ( m − x ) = N ( x | 0 , Φ b ) N ( x | m , Φ w / n )
Since two Gaussian’s product is Gaussian as well, we get.
p(x)=N(w,Φ^) p ( x ) = N ( w , Φ ^ )
where
Φ^=(Φ−1b+nΦ−1w)−1 Φ ^ = ( Φ b − 1 + n Φ w − 1 ) − 1 and
w=Φ^nΦ−1wm w = Φ ^ n Φ w − 1 m .
Φ^ Φ ^ and w w can be inferred by comparing the one and two order coefficients to the standard form of log Gaussian. As Kaldi’s comment does:
Note: the C is different from line to line.
lnp(x)=C−0.5(xTΦ−1bx+(m−x)TnΦ−1w(m−x))=C−0.5xT(Φ−1b+nΦ−1w)x+xTz ln p ( x ) = C − 0.5 ( x T Φ b − 1 x + ( m − x ) T n Φ w − 1 ( m − x ) ) = C − 0.5 x T ( Φ b − 1 + n Φ w − 1 ) x + x T z
where
z=nΦ−1wm z = n Φ w − 1 m , and we can write this as:
lnp(x)=C−0.5(x−w)T(Φ−1b+nΦ−1w)(x−w) ln p ( x ) = C − 0.5 ( x − w ) T ( Φ b − 1 + n Φ w − 1 ) ( x − w )
where
xT(Φ−1b+nΦ−1w)w=xTz x T ( Φ b − 1 + n Φ w − 1 ) w = x T z , i.e.
(Φ−1b+nΦ−1w)w=z=nΦ−1wm ( Φ b − 1 + n Φ w − 1 ) w = z = n Φ w − 1 m
,
so
w=(Φ−1b+nΦ−1w)−1∗nΦ−1wm w = ( Φ b − 1 + n Φ w − 1 ) − 1 ∗ n Φ w − 1 m
Φ^=(Φ−1b+nΦ−1w)−1 Φ ^ = ( Φ b − 1 + n Φ w − 1 ) − 1
2.2 M Step
The objective function of EM update is:
Q=Exlnpx(x)=Ex−0.5ln|Φb|−0.5xT(Φb)−1x=−0.5ln|Φb|−0.5tr(xxT(Φwb)−1) Q = E x ln p x ( x ) = E x − 0.5 ln | Φ b | − 0.5 x T ( Φ b ) − 1 x = − 0.5 ln | Φ b | − 0.5 t r ( x x T ( Φ w b ) − 1 )
derivative w.r.t Φw/n Φ w / n is as follows:
∂∂(Φb)=−0.5(Φb)−1+0.5(Φb)−1E[xxT](Φb)−1 ∂ ∂ ( Φ b ) = − 0.5 ( Φ b ) − 1 + 0.5 ( Φ b ) − 1 E [ x x T ] ( Φ b ) − 1
to zero it, we have:
Φ^b=Ex[xxT]=Φ^+Ex[x]Ex[x]T=Φ^+wwT Φ ^ b = E x [ x x T ] = Φ ^ + E x [ x ] E x [ x ] T = Φ ^ + w w T
Similarly, we have:
Φ^w/n=Ey[yyT]=Φ^+Ey[y]Ey[y]T=Φ^+(w−m)(w−m)T Φ ^ w / n = E y [ y y T ] = Φ ^ + E y [ y ] E y [ y ] T = Φ ^ + ( w − m ) ( w − m ) T
3. Summary
recap that given samples of certain class, we can calculate the following statistics:
Φ^=(Φ−1b+nΦ−1w)−1 Φ ^ = ( Φ b − 1 + n Φ w − 1 ) − 1
w=Φ^Φ−1wnm w = Φ ^ Φ w − 1 n m
Φ^w=n(Φ+wwT) Φ ^ w = n ( Φ + w w T )
Φ^b=Φ^+(w−m)(w−m)T Φ ^ b = Φ ^ + ( w − m ) ( w − m ) T
Given
K K classes, updated estimation via EM will be:
Φw=1N∑knk(Φ^k+wkwTk) Φ w = 1 N ∑ k n k ( Φ ^ k + w k w k T )
Φb=1K∑k(Φ^k+(wk−mk)(wk−mk)T) Φ b = 1 K ∑ k ( Φ ^ k + ( w k − m k ) ( w k − m k ) T )
Finally, Kaldi use the following update formula for Φw Φ w :
Φw=12N−K(S+∑knk(Φ^k+wkwTk)) Φ w = 1 2 N − K ( S + ∑ k n k ( Φ ^ k + w k w k T ) )
where S S is the scatter matrix S=∑k∑i(zki−ck) S = ∑ k ∑ i ( z k i − c k ) , and ck=1nk∑izki c k = 1 n k ∑ i z k i is the mean of samples of the k k -th class.
Note that S S is the result of EM used here, since m=x+y m = x + y only take pooling of data into consideration.
For other EM training, see [2] and the references therein.
References
- Ioffe. Probabilistic Linear Discriminant Analysis.
- Sizov et al. Unifying Probabilistic Linear Discriminant Analysis Variants in Biometric Authentication.