2023NOIP A层联测31-暴力操作

有一个长为 n n n 的序列 { a i } \{a_i\} {ai},你可以操作若干次:选择一个 i i i,花费 c x c_x cx 元将 a i a_i ai 变为 ⌊ a i x ⌋ ⌊\frac{a_i}x⌋ xai,你总共有 K K K 元。问最终序列的中位数最小是多少。

保证 n n n 为奇数, 1 ≤ a i ≤ m , 1 ≤ n , m ≤ 5 × 1 0 5 , 1 ≤ c i , k ≤ 1 0 9 1≤a_i≤m,1\le n,m\le5\times10^5,1\le c_i,k\le10^9 1aim,1n,m5×105,1ci,k109


先对 a a a 排序。然后二分答案。发现我们只需对 i ≤ n + 1 2 i\le \frac{n+1}2 i2n+1 a i a_i ai 进行操作。由于 ⌊ ⌊ a x ⌋ y ⌋ = ⌊ a x y ⌋ \left\lfloor\dfrac{\lfloor\frac{a}{x}\rfloor}{y}\right\rfloor=\left\lfloor\dfrac{a}{xy}\right\rfloor yxa=xya,所以考虑对于一个 a i a_i ai,它无论如何操作,实际上都只是除以了一个 x x x,考虑怎样选择若干数 b i b_i bi,和为 x x x,花费是 ∑ c b i \sum c_{b_i} cbi,花费最小。这可以用 dp 求。设 f i , j f_{i,j} fi,j 表示前 i i i 个数中选的数乘积为 j j j 的最小花费(乘积要枚举到 2 m 2m 2m),转移显然是 f i , j = min ⁡ ( f i − 1 , j , f i − 1 , j i k + k ⋅ c i ) f_{i,j}=\min(f_{i-1,j},f_{i-1,\frac{j}{i^k}}+k\cdot c_i) fi,j=min(fi1,j,fi1,ikj+kci),初始 f 1 , 1 = 0 f_{1,1}=0 f1,1=0。这部分的 dp 时间复杂度为 O ( ∑ i = 2 m ∑ j = 1 ∞ ⌊ m i j ⌋ ) = O ( m ln ⁡ m ) O(\sum\limits_{i=2}^m\sum\limits_{j=1}^{\infty}\lfloor\frac{m}{i^j}\rfloor)=O(m\ln m) O(i=2mj=1ijm⌋)=O(mlnm)

要使 a i a_i ai 比中位数 m i d mid mid 要小, a i a_i ai 要除以的值至少大于 ⌊ a i + m i d + 1 m i d + 1 ⌋ \lfloor\frac{a_i+mid+1}{mid+1}\rfloor mid+1ai+mid+1,因为有可能除以更大的值花费更小,所以我们要对 f f f 做后缀和。

时间复杂度 O ( n log ⁡ n m + m ln ⁡ m ) O(n\log nm+m\ln m) O(nlognm+mlnm)

#include
using namespace std;
#define ll long long
const int N=5e5+10;
int n,m,K,a[N],c[N];
ll f[N<<1],Min[N<<1];
bool check(int mid)
{
    ll sum=0;
    for(int i=1;i*2<=n+1;i++){
        if(a[i]<=mid) continue;
        int x=(a[i]+mid+1)/(mid+1);
        sum+=Min[x];
        if(sum>K) return 0;
    }
    return 1;
}
int main()
{
    freopen("opt.in","r",stdin);
    freopen("opt.out","w",stdout);
    cin.tie(0)->sync_with_stdio(0);
    cin>>n>>m>>K;
    for(int i=1;i<=n;i++) cin>>a[i];
    for(int i=1;i<=m;i++) cin>>c[i];
    sort(a+1,a+1+n);
    memset(f,0x3f,sizeof(f));
    memset(Min,0x3f,sizeof(Min));
    f[1]=0;
    for(int i=2;i<=m;i++){
        int x=2*m/i*i;
        for(int j=x;j;j-=i){
            ll pw=i;
            for(int k=1;pw<=2*m&&j%pw==0;k++,pw*=i){
                f[j]=min(f[j],f[j/pw]+1ll*k*c[i]);
            }
        }
    }
    for(int i=2*m;i>=1;i--) Min[i]=min(Min[i+1],f[i]);
    int l=0,r=m,ans=m;
    while(l<=r){
        int mid=l+r>>1;
        if(check(mid)) r=mid-1,ans=mid;
        else l=mid+1;
    }
    cout<<ans;
}

你可能感兴趣的:(算法)