HMM viterbi 算法

好久没更新了,写了一个 viterbi 算法,主要是加深理解哈

# coding=utf-8
"""
首先通过语料库计算出 HMM 的三要素:
初始状态π
状态转移矩阵 A
发射矩阵 B
然后用 Viterbi 算法进行切词操作
"""

TRAIN_CORPUS = 'trainCorpus.txt_utf8'
PROB_INIT = 'prob_init.txt'
PROB_EMIT = 'prob_emit.txt'
PROB_TRANS = 'prob_trans.txt'


def train_hmm(input_data):
    init_dict = {'S': 0, 'B': 0, 'M': 0, 'E': 0}
    emit_dict = {'S': {}, 'B': {}, 'M': {}, 'E': {}}
    trans_dict = {'S': {'S': 0, 'B': 0, 'M': 0, 'E': 0},
                  'B': {'S': 0, 'B': 0, 'M': 0, 'E': 0},
                  'M': {'S': 0, 'B': 0, 'M': 0, 'E': 0},
                  'E': {'S': 0, 'B': 0, 'M': 0, 'E': 0}
                  }

    def get_sign(sentence):
        temp = []
        words = sentence.strip().split()
        for w in words:
            w = list(w.decode('utf-8'))
            if len(w) == 1:
                temp.append('S')
            else:
                temp.append('B')
                for i in range(len(w) - 2):
                    temp.append('M')
                temp.append('E')
        return temp

    with open(input_data, 'r') as fi:
        for line in fi:
            sign = get_sign(line)
            # init
            if len(sign) > 1:
                if sign[0] == 'S':
                    init_dict['S'] += 1
                elif sign[0] == 'B':
                    init_dict['B'] += 1

            # emit_dict
            line = list(line.strip().replace(' ', '').decode('utf-8'))
            for i in range(len(line)):
                if line[i] not in emit_dict[sign[i]]:
                    emit_dict[sign[i]][line[i]] = 0
                emit_dict[sign[i]][line[i]] += 1

            # trans
            sign_len = len(sign)
            for i in range(1, sign_len):
                trans_dict[sign[i - 1]][sign[i]] += 1

    init_wr = open(PROB_INIT, 'w+')
    init_total = float(sum([init_dict[i] for i in init_dict]))
    for i in init_dict:
        init_wr.write(i + '\t' + str(init_dict[i] / init_total) + '\n')

    emit_wr = open(PROB_EMIT, 'w+')
    for i in emit_dict:
        emit_total = float(sum([emit_dict[i][s] for s in emit_dict[i]]))
        state_str = str(i)
        for w in emit_dict[i]:
            temp = emit_dict[i][w] / emit_total
            state_str += '\t' + w.encode('utf-8') + ':' + str(temp)
        emit_wr.write(state_str + '\n')

    trans_wr = open(PROB_TRANS, 'w+')
    for i in trans_dict:
        trans_total = float(sum([trans_dict[i][s] for s in trans_dict[i]]))
        trans_wr.write(
            i + '\tE:' + str(trans_dict[i]['E'] / trans_total) + '\tS:' + str(trans_dict[i]['S'] / trans_total)
            + '\tB:' + str(trans_dict[i]['B'] / trans_total) + '\tM:' + str(trans_dict[i]['M'] / trans_total) + '\n')


def viterbi_seg(sentence):
    # 载入 π、A、B
    init_dict = {}
    emit_dict = {}
    trans_dict = {}
    with open(PROB_INIT, 'r') as r_init:
        for line in r_init:
            line = line.strip().split('\t')
            init_dict[line[0]] = float(line[1])
    with open(PROB_TRANS, 'r') as r_trans:
        for line in r_trans:
            line = line.strip().split('\t')
            trans_dict[line[0]] = {}
            for i in line[1:]:
                i = i.split(':')
                trans_dict[line[0]][i[0]] = float(i[1])
    with open(PROB_EMIT, 'r') as r_emit:
        for line in r_emit:
            line = line.strip().split('\t')
            emit_dict[line[0]] = {}
            for i in line[1:]:
                i = i.split(':')
                emit_dict[line[0]][i[0]] = float(i[1])

    def viterbi(obs, states, start_p, trans_p, emit_p):  # 维特比算法(一种递归算法)
        obs = obs.decode('utf-8')
        V = [{}]
        path = {}
        for y in states:  # 初始值
            V[0][y] = start_p[y] * emit_p[y].get(obs[0].encode('utf-8'), 0)  # 在位置0,以y状态为末尾的状态序列的最大概率
            path[y] = [y]
        for t in range(1, len(obs)):
            V.append({})
            newpath = {}
            for y in states:  # 从y0 -> y状态的递归
                (prob, state) = max(
                    [(V[t - 1][y0] * trans_p[y0].get(y, 0) * emit_p[y].get(obs[t].encode('utf-8'), 0), y0) for y0 in
                     states if
                     V[t - 1][y0] > 0])
                V[t][y] = prob
                newpath[y] = path[state] + [y]
            path = newpath  # 记录状态序列
        (prob, state) = max([(V[len(obs) - 1][y], y) for y in states])  # 在最后一个位置,以y状态为末尾的状态序列的最大概率
        return (prob, path[state])  # 返回概率和状态序列

    res = viterbi(sentence, ('B', 'M', 'E', 'S'), init_dict, trans_dict, emit_dict)
    sen_utf = list(sentence.decode('utf-8'))
    temp = []
    for i in range(len(sen_utf)):
        temp.append(sen_utf[i])
        if res[1][i] == 'S' or res[1][i] == 'E':
            temp.append(' ')
    return ''.join(temp)


while (True):
    a = raw_input("input:")
    print viterbi_seg(a)

语料地址

你可能感兴趣的:(HMM viterbi 算法)