E. Permutation Separation,Educational Codeforces Round 81 (Rated for Div. 2),线段树

E. Permutation Separation

http://codeforces.com/contest/1295/problem/E
You are given a permutation p1,p2,…,pn (an array where each integer from 1 to n appears exactly once). The weight of the i-th element of this permutation is ai.

At first, you separate your permutation into two non-empty sets — prefix and suffix. More formally, the first set contains elements p1,p2,…,pk, the second — pk+1,pk+2,…,pn, where 1≤k

After that, you may move elements between sets. The operation you are allowed to do is to choose some element of the first set and move it to the second set, or vice versa (move from the second set to the first). You have to pay ai dollars to move the element pi.

Your goal is to make it so that each element of the first set is less than each element of the second set. Note that if one of the sets is empty, this condition is met.

For example, if p=[3,1,2] and a=[7,1,4], then the optimal strategy is: separate p into two parts [3,1] and [2] and then move the 2-element into first set (it costs 4). And if p=[3,5,1,6,2,4], a=[9,1,9,9,1,9], then the optimal strategy is: separate p into two parts [3,5,1] and [6,2,4], and then move the 2-element into first set (it costs 1), and 5-element into second set (it also costs 1).

考场上想出来了没写得完。。。有点难受
思路:先假设最终两个集合为"空集"和"1 ~ n",算出k取1到n-1的代价,然后每次将最终集合分界点往右移动一下(比如第一次移动后变为"1"和"2 ~ n"),假设第i次移动是将i从右集合移动到左集合,设原序列中 p j = i p_j = i pj=i,可以发现k取1 ~ j-1的代价都会加上 a j a_j aj,k取j ~ n-1的代价都会减去 a j a_j aj,直接线段树维护最小值就好

#include
#define MAXN 200005
#define ll long long
#define lson tr[k<<1]
#define rson tr[k<<1|1]
#define nw tr[k]
#define INF 0x3f3f3f3f3f3f3f3f
using namespace std;
struct node
{
	ll minn,lz;
}tr[MAXN << 2];
int n,p[MAXN],id[MAXN];
ll a[MAXN],w[MAXN];
inline void pushup(int k)
{
	nw.minn = min(lson.minn,rson.minn);
}
inline void pushdown(int k)
{
	lson.lz += nw.lz,rson.lz += nw.lz;
	lson.minn += nw.lz,rson.minn += nw.lz;
	nw.lz = 0;
}
inline void build(int k,int l,int r)
{
	nw.lz = 0;
	if(l == r)
	{
		nw.minn = w[l];
		return;
	}
	int mid = (l+r) >> 1;
	build(k<<1,l,mid);
	build(k<<1|1,mid+1,r);
	pushup(k);
}
inline void add(int k,int cl,int cr,int l,int r,ll x)
{
	if(nw.lz && l != r)
		pushdown(k);
	if(cl == l && cr == r)
	{
		nw.minn += x;
		nw.lz += x;
		return;
	}
	int mid = (l + r) >> 1;
	if(cl > mid)
		add(k<<1|1,cl,cr,mid+1,r,x);
	else if(cr <= mid)
		add(k<<1,cl,cr,l,mid,x);
	else
		add(k<<1,cl,mid,l,mid,x),add(k<<1|1,mid+1,cr,mid+1,r,x);
	pushup(k);
}
int main()
{
	while(~scanf("%d",&n))
	{
		for(int i = 1;i <= n;++i)
			scanf("%d",&p[i]),id[p[i]] = i;
		for(int i = 1;i <= n;++i)
			scanf("%lld",&a[i]);
		w[1] = a[1];
		for(int i = 2;i < n;++i)
			w[i] = w[i-1] + a[i];
		build(1,1,n-1);
		ll ans = tr[1].minn;
		for(int i = 1;i <= n;++i)
		{
			int tmp = id[i];
			if(tmp != 1)
				add(1,1,tmp-1,1,n-1,a[tmp]);
			if(tmp != n)
				add(1,tmp,n-1,1,n-1,-a[tmp]);
			ans = min(ans,tr[1].minn);
		}
		printf("%lld\n",ans);
	}
	return 0;
}

你可能感兴趣的:(数据结构)