K Smallest Sums
You're given k arrays, each array has k integers. There are kk ways to pick exactly one element in each array and calculate the sum of the integers. Your task is to find the k smallest sums among them.
Input
There will be several test cases. The first line of each case contains an integer k (2<=k<=750). Each of the following k lines contains k positive integers in each array. Each of these integers does not exceed 1,000,000. The input is terminated by end-of-file (EOF). The size of input file does not exceed 5MB.
Output
For each test case, print the k smallest sums, in ascending order.
Sample Input
3 1 8 5 9 2 5 10 7 6 2 1 1 1 2
Output for the Sample Input
9 10 12
2 2
分析:本题也是用优先级队列来解决多路归并问题。本题输入数据有多行,每次处理两行,然后从第一列开始处理,之后两两合并。
初步代码如下:
#include
#include
#include
using namespace std;
const int maxn = 768;
int A[maxn][maxn];
struct Node
{
int sum,b;
Node(int s,int b) : sum(s) , b(b) {}
bool operator<(const Node & obj) const
{
return sum > obj.sum;
}
};
void merge(int *A,int *B,int *C,int n)
{
priority_queue pq;
for (int i = 0; i < n; i++)
pq.push(Node(A[i] + B[0],0));
for (int j = 0; j < n; j++)
{
Node item = pq.top();
pq.pop();
C[j] = item.sum;
int b = item.b;
if (b + 1 < n)
pq.push(Node(item.sum - B[b] + B[b + 1], b + 1));//取下一个数
}
}
int main()
{
int n;
while (scanf("%d",&n) == 1)
{
for (int i = 0; i < n; i++)
{
for (int j = 0; j < n; j++)
scanf("%d",&A[i][j]);
sort(A[i],A[i] + n);
}
for (int k = 1; k < n; k++)//两两合并
merge(A[0],A[k],A[0],n);
printf("%d",A[0][0]);
for (int j = 1; j < n; j++)
printf(" %d",A[0][j]);
printf("\n");
}
return 0;
}
仔细考虑下,会发现以上代码可以改进,其实一开始并不需要把K*K元素都保存。可以每次只读k个元素,这样大大降低了空间复杂度,改进后的代码如下:
#include
#include
#include
using namespace std;
const int maxn = 768;
int A[maxn],B[maxn];
struct Node
{
int sum,b;
Node(int s,int b) : sum(s) , b(b) {}
bool operator<(const Node & obj) const
{
return sum > obj.sum;
}
};
void merge(int *A,int *B,int *C,int n)
{
priority_queue pq;
for (int i = 0; i < n; i++)
pq.push(Node(A[i] + B[0],0));
for (int j = 0; j < n; j++)
{
Node item = pq.top();
pq.pop();
C[j] = item.sum;
int b = item.b;
if (b + 1 < n)
pq.push(Node(item.sum - B[b] + B[b + 1], b + 1));
}
}
int main()
{
int n;
while (scanf("%d",&n) == 1)
{
for (int i = 0; i < n; i++)
{
for (int j = 0; j < n; j++)//每次读一行
scanf("%d",&A[j]);
sort(A,A + n);
if (i != 0)
merge(B,A,B,n);//两两合并
else
for (int k = 0; k < n; k++)
B[k] = A[k];
}
printf("%d",B[0]);
for (int j = 1; j < n; j++)
printf(" %d",B[j]);
printf("\n");
}
return 0;
}