两个有序数组,寻找第k小的数

题意

给定两个有序数组A和B,长度分别为m, n。

求这两个数组合并之后的第k小的数。

解答

int find_kth(const int *a, int m, const int *b, int n, int k){
	const int *afrom = a, *amid, *ato = a + m;
	const int *bfrom = b, *bmid, *bto = b + n;
	int aempty = (!a || m <= 0), bempty = (!b || n <= 0);
	int amax = aempty ? INT_MIN : a[m-1], amin = aempty ? INT_MAX : a[0];
	int bmax = bempty ? INT_MIN : b[n-1], bmin = bempty ? INT_MAX: b[0];
	int vmin = amin < bmin ? amin : bmin;
	int vmax = amax > bmax ? amax : bmax;
	int mid, kvalue;

	if (aempty && bempty) return INT_MAX;
	if (k <= 0) return vmin;
	if (k >= (m + n)) return vmax;

	while (afrom < ato && bfrom < bto) {
		amid = afrom + ((ato - afrom) >> 1); bmid = bfrom + ((bto - bfrom) >> 1);
		mid = amid - afrom + bmid - bfrom;

		if (0 == k) return *afrom < *bfrom ? *afrom : *bfrom;
		if (k >= ((ato - afrom) + (bto - bfrom) - 1)) 
			return *(ato-1) > *(bto-1) ?  *(ato-1) : *(bto-1);

		assert(0 < k && k < ((ato - afrom) + (bto - bfrom)));

		if (*amid == *bmid) {
			// A: [afrom amid), amid, [amid + 1, ato)
			// B: [bfrom, bmid),amid, [bmid + 1, bto)
			if (k < mid) {
				ato = amid; bto = bmid;
			} else if (k > (mid + 1)) {
				afrom = amid + 1; bfrom = bmid + 1;
				k -= mid + 2; 
			} else if (k == mid || k == (mid + 1)) {
				return *amid;
			}
		} else if (*amid > *bmid) {
			// A: [afrom, amid), [amid, ato) 
			// B: [bfrom, bmid), bmid, [bmid + 1, bto)
			if (k < mid) {
				// When *amid > *bmid. *amid is larger than mid+1 numbers.
				// Delete [amid, ato)
				ato = amid;
			} else if (k >= mid) {
				// Delete [bform, bmid).
				if (bmid == bfrom) {
					kvalue = afrom[k];
					if (kvalue <= *bmid) {
						return kvalue;
					} else {
						ato = afrom + k;
					}
				} else {
					k -= bmid - bfrom;
					bfrom = bmid;
				}
			}
		} else if (*amid < *bmid) {
			if (k < mid) {
				bto = bmid;
			} else if (k >= mid) {
				if (amid == afrom) {
					kvalue = bfrom[k];
					if (kvalue <= *amid) return kvalue;
					else bto = bfrom + k;
				} else {
					k -= amid - afrom;
					afrom = amid;
				}
			}
		}
	}
	if (bfrom >= bto) {
		return afrom[k];
	}
	if (afrom >= ato) {
		return bfrom[k];
	}
}


下面的测试代码部分:test()利用随机函数来生成各种各样的数组,并与排序后的结果进行验证。

测试代码:

#include 
#include 
#include 
#include 

int find_kth(const int *a, int m, const int *b, int n, int k){
	//值得注意的是k的取值是从0开始的,也就是说k = 0代表了两者最小的数。
	const int *afrom = a, *amid, *ato = a + m;
	const int *bfrom = b, *bmid, *bto = b + n;
	int aempty = (!a || m <= 0), bempty = (!b || n <= 0);
	int amax = aempty ? INT_MIN : a[m-1], amin = aempty ? INT_MAX : a[0];
	int bmax = bempty ? INT_MIN : b[n-1], bmin = bempty ? INT_MAX: b[0];
	int vmin = amin < bmin ? amin : bmin;
	int vmax = amax > bmax ? amax : bmax;
	int mid, kvalue;

	// 注意:当k的取值不合法时,利用vmin, vmax来设置返回值,
	// 直接利用INT_MIN, INT_MAX来避免数组是否为空的讨论。
	if (aempty && bempty) return INT_MAX;
	if (k <= 0) return vmin;
	if (k >= (m + n)) return vmax;

	// 经过上面的处理之后,肯定能够保证0 < k < total_length。
	// 因此,0 < k < total_length是属于恒等不变式。
	while (afrom < ato && bfrom < bto) {
		amid = afrom + ((ato - afrom) >> 1); bmid = bfrom + ((bto - bfrom) >> 1);
		mid = amid - afrom + bmid - bfrom;

		// 确保0 < k < (total_length)成立
		// 当k的取值在边界上时,直接设置返回值。
		if (0 == k) return *afrom < *bfrom ? *afrom : *bfrom;
		if (k >= ((ato - afrom) + (bto - bfrom) - 1)) 
			return *(ato-1) > *(bto-1) ?  *(ato-1) : *(bto-1);

		assert(0 < k && k < ((ato - afrom) + (bto - bfrom)));

		if (*amid == *bmid) {
			// 注意:当两数相等时,这两个数组的切分方式如下:
			//       amid和bmid均属于独立部分。
			// A: [afrom, amid), amid, [amid + 1, ato)
			// B: [bfrom, bmid), bmid, [bmid + 1, bto)
			if (k < mid) {
				// 如果k小于前面部分的长度,直接去掉[amid, ato)及[bmid, bto)
				// 这两部分。
				ato = amid; bto = bmid;
			} else if (k > (mid + 1)) {
				// 如果k > (mid+1) 那么,要寻找的值必然是属于后半部分。
				// 去掉[afrom, amid], [bfrom, bmid]
				// 注意减去的长度为mid + 2.
				afrom = amid + 1; bfrom = bmid + 1;
				k -= mid + 2; 
			} else if (k == mid || k == (mid + 1)) {
				// 如果k == mid,或者k == mid + 1,此时目标值为*amid, *bmid。
				return *amid;
			}
		} else if (*amid > *bmid) {
			// A: [afrom, amid), [amid, ato) 
			// B: [bfrom, bmid), bmid, [bmid + 1, bto)
			if (k < mid) {
				// When *amid > *bmid. *amid is larger than mid+1 numbers.
				// Delete [amid, ato)
				ato = amid;
			} else if (k >= mid) {
				// Delete [bform, bmid).
				if (bmid == bfrom) {
					//注意:在删除前端的时候,需要对是否删除的长度为0进行讨论。
					//      在删除前端长度为0的时候,这种情况一般都是一个数组
					//      长度为1,而另外一个数组长度正常的情况下发生的。
					//      比如: A = [34], B = [0, 24, 34, 45];
					kvalue = afrom[k];
					if (kvalue <= *bmid) {
						// 如果单一值在k值之外,直接返回相应值。
						// 这里不需要再做0 <= k或者k >=length的判断。
						// 在前面已经有了不变式的保证。
						return kvalue;
					} else {
						// 当此单一值在k值之内的时候。直接可以舍弃多余的部分。
						ato = afrom + k;
					}
				} else {
					// 删除前端的部分不是为空,那么直接删除即可。
					k -= bmid - bfrom;
					bfrom = bmid;
				}
			}
		} else if (*amid < *bmid) {
			if (k < mid) {
				bto = bmid;
			} else if (k >= mid) {
				if (amid == afrom) {
					kvalue = bfrom[k];
					if (kvalue <= *amid) return kvalue;
					else bto = bfrom + k;
				} else {
					k -= amid - afrom;
					afrom = amid;
				}
			}
		}
	}
	//处理一个数组为空,而另外一个数组非空的情况。
	//这里不需要再讨论不合法的k值的情况。因为在前面已经讨论。
	if (bfrom >= bto) {
		return afrom[k];
	}
	if (afrom >= ato) {
		return bfrom[k];
	}
}

int cmp(const void *a, const void *b) {
	return (*(int *)a) - (*(int *)b);
}

void aprint(int *a, int n) {
	int i = 0;
	for (i = 0; i < n; ++i) {
		printf("%d, ", a[i]);
	}
	printf("\n");
}

void test(void) {
	int *a = NULL, *b = NULL, *c = NULL;
	int m = 0, n = 0;
	int iter = 0, i, ret, cret, find_error = 1;
	for (iter = 0; iter < 100 && find_error; ++iter) {
		m = rand() % 1000; n = rand() % 1000;
		a = (int *)malloc(sizeof(int) * m);
		b = (int *)malloc(sizeof(int) * n);
		c = (int *)malloc(sizeof(int)* (m+n));
		for (i = 0; i < m; ++i) { a[i] = rand() % 100; c[i] = a[i];}
		for (i = 0; i < n; ++i) { b[i] = rand() % 100; c[m+i] = b[i];}
		qsort(a, m, sizeof(int), cmp);
		qsort(b, n, sizeof(int), cmp);
		qsort(c, m + n, sizeof(int), cmp);
		for (i = -2; i < m + n + 10; ++i) {
			if (i <= 0) cret = c[0];
			else if (i >= (m + n)) cret = c[m+n-1];
			else cret = c[i];

			ret = find_kth(a, m, b, n, i);
			if (ret != cret) {
				printf("Error i = %d, ret = %d, cret = %d\n", i, ret, cret);
				printf("a = "); aprint(a, m);
				printf("b = "); aprint(b, n);
				printf("c = "); aprint(c, m + n);
				ret = find_kth(a, m, b, n, i);
				find_error = 0;
				break;
			}
		}
		free(a); free(b); free(c);
	}

	for (iter = 0; iter < 100 && find_error; ++iter) {
		m = 1; n = rand() % 1000;
		a = (int *)malloc(sizeof(int) * m);
		b = (int *)malloc(sizeof(int) * n);
		c = (int *)malloc(sizeof(int)* (m+n));
		for (i = 0; i < m; ++i) { a[i] = rand() % 100; c[i] = a[i];}
		for (i = 0; i < n; ++i) { b[i] = rand() % 100; c[m+i] = b[i];}
		qsort(a, m, sizeof(int), cmp);
		qsort(b, n, sizeof(int), cmp);
		qsort(c, m + n, sizeof(int), cmp);
		for (i = -2; i < m + n + 10; ++i) {
			if (i <= 0) cret = c[0];
			else if (i >= (m + n)) cret = c[m+n-1];
			else cret = c[i];

			ret = find_kth(a, m, b, n, i);
			if (ret != cret) {
				printf("Error i = %d, ret = %d, cret = %d\n", i, ret, cret);
				printf("a = "); aprint(a, m);
				printf("b = "); aprint(b, n);
				printf("c = "); aprint(c, m + n);
				ret = find_kth(a, m, b, n, i);
				find_error = 0;
				break;
			}
		}
		free(a); free(b); free(c);
	}

	for (iter = 0; iter < 100 && find_error; ++iter) {
		m = 1; n = 1;
		a = (int *)malloc(sizeof(int) * m);
		b = (int *)malloc(sizeof(int) * n);
		c = (int *)malloc(sizeof(int)* (m+n));
		for (i = 0; i < m; ++i) { a[i] = rand() % 100; c[i] = a[i];}
		for (i = 0; i < n; ++i) { b[i] = rand() % 100; c[m+i] = b[i];}
		qsort(a, m, sizeof(int), cmp);
		qsort(b, n, sizeof(int), cmp);
		qsort(c, m + n, sizeof(int), cmp);
		for (i = -2; i < m + n + 10; ++i) {
			if (i <= 0) cret = c[0];
			else if (i >= (m + n)) cret = c[m+n-1];
			else cret = c[i];

			ret = find_kth(a, m, b, n, i);
			if (ret != cret) {
				printf("Error i = %d, ret = %d, cret = %d\n", i, ret, cret);
				printf("a = "); aprint(a, m);
				printf("b = "); aprint(b, n);
				printf("c = "); aprint(c, m + n);
				ret = find_kth(a, m, b, n, i);
				find_error = 0;
				break;
			}
		}
		free(a); free(b); free(c);
	}

	for (iter = 0; iter < 100 && find_error; ++iter) {
		m = rand() % 1000; n = 1;
		a = (int *)malloc(sizeof(int) * m);
		b = (int *)malloc(sizeof(int) * n);
		c = (int *)malloc(sizeof(int)* (m+n));
		for (i = 0; i < m; ++i) { a[i] = rand() % 100; c[i] = a[i];}
		for (i = 0; i < n; ++i) { b[i] = rand() % 100; c[m+i] = b[i];}
		qsort(a, m, sizeof(int), cmp);
		qsort(b, n, sizeof(int), cmp);
		qsort(c, m + n, sizeof(int), cmp);
		for (i = -2; i < m + n + 10; ++i) {
			if (i <= 0) cret = c[0];
			else if (i >= (m + n)) cret = c[m+n-1];
			else cret = c[i];

			ret = find_kth(a, m, b, n, i);
			if (ret != cret) {
				printf("Error i = %d, ret = %d, cret = %d\n", i, ret, cret);
				printf("a = "); aprint(a, m);
				printf("b = "); aprint(b, n);
				printf("c = "); aprint(c, m + n);
				ret = find_kth(a, m, b, n, i);
				find_error = 0;
				break;
			}
		}
		free(a); free(b); free(c);
	}
	for (iter = 0; iter < 100 && find_error; ++iter) {
		m = rand() % 1000; n = m;
		a = (int *)malloc(sizeof(int) * m);
		b = (int *)malloc(sizeof(int) * n);
		c = (int *)malloc(sizeof(int)* (m+n));
		for (i = 0; i < m; ++i) { a[i] = rand() % 100; c[i] = a[i];}
		for (i = 0; i < n; ++i) { b[i] = a[i]; c[m+i] = b[i];}
		qsort(a, m, sizeof(int), cmp);
		qsort(b, n, sizeof(int), cmp);
		qsort(c, m + n, sizeof(int), cmp);
		for (i = -2; i < m + n + 10; ++i) {
			if (i <= 0) cret = c[0];
			else if (i >= (m + n)) cret = c[m+n-1];
			else cret = c[i];

			ret = find_kth(a, m, b, n, i);
			if (ret != cret) {
				printf("Error i = %d, ret = %d, cret = %d\n", i, ret, cret);
				printf("a = "); aprint(a, m);
				printf("b = "); aprint(b, n);
				printf("c = "); aprint(c, m + n);
				ret = find_kth(a, m, b, n, i);
				find_error = 0;
				break;
			}
		}
		free(a); free(b); free(c);
	}
}
int main(void) {
	int a[] = {1, 2, 3, 4, 5, 6};
	int m = sizeof(a)/sizeof(int);
	int *b = a;
	int n = m;
	int x;

	test();
	while (scanf("%d", &x) != EOF) {
		printf("%d\n", find_kth(a, m, b, n, x));
	}
	return 0;
}


你可能感兴趣的:(算法导论,数据结构算法面试题精选及整理)