package com.dp.arts.biz.processor.post;
import java.util.Random;
public class AAA {
public double hiddenDistance(double[] v0, double[] v1) {
double dist = 0;
double t;
for (int i = 0; i < v0.length; i++) {
t = v0[i] - v1[i];
dist += (t*t);
}
return dist;
}
public static float hiddenDistance1(float[] v0, float[] v1) {
float dist = 0;
for (int i = 0; i < v0.length; ++i) {
float t1 = v0[i + 0] - v1[i + 0];
dist += t1 * t1;
}
return dist;
}
public static float hiddenDistance4(float[] v0, float[] v1) {
float dist = 0;
float dist1 = 0, dist2 = 0, dist3 = 0, dist4 = 0;
int loops = (v0.length / 4) * 4;
for (int i = 0; i < loops; i += 4) {
float t1 = v0[i + 0] - v1[i + 0];
float t2 = v0[i + 1] - v1[i + 1];
float t3 = v0[i + 2] - v1[i + 2];
float t4 = v0[i + 3] - v1[i + 3];
dist1 += t1*t1;
dist2 += t2*t2;
dist3 += t3*t3;
dist4 += t4*t4;
}
dist = dist1 + dist2 + dist3 + dist4;
for (int i = loops; i < v0.length; ++i) {
float t = v0[i] - v1[i];
dist += t * t;
}
return dist;
}
public static float hiddenDistance8(float[] v0, float[] v1) {
float dist = 0;
float dist1 = 0, dist2 = 0, dist3 = 0, dist4 = 0, dist5 = 0, dist6 = 0, dist7 = 0, dist8 = 0;
int loops = (v0.length / 8) * 8;
for (int i = 0; i < loops; i += 8) {
float t1 = v0[i + 0] - v1[i + 0];
float t2 = v0[i + 1] - v1[i + 1];
float t3 = v0[i + 2] - v1[i + 2];
float t4 = v0[i + 3] - v1[i + 3];
float t5 = v0[i + 4] - v1[i + 4];
float t6 = v0[i + 5] - v1[i + 5];
float t7 = v0[i + 6] - v1[i + 6];
float t8 = v0[i + 7] - v1[i + 7];
dist1 += t1*t1;
dist2 += t2*t2;
dist3 += t3*t3;
dist4 += t4*t4;
dist5 += t5*t5;
dist6 += t6*t6;
dist7 += t7*t7;
dist8 += t8*t8;
}
dist = dist1 + dist2 + dist3 + dist4 + dist5 + dist6 + dist7 + dist8;
for (int i = loops; i < v0.length; ++i) {
float t = v0[i] - v1[i];
dist += t * t;
}
return dist;
}
private static void test() {
int cnt = 100000000;
int dimensions = 300;
float[] v0 = new float[dimensions];
float[] v1 = new float[dimensions];
Random random = new Random();
for (int i = 0; i < dimensions; ++i) {
v0[i] = random.nextFloat() * 1000000;
v1[i] = random.nextFloat() * 1000000;
}
long old = System.currentTimeMillis();
float dist = 0;
for (int i = 0; i < cnt; ++i) {
dist += hiddenDistance4(v0, v1);
}
System.err.println("dist=" + dist + ", Use " + (System.currentTimeMillis() - old) + " ms.");
}
public static void main(String[] args) {
System.err.println("Warming up...");
test();
System.err.println("Testing...");
test();
}
}
高维向量相似性检索,有很多时间会花费在欧式距离计算上,做了下测试:
在我的mac上,不论怎么调jvm的SIMD选项(avx,mmx,sse等选项)或者打开enable aggregate opts,300维distance1平均每次耗时为380ns,distance4和distance8均减少到290ns(可见矢量化一定程度生效了)。
而用gcc发现如果不打开优化编译的开关发现性能远远慢于java版本(震惊):
#include
#include
#include
#include
#include
#include
#define PORTABLE_ALIGN16 __attribute__((aligned(16)))
using namespace std;
inline float distance8(float* v0, float* v1, int n) {
float sum = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0, sum4 = 0.0, sum5 = 0.0, sum6 = 0.0, sum7 = 0.0, sum8 = 0.0;
int loops = (n / 8) * 8;
float d1, d2, d3, d4, d5, d6, d7, d8;
for (int i = 0; i < loops; i += 8) {
d1 = v0[i] - v1[i];
d2 = v0[i + 1] - v1[i + 1];
d3 = v0[i + 2] - v1[i + 2];
d4 = v0[i + 3] - v1[i + 3];
d5 = v0[i + 4] - v1[i + 4];
d6 = v0[i + 5] - v1[i + 5];
d7 = v0[i + 6] - v1[i + 6];
d8 = v0[i + 7] - v1[i + 7];
sum1 += d1 * d1;
sum2 += d2 * d2;
sum3 += d3 * d3;
sum4 += d4 * d4;
sum5 += d5 * d5;
sum6 += d6 * d6;
sum7 += d7 * d7;
sum8 += d8 * d8;
}
sum = sum1 + sum2 + sum3 + sum4 + sum5 + sum6 + sum7 + sum8;
float delta = 0.0;
for (int i = loops; i < n; ++i) {
delta = v0[i] - v1[i];
sum += delta * delta;
}
return sum;
}
inline float distance4(float* v0, float* v1, int n) {
float sum = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0, sum4 = 0.0;
int loops = (n / 4) * 4;
float d1, d2, d3, d4;
for (int i = 0; i < loops; i += 4) {
d1 = v0[i] - v1[i];
d2 = v0[i + 1] - v1[i + 1];
d3 = v0[i + 2] - v1[i + 2];
d4 = v0[i + 3] - v1[i + 3];
sum1 += d1 * d1;
sum2 += d2 * d2;
sum3 += d3 * d3;
sum4 += d4 * d4;
}
sum = sum1 + sum2 + sum3 + sum4;
float delta = 0.0;
for (int i = loops; i < n; ++i) {
delta = v0[i] - v1[i];
sum += delta * delta;
}
return sum;
}
inline float distance_SIMD(float* pVect1, float* pVect2, int qty) {
int qty4 = qty/4;
int qty16 = qty/16;
const float* pEnd1 = pVect1 + 16 * qty16;
const float* pEnd2 = pVect1 + 4 * qty4;
const float* pEnd3 = pVect1 + qty;
__m128 diff, v1, v2;
__m128 sum = _mm_set1_ps(0);
while (pVect1 < pEnd1) {
v1 = _mm_loadu_ps(pVect1); pVect1 += 4;
v2 = _mm_loadu_ps(pVect2); pVect2 += 4;
diff = _mm_sub_ps(v1, v2);
sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
v1 = _mm_loadu_ps(pVect1); pVect1 += 4;
v2 = _mm_loadu_ps(pVect2); pVect2 += 4;
diff = _mm_sub_ps(v1, v2);
sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
v1 = _mm_loadu_ps(pVect1); pVect1 += 4;
v2 = _mm_loadu_ps(pVect2); pVect2 += 4;
diff = _mm_sub_ps(v1, v2);
sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
v1 = _mm_loadu_ps(pVect1); pVect1 += 4;
v2 = _mm_loadu_ps(pVect2); pVect2 += 4;
diff = _mm_sub_ps(v1, v2);
sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
}
while (pVect1 < pEnd2) {
v1 = _mm_loadu_ps(pVect1); pVect1 += 4;
v2 = _mm_loadu_ps(pVect2); pVect2 += 4;
diff = _mm_sub_ps(v1, v2);
sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
}
float PORTABLE_ALIGN16 TmpRes[4];
_mm_store_ps(TmpRes, sum);
float res= TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
while (pVect1 < pEnd3) {
float diff = *pVect1++ - *pVect2++;
res += diff * diff;
}
return res;
}
inline float distance1(float* v0, float* v1, int n) {
float sum = 0.0;
float delta = 0.0;
for (int i = 0; i < n; ++i) {
delta = v0[i] - v1[i];
sum += delta * delta;
}
return sum;
}
void test() {
int cnt = 100000000;
int dimensions = 300;
float sum = 0.0;
float v0[dimensions];
float v1[dimensions];
for (int i = 0; i < dimensions; ++i) {
v0[i] = rand() / (RAND_MAX + 1.0) * 1000000;
v1[i] = rand() / (RAND_MAX + 1.0) * 1000000;
}
clock_t old = clock();
for (int i = 0; i < cnt; ++i) {
sum += distance4(v0, v1, dimensions);
}
clock_t n = clock();
cout << "Sum = " << sum << ", Using " << (n - old) * 1000 / CLOCKS_PER_SEC << " ticks." << endl;
}
int main(int argc, char** argv) {
srand( (unsigned)time( NULL ));
cout << "Warming up..." << endl;
test();
cout << "Testing..." << endl;
test();
return 0;
}
启用mavx等编译选项,没有变化。
启用o2或o3之后,distance4的正常和avx版本性能同样都是90ns。
jni的overhead过多,决定使用CriticalNative,目前测下来java版本比原始的nmslib更高效。