import numpy as np
from sklearn.metrics.pairwise import pairwise_kernels, pairwise_distances, nan_euclidean_distances
from sklearn import datasets
from numpy.linalg import LinAlgError
from scipy.spatial.distance import pdist, squareform
X, y = datasets.load_iris(return_X_y=True)
# print(X.shape)
X[10,1] = np.NaN
X[100,0] = np.NaN
X[5,3] = np.NaN
X[50,2] = np.NaN
# DM = pairwise_distances(X=X, metric='nan_euclidean')
# print(DM.shape)
def EM(Xm):
success = 1
Xi = Xm.copy()
mu = np.nanmean(Xm, axis=0)
print("mu:",mu)
for i in range(Xm.shape[1]):
I = np.isnan(Xm[:, i])
# print("I",I)
Xi[I, i] = mu[i]
# print(Xi[10,1],Xi[100,0],Xi[5,3],Xi[50,2])
mu = np.nanmean(Xi, axis=0)
sigma = np.cov(Xi, rowvar=False, bias=True)
r1 = np.inf
iters = 1
max_iter = 200
tol = 1e-5
N, n = Xm.shape
while iters < max_iter:
r = r1
iters += 1
B = np.zeros((n, n))
for j in range(N):
X1 = Xm[j, :]
mi1 = np.where(np.isnan(X1))[0]
av1 = np.where(~np.isnan(X1))[0]
if len(av1) < n:
B[np.ix_(mi1, mi1)] += (sigma[np.ix_(mi1, mi1)] -
np.dot(np.dot(sigma[np.ix_(mi1, av1)],
np.linalg.inv(sigma[np.ix_(av1, av1)])),
sigma[np.ix_(av1, mi1)]))
Xi[j, mi1] = mu[mi1] + np.dot(np.dot(sigma[np.ix_(mi1, av1)],
np.linalg.inv(sigma[np.ix_(av1, av1)])),
(Xi[j, av1] - mu[av1]))
mu = np.nanmean(Xi, axis=0)
sigma = np.cov(Xi, rowvar=False, bias=True) + B / N
try:
U = np.linalg.cholesky(sigma)
except LinAlgError:
success = 0
break
y = Xi - mu
r1 = np.sum((0.5) * (2 * np.sum(np.log(np.diag(U))) +
np.diag(np.dot(np.dot(y, np.linalg.inv(sigma)), y.T)) +
n * np.log(2 * np.pi)))
print("r1:",r1,"r:",r)
if np.linalg.norm(r1 - r) < tol:
break
mu = mu.reshape(-1, 1)
return mu, sigma, success
# Example usage
# Xm = ... # Input data set with missing values
# mu, sigma, success = EM(Xm)
def getcovariances(M, missing, notmissing):
Cov11 = M.copy()
Cov11 = np.delete(Cov11, notmissing, axis=0)
Cov11 = np.delete(Cov11, notmissing, axis=1)
Cov12 = M.copy()
Cov12 = np.delete(Cov12, notmissing, axis=0)
Cov12 = np.delete(Cov12, missing, axis=1)
Cov21 = M.copy()
Cov21 = np.delete(Cov21, missing, axis=0)
Cov21 = np.delete(Cov21, notmissing, axis=1)
Cov22 = M.copy()
Cov22 = np.delete(Cov22, missing, axis=0)
Cov22 = np.delete(Cov22, missing, axis=1)
covariances = [Cov11, Cov12, Cov21, Cov22]
return covariances
def ecmnmlefunc(X):
success = 1
np.seterr(all='ignore')
if np.sum(np.isnan(X)) > 0:
Nx, nx = X.shape
Mu, Covariance, success = EM(X.copy())
if not success:
Xi = X.copy()
sx = np.zeros(X.shape)
return Xi, sx, success
sx = np.zeros((Nx, nx))
for i in range(Nx):
missing = np.where(np.isnan(X[i, :]))[0]
notmissing = np.where(~np.isnan(X[i, :]))[0]
if len(missing) != nx and len(notmissing) != nx:
Cov = getcovariances(Covariance, missing, notmissing)
mu1 = np.zeros(nx)
# print("Mu[missing]:",Mu[missing].shape)
# print(X[i, notmissing].shape)
# print( Mu[notmissing].shape)
# print((X[i, notmissing] - Mu[notmissing]))
# print("==========")
# print(X[i, notmissing].reshape(-1,1).shape)
# print("=")
# print(Cov[2].shape)
# print(np.linalg.inv(Cov[3]).shape)
# print(np.dot( np.linalg.inv(Cov[3]),Cov[2]))
# print(np.dot(Cov[2], np.linalg.inv(Cov[3])).shape)
mu1[missing] = Mu[missing] + np.dot(np.dot(np.linalg.inv(Cov[3]),Cov[2]).T, (X[i, notmissing].reshape(-1,1) - Mu[notmissing]))
X[i, missing] = mu1[missing]
# print("Cov[0]:",Cov[0].shape)
# print(Cov[2].shape)
# print(np.dot(np.linalg.inv(Cov[3]),Cov[2]).shape)
# print(Cov[1].shape)
cov1 = Cov[0] - np.dot(Cov[1], np.dot(np.linalg.inv(Cov[3]),Cov[2]))
sx[i, missing] = np.diag(cov1)
else:
sx = np.zeros(X.shape)
Xi = X
return Xi, sx, success
# mu, sigma, success = EM(Xm=X)
# print("mu:",mu)
# print("sigma:",sigma)
# print("success:",success)
# Xi, sx, success = ecmnmlefunc(X=X)
# print("sx",sx)
# print("success:",success)
def ESD(X):
X, sx, _ = ecmnmlefunc(X)
sx = np.sum(sx, axis=1)
N, _ = X.shape
ss = np.zeros(N*N-N*(N+1)//2)
idx = 0
for i in range(len(sx) - 1):
for j in range(i + 1, len(sx)):
ss[idx] = sx[i] + sx[j]
idx += 1
D = np.sqrt((pdist(X)**2) + ss)
return D
print("ESD::::::",ESD(X).shape)
dist_matrix = squareform(ESD(X))
print(dist_matrix.shape)