给定两个有序数组A和B,长度分别为m, n。
求这两个数组合并之后的第k小的数。
int find_kth(const int *a, int m, const int *b, int n, int k){ const int *afrom = a, *amid, *ato = a + m; const int *bfrom = b, *bmid, *bto = b + n; int aempty = (!a || m <= 0), bempty = (!b || n <= 0); int amax = aempty ? INT_MIN : a[m-1], amin = aempty ? INT_MAX : a[0]; int bmax = bempty ? INT_MIN : b[n-1], bmin = bempty ? INT_MAX: b[0]; int vmin = amin < bmin ? amin : bmin; int vmax = amax > bmax ? amax : bmax; int mid, kvalue; if (aempty && bempty) return INT_MAX; if (k <= 0) return vmin; if (k >= (m + n)) return vmax; while (afrom < ato && bfrom < bto) { amid = afrom + ((ato - afrom) >> 1); bmid = bfrom + ((bto - bfrom) >> 1); mid = amid - afrom + bmid - bfrom; if (0 == k) return *afrom < *bfrom ? *afrom : *bfrom; if (k >= ((ato - afrom) + (bto - bfrom) - 1)) return *(ato-1) > *(bto-1) ? *(ato-1) : *(bto-1); assert(0 < k && k < ((ato - afrom) + (bto - bfrom))); if (*amid == *bmid) { // A: [afrom amid), amid, [amid + 1, ato) // B: [bfrom, bmid),amid, [bmid + 1, bto) if (k < mid) { ato = amid; bto = bmid; } else if (k > (mid + 1)) { afrom = amid + 1; bfrom = bmid + 1; k -= mid + 2; } else if (k == mid || k == (mid + 1)) { return *amid; } } else if (*amid > *bmid) { // A: [afrom, amid), [amid, ato) // B: [bfrom, bmid), bmid, [bmid + 1, bto) if (k < mid) { // When *amid > *bmid. *amid is larger than mid+1 numbers. // Delete [amid, ato) ato = amid; } else if (k >= mid) { // Delete [bform, bmid). if (bmid == bfrom) { kvalue = afrom[k]; if (kvalue <= *bmid) { return kvalue; } else { ato = afrom + k; } } else { k -= bmid - bfrom; bfrom = bmid; } } } else if (*amid < *bmid) { if (k < mid) { bto = bmid; } else if (k >= mid) { if (amid == afrom) { kvalue = bfrom[k]; if (kvalue <= *amid) return kvalue; else bto = bfrom + k; } else { k -= amid - afrom; afrom = amid; } } } } if (bfrom >= bto) { return afrom[k]; } if (afrom >= ato) { return bfrom[k]; } }
下面的测试代码部分:test()利用随机函数来生成各种各样的数组,并与排序后的结果进行验证。
#include <limits.h> #include <stdlib.h> #include <stdio.h> #include <assert.h> int find_kth(const int *a, int m, const int *b, int n, int k){ //值得注意的是k的取值是从0开始的,也就是说k = 0代表了两者最小的数。 const int *afrom = a, *amid, *ato = a + m; const int *bfrom = b, *bmid, *bto = b + n; int aempty = (!a || m <= 0), bempty = (!b || n <= 0); int amax = aempty ? INT_MIN : a[m-1], amin = aempty ? INT_MAX : a[0]; int bmax = bempty ? INT_MIN : b[n-1], bmin = bempty ? INT_MAX: b[0]; int vmin = amin < bmin ? amin : bmin; int vmax = amax > bmax ? amax : bmax; int mid, kvalue; // 注意:当k的取值不合法时,利用vmin, vmax来设置返回值, // 直接利用INT_MIN, INT_MAX来避免数组是否为空的讨论。 if (aempty && bempty) return INT_MAX; if (k <= 0) return vmin; if (k >= (m + n)) return vmax; // 经过上面的处理之后,肯定能够保证0 < k < total_length。 // 因此,0 < k < total_length是属于恒等不变式。 while (afrom < ato && bfrom < bto) { amid = afrom + ((ato - afrom) >> 1); bmid = bfrom + ((bto - bfrom) >> 1); mid = amid - afrom + bmid - bfrom; // 确保0 < k < (total_length)成立 // 当k的取值在边界上时,直接设置返回值。 if (0 == k) return *afrom < *bfrom ? *afrom : *bfrom; if (k >= ((ato - afrom) + (bto - bfrom) - 1)) return *(ato-1) > *(bto-1) ? *(ato-1) : *(bto-1); assert(0 < k && k < ((ato - afrom) + (bto - bfrom))); if (*amid == *bmid) { // 注意:当两数相等时,这两个数组的切分方式如下: // amid和bmid均属于独立部分。 // A: [afrom, amid), amid, [amid + 1, ato) // B: [bfrom, bmid), bmid, [bmid + 1, bto) if (k < mid) { // 如果k小于前面部分的长度,直接去掉[amid, ato)及[bmid, bto) // 这两部分。 ato = amid; bto = bmid; } else if (k > (mid + 1)) { // 如果k > (mid+1) 那么,要寻找的值必然是属于后半部分。 // 去掉[afrom, amid], [bfrom, bmid] // 注意减去的长度为mid + 2. afrom = amid + 1; bfrom = bmid + 1; k -= mid + 2; } else if (k == mid || k == (mid + 1)) { // 如果k == mid,或者k == mid + 1,此时目标值为*amid, *bmid。 return *amid; } } else if (*amid > *bmid) { // A: [afrom, amid), [amid, ato) // B: [bfrom, bmid), bmid, [bmid + 1, bto) if (k < mid) { // When *amid > *bmid. *amid is larger than mid+1 numbers. // Delete [amid, ato) ato = amid; } else if (k >= mid) { // Delete [bform, bmid). if (bmid == bfrom) { //注意:在删除前端的时候,需要对是否删除的长度为0进行讨论。 // 在删除前端长度为0的时候,这种情况一般都是一个数组 // 长度为1,而另外一个数组长度正常的情况下发生的。 // 比如: A = [34], B = [0, 24, 34, 45]; kvalue = afrom[k]; if (kvalue <= *bmid) { // 如果单一值在k值之外,直接返回相应值。 // 这里不需要再做0 <= k或者k >=length的判断。 // 在前面已经有了不变式的保证。 return kvalue; } else { // 当此单一值在k值之内的时候。直接可以舍弃多余的部分。 ato = afrom + k; } } else { // 删除前端的部分不是为空,那么直接删除即可。 k -= bmid - bfrom; bfrom = bmid; } } } else if (*amid < *bmid) { if (k < mid) { bto = bmid; } else if (k >= mid) { if (amid == afrom) { kvalue = bfrom[k]; if (kvalue <= *amid) return kvalue; else bto = bfrom + k; } else { k -= amid - afrom; afrom = amid; } } } } //处理一个数组为空,而另外一个数组非空的情况。 //这里不需要再讨论不合法的k值的情况。因为在前面已经讨论。 if (bfrom >= bto) { return afrom[k]; } if (afrom >= ato) { return bfrom[k]; } } int cmp(const void *a, const void *b) { return (*(int *)a) - (*(int *)b); } void aprint(int *a, int n) { int i = 0; for (i = 0; i < n; ++i) { printf("%d, ", a[i]); } printf("\n"); } void test(void) { int *a = NULL, *b = NULL, *c = NULL; int m = 0, n = 0; int iter = 0, i, ret, cret, find_error = 1; for (iter = 0; iter < 100 && find_error; ++iter) { m = rand() % 1000; n = rand() % 1000; a = (int *)malloc(sizeof(int) * m); b = (int *)malloc(sizeof(int) * n); c = (int *)malloc(sizeof(int)* (m+n)); for (i = 0; i < m; ++i) { a[i] = rand() % 100; c[i] = a[i];} for (i = 0; i < n; ++i) { b[i] = rand() % 100; c[m+i] = b[i];} qsort(a, m, sizeof(int), cmp); qsort(b, n, sizeof(int), cmp); qsort(c, m + n, sizeof(int), cmp); for (i = -2; i < m + n + 10; ++i) { if (i <= 0) cret = c[0]; else if (i >= (m + n)) cret = c[m+n-1]; else cret = c[i]; ret = find_kth(a, m, b, n, i); if (ret != cret) { printf("Error i = %d, ret = %d, cret = %d\n", i, ret, cret); printf("a = "); aprint(a, m); printf("b = "); aprint(b, n); printf("c = "); aprint(c, m + n); ret = find_kth(a, m, b, n, i); find_error = 0; break; } } free(a); free(b); free(c); } for (iter = 0; iter < 100 && find_error; ++iter) { m = 1; n = rand() % 1000; a = (int *)malloc(sizeof(int) * m); b = (int *)malloc(sizeof(int) * n); c = (int *)malloc(sizeof(int)* (m+n)); for (i = 0; i < m; ++i) { a[i] = rand() % 100; c[i] = a[i];} for (i = 0; i < n; ++i) { b[i] = rand() % 100; c[m+i] = b[i];} qsort(a, m, sizeof(int), cmp); qsort(b, n, sizeof(int), cmp); qsort(c, m + n, sizeof(int), cmp); for (i = -2; i < m + n + 10; ++i) { if (i <= 0) cret = c[0]; else if (i >= (m + n)) cret = c[m+n-1]; else cret = c[i]; ret = find_kth(a, m, b, n, i); if (ret != cret) { printf("Error i = %d, ret = %d, cret = %d\n", i, ret, cret); printf("a = "); aprint(a, m); printf("b = "); aprint(b, n); printf("c = "); aprint(c, m + n); ret = find_kth(a, m, b, n, i); find_error = 0; break; } } free(a); free(b); free(c); } for (iter = 0; iter < 100 && find_error; ++iter) { m = 1; n = 1; a = (int *)malloc(sizeof(int) * m); b = (int *)malloc(sizeof(int) * n); c = (int *)malloc(sizeof(int)* (m+n)); for (i = 0; i < m; ++i) { a[i] = rand() % 100; c[i] = a[i];} for (i = 0; i < n; ++i) { b[i] = rand() % 100; c[m+i] = b[i];} qsort(a, m, sizeof(int), cmp); qsort(b, n, sizeof(int), cmp); qsort(c, m + n, sizeof(int), cmp); for (i = -2; i < m + n + 10; ++i) { if (i <= 0) cret = c[0]; else if (i >= (m + n)) cret = c[m+n-1]; else cret = c[i]; ret = find_kth(a, m, b, n, i); if (ret != cret) { printf("Error i = %d, ret = %d, cret = %d\n", i, ret, cret); printf("a = "); aprint(a, m); printf("b = "); aprint(b, n); printf("c = "); aprint(c, m + n); ret = find_kth(a, m, b, n, i); find_error = 0; break; } } free(a); free(b); free(c); } for (iter = 0; iter < 100 && find_error; ++iter) { m = rand() % 1000; n = 1; a = (int *)malloc(sizeof(int) * m); b = (int *)malloc(sizeof(int) * n); c = (int *)malloc(sizeof(int)* (m+n)); for (i = 0; i < m; ++i) { a[i] = rand() % 100; c[i] = a[i];} for (i = 0; i < n; ++i) { b[i] = rand() % 100; c[m+i] = b[i];} qsort(a, m, sizeof(int), cmp); qsort(b, n, sizeof(int), cmp); qsort(c, m + n, sizeof(int), cmp); for (i = -2; i < m + n + 10; ++i) { if (i <= 0) cret = c[0]; else if (i >= (m + n)) cret = c[m+n-1]; else cret = c[i]; ret = find_kth(a, m, b, n, i); if (ret != cret) { printf("Error i = %d, ret = %d, cret = %d\n", i, ret, cret); printf("a = "); aprint(a, m); printf("b = "); aprint(b, n); printf("c = "); aprint(c, m + n); ret = find_kth(a, m, b, n, i); find_error = 0; break; } } free(a); free(b); free(c); } for (iter = 0; iter < 100 && find_error; ++iter) { m = rand() % 1000; n = m; a = (int *)malloc(sizeof(int) * m); b = (int *)malloc(sizeof(int) * n); c = (int *)malloc(sizeof(int)* (m+n)); for (i = 0; i < m; ++i) { a[i] = rand() % 100; c[i] = a[i];} for (i = 0; i < n; ++i) { b[i] = a[i]; c[m+i] = b[i];} qsort(a, m, sizeof(int), cmp); qsort(b, n, sizeof(int), cmp); qsort(c, m + n, sizeof(int), cmp); for (i = -2; i < m + n + 10; ++i) { if (i <= 0) cret = c[0]; else if (i >= (m + n)) cret = c[m+n-1]; else cret = c[i]; ret = find_kth(a, m, b, n, i); if (ret != cret) { printf("Error i = %d, ret = %d, cret = %d\n", i, ret, cret); printf("a = "); aprint(a, m); printf("b = "); aprint(b, n); printf("c = "); aprint(c, m + n); ret = find_kth(a, m, b, n, i); find_error = 0; break; } } free(a); free(b); free(c); } } int main(void) { int a[] = {1, 2, 3, 4, 5, 6}; int m = sizeof(a)/sizeof(int); int *b = a; int n = m; int x; test(); while (scanf("%d", &x) != EOF) { printf("%d\n", find_kth(a, m, b, n, x)); } return 0; }