维特比、beam_search的实现

import numpy as np
import random
import copy


'''
维特比解码和beam search
'''

class Fence:
    def __init__(self, n, h):
        self.width = n
        self.height = h

    #用行列组成的list代表一个节点,每两个相邻的列的节点之间可以计算距离
    #e.g:node1 = [2,1] node2 = [3, 2]
    #为两个节点给一个固定的路径分值
    def score(self, node1, node2):
        if node1 == "start":
            return (node2[0] + node2[1] + 1) / (node2[0] * node2[1] + 1)
        assert node1[1] + 1 == node2[1] #保证两个节点处于相邻列
        mod = (node1[0] + node1[1] + node2[0] + node2[1]) % 3 + 1
        mod /= node1[0] * 4 + node1[1] * 3 + node2[0] * 2 + node2[1] * 1
        return mod

class Path:
    #定义一个路径
    #路径由数个节点组成,并且具有一个路径总分
    def __init__(self):
        self.nodes = ["start"]
        self.score = 0

    def __len__(self):
        return len(self.nodes)

def beam_search(fence, beam_size):
    width = fence.width
    height = fence.height
    starter = Path()
    beam_buffer = [starter]
    new_beam_buffer = []
    while True:
        for path in beam_buffer:
            path_length = len(path) - 1
            for h in range(height):
                node = [h, path_length]
                new_path = copy.deepcopy(path)
                new_path.score += fence.score(path.nodes[-1], node)
                new_path.nodes.append(node)
                new_beam_buffer.append(new_path)
        new_beam_buffer = sorted(new_beam_buffer, key=lambda x:x.score)
        beam_buffer = new_beam_buffer[:beam_size]
        new_beam_buffer = []
        if len(beam_buffer[0]) == width + 1:
            break
    return beam_buffer

def viterbi(fence):
    width = fence.width
    height = fence.height
    starter = Path()
    beam_buffer = [starter]
    new_beam_buffer = []
    while True:
        for h in range(height):
            path_length = len(beam_buffer[0]) - 1
            node = [h, path_length]
            node_path = []
            for path in beam_buffer:
                new_path = copy.deepcopy(path)
                new_path.score += fence.score(path.nodes[-1], node)
                new_path.nodes.append(node)
                node_path.append(new_path)
            node_path = sorted(node_path, key=lambda x:x.score)
            new_beam_buffer.append(node_path[0])
        beam_buffer = new_beam_buffer
        new_beam_buffer = []
        if len(beam_buffer[0]) == width + 1:
            break
    return sorted(beam_buffer, key=lambda x:x.score)

width = 6
height = 4
fence = Fence(width, height)
# print(fence.score([1,2], [3,3]))

beam_size = 1
res = beam_search(fence, beam_size)
for i in range(beam_size):
    print(res[i].nodes, res[i].score)
print("-----------")
res = viterbi(fence)
for path in res:
    print(path.nodes, path.score)

你可能感兴趣的:(NLP,python,开发语言)