数轴上有m个生产车间可以生产零件。一共有n种零件,编号为1~n。第i个车间的坐标为xi,生产第pi种零件(1<=pi<=n)。你需要在数轴上的某个位置修建一个组装车间,把这些零件组装起来。为了节约运输成本,你需要最小化cost(1)+cost(2)+…+cost(n),其中cost(x)表示生产第x种零件的车间中,到组装车间距离的平方的最小值。
输入第一行为两个整数n, m,即零件的种类数和生产车间的个数。以下m行每行两个整数xi和pi(1<=pi<=n)。输入按照生产车间从左到右的顺序排列(即xi<=xi+1。注意车间位置可以重复)。输入保证每种零件都有车间生产。
输出仅一行,即组装车间的最优位置(可以和某个生产车间重合),四舍五入保留四位小数。输入保证最优位置惟一。
3 5
-1 3
0 1
2 3
4 2
5 2
2.0000
范围:n<=10000, m<=100000, xi<=100000
题目大意就是数轴上很多点,分为不同类,找一点使它到每一类中离它最近的点到它距离的平方和最小。网上很多人做这道题都是用的贪心,按照什么什么排序。我觉得这很不科学,如果不知道这道题是贪心,独立想这道题最直观的思路应该是考察目标函数的最小值吧(也可能是我太另类了)。。
根据题目描述,可以看出整个函数是一个定义域为R的分段的连续函数。并且有一个重要且明显的特征:在每一段内是一个可以求出系数的开口向下的二次函数,知道了系数a,b,c,求解二次函数区间最小值就行了。知道这些,问题的核心就集中在找到目标函数的所有分段,以及如何在分段发生变化的时候O(1)地维护二次函数的三个系数。
显然,目标函数分段的根本原因就是其中离当前点最近的某一类加工处发生了改变。也就是说,相邻两个同类加工处的中点必然是目标函数的一个分段点。只需要将所有的相邻的同类加工处的中点排序,即可得到目标函数所有分段点。
其次,如何维护二次函数的三个系数呢?观察目标函数的来源:f(x) = (x-p1)^2+(x-p2)^2... 其中每一个p是离当前x最近的各个类别的加工处的横坐标。每经过一个分段点,函数f(x)发生改变的其实只有一项,假设(x-pi)^2变为(x-pi')^2,系数变化为:b+=2*pi,c-=pi^2, b-=2*pi', c+=pi'^2。前两个式子为删除上一个被淘汰的加工处,后两个式子是引入新的更近的加工处。
然后就可以将所有分段点(不会超过m个)排序,以此枚举分段点,求二次函数区间最值即可。
我的代码其实有问题,没有给每一类的链表排序,这样实际会错的。但是这题良心数据,给出的点都是升序的,所以过了。。
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> using namespace std; #define DB double const int MAXN = 10005; const int MAXM = 200005; int N, M, cnt; struct Node { //我忘记给这个链表排序了,如果给出的点没有升序,将会出错。建议这部分换成set DB x; Node*next; } Edge[MAXM*2], *ecnt=Edge, *head[MAXN], *tail[MAXN]; void insert(int id, DB x) { (++ecnt)->x = x; if (!head[id]) head[id] = tail[id] = ecnt; else if (x < head[id]->x) { //将最小值放在最前面 ecnt->next = head[id]; head[id] = ecnt; } else { tail[id]->next = ecnt; tail[id] = ecnt; } } DB A, B, C; //二次函数系数 DB bestp, best; DB vrleft = 1e7; DB cur[MAXN]; struct TurnP { //转折点 Node*pre; DB xp; int tp; TurnP() {} TurnP(Node*p,DB a,int b) {pre=p; xp=a; tp=b;} bool operator < (const TurnP&t) const { if (xp == t.xp) return tp < t.tp; return xp < t.xp; } } pt[MAXM]; inline DB f(DB x) { //返回二次函数值f(x) return A*x*x + B*x + C; } inline DB minf(DB l, DB r, DB&x) { //返回f(x)在l到r的最小值 DB z = - B / (2.0*A); if (z>=l && z<=r) return f(x=z); if (z>r) return f(x=r); //单减 else return f(x=l); } int main() { int i, type; DB x, temp, val, L, R; scanf("%d%d", &N, &M); for (i = 1; i<=M; ++i) { scanf("%lf%d", &x, &type); if (x < vrleft) vrleft = x; insert(type, x); } for (i = 1; i<=N; ++i) { if (!head[i]) continue; x = head[i]->x; A = A + 1.0; cur[i] = x; B = B - 2 * x; C = C + x * x; for (Node*p = head[i]; p; p = p->next) { if (p->next) pt[++cnt] = TurnP(p, ((p->x)+(p->next->x))/2, i); } } sort(pt+1, pt+cnt+1); bestp = x = vrleft; best = f(vrleft); for (i = 1; i<=cnt; ++i) { temp = pt[i].pre->x; B = B + 2 * temp; C = C - temp * temp; temp = pt[i].pre->next->x; B = B - 2 * temp; C = C + temp * temp; temp = 0; L = pt[i].xp; R = (i==cnt?1e8:pt[i+1].xp); val = minf(L, R, temp); if (val < best) best = val, bestp = temp; } printf("%.4lf\n", bestp); return 0; }