欧式距离计算

package com.dp.arts.biz.processor.post;


import java.util.Random;

public class AAA {
    public double hiddenDistance(double[] v0, double[] v1) {


        double dist = 0;
        double t;
        for (int i = 0; i < v0.length; i++) {
            t = v0[i] - v1[i];
            dist += (t*t);
        }


        return dist;
    }


    public static float hiddenDistance1(float[] v0, float[] v1) {
        float dist = 0;
        for (int i = 0; i < v0.length; ++i) {
            float t1 = v0[i + 0] - v1[i + 0];
            dist += t1 * t1;
        }
        return dist;
    }


    public static float hiddenDistance4(float[] v0, float[] v1) {
        float dist = 0;
        float dist1 = 0, dist2 = 0, dist3 = 0, dist4 = 0;
        int loops = (v0.length / 4) * 4;
        for (int i = 0; i < loops; i += 4) {
            float t1 = v0[i + 0] - v1[i + 0];
            float t2 = v0[i + 1] - v1[i + 1];
            float t3 = v0[i + 2] - v1[i + 2];
            float t4 = v0[i + 3] - v1[i + 3];
            dist1 += t1*t1;
            dist2 += t2*t2;
            dist3 += t3*t3;
            dist4 += t4*t4;
        }
        dist = dist1 + dist2 + dist3 + dist4;
        for (int i = loops; i < v0.length; ++i) {
            float t = v0[i] - v1[i];
            dist += t * t;
        }
        return dist;
    }


    public static float hiddenDistance8(float[] v0, float[] v1) {
        float dist = 0;
        float dist1 = 0, dist2 = 0, dist3 = 0, dist4 = 0, dist5 = 0, dist6 = 0, dist7 = 0, dist8 = 0;
        int loops = (v0.length / 8) * 8;
        for (int i = 0; i < loops; i += 8) {
            float t1 = v0[i + 0] - v1[i + 0];
            float t2 = v0[i + 1] - v1[i + 1];
            float t3 = v0[i + 2] - v1[i + 2];
            float t4 = v0[i + 3] - v1[i + 3];
            float t5 = v0[i + 4] - v1[i + 4];
            float t6 = v0[i + 5] - v1[i + 5];
            float t7 = v0[i + 6] - v1[i + 6];
            float t8 = v0[i + 7] - v1[i + 7];
            dist1 += t1*t1;
            dist2 += t2*t2;
            dist3 += t3*t3;
            dist4 += t4*t4;
            dist5 += t5*t5;
            dist6 += t6*t6;
            dist7 += t7*t7;
            dist8 += t8*t8;
        }
        dist = dist1 + dist2 + dist3 + dist4 + dist5 + dist6 + dist7 + dist8;
        for (int i = loops; i < v0.length; ++i) {
            float t = v0[i] - v1[i];
            dist += t * t;
        }
        return dist;
    }


    private static void test() {
        int cnt = 100000000;
        int dimensions = 300;
        float[] v0 = new float[dimensions];
        float[] v1 = new float[dimensions];
        Random random = new Random();
        for (int i = 0; i < dimensions; ++i) {
            v0[i] = random.nextFloat() * 1000000;
            v1[i] = random.nextFloat() * 1000000;
        }


        long old = System.currentTimeMillis();
        float dist = 0;
        for (int i = 0; i < cnt; ++i) {
            dist += hiddenDistance4(v0, v1);
        }
        System.err.println("dist=" + dist + ", Use " + (System.currentTimeMillis() - old) + " ms.");
    }


    public static void main(String[] args) {
        System.err.println("Warming up...");
        test();
        System.err.println("Testing...");
        test();
    }
}

高维向量相似性检索,有很多时间会花费在欧式距离计算上,做了下测试:

在我的mac上,不论怎么调jvm的SIMD选项(avx,mmx,sse等选项)或者打开enable aggregate opts,300维distance1平均每次耗时为380ns,distance4和distance8均减少到290ns(可见矢量化一定程度生效了)。

而用gcc发现如果不打开优化编译的开关发现性能远远慢于java版本(震惊):

#include 
#include 
#include 
#include 
#include 
#include 

#define PORTABLE_ALIGN16 __attribute__((aligned(16)))

using namespace std;

inline float distance8(float* v0, float* v1, int n) {
  float sum = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0, sum4 = 0.0, sum5 = 0.0, sum6 = 0.0, sum7 = 0.0, sum8 = 0.0;
  int loops = (n / 8) * 8;
  float d1, d2, d3, d4, d5, d6, d7, d8;
  for (int i = 0; i < loops; i += 8) {
    d1 = v0[i] - v1[i];
    d2 = v0[i + 1] - v1[i + 1];
    d3 = v0[i + 2] - v1[i + 2];
    d4 = v0[i + 3] - v1[i + 3];
    d5 = v0[i + 4] - v1[i + 4];
    d6 = v0[i + 5] - v1[i + 5];
    d7 = v0[i + 6] - v1[i + 6];
    d8 = v0[i + 7] - v1[i + 7];
    sum1 += d1 * d1;
    sum2 += d2 * d2;
    sum3 += d3 * d3;
    sum4 += d4 * d4;
    sum5 += d5 * d5;
    sum6 += d6 * d6;
    sum7 += d7 * d7;
    sum8 += d8 * d8;
  }
  sum = sum1 + sum2 + sum3 + sum4 + sum5 + sum6 + sum7 + sum8;
  float delta = 0.0;
  for (int i = loops; i < n; ++i) {
    delta = v0[i] - v1[i];
    sum += delta * delta;
  }
  return sum;
}

inline float distance4(float* v0, float* v1, int n) {
  float sum = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0, sum4 = 0.0;
  int loops = (n / 4) * 4;
  float d1, d2, d3, d4;
  for (int i = 0; i < loops; i += 4) {
    d1 = v0[i] - v1[i];
    d2 = v0[i + 1] - v1[i + 1];
    d3 = v0[i + 2] - v1[i + 2];
    d4 = v0[i + 3] - v1[i + 3];
    sum1 += d1 * d1;
    sum2 += d2 * d2;
    sum3 += d3 * d3;
    sum4 += d4 * d4;
  }
  sum = sum1 + sum2 + sum3 + sum4;
  float delta = 0.0;
  for (int i = loops; i < n; ++i) {
    delta = v0[i] - v1[i];
    sum += delta * delta;
  }
  return sum;
}

inline float distance_SIMD(float* pVect1, float* pVect2, int qty) {
    int qty4  = qty/4;
    int qty16 = qty/16;

    const float* pEnd1 = pVect1 + 16 * qty16;
    const float* pEnd2 = pVect1 + 4  * qty4;
    const float* pEnd3 = pVect1 + qty;

    __m128  diff, v1, v2;
    __m128  sum = _mm_set1_ps(0);

    while (pVect1 < pEnd1) {
        v1   = _mm_loadu_ps(pVect1); pVect1 += 4;
        v2   = _mm_loadu_ps(pVect2); pVect2 += 4;
        diff = _mm_sub_ps(v1, v2);
        sum  = _mm_add_ps(sum, _mm_mul_ps(diff, diff));

        v1   = _mm_loadu_ps(pVect1); pVect1 += 4;
        v2   = _mm_loadu_ps(pVect2); pVect2 += 4;
        diff = _mm_sub_ps(v1, v2);
        sum  = _mm_add_ps(sum, _mm_mul_ps(diff, diff));

        v1   = _mm_loadu_ps(pVect1); pVect1 += 4;
        v2   = _mm_loadu_ps(pVect2); pVect2 += 4;
        diff = _mm_sub_ps(v1, v2);
        sum  = _mm_add_ps(sum, _mm_mul_ps(diff, diff));

        v1   = _mm_loadu_ps(pVect1); pVect1 += 4;
        v2   = _mm_loadu_ps(pVect2); pVect2 += 4;
        diff = _mm_sub_ps(v1, v2);
        sum  = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
    }

    while (pVect1 < pEnd2) {
        v1   = _mm_loadu_ps(pVect1); pVect1 += 4;
        v2   = _mm_loadu_ps(pVect2); pVect2 += 4;
        diff = _mm_sub_ps(v1, v2);
        sum  = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
    }

    float PORTABLE_ALIGN16 TmpRes[4];

    _mm_store_ps(TmpRes, sum);
    float res= TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];

    while (pVect1 < pEnd3) {
        float diff = *pVect1++ - *pVect2++;
        res += diff * diff;
    }
    return res;
}

inline float distance1(float* v0, float* v1, int n) {
  float sum = 0.0;
  float delta = 0.0;
  for (int i = 0; i < n; ++i) {
    delta = v0[i] - v1[i];
    sum += delta * delta;
  }
  return sum;
}

void test() {
  int cnt = 100000000;
  int dimensions = 300;
  float sum = 0.0;

  float v0[dimensions];
  float v1[dimensions];
  for (int i = 0; i < dimensions; ++i) {
    v0[i] = rand() / (RAND_MAX + 1.0) * 1000000;
    v1[i] = rand() / (RAND_MAX + 1.0) * 1000000;
  }

  clock_t old = clock();
  for (int i = 0; i < cnt; ++i) {
    sum += distance4(v0, v1, dimensions);
  }
  clock_t n = clock();
  cout << "Sum = " << sum << ", Using " << (n - old) * 1000 / CLOCKS_PER_SEC << " ticks." << endl;
}

int main(int argc, char** argv) {
  srand( (unsigned)time( NULL ));
  cout << "Warming up..." << endl;
  test();
  cout << "Testing..." << endl;
  test();
  return 0;
}
 
  
java中distance1为1090ns,distance4和distance8为630ns,不如java版本快,即使代码用了intrincs.h中的内联函数,同样如此。

启用mavx等编译选项,没有变化。

启用o2或o3之后,distance4的正常和avx版本性能同样都是90ns。

jni的overhead过多,决定使用CriticalNative,目前测下来java版本比原始的nmslib更高效。


你可能感兴趣的:(欧式距离计算)