S中最接近中位数的k个元素

在《算法导论》第3版习题习题9.3-7提到,设计一个O(n)时间的算法,对于一个给定的包含n个互异元素的集合S和一个正整数 k<=n,该算法能够确定S中最接近中位数的k个元素。

步骤如下:

1: select A数组得到其中位数nmid,其下标为imid

2: 计算A中每个数到中位数的差值作为数组dis, 并拷贝到数组dis_copy

3: select dis_copy数组得到第k小的数nkmid

4: 遍历数组dis, 获取k个值小于等于nkmid的数

代码如下

int partition(int A[], int p, int r)
{
    int x = A[r];
    int i = p - 1;

    for (int j = p; j < r; ++j)
    {
        if (A[j] <= x)
        {
            ++i;
            swap(A[i], A[j]);
        }
    }

    swap(A[i + 1], A[r]);

    return i + 1;
}

int select(int A[], int p, int r, int k)
{
    assert(p <= r);
    assert(k <= r - p + 1);

    if (p == r)
        return A[p];

    int mid = partition(A, p, r);

    int count = mid - p + 1;
    if (k == count)
        return A[mid];
    else if (k < count)
        return select(A, p, mid - 1, k);
    else
        return select(A, mid + 1, r, k - count);
}

int* kth_select(int A[], int len, int k)
{
    assert(k <= len);

    int *dis = new int[len - 1];
    int *dis_cpy = new int[len - 1];
    int *res = new int[k];

    int nmid = select(A, 0, len - 1, len / 2);
    int imid = 0;
    int count = 0;

    for (int i = 0; i < len; ++i)
    {
        if (A[i] != nmid)
            dis[count++] = abs(A[i] - nmid);
        else
            imid = i;
    }

    memcpy(dis_cpy, dis, sizeof(int)*(len - 1));
    int nkmid = select(dis_cpy, 0, count - 1, k);

    delete dis_cpy;
    dis_cpy = NULL;

    int ik = 0;
    for (int i = 0; ik < k && i < count; ++i)
    {
        if (dis[i] <= nkmid)
        {
            if (i < imid)
                res[ik++] = nmid - dis[i];
            else
                res[ik++] = nmid + dis[i];
        }
    }

    delete dis;
    dis = NULL;

    return res;
}

你可能感兴趣的:(算法基础)