坑了一上午QAQ感觉自己好弱智啊
BZOJ上的题面坑死人,2N是闹哪样啊,明明是2^N,害的我还以为是水题,WA了好几次。
然后上COGS(好评)上看了下题,发现是2^N,然后论文里的省空间方法好麻烦,于是直接用vector+动态开节点水过去了。
话说我这个写得怎么这么像线段树2333333
#include<iostream> #include<cstdio> #include<cstring> #include<vector> using namespace std; const int inf=1e9; struct Node{ int lc,rc,pa,sum[1100]; vector<int>f[1100]; }tr[2200]; int n,f[1100][1100],tot; int bin[20]; void build(int &o,int l,int r,int depth){ o=++tot; memset(tr[o].sum,0,sizeof(tr[o].sum)); for(int i=0;i<=r-l+1;i++) tr[o].f[i].resize(bin[depth]); if(l==r){ for(int i=1;i<=n;i++) tr[o].sum[i]+=f[i][l]; return; } int mid=l+r>>1; build(tr[o].lc,l,mid,depth+1);build(tr[o].rc,mid+1,r,depth+1); for(int i=1;i<=n;i++) tr[o].sum[i]=tr[tr[o].lc].sum[i]+tr[tr[o].rc].sum[i]; tr[tr[o].lc].pa=tr[tr[o].rc].pa=o; } int calclayer(int x){ int ans=0; while(x)x=tr[x].pa,ans++; return ans; } int org[1100],change[1100]; void dp(int o,int l,int r){ int k=calclayer(o); if(l==r){ for(int j=0;j<=1;j++) for(int i=0;i<bin[k-1];i++){ int t=!j; tr[o].f[j][i]+=(t!=org[l])*change[l]; int x=k-1,y=o; while(x){ int c=(tr[tr[y].pa].lc)==y; tr[o].f[j][i]+=(((i&bin[x-1])?1:0)==t)*tr[c?(tr[tr[y].pa].rc):(tr[tr[y].pa].lc)].sum[l]; x--;y=tr[y].pa; } } }else{ int mid=l+r>>1; dp(tr[o].lc,l,mid);dp(tr[o].rc,mid+1,r); int len=r-l+1; for(int j=0;j<=len;j++){ int s=(j>=len-j)*bin[k-1]; for(int i=0;i<bin[k-1];i++){ tr[o].f[j][i]=inf; for(int u=0;u<=j;u++){ if(u>len-(len>>1)||j-u>(len>>1))continue; tr[o].f[j][i]=min(tr[o].f[j][i],tr[tr[o].lc].f[u][i|s]+tr[tr[o].rc].f[j-u][i|s]); //printf("%d %d %d\n",u,i|s,tr[lc].f[u][i|s]+tr[rc].f[j-u][i|s]); } //printf("%d %d %d %d %d\n",l,r,j,i,tr[o].f[j][i]); } } } } int main(){ //freopen("networkcost.in","r",stdin); //freopen("networkcost.out","w",stdout); bin[0]=1; for(int i=1;i<20;i++)bin[i]=bin[i-1]<<1; scanf("%d",&n);n=1<<n; for(int i=1;i<=n;i++)scanf("%d",&org[i]); for(int i=1;i<=n;i++)scanf("%d",&change[i]); for(int i=1;i<n;i++){ for(int j=1;i+j<=n;j++) scanf("%d",&f[i][i+j]); for(int j=i+1;j<=n;j++) f[j][i]=f[i][j]; } int root; build(root,1,n,0); dp(root,1,n); int ans=inf; for(int i=0;i<=n;i++) ans=min(ans,tr[root].f[i][0]); printf("%d\n",ans); return 0; }