DTW标准代码,在轨迹相似度有过应用

DTW 的理解思路还是按照动态规划的思路 ,和LeetCode的72题编辑距离以及求最短路径类似。DTW会重复使用序列中的点,从而达到扭曲对齐的.

一般都是用两个指针i,j分别指向两个列表的最后,然后一步步往前走,缩小问题的规模。先计算a[i]和b[j]的两点距离,然后开始移动指针i和j,可以i,j一起移动到i-1,j-1,也可以i或者j只移动一个即i-1,j和 i,j-1。那么dp[i,j]= distance(i,j)+min(dp[i-1,j-1],dp[i-1,j],dp(i,j-1))

  1. dp[i,j]的含义是存储两个序列a,b的最短路径距离
  2. dp[i,j]可以由dp[i-1,j],dp[i,j-1],dp[i-1,j-1]推导得到,从三者中找出最小值再加上a[i]和b[j]的两点距离
  3. base case就是i,j为0的时候,设为无穷大即可
import numpy as np
a = np.random.randint(0,5,5)
b = np.random.randint(0,5,2)
a,b
(array([3, 1, 3, 1, 4]), array([1, 2]))
l1 = len(a)
l2 = len(b)

dp table备忘录

dp = np.full((l1+1,l2+1),fill_value=float('inf'))
dp[0,0]=0

choices记录移动方向,初始化
最终要从dp[i,j]往dp[1,1]的回推

choices = np.full((l1+1,l2+1),fill_value='45')
choices
array([['45', '45', '45'],
       ['45', '45', '45'],
       ['45', '45', '45'],
       ['45', '45', '45'],
       ['45', '45', '45'],
       ['45', '45', '45']], dtype='

计算两点距离

def distance(m,n):
    return np.abs((m-n))
    

DTW

for i in range(1,l1+1):
    for j in range(1,l2+1):
        which = np.argmin((dp[i-1,j-1],dp[i-1,j],dp[i,j-1]))
        if which==0:
            pass
        elif which==1:
            choices[i,j]='up'
        else:
            choices[i,j] = 'lf'
        dp[i,j] = min(dp[i-1,j-1],dp[i-1,j],dp[i,j-1])+distance(a[i-1],b[j-1])
        
dp
array([[ 0., inf, inf],
       [inf,  2.,  3.],
       [inf,  2.,  3.],
       [inf,  4.,  3.],
       [inf,  4.,  4.],
       [inf,  7.,  6.]])
choices
array([['45', '45', '45'],
       ['45', '45', 'lf'],
       ['45', 'up', '45'],
       ['45', 'up', '45'],
       ['45', 'up', 'up'],
       ['45', 'up', '45']], dtype='

你可能感兴趣的:(动态规划,相似度,动态规划,python)