import math
import matplotlib.pyplot as plt
def load_data():
data = []
with open('./dataset.txt', 'r') as fp:
for line in fp.readlines():
data.append(float(line.strip()))
return Array(data)
class Array:
def __init__(self, value: list) -> None:
self.value = value
self.size = len(self.value)
def __getitem__(self, index):
return self.value[index]
def __len__(self):
return self.size
def __add__(self, other):
if isinstance(other, int) or isinstance(other, float):
return Array([self.value[i] + other for i in range(len(self.value))])
else:
return Array([self.value[i] + other[i] for i in range(len(self.value))])
def __radd__(self, other):
return self.__add__(other)
def __sub__(self, other):
if isinstance(other, int) or isinstance(other, float):
return Array([self.value[i] - other for i in range(len(self.value))])
else:
return Array([self.value[i] - other[i] for i in range(len(self.value))])
def __rsub__(self, other):
if isinstance(other, int) or isinstance(other, float):
return Array([other - self.value[i] for i in range(len(self.value))])
else:
return Array([other[i] - self.value[i] for i in range(len(self.value))])
def __mul__(self, other):
if isinstance(other, int) or isinstance(other, float):
return Array([self.value[i] * other for i in range(len(self.value))])
else:
return Array([self.value[i] * other[i] for i in range(len(self.value))])
def __rmul__(self, other):
return self.__mul__(other)
def __truediv__(self, other):
if isinstance(other, int) or isinstance(other, float):
return Array([self.value[i] / other for i in range(len(self.value))])
else:
return Array([self.value[i] / other[i] for i in range(len(self.value))])
def __rtruediv__(self, other):
if isinstance(other, int) or isinstance(other, float):
return Array([other / self.value[i] for i in range(len(self.value))])
else:
return Array([other[i] / self.value[i] for i in range(len(self.value))])
def __pow__(self, other):
return Array([self.value[i]**other for i in range(len(self.value))])
def sum(x):
total = 0
for i in x:
total += i
return total
def sum_0(x):
result = []
for i in range(len(x[0])):
total = 0
for j in range(len(x)):
total += x[j][i]
result.append(total)
return result
def log(x):
if isinstance(x, int) or isinstance(x, float):
return math.log(x)
else:
return Array([math.log(i) for i in x])
def sqrt(x):
if isinstance(x, int) or isinstance(x, float):
return math.sqrt(x)
else:
return Array([math.sqrt(i) for i in x])
def abs(x):
if isinstance(x, int) or isinstance(x, float):
return x if x > 0 else -x
else:
return Array([i if i > 0 else -i for i in x])
def norm_pdf(x: list, loc=0, scale=1):
return Array([1 / (math.sqrt(2 * math.pi * scale**2)) * math.exp(-1 * ((i - loc)**2) / (2 * scale**2)) for i in x])
def E(d: Array, args: list):
a, m, s = args
t = [a[k] * norm_pdf(d, loc=m[k], scale=s[k]) for k in range(len(a))]
total = Array(sum_0(t))
g = [t[k] / total for k in range(len(a))]
return g
def M(d: Array, g: list):
m = [sum(g[k] * d) / sum(g[k]) for k in range(len(g))]
s = [sqrt(sum(g[k] * ((d - m[k])**2)) / sum(g[k])) for k in range(len(g))]
a = [sum(g[k]) / len(d) for k in range(len(g))]
q = [sum(g[k]) * log(a[k]) + sum(g[k] * (log(1 / sqrt(2 * math.pi)) - log(sqrt(s[k])) - 1 / (2 * s[k]) * ((d - m[k])**2))) for k in range(len(g))]
q = sum(Array(q))
return [a, m, s], q
def EM(d: Array, args: list = None):
a, m, s = args
print('\nInitial:\n')
for i in range(len(a)):
print('GM[{}] => weight: {:.4f}, mu: {:.4f}, sigma: {:.4f}'.format(i, a[i], m[i], s[i]))
q_old, q_new = 0, 0
rounds = 0
while (rounds <= 1 or abs(q_old - q_new) >= 10e-8):
args, q = M(d, E(d, args))
q_old, q_new = q_new, q
rounds += 1
print('-------------------------')
a, m, s = args
print('\nPrediction:\n')
for i in range(len(a)):
print('GM[{}] => weight: {:.4f}, mu: {:.4f}, sigma: {:.4f}'.format(i, a[i], m[i], s[i]))
return args
def visual(args, data: Array = None):
minx, maxx = int(min(data.value)), int(max(data.value))
x = [minx + 0.1 * i for i in range(int((maxx - minx) / 0.1))]
a, m, s = args
y_list = [a[k] * norm_pdf(x, loc=m[k], scale=s[k]) for k in range(len(a))]
plt.figure()
for i in range(len(a)):
plt.plot(x, y_list[i].value, label='$w={:.4f}, \mu={:.4f}, \sigma={:.4f}$'.format(a[i], m[i], s[i]))
plt.plot(x, sum_0(y_list), label='Prediction')
plt.hist(data, bins=maxx - minx, range=(minx, maxx), density=True, label='True', alpha=0.3)
plt.legend(loc=2, prop={"size": 7})
plt.savefig('./em.png')
plt.show()
plt.close()
if __name__ == '__main__':
data = load_data()
args = EM(data, args=[[0.3, 0.3, 0.4], [150, 170, 190], [4, 4, 4]])
visual(args, data)