EM算法求解一维特征GMM模型纯Python实现(无numpy)

import math
import matplotlib.pyplot as plt


def load_data():
    data = []
    with open('./dataset.txt', 'r') as fp:
        for line in fp.readlines():
            data.append(float(line.strip()))
    return Array(data)


class Array:
    def __init__(self, value: list) -> None:
        self.value = value
        self.size = len(self.value)

    def __getitem__(self, index):
        return self.value[index]

    def __len__(self):
        return self.size

    def __add__(self, other):
        if isinstance(other, int) or isinstance(other, float):
            return Array([self.value[i] + other for i in range(len(self.value))])
        else:
            return Array([self.value[i] + other[i] for i in range(len(self.value))])

    def __radd__(self, other):
        return self.__add__(other)

    def __sub__(self, other):
        if isinstance(other, int) or isinstance(other, float):
            return Array([self.value[i] - other for i in range(len(self.value))])
        else:
            return Array([self.value[i] - other[i] for i in range(len(self.value))])

    def __rsub__(self, other):
        if isinstance(other, int) or isinstance(other, float):
            return Array([other - self.value[i] for i in range(len(self.value))])
        else:
            return Array([other[i] - self.value[i] for i in range(len(self.value))])

    def __mul__(self, other):
        if isinstance(other, int) or isinstance(other, float):
            return Array([self.value[i] * other for i in range(len(self.value))])
        else:
            return Array([self.value[i] * other[i] for i in range(len(self.value))])

    def __rmul__(self, other):
        return self.__mul__(other)

    def __truediv__(self, other):
        if isinstance(other, int) or isinstance(other, float):
            return Array([self.value[i] / other for i in range(len(self.value))])
        else:
            return Array([self.value[i] / other[i] for i in range(len(self.value))])

    def __rtruediv__(self, other):
        if isinstance(other, int) or isinstance(other, float):
            return Array([other / self.value[i] for i in range(len(self.value))])
        else:
            return Array([other[i] / self.value[i] for i in range(len(self.value))])

    def __pow__(self, other):
        return Array([self.value[i]**other for i in range(len(self.value))])


def sum(x):
    total = 0
    for i in x:
        total += i
    return total


def sum_0(x):
    result = []
    for i in range(len(x[0])):
        total = 0
        for j in range(len(x)):
            total += x[j][i]
        result.append(total)
    return result


def log(x):
    if isinstance(x, int) or isinstance(x, float):
        return math.log(x)
    else:
        return Array([math.log(i) for i in x])


def sqrt(x):
    if isinstance(x, int) or isinstance(x, float):
        return math.sqrt(x)
    else:
        return Array([math.sqrt(i) for i in x])


def abs(x):
    if isinstance(x, int) or isinstance(x, float):
        return x if x > 0 else -x
    else:
        return Array([i if i > 0 else -i for i in x])


def norm_pdf(x: list, loc=0, scale=1):
    return Array([1 / (math.sqrt(2 * math.pi * scale**2)) * math.exp(-1 * ((i - loc)**2) / (2 * scale**2)) for i in x])


def E(d: Array, args: list):
    a, m, s = args
    t = [a[k] * norm_pdf(d, loc=m[k], scale=s[k]) for k in range(len(a))]
    total = Array(sum_0(t))
    g = [t[k] / total for k in range(len(a))]
    return g


def M(d: Array, g: list):
    m = [sum(g[k] * d) / sum(g[k]) for k in range(len(g))]
    s = [sqrt(sum(g[k] * ((d - m[k])**2)) / sum(g[k])) for k in range(len(g))]
    a = [sum(g[k]) / len(d) for k in range(len(g))]
    q = [sum(g[k]) * log(a[k]) + sum(g[k] * (log(1 / sqrt(2 * math.pi)) - log(sqrt(s[k])) - 1 / (2 * s[k]) * ((d - m[k])**2))) for k in range(len(g))]
    q = sum(Array(q))
    return [a, m, s], q


def EM(d: Array, args: list = None):
    a, m, s = args
    print('\nInitial:\n')
    for i in range(len(a)):
        print('GM[{}] => weight: {:.4f}, mu: {:.4f}, sigma: {:.4f}'.format(i, a[i], m[i], s[i]))# made by xp
    q_old, q_new = 0, 0
    rounds = 0
    while (rounds <= 1 or abs(q_old - q_new) >= 10e-8):
        args, q = M(d, E(d, args))
        q_old, q_new = q_new, q
        rounds += 1
    print('-------------------------')
    a, m, s = args
    print('\nPrediction:\n')
    for i in range(len(a)):
        print('GM[{}] => weight: {:.4f}, mu: {:.4f}, sigma: {:.4f}'.format(i, a[i], m[i], s[i]))
    return args


def visual(args, data: Array = None):
    minx, maxx = int(min(data.value)), int(max(data.value))
    x = [minx + 0.1 * i for i in range(int((maxx - minx) / 0.1))]
    a, m, s = args
    y_list = [a[k] * norm_pdf(x, loc=m[k], scale=s[k]) for k in range(len(a))]
    plt.figure()
    for i in range(len(a)):
        plt.plot(x, y_list[i].value, label='$w={:.4f}, \mu={:.4f}, \sigma={:.4f}$'.format(a[i], m[i], s[i]))
    plt.plot(x, sum_0(y_list), label='Prediction')
    plt.hist(data, bins=maxx - minx, range=(minx, maxx), density=True, label='True', alpha=0.3)
    plt.legend(loc=2, prop={"size": 7})
    plt.savefig('./em.png')
    plt.show()
    plt.close()


if __name__ == '__main__':
    data = load_data()
    args = EM(data, args=[[0.3, 0.3, 0.4], [150, 170, 190], [4, 4, 4]])
    visual(args, data)

你可能感兴趣的:(python,算法,深度学习)